Skip to content

Commit

Permalink
feat: Add join_nulls option to asof join
Browse files Browse the repository at this point in the history
  • Loading branch information
NicolasMuellerQC committed Feb 10, 2025
1 parent 495466f commit 98e0d1a
Show file tree
Hide file tree
Showing 7 changed files with 258 additions and 39 deletions.
82 changes: 55 additions & 27 deletions crates/polars-ops/src/frame/join/asof/groups.rs
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ fn asof_join_by_binary<B, T, A, F>(
right_asof: &ChunkedArray<T>,
filter: F,
allow_eq: bool,
join_nulls: bool,
) -> IdxArr
where
B: PolarsDataType,
Expand All @@ -178,7 +179,7 @@ where

let (prep_by_left, prep_by_right, _, _) = prepare_binary::<B>(by_left, by_right, false);
let offsets = compute_len_offsets(prep_by_left.iter().map(|s| s.len()));
let hash_tbls = build_tables(prep_by_right, false);
let hash_tbls = build_tables(prep_by_right, join_nulls);
let n_tables = hash_tbls.len();

// Now we probe the right hand side for each left hand side.
Expand Down Expand Up @@ -228,6 +229,7 @@ fn dispatch_join_by_type<T, A, F>(
right_by: &mut DataFrame,
filter: F,
allow_eq: bool,
join_nulls: bool,
) -> PolarsResult<IdxArr>
where
T: PolarsDataType,
Expand All @@ -242,22 +244,22 @@ where
polars_ensure!(left_dtype == right_dtype,
ComputeError: "mismatching dtypes in 'by' parameter of asof-join: `{left_dtype}` and `{right_dtype}`",
);
match left_dtype {
DataType::String => {
match (left_dtype, join_nulls) {
(DataType::String, _) => {
let left_by = &left_by_s.str().unwrap().as_binary();
let right_by = right_by_s.str().unwrap().as_binary();
asof_join_by_binary::<BinaryType, T, A, F>(
left_by, &right_by, left_asof, right_asof, filter, allow_eq,
left_by, &right_by, left_asof, right_asof, filter, allow_eq, join_nulls,
)
},
DataType::Binary => {
(DataType::Binary, _) => {
let left_by = &left_by_s.binary().unwrap();
let right_by = right_by_s.binary().unwrap();
asof_join_by_binary::<BinaryType, T, A, F>(
left_by, right_by, left_asof, right_asof, filter, allow_eq,
left_by, right_by, left_asof, right_asof, filter, allow_eq, join_nulls,
)
},
x if x.is_float() => {
(x, false) if x.is_float() => {
with_match_physical_float_polars_type!(left_by_s.dtype(), |$T| {
let left_by: &ChunkedArray<$T> = left_by_s.as_materialized_series().as_ref().as_ref().as_ref();
let right_by: &ChunkedArray<$T> = right_by_s.as_materialized_series().as_ref().as_ref().as_ref();
Expand All @@ -266,7 +268,7 @@ where
)?
})
},
_ => {
(_, false) => {
let left_by = left_by_s.bit_repr();
let right_by = right_by_s.bit_repr();

Expand All @@ -290,6 +292,16 @@ where
_ => unreachable!(),
}
},
(_, true) => {
let left_by_bin = left_by_s.strict_cast(&DataType::Binary)?;
let right_by_bin = right_by_s.strict_cast(&DataType::Binary)?;

let left_by = left_by_bin.binary().unwrap();
let right_by = right_by_bin.binary().unwrap();
asof_join_by_binary::<BinaryType, T, A, F>(
left_by, right_by, left_asof, right_asof, filter, allow_eq, join_nulls,
)
},
}
} else {
for (lhs, rhs) in left_by.get_columns().iter().zip(right_by.get_columns()) {
Expand All @@ -303,10 +315,10 @@ where
// TODO: @scalar-opt.
let left_by_series: Vec<_> = left_by.materialized_column_iter().cloned().collect();
let right_by_series: Vec<_> = right_by.materialized_column_iter().cloned().collect();
let lhs_keys = prepare_keys_multiple(&left_by_series, false)?;
let rhs_keys = prepare_keys_multiple(&right_by_series, false)?;
let lhs_keys = prepare_keys_multiple(&left_by_series, join_nulls)?;
let rhs_keys = prepare_keys_multiple(&right_by_series, join_nulls)?;
asof_join_by_binary::<BinaryOffsetType, T, A, F>(
&lhs_keys, &rhs_keys, left_asof, right_asof, filter, allow_eq,
&lhs_keys, &rhs_keys, left_asof, right_asof, filter, allow_eq, join_nulls,
)
};
Ok(out)
Expand All @@ -320,6 +332,7 @@ fn dispatch_join_strategy<T: PolarsDataType>(
right_by: &mut DataFrame,
strategy: AsofStrategy,
allow_eq: bool,
join_nulls: bool,
) -> PolarsResult<IdxArr>
where
for<'a> T::Physical<'a>: PartialOrd,
Expand All @@ -329,10 +342,10 @@ where
let filter = |_a: T::Physical<'_>, _b: T::Physical<'_>| true;
match strategy {
AsofStrategy::Backward => dispatch_join_by_type::<T, AsofJoinBackwardState, _>(
left_asof, right_asof, left_by, right_by, filter, allow_eq,
left_asof, right_asof, left_by, right_by, filter, allow_eq, join_nulls,
),
AsofStrategy::Forward => dispatch_join_by_type::<T, AsofJoinForwardState, _>(
left_asof, right_asof, left_by, right_by, filter, allow_eq,
left_asof, right_asof, left_by, right_by, filter, allow_eq, join_nulls,
),
AsofStrategy::Nearest => unimplemented!(),
}
Expand All @@ -347,6 +360,7 @@ fn dispatch_join_strategy_numeric<T: PolarsNumericType>(
strategy: AsofStrategy,
tolerance: Option<AnyValue<'static>>,
allow_eq: bool,
join_nulls: bool,
) -> PolarsResult<IdxArr> {
let right_ca = left_asof.unpack_series_matching_type(right_asof)?;

Expand All @@ -356,26 +370,26 @@ fn dispatch_join_strategy_numeric<T: PolarsNumericType>(
let filter = |a: T::Native, b: T::Native| a.abs_diff(b) <= abs_tolerance;
match strategy {
AsofStrategy::Backward => dispatch_join_by_type::<T, AsofJoinBackwardState, _>(
left_asof, right_ca, left_by, right_by, filter, allow_eq,
left_asof, right_ca, left_by, right_by, filter, allow_eq, join_nulls,
),
AsofStrategy::Forward => dispatch_join_by_type::<T, AsofJoinForwardState, _>(
left_asof, right_ca, left_by, right_by, filter, allow_eq,
left_asof, right_ca, left_by, right_by, filter, allow_eq, join_nulls,
),
AsofStrategy::Nearest => dispatch_join_by_type::<T, AsofJoinNearestState, _>(
left_asof, right_ca, left_by, right_by, filter, allow_eq,
left_asof, right_ca, left_by, right_by, filter, allow_eq, join_nulls,
),
}
} else {
let filter = |_a: T::Physical<'_>, _b: T::Physical<'_>| true;
match strategy {
AsofStrategy::Backward => dispatch_join_by_type::<T, AsofJoinBackwardState, _>(
left_asof, right_ca, left_by, right_by, filter, allow_eq,
left_asof, right_ca, left_by, right_by, filter, allow_eq, join_nulls,
),
AsofStrategy::Forward => dispatch_join_by_type::<T, AsofJoinForwardState, _>(
left_asof, right_ca, left_by, right_by, filter, allow_eq,
left_asof, right_ca, left_by, right_by, filter, allow_eq, join_nulls,
),
AsofStrategy::Nearest => dispatch_join_by_type::<T, AsofJoinNearestState, _>(
left_asof, right_ca, left_by, right_by, filter, allow_eq,
left_asof, right_ca, left_by, right_by, filter, allow_eq, join_nulls,
),
}
}
Expand All @@ -390,54 +404,55 @@ fn dispatch_join_type(
strategy: AsofStrategy,
tolerance: Option<AnyValue<'static>>,
allow_eq: bool,
join_nulls: bool,
) -> PolarsResult<IdxArr> {
match left_asof.dtype() {
DataType::Int64 => {
let ca = left_asof.i64().unwrap();
dispatch_join_strategy_numeric(
ca, right_asof, left_by, right_by, strategy, tolerance, allow_eq,
ca, right_asof, left_by, right_by, strategy, tolerance, allow_eq, join_nulls,
)
},
DataType::Int32 => {
let ca = left_asof.i32().unwrap();
dispatch_join_strategy_numeric(
ca, right_asof, left_by, right_by, strategy, tolerance, allow_eq,
ca, right_asof, left_by, right_by, strategy, tolerance, allow_eq, join_nulls,
)
},
DataType::UInt64 => {
let ca = left_asof.u64().unwrap();
dispatch_join_strategy_numeric(
ca, right_asof, left_by, right_by, strategy, tolerance, allow_eq,
ca, right_asof, left_by, right_by, strategy, tolerance, allow_eq, join_nulls,
)
},
DataType::UInt32 => {
let ca = left_asof.u32().unwrap();
dispatch_join_strategy_numeric(
ca, right_asof, left_by, right_by, strategy, tolerance, allow_eq,
ca, right_asof, left_by, right_by, strategy, tolerance, allow_eq, join_nulls,
)
},
DataType::Float32 => {
let ca = left_asof.f32().unwrap();
dispatch_join_strategy_numeric(
ca, right_asof, left_by, right_by, strategy, tolerance, allow_eq,
ca, right_asof, left_by, right_by, strategy, tolerance, allow_eq, join_nulls,
)
},
DataType::Float64 => {
let ca = left_asof.f64().unwrap();
dispatch_join_strategy_numeric(
ca, right_asof, left_by, right_by, strategy, tolerance, allow_eq,
ca, right_asof, left_by, right_by, strategy, tolerance, allow_eq, join_nulls,
)
},
DataType::Boolean => {
let ca = left_asof.bool().unwrap();
dispatch_join_strategy::<BooleanType>(
ca, right_asof, left_by, right_by, strategy, allow_eq,
ca, right_asof, left_by, right_by, strategy, allow_eq, join_nulls,
)
},
DataType::Binary => {
let ca = left_asof.binary().unwrap();
dispatch_join_strategy::<BinaryType>(
ca, right_asof, left_by, right_by, strategy, allow_eq,
ca, right_asof, left_by, right_by, strategy, allow_eq, join_nulls,
)
},
DataType::String => {
Expand All @@ -450,6 +465,7 @@ fn dispatch_join_type(
right_by,
strategy,
allow_eq,
join_nulls,
)
},
_ => {
Expand All @@ -464,6 +480,7 @@ fn dispatch_join_type(
strategy,
tolerance,
allow_eq,
join_nulls,
)
},
}
Expand All @@ -486,6 +503,7 @@ pub trait AsofJoinBy: IntoDf {
coalesce: bool,
allow_eq: bool,
check_sortedness: bool,
join_nulls: bool,
) -> PolarsResult<DataFrame> {
let (self_sliced_slot, other_sliced_slot, left_slice_s, right_slice_s); // Keeps temporaries alive.
let (self_df, other_df, left_key, right_key);
Expand Down Expand Up @@ -541,6 +559,7 @@ pub trait AsofJoinBy: IntoDf {
strategy,
tolerance,
allow_eq,
join_nulls,
)?;

let mut drop_these = right_by.get_column_names();
Expand Down Expand Up @@ -582,6 +601,7 @@ pub trait AsofJoinBy: IntoDf {
tolerance: Option<AnyValue<'static>>,
allow_eq: bool,
check_sortedness: bool,
join_nulls: bool,
) -> PolarsResult<DataFrame>
where
I: IntoIterator<Item = S>,
Expand All @@ -605,6 +625,7 @@ pub trait AsofJoinBy: IntoDf {
true,
allow_eq,
check_sortedness,
join_nulls,
)
}
}
Expand Down Expand Up @@ -638,6 +659,7 @@ mod test {
None,
true,
true,
false,
)?;
assert_eq!(out.get_column_names(), &["a", "b", "right_vals"]);
let out = out.column("right_vals").unwrap();
Expand Down Expand Up @@ -683,6 +705,7 @@ mod test {
None,
true,
true,
false,
)?;
let a = out.column("bid_right").unwrap();
let a = a.f64().unwrap();
Expand All @@ -700,6 +723,7 @@ mod test {
None,
true,
true,
false,
)?;
let a = out.column("bid_right").unwrap();
let a = a.f64().unwrap();
Expand Down Expand Up @@ -732,6 +756,7 @@ mod test {
None,
true,
true,
false,
)?;
assert_eq!(out.get_column_names(), &["a", "b", "right_vals"]);
let out = out.column("right_vals").unwrap();
Expand All @@ -751,6 +776,7 @@ mod test {
Some(AnyValue::Int32(1)),
true,
true,
false,
)?;
assert_eq!(out.get_column_names(), &["a", "b", "right_vals"]);
let out = out.column("right_vals").unwrap();
Expand Down Expand Up @@ -821,6 +847,7 @@ mod test {
None,
true,
true,
false,
)?;
let a = out.column("bid_right").unwrap();
let a = a.f64().unwrap();
Expand All @@ -844,6 +871,7 @@ mod test {
None,
true,
true,
false,
)?;
let a = out.column("bid_right").unwrap();
let a = a.f64().unwrap();
Expand Down
1 change: 1 addition & 0 deletions crates/polars-ops/src/frame/join/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,7 @@ pub trait DataFrameJoinOps: IntoDf {
should_coalesce,
options.allow_eq,
options.check_sortedness,
args.join_nulls,
),
(None, None) => left_df._join_asof(
other,
Expand Down
4 changes: 3 additions & 1 deletion crates/polars-python/src/lazyframe/general.rs
Original file line number Diff line number Diff line change
Expand Up @@ -969,7 +969,7 @@ impl PyLazyFrame {
}

#[cfg(feature = "asof_join")]
#[pyo3(signature = (other, left_on, right_on, left_by, right_by, allow_parallel, force_parallel, suffix, strategy, tolerance, tolerance_str, coalesce, allow_eq, check_sortedness))]
#[pyo3(signature = (other, left_on, right_on, left_by, right_by, allow_parallel, force_parallel, suffix, strategy, tolerance, tolerance_str, join_nulls, coalesce, allow_eq, check_sortedness))]
fn join_asof(
&self,
other: Self,
Expand All @@ -983,6 +983,7 @@ impl PyLazyFrame {
strategy: Wrap<AsofStrategy>,
tolerance: Option<Wrap<AnyValue<'_>>>,
tolerance_str: Option<String>,
join_nulls: bool,
coalesce: bool,
allow_eq: bool,
check_sortedness: bool,
Expand All @@ -1004,6 +1005,7 @@ impl PyLazyFrame {
.allow_parallel(allow_parallel)
.force_parallel(force_parallel)
.coalesce(coalesce)
.join_nulls(join_nulls)
.how(JoinType::AsOf(AsOfOptions {
strategy: strategy.0,
left_by: left_by.map(strings_to_pl_smallstr),
Expand Down
2 changes: 2 additions & 0 deletions docs/source/src/rust/user-guide/transformations/joins.rs
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
None,
true,
true,
false,
)?;
println!("{}", result);
// --8<-- [end:asof]
Expand All @@ -269,6 +270,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
Some(AnyValue::Duration(60000, TimeUnit::Milliseconds)),
true,
true,
false,
)?;
println!("{}", result);
// --8<-- [end:asof-tolerance]
Expand Down
5 changes: 5 additions & 0 deletions py-polars/polars/dataframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -6982,6 +6982,7 @@ def join_asof(
tolerance: str | int | float | timedelta | None = None,
allow_parallel: bool = True,
force_parallel: bool = False,
join_nulls: bool = False,
coalesce: bool = True,
allow_exact_matches: bool = True,
check_sortedness: bool = True,
Expand Down Expand Up @@ -7062,6 +7063,9 @@ def join_asof(
force_parallel
Force the physical plan to evaluate the computation of both DataFrames up to
the join in parallel.
join_nulls
Join on null values in the "by" part.
By default null values will never produce matches.
coalesce
Coalescing behavior (merging of `on` / `left_on` / `right_on` columns):
Expand Down Expand Up @@ -7312,6 +7316,7 @@ def join_asof(
tolerance=tolerance,
allow_parallel=allow_parallel,
force_parallel=force_parallel,
join_nulls=join_nulls,
coalesce=coalesce,
allow_exact_matches=allow_exact_matches,
check_sortedness=check_sortedness,
Expand Down
Loading

0 comments on commit 98e0d1a

Please sign in to comment.