Skip to content

Commit

Permalink
Fix mismatched types in Series.pow (#821)
Browse files Browse the repository at this point in the history
* first attempt: use lazy version underneath

* fix tests

* missed some tests

* fix a warning related to default defp args

* informative panics (whoops...)

* drop todo -- infer signed int is likely here to stay

* pull dtype off ExSeries

* fix rust linter error

* also test pow(uint, sint) == float64

* simplify dtype call

Co-authored-by: lkarthee <lkarthee@users.noreply.github.com>

* whoops, wrong variable

* use match instead of if/else if/else

Co-authored-by: lkarthee <lkarthee@users.noreply.github.com>

* keep everything in elixir

* also fix doctests

* use === in tests with float assertions

* use cast_to_pow to declare out_dtype

* fix warning (again)

* consilidate cases

Co-authored-by: josevalim <jose.valim@dashbit.co>

* pre-cast to ensure precision

---------

Co-authored-by: lkarthee <lkarthee@users.noreply.github.com>
Co-authored-by: josevalim <jose.valim@dashbit.co>
  • Loading branch information
3 people authored Jan 16, 2024
1 parent e30207e commit 8fa5071
Show file tree
Hide file tree
Showing 9 changed files with 154 additions and 108 deletions.
4 changes: 2 additions & 2 deletions lib/explorer/backend/lazy_series.ex
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,8 @@ defmodule Explorer.Backend.LazySeries do

@comparison_operations [:equal, :not_equal, :greater, :greater_equal, :less, :less_equal]

@basic_arithmetic_operations [:add, :subtract, :multiply, :divide]
@other_arithmetic_operations [:pow, :quotient, :remainder]
@basic_arithmetic_operations [:add, :subtract, :multiply, :divide, :pow]
@other_arithmetic_operations [:quotient, :remainder]

@aggregation_operations [
:sum,
Expand Down
2 changes: 1 addition & 1 deletion lib/explorer/backend/series.ex
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ defmodule Explorer.Backend.Series do
@callback divide(out_dtype :: dtype(), s, s) :: s
@callback quotient(s, s) :: s
@callback remainder(s, s) :: s
@callback pow(s, s) :: s
@callback pow(out_dtype :: dtype(), s, s) :: s
@callback log(argument :: s) :: s
@callback log(argument :: s, base :: float()) :: s
@callback exp(s) :: s
Expand Down
1 change: 0 additions & 1 deletion lib/explorer/polars_backend/native.ex
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,6 @@ defmodule Explorer.PolarsBackend.Native do
def s_peak_max(_s), do: err()
def s_peak_min(_s), do: err()
def s_select(_pred, _on_true, _on_false), do: err()
def s_pow(_s, _other), do: err()
def s_log_natural(_s_argument), do: err()
def s_log(_s_argument, _base_as_float), do: err()
def s_quantile(_s, _quantile, _strategy), do: err()
Expand Down
24 changes: 22 additions & 2 deletions lib/explorer/polars_backend/series.ex
Original file line number Diff line number Diff line change
Expand Up @@ -316,8 +316,28 @@ defmodule Explorer.PolarsBackend.Series do
do: Shared.apply_series(matching_size!(left, right), :s_remainder, [right.data])

@impl true
def pow(left, right),
do: Shared.apply_series(matching_size!(left, right), :s_pow, [right.data])
def pow(out_dtype, left, right) do
_ = matching_size!(left, right)

# We need to pre-cast or we may lose precision.
left = Explorer.Series.cast(left, out_dtype)

left_lazy = Explorer.Backend.LazySeries.new(:column, ["base"], left.dtype)
right_lazy = Explorer.Backend.LazySeries.new(:column, ["exponent"], right.dtype)

{df_args, pow_args} =
case {size(left), size(right)} do
{n, n} -> {[{"base", left}, {"exponent", right}], [left_lazy, right_lazy]}
{1, _} -> {[{"exponent", right}], [Explorer.Series.at(left, 0), right_lazy]}
{_, 1} -> {[{"base", left}], [left_lazy, Explorer.Series.at(right, 0)]}
end

df = Explorer.PolarsBackend.DataFrame.from_series(df_args)
pow = Explorer.Backend.LazySeries.new(:pow, pow_args, out_dtype)

Explorer.PolarsBackend.DataFrame.mutate_with(df, df, [{"pow", pow}])
|> Explorer.PolarsBackend.DataFrame.pull("pow")
end

@impl true
def log(%Series{} = argument), do: Shared.apply_series(argument, :s_log_natural, [])
Expand Down
22 changes: 18 additions & 4 deletions lib/explorer/series.ex
Original file line number Diff line number Diff line change
Expand Up @@ -3405,7 +3405,7 @@ defmodule Explorer.Series do
iex> Explorer.Series.pow(s, 3)
#Explorer.Series<
Polars[3]
s64 [8, 64, 216]
f64 [8.0, 64.0, 216.0]
>
iex> s = [2, 4, 6] |> Explorer.Series.from_list()
Expand All @@ -3431,7 +3431,23 @@ defmodule Explorer.Series do
"""
@doc type: :element_wise
@spec pow(left :: Series.t() | number(), right :: Series.t() | number()) :: Series.t()
def pow(left, right), do: basic_numeric_operation(:pow, left, right)
def pow(left, right) do
[left, right] = cast_for_arithmetic("pow/2", [left, right])

if out_dtype = cast_to_pow(dtype(left), dtype(right)) do
apply_series_list(:pow, [out_dtype, left, right])
else
dtype_mismatch_error("pow/2", left, right)
end
end

defp cast_to_pow({:u, l}, {:u, r}), do: {:u, max(l, r)}
defp cast_to_pow({:s, s}, {:u, u}), do: {:s, min(64, max(2 * u, s))}
defp cast_to_pow({:f, l}, {:f, r}), do: {:f, max(l, r)}
defp cast_to_pow({:f, l}, {n, _}) when K.in(n, [:u, :s]), do: {:f, l}
defp cast_to_pow({n, _}, {:f, r}) when K.in(n, [:u, :s]), do: {:f, r}
defp cast_to_pow({n, _}, {:s, _}) when K.in(n, [:u, :s]), do: {:f, 64}
defp cast_to_pow(_, _), do: nil

@doc """
Calculates the natural logarithm.
Expand Down Expand Up @@ -3761,8 +3777,6 @@ defmodule Explorer.Series do
def atan(%Series{dtype: dtype}),
do: dtype_error("atan/1", dtype, [{:f, 32}, {:f, 64}])

defp basic_numeric_operation(operation, left, right, args \\ [])

defp basic_numeric_operation(operation, %Series{} = left, right, args) when is_numeric(right),
do: basic_numeric_operation(operation, left, from_same_value(left, right), args)

Expand Down
1 change: 0 additions & 1 deletion native/explorer/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,6 @@ rustler::init!(
s_peak_max,
s_peak_min,
s_select,
s_pow,
s_quantile,
s_quotient,
s_rank,
Expand Down
63 changes: 0 additions & 63 deletions native/explorer/src/series.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1298,69 +1298,6 @@ pub fn s_n_distinct(s: ExSeries) -> Result<usize, ExplorerError> {
Ok(s.n_unique()?)
}

#[rustler::nif(schedule = "DirtyCpu")]
pub fn s_pow(s: ExSeries, other: ExSeries) -> Result<ExSeries, ExplorerError> {
match (s.dtype().is_integer(), other.dtype().is_integer()) {
(true, true) => {
let cast1 = s.cast(&DataType::Int64)?;
let mut iter1 = cast1.i64()?.into_iter();

match other.strict_cast(&DataType::UInt32) {
Ok(casted) => {
let mut iter2 = casted.u32()?.into_iter();

let res = if s.len() == 1 {
let v1 = iter1.next().unwrap();
iter2
.map(|v2| v1.and_then(|left| v2.map(|right| left.pow(right))))
.collect()
} else if other.len() == 1 {
let v2 = iter2.next().unwrap();
iter1
.map(|v1| v1.and_then(|left| v2.map(|right| left.pow(right))))
.collect()
} else {
iter1
.zip(iter2)
.map(|(v1, v2)| v1.and_then(|left| v2.map(|right| left.pow(right))))
.collect()
};

Ok(ExSeries::new(res))
}
Err(_) => Err(ExplorerError::Other(
"negative exponent with an integer base".into(),
)),
}
}
(_, _) => {
let cast1 = s.cast(&DataType::Float64)?;
let cast2 = other.cast(&DataType::Float64)?;
let mut iter1 = cast1.f64()?.into_iter();
let mut iter2 = cast2.f64()?.into_iter();

let res = if s.len() == 1 {
let v1 = iter1.next().unwrap();
iter2
.map(|v2| v1.and_then(|left| v2.map(|right| left.powf(right))))
.collect()
} else if other.len() == 1 {
let v2 = iter2.next().unwrap();
iter1
.map(|v1| v1.and_then(|left| v2.map(|right| left.powf(right))))
.collect()
} else {
iter1
.zip(iter2)
.map(|(v1, v2)| v1.and_then(|left| v2.map(|right| left.powf(right))))
.collect()
};

Ok(ExSeries::new(res))
}
}
}

#[rustler::nif(schedule = "DirtyCpu")]
pub fn s_cast(s: ExSeries, to_type: ExSeriesDtype) -> Result<ExSeries, ExplorerError> {
let dtype = DataType::try_from(&to_type)?;
Expand Down
23 changes: 11 additions & 12 deletions test/explorer/data_frame_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ defmodule Explorer.DataFrameTest do
df = DF.new(a: [1, 2, 3, 4, 5, 6, 5], b: [9, 8, 7, 6, 5, 4, 3])

message =
"expecting the function to return a boolean LazySeries, but instead it returned a LazySeries of type {:s, 64}"
"expecting the function to return a boolean LazySeries, but instead it returned a LazySeries of type {:f, 64}"

assert_raise ArgumentError, message, fn ->
DF.filter_with(df, fn ldf ->
Expand Down Expand Up @@ -799,7 +799,7 @@ defmodule Explorer.DataFrameTest do
calc10: is_nan(divide(a * 0.0, 0.0))
)

assert DF.to_columns(df1, atom_keys: true) == %{
assert DF.to_columns(df1, atom_keys: true) === %{
a: [1, 2, 4],
calc1: [3, 4, 6],
calc2: [-1, 0, 2],
Expand All @@ -819,7 +819,7 @@ defmodule Explorer.DataFrameTest do
"calc2" => {:s, 64},
"calc3" => {:s, 64},
"calc4" => {:f, 64},
"calc5" => {:s, 64},
"calc5" => {:f, 64},
"calc6" => {:s, 64},
"calc7" => {:s, 64},
"calc8" => {:f, 64},
Expand All @@ -843,7 +843,7 @@ defmodule Explorer.DataFrameTest do
calc7: remainder(2, a)
)

assert DF.to_columns(df1, atom_keys: true) == %{
assert DF.to_columns(df1, atom_keys: true) === %{
a: [1, 2, 4],
calc1: [3, 4, 6],
calc2: [1, 0, -2],
Expand All @@ -861,8 +861,7 @@ defmodule Explorer.DataFrameTest do
"calc2" => {:s, 64},
"calc3" => {:s, 64},
"calc4" => {:f, 64},
# TODO: This should be float after #374 is resolved
"calc5" => {:s, 64},
"calc5" => {:f, 64},
"calc5_1" => {:f, 64},
"calc6" => {:s, 64},
"calc7" => {:s, 64}
Expand All @@ -884,7 +883,7 @@ defmodule Explorer.DataFrameTest do
calc7: remainder(a, ^series)
)

assert DF.to_columns(df1, atom_keys: true) == %{
assert DF.to_columns(df1, atom_keys: true) === %{
a: [1, 2, 4],
calc1: [3, 3, 6],
calc2: [-1, 1, 2],
Expand All @@ -901,7 +900,7 @@ defmodule Explorer.DataFrameTest do
"calc2" => {:s, 64},
"calc3" => {:s, 64},
"calc4" => {:f, 64},
"calc5" => {:s, 64},
"calc5" => {:f, 64},
"calc6" => {:s, 64},
"calc7" => {:s, 64}
}
Expand All @@ -922,7 +921,7 @@ defmodule Explorer.DataFrameTest do
calc7: remainder(^series, a)
)

assert DF.to_columns(df1, atom_keys: true) == %{
assert DF.to_columns(df1, atom_keys: true) === %{
a: [2, 1, 2],
calc1: [3, 3, 6],
calc2: [-1, 1, 2],
Expand All @@ -939,7 +938,7 @@ defmodule Explorer.DataFrameTest do
"calc2" => {:s, 64},
"calc3" => {:s, 64},
"calc4" => {:f, 64},
"calc5" => {:s, 64},
"calc5" => {:f, 64},
"calc6" => {:s, 64},
"calc7" => {:s, 64}
}
Expand All @@ -959,7 +958,7 @@ defmodule Explorer.DataFrameTest do
calc7: remainder(b, c)
)

assert DF.to_columns(df1, atom_keys: true) == %{
assert DF.to_columns(df1, atom_keys: true) === %{
a: [1, 2, 3],
b: [20, 40, 60],
c: [10, 0, 8],
Expand All @@ -982,7 +981,7 @@ defmodule Explorer.DataFrameTest do
"calc2" => {:s, 64},
"calc3" => {:s, 64},
"calc4" => {:f, 64},
"calc5" => {:s, 64},
"calc5" => {:f, 64},
"calc6" => {:s, 64},
"calc7" => {:s, 64}
}
Expand Down
Loading

0 comments on commit 8fa5071

Please sign in to comment.