From b7e60483bd929b7d085960d78144d9ac398232ce Mon Sep 17 00:00:00 2001 From: lkarthee Date: Wed, 17 Jan 2024 14:00:10 +0530 Subject: [PATCH] Series.concat handle multiple integer, float and null types (#812) --- lib/explorer/series.ex | 57 ++++++++++++++++++++++++++++++++++-------- 1 file changed, 46 insertions(+), 11 deletions(-) diff --git a/lib/explorer/series.ex b/lib/explorer/series.ex index de31d1a90..d27c87ce5 100644 --- a/lib/explorer/series.ex +++ b/lib/explorer/series.ex @@ -2097,20 +2097,55 @@ defmodule Explorer.Series do @doc type: :shape @spec concat([Series.t()]) :: Series.t() def concat([%Series{} | _t] = series) do - dtypes = series |> Enum.map(& &1.dtype) |> Enum.uniq() + {null?, dtypes_map} = + Enum.reduce(series, {false, %{}}, fn + %Series{dtype: :null}, {true, _dtypes} = acc -> + acc - case dtypes do - [_dtype] -> - impl!(series).concat(series) + %Series{dtype: :null}, {_, dtypes} -> + {true, dtypes} - [a, b] when K.and(is_numeric_dtype(a), is_numeric_dtype(b)) -> - series = Enum.map(series, &cast(&1, {:f, 64})) - impl!(series).concat(series) + %Series{dtype: dt}, {null?, dtypes} when is_atom(dt) -> + {null?, Map.put_new(dtypes, dt, dt)} - incompatible -> - raise ArgumentError, - "cannot concatenate series with mismatched dtypes: #{inspect(incompatible)}. " <> - "First cast the series to the desired dtype." + %Series{dtype: {type, n} = dt}, {null?, dtypes} -> + dtypes = + case dtypes[type] do + nil -> Map.put(dtypes, type, dt) + {_t, tn} when n > tn -> Map.put(dtypes, type, dt) + _ -> dtypes + end + + {null?, dtypes} + end) + + dtypes = Map.values(dtypes_map) + + series = + series + |> maybe_raise_mismatched!(dtypes) + |> maybe_cast(null?, dtypes) + + impl!(series).concat(series) + end + + defp maybe_cast(series, true, []), do: series + + defp maybe_cast(series, _, [dtype]), do: Enum.map(series, &cast(&1, dtype)) + + defp maybe_cast(series, _, _), do: Enum.map(series, &cast(&1, {:f, 64})) + + defp maybe_raise_mismatched!(series, [_dtype]), do: series + + defp maybe_raise_mismatched!(series, dtypes) do + if Enum.all?(dtypes, &is_numeric_dtype/1) do + series + else + dtypes = Enum.map(series, & &1.dtype) + + raise ArgumentError, + "cannot concatenate series with mismatched dtypes: #{inspect(dtypes)}. " <> + "First cast the series to the desired dtype." end end