Skip to content

Commit

Permalink
use vec and elide bound checks
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Sep 2, 2024
1 parent 70f8b6f commit a5cc510
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 28 deletions.
60 changes: 32 additions & 28 deletions crates/polars-ops/src/frame/join/iejoin/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use polars_core::chunked_array::ChunkedArray;
use polars_core::datatypes::{IdxCa, IdxType, NumericNative, PolarsNumericType};
use polars_core::frame::DataFrame;
use polars_core::prelude::*;
use polars_core::with_match_physical_numeric_polars_type;
use polars_core::{with_match_physical_numeric_polars_type, POOL};
use polars_error::{polars_err, PolarsResult};
use polars_utils::total_ord::{TotalEq, TotalOrd};
use polars_utils::IdxSize;
Expand Down Expand Up @@ -89,7 +89,7 @@ pub fn iejoin(
build_l1_array(ca, &l1_order, left.height() as IdxSize)
})?;

let y_ordered = y.take(&l1_order)?;
let y_ordered = unsafe { y.take_unchecked(&l1_order) };
let l2_sort_options = SortOptions::default()
.with_maintain_order(true)
.with_nulls_last(false)
Expand All @@ -105,8 +105,8 @@ pub fn iejoin(
// denoting which entries have been visited while traversing L2.
let mut bit_array = FilteredBitArray::from_len_zeroed(l1_order.len());

let mut left_row_ids_builder = PrimitiveChunkedBuilder::<IdxType>::new("".into(), 0);
let mut right_row_ids_builder = PrimitiveChunkedBuilder::<IdxType>::new("".into(), 0);
let mut left_row_idx: Vec<IdxSize> = vec![];
let mut right_row_idx: Vec<IdxSize> = vec![];

let slice_end = match slice {
Some((offset, len)) if offset >= 0 => Some(offset.saturating_add_unsigned(len as u64)),
Expand All @@ -125,8 +125,8 @@ pub fn iejoin(
p as usize,
&mut bit_array,
op1,
&mut left_row_ids_builder,
&mut right_row_ids_builder,
&mut left_row_idx,
&mut right_row_idx,
);

if slice_end.is_some_and(|end| match_count >= end) {
Expand All @@ -152,8 +152,8 @@ pub fn iejoin(
p as usize,
&bit_array,
op1,
&mut left_row_ids_builder,
&mut right_row_ids_builder,
&mut left_row_idx,
&mut right_row_idx,
);
}

Expand All @@ -166,17 +166,21 @@ pub fn iejoin(
}
}

let left_rows = left_row_ids_builder.finish();
let right_rows = right_row_ids_builder.finish();

debug_assert_eq!(left_rows.len(), right_rows.len());
let (left_rows, right_rows) = match slice {
None => (left_rows, right_rows),
Some((offset, len)) => (left_rows.slice(offset, len), right_rows.slice(offset, len)),
debug_assert_eq!(left_row_idx.len(), right_row_idx.len());
let left_row_idx = IdxCa::from_vec("".into(), left_row_idx);
let right_row_idx = IdxCa::from_vec("".into(), right_row_idx);
let (left_row_idx, right_row_idx) = match slice {
None => (left_row_idx, right_row_idx),
Some((offset, len)) => (left_row_idx.slice(offset, len), right_row_idx.slice(offset, len)),
};

let join_left = left.take(&left_rows)?;
let join_right = right.take(&right_rows)?;
let (join_left, join_right) = unsafe { POOL.join(|| {
left.take_unchecked(&left_row_idx)
},
|| {
right.take_unchecked(&right_row_idx)
}) };

_finish_join(join_left, join_right, suffix)
}
Expand Down Expand Up @@ -205,17 +209,17 @@ trait L1Array {
l1_index: usize,
bit_array: &mut FilteredBitArray,
op1: InequalityOperator,
left_row_ids: &mut PrimitiveChunkedBuilder<IdxType>,
right_row_ids: &mut PrimitiveChunkedBuilder<IdxType>,
left_row_ids: &mut Vec<IdxSize>,
right_row_ids: &mut Vec<IdxSize>,
) -> i64;

fn process_lhs_entry(
&self,
l1_index: usize,
bit_array: &FilteredBitArray,
op1: InequalityOperator,
left_row_ids: &mut PrimitiveChunkedBuilder<IdxType>,
right_row_ids: &mut PrimitiveChunkedBuilder<IdxType>,
left_row_ids: &mut Vec<IdxSize>,
right_row_ids: &mut Vec<IdxSize>,
) -> i64;

fn mark_visited(&self, index: usize, bit_array: &mut FilteredBitArray);
Expand Down Expand Up @@ -292,8 +296,8 @@ fn find_matches_in_l1<T>(
row_index: i64,
bit_array: &FilteredBitArray,
op1: InequalityOperator,
left_row_ids: &mut PrimitiveChunkedBuilder<IdxType>,
right_row_ids: &mut PrimitiveChunkedBuilder<IdxType>,
left_row_ids: &mut Vec<IdxSize>,
right_row_ids: &mut Vec<IdxSize>,
) -> i64
where
T: NumericNative,
Expand All @@ -311,8 +315,8 @@ where
bit_array.on_set_bits_from(start_index, |set_bit: usize| {
let right_row_index = l1_array[set_bit].row_index;
debug_assert!(right_row_index < 0);
left_row_ids.append_value((row_index - 1) as IdxSize);
right_row_ids.append_value((-right_row_index) as IdxSize - 1);
left_row_ids.push((row_index - 1) as IdxSize);
right_row_ids.push((-right_row_index) as IdxSize - 1);
match_count += 1;
});

Expand All @@ -328,8 +332,8 @@ where
l1_index: usize,
bit_array: &mut FilteredBitArray,
op1: InequalityOperator,
left_row_ids: &mut PrimitiveChunkedBuilder<IdxType>,
right_row_ids: &mut PrimitiveChunkedBuilder<IdxType>,
left_row_ids: &mut Vec<IdxSize>,
right_row_ids: &mut Vec<IdxSize>,
) -> i64 {
let row_index = self[l1_index].row_index;
let from_lhs = row_index > 0;
Expand All @@ -354,8 +358,8 @@ where
l1_index: usize,
bit_array: &FilteredBitArray,
op1: InequalityOperator,
left_row_ids: &mut PrimitiveChunkedBuilder<IdxType>,
right_row_ids: &mut PrimitiveChunkedBuilder<IdxType>,
left_row_ids: &mut Vec<IdxSize>,
right_row_ids: &mut Vec<IdxSize>,
) -> i64 {
let row_index = self[l1_index].row_index;
let from_lhs = row_index > 0;
Expand Down
31 changes: 31 additions & 0 deletions crates/polars-ops/src/frame/join/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -530,3 +530,34 @@ pub fn private_left_join_multiple_keys(
let b = prepare_keys_multiple(b.get_columns(), join_nulls)?.into_series();
sort_or_hash_left(&a, &b, false, JoinValidation::ManyToMany, join_nulls)
}


#[test]
fn test_foo() {
let west = df![
"t_id" => [404, 498, 676, 742],
"time" => [100, 140, 80, 90],
"cost" => [6, 11, 10, 5],
"cores" => [4, 2, 1, 4]
].unwrap();

let time = west.column("time").unwrap();
let cost = west.column("cost").unwrap();

let selected = vec![time.clone(), cost.clone()];

let out = west._join_impl(
&west.clone(),
selected.clone(),
selected,
JoinArgs::new(JoinType::IEJoin(IEJoinOptions {
operator1: InequalityOperator::Gt,
operator2: InequalityOperator::Lt
})),
false,
false
).unwrap();

dbg!(out);

}

0 comments on commit a5cc510

Please sign in to comment.