From a5cc510ff3983b974578a09c28afb007b1112eca Mon Sep 17 00:00:00 2001 From: ritchie Date: Mon, 2 Sep 2024 17:08:25 +0200 Subject: [PATCH] use vec and elide bound checks --- .../polars-ops/src/frame/join/iejoin/mod.rs | 60 ++++++++++--------- crates/polars-ops/src/frame/join/mod.rs | 31 ++++++++++ 2 files changed, 63 insertions(+), 28 deletions(-) diff --git a/crates/polars-ops/src/frame/join/iejoin/mod.rs b/crates/polars-ops/src/frame/join/iejoin/mod.rs index 2198647d6cf6..8f0909262fe3 100644 --- a/crates/polars-ops/src/frame/join/iejoin/mod.rs +++ b/crates/polars-ops/src/frame/join/iejoin/mod.rs @@ -7,7 +7,7 @@ use polars_core::chunked_array::ChunkedArray; use polars_core::datatypes::{IdxCa, IdxType, NumericNative, PolarsNumericType}; use polars_core::frame::DataFrame; use polars_core::prelude::*; -use polars_core::with_match_physical_numeric_polars_type; +use polars_core::{with_match_physical_numeric_polars_type, POOL}; use polars_error::{polars_err, PolarsResult}; use polars_utils::total_ord::{TotalEq, TotalOrd}; use polars_utils::IdxSize; @@ -89,7 +89,7 @@ pub fn iejoin( build_l1_array(ca, &l1_order, left.height() as IdxSize) })?; - let y_ordered = y.take(&l1_order)?; + let y_ordered = unsafe { y.take_unchecked(&l1_order) }; let l2_sort_options = SortOptions::default() .with_maintain_order(true) .with_nulls_last(false) @@ -105,8 +105,8 @@ pub fn iejoin( // denoting which entries have been visited while traversing L2. let mut bit_array = FilteredBitArray::from_len_zeroed(l1_order.len()); - let mut left_row_ids_builder = PrimitiveChunkedBuilder::::new("".into(), 0); - let mut right_row_ids_builder = PrimitiveChunkedBuilder::::new("".into(), 0); + let mut left_row_idx: Vec = vec![]; + let mut right_row_idx: Vec = vec![]; let slice_end = match slice { Some((offset, len)) if offset >= 0 => Some(offset.saturating_add_unsigned(len as u64)), @@ -125,8 +125,8 @@ pub fn iejoin( p as usize, &mut bit_array, op1, - &mut left_row_ids_builder, - &mut right_row_ids_builder, + &mut left_row_idx, + &mut right_row_idx, ); if slice_end.is_some_and(|end| match_count >= end) { @@ -152,8 +152,8 @@ pub fn iejoin( p as usize, &bit_array, op1, - &mut left_row_ids_builder, - &mut right_row_ids_builder, + &mut left_row_idx, + &mut right_row_idx, ); } @@ -166,17 +166,21 @@ pub fn iejoin( } } - let left_rows = left_row_ids_builder.finish(); - let right_rows = right_row_ids_builder.finish(); - debug_assert_eq!(left_rows.len(), right_rows.len()); - let (left_rows, right_rows) = match slice { - None => (left_rows, right_rows), - Some((offset, len)) => (left_rows.slice(offset, len), right_rows.slice(offset, len)), + debug_assert_eq!(left_row_idx.len(), right_row_idx.len()); + let left_row_idx = IdxCa::from_vec("".into(), left_row_idx); + let right_row_idx = IdxCa::from_vec("".into(), right_row_idx); + let (left_row_idx, right_row_idx) = match slice { + None => (left_row_idx, right_row_idx), + Some((offset, len)) => (left_row_idx.slice(offset, len), right_row_idx.slice(offset, len)), }; - let join_left = left.take(&left_rows)?; - let join_right = right.take(&right_rows)?; + let (join_left, join_right) = unsafe { POOL.join(|| { + left.take_unchecked(&left_row_idx) + }, + || { + right.take_unchecked(&right_row_idx) + }) }; _finish_join(join_left, join_right, suffix) } @@ -205,8 +209,8 @@ trait L1Array { l1_index: usize, bit_array: &mut FilteredBitArray, op1: InequalityOperator, - left_row_ids: &mut PrimitiveChunkedBuilder, - right_row_ids: &mut PrimitiveChunkedBuilder, + left_row_ids: &mut Vec, + right_row_ids: &mut Vec, ) -> i64; fn process_lhs_entry( @@ -214,8 +218,8 @@ trait L1Array { l1_index: usize, bit_array: &FilteredBitArray, op1: InequalityOperator, - left_row_ids: &mut PrimitiveChunkedBuilder, - right_row_ids: &mut PrimitiveChunkedBuilder, + left_row_ids: &mut Vec, + right_row_ids: &mut Vec, ) -> i64; fn mark_visited(&self, index: usize, bit_array: &mut FilteredBitArray); @@ -292,8 +296,8 @@ fn find_matches_in_l1( row_index: i64, bit_array: &FilteredBitArray, op1: InequalityOperator, - left_row_ids: &mut PrimitiveChunkedBuilder, - right_row_ids: &mut PrimitiveChunkedBuilder, + left_row_ids: &mut Vec, + right_row_ids: &mut Vec, ) -> i64 where T: NumericNative, @@ -311,8 +315,8 @@ where bit_array.on_set_bits_from(start_index, |set_bit: usize| { let right_row_index = l1_array[set_bit].row_index; debug_assert!(right_row_index < 0); - left_row_ids.append_value((row_index - 1) as IdxSize); - right_row_ids.append_value((-right_row_index) as IdxSize - 1); + left_row_ids.push((row_index - 1) as IdxSize); + right_row_ids.push((-right_row_index) as IdxSize - 1); match_count += 1; }); @@ -328,8 +332,8 @@ where l1_index: usize, bit_array: &mut FilteredBitArray, op1: InequalityOperator, - left_row_ids: &mut PrimitiveChunkedBuilder, - right_row_ids: &mut PrimitiveChunkedBuilder, + left_row_ids: &mut Vec, + right_row_ids: &mut Vec, ) -> i64 { let row_index = self[l1_index].row_index; let from_lhs = row_index > 0; @@ -354,8 +358,8 @@ where l1_index: usize, bit_array: &FilteredBitArray, op1: InequalityOperator, - left_row_ids: &mut PrimitiveChunkedBuilder, - right_row_ids: &mut PrimitiveChunkedBuilder, + left_row_ids: &mut Vec, + right_row_ids: &mut Vec, ) -> i64 { let row_index = self[l1_index].row_index; let from_lhs = row_index > 0; diff --git a/crates/polars-ops/src/frame/join/mod.rs b/crates/polars-ops/src/frame/join/mod.rs index 6df584c4f946..6661dce64320 100644 --- a/crates/polars-ops/src/frame/join/mod.rs +++ b/crates/polars-ops/src/frame/join/mod.rs @@ -530,3 +530,34 @@ pub fn private_left_join_multiple_keys( let b = prepare_keys_multiple(b.get_columns(), join_nulls)?.into_series(); sort_or_hash_left(&a, &b, false, JoinValidation::ManyToMany, join_nulls) } + + +#[test] +fn test_foo() { + let west = df![ + "t_id" => [404, 498, 676, 742], + "time" => [100, 140, 80, 90], + "cost" => [6, 11, 10, 5], + "cores" => [4, 2, 1, 4] + ].unwrap(); + + let time = west.column("time").unwrap(); + let cost = west.column("cost").unwrap(); + + let selected = vec![time.clone(), cost.clone()]; + + let out = west._join_impl( + &west.clone(), + selected.clone(), + selected, + JoinArgs::new(JoinType::IEJoin(IEJoinOptions { + operator1: InequalityOperator::Gt, + operator2: InequalityOperator::Lt + })), + false, + false + ).unwrap(); + + dbg!(out); + +} \ No newline at end of file