From 4e4d259ddf348385c127607ca3abc163a7344bf6 Mon Sep 17 00:00:00 2001 From: Marshall Date: Fri, 28 Feb 2025 08:24:29 -0500 Subject: [PATCH] feat: Add 'nulls_equal' parameter to `is_in` (#21426) --- crates/polars-expr/src/expressions/apply.rs | 4 +- crates/polars-lazy/src/tests/io.rs | 2 +- .../src/tests/predicate_queries.rs | 2 +- crates/polars-lazy/src/tests/queries.rs | 6 +- crates/polars-ops/src/series/ops/is_in.rs | 218 +++++++++--- crates/polars-ops/src/series/ops/replace.rs | 2 +- .../src/dsl/function_expr/array.rs | 1 + .../src/dsl/function_expr/binary.rs | 4 +- .../src/dsl/function_expr/boolean.rs | 11 +- .../polars-plan/src/dsl/function_expr/list.rs | 13 +- crates/polars-plan/src/dsl/mod.rs | 6 +- .../plans/aexpr/predicates/skip_batches.rs | 2 +- .../polars-plan/src/plans/aexpr/properties.rs | 2 +- .../src/plans/conversion/type_coercion/mod.rs | 4 +- .../optimizer/predicate_pushdown/join.rs | 2 +- .../polars-plan/src/plans/python/pyarrow.rs | 2 +- crates/polars-python/src/expr/general.rs | 4 +- .../src/lazyframe/visitor/expr_nodes.rs | 4 +- crates/polars-sql/src/sql_expr.rs | 10 +- .../polars/tests/it/lazy/expressions/is_in.rs | 2 +- .../polars/tests/it/lazy/predicate_queries.rs | 2 +- py-polars/polars/expr/expr.py | 11 +- py-polars/polars/series/series.py | 26 +- py-polars/tests/unit/operations/test_is_in.py | 327 ++++++++++++++++-- 24 files changed, 536 insertions(+), 131 deletions(-) diff --git a/crates/polars-expr/src/expressions/apply.rs b/crates/polars-expr/src/expressions/apply.rs index 5fe09068bd11..bdff32194cec 100644 --- a/crates/polars-expr/src/expressions/apply.rs +++ b/crates/polars-expr/src/expressions/apply.rs @@ -438,7 +438,7 @@ impl PhysicalExpr for ApplyExpr { match function { FunctionExpr::Boolean(BooleanFunction::IsNull) => Some(self), #[cfg(feature = "is_in")] - FunctionExpr::Boolean(BooleanFunction::IsIn) => Some(self), + FunctionExpr::Boolean(BooleanFunction::IsIn { .. }) => Some(self), #[cfg(feature = "is_between")] FunctionExpr::Boolean(BooleanFunction::IsBetween { closed: _ }) => Some(self), FunctionExpr::Boolean(BooleanFunction::IsNotNull) => Some(self), @@ -573,7 +573,7 @@ impl ApplyExpr { } }, #[cfg(feature = "is_in")] - FunctionExpr::Boolean(BooleanFunction::IsIn) => { + FunctionExpr::Boolean(BooleanFunction::IsIn { .. }) => { let should_read = || -> Option { let root = expr_to_leaf_column_name(&input[0]).ok()?; diff --git a/crates/polars-lazy/src/tests/io.rs b/crates/polars-lazy/src/tests/io.rs index 435fa74a373a..cfe93cfd778b 100644 --- a/crates/polars-lazy/src/tests/io.rs +++ b/crates/polars-lazy/src/tests/io.rs @@ -136,7 +136,7 @@ fn test_parquet_statistics() -> PolarsResult<()> { // issue: 13427 let out = scan_foods_parquet(par) - .filter(col("calories").is_in(lit(Series::new("".into(), [0, 500])))) + .filter(col("calories").is_in(lit(Series::new("".into(), [0, 500])), false)) .collect()?; assert_eq!(out.shape(), (0, 4)); diff --git a/crates/polars-lazy/src/tests/predicate_queries.rs b/crates/polars-lazy/src/tests/predicate_queries.rs index 71d24d1207e1..748bc4fc31f3 100644 --- a/crates/polars-lazy/src/tests/predicate_queries.rs +++ b/crates/polars-lazy/src/tests/predicate_queries.rs @@ -48,7 +48,7 @@ fn test_issue_2472() -> PolarsResult<()> { .extract(lit(r"(\d+-){4}(\w+)-"), 2) .cast(DataType::Int32) .alias("age"); - let predicate = col("age").is_in(lit(Series::new("".into(), [2i32]))); + let predicate = col("age").is_in(lit(Series::new("".into(), [2i32])), false); let out = base .clone() diff --git a/crates/polars-lazy/src/tests/queries.rs b/crates/polars-lazy/src/tests/queries.rs index da2063453ac6..e3fbe79e70a1 100644 --- a/crates/polars-lazy/src/tests/queries.rs +++ b/crates/polars-lazy/src/tests/queries.rs @@ -355,7 +355,7 @@ fn test_lazy_query_8() -> PolarsResult<()> { let mut selection = vec![]; for &c in &["A", "B", "C", "D", "E"] { - let e = when(col(c).is_in(col("E"))) + let e = when(col(c).is_in(col("E"), false)) .then(col("A")) .otherwise(Null {}.lit()) .alias(c); @@ -1761,7 +1761,7 @@ fn test_is_in() -> PolarsResult<()> { .clone() .lazy() .group_by_stable([col("fruits")]) - .agg([col("cars").is_in(col("cars").filter(col("cars").eq(lit("beetle"))))]) + .agg([col("cars").is_in(col("cars").filter(col("cars").eq(lit("beetle"))), false)]) .collect()?; let out = out.column("cars").unwrap(); let out = out.explode()?; @@ -1775,7 +1775,7 @@ fn test_is_in() -> PolarsResult<()> { let out = df .lazy() .group_by_stable([col("fruits")]) - .agg([col("cars").is_in(lit(Series::new("a".into(), ["beetle", "vw"])))]) + .agg([col("cars").is_in(lit(Series::new("a".into(), ["beetle", "vw"])), false)]) .collect()?; let out = out.column("cars").unwrap(); diff --git a/crates/polars-ops/src/series/ops/is_in.rs b/crates/polars-ops/src/series/ops/is_in.rs index 0f6e807ec03d..3520658b58a2 100644 --- a/crates/polars-ops/src/series/ops/is_in.rs +++ b/crates/polars-ops/src/series/ops/is_in.rs @@ -1,6 +1,6 @@ use std::hash::Hash; -use polars_core::prelude::arity::unary_elementwise_values; +use polars_core::prelude::arity::{unary_elementwise, unary_elementwise_values}; use polars_core::prelude::*; use polars_core::utils::{try_get_supertype, CustomIterTools}; use polars_core::with_match_physical_numeric_polars_type; @@ -11,6 +11,7 @@ use polars_utils::total_ord::{ToTotalOrd, TotalEq, TotalHash}; fn is_in_helper_ca<'a, T>( ca: &'a ChunkedArray, other: &'a ChunkedArray, + nulls_equal: bool, ) -> PolarsResult where T: PolarsDataType, @@ -25,20 +26,39 @@ where } }) }); - Ok( - unary_elementwise_values(ca, |val| set.contains(&val.to_total_ord())) - .with_name(ca.name().clone()), - ) + + if nulls_equal { + if other.has_nulls() { + // If the rhs has nulls, then nulls in the left set evaluates to true. + Ok(unary_elementwise(ca, |val| { + val.is_none_or(|v| set.contains(&v.to_total_ord())) + })) + } else { + // The rhs has no nulls; nulls in the left evaluates to false. + Ok(unary_elementwise(ca, |val| { + val.is_some_and(|v| set.contains(&v.to_total_ord())) + })) + } + } else { + Ok( + unary_elementwise_values(ca, |v| set.contains(&v.to_total_ord())) + .with_name(ca.name().clone()), + ) + } } -fn is_in_helper<'a, T>(ca: &'a ChunkedArray, other: &'a Series) -> PolarsResult +fn is_in_helper<'a, T>( + ca: &'a ChunkedArray, + other: &'a Series, + nulls_equal: bool, +) -> PolarsResult where T: PolarsDataType, T::Physical<'a>: TotalHash + TotalEq + Copy + ToTotalOrd, as ToTotalOrd>::TotalOrdItem: Hash + Eq + Copy, { let other = ca.unpack_series_matching_type(other)?; - is_in_helper_ca(ca, other) + is_in_helper_ca(ca, other, nulls_equal) } fn is_in_numeric_list(ca_in: &ChunkedArray, other: &Series) -> PolarsResult @@ -48,7 +68,6 @@ where { let mut ca: BooleanChunked = if ca_in.len() == 1 && other.len() != 1 { let value = ca_in.get(0); - other.list()?.apply_amortized_generic(|opt_s| { Some( opt_s.map(|s| { @@ -112,7 +131,11 @@ where Ok(ca) } -fn is_in_numeric(ca_in: &ChunkedArray, other: &Series) -> PolarsResult +fn is_in_numeric( + ca_in: &ChunkedArray, + other: &Series, + nulls_equal: bool, +) -> PolarsResult where T: PolarsNumericType, T::Native: TotalHash + TotalEq + ToTotalOrd, @@ -125,7 +148,7 @@ where if &st != ca_in.dtype() || **dt != st { let left = ca_in.cast(&st)?; let right = other.cast(&DataType::List(Box::new(st)))?; - return is_in(&left, &right); + return is_in(&left, &right, nulls_equal); }; is_in_numeric_list(ca_in, other) }, @@ -135,7 +158,7 @@ where if &st != ca_in.dtype() || **dt != st { let left = ca_in.cast(&st)?; let right = other.cast(&DataType::Array(Box::new(st), *width))?; - return is_in(&left, &right); + return is_in(&left, &right, nulls_equal); }; is_in_numeric_array(ca_in, other) }, @@ -145,9 +168,9 @@ where let st = try_get_supertype(ca_in.dtype(), other.dtype())?; let left = ca_in.cast(&st)?; let right = other.cast(&st)?; - return is_in(&left, &right); + return is_in(&left, &right, nulls_equal); } - is_in_helper(ca_in, other) + is_in_helper(ca_in, other, nulls_equal) }, } } @@ -205,7 +228,11 @@ fn is_in_string_list_categorical( Ok(ca) } -fn is_in_string(ca_in: &StringChunked, other: &Series) -> PolarsResult { +fn is_in_string( + ca_in: &StringChunked, + other: &Series, + nulls_equal: bool, +) -> PolarsResult { match other.dtype() { #[cfg(feature = "dtype-categorical")] DataType::List(dt) @@ -223,6 +250,7 @@ fn is_in_string(ca_in: &StringChunked, other: &Series) -> PolarsResult is_in_binary( @@ -230,13 +258,16 @@ fn is_in_string(ca_in: &StringChunked, other: &Series) -> PolarsResult is_in_binary( + &ca_in.as_binary(), + &other.cast(&DataType::Binary).unwrap(), + nulls_equal, ), - DataType::String => { - is_in_binary(&ca_in.as_binary(), &other.cast(&DataType::Binary).unwrap()) - }, #[cfg(feature = "dtype-categorical")] DataType::Enum(_, _) | DataType::Categorical(_, _) => { - is_in_string_categorical(ca_in, other.categorical().unwrap()) + is_in_string_categorical(ca_in, other.categorical().unwrap(), nulls_equal) }, _ => polars_bail!(opq = is_in, ca_in.dtype(), other.dtype()), } @@ -305,17 +336,26 @@ fn is_in_binary_array(ca_in: &BinaryChunked, other: &Series) -> PolarsResult PolarsResult { +fn is_in_binary( + ca_in: &BinaryChunked, + other: &Series, + nulls_equal: bool, +) -> PolarsResult { match other.dtype() { DataType::List(dt) if DataType::Binary == **dt => is_in_binary_list(ca_in, other), #[cfg(feature = "dtype-array")] DataType::Array(dt, _) if DataType::Binary == **dt => is_in_binary_array(ca_in, other), - DataType::Binary => is_in_helper(ca_in, other), + DataType::Binary => is_in_helper(ca_in, other, nulls_equal), _ => polars_bail!(opq = is_in, ca_in.dtype(), other.dtype()), } } -fn is_in_boolean_list(ca_in: &BooleanChunked, other: &Series) -> PolarsResult { +fn is_in_boolean_list( + ca_in: &BooleanChunked, + other: &Series, + _nulls_equal: bool, // NOTE: this is unimplemented at the moment. + // See https://github.com/pola-rs/polars/issues/21485. +) -> PolarsResult { let mut ca: BooleanChunked = if ca_in.len() == 1 && other.len() != 1 { let value = ca_in.get(0); // SAFETY: we know the iterators len @@ -353,7 +393,12 @@ fn is_in_boolean_list(ca_in: &BooleanChunked, other: &Series) -> PolarsResult PolarsResult { +fn is_in_boolean_array( + ca_in: &BooleanChunked, + other: &Series, + _nulls_equal: bool, // NOTE: this is unimplemented at the moment. + // https://github.com/pola-rs/polars/issues/21485 +) -> PolarsResult { let mut ca: BooleanChunked = if ca_in.len() == 1 && other.len() != 1 { let value = ca_in.get(0); // SAFETY: we know the iterators len @@ -388,11 +433,19 @@ fn is_in_boolean_array(ca_in: &BooleanChunked, other: &Series) -> PolarsResult PolarsResult { +fn is_in_boolean( + ca_in: &BooleanChunked, + other: &Series, + nulls_equal: bool, +) -> PolarsResult { match other.dtype() { - DataType::List(dt) if ca_in.dtype() == &**dt => is_in_boolean_list(ca_in, other), + DataType::List(dt) if ca_in.dtype() == &**dt => { + is_in_boolean_list(ca_in, other, nulls_equal) + }, #[cfg(feature = "dtype-array")] - DataType::Array(dt, _) if ca_in.dtype() == &**dt => is_in_boolean_array(ca_in, other), + DataType::Array(dt, _) if ca_in.dtype() == &**dt => { + is_in_boolean_array(ca_in, other, nulls_equal) + }, DataType::Boolean => { let other = other.bool().unwrap(); let has_true = other.any(); @@ -403,9 +456,20 @@ fn is_in_boolean(ca_in: &BooleanChunked, other: &Series) -> PolarsResult polars_bail!(opq = is_in, ca_in.dtype(), other.dtype()), } @@ -502,7 +566,11 @@ fn is_in_struct_array(ca_in: &StructChunked, other: &Series) -> PolarsResult PolarsResult { +fn is_in_struct( + ca_in: &StructChunked, + other: &Series, + nulls_equal: bool, +) -> PolarsResult { match other.dtype() { DataType::List(_) => is_in_struct_list(ca_in, other), #[cfg(feature = "dtype-array")] @@ -540,19 +608,21 @@ fn is_in_struct(ca_in: &StructChunked, other: &Series) -> PolarsResult 0 { let ca_in = ca_in.rechunk(); let mut ca_in_o = ca_in.get_row_encoded(Default::default())?; ca_in_o.merge_validities(ca_in.chunks()); - let ca_other = other.get_row_encoded(Default::default())?; - is_in_helper_ca(&ca_in_o, &ca_other) + let other = other.rechunk(); + let mut ca_other = other.get_row_encoded(Default::default())?; + ca_other.merge_validities(other.chunks()); + is_in_helper_ca(&ca_in_o, &ca_other, nulls_equal) } else { let ca_in = ca_in.get_row_encoded(Default::default())?; let ca_other = other.get_row_encoded(Default::default())?; - is_in_helper_ca(&ca_in, &ca_other) + is_in_helper_ca(&ca_in, &ca_other, nulls_equal) } }, } @@ -562,6 +632,7 @@ fn is_in_struct(ca_in: &StructChunked, other: &Series) -> PolarsResult PolarsResult { // In case of fast unique, we can directly use the categories. Otherwise we need to // first get the unique physicals @@ -576,16 +647,20 @@ fn is_in_string_categorical( // SAFETY: Invariant of categorical means indices are in bound unsafe { categories.take_unchecked(s.idx()?) } }; - is_in_helper_ca(&ca_in.as_binary(), &other.as_binary()) + is_in_helper_ca(&ca_in.as_binary(), &other.as_binary(), nulls_equal) } #[cfg(feature = "dtype-categorical")] -fn is_in_cat(ca_in: &CategoricalChunked, other: &Series) -> PolarsResult { +fn is_in_cat( + ca_in: &CategoricalChunked, + other: &Series, + nulls_equal: bool, +) -> PolarsResult { match other.dtype() { DataType::Categorical(_, _) | DataType::Enum(_, _) => { let (ca_in, other_in) = make_categoricals_compatible(ca_in, other.categorical().unwrap())?; - is_in_helper_ca(ca_in.physical(), other_in.physical()) + is_in_helper_ca(ca_in.physical(), other_in.physical(), nulls_equal) }, DataType::String => { let ca_other = other.str().unwrap(); @@ -620,10 +695,25 @@ fn is_in_cat(ca_in: &CategoricalChunked, other: &Series) -> PolarsResult PolarsResult PolarsResult { +fn is_in_null(s: &Series, other: &Series, nulls_equal: bool) -> PolarsResult { + if nulls_equal { + let ca_in = s.null()?; + Ok(match other.dtype() { + DataType::List(_) => other.list()?.apply_amortized_generic(|opt_s| { + Some(opt_s.map(|s| s.as_ref().has_nulls()) == Some(true)) + }), + #[cfg(feature = "dtype-array")] + DataType::Array(_, _) => other.array()?.apply_amortized_generic(|opt_s| { + Some(opt_s.map(|s| s.as_ref().has_nulls()) == Some(true)) + }), + _ => { + // If other has null values, then all are true, else all are false. + BooleanChunked::from_iter_values( + ca_in.name().clone(), + std::iter::repeat(other.has_nulls()).take(ca_in.len()), + ) + }, + }) + } else { + let out = s.cast(&DataType::Boolean)?; + let ca_bool = out.bool()?.clone(); + Ok(ca_bool) + } +} + +pub fn is_in(s: &Series, other: &Series, nulls_equal: bool) -> PolarsResult { match s.dtype() { #[cfg(feature = "dtype-categorical")] DataType::Categorical(_, _) | DataType::Enum(_, _) => { let ca = s.categorical().unwrap(); - is_in_cat(ca, other) + is_in_cat(ca, other, nulls_equal) }, #[cfg(feature = "dtype-struct")] DataType::Struct(_) => { let ca = s.struct_().unwrap(); - is_in_struct(ca, other) + is_in_struct(ca, other, nulls_equal) }, DataType::String => { let ca = s.str().unwrap(); - is_in_string(ca, other) + is_in_string(ca, other, nulls_equal) }, DataType::Binary => { let ca = s.binary().unwrap(); - is_in_binary(ca, other) + is_in_binary(ca, other, nulls_equal) }, DataType::Boolean => { let ca = s.bool().unwrap(); - is_in_boolean(ca, other) - }, - DataType::Null => { - let series_bool = s.cast(&DataType::Boolean)?; - let ca = series_bool.bool().unwrap(); - Ok(ca.clone()) + is_in_boolean(ca, other, nulls_equal) }, + DataType::Null => is_in_null(s, other, nulls_equal), #[cfg(feature = "dtype-decimal")] DataType::Decimal(_, _) => { let s = s.decimal()?; @@ -726,13 +838,13 @@ pub fn is_in(s: &Series, other: &Series) -> PolarsResult { let s = s.to_scale(scale)?; let other = other.to_scale(scale)?.into_owned().into_series(); - is_in_numeric(s.physical(), other.to_physical_repr().as_ref()) + is_in_numeric(s.physical(), other.to_physical_repr().as_ref(), nulls_equal) }, dt if dt.to_physical().is_primitive_numeric() => { let s = s.to_physical_repr(); with_match_physical_numeric_polars_type!(s.dtype(), |$T| { let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref(); - is_in_numeric(ca, other) + is_in_numeric(ca, other, nulls_equal) }) }, dt => polars_bail!(opq = is_in, dt), diff --git a/crates/polars-ops/src/series/ops/replace.rs b/crates/polars-ops/src/series/ops/replace.rs index d76427702253..4da56c669fb5 100644 --- a/crates/polars-ops/src/series/ops/replace.rs +++ b/crates/polars-ops/src/series/ops/replace.rs @@ -150,7 +150,7 @@ fn get_replacement_mask(s: &Series, old: &Series) -> PolarsResult PolarsResult { Ok(is_in( item.as_materialized_series(), array.as_materialized_series(), + true, )? .with_name(array.name().clone()) .into_column()) diff --git a/crates/polars-plan/src/dsl/function_expr/binary.rs b/crates/polars-plan/src/dsl/function_expr/binary.rs index 00855bd4e97b..818ca44fd008 100644 --- a/crates/polars-plan/src/dsl/function_expr/binary.rs +++ b/crates/polars-plan/src/dsl/function_expr/binary.rs @@ -27,7 +27,7 @@ impl BinaryFunction { pub(super) fn get_field(&self, mapper: FieldsMapper) -> PolarsResult { use BinaryFunction::*; match self { - Contains { .. } => mapper.with_dtype(DataType::Boolean), + Contains => mapper.with_dtype(DataType::Boolean), EndsWith | StartsWith => mapper.with_dtype(DataType::Boolean), #[cfg(feature = "binary_encoding")] HexDecode(_) | Base64Decode(_) => mapper.with_same_dtype(), @@ -44,7 +44,7 @@ impl Display for BinaryFunction { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { use BinaryFunction::*; let s = match self { - Contains { .. } => "contains", + Contains => "contains", StartsWith => "starts_with", EndsWith => "ends_with", #[cfg(feature = "binary_encoding")] diff --git a/crates/polars-plan/src/dsl/function_expr/boolean.rs b/crates/polars-plan/src/dsl/function_expr/boolean.rs index 089fed3dc51b..1015d12f35a0 100644 --- a/crates/polars-plan/src/dsl/function_expr/boolean.rs +++ b/crates/polars-plan/src/dsl/function_expr/boolean.rs @@ -36,7 +36,9 @@ pub enum BooleanFunction { closed: ClosedInterval, }, #[cfg(feature = "is_in")] - IsIn, + IsIn { + nulls_equal: bool, + }, AllHorizontal, AnyHorizontal, // Also bitwise negate @@ -84,7 +86,7 @@ impl Display for BooleanFunction { #[cfg(feature = "is_between")] IsBetween { .. } => "is_between", #[cfg(feature = "is_in")] - IsIn => "is_in", + IsIn { .. } => "is_in", AnyHorizontal => "any_horizontal", AllHorizontal => "all_horizontal", Not => "not", @@ -116,7 +118,7 @@ impl From for SpecialEq> { #[cfg(feature = "is_between")] IsBetween { closed } => map_as_slice!(is_between, closed), #[cfg(feature = "is_in")] - IsIn => wrap!(is_in), + IsIn { nulls_equal } => wrap!(is_in, nulls_equal), Not => map!(not), AllHorizontal => map_as_slice!(all_horizontal), AnyHorizontal => map_as_slice!(any_horizontal), @@ -207,12 +209,13 @@ fn is_between(s: &[Column], closed: ClosedInterval) -> PolarsResult { } #[cfg(feature = "is_in")] -fn is_in(s: &mut [Column]) -> PolarsResult> { +fn is_in(s: &mut [Column], nulls_equal: bool) -> PolarsResult> { let left = &s[0]; let other = &s[1]; polars_ops::prelude::is_in( left.as_materialized_series(), other.as_materialized_series(), + nulls_equal, ) .map(|ca| Some(ca.into_column())) } diff --git a/crates/polars-plan/src/dsl/function_expr/list.rs b/crates/polars-plan/src/dsl/function_expr/list.rs index 96a3cdc128e7..54f56bda617b 100644 --- a/crates/polars-plan/src/dsl/function_expr/list.rs +++ b/crates/polars-plan/src/dsl/function_expr/list.rs @@ -268,12 +268,15 @@ pub(super) fn contains(args: &mut [Column]) -> PolarsResult> { polars_ensure!(matches!(list.dtype(), DataType::List(_)), SchemaMismatch: "invalid series dtype: expected `List`, got `{}`", list.dtype(), ); - polars_ops::prelude::is_in(item.as_materialized_series(), list.as_materialized_series()).map( - |mut ca| { - ca.rename(list.name().clone()); - Some(ca.into_column()) - }, + polars_ops::prelude::is_in( + item.as_materialized_series(), + list.as_materialized_series(), + true, ) + .map(|mut ca| { + ca.rename(list.name().clone()); + Some(ca.into_column()) + }) } #[cfg(feature = "list_drop_nulls")] diff --git a/crates/polars-plan/src/dsl/mod.rs b/crates/polars-plan/src/dsl/mod.rs index 4d88b9ca9291..6cf5f89f35f2 100644 --- a/crates/polars-plan/src/dsl/mod.rs +++ b/crates/polars-plan/src/dsl/mod.rs @@ -1196,7 +1196,7 @@ impl Expr { /// Check if the values of the left expression are in the lists of the right expr. #[allow(clippy::wrong_self_convention)] #[cfg(feature = "is_in")] - pub fn is_in>(self, other: E) -> Self { + pub fn is_in>(self, other: E, nulls_equal: bool) -> Self { let other = other.into(); let has_literal = has_leaf_literal(&other); @@ -1207,14 +1207,14 @@ impl Expr { // we don't have to apply on groups, so this is faster if has_literal { self.map_many_private( - BooleanFunction::IsIn.into(), + BooleanFunction::IsIn { nulls_equal }.into(), arguments, returns_scalar, Some(Default::default()), ) } else { self.apply_many_private( - BooleanFunction::IsIn.into(), + BooleanFunction::IsIn { nulls_equal }.into(), arguments, returns_scalar, true, diff --git a/crates/polars-plan/src/plans/aexpr/predicates/skip_batches.rs b/crates/polars-plan/src/plans/aexpr/predicates/skip_batches.rs index 3cb4222ffaed..81adb10d765e 100644 --- a/crates/polars-plan/src/plans/aexpr/predicates/skip_batches.rs +++ b/crates/polars-plan/src/plans/aexpr/predicates/skip_batches.rs @@ -398,7 +398,7 @@ fn aexpr_to_skip_batch_predicate_rec( } => match function { FunctionExpr::Boolean(f) => match f { #[cfg(feature = "is_in")] - BooleanFunction::IsIn => { + BooleanFunction::IsIn { .. } => { let lv_node = input[1].node(); match ( into_column(input[0].node(), expr_arena, schema, 0), diff --git a/crates/polars-plan/src/plans/aexpr/properties.rs b/crates/polars-plan/src/plans/aexpr/properties.rs index e627e6fe7a7d..aad4301785e2 100644 --- a/crates/polars-plan/src/plans/aexpr/properties.rs +++ b/crates/polars-plan/src/plans/aexpr/properties.rs @@ -64,7 +64,7 @@ pub fn is_elementwise(stack: &mut UnitVec, ae: &AExpr, expr_arena: &Arena< // for inspection. (e.g. `is_in()`). #[cfg(feature = "is_in")] Function { - function: FunctionExpr::Boolean(BooleanFunction::IsIn), + function: FunctionExpr::Boolean(BooleanFunction::IsIn { .. }), input, .. } => (|| { diff --git a/crates/polars-plan/src/plans/conversion/type_coercion/mod.rs b/crates/polars-plan/src/plans/conversion/type_coercion/mod.rs index f8b3d802c5f1..cb320503f853 100644 --- a/crates/polars-plan/src/plans/conversion/type_coercion/mod.rs +++ b/crates/polars-plan/src/plans/conversion/type_coercion/mod.rs @@ -159,7 +159,7 @@ impl OptimizationRule for TypeCoercionRule { } => return process_binary(expr_arena, lp_arena, lp_node, node_left, op, node_right), #[cfg(feature = "is_in")] AExpr::Function { - function: FunctionExpr::Boolean(BooleanFunction::IsIn), + function: FunctionExpr::Boolean(BooleanFunction::IsIn { nulls_equal }), ref input, options, } => { @@ -173,7 +173,7 @@ impl OptimizationRule for TypeCoercionRule { input[1].set_node(other_input); Some(AExpr::Function { - function: FunctionExpr::Boolean(BooleanFunction::IsIn), + function: FunctionExpr::Boolean(BooleanFunction::IsIn { nulls_equal }), input, options, }) diff --git a/crates/polars-plan/src/plans/optimizer/predicate_pushdown/join.rs b/crates/polars-plan/src/plans/optimizer/predicate_pushdown/join.rs index cf991d9469f0..e50aa62f1a43 100644 --- a/crates/polars-plan/src/plans/optimizer/predicate_pushdown/join.rs +++ b/crates/polars-plan/src/plans/optimizer/predicate_pushdown/join.rs @@ -25,7 +25,7 @@ fn should_block_join_specific( } => join_produces_null(how), #[cfg(feature = "is_in")] Function { - function: FunctionExpr::Boolean(BooleanFunction::IsIn), + function: FunctionExpr::Boolean(BooleanFunction::IsIn { .. }), .. } => join_produces_null(how), // joins can produce duplicates diff --git a/crates/polars-plan/src/plans/python/pyarrow.rs b/crates/polars-plan/src/plans/python/pyarrow.rs index b3920b066440..1f675c6961f3 100644 --- a/crates/polars-plan/src/plans/python/pyarrow.rs +++ b/crates/polars-plan/src/plans/python/pyarrow.rs @@ -116,7 +116,7 @@ pub fn predicate_to_pa( }, #[cfg(feature = "is_in")] AExpr::Function { - function: FunctionExpr::Boolean(BooleanFunction::IsIn), + function: FunctionExpr::Boolean(BooleanFunction::IsIn { .. }), input, .. } => { diff --git a/crates/polars-python/src/expr/general.rs b/crates/polars-python/src/expr/general.rs index 79b2422e1169..9edbd1ab9a12 100644 --- a/crates/polars-python/src/expr/general.rs +++ b/crates/polars-python/src/expr/general.rs @@ -652,8 +652,8 @@ impl PyExpr { } #[cfg(feature = "is_in")] - fn is_in(&self, expr: Self) -> Self { - self.inner.clone().is_in(expr.inner).into() + fn is_in(&self, expr: Self, nulls_equal: bool) -> Self { + self.inner.clone().is_in(expr.inner, nulls_equal).into() } #[cfg(feature = "repeat_by")] diff --git a/crates/polars-python/src/lazyframe/visitor/expr_nodes.rs b/crates/polars-python/src/lazyframe/visitor/expr_nodes.rs index 0ab9a01c6e11..83dc5d7479e2 100644 --- a/crates/polars-python/src/lazyframe/visitor/expr_nodes.rs +++ b/crates/polars-python/src/lazyframe/visitor/expr_nodes.rs @@ -1086,7 +1086,9 @@ pub(crate) fn into_py(py: Python<'_>, expr: &AExpr) -> PyResult { (PyBooleanFunction::IsBetween, Into::<&str>::into(closed)).into_py_any(py) }, #[cfg(feature = "is_in")] - BooleanFunction::IsIn => (PyBooleanFunction::IsIn,).into_py_any(py), + BooleanFunction::IsIn { nulls_equal } => { + (PyBooleanFunction::IsIn, nulls_equal).into_py_any(py) + }, BooleanFunction::AllHorizontal => { (PyBooleanFunction::AllHorizontal,).into_py_any(py) }, diff --git a/crates/polars-sql/src/sql_expr.rs b/crates/polars-sql/src/sql_expr.rs index 4962eb2b844c..51f0fbb1178f 100644 --- a/crates/polars-sql/src/sql_expr.rs +++ b/crates/polars-sql/src/sql_expr.rs @@ -124,7 +124,7 @@ impl SQLExprVisitor<'_> { } => { let expr = self.visit_expr(expr)?; let elems = self.visit_array_expr(list, false, Some(&expr))?; - let is_in = expr.is_in(elems); + let is_in = expr.is_in(elems, false); Ok(if *negated { is_in.not() } else { is_in }) }, SQLExpr::InSubquery { @@ -692,8 +692,8 @@ impl SQLExprVisitor<'_> { SQLBinaryOperator::Lt => Ok(left.lt(right.max())), SQLBinaryOperator::GtEq => Ok(left.gt_eq(right.min())), SQLBinaryOperator::LtEq => Ok(left.lt_eq(right.max())), - SQLBinaryOperator::Eq => Ok(left.is_in(right)), - SQLBinaryOperator::NotEq => Ok(left.is_in(right).not()), + SQLBinaryOperator::Eq => Ok(left.is_in(right, false)), + SQLBinaryOperator::NotEq => Ok(left.is_in(right, false).not()), _ => polars_bail!(SQLInterface: "invalid comparison operator"), } } @@ -917,9 +917,9 @@ impl SQLExprVisitor<'_> { let subquery_result = self.visit_subquery(subquery, SubqueryRestriction::SingleColumn)?; let expr = self.visit_expr(expr)?; Ok(if negated { - expr.is_in(subquery_result).not() + expr.is_in(subquery_result, false).not() } else { - expr.is_in(subquery_result) + expr.is_in(subquery_result, false) }) } diff --git a/crates/polars/tests/it/lazy/expressions/is_in.rs b/crates/polars/tests/it/lazy/expressions/is_in.rs index 73591af48328..46a46e7eee3e 100644 --- a/crates/polars/tests/it/lazy/expressions/is_in.rs +++ b/crates/polars/tests/it/lazy/expressions/is_in.rs @@ -10,7 +10,7 @@ fn test_is_in() -> PolarsResult<()> { let out = df .lazy() - .select([col("y").is_in(lit(s)).alias("isin")]) + .select([col("y").is_in(lit(s), false).alias("isin")]) .collect()?; assert_eq!( Vec::from(out.column("isin")?.bool()?), diff --git a/crates/polars/tests/it/lazy/predicate_queries.rs b/crates/polars/tests/it/lazy/predicate_queries.rs index 49460facc118..feb643925b8e 100644 --- a/crates/polars/tests/it/lazy/predicate_queries.rs +++ b/crates/polars/tests/it/lazy/predicate_queries.rs @@ -140,7 +140,7 @@ fn test_is_in_categorical_3420() -> PolarsResult<()> { let out = df .lazy() .with_column(col("a").strict_cast(DataType::Categorical(None, Default::default()))) - .filter(col("a").is_in(lit(s).alias("x"))) + .filter(col("a").is_in(lit(s).alias("x"), false)) .collect()?; let mut expected = df![ diff --git a/py-polars/polars/expr/expr.py b/py-polars/polars/expr/expr.py index 5c4d294a45c7..a06cf950b068 100644 --- a/py-polars/polars/expr/expr.py +++ b/py-polars/polars/expr/expr.py @@ -5782,7 +5782,12 @@ def xor(self, other: Any) -> Expr: """ return self.__xor__(other) - def is_in(self, other: Expr | Collection[Any] | Series) -> Expr: + def is_in( + self, + other: Expr | Collection[Any] | Series, + *, + nulls_equal: bool = False, + ) -> Expr: """ Check if elements of this expression are present in the other Series. @@ -5790,6 +5795,8 @@ def is_in(self, other: Expr | Collection[Any] | Series) -> Expr: ---------- other Series or sequence of primitive type. + nulls_equal : bool, default False + If True, treat null as a distinct value. Null values will not propagate. Returns ------- @@ -5819,7 +5826,7 @@ def is_in(self, other: Expr | Collection[Any] | Series) -> Expr: other = F.lit(pl.Series(other))._pyexpr else: other = parse_into_expression(other) - return self._from_pyexpr(self._pyexpr.is_in(other)) + return self._from_pyexpr(self._pyexpr.is_in(other, nulls_equal)) def repeat_by(self, by: pl.Series | Expr | str | int) -> Expr: """ diff --git a/py-polars/polars/series/series.py b/py-polars/polars/series/series.py index be44f072d22e..a60ff4fbf421 100644 --- a/py-polars/polars/series/series.py +++ b/py-polars/polars/series/series.py @@ -3783,10 +3783,20 @@ def is_not_nan(self) -> Series: ] """ - def is_in(self, other: Series | Collection[Any]) -> Series: + def is_in( + self, + other: Series | Collection[Any], + *, + nulls_equal: bool = False, + ) -> Series: """ Check if elements of this Series are in the other Series. + Parameters + ---------- + nulls_equal : bool, default False + If True, treat null as a distinct value. Null values will not propagate. + Returns ------- Series @@ -3795,13 +3805,23 @@ def is_in(self, other: Series | Collection[Any]) -> Series: Examples -------- >>> s = pl.Series("a", [1, 2, 3]) - >>> s2 = pl.Series("b", [2, 4]) + >>> s2 = pl.Series("b", [2, 4, None]) >>> s2.is_in(s) - shape: (2,) + shape: (3,) Series: 'b' [bool] [ true false + null + ] + >>> # when nulls_equal=True, None is treated as a distinct value + >>> s2.is_in(s, nulls_equal=True) + shape: (3,) + Series: 'b' [bool] + [ + true + false + false ] >>> # check if some values are a member of sublists diff --git a/py-polars/tests/unit/operations/test_is_in.py b/py-polars/tests/unit/operations/test_is_in.py index dc5c1dfdaa5c..b5d69dbd9717 100644 --- a/py-polars/tests/unit/operations/test_is_in.py +++ b/py-polars/tests/unit/operations/test_is_in.py @@ -37,12 +37,88 @@ def test_struct_logical_is_in() -> None: assert s1.is_in(s2).to_list() == [False, False, True, True, True, True, True] -def test_is_in_bool() -> None: +def test_struct_logical_is_in_nonullpropagate() -> None: + s = pl.Series([date(2022, 1, 1), date(2022, 1, 2), date(2022, 1, 3), None]) + df1 = pl.DataFrame( + { + "x": s, + "y": [0, 4, 6, None], + } + ) + s = pl.Series([date(2022, 2, 1), date(2022, 1, 2), date(2022, 2, 3), None]) + df2 = pl.DataFrame( + { + "x": s, + "y": [6, 4, 3, None], + } + ) + + # Left has no nulls, right has nulls + s1 = df1.select(pl.struct(["x", "y"])).to_series() + s1 = s1.extend_constant(s1[0], 1) + s2 = df2.select(pl.struct(["x", "y"])).to_series().extend_constant(None, 1) + assert s1.is_in(s2, nulls_equal=False).to_list() == [ + False, + True, + False, + True, + False, + ] + assert s1.is_in(s2, nulls_equal=True).to_list() == [ + False, + True, + False, + True, + False, + ] + + # Left has nulls, right has no nulls + s1 = df1.select(pl.struct(["x", "y"])).to_series().extend_constant(None, 1) + s2 = df2.select(pl.struct(["x", "y"])).to_series() + s2 = s2.extend_constant(s2[0], 1) + assert s1.is_in(s2, nulls_equal=False).to_list() == [ + False, + True, + False, + True, + None, + ] + assert s1.is_in(s2, nulls_equal=True).to_list() == [ + False, + True, + False, + True, + False, + ] + + # Both have nulls + # {None, None} is a valid element unaffected by the missing parameter. + s1 = df1.select(pl.struct(["x", "y"])).to_series().extend_constant(None, 1) + s2 = df2.select(pl.struct(["x", "y"])).to_series().extend_constant(None, 1) + assert s1.is_in(s2, nulls_equal=False).to_list() == [ + False, + True, + False, + True, + None, + ] + assert s1.is_in(s2, nulls_equal=True).to_list() == [ + False, + True, + False, + True, + True, + ] + + +@pytest.mark.parametrize("nulls_equal", [False, True]) +def test_is_in_bool(nulls_equal: bool) -> None: vals = [True, None] df = pl.DataFrame({"A": [True, False, None]}) - assert df.select(pl.col("A").is_in(vals)).to_dict(as_series=False) == { - "A": [True, False, None] - } + missing_value = True if nulls_equal else None + assert df.select(pl.col("A").is_in(vals, nulls_equal=nulls_equal)).to_dict( + as_series=False + ) == {"A": [True, False, missing_value]} def test_is_in_bool_11216() -> None: @@ -51,8 +127,9 @@ def test_is_in_bool_11216() -> None: assert_series_equal(s, expected) -def test_is_in_empty_list_4559() -> None: - assert pl.Series(["a"]).is_in([]).to_list() == [False] +@pytest.mark.parametrize("nulls_equal", [False, True]) +def test_is_in_empty_list_4559(nulls_equal: bool) -> None: + assert pl.Series(["a"]).is_in([], nulls_equal=nulls_equal).to_list() == [False] def test_is_in_empty_list_4639() -> None: @@ -152,10 +229,72 @@ def test_is_in_series() -> None: assert_series_equal(c, pl.Series("c", [True, False])) -def test_is_in_null() -> None: +@pytest.mark.parametrize("nulls_equal", [False, True]) +def test_is_in_null(nulls_equal: bool) -> None: + # No nulls in right s = pl.Series([None, None], dtype=pl.Null) - result = s.is_in([1, 2, None]) - expected = pl.Series([None, None], dtype=pl.Boolean) + result = s.is_in([1, 2], nulls_equal=nulls_equal) + missing_value = False if nulls_equal else None + expected = pl.Series([missing_value, missing_value], dtype=pl.Boolean) + assert_series_equal(result, expected) + + # Nulls in right + s = pl.Series([None, None], dtype=pl.Null) + result = s.is_in([None, None], nulls_equal=nulls_equal) + missing_value = True if nulls_equal else None + expected = pl.Series([missing_value, missing_value], dtype=pl.Boolean) + assert_series_equal(result, expected) + + +@pytest.mark.parametrize("nulls_equal", [False, True]) +def test_is_in_boolean(nulls_equal: bool) -> None: + # Nulls in neither left nor right + s = pl.Series([True, False]) + result = s.is_in([True, False], nulls_equal=nulls_equal) + expected = pl.Series([True, True]) + assert_series_equal(result, expected) + + # Nulls in left only + s = pl.Series([True, None]) + result = s.is_in([False, False], nulls_equal=nulls_equal) + missing_value = False if nulls_equal else None + expected = pl.Series([False, missing_value]) + assert_series_equal(result, expected) + + # Nulls in right only + s = pl.Series([True, False]) + result = s.is_in([True, None], nulls_equal=nulls_equal) + expected = pl.Series([True, False]) + assert_series_equal(result, expected) + + # Nulls in both + s = pl.Series([True, False, None]) + result = s.is_in([True, None], nulls_equal=nulls_equal) + missing_value = True if nulls_equal else None + expected = pl.Series([True, False, missing_value]) + assert_series_equal(result, expected) + + +@pytest.mark.parametrize("dtype", [pl.List(pl.Boolean), pl.Array(pl.Boolean, 2)]) +def test_is_in_boolean_list(dtype: PolarsDataType) -> None: + # Note list is_in does not propagate nulls. + df = pl.DataFrame( + { + "a": [True, False, None, None, None], + "b": pl.Series( + [ + [True, False], + [True, True], + [None, True], + [False, True], + [True, True], + ], + dtype=dtype, + ), + } + ) + result = df.select(pl.col("a").is_in("b"))["a"] + expected = pl.Series("a", [True, False, True, False, False]) assert_series_equal(result, expected) @@ -225,10 +364,46 @@ def test_is_in_expr_list_series( df.select(expr_is_in) -def test_is_in_null_series() -> None: +@pytest.mark.parametrize( + ("df", "matches"), + [ + ( + pl.DataFrame({"a": [1, None], "b": [[1.0, 2.5, 4.0], [3.0, 4.0, 5.0]]}), + [True, False], + ), + ( + pl.DataFrame({"a": [1, None], "b": [[0.0, 2.5, None], [3.0, 4.0, None]]}), + [False, True], + ), + ( + pl.DataFrame( + {"a": [None, None], "b": [[1, 2], [3, 4]]}, + schema_overrides={"a": pl.Null}, + ), + [False, False], + ), + ( + pl.DataFrame( + {"a": [None, None], "b": [[1, 2], [3, None]]}, + schema_overrides={"a": pl.Null}, + ), + [False, True], + ), + ], +) +def test_is_in_expr_list_series_nonullpropagate( + df: pl.DataFrame, matches: list[bool] +) -> None: + expr_is_in = pl.col("a").is_in(pl.col("b"), nulls_equal=True) + assert df.select(expr_is_in).to_series().to_list() == matches + + +@pytest.mark.parametrize("nulls_equal", [False, True]) +def test_is_in_null_series(nulls_equal: bool) -> None: df = pl.DataFrame({"a": ["a", "b", None]}) - result = df.select(pl.col("a").is_in([None])) - expected = pl.DataFrame({"a": [False, False, None]}) + result = df.select(pl.col("a").is_in([None], nulls_equal=nulls_equal)) + missing_value = True if nulls_equal else None + expected = pl.DataFrame({"a": [False, False, missing_value]}) assert_frame_equal(result, expected) @@ -254,55 +429,67 @@ def test_is_in_date_range() -> None: @StringCache() @pytest.mark.parametrize("dtype", [pl.Categorical, pl.Enum(["a", "b", "c"])]) -def test_cat_is_in_series(dtype: pl.DataType) -> None: +@pytest.mark.parametrize("nulls_equal", [False, True]) +def test_cat_is_in_series(dtype: pl.DataType, nulls_equal: bool) -> None: s = pl.Series(["a", "b", "c", None], dtype=dtype) s2 = pl.Series(["b", "c"], dtype=dtype) - expected = pl.Series([False, True, True, None]) - assert_series_equal(s.is_in(s2), expected) + missing_value = False if nulls_equal else None + expected = pl.Series([False, True, True, missing_value]) + assert_series_equal(s.is_in(s2, nulls_equal=nulls_equal), expected) s2_str = s2.cast(pl.String) - assert_series_equal(s.is_in(s2_str), expected) + assert_series_equal(s.is_in(s2_str, nulls_equal=nulls_equal), expected) @StringCache() -def test_cat_is_in_series_non_existent() -> None: +@pytest.mark.parametrize("nulls_equal", [False, True]) +def test_cat_is_in_series_non_existent(nulls_equal: bool) -> None: dtype = pl.Categorical s = pl.Series(["a", "b", "c", None], dtype=dtype) s2 = pl.Series(["a", "d", "e"], dtype=dtype) - expected = pl.Series([True, False, False, None]) - assert_series_equal(s.is_in(s2), expected) + missing_value = False if nulls_equal else None + expected = pl.Series([True, False, False, missing_value]) + assert_series_equal(s.is_in(s2, nulls_equal=nulls_equal), expected) s2_str = s2.cast(pl.String) - assert_series_equal(s.is_in(s2_str), expected) + assert_series_equal(s.is_in(s2_str, nulls_equal=nulls_equal), expected) @StringCache() -def test_enum_is_in_series_non_existent() -> None: +@pytest.mark.parametrize("nulls_equal", [False, True]) +def test_enum_is_in_series_non_existent(nulls_equal: bool) -> None: dtype = pl.Enum(["a", "b", "c"]) + missing_value = False if nulls_equal else None s = pl.Series(["a", "b", "c", None], dtype=dtype) s2_str = pl.Series(["a", "d", "e"]) - expected = pl.Series([True, False, False, None]) - assert_series_equal(s.is_in(s2_str), expected) + expected = pl.Series([True, False, False, missing_value]) + assert_series_equal(s.is_in(s2_str, nulls_equal=nulls_equal), expected) @StringCache() @pytest.mark.parametrize("dtype", [pl.Categorical, pl.Enum(["a", "b", "c"])]) -def test_cat_is_in_with_lit_str(dtype: pl.DataType) -> None: +@pytest.mark.parametrize("nulls_equal", [False, True]) +def test_cat_is_in_with_lit_str(dtype: pl.DataType, nulls_equal: bool) -> None: + missing_value = False if nulls_equal else None s = pl.Series(["a", "b", "c", None], dtype=dtype) lit = ["b"] - expected = pl.Series([False, True, False, None]) + expected = pl.Series([False, True, False, missing_value]) - assert_series_equal(s.is_in(lit), expected) + assert_series_equal(s.is_in(lit, nulls_equal=nulls_equal), expected) @StringCache() +@pytest.mark.parametrize("nulls_equal", [False, True]) @pytest.mark.parametrize("dtype", [pl.Categorical, pl.Enum(["a", "b", "c"])]) -def test_cat_is_in_with_lit_str_non_existent(dtype: pl.DataType) -> None: +def test_cat_is_in_with_lit_str_non_existent( + dtype: pl.DataType, nulls_equal: bool +) -> None: + missing_value = False if nulls_equal else None s = pl.Series(["a", "b", "c", None], dtype=dtype) lit = ["d"] - expected = pl.Series([False, False, False, None]) + expected = pl.Series([False, False, False, missing_value]) - assert_series_equal(s.is_in(lit), expected) + assert_series_equal(s.is_in(lit, nulls_equal=nulls_equal), expected) @StringCache() @@ -409,7 +596,8 @@ def test_cat_list_is_in_from_single_str(val: str | None, expected: list[bool]) - assert_frame_equal(res, expected_df) -def test_is_in_struct_enum_17618() -> None: +@pytest.mark.parametrize("nulls_equal", [False, True]) +def test_is_in_struct_enum_17618(nulls_equal: bool) -> None: df = pl.DataFrame() dtype = pl.Enum(categories=["HBS"]) df = df.insert_column(0, pl.Series("category", [], dtype=dtype)) @@ -418,21 +606,31 @@ def test_is_in_struct_enum_17618() -> None: pl.Series( [{"category": "HBS"}], dtype=pl.Struct({"category": df["category"].dtype}), - ) + ), + nulls_equal=nulls_equal, ) ).shape == (0, 1) -def test_is_in_decimal() -> None: +@pytest.mark.parametrize("nulls_equal", [False, True]) +def test_is_in_decimal(nulls_equal: bool) -> None: assert pl.DataFrame({"a": [D("0.0"), D("0.2"), D("0.1")]}).select( - pl.col("a").is_in([0.0, 0.1]) + pl.col("a").is_in([0.0, 0.1], nulls_equal=nulls_equal) )["a"].to_list() == [True, False, True] assert pl.DataFrame({"a": [D("0.0"), D("0.2"), D("0.1")]}).select( - pl.col("a").is_in([D("0.0"), D("0.1")]) + pl.col("a").is_in([D("0.0"), D("0.1")], nulls_equal=nulls_equal) )["a"].to_list() == [True, False, True] assert pl.DataFrame({"a": [D("0.0"), D("0.2"), D("0.1")]}).select( - pl.col("a").is_in([1, 0, 2]) + pl.col("a").is_in([1, 0, 2], nulls_equal=nulls_equal) )["a"].to_list() == [True, False, False] + missing_value = True if nulls_equal else None + assert pl.DataFrame({"a": [D("0.0"), D("0.2"), None]}).select( + pl.col("a").is_in([0.0, 0.1, None], nulls_equal=nulls_equal) + )["a"].to_list() == [True, False, missing_value] + missing_value = False if nulls_equal else None + assert pl.DataFrame({"a": [D("0.0"), D("0.2"), None]}).select( + pl.col("a").is_in([0.0, 0.1], nulls_equal=nulls_equal) + )["a"].to_list() == [True, False, missing_value] def test_is_in_collection() -> None: @@ -464,3 +662,62 @@ def __len__(self) -> int: ): res = df.filter(pl.col("val").is_in(constraint_values)) assert set(res["lbl"]) == {"bb", "cc", "dd"} + + +@pytest.mark.parametrize("nulls_equal", [False, True]) +def test_null_propagate_all_paths(nulls_equal: bool) -> None: + # No nulls in either + s = pl.Series([1, 2, 3]) + result = s.is_in([1, 3, 8], nulls_equal=nulls_equal) + expected = pl.Series([True, False, True]) + assert_series_equal(result, expected) + + # Nulls in left only + s = pl.Series([1, 2, None]) + result = s.is_in([1, 3, 8], nulls_equal=nulls_equal) + missing_value = False if nulls_equal else None + expected = pl.Series([True, False, missing_value]) + assert_series_equal(result, expected) + + # Nulls in right only + s = pl.Series([1, 2, 3]) + result = s.is_in([1, 3, None], nulls_equal=nulls_equal) + expected = pl.Series([True, False, True]) + assert_series_equal(result, expected) + + # Nulls in both + s = pl.Series([1, 2, None]) + result = s.is_in([1, 3, None], nulls_equal=nulls_equal) + missing_value = True if nulls_equal else None + expected = pl.Series([True, False, missing_value]) + assert_series_equal(result, expected) + + +@pytest.mark.usefixtures("test_global_and_local") +@pytest.mark.parametrize("nulls_equal", [False, True]) +def test_null_propagate_all_paths_cat(nulls_equal: bool) -> None: + # No nulls in either + s = pl.Series(["1", "2", "3"]) + result = s.is_in(["1", "3", "8"], nulls_equal=nulls_equal) + expected = pl.Series([True, False, True]) + assert_series_equal(result, expected) + + # Nulls in left only + s = pl.Series(["1", "2", None]) + result = s.is_in(["1", "3", "8"], nulls_equal=nulls_equal) + missing_value = False if nulls_equal else None + expected = pl.Series([True, False, missing_value]) + assert_series_equal(result, expected) + + # Nulls in right only + s = pl.Series(["1", "2", "3"]) + result = s.is_in(["1", "3", None], nulls_equal=nulls_equal) + expected = pl.Series([True, False, True]) + assert_series_equal(result, expected) + + # Nulls in both + s = pl.Series(["1", "2", None]) + result = s.is_in(["1", "3", None], nulls_equal=nulls_equal) + missing_value = True if nulls_equal else None + expected = pl.Series([True, False, missing_value]) + assert_series_equal(result, expected)