Skip to content

Commit ece0079

Browse files
adamreeveritchie46
authored andcommitted
Tidy ups
1 parent 1fb3b15 commit ece0079

File tree

7 files changed

+92
-55
lines changed

7 files changed

+92
-55
lines changed

crates/polars-ops/src/frame/join/iejoin/filtered_bit_array.rs

+3-2
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@ use std::cmp::min;
22

33
use arrow::bitmap::MutableBitmap;
44

5-
/// Bit array with a filter to speed up searching for set bits,
6-
/// based on section 4.1 in Khayyat et al.
5+
/// Bit array with a filter to speed up searching for set bits when sparse,
6+
/// based on section 4.1 from Khayyat et al. 2015,
7+
/// "Lightning Fast and Space Efficient Inequality Joins"
78
pub struct FilteredBitArray {
89
bit_array: MutableBitmap,
910
filter: MutableBitmap,

crates/polars-ops/src/frame/join/iejoin/mod.rs

+68-28
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,11 @@ pub struct IEJoinOptions {
3333
pub operator2: InequalityOperator,
3434
}
3535

36-
/// Inequality join. Matches rows from this DataFrame with rows from another DataFrame
37-
/// using two inequality operators (one of [<, <=, >, >=]).
36+
/// Inequality join. Matches rows between two DataFrames using two inequality operators
37+
/// (one of [<, <=, >, >=]).
3838
/// Based on Khayyat et al. 2015, "Lightning Fast and Space Efficient Inequality Joins"
3939
/// and extended to work with duplicate values.
40-
pub fn join_dataframes(
40+
pub fn iejoin(
4141
left: &DataFrame,
4242
right: &DataFrame,
4343
selected_left: Vec<Series>,
@@ -221,6 +221,8 @@ trait L1Array {
221221
fn mark_visited(&self, index: usize, bit_array: &mut FilteredBitArray);
222222
}
223223

224+
/// Find the position in the L1 array where we should begin checking for matches,
225+
/// given the index in L1 corresponding to the current position in L2.
224226
fn find_search_start_index<T>(
225227
l1_array: &[L1Item<T>],
226228
index: usize,
@@ -284,6 +286,39 @@ where
284286
}
285287
}
286288

289+
fn find_matches_in_l1<T>(
290+
l1_array: &[L1Item<T>],
291+
l1_index: usize,
292+
row_index: i64,
293+
bit_array: &FilteredBitArray,
294+
op1: InequalityOperator,
295+
left_row_ids: &mut PrimitiveChunkedBuilder<IdxType>,
296+
right_row_ids: &mut PrimitiveChunkedBuilder<IdxType>,
297+
) -> i64
298+
where
299+
T: NumericNative,
300+
T: TotalOrd,
301+
{
302+
debug_assert!(row_index > 0);
303+
let mut match_count = 0;
304+
305+
// This entry comes from the left hand side DataFrame.
306+
// Find all following entries in L1 (meaning they satisfy the first operator)
307+
// that have already been visited (so satisfy the second operator).
308+
// Because we use a stable sort for l2, we know that we won't find any
309+
// matches for duplicate y values when traversing forwards in l1.
310+
let start_index = find_search_start_index(l1_array, l1_index, op1);
311+
bit_array.on_set_bits_from(start_index, |set_bit: usize| {
312+
let right_row_index = l1_array[set_bit].row_index;
313+
debug_assert!(right_row_index < 0);
314+
left_row_ids.append_value((row_index - 1) as IdxSize);
315+
right_row_ids.append_value((-right_row_index) as IdxSize - 1);
316+
match_count += 1;
317+
});
318+
319+
match_count
320+
}
321+
287322
impl<T> L1Array for Vec<L1Item<T>>
288323
where
289324
T: NumericNative,
@@ -296,27 +331,22 @@ where
296331
left_row_ids: &mut PrimitiveChunkedBuilder<IdxType>,
297332
right_row_ids: &mut PrimitiveChunkedBuilder<IdxType>,
298333
) -> i64 {
299-
let mut match_count = 0;
300334
let row_index = self[l1_index].row_index;
301335
let from_lhs = row_index > 0;
302336
if from_lhs {
303-
// This entry comes from the left hand side DataFrame.
304-
// Find all following entries in L1 (meaning they satisfy the first operator)
305-
// that have already been visited (so satisfy the second operator).
306-
// Because we use a stable sort for l2, we know that we won't find any
307-
// matches for duplicate y values when traversing forwards in l1.
308-
let start_index = find_search_start_index(self, l1_index, op1);
309-
bit_array.on_set_bits_from(start_index, |set_bit: usize| {
310-
let right_row_index = self[set_bit].row_index;
311-
debug_assert!(right_row_index < 0);
312-
left_row_ids.append_value((row_index - 1) as IdxSize);
313-
right_row_ids.append_value((-right_row_index) as IdxSize - 1);
314-
match_count += 1;
315-
});
337+
find_matches_in_l1(
338+
self,
339+
l1_index,
340+
row_index,
341+
bit_array,
342+
op1,
343+
left_row_ids,
344+
right_row_ids,
345+
)
316346
} else {
317347
bit_array.set_bit(l1_index);
348+
0
318349
}
319-
match_count
320350
}
321351

322352
fn process_lhs_entry(
@@ -327,30 +357,35 @@ where
327357
left_row_ids: &mut PrimitiveChunkedBuilder<IdxType>,
328358
right_row_ids: &mut PrimitiveChunkedBuilder<IdxType>,
329359
) -> i64 {
330-
let mut match_count = 0;
331360
let row_index = self[l1_index].row_index;
332361
let from_lhs = row_index > 0;
333362
if from_lhs {
334-
let start_index = find_search_start_index(self, l1_index, op1);
335-
bit_array.on_set_bits_from(start_index, |set_bit: usize| {
336-
let right_row_index = self[set_bit].row_index;
337-
debug_assert!(right_row_index < 0);
338-
left_row_ids.append_value((row_index - 1) as IdxSize);
339-
right_row_ids.append_value((-right_row_index) as IdxSize - 1);
340-
match_count += 1;
341-
});
363+
find_matches_in_l1(
364+
self,
365+
l1_index,
366+
row_index,
367+
bit_array,
368+
op1,
369+
left_row_ids,
370+
right_row_ids,
371+
)
372+
} else {
373+
0
342374
}
343-
match_count
344375
}
345376

346377
fn mark_visited(&self, index: usize, bit_array: &mut FilteredBitArray) {
347378
let from_lhs = self[index].row_index > 0;
379+
// We only mark RHS entries as visited,
380+
// so that we don't try to match LHS entries with other LHS entries.
348381
if !from_lhs {
349382
bit_array.set_bit(index);
350383
}
351384
}
352385
}
353386

387+
/// Create a vector of L1 items from the array of LHS x values concatenated with RHS x values
388+
/// and their ordering.
354389
fn build_l1_array<T>(
355390
ca: &ChunkedArray<T>,
356391
order: &IdxCa,
@@ -366,15 +401,20 @@ where
366401
// Nulls should have been skipped over
367402
.ok_or_else(|| polars_err!(ComputeError: "Unexpected null value in IEJoin data"))?;
368403
let row_index = if index < right_df_offset {
404+
// Row from LHS
369405
index as i64 + 1
370406
} else {
407+
// Row from RHS
371408
-((index - right_df_offset) as i64) - 1
372409
};
373410
array.push(L1Item { row_index, value });
374411
}
375412
Ok(Box::new(array))
376413
}
377414

415+
/// Create a vector of L2 items from the array of y values ordered according to the L1 order,
416+
/// and their ordering. We don't need to store actual y values but only track whether we're at
417+
/// the end of a run of equal values.
378418
fn build_l2_array<T>(ca: &ChunkedArray<T>, order: &IdxCa) -> PolarsResult<Vec<L2Item>>
379419
where
380420
T: PolarsNumericType,

crates/polars-ops/src/frame/join/mod.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ pub trait DataFrameJoinOps: IntoDf {
200200
}
201201

202202
if let JoinType::IEJoin(options) = args.how {
203-
return iejoin::join_dataframes(
203+
return iejoin::iejoin(
204204
left_df,
205205
other,
206206
selected_left,

crates/polars-python/src/lazyframe/general.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -969,7 +969,7 @@ impl PyLazyFrame {
969969
.into())
970970
}
971971

972-
fn ie_join(&self, other: Self, on: Vec<PyExpr>, suffix: String) -> PyResult<Self> {
972+
fn inequality_join(&self, other: Self, on: Vec<PyExpr>, suffix: String) -> PyResult<Self> {
973973
let ldf = self.ldf.clone();
974974
let other = other.ldf;
975975
let (left_on, operators, right_on) = parse_ie_join_expressions(on)?;

py-polars/polars/dataframe/frame.py

+6-8
Original file line numberDiff line numberDiff line change
@@ -7084,41 +7084,39 @@ def join(
70847084
.collect(_eager=True)
70857085
)
70867086

7087-
def ie_join(
7087+
def inequality_join(
70887088
self,
70897089
other: DataFrame,
70907090
*,
70917091
on: Sequence[Expr],
70927092
suffix: str = "_right",
70937093
) -> DataFrame:
70947094
"""
7095-
Perform a join using inequality operations.
7095+
Perform a join using two inequality expressions.
70967096
70977097
Parameters
70987098
----------
70997099
other
71007100
DataFrame to join with.
71017101
on
7102-
Inequality expressions to join with,
7102+
A sequence of two inequality expressions to join on, where each expression
7103+
is in the form `left_hand_side_expr op right_hand_side_expr` and op
7104+
is one of <, <=, >, >=.
71037105
for example [pl.col("a") < pl.col("b"), pl.col("c") > pl.col("d")]
71047106
suffix
71057107
Suffix to append to columns with a duplicate name.
71067108
71077109
Returns
71087110
-------
71097111
DataFrame
7110-
7111-
See Also
7112-
--------
7113-
join_asof
71147112
"""
71157113
if not isinstance(other, DataFrame):
71167114
msg = f"expected `other` join table to be a DataFrame, got {type(other).__name__!r}"
71177115
raise TypeError(msg)
71187116

71197117
return (
71207118
self.lazy()
7121-
.ie_join(
7119+
.inequality_join(
71227120
other.lazy(),
71237121
on=on,
71247122
suffix=suffix,

py-polars/polars/lazyframe/frame.py

+6-8
Original file line numberDiff line numberDiff line change
@@ -4341,33 +4341,31 @@ def join_asof(
43414341
)
43424342
)
43434343

4344-
def ie_join(
4344+
def inequality_join(
43454345
self,
43464346
other: LazyFrame,
43474347
*,
43484348
on: Sequence[Expr],
43494349
suffix: str = "_right",
43504350
) -> LazyFrame:
43514351
"""
4352-
Perform a join using inequality operations.
4352+
Perform a join using two inequality expressions.
43534353
43544354
Parameters
43554355
----------
43564356
other
43574357
LazyFrame to join with.
43584358
on
4359-
Inequality expressions to join with,
4359+
A sequence of two inequality expressions to join on, where each expression
4360+
is in the form `left_hand_side_expr op right_hand_side_expr` and op
4361+
is one of <, <=, >, >=.
43604362
for example [pl.col("a") < pl.col("b"), pl.col("c") > pl.col("d")]
43614363
suffix
43624364
Suffix to append to columns with a duplicate name.
43634365
43644366
Returns
43654367
-------
43664368
LazyFrame
4367-
4368-
See Also
4369-
--------
4370-
join_asof
43714369
"""
43724370
if not isinstance(other, LazyFrame):
43734371
msg = f"expected `other` join table to be a LazyFrame, not a {type(other).__name__!r}"
@@ -4382,7 +4380,7 @@ def ie_join(
43824380
raise ValueError(msg)
43834381

43844382
return self._from_pyldf(
4385-
self._ldf.ie_join(
4383+
self._ldf.inequality_join(
43864384
other._ldf,
43874385
[expr._pyexpr for expr in on],
43884386
suffix,

py-polars/tests/unit/operations/test_ie_join.py py-polars/tests/unit/operations/test_inequality_join.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def test_self_join() -> None:
2323
}
2424
)
2525

26-
actual = west.ie_join(
26+
actual = west.inequality_join(
2727
west, on=[pl.col("time") > pl.col("time"), pl.col("cost") < pl.col("cost")]
2828
)
2929

@@ -60,7 +60,7 @@ def test_basic_ie_join() -> None:
6060
}
6161
)
6262

63-
actual = east.ie_join(
63+
actual = east.inequality_join(
6464
west, on=[pl.col("dur") < pl.col("time"), pl.col("rev") > pl.col("cost")]
6565
)
6666

@@ -102,7 +102,7 @@ def test_ie_join_with_slice(offset: int, length: int) -> None:
102102
).lazy()
103103

104104
actual = (
105-
east.ie_join(
105+
east.inequality_join(
106106
west, on=[pl.col("dur") < pl.col("time"), pl.col("rev") < pl.col("cost")]
107107
)
108108
.slice(offset, length)
@@ -144,7 +144,7 @@ def test_ie_join_with_expressions() -> None:
144144
}
145145
)
146146

147-
actual = east.ie_join(
147+
actual = east.inequality_join(
148148
west,
149149
on=[
150150
(pl.col("dur") * 2) < pl.col("time"),
@@ -272,7 +272,7 @@ def test_ie_join(east: pl.DataFrame, west: pl.DataFrame, op1: str, op2: str) ->
272272
expr0 = _inequality_expression("dur", op1, "time")
273273
expr1 = _inequality_expression("rev", op2, "cost")
274274

275-
actual = east.ie_join(west, on=[expr0, expr1])
275+
actual = east.inequality_join(west, on=[expr0, expr1])
276276

277277
expected = east.join(west, how="cross").filter(expr0 & expr1)
278278
assert_frame_equal(actual, expected, check_row_order=False, check_exact=True)
@@ -290,7 +290,7 @@ def test_ie_join_with_nulls(
290290
expr0 = _inequality_expression("dur", op1, "time")
291291
expr1 = _inequality_expression("rev", op2, "cost")
292292

293-
actual = east.ie_join(west, on=[expr0, expr1])
293+
actual = east.inequality_join(west, on=[expr0, expr1])
294294

295295
expected = east.join(west, how="cross").filter(expr0 & expr1)
296296
assert_frame_equal(actual, expected, check_row_order=False, check_exact=True)
@@ -308,7 +308,7 @@ def test_ie_join_with_floats(
308308
expr0 = _inequality_expression("dur", op1, "time")
309309
expr1 = _inequality_expression("rev", op2, "cost")
310310

311-
actual = east.ie_join(west, on=[expr0, expr1])
311+
actual = east.inequality_join(west, on=[expr0, expr1])
312312

313313
expected = east.join(west, how="cross").filter(expr0 & expr1)
314314
assert_frame_equal(actual, expected, check_row_order=False, check_exact=True)

0 commit comments

Comments
 (0)