Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix Series.from_list/2 bugs #826

Closed
wants to merge 13 commits into from
5 changes: 3 additions & 2 deletions lib/explorer/series.ex
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,7 @@ defmodule Explorer.Series do
* `:dtype` - Cast the series to a given `:dtype`. By default this is `nil`, which means
that Explorer will infer the type from the values in the list.
See the module docs for the list of valid dtypes and aliases.
* `:strict` - when `true` raises on overflow and underflow - defaults to `false`.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When it is false, what happens? The value is simply discarded?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am a bit confused because it seems Rustler will raise on overflow-underflow if strict: false.

Copy link
Member Author

@lkarthee lkarthee Jan 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When false overflow/underflow can happen. Discard only happens when a series is created as {:s, 64} and cast it to {:s, 8}. And cast does not happen with existing code with and without strict.

Series.from_list([-129, 256], dtype: {:s, 16}) |> Series.cast({:s, 8})
#Explorer.Series<
  Polars[2]
  s8 [nil, nil]
>

More details about rustler behaviour - #826 (comment)


## Examples

Expand Down Expand Up @@ -406,12 +407,12 @@ defmodule Explorer.Series do
@doc type: :conversion
@spec from_list(list :: list(), opts :: Keyword.t()) :: Series.t()
def from_list(list, opts \\ []) do
opts = Keyword.validate!(opts, [:dtype, :backend])
opts = Keyword.validate!(opts, [:dtype, :backend, strict: false])
backend = backend_from_options!(opts)

normalised_dtype = if opts[:dtype], do: Shared.normalise_dtype!(opts[:dtype])

type = Shared.dtype_from_list!(list, normalised_dtype)
type = Shared.dtype_from_list!(list, normalised_dtype, opts[:strict])
list = Shared.cast_numerics(list, type)

series = backend.from_list(list, type)
Expand Down
161 changes: 130 additions & 31 deletions lib/explorer/shared.ex
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ defmodule Explorer.Shared do
Gets the `dtype` of a list or raise error if not possible.
"""
def dtype_from_list!(list) do
Enum.reduce(list, :null, &infer_type/2)
dtype_from_list!(list, nil, false)
end

@doc """
Expand All @@ -270,71 +270,170 @@ defmodule Explorer.Shared do

If no preferred type is given (nil), then the inferred type is returned.
"""
def dtype_from_list!(_list, :null), do: :null
def dtype_from_list!(list, preferred_type, strict \\ false)

def dtype_from_list!(list, nil), do: dtype_from_list!(list)
def dtype_from_list!(_list, :null, _strict), do: :null

def dtype_from_list!(list, preferred_type) do
def dtype_from_list!(list, nil, strict) do
Enum.reduce(list, :null, &infer_type(&1, &2, nil, strict))
end

def dtype_from_list!(list, preferred_type, strict) do
list
|> dtype_from_list!()
|> Enum.reduce(:null, &infer_type(&1, &2, preferred_type, strict))
|> merge_preferred(preferred_type)
end

@non_finite [:nan, :infinity, :neg_infinity]

defp infer_type(nil, type), do: type
defp infer_type(item, :null), do: infer_type(item)
defp infer_type(integer, {:f, 64}) when is_integer(integer), do: {:f, 64}
defp infer_type(float, {:s, 64}) when is_float(float) or float in @non_finite, do: {:f, 64}
defp infer_type(list, {:list, type}) when is_list(list), do: infer_list(list, type)
defp infer_type(%{} = map, {:struct, inner}), do: infer_struct(map, inner)
defp infer_type(nil, type, _preferred, _strict), do: type
defp infer_type(item, :null, preferred, strict), do: infer_type(item, preferred, strict)

defp infer_type(integer, {:f, 64}, _preferred, _strict) when is_integer(integer), do: {:f, 64}

defp infer_type(float, {:s, _}, _preferred, _strict)
when is_float(float) or float in @non_finite,
do: {:f, 64}

defp infer_type(float, {:u, _}, _preferred, _strict)
when is_float(float) or float in @non_finite,
do: {:f, 64}

defp infer_type(list, {:list, type}, preferred, strict) when is_list(list) do
preferred =
case preferred do
{:list, preferred} -> preferred
_ -> preferred
end

infer_list(list, type, preferred, strict)
end

defp infer_type(%{} = map, {:struct, inner}, preferred, strict) do
preferred =
case preferred do
{:struct, preferred} -> preferred
_ -> preferred
end

defp infer_type(item, type) do
if infer_type(item) == type do
infer_struct(map, inner, preferred, strict)
end

defp infer_type(item, type, preferred, strict) do
if infer_type(item, preferred, strict) == type do
type
else
raise ArgumentError,
"the value #{inspect(item)} does not match the inferred dtype #{inspect(type)}"
end
end

defp infer_type(%Date{} = _item), do: :date
defp infer_type(%Time{} = _item), do: :time
defp infer_type(%NaiveDateTime{} = _item), do: {:datetime, :microsecond}
defp infer_type(%Explorer.Duration{precision: precision} = _item), do: {:duration, precision}
defp infer_type(item) when is_integer(item), do: {:s, 64}
defp infer_type(item) when is_float(item) or item in @non_finite, do: {:f, 64}
defp infer_type(item) when is_boolean(item), do: :boolean
defp infer_type(item) when is_binary(item), do: :string
defp infer_type(list) when is_list(list), do: infer_list(list, :null)
defp infer_type(%{} = map), do: infer_struct(map, nil)
defp infer_type(item), do: raise(ArgumentError, "unsupported datatype: #{inspect(item)}")
defp infer_type(%Date{} = _item, _preferred, _strict), do: :date
defp infer_type(%Time{} = _item, _preferred, _strict), do: :time
defp infer_type(%NaiveDateTime{} = _item, _preferred, _strict), do: {:datetime, :microsecond}

defp infer_type(%Explorer.Duration{precision: precision} = _item, _preferred, _strict),
do: {:duration, precision}

defp infer_type(item, _, false) when is_integer(item), do: {:s, 64}

defp infer_type(item, nil, true) when is_integer(item) do
if item < -9_223_372_036_854_775_808 or item > 9_223_372_036_854_775_807 do
raise_mismatched_dtype!(item, {:s, 64})
else
{:s, 64}
end
end

defp infer_type(item, preferred, true) when is_integer(item) do
case preferred do
{:s, 8} when item > -129 and item < 128 ->
preferred

{:s, 16} when item > -32_769 and item < 32_768 ->
preferred

{:s, 32} when item > -2_147_483_649 and item < 2_147_483_648 ->
preferred

{:s, 64} when item > -9_223_372_036_854_775_809 and item < 9_223_372_036_854_775_808 ->
preferred

{:s, _} ->
raise_mismatched_dtype!(item, preferred)

{:u, _} when item < 0 ->
raise_mismatched_dtype!(item, preferred)

{:u, 8} when item < 256 ->
preferred

{:u, 16} when item < 65_536 ->
preferred

{:u, 32} when item < 4_294_967_296 ->
preferred

{:u, 64} when item < 18_446_744_073_709_551_616 ->
preferred

{:u, _} ->
raise_mismatched_dtype!(item, preferred)

_ ->
{:s, 64}
end
end

defp infer_type(item, _preferred, _strict) when is_float(item) or item in @non_finite,
do: {:f, 64}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Btw, we also have precision issues with floats, since f32 cannot represent as large numbers as f64. :)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But that should probably be another pull request.


defp infer_type(item, _preferred, _strict) when is_boolean(item), do: :boolean
defp infer_type(item, _preferred, _strict) when is_binary(item), do: :string

defp infer_type(list, preferred, strict) when is_list(list),
do: infer_list(list, :null, preferred, strict)

defp infer_type(%{} = map, preferred, strict), do: infer_struct(map, nil, preferred, strict)

defp infer_list(list, type) do
{:list, Enum.reduce(list, type, &infer_type/2)}
defp infer_type(item, _preferred, _strict),
do: raise(ArgumentError, "unsupported datatype: #{inspect(item)}")

defp infer_list(list, type, preferred, strict) do
preferred =
case preferred do
{:list, preferred} -> preferred
_ -> preferred
end

{:list, Enum.reduce(list, type, &infer_type(&1, &2, preferred, strict))}
end

defp infer_struct(%{} = map, types) do
defp infer_struct(%{} = map, types, preferred, strict) do
types =
for {key, value} <- map, into: %{} do
key = to_string(key)

cond do
types == nil ->
{key, infer_type(value, :null)}
{key, infer_type(value, :null, preferred, strict)}

type = types[key] ->
{key, infer_type(value, type)}
{key, infer_type(value, type, preferred, strict)}

true ->
raise ArgumentError,
"the value #{inspect(map)} does not match the inferred dtype #{inspect({:struct, types})}"
raise_mismatched_dtype!(map, {:struct, types})
end
end

{:struct, types}
end

defp raise_mismatched_dtype!(value, type) do
raise ArgumentError,
"the value #{inspect(value)} does not match the inferred dtype #{inspect(type)}"
end

defp merge_preferred(type, type), do: type
defp merge_preferred(:null, type), do: type
defp merge_preferred({:s, 64}, {:u, _} = type), do: type
Expand Down
45 changes: 45 additions & 0 deletions test/explorer/series/list_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,51 @@ defmodule Explorer.Series.ListTest do
assert Series.to_list(series) == [[[1]]]
end

test "list of lists of one integer & one u64" do
series = Series.from_list([[1, 9_223_372_036_854_775_808]], dtype: {:list, {:u, 64}})

assert series.dtype == {:list, {:u, 64}}
assert series[0] == [1, 9_223_372_036_854_775_808]
assert Series.to_list(series) == [[1, 9_223_372_036_854_775_808]]
end

test "list of lists of one negative integer & one u64 " do
assert_raise ArgumentError,
"the value 9223372036854775808 does not match the inferred dtype {:s, 64}",
fn ->
Series.from_list([[-1, 9_223_372_036_854_775_808]], strict: true)
end
end

test "list of lists of integers with one u64" do
series =
Series.from_list([[0], [1], [2, 9_223_372_036_854_775_808], [3, nil, 4], nil, []],
dtype: {:list, {:u, 64}},
strict: true
)

assert series.dtype == {:list, {:u, 64}}
assert series[0] == [0]

assert Series.to_list(series) == [
[0],
[1],
[2, 9_223_372_036_854_775_808],
[3, nil, 4],
nil,
[]
]
end

test "list of lists of integers recursively - u64" do
series =
Series.from_list([[[1, 9_223_372_036_854_775_808]]], dtype: {:list, {:list, {:u, 64}}})

assert series.dtype == {:list, {:list, {:u, 64}}}
assert series[0] == [[1, 9_223_372_036_854_775_808]]
assert Series.to_list(series) == [[[1, 9_223_372_036_854_775_808]]]
end

test "list of lists of floats recursively" do
series = Series.from_list([[[1.52]]])

Expand Down
66 changes: 66 additions & 0 deletions test/explorer/series/struct_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,72 @@ defmodule Explorer.Series.StructTest do
assert Series.to_list(s) == [%{"a" => 1}, %{"a" => 3}, %{"a" => 5}]
end

test "allows struct values with dtype u64" do
s = Series.from_list([%{a: 1}, %{a: 3}, %{a: 5}], dtype: {:struct, %{"a" => {:u, 64}}})

assert s.dtype == {:struct, %{"a" => {:u, 64}}}

assert Series.to_list(s) == [%{"a" => 1}, %{"a" => 3}, %{"a" => 5}]
end

test "allows struct values - u64 integers" do
s = Series.from_list([%{a: 9_223_372_036_854_775_806}, %{a: 3}, %{a: 5}])

assert s.dtype == {:struct, %{"a" => {:s, 64}}}

assert Series.to_list(s) == [%{"a" => 9_223_372_036_854_775_806}, %{"a" => 3}, %{"a" => 5}]
end

test "allows struct values - list of u64 integers" do
s = Series.from_list([%{a: [9_223_372_036_854_775_806]}, %{a: [3]}, %{a: [5]}])

assert s.dtype == {:struct, %{"a" => {:list, {:s, 64}}}}

assert Series.to_list(s) == [
%{"a" => [9_223_372_036_854_775_806]},
%{"a" => [3]},
%{"a" => [5]}
]
end

test "allows struct values - list of u64 integers with dtype" do
dtype = {:struct, %{"a" => {:list, {:u, 64}}}}
list = [%{"a" => [9_223_372_036_854_775_802]}, %{"a" => [3]}, %{"a" => [5]}]
s = Series.from_list(list, dtype: dtype)

assert s.dtype == {:struct, %{"a" => {:list, {:u, 64}}}}

assert Series.to_list(s) == [
%{"a" => [9_223_372_036_854_775_802]},
%{"a" => [3]},
%{"a" => [5]}
]
end

test "allows struct values - integers with dtype u32" do
s = Series.from_list([%{a: 1}, %{a: 3}, %{a: 5}], dtype: {:struct, %{"a" => {:u, 32}}})

assert s.dtype == {:struct, %{"a" => {:u, 32}}}

assert Series.to_list(s) == [%{"a" => 1}, %{"a" => 3}, %{"a" => 5}]
end

test "allows struct values - integers with dtype s32" do
s = Series.from_list([%{a: 1}, %{a: 3}, %{a: 5}], dtype: {:struct, %{"a" => {:s, 32}}})

assert s.dtype == {:struct, %{"a" => {:s, 32}}}

assert Series.to_list(s) == [%{"a" => 1}, %{"a" => 3}, %{"a" => 5}]
end

test "allows struct values - s64 integers with dtype u64" do
s = Series.from_list([%{a: 1}, %{a: 3}, %{a: 5}], dtype: {:struct, %{"a" => {:u, 64}}})

assert s.dtype == {:struct, %{"a" => {:u, 64}}}

assert Series.to_list(s) == [%{"a" => 1}, %{"a" => 3}, %{"a" => 5}]
end

test "allows structs with nil values" do
s =
Series.from_list([
Expand Down
Loading
Loading