Skip to content

Commit

Permalink
Series.concat handle multiple integer, float and null types (#812)
Browse files Browse the repository at this point in the history
  • Loading branch information
lkarthee authored Jan 17, 2024
1 parent 8fa5071 commit b7e6048
Showing 1 changed file with 46 additions and 11 deletions.
57 changes: 46 additions & 11 deletions lib/explorer/series.ex
Original file line number Diff line number Diff line change
Expand Up @@ -2097,20 +2097,55 @@ defmodule Explorer.Series do
@doc type: :shape
@spec concat([Series.t()]) :: Series.t()
def concat([%Series{} | _t] = series) do
dtypes = series |> Enum.map(& &1.dtype) |> Enum.uniq()
{null?, dtypes_map} =
Enum.reduce(series, {false, %{}}, fn
%Series{dtype: :null}, {true, _dtypes} = acc ->
acc

case dtypes do
[_dtype] ->
impl!(series).concat(series)
%Series{dtype: :null}, {_, dtypes} ->
{true, dtypes}

[a, b] when K.and(is_numeric_dtype(a), is_numeric_dtype(b)) ->
series = Enum.map(series, &cast(&1, {:f, 64}))
impl!(series).concat(series)
%Series{dtype: dt}, {null?, dtypes} when is_atom(dt) ->
{null?, Map.put_new(dtypes, dt, dt)}

incompatible ->
raise ArgumentError,
"cannot concatenate series with mismatched dtypes: #{inspect(incompatible)}. " <>
"First cast the series to the desired dtype."
%Series{dtype: {type, n} = dt}, {null?, dtypes} ->
dtypes =
case dtypes[type] do
nil -> Map.put(dtypes, type, dt)
{_t, tn} when n > tn -> Map.put(dtypes, type, dt)
_ -> dtypes
end

{null?, dtypes}
end)

dtypes = Map.values(dtypes_map)

series =
series
|> maybe_raise_mismatched!(dtypes)
|> maybe_cast(null?, dtypes)

impl!(series).concat(series)
end

defp maybe_cast(series, true, []), do: series

defp maybe_cast(series, _, [dtype]), do: Enum.map(series, &cast(&1, dtype))

defp maybe_cast(series, _, _), do: Enum.map(series, &cast(&1, {:f, 64}))

defp maybe_raise_mismatched!(series, [_dtype]), do: series

defp maybe_raise_mismatched!(series, dtypes) do
if Enum.all?(dtypes, &is_numeric_dtype/1) do
series
else
dtypes = Enum.map(series, & &1.dtype)

raise ArgumentError,
"cannot concatenate series with mismatched dtypes: #{inspect(dtypes)}. " <>
"First cast the series to the desired dtype."
end
end

Expand Down

0 comments on commit b7e6048

Please sign in to comment.