diff --git a/crates/polars-core/src/series/implementations/null.rs b/crates/polars-core/src/series/implementations/null.rs index b0e4a7d6e26a..b5d32aab0f8d 100644 --- a/crates/polars-core/src/series/implementations/null.rs +++ b/crates/polars-core/src/series/implementations/null.rs @@ -81,6 +81,23 @@ impl PrivateSeries for NullChunked { ExplodeByOffsets::explode_by_offsets(self, offsets) } + fn subtract(&self, _rhs: &Series) -> PolarsResult { + null_arithmetic(self, _rhs, "subtract") + } + + fn add_to(&self, _rhs: &Series) -> PolarsResult { + null_arithmetic(self, _rhs, "add_to") + } + fn multiply(&self, _rhs: &Series) -> PolarsResult { + null_arithmetic(self, _rhs, "multiply") + } + fn divide(&self, _rhs: &Series) -> PolarsResult { + null_arithmetic(self, _rhs, "divide") + } + fn remainder(&self, _rhs: &Series) -> PolarsResult { + null_arithmetic(self, _rhs, "remainder") + } + #[cfg(feature = "algorithm_group_by")] fn group_tuples(&self, _multithreaded: bool, _sorted: bool) -> PolarsResult { Ok(if self.is_empty() { @@ -98,6 +115,16 @@ impl PrivateSeries for NullChunked { } } +fn null_arithmetic(lhs: &NullChunked, rhs: &Series, op: &str) -> PolarsResult { + let output_len = match (lhs.len(), rhs.len()) { + (1, len_r) => len_r, + (len_l, 1) => len_l, + (len_l, len_r) if len_l == len_r => len_l, + _ => polars_bail!(ComputeError: "Cannot {:?} two series of different lengths.", op), + }; + Ok(NullChunked::new(lhs.name().into(), output_len).into_series()) +} + impl SeriesTrait for NullChunked { fn name(&self) -> &str { self.name.as_ref() diff --git a/py-polars/tests/unit/operations/arithmetic/test_arithmetic.py b/py-polars/tests/unit/operations/arithmetic/test_arithmetic.py index 317cf68ffba1..15d4f865576b 100644 --- a/py-polars/tests/unit/operations/arithmetic/test_arithmetic.py +++ b/py-polars/tests/unit/operations/arithmetic/test_arithmetic.py @@ -270,3 +270,26 @@ def test_operator_arithmetic_with_nulls(op: Any) -> None: assert_frame_equal(df_expected, op(df, None)) assert_series_equal(s_expected, op(s, None)) + + +@pytest.mark.parametrize( + "op", + [ + operator.add, + operator.mod, + operator.mul, + operator.sub, + ], +) +def test_null_column_arithmetic(op: Any) -> None: + df = pl.DataFrame({"a": [None, None], "b": [None, None]}) + expected_df = pl.DataFrame({"a": [None, None]}) + + output_df = df.select(op(pl.col("a"), pl.col("b"))) + assert_frame_equal(expected_df, output_df) + # test broadcast right + output_df = df.select(op(pl.col("a"), pl.Series([None]))) + assert_frame_equal(expected_df, output_df) + # test broadcast left + output_df = df.select(op(pl.Series("a", [None]), pl.col("a"))) + assert_frame_equal(expected_df, output_df)