diff --git a/lib/explorer/backend/series.ex b/lib/explorer/backend/series.ex index 90ddf6f1f..64f8e9e31 100644 --- a/lib/explorer/backend/series.ex +++ b/lib/explorer/backend/series.ex @@ -305,7 +305,7 @@ defmodule Explorer.Backend.Series do # Struct @callback field(s, String.t()) :: s - @callback json_decode(s, dtype() | nil, non_neg_integer() | nil) :: s + @callback json_decode(s, option(dtype()), option(non_neg_integer())) :: s # Functions diff --git a/lib/explorer/polars_backend/shared.ex b/lib/explorer/polars_backend/shared.ex index 56c2df183..d426acb28 100644 --- a/lib/explorer/polars_backend/shared.ex +++ b/lib/explorer/polars_backend/shared.ex @@ -42,36 +42,55 @@ defmodule Explorer.PolarsBackend.Shared do def apply_dataframe(%DataFrame{} = df, %DataFrame{} = out_df, fun, args) do case apply(Native, fun, [df.data | args]) do {:ok, %module{} = new_df} when module in @polars_df -> - if @check_frames do - # We need to collect here, because the lazy frame may not have - # the full picture of the result yet. - check_df = - if match?(%PolarsLazyFrame{}, new_df) do - {:ok, new_df} = Native.lf_collect(new_df) - create_dataframe(new_df) - else - create_dataframe(new_df) + {struct?, dtypes} = + if @check_frames do + # We need to collect here, because the lazy frame may not have + # the full picture of the result yet. + check_df = + if match?(%PolarsLazyFrame{}, new_df) do + {:ok, new_df} = Native.lf_collect(new_df) + create_dataframe(new_df) + else + create_dataframe(new_df) + end + + # When dealing with structs in mutate, we may not know dtype of struct series. + # We have to accept the dtype returned by polars, else we will have mismatch error. + {struct?, out_dtypes} = + if fun == :df_mutate_with_exprs do + Enum.reduce(check_df.dtypes, {false, out_df.dtypes}, fn + {key, {:struct, _} = dtype}, {_, dtypes} -> {true, Map.put(dtypes, key, dtype)} + _, acc -> acc + end) + else + {false, out_df.dtypes} + end + + if Enum.sort(out_df.names) != Enum.sort(check_df.names) or + out_dtypes != check_df.dtypes do + raise """ + DataFrame mismatch. + + expected: + + names: #{inspect(out_df.names)} + dtypes: #{inspect(out_df.dtypes)} + + got: + + names: #{inspect(check_df.names)} + dtypes: #{inspect(check_df.dtypes)} + """ end - if Enum.sort(out_df.names) != Enum.sort(check_df.names) or - out_df.dtypes != check_df.dtypes do - raise """ - DataFrame mismatch. - - expected: - - names: #{inspect(out_df.names)} - dtypes: #{inspect(out_df.dtypes)} - - got: - - names: #{inspect(check_df.names)} - dtypes: #{inspect(check_df.dtypes)} - """ + {struct?, out_dtypes} end - end - %{out_df | data: new_df} + if struct? do + %{out_df | data: new_df, dtypes: dtypes} + else + %{out_df | data: new_df} + end {:error, error} -> raise runtime_error(error) diff --git a/lib/explorer/series.ex b/lib/explorer/series.ex index 8c8da30c6..0b1bdd379 100644 --- a/lib/explorer/series.ex +++ b/lib/explorer/series.ex @@ -6044,7 +6044,7 @@ defmodule Explorer.Series do end @doc """ - Decode json from string + Decode json from string. ## Examples @@ -6054,6 +6054,8 @@ defmodule Explorer.Series do Polars[1] struct[1] [%{"a" => 1}] > + + Will raise `RuntimeError` for invalid json. """ @doc type: :struct_wise @spec json_decode(Series.t(), Keyword.t()) :: Series.t() diff --git a/native/explorer/src/expressions.rs b/native/explorer/src/expressions.rs index 3d51a5d7f..1acd3e460 100644 --- a/native/explorer/src/expressions.rs +++ b/native/explorer/src/expressions.rs @@ -1072,7 +1072,7 @@ pub fn expr_json_decode( ex_dtype: Option, infer_schema_length: Option, ) -> ExExpr { - let dtype = ex_dtype.map(|x| DataType::try_from(&x).unwrap()); //DataType::try_from().unwrap(); + let dtype = ex_dtype.map(|x| DataType::try_from(&x).unwrap()); let expr = expr .clone_inner() .str() diff --git a/test/explorer/data_frame_test.exs b/test/explorer/data_frame_test.exs index b08997c54..4fdaa773c 100644 --- a/test/explorer/data_frame_test.exs +++ b/test/explorer/data_frame_test.exs @@ -1870,6 +1870,32 @@ defmodule Explorer.DataFrameTest do member?: [true, false] } end + + test "extracts struct from json - json_decode" do + df = DF.new([%{a: "{\"n\": 1}"}]) + dfj = DF.mutate(df, aj: json_decode(a, dtype: {:struct, %{"n" => {:s, 64}}})) + assert dfj.dtypes == %{"a" => :string, "aj" => {:struct, %{"n" => {:s, 64}}}} + assert DF.to_rows(dfj) == [%{"a" => "{\"n\": 1}", "aj" => %{"n" => 1}}] + end + + test "extracts struct from json - json_decode with dtype" do + df = DF.new([%{a: "{\"n\": 1}"}]) + dfj = DF.mutate(df, aj: json_decode(a, dtype: {:struct, %{"n" => {:f, 64}}})) + assert dfj.dtypes == %{"a" => :string, "aj" => {:struct, %{"n" => {:f, 64}}}} + assert DF.to_rows(dfj) == [%{"a" => "{\"n\": 1}", "aj" => %{"n" => 1.0}}] + end + + test "extracts struct from json - json_decode with infer_schema_length" do + df = DF.new([%{a: "{\"n\": 1}"}]) + + dfj = + DF.mutate(df, + aj: json_decode(a, infer_schema_length: 100) + ) + + assert dfj.dtypes == %{"a" => :string, "aj" => {:struct, %{"n" => {:s, 64}}}} + assert DF.to_rows(dfj) == [%{"a" => "{\"n\": 1}", "aj" => %{"n" => 1}}] + end end describe "sort_by/3" do