From b8d584f5b8502f272e3258cbe440caf2c2d6378b Mon Sep 17 00:00:00 2001 From: Orson Peters Date: Wed, 19 Feb 2025 12:39:36 +0100 Subject: [PATCH 01/25] wip --- crates/polars-expr/src/hash_keys.rs | 12 +--- .../src/nodes/joins/equi_join.rs | 71 +++++++++++++++++++ 2 files changed, 74 insertions(+), 9 deletions(-) diff --git a/crates/polars-expr/src/hash_keys.rs b/crates/polars-expr/src/hash_keys.rs index 639fa6fa8fdd..f832e1984fd7 100644 --- a/crates/polars-expr/src/hash_keys.rs +++ b/crates/polars-expr/src/hash_keys.rs @@ -74,9 +74,9 @@ impl HashKeys { self.len() == 0 } - /// After this call partition_idxs[p] will contain the indices of hashes - /// that belong to partition p, and the cardinality sketches are updated - /// accordingly. + /// After this call partition_idxs[p] will be extended with the indices of + /// hashes that belong to partition p, and the cardinality sketches are + /// updated accordingly. pub fn gen_partition_idxs( &self, partitioner: &HashPartitioner, @@ -168,9 +168,6 @@ impl RowEncodedKeys { ) { assert!(partition_idxs.len() == partitioner.num_partitions()); assert!(!BUILD_SKETCHES || sketches.len() == partitioner.num_partitions()); - for p in partition_idxs.iter_mut() { - p.clear(); - } if let Some(validity) = self.keys.validity() { for (i, (h, is_v)) in self.hashes.values_iter().zip(validity).enumerate() { @@ -272,9 +269,6 @@ impl SingleKeys { _partition_nulls: bool, ) { assert!(partitioner.num_partitions() == partition_idxs.len()); - for p in partition_idxs.iter_mut() { - p.clear(); - } todo!() } diff --git a/crates/polars-stream/src/nodes/joins/equi_join.rs b/crates/polars-stream/src/nodes/joins/equi_join.rs index 051440985e40..659a44324848 100644 --- a/crates/polars-stream/src/nodes/joins/equi_join.rs +++ b/crates/polars-stream/src/nodes/joins/equi_join.rs @@ -456,6 +456,71 @@ impl SampleState { } } +mod new { + use super::*; + + #[derive(Default)] + struct LocalBuilder { + // The complete list of morsels and their computed hashes seen by this builder. + morsels: Vec<(MorselSeq, DataFrame, HashKeys)>, + + // A cardinality sketch per partition for the keys seen by this builder. + sketch_per_p: Vec, + + // morsel_idxs_values_per_p[p][start..stop] contains the offsets into morsels[i] + // for partition p, where start, stop are: + // let start = morsel_idxs_offsets[i * num_partitions + p]; + // let stop = morsel_idxs_offsets[(i + 1) * num_partitions + p]; + morsel_idxs_values_per_p: Vec>, + morsel_idxs_offsets_per_p: Vec, + } + + async fn partition_and_sink( + mut recv: Receiver, + local: &mut LocalBuilder, + partitioner: HashPartitioner, + params: &EquiJoinParams, + state: &ExecutionState, + ) -> PolarsResult<()> { + let track_unmatchable = params.emit_unmatched_build(); + local.sketch_per_p.resize_with(partitioner.num_partitions(), Default::default); + local.morsel_idxs_values_per_p.resize_with(partitioner.num_partitions(), Default::default); + + if local.morsel_idxs_offsets_per_p.is_empty() { + local.morsel_idxs_offsets_per_p.resize(partitioner.num_partitions(), 0); + } + + let (key_selectors, payload_selector); + if params.left_is_build.unwrap() { + payload_selector = ¶ms.left_payload_select; + key_selectors = ¶ms.left_key_selectors; + } else { + payload_selector = ¶ms.right_payload_select; + key_selectors = ¶ms.right_key_selectors; + }; + + while let Ok(morsel) = recv.recv().await { + // Compute hashed keys and payload. We must rechunk the payload for + // later gathers. + let hash_keys = select_keys(morsel.df(), key_selectors, params, state).await?; + let mut payload = select_payload(morsel.df().clone(), payload_selector); + payload.rechunk_mut(); + + hash_keys.gen_partition_idxs( + &partitioner, + &mut local.morsel_idxs_values_per_p, + &mut local.sketch_per_p, + track_unmatchable, + ); + + local.morsel_idxs_offsets_per_p.extend(local.morsel_idxs_values_per_p.iter().map(|vp| vp.len())); + local.morsels.push((morsel.seq(), payload, hash_keys)); + } + Ok(()) + } +} + + #[derive(Default)] struct BuildPartition { hash_keys: Vec, @@ -500,6 +565,9 @@ impl BuildState { payload._deshare_views_mut(); unsafe { + for p in partition_idxs.iter_mut() { + p.clear(); + } hash_keys.gen_partition_idxs( &partitioner, &mut partition_idxs, @@ -672,6 +740,9 @@ impl ProbeState { unsafe { // Partition and probe the tables. + for p in partition_idxs.iter_mut() { + p.clear(); + } hash_keys.gen_partition_idxs( &partitioner, &mut partition_idxs, From aae2ac3c7ea1b5c3544fc7272f672d90734a3add Mon Sep 17 00:00:00 2001 From: Orson Peters Date: Tue, 4 Mar 2025 10:36:45 +0100 Subject: [PATCH 02/25] wip --- Cargo.lock | 1 + crates/polars-expr/src/idx_table/mod.rs | 68 + .../polars-expr/src/idx_table/row_encoded.rs | 333 +++++ crates/polars-expr/src/lib.rs | 1 + crates/polars-stream/Cargo.toml | 1 + .../src/nodes/joins/equi_join.rs | 103 +- crates/polars-stream/src/nodes/joins/mod.rs | 1 + .../src/nodes/joins/new_equi_join.rs | 1290 +++++++++++++++++ 8 files changed, 1714 insertions(+), 84 deletions(-) create mode 100644 crates/polars-expr/src/idx_table/mod.rs create mode 100644 crates/polars-expr/src/idx_table/row_encoded.rs create mode 100644 crates/polars-stream/src/nodes/joins/new_equi_join.rs diff --git a/Cargo.lock b/Cargo.lock index dcf382913791..b26c1405cf82 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3523,6 +3523,7 @@ dependencies = [ "memmap2", "parking_lot", "pin-project-lite", + "polars-arrow", "polars-core", "polars-error", "polars-expr", diff --git a/crates/polars-expr/src/idx_table/mod.rs b/crates/polars-expr/src/idx_table/mod.rs new file mode 100644 index 000000000000..7996d2f74e77 --- /dev/null +++ b/crates/polars-expr/src/idx_table/mod.rs @@ -0,0 +1,68 @@ +use std::any::Any; + +use polars_core::prelude::*; +use polars_utils::IdxSize; + +use crate::hash_keys::HashKeys; + +mod row_encoded; + +pub trait IdxTable: Any + Send + Sync { + /// Creates a new empty IdxTable similar to this one. + fn new_empty(&self) -> Box; + + /// Reserves space for the given number additional keys. + fn reserve(&mut self, additional: usize); + + /// Returns the number of unique keys in this IdxTable. + fn num_keys(&self) -> IdxSize; + + /// Inserts the given keys into this IdxTable. + fn insert_keys(&mut self, keys: &HashKeys, track_unmatchable: bool); + + /// Inserts a subset of the given keys into this IdxTable. + /// # Safety + /// The provided subset indices must be in-bounds. + unsafe fn insert_keys_subset(&mut self, keys: &HashKeys, subset: &[IdxSize], track_unmatchable: bool); + + /// Probe the table, adding an entry to table_match and probe_match for each + /// match. Will stop processing new keys once limit matches have been + /// generated, returning the number of keys processed. + /// + /// If mark_matches is true, matches are marked in the table as such. + /// + /// If emit_unmatched is true, for keys that do not have a match we emit a + /// match with ChunkId::null() on the table match. + fn probe( + &self, + hash_keys: &HashKeys, + table_match: &mut Vec, + probe_match: &mut Vec, + mark_matches: bool, + emit_unmatched: bool, + limit: IdxSize, + ) -> IdxSize; + + /// The same as probe, except it will only apply to the specified subset of keys. + /// # Safety + /// The provided subset indices must be in-bounds. + #[allow(clippy::too_many_arguments)] + unsafe fn probe_subset( + &self, + hash_keys: &HashKeys, + subset: &[IdxSize], + table_match: &mut Vec, + probe_match: &mut Vec, + mark_matches: bool, + emit_unmatched: bool, + limit: IdxSize, + ) -> IdxSize; + + /// Get the ChunkIds for each key which was never marked during probing. + fn unmarked_keys(&self, out: &mut Vec, offset: IdxSize, limit: IdxSize) + -> IdxSize; +} + +pub fn new_idx_table(_key_schema: Arc) -> Box { + Box::new(row_encoded::RowEncodedIdxTable::new()) +} diff --git a/crates/polars-expr/src/idx_table/row_encoded.rs b/crates/polars-expr/src/idx_table/row_encoded.rs new file mode 100644 index 000000000000..b8d09f448586 --- /dev/null +++ b/crates/polars-expr/src/idx_table/row_encoded.rs @@ -0,0 +1,333 @@ +use std::sync::atomic::{AtomicU64, Ordering}; + +use arrow::array::Array; +use polars_utils::idx_map::bytes_idx_map::{BytesIndexMap, Entry}; +use polars_utils::idx_vec::UnitVec; +use polars_utils::itertools::Itertools; +use polars_utils::unitvec; + +use super::*; +use crate::hash_keys::HashKeys; + +#[derive(Default)] +pub struct RowEncodedIdxTable { + // These AtomicU64s actually are IdxSizes, but we use the top bit of the + // first index in each to mark keys during probing. + idx_map: BytesIndexMap>, + idx_offset: IdxSize, + null_keys: Vec, +} + +impl RowEncodedIdxTable { + pub fn new() -> Self { + Self { + idx_map: BytesIndexMap::new(), + idx_offset: 0, + null_keys: Vec::new(), + } + } +} + +impl RowEncodedIdxTable { + #[inline(always)] + fn probe_one( + &self, + key_idx: IdxSize, + hash: u64, + key: &[u8], + table_match: &mut Vec, + probe_match: &mut Vec, + ) -> bool { + if let Some(idxs) = self.idx_map.get(hash, key) { + for idx in &idxs[..] { + // Create matches, making sure to clear top bit. + table_match.push(idx.load(Ordering::Relaxed) & !(1 << 63)); + probe_match.push(key_idx); + } + + // Mark if necessary. This action is idempotent so doesn't + // need any synchronization on the load, nor does it need a + // fetch_or to do it atomically. + if MARK_MATCHES { + let first_idx = unsafe { idxs.get_unchecked(0) }; + let first_idx_val = first_idx.load(Ordering::Relaxed); + if first_idx_val >> 63 == 0 { + first_idx.store(first_idx_val | (1 << 63), Ordering::Release); + } + } + true + } else { + false + } + } + + fn probe_impl<'a, const MARK_MATCHES: bool, const EMIT_UNMATCHED: bool>( + &self, + hash_keys: impl Iterator)>, + table_match: &mut Vec, + probe_match: &mut Vec, + limit: IdxSize, + ) -> IdxSize { + table_match.clear(); + probe_match.clear(); + + let mut keys_processed = 0; + for (key_idx, hash, key) in hash_keys { + let found_match = if let Some(key) = key { + self.probe_one::(key_idx, hash, key, table_match, probe_match) + } else { + false + }; + + if EMIT_UNMATCHED && !found_match { + table_match.push(IdxSize::MAX); + probe_match.push(key_idx); + } + + keys_processed += 1; + if table_match.len() >= limit as usize { + break; + } + } + keys_processed + } + + fn probe_dispatch<'a>( + &self, + hash_keys: impl Iterator)>, + table_match: &mut Vec, + probe_match: &mut Vec, + mark_matches: bool, + emit_unmatched: bool, + limit: IdxSize, + ) -> IdxSize { + match (mark_matches, emit_unmatched) { + (false, false) => { + self.probe_impl::(hash_keys, table_match, probe_match, limit) + }, + (false, true) => { + self.probe_impl::(hash_keys, table_match, probe_match, limit) + }, + (true, false) => { + self.probe_impl::(hash_keys, table_match, probe_match, limit) + }, + (true, true) => { + self.probe_impl::(hash_keys, table_match, probe_match, limit) + }, + } + } +} + +impl IdxTable for RowEncodedIdxTable { + fn new_empty(&self) -> Box { + Box::new(Self::new()) + } + + fn reserve(&mut self, additional: usize) { + self.idx_map.reserve(additional); + } + + fn num_keys(&self) -> IdxSize { + self.idx_map.len() + } + + fn insert_keys(&mut self, hash_keys: &HashKeys, track_unmatchable: bool) { + let HashKeys::RowEncoded(hash_keys) = hash_keys else { + unreachable!() + }; + let new_idx_offset = (self.idx_offset as usize).checked_add(hash_keys.keys.len()).unwrap(); + assert!(new_idx_offset < IdxSize::MAX as usize, "overly large index in RowEncodedIdxTable"); + + for (i, (hash, key)) in hash_keys + .hashes + .values_iter() + .zip(hash_keys.keys.iter()) + .enumerate_idx() + { + let idx = self.idx_offset + i; + if let Some(key) = key { + match self.idx_map.entry(*hash, key) { + Entry::Occupied(o) => { + o.into_mut().push(AtomicU64::new(idx)); + }, + Entry::Vacant(v) => { + v.insert(unitvec![AtomicU64::new(idx)]); + }, + } + } else if track_unmatchable { + self.null_keys.push(idx); + } + } + + self.idx_offset = new_idx_offset as IdxSize; + } + + unsafe fn insert_keys_subset(&mut self, hash_keys: &HashKeys, subset: &[IdxSize], track_unmatchable: bool) { + let HashKeys::RowEncoded(hash_keys) = hash_keys else { + unreachable!() + }; + let new_idx_offset = (self.idx_offset as usize).checked_add(subset.len()).unwrap(); + assert!(new_idx_offset < IdxSize::MAX as usize, "overly large index in RowEncodedIdxTable"); + + for (i, subset_idx) in subset.iter().enumerate_idx() { + let hash = unsafe { hash_keys.hashes.value_unchecked(*subset_idx as usize) }; + let key = unsafe { hash_keys.keys.get_unchecked(*subset_idx as usize) }; + let idx = self.idx_offset + i; + if let Some(key) = key { + match self.idx_map.entry(hash, key) { + Entry::Occupied(o) => { + o.into_mut().push(AtomicU64::new(idx)); + }, + Entry::Vacant(v) => { + v.insert(unitvec![AtomicU64::new(idx)]); + }, + } + } else if track_unmatchable { + self.null_keys.push(idx); + } + } + + self.idx_offset = new_idx_offset as IdxSize; + } + + fn probe( + &self, + hash_keys: &HashKeys, + table_match: &mut Vec, + probe_match: &mut Vec, + mark_matches: bool, + emit_unmatched: bool, + limit: IdxSize, + ) -> IdxSize { + let HashKeys::RowEncoded(hash_keys) = hash_keys else { + unreachable!() + }; + + if hash_keys.keys.has_nulls() { + let iter = hash_keys + .hashes + .values_iter() + .copied() + .zip(hash_keys.keys.iter()) + .enumerate_idx() + .map(|(i, (h, k))| (i, h, k)); + self.probe_dispatch( + iter, + table_match, + probe_match, + mark_matches, + emit_unmatched, + limit, + ) + } else { + let iter = hash_keys + .hashes + .values_iter() + .copied() + .zip(hash_keys.keys.values_iter().map(Some)) + .enumerate_idx() + .map(|(i, (h, k))| (i, h, k)); + self.probe_dispatch( + iter, + table_match, + probe_match, + mark_matches, + emit_unmatched, + limit, + ) + } + } + + unsafe fn probe_subset( + &self, + hash_keys: &HashKeys, + subset: &[IdxSize], + table_match: &mut Vec, + probe_match: &mut Vec, + mark_matches: bool, + emit_unmatched: bool, + limit: IdxSize, + ) -> IdxSize { + let HashKeys::RowEncoded(hash_keys) = hash_keys else { + unreachable!() + }; + + if hash_keys.keys.has_nulls() { + let iter = subset.iter().map(|i| { + ( + *i, + hash_keys.hashes.value_unchecked(*i as usize), + hash_keys.keys.get_unchecked(*i as usize), + ) + }); + self.probe_dispatch( + iter, + table_match, + probe_match, + mark_matches, + emit_unmatched, + limit, + ) + } else { + let iter = subset.iter().map(|i| { + ( + *i, + hash_keys.hashes.value_unchecked(*i as usize), + Some(hash_keys.keys.value_unchecked(*i as usize)), + ) + }); + self.probe_dispatch( + iter, + table_match, + probe_match, + mark_matches, + emit_unmatched, + limit, + ) + } + } + + fn unmarked_keys( + &self, + out: &mut Vec, + mut offset: IdxSize, + limit: IdxSize, + ) -> IdxSize { + out.clear(); + + let mut keys_processed = 0; + if (offset as usize) < self.null_keys.len() { + out.extend( + self.null_keys[offset as usize..] + .iter() + .copied() + .take(limit as usize), + ); + keys_processed += out.len() as IdxSize; + offset += out.len() as IdxSize; + if out.len() >= limit as usize { + return keys_processed; + } + } + + offset -= self.null_keys.len() as IdxSize; + + while let Some((_, _, idxs)) = self.idx_map.get_index(offset) { + let first_idx = unsafe { idxs.get_unchecked(0) }; + let first_idx_val = first_idx.load(Ordering::Acquire); + if first_idx_val >> 63 == 0 { + for idx in &idxs[..] { + out.push(idx.load(Ordering::Relaxed) & !(1 << 63)); + } + } + + keys_processed += 1; + offset += 1; + if out.len() >= limit as usize { + break; + } + } + + keys_processed + } +} diff --git a/crates/polars-expr/src/lib.rs b/crates/polars-expr/src/lib.rs index 138068e3c268..ef6e96db1b1a 100644 --- a/crates/polars-expr/src/lib.rs +++ b/crates/polars-expr/src/lib.rs @@ -2,6 +2,7 @@ pub mod chunked_idx_table; mod expressions; pub mod groups; pub mod hash_keys; +pub mod idx_table; pub mod planner; pub mod prelude; pub mod reduce; diff --git a/crates/polars-stream/Cargo.toml b/crates/polars-stream/Cargo.toml index cbf2f9fbe528..dde5bc4cf73b 100644 --- a/crates/polars-stream/Cargo.toml +++ b/crates/polars-stream/Cargo.toml @@ -9,6 +9,7 @@ repository = { workspace = true } description = "Private crate for the streaming execution engine for the Polars DataFrame library" [dependencies] +arrow = { workspace = true } atomic-waker = { workspace = true } crossbeam-deque = { workspace = true } crossbeam-queue = { workspace = true } diff --git a/crates/polars-stream/src/nodes/joins/equi_join.rs b/crates/polars-stream/src/nodes/joins/equi_join.rs index 659a44324848..024133b2f699 100644 --- a/crates/polars-stream/src/nodes/joins/equi_join.rs +++ b/crates/polars-stream/src/nodes/joins/equi_join.rs @@ -456,71 +456,6 @@ impl SampleState { } } -mod new { - use super::*; - - #[derive(Default)] - struct LocalBuilder { - // The complete list of morsels and their computed hashes seen by this builder. - morsels: Vec<(MorselSeq, DataFrame, HashKeys)>, - - // A cardinality sketch per partition for the keys seen by this builder. - sketch_per_p: Vec, - - // morsel_idxs_values_per_p[p][start..stop] contains the offsets into morsels[i] - // for partition p, where start, stop are: - // let start = morsel_idxs_offsets[i * num_partitions + p]; - // let stop = morsel_idxs_offsets[(i + 1) * num_partitions + p]; - morsel_idxs_values_per_p: Vec>, - morsel_idxs_offsets_per_p: Vec, - } - - async fn partition_and_sink( - mut recv: Receiver, - local: &mut LocalBuilder, - partitioner: HashPartitioner, - params: &EquiJoinParams, - state: &ExecutionState, - ) -> PolarsResult<()> { - let track_unmatchable = params.emit_unmatched_build(); - local.sketch_per_p.resize_with(partitioner.num_partitions(), Default::default); - local.morsel_idxs_values_per_p.resize_with(partitioner.num_partitions(), Default::default); - - if local.morsel_idxs_offsets_per_p.is_empty() { - local.morsel_idxs_offsets_per_p.resize(partitioner.num_partitions(), 0); - } - - let (key_selectors, payload_selector); - if params.left_is_build.unwrap() { - payload_selector = ¶ms.left_payload_select; - key_selectors = ¶ms.left_key_selectors; - } else { - payload_selector = ¶ms.right_payload_select; - key_selectors = ¶ms.right_key_selectors; - }; - - while let Ok(morsel) = recv.recv().await { - // Compute hashed keys and payload. We must rechunk the payload for - // later gathers. - let hash_keys = select_keys(morsel.df(), key_selectors, params, state).await?; - let mut payload = select_payload(morsel.df().clone(), payload_selector); - payload.rechunk_mut(); - - hash_keys.gen_partition_idxs( - &partitioner, - &mut local.morsel_idxs_values_per_p, - &mut local.sketch_per_p, - track_unmatchable, - ); - - local.morsel_idxs_offsets_per_p.extend(local.morsel_idxs_values_per_p.iter().map(|vp| vp.len())); - local.morsels.push((morsel.seq(), payload, hash_keys)); - } - Ok(()) - } -} - - #[derive(Default)] struct BuildPartition { hash_keys: Vec, @@ -670,8 +605,8 @@ impl BuildState { accumulate_dataframes_vertical_unchecked(combined_frames) }; ProbeTable { - table, - df, + hash_table: table, + payload: df, chunk_seq_ids, } }) @@ -689,8 +624,8 @@ impl BuildState { struct ProbeTable { // Important that df is not rechunked, the chunks it was inserted with // into the table must be preserved for chunked gathers. - table: Box, - df: DataFrame, + hash_table: Box, + payload: DataFrame, chunk_seq_ids: Vec, } @@ -755,7 +690,7 @@ impl ProbeState { let mut out_per_partition = Vec::with_capacity(partitioner.num_partitions()); let name = PlSmallStr::from_static("__POLARS_PROBE_PRESERVE_ORDER_IDX"); for (p, idxs_in_p) in partitions.iter().zip(&partition_idxs) { - p.table.probe_subset( + p.hash_table.probe_subset( &hash_keys, idxs_in_p, &mut table_match, @@ -771,9 +706,9 @@ impl ProbeState { // Gather output and add to buffer. let mut build_df = if emit_unmatched { - p.df.take_opt_chunked_unchecked(&table_match, false) + p.payload.take_opt_chunked_unchecked(&table_match, false) } else { - p.df.take_chunked_unchecked(&table_match, IsSorted::Not, false) + p.payload.take_chunked_unchecked(&table_match, IsSorted::Not, false) }; if !payload_rechunked { @@ -824,7 +759,7 @@ impl ProbeState { for (p, idxs_in_p) in partitions.iter().zip(&partition_idxs) { let mut offset = 0; while offset < idxs_in_p.len() { - offset += p.table.probe_subset( + offset += p.hash_table.probe_subset( &hash_keys, &idxs_in_p[offset..], &mut table_match, @@ -840,9 +775,9 @@ impl ProbeState { // Gather output and send. let mut build_df = if emit_unmatched { - p.df.take_opt_chunked_unchecked(&table_match, false) + p.payload.take_opt_chunked_unchecked(&table_match, false) } else { - p.df.take_chunked_unchecked(&table_match, IsSorted::Not, false) + p.payload.take_chunked_unchecked(&table_match, IsSorted::Not, false) }; if !payload_rechunked { // TODO: can avoid rechunk? We have to rechunk here or else we do it @@ -906,11 +841,11 @@ impl ProbeState { let mut unmarked_idxs = Vec::new(); unsafe { for p in self.table_per_partition.iter() { - p.table.unmarked_keys(&mut unmarked_idxs, 0, IdxSize::MAX); + p.hash_table.unmarked_keys(&mut unmarked_idxs, 0, IdxSize::MAX); // Gather and create full-null counterpart. let mut build_df = - p.df.take_chunked_unchecked(&unmarked_idxs, IsSorted::Not, false); + p.payload.take_chunked_unchecked(&unmarked_idxs, IsSorted::Not, false); let len = build_df.height(); let mut out_df = if params.left_is_build.unwrap() { let probe_df = DataFrame::full_null(¶ms.right_payload_schema, len); @@ -985,7 +920,7 @@ impl EmitUnmatchedState { let total_len: usize = self .partitions .iter() - .map(|p| p.table.num_keys() as usize) + .map(|p| p.hash_table.num_keys() as usize) .sum(); let ideal_morsel_count = (total_len / get_ideal_morsel_size()).max(1); let morsel_count = ideal_morsel_count.next_multiple_of(num_pipelines); @@ -997,7 +932,7 @@ impl EmitUnmatchedState { while let Some(p) = self.partitions.get(self.active_partition_idx) { loop { // Generate a chunk of unmarked key indices. - self.offset_in_active_p += p.table.unmarked_keys( + self.offset_in_active_p += p.hash_table.unmarked_keys( &mut unmarked_idxs, self.offset_in_active_p as IdxSize, morsel_size as IdxSize, @@ -1009,7 +944,7 @@ impl EmitUnmatchedState { // Gather and create full-null counterpart. let out_df = unsafe { let mut build_df = - p.df.take_chunked_unchecked(&unmarked_idxs, IsSorted::Not, false); + p.payload.take_chunked_unchecked(&unmarked_idxs, IsSorted::Not, false); let len = build_df.height(); if params.left_is_build.unwrap() { let probe_df = DataFrame::full_null(¶ms.right_payload_schema, len); @@ -1064,8 +999,8 @@ struct EquiJoinParams { right_key_selectors: Vec, left_payload_select: Vec>, right_payload_select: Vec>, - left_payload_schema: Schema, - right_payload_schema: Schema, + left_payload_schema: Arc, + right_payload_schema: Arc, args: JoinArgs, random_state: PlRandomState, } @@ -1154,8 +1089,8 @@ impl EquiJoinNode { EquiJoinState::Sample(SampleState::default()) }; - let left_payload_schema = select_schema(&left_input_schema, &left_payload_select); - let right_payload_schema = select_schema(&right_input_schema, &right_payload_select); + let left_payload_schema = Arc::new(select_schema(&left_input_schema, &left_payload_select)); + let right_payload_schema = Arc::new(select_schema(&right_input_schema, &right_payload_select)); Ok(Self { state, num_pipelines: 0, diff --git a/crates/polars-stream/src/nodes/joins/mod.rs b/crates/polars-stream/src/nodes/joins/mod.rs index f5304162d56a..14706fa4407f 100644 --- a/crates/polars-stream/src/nodes/joins/mod.rs +++ b/crates/polars-stream/src/nodes/joins/mod.rs @@ -1,2 +1,3 @@ pub mod equi_join; +pub mod new_equi_join; pub mod in_memory; diff --git a/crates/polars-stream/src/nodes/joins/new_equi_join.rs b/crates/polars-stream/src/nodes/joins/new_equi_join.rs new file mode 100644 index 000000000000..0c1b20cad8ab --- /dev/null +++ b/crates/polars-stream/src/nodes/joins/new_equi_join.rs @@ -0,0 +1,1290 @@ +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::{Arc, LazyLock}; + +use crossbeam_queue::ArrayQueue; +use polars_core::prelude::*; +use polars_core::schema::{Schema, SchemaExt}; + use polars_utils::sync::SyncPtr; +use polars_core::series::IsSorted; +use polars_core::utils::accumulate_dataframes_vertical_unchecked; +use polars_core::{config, POOL}; +use polars_expr::idx_table::{new_idx_table, IdxTable}; +use polars_expr::hash_keys::HashKeys; +use polars_io::pl_async::get_runtime; +use polars_ops::frame::{JoinArgs, JoinType, MaintainOrderJoin}; +use polars_ops::prelude::TakeChunked; +use polars_ops::series::coalesce_columns; +use polars_utils::cardinality_sketch::CardinalitySketch; +use polars_utils::hashing::HashPartitioner; +use polars_utils::itertools::Itertools; +use polars_utils::pl_str::PlSmallStr; +use polars_utils::{format_pl_smallstr, IdxSize}; + use arrow::array::builder::ShareStrategy; + use polars_core::frame::builder::DataFrameBuilder; +use rayon::prelude::*; + +use crate::async_primitives::connector::{connector, Receiver, Sender}; +use crate::async_primitives::wait_group::WaitGroup; +use crate::expression::StreamExpr; +use crate::morsel::{get_ideal_morsel_size, SourceToken}; +use crate::nodes::compute_node_prelude::*; +use crate::nodes::in_memory_source::InMemorySourceNode; + +static SAMPLE_LIMIT: LazyLock = LazyLock::new(|| { + std::env::var("POLARS_JOIN_SAMPLE_LIMIT") + .map(|limit| limit.parse().unwrap()) + .unwrap_or(10_000_000) +}); + +// If one side is this much bigger than the other side we'll always use the +// smaller side as the build side without checking cardinalities. +const LOPSIDED_SAMPLE_FACTOR: usize = 10; + +/// A payload selector contains for each column whether that column should be +/// included in the payload, and if yes with what name. +fn compute_payload_selector( + this: &Schema, + other: &Schema, + this_key_schema: &Schema, + is_left: bool, + args: &JoinArgs, +) -> PolarsResult>> { + let should_coalesce = args.should_coalesce(); + + this.iter_names() + .enumerate() + .map(|(i, c)| { + let selector = if should_coalesce && this_key_schema.contains(c) { + if is_left != (args.how == JoinType::Right) { + Some(c.clone()) + } else if args.how == JoinType::Full { + // We must keep the right-hand side keycols around for + // coalescing. + Some(format_pl_smallstr!("__POLARS_COALESCE_KEYCOL{i}")) + } else { + None + } + } else if !other.contains(c) || is_left { + Some(c.clone()) + } else { + let suffixed = format_pl_smallstr!("{}{}", c, args.suffix()); + if other.contains(&suffixed) { + polars_bail!(Duplicate: "column with name '{suffixed}' already exists\n\n\ + You may want to try:\n\ + - renaming the column prior to joining\n\ + - using the `suffix` parameter to specify a suffix different to the default one ('_right')") + } + Some(suffixed) + }; + Ok(selector) + }) + .collect() +} + +/// Fixes names and does coalescing of columns post-join. +fn postprocess_join(df: DataFrame, params: &EquiJoinParams) -> DataFrame { + if params.args.how == JoinType::Full && params.args.should_coalesce() { + // TODO: don't do string-based column lookups for each dataframe, pre-compute coalesce indices. + let mut key_idx = 0; + df.get_columns() + .iter() + .filter_map(|c| { + if let Some((key_name, _)) = params.left_key_schema.get_at_index(key_idx) { + if c.name() == key_name { + let other = df + .column(&format_pl_smallstr!("__POLARS_COALESCE_KEYCOL{key_idx}")) + .unwrap(); + key_idx += 1; + return Some(coalesce_columns(&[c.clone(), other.clone()]).unwrap()); + } + } + + if c.name().starts_with("__POLARS_COALESCE_KEYCOL") { + return None; + } + + Some(c.clone()) + }) + .collect() + } else { + df + } +} + +fn select_schema(schema: &Schema, selector: &[Option]) -> Schema { + schema + .iter_fields() + .zip(selector) + .filter_map(|(f, name)| Some(f.with_name(name.clone()?))) + .collect() +} + +async fn select_keys( + df: &DataFrame, + key_selectors: &[StreamExpr], + params: &EquiJoinParams, + state: &ExecutionState, +) -> PolarsResult { + let mut key_columns = Vec::new(); + for (i, selector) in key_selectors.iter().enumerate() { + // We use key columns entirely by position, and allow duplicate names, + // so just assign arbitrary unique names. + let unique_name = format_pl_smallstr!("__POLARS_KEYCOL_{i}"); + let s = selector.evaluate(df, state).await?; + key_columns.push(s.into_column().with_name(unique_name)); + } + let keys = DataFrame::new_with_broadcast_len(key_columns, df.height())?; + Ok(HashKeys::from_df( + &keys, + params.random_state.clone(), + params.args.nulls_equal, + true, + )) +} + +fn select_payload(df: DataFrame, selector: &[Option]) -> DataFrame { + // Maintain height of zero-width dataframes. + if df.width() == 0 { + return df; + } + + df.take_columns() + .into_iter() + .zip(selector) + .filter_map(|(c, name)| Some(c.with_name(name.clone()?))) + .collect() +} + +fn estimate_cardinality( + morsels: &[Morsel], + key_selectors: &[StreamExpr], + params: &EquiJoinParams, + state: &ExecutionState, +) -> PolarsResult { + // TODO: parallelize. + let mut sketch = CardinalitySketch::new(); + for morsel in morsels { + let hash_keys = + get_runtime().block_on(select_keys(morsel.df(), key_selectors, params, state))?; + hash_keys.sketch_cardinality(&mut sketch); + } + Ok(sketch.estimate()) +} + +struct BufferedStream { + morsels: ArrayQueue, + post_buffer_offset: MorselSeq, +} + +impl BufferedStream { + pub fn new(morsels: Vec, start_offset: MorselSeq) -> Self { + // Relabel so we can insert into parallel streams later. + let mut seq = start_offset; + let queue = ArrayQueue::new(morsels.len().max(1)); + for mut morsel in morsels { + morsel.set_seq(seq); + queue.push(morsel).unwrap(); + seq = seq.successor(); + } + + Self { + morsels: queue, + post_buffer_offset: seq, + } + } + + pub fn is_empty(&self) -> bool { + self.morsels.is_empty() + } + + #[expect(clippy::needless_lifetimes)] + pub fn reinsert<'s, 'env>( + &'s self, + num_pipelines: usize, + recv_port: Option>, + scope: &'s TaskScope<'s, 'env>, + join_handles: &mut Vec>>, + ) -> Option>> { + let receivers = if let Some(p) = recv_port { + p.parallel().into_iter().map(Some).collect_vec() + } else { + (0..num_pipelines).map(|_| None).collect_vec() + }; + + let source_token = SourceToken::new(); + let mut out = Vec::new(); + for orig_recv in receivers { + let (mut new_send, new_recv) = connector(); + out.push(new_recv); + let source_token = source_token.clone(); + join_handles.push(scope.spawn_task(TaskPriority::High, async move { + // Act like an InMemorySource node until cached morsels are consumed. + let wait_group = WaitGroup::default(); + loop { + let Some(mut morsel) = self.morsels.pop() else { + break; + }; + morsel.replace_source_token(source_token.clone()); + morsel.set_consume_token(wait_group.token()); + if new_send.send(morsel).await.is_err() { + return Ok(()); + } + wait_group.wait().await; + // TODO: Unfortunately we can't actually stop here without + // re-buffering morsels from the stream that comes after. + // if source_token.stop_requested() { + // break; + // } + } + + if let Some(mut recv) = orig_recv { + while let Ok(mut morsel) = recv.recv().await { + if source_token.stop_requested() { + morsel.source_token().stop(); + } + morsel.set_seq(morsel.seq().offset_by(self.post_buffer_offset)); + if new_send.send(morsel).await.is_err() { + break; + } + } + } + Ok(()) + })); + } + Some(out) + } +} + +impl Default for BufferedStream { + fn default() -> Self { + Self { + morsels: ArrayQueue::new(1), + post_buffer_offset: MorselSeq::default(), + } + } +} + +impl Drop for BufferedStream { + fn drop(&mut self) { + POOL.install(|| { + // Parallel drop as the state might be quite big. + (0..self.morsels.len()) + .into_par_iter() + .for_each(|_| drop(self.morsels.pop())); + }) + } +} + +#[derive(Default)] +struct SampleState { + left: Vec, + left_len: usize, + right: Vec, + right_len: usize, +} + +impl SampleState { + async fn sink( + mut recv: Receiver, + morsels: &mut Vec, + len: &mut usize, + this_final_len: Arc, + other_final_len: Arc, + ) -> PolarsResult<()> { + while let Ok(mut morsel) = recv.recv().await { + *len += morsel.df().height(); + if *len >= *SAMPLE_LIMIT + || *len + >= other_final_len + .load(Ordering::Relaxed) + .saturating_mul(LOPSIDED_SAMPLE_FACTOR) + { + morsel.source_token().stop(); + } + + drop(morsel.take_consume_token()); + morsels.push(morsel); + } + this_final_len.store(*len, Ordering::Relaxed); + Ok(()) + } + + fn try_transition_to_build( + &mut self, + recv: &[PortState], + num_pipelines: usize, + params: &mut EquiJoinParams, + table: &mut Option>, + ) -> PolarsResult> { + let left_saturated = self.left_len >= *SAMPLE_LIMIT; + let right_saturated = self.right_len >= *SAMPLE_LIMIT; + let left_done = recv[0] == PortState::Done || left_saturated; + let right_done = recv[1] == PortState::Done || right_saturated; + #[expect(clippy::nonminimal_bool)] + let stop_sampling = (left_done && right_done) + || (left_done && self.right_len >= LOPSIDED_SAMPLE_FACTOR * self.left_len) + || (right_done && self.left_len >= LOPSIDED_SAMPLE_FACTOR * self.right_len); + if !stop_sampling { + return Ok(None); + } + + if config::verbose() { + eprintln!( + "choosing equi-join build side, sample lengths are: {} vs. {}", + self.left_len, self.right_len + ); + } + + let estimate_cardinalities = || { + let execution_state = ExecutionState::new(); + let left_cardinality = estimate_cardinality( + &self.left, + ¶ms.left_key_selectors, + params, + &execution_state, + )?; + let right_cardinality = estimate_cardinality( + &self.right, + ¶ms.right_key_selectors, + params, + &execution_state, + )?; + let norm_left_factor = self.left_len.min(*SAMPLE_LIMIT) as f64 / self.left_len as f64; + let norm_right_factor = + self.right_len.min(*SAMPLE_LIMIT) as f64 / self.right_len as f64; + let norm_left_cardinality = (left_cardinality as f64 * norm_left_factor) as usize; + let norm_right_cardinality = (right_cardinality as f64 * norm_right_factor) as usize; + if config::verbose() { + eprintln!("estimated cardinalities are: {norm_left_cardinality} vs. {norm_right_cardinality}"); + } + PolarsResult::Ok((norm_left_cardinality, norm_right_cardinality)) + }; + + let left_is_build = match (left_saturated, right_saturated) { + (false, false) => { + if self.left_len * LOPSIDED_SAMPLE_FACTOR < self.right_len + || self.left_len > self.right_len * LOPSIDED_SAMPLE_FACTOR + { + // Don't bother estimating cardinality, just choose smaller as it's highly + // imbalanced. + self.left_len < self.right_len + } else { + let (lc, rc) = estimate_cardinalities()?; + // Let's assume for now that per element building a + // table is 3x more expensive than a probe, with + // unique keys getting an additional 3x factor for + // having to update the hash table in addition to the probe. + let left_build_cost = self.left_len * 3 + 3 * lc; + let left_probe_cost = self.left_len; + let right_build_cost = self.right_len * 3 + 3 * rc; + let right_probe_cost = self.right_len; + left_build_cost + right_probe_cost < left_probe_cost + right_build_cost + } + }, + + // Choose the unsaturated side, the saturated side could be + // arbitrarily big. + (false, true) => true, + (true, false) => false, + + // Estimate cardinality and choose smaller. + (true, true) => { + let (lc, rc) = estimate_cardinalities()?; + lc < rc + }, + }; + + if config::verbose() { + eprintln!( + "build side chosen: {}", + if left_is_build { "left" } else { "right" } + ); + } + + // Transition to building state. + params.left_is_build = Some(left_is_build); + *table = Some(if left_is_build { + new_idx_table(params.left_key_schema.clone()) + } else { + new_idx_table(params.right_key_schema.clone()) + }); + + let mut sampled_build_morsels = + BufferedStream::new(core::mem::take(&mut self.left), MorselSeq::default()); + let mut sampled_probe_morsels = + BufferedStream::new(core::mem::take(&mut self.right), MorselSeq::default()); + if !left_is_build { + core::mem::swap(&mut sampled_build_morsels, &mut sampled_probe_morsels); + } + + let partitioner = HashPartitioner::new(num_pipelines, 0); + let mut build_state = BuildState { + local_builders: (0..num_pipelines).map(|_| LocalBuilder::default()).collect(), + sampled_probe_morsels, + }; + + // Simulate the sample build morsels flowing into the build side. + if !sampled_build_morsels.is_empty() { + let state = ExecutionState::new(); + crate::async_executor::task_scope(|scope| { + let mut join_handles = Vec::new(); + let receivers = sampled_build_morsels + .reinsert(num_pipelines, None, scope, &mut join_handles) + .unwrap(); + + for (local_builder, recv) in build_state.local_builders.iter_mut().zip(receivers) + { + join_handles.push(scope.spawn_task( + TaskPriority::High, + BuildState::partition_and_sink( + recv, + local_builder, + partitioner.clone(), + params, + &state, + ), + )); + } + + polars_io::pl_async::get_runtime().block_on(async move { + for handle in join_handles { + handle.await?; + } + PolarsResult::Ok(()) + }) + })?; + } + + Ok(Some(build_state)) + } +} + +#[derive(Default)] +struct LocalBuilder { + // The complete list of morsels and their computed hashes seen by this builder. + morsels: Vec<(MorselSeq, DataFrame, HashKeys)>, + + // A cardinality sketch per partition for the keys seen by this builder. + sketch_per_p: Vec, + + // morsel_idxs_values_per_p[p][start..stop] contains the offsets into morsels[i] + // for partition p, where start, stop are: + // let start = morsel_idxs_offsets[i * num_partitions + p]; + // let stop = morsel_idxs_offsets[(i + 1) * num_partitions + p]; + morsel_idxs_values_per_p: Vec>, + morsel_idxs_offsets_per_p: Vec, +} + +#[derive(Default)] +struct BuildState { + local_builders: Vec, + sampled_probe_morsels: BufferedStream, +} + +impl BuildState { + async fn partition_and_sink( + mut recv: Receiver, + local: &mut LocalBuilder, + partitioner: HashPartitioner, + params: &EquiJoinParams, + state: &ExecutionState, + ) -> PolarsResult<()> { + let track_unmatchable = params.emit_unmatched_build(); + local + .sketch_per_p + .resize_with(partitioner.num_partitions(), Default::default); + local + .morsel_idxs_values_per_p + .resize_with(partitioner.num_partitions(), Default::default); + + if local.morsel_idxs_offsets_per_p.is_empty() { + local + .morsel_idxs_offsets_per_p + .resize(partitioner.num_partitions(), 0); + } + + let (key_selectors, payload_selector); + if params.left_is_build.unwrap() { + payload_selector = ¶ms.left_payload_select; + key_selectors = ¶ms.left_key_selectors; + } else { + payload_selector = ¶ms.right_payload_select; + key_selectors = ¶ms.right_key_selectors; + }; + + while let Ok(morsel) = recv.recv().await { + // Compute hashed keys and payload. We must rechunk the payload for + // later gathers. + let hash_keys = select_keys(morsel.df(), key_selectors, params, state).await?; + let mut payload = select_payload(morsel.df().clone(), payload_selector); + payload.rechunk_mut(); + + hash_keys.gen_partition_idxs( + &partitioner, + &mut local.morsel_idxs_values_per_p, + &mut local.sketch_per_p, + track_unmatchable, + ); + + local + .morsel_idxs_offsets_per_p + .extend(local.morsel_idxs_values_per_p.iter().map(|vp| vp.len())); + local.morsels.push((morsel.seq(), payload, hash_keys)); + } + Ok(()) + } + + fn finalize(&mut self, params: &EquiJoinParams, table: &dyn IdxTable) -> ProbeState { + let track_unmatchable = params.emit_unmatched_build(); + let payload_schema = if params.left_is_build.unwrap() { + ¶ms.left_payload_schema + } else { + ¶ms.right_payload_schema + }; + + // To reduce maximum memory usage we want to drop the morsels + // as soon as they're processed, so we move into Arcs. + let morsels_per_local_builder = self + .local_builders + .iter_mut() + .map(|b| Arc::new(core::mem::take(&mut b.morsels))) + .collect_vec(); + let num_partitions = self.local_builders[0].sketch_per_p.len(); + let local_builders = &self.local_builders; + + + let mut probe_tables = POOL.scope(|s| { + let mut probe_tables: Vec = Vec::with_capacity(num_partitions); + let probe_table_ptr = unsafe { SyncPtr::new(probe_tables.as_mut_ptr()) }; + + // Wrap in outer Arc to move to each thread, performing the + // expensive clone on that thread. + let arc_morsels_per_local_builder = Arc::new(morsels_per_local_builder); + for p in 0..num_partitions { + let arc_morsels_per_local_builder = Arc::clone(&arc_morsels_per_local_builder); + s.spawn(move |_| { + // Extract from outer arc and drop outer arc. + let morsels_per_local_builder = Arc::unwrap_or_clone(arc_morsels_per_local_builder); + + // Compute cardinality estimate and total amount of + // payload for this partition. + let mut sketch = CardinalitySketch::new(); + let mut payload_rows = 0; + for l in local_builders { + sketch.combine(&l.sketch_per_p[p]); + let offsets_len = l.morsel_idxs_offsets_per_p.len(); + payload_rows += l.morsel_idxs_offsets_per_p[offsets_len - num_partitions + p]; + } + + // Allocate hash table and payload builder. + let mut p_table = table.new_empty(); + p_table.reserve(sketch.estimate() * 5 / 4); + let mut p_payload = DataFrameBuilder::new(payload_schema.clone()); + p_payload.reserve(payload_rows); + + // Build. + for (l, l_morsels) in local_builders.iter().zip(morsels_per_local_builder) { + for (i, morsel) in l_morsels.iter().enumerate() { + let (_mseq, payload, keys) = morsel; + unsafe { + let p_morsel_idxs_start = l.morsel_idxs_offsets_per_p[i * num_partitions + p]; + let p_morsel_idxs_stop = l.morsel_idxs_offsets_per_p[(i + 1) * num_partitions + p]; + let p_morsel_idxs = &l.morsel_idxs_values_per_p[p][p_morsel_idxs_start..p_morsel_idxs_stop]; + p_table.insert_keys_subset(keys, p_morsel_idxs, track_unmatchable); + p_payload.gather_extend(payload, p_morsel_idxs, ShareStrategy::Never); + } + } + } + + unsafe { + probe_table_ptr.get().add(p).write(ProbeTable { + hash_table: p_table, + payload: p_payload.freeze(), + }); + } + }); + } + + // Drop outer arc after spawning each thread so the inner arcs + // can get dropped as soon as they're processed. + drop(arc_morsels_per_local_builder); + probe_tables + }); + + unsafe { + // SAFETY: all entries are initialized now. + probe_tables.set_len(num_partitions); + } + + todo!() + } +} + +struct ProbeTable { + hash_table: Box, + payload: DataFrame, +} + +struct ProbeState { + table_per_partition: Vec, + max_seq_sent: MorselSeq, + sampled_probe_morsels: BufferedStream, +} + +impl ProbeState { + /// Returns the max morsel sequence sent. + async fn partition_and_probe( + mut recv: Receiver, + mut send: Sender, + partitions: &[ProbeTable], + partitioner: HashPartitioner, + params: &EquiJoinParams, + state: &ExecutionState, + ) -> PolarsResult { + todo!() + /* + // TODO: shuffle after partitioning and keep probe tables thread-local. + let mut partition_idxs = vec![Vec::new(); partitioner.num_partitions()]; + let mut table_match = Vec::new(); + let mut probe_match = Vec::new(); + let mut max_seq = MorselSeq::default(); + + let probe_limit = get_ideal_morsel_size() as IdxSize; + let mark_matches = params.emit_unmatched_build(); + let emit_unmatched = params.emit_unmatched_probe(); + + let (key_selectors, payload_selector); + if params.left_is_build.unwrap() { + payload_selector = ¶ms.right_payload_select; + key_selectors = ¶ms.right_key_selectors; + } else { + payload_selector = ¶ms.left_payload_select; + key_selectors = ¶ms.left_key_selectors; + }; + + while let Ok(morsel) = recv.recv().await { + // Compute hashed keys and payload. + let (df, seq, src_token, wait_token) = morsel.into_inner(); + let hash_keys = select_keys(&df, key_selectors, params, state).await?; + let mut payload = select_payload(df, payload_selector); + let mut payload_rechunked = false; // We don't eagerly rechunk because there might be no matches. + + max_seq = seq; + + unsafe { + // Partition and probe the tables. + for p in partition_idxs.iter_mut() { + p.clear(); + } + hash_keys.gen_partition_idxs( + &partitioner, + &mut partition_idxs, + &mut [], + emit_unmatched, + ); + if params.preserve_order_probe { + // TODO: non-sort based implementation, can directly scatter + // after finding matches for each partition. + let mut out_per_partition = Vec::with_capacity(partitioner.num_partitions()); + let name = PlSmallStr::from_static("__POLARS_PROBE_PRESERVE_ORDER_IDX"); + for (p, idxs_in_p) in partitions.iter().zip(&partition_idxs) { + p.hash_table.probe_subset( + &hash_keys, + idxs_in_p, + &mut table_match, + &mut probe_match, + mark_matches, + emit_unmatched, + IdxSize::MAX, + ); + + if table_match.is_empty() { + continue; + } + + // Gather output and add to buffer. + let mut build_df = if emit_unmatched { + p.payload.take_opt_chunked_unchecked(&table_match, false) + } else { + p.payload.take_chunked_unchecked(&table_match, IsSorted::Not, false) + }; + + if !payload_rechunked { + // TODO: can avoid rechunk? We have to rechunk here or else we do it + // multiple times during the gather. + payload.rechunk_mut(); + payload_rechunked = true; + } + let mut probe_df = payload.take_slice_unchecked_impl(&probe_match, false); + + let mut out_df = if params.left_is_build.unwrap() { + build_df.hstack_mut_unchecked(probe_df.get_columns()); + build_df + } else { + probe_df.hstack_mut_unchecked(build_df.get_columns()); + probe_df + }; + + let idxs_ca = + IdxCa::from_vec(name.clone(), core::mem::take(&mut probe_match)); + out_df.with_column_unchecked(idxs_ca.into_column()); + out_per_partition.push(out_df); + } + + if !out_per_partition.is_empty() { + let sort_options = SortMultipleOptions { + descending: vec![false], + nulls_last: vec![false], + multithreaded: false, + maintain_order: true, + limit: None, + }; + let mut out_df = + accumulate_dataframes_vertical_unchecked(out_per_partition); + out_df.sort_in_place([name.clone()], sort_options).unwrap(); + out_df.drop_in_place(&name).unwrap(); + out_df = postprocess_join(out_df, params); + + // TODO: break in smaller morsels. + let out_morsel = Morsel::new(out_df, seq, src_token.clone()); + if send.send(out_morsel).await.is_err() { + break; + } + } + } else { + let mut out_frames = Vec::new(); + let mut out_len = 0; + for (p, idxs_in_p) in partitions.iter().zip(&partition_idxs) { + let mut offset = 0; + while offset < idxs_in_p.len() { + offset += p.hash_table.probe_subset( + &hash_keys, + &idxs_in_p[offset..], + &mut table_match, + &mut probe_match, + mark_matches, + emit_unmatched, + probe_limit - out_len, + ) as usize; + + if table_match.is_empty() { + continue; + } + + // Gather output and send. + let mut build_df = if emit_unmatched { + p.payload.take_opt_chunked_unchecked(&table_match, false) + } else { + p.payload.take_chunked_unchecked(&table_match, IsSorted::Not, false) + }; + if !payload_rechunked { + // TODO: can avoid rechunk? We have to rechunk here or else we do it + // multiple times during the gather. + payload.rechunk_mut(); + payload_rechunked = true; + } + let mut probe_df = + payload.take_slice_unchecked_impl(&probe_match, false); + + let out_df = if params.left_is_build.unwrap() { + build_df.hstack_mut_unchecked(probe_df.get_columns()); + build_df + } else { + probe_df.hstack_mut_unchecked(build_df.get_columns()); + probe_df + }; + let out_df = postprocess_join(out_df, params); + + out_len = out_len + .checked_add(out_df.height().try_into().unwrap()) + .unwrap(); + out_frames.push(out_df); + + if out_len >= probe_limit { + out_len = 0; + let df = + accumulate_dataframes_vertical_unchecked(out_frames.drain(..)); + let out_morsel = Morsel::new(df, seq, src_token.clone()); + if send.send(out_morsel).await.is_err() { + break; + } + } + } + } + + if out_len > 0 { + let df = accumulate_dataframes_vertical_unchecked(out_frames.drain(..)); + let out_morsel = Morsel::new(df, seq, src_token.clone()); + if send.send(out_morsel).await.is_err() { + break; + } + } + } + } + + drop(wait_token); + } + + Ok(max_seq) + */ + } + + fn ordered_unmatched( + &mut self, + partitioner: &HashPartitioner, + params: &EquiJoinParams, + ) -> DataFrame { + todo!() + } +} + +impl Drop for ProbeState { + fn drop(&mut self) { + POOL.install(|| { + // Parallel drop as the state might be quite big. + self.table_per_partition.par_drain(..).for_each(drop); + }) + } +} + +struct EmitUnmatchedState { + partitions: Vec, + active_partition_idx: usize, + offset_in_active_p: usize, + morsel_seq: MorselSeq, +} + +impl EmitUnmatchedState { + async fn emit_unmatched( + &mut self, + mut send: Sender, + params: &EquiJoinParams, + num_pipelines: usize, + ) -> PolarsResult<()> { + todo!() + } +} + + +enum EquiJoinState { + Sample(SampleState), + Build(BuildState), + Probe(ProbeState), + EmitUnmatchedBuild(EmitUnmatchedState), + EmitUnmatchedBuildInOrder(InMemorySourceNode), + Done, +} + +struct EquiJoinParams { + left_is_build: Option, + preserve_order_build: bool, + preserve_order_probe: bool, + left_key_schema: Arc, + left_key_selectors: Vec, + right_key_schema: Arc, + right_key_selectors: Vec, + left_payload_select: Vec>, + right_payload_select: Vec>, + left_payload_schema: Arc, + right_payload_schema: Arc, + args: JoinArgs, + random_state: PlRandomState, +} + +impl EquiJoinParams { + /// Should we emit unmatched rows from the build side? + fn emit_unmatched_build(&self) -> bool { + if self.left_is_build.unwrap() { + self.args.how == JoinType::Left || self.args.how == JoinType::Full + } else { + self.args.how == JoinType::Right || self.args.how == JoinType::Full + } + } + + /// Should we emit unmatched rows from the probe side? + fn emit_unmatched_probe(&self) -> bool { + if self.left_is_build.unwrap() { + self.args.how == JoinType::Right || self.args.how == JoinType::Full + } else { + self.args.how == JoinType::Left || self.args.how == JoinType::Full + } + } +} + +pub struct EquiJoinNode { + state: EquiJoinState, + params: EquiJoinParams, + num_pipelines: usize, + table: Option>, +} + +impl EquiJoinNode { + pub fn new( + left_input_schema: Arc, + right_input_schema: Arc, + left_key_schema: Arc, + right_key_schema: Arc, + left_key_selectors: Vec, + right_key_selectors: Vec, + args: JoinArgs, + ) -> PolarsResult { + let left_is_build = match args.maintain_order { + MaintainOrderJoin::None => { + if *SAMPLE_LIMIT == 0 { + Some(true) + } else { + None + } + }, + MaintainOrderJoin::Left | MaintainOrderJoin::LeftRight => Some(false), + MaintainOrderJoin::Right | MaintainOrderJoin::RightLeft => Some(true), + }; + + let table = left_is_build.map(|lib| { + if lib { + new_idx_table(left_key_schema.clone()) + } else { + new_idx_table(right_key_schema.clone()) + } + }); + + let preserve_order_probe = args.maintain_order != MaintainOrderJoin::None; + let preserve_order_build = matches!( + args.maintain_order, + MaintainOrderJoin::LeftRight | MaintainOrderJoin::RightLeft + ); + + let left_payload_select = compute_payload_selector( + &left_input_schema, + &right_input_schema, + &left_key_schema, + true, + &args, + )?; + let right_payload_select = compute_payload_selector( + &right_input_schema, + &left_input_schema, + &right_key_schema, + false, + &args, + )?; + + let state = if left_is_build.is_some() { + EquiJoinState::Build(BuildState::default()) + } else { + EquiJoinState::Sample(SampleState::default()) + }; + + let left_payload_schema = Arc::new(select_schema(&left_input_schema, &left_payload_select)); + let right_payload_schema = Arc::new(select_schema(&right_input_schema, &right_payload_select)); + Ok(Self { + state, + num_pipelines: 0, + params: EquiJoinParams { + left_is_build, + preserve_order_build, + preserve_order_probe, + left_key_schema, + left_key_selectors, + right_key_schema, + right_key_selectors, + left_payload_select, + right_payload_select, + left_payload_schema, + right_payload_schema, + args, + random_state: PlRandomState::new(), + }, + table, + }) + } +} + +impl ComputeNode for EquiJoinNode { + fn name(&self) -> &str { + "equi_join" + } + + fn initialize(&mut self, num_pipelines: usize) { + self.num_pipelines = num_pipelines; + } + + fn update_state(&mut self, recv: &mut [PortState], send: &mut [PortState]) -> PolarsResult<()> { + assert!(recv.len() == 2 && send.len() == 1); + + // If the output doesn't want any more data, transition to being done. + if send[0] == PortState::Done { + self.state = EquiJoinState::Done; + } + + // If we are sampling and both sides are done/filled, transition to building. + if let EquiJoinState::Sample(sample_state) = &mut self.state { + if let Some(build_state) = sample_state.try_transition_to_build( + recv, + self.num_pipelines, + &mut self.params, + &mut self.table, + )? { + self.state = EquiJoinState::Build(build_state); + } + } + + let build_idx = if self.params.left_is_build == Some(true) { + 0 + } else { + 1 + }; + let probe_idx = 1 - build_idx; + + // If we are building and the build input is done, transition to probing. + if let EquiJoinState::Build(build_state) = &mut self.state { + if recv[build_idx] == PortState::Done { + self.state = EquiJoinState::Probe( + build_state.finalize(&self.params, self.table.as_deref().unwrap()), + ); + } + } + + // If we are probing and the probe input is done, emit unmatched if + // necessary, otherwise we're done. + if let EquiJoinState::Probe(probe_state) = &mut self.state { + let samples_consumed = probe_state.sampled_probe_morsels.is_empty(); + if samples_consumed && recv[probe_idx] == PortState::Done { + if self.params.emit_unmatched_build() { + if self.params.preserve_order_build { + let partitioner = HashPartitioner::new(self.num_pipelines, 0); + let unmatched = probe_state.ordered_unmatched(&partitioner, &self.params); + let mut src = InMemorySourceNode::new( + Arc::new(unmatched), + probe_state.max_seq_sent.successor(), + ); + src.initialize(self.num_pipelines); + self.state = EquiJoinState::EmitUnmatchedBuildInOrder(src); + } else { + self.state = EquiJoinState::EmitUnmatchedBuild(EmitUnmatchedState { + partitions: core::mem::take(&mut probe_state.table_per_partition), + active_partition_idx: 0, + offset_in_active_p: 0, + morsel_seq: probe_state.max_seq_sent.successor(), + }); + } + } else { + self.state = EquiJoinState::Done; + } + } + } + + // Finally, check if we are done emitting unmatched keys. + if let EquiJoinState::EmitUnmatchedBuild(emit_state) = &mut self.state { + if emit_state.active_partition_idx >= emit_state.partitions.len() { + self.state = EquiJoinState::Done; + } + } + + match &mut self.state { + EquiJoinState::Sample(sample_state) => { + send[0] = PortState::Blocked; + if recv[0] != PortState::Done { + recv[0] = if sample_state.left_len < *SAMPLE_LIMIT { + PortState::Ready + } else { + PortState::Blocked + }; + } + if recv[1] != PortState::Done { + recv[1] = if sample_state.right_len < *SAMPLE_LIMIT { + PortState::Ready + } else { + PortState::Blocked + }; + } + }, + EquiJoinState::Build(_) => { + send[0] = PortState::Blocked; + if recv[build_idx] != PortState::Done { + recv[build_idx] = PortState::Ready; + } + if recv[probe_idx] != PortState::Done { + recv[probe_idx] = PortState::Blocked; + } + }, + EquiJoinState::Probe(probe_state) => { + if recv[probe_idx] != PortState::Done { + core::mem::swap(&mut send[0], &mut recv[probe_idx]); + } else { + let samples_consumed = probe_state.sampled_probe_morsels.is_empty(); + send[0] = if samples_consumed { + PortState::Done + } else { + PortState::Ready + }; + } + recv[build_idx] = PortState::Done; + }, + EquiJoinState::EmitUnmatchedBuild(_) => { + send[0] = PortState::Ready; + recv[build_idx] = PortState::Done; + recv[probe_idx] = PortState::Done; + }, + EquiJoinState::EmitUnmatchedBuildInOrder(src_node) => { + recv[build_idx] = PortState::Done; + recv[probe_idx] = PortState::Done; + src_node.update_state(&mut [], &mut send[0..1])?; + if send[0] == PortState::Done { + self.state = EquiJoinState::Done; + } + }, + EquiJoinState::Done => { + send[0] = PortState::Done; + recv[0] = PortState::Done; + recv[1] = PortState::Done; + }, + } + Ok(()) + } + + fn is_memory_intensive_pipeline_blocker(&self) -> bool { + matches!( + self.state, + EquiJoinState::Sample { .. } | EquiJoinState::Build { .. } + ) + } + + fn spawn<'env, 's>( + &'env mut self, + scope: &'s TaskScope<'s, 'env>, + recv_ports: &mut [Option>], + send_ports: &mut [Option>], + state: &'s ExecutionState, + join_handles: &mut Vec>>, + ) { + assert!(recv_ports.len() == 2); + assert!(send_ports.len() == 1); + + let build_idx = if self.params.left_is_build == Some(true) { + 0 + } else { + 1 + }; + let probe_idx = 1 - build_idx; + + match &mut self.state { + EquiJoinState::Sample(sample_state) => { + assert!(send_ports[0].is_none()); + let left_final_len = Arc::new(AtomicUsize::new(if recv_ports[0].is_none() { + sample_state.left_len + } else { + usize::MAX + })); + let right_final_len = Arc::new(AtomicUsize::new(if recv_ports[1].is_none() { + sample_state.right_len + } else { + usize::MAX + })); + + if let Some(left_recv) = recv_ports[0].take() { + join_handles.push(scope.spawn_task( + TaskPriority::High, + SampleState::sink( + left_recv.serial(), + &mut sample_state.left, + &mut sample_state.left_len, + left_final_len.clone(), + right_final_len.clone(), + ), + )); + } + if let Some(right_recv) = recv_ports[1].take() { + join_handles.push(scope.spawn_task( + TaskPriority::High, + SampleState::sink( + right_recv.serial(), + &mut sample_state.right, + &mut sample_state.right_len, + right_final_len, + left_final_len, + ), + )); + } + }, + EquiJoinState::Build(build_state) => { + assert!(send_ports[0].is_none()); + assert!(recv_ports[probe_idx].is_none()); + let receivers = recv_ports[build_idx].take().unwrap().parallel(); + + build_state + .local_builders + .resize_with(self.num_pipelines, Default::default); + let partitioner = HashPartitioner::new(self.num_pipelines, 0); + for (local_builder, recv) in build_state.local_builders.iter_mut().zip(receivers) + { + join_handles.push(scope.spawn_task( + TaskPriority::High, + BuildState::partition_and_sink( + recv, + local_builder, + partitioner.clone(), + &self.params, + state, + ), + )); + } + }, + EquiJoinState::Probe(probe_state) => { + assert!(recv_ports[build_idx].is_none()); + let senders = send_ports[0].take().unwrap().parallel(); + let receivers = probe_state + .sampled_probe_morsels + .reinsert( + self.num_pipelines, + recv_ports[probe_idx].take(), + scope, + join_handles, + ) + .unwrap(); + + let partitioner = HashPartitioner::new(self.num_pipelines, 0); + let probe_tasks = receivers + .into_iter() + .zip(senders) + .map(|(recv, send)| { + scope.spawn_task( + TaskPriority::High, + ProbeState::partition_and_probe( + recv, + send, + &probe_state.table_per_partition, + partitioner.clone(), + &self.params, + state, + ), + ) + }) + .collect_vec(); + + let max_seq_sent = &mut probe_state.max_seq_sent; + join_handles.push(scope.spawn_task(TaskPriority::High, async move { + for probe_task in probe_tasks { + *max_seq_sent = (*max_seq_sent).max(probe_task.await?); + } + Ok(()) + })); + }, + EquiJoinState::EmitUnmatchedBuild(emit_state) => { + assert!(recv_ports[build_idx].is_none()); + assert!(recv_ports[probe_idx].is_none()); + let send = send_ports[0].take().unwrap().serial(); + join_handles.push(scope.spawn_task( + TaskPriority::Low, + emit_state.emit_unmatched(send, &self.params, self.num_pipelines), + )); + }, + EquiJoinState::EmitUnmatchedBuildInOrder(src_node) => { + assert!(recv_ports[build_idx].is_none()); + assert!(recv_ports[probe_idx].is_none()); + src_node.spawn(scope, &mut [], send_ports, state, join_handles); + }, + EquiJoinState::Done => unreachable!(), + } + } +} From 232be2b873ae9a0966604d1d62cee9b97208f030 Mon Sep 17 00:00:00 2001 From: Orson Peters Date: Tue, 4 Mar 2025 13:30:44 +0100 Subject: [PATCH 03/25] wip --- .../src/nodes/joins/new_equi_join.rs | 92 ++++--------------- 1 file changed, 17 insertions(+), 75 deletions(-) diff --git a/crates/polars-stream/src/nodes/joins/new_equi_join.rs b/crates/polars-stream/src/nodes/joins/new_equi_join.rs index 0c1b20cad8ab..d150d9b49c02 100644 --- a/crates/polars-stream/src/nodes/joins/new_equi_join.rs +++ b/crates/polars-stream/src/nodes/joins/new_equi_join.rs @@ -616,7 +616,11 @@ impl BuildState { probe_tables.set_len(num_partitions); } - todo!() + ProbeState { + table_per_partition: probe_tables, + max_seq_sent: MorselSeq::default(), + sampled_probe_morsels: core::mem::take(&mut self.sampled_probe_morsels), + } } } @@ -641,8 +645,6 @@ impl ProbeState { params: &EquiJoinParams, state: &ExecutionState, ) -> PolarsResult { - todo!() - /* // TODO: shuffle after partitioning and keep probe tables thread-local. let mut partition_idxs = vec![Vec::new(); partitioner.num_partitions()]; let mut table_match = Vec::new(); @@ -653,13 +655,17 @@ impl ProbeState { let mark_matches = params.emit_unmatched_build(); let emit_unmatched = params.emit_unmatched_probe(); - let (key_selectors, payload_selector); + let (key_selectors, payload_selector, build_payload_schema, probe_payload_schema); if params.left_is_build.unwrap() { - payload_selector = ¶ms.right_payload_select; key_selectors = ¶ms.right_key_selectors; + payload_selector = ¶ms.right_payload_select; + build_payload_schema = ¶ms.left_payload_schema; + probe_payload_schema = ¶ms.right_payload_schema; } else { - payload_selector = ¶ms.left_payload_select; key_selectors = ¶ms.left_key_selectors; + payload_selector = ¶ms.left_payload_select; + build_payload_schema = ¶ms.right_payload_schema; + probe_payload_schema = ¶ms.left_payload_schema; }; while let Ok(morsel) = recv.recv().await { @@ -683,77 +689,13 @@ impl ProbeState { emit_unmatched, ); if params.preserve_order_probe { - // TODO: non-sort based implementation, can directly scatter - // after finding matches for each partition. - let mut out_per_partition = Vec::with_capacity(partitioner.num_partitions()); - let name = PlSmallStr::from_static("__POLARS_PROBE_PRESERVE_ORDER_IDX"); - for (p, idxs_in_p) in partitions.iter().zip(&partition_idxs) { - p.hash_table.probe_subset( - &hash_keys, - idxs_in_p, - &mut table_match, - &mut probe_match, - mark_matches, - emit_unmatched, - IdxSize::MAX, - ); - - if table_match.is_empty() { - continue; - } - - // Gather output and add to buffer. - let mut build_df = if emit_unmatched { - p.payload.take_opt_chunked_unchecked(&table_match, false) - } else { - p.payload.take_chunked_unchecked(&table_match, IsSorted::Not, false) - }; - - if !payload_rechunked { - // TODO: can avoid rechunk? We have to rechunk here or else we do it - // multiple times during the gather. - payload.rechunk_mut(); - payload_rechunked = true; - } - let mut probe_df = payload.take_slice_unchecked_impl(&probe_match, false); - - let mut out_df = if params.left_is_build.unwrap() { - build_df.hstack_mut_unchecked(probe_df.get_columns()); - build_df - } else { - probe_df.hstack_mut_unchecked(build_df.get_columns()); - probe_df - }; - - let idxs_ca = - IdxCa::from_vec(name.clone(), core::mem::take(&mut probe_match)); - out_df.with_column_unchecked(idxs_ca.into_column()); - out_per_partition.push(out_df); - } - - if !out_per_partition.is_empty() { - let sort_options = SortMultipleOptions { - descending: vec![false], - nulls_last: vec![false], - multithreaded: false, - maintain_order: true, - limit: None, - }; - let mut out_df = - accumulate_dataframes_vertical_unchecked(out_per_partition); - out_df.sort_in_place([name.clone()], sort_options).unwrap(); - out_df.drop_in_place(&name).unwrap(); - out_df = postprocess_join(out_df, params); - - // TODO: break in smaller morsels. - let out_morsel = Morsel::new(out_df, seq, src_token.clone()); - if send.send(out_morsel).await.is_err() { - break; - } - } + todo!() } else { let mut out_frames = Vec::new(); + let mut build_out = DataFrameBuilder::new(build_payload_schema.clone()); + let mut probe_out = DataFrameBuilder::new(probe_payload_schema.clone()); let mut out_len = 0; + /* for (p, idxs_in_p) in partitions.iter().zip(&partition_idxs) { let mut offset = 0; while offset < idxs_in_p.len() { @@ -811,6 +753,7 @@ impl ProbeState { } } } + */ if out_len > 0 { let df = accumulate_dataframes_vertical_unchecked(out_frames.drain(..)); @@ -826,7 +769,6 @@ impl ProbeState { } Ok(max_seq) - */ } fn ordered_unmatched( From 5335fc69cdb5dabd443a7a8cd7f0997204f1e531 Mon Sep 17 00:00:00 2001 From: Orson Peters Date: Tue, 4 Mar 2025 14:19:58 +0100 Subject: [PATCH 04/25] wip --- .../src/nodes/joins/new_equi_join.rs | 33 ++++++++++--------- 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/crates/polars-stream/src/nodes/joins/new_equi_join.rs b/crates/polars-stream/src/nodes/joins/new_equi_join.rs index d150d9b49c02..fe6b81735202 100644 --- a/crates/polars-stream/src/nodes/joins/new_equi_join.rs +++ b/crates/polars-stream/src/nodes/joins/new_equi_join.rs @@ -695,7 +695,7 @@ impl ProbeState { let mut build_out = DataFrameBuilder::new(build_payload_schema.clone()); let mut probe_out = DataFrameBuilder::new(probe_payload_schema.clone()); let mut out_len = 0; - /* + for (p, idxs_in_p) in partitions.iter().zip(&partition_idxs) { let mut offset = 0; while offset < idxs_in_p.len() { @@ -714,28 +714,29 @@ impl ProbeState { } // Gather output and send. - let mut build_df = if emit_unmatched { - p.payload.take_opt_chunked_unchecked(&table_match, false) + if emit_unmatched { + build_out.opt_gather_extend(&p.payload, &table_match, ShareStrategy::Always); } else { - p.payload.take_chunked_unchecked(&table_match, IsSorted::Not, false) + build_out.gather_extend(&p.payload, &table_match, ShareStrategy::Always); }; if !payload_rechunked { - // TODO: can avoid rechunk? We have to rechunk here or else we do it - // multiple times during the gather. payload.rechunk_mut(); payload_rechunked = true; } - let mut probe_df = - payload.take_slice_unchecked_impl(&probe_match, false); + probe_out.gather_extend(&payload, &probe_match, ShareStrategy::Always); + + if probe_out.len() >= probe_limit { + let out_df = if params.left_is_build.unwrap() { + build_df.hstack_mut_unchecked(probe_df.get_columns()); + build_df + } else { + probe_df.hstack_mut_unchecked(build_df.get_columns()); + probe_df + }; + let out_df = postprocess_join(out_df, params); + + } - let out_df = if params.left_is_build.unwrap() { - build_df.hstack_mut_unchecked(probe_df.get_columns()); - build_df - } else { - probe_df.hstack_mut_unchecked(build_df.get_columns()); - probe_df - }; - let out_df = postprocess_join(out_df, params); out_len = out_len .checked_add(out_df.height().try_into().unwrap()) From 4bf0de9b228da15753bc0d13ff6d1c207effe1de Mon Sep 17 00:00:00 2001 From: Orson Peters Date: Tue, 4 Mar 2025 17:06:33 +0100 Subject: [PATCH 05/25] wip --- .../polars-expr/src/idx_table/row_encoded.rs | 12 +- .../src/nodes/joins/new_equi_join.rs | 129 ++++++++++++------ .../src/physical_plan/to_graph.rs | 50 +++++-- 3 files changed, 127 insertions(+), 64 deletions(-) diff --git a/crates/polars-expr/src/idx_table/row_encoded.rs b/crates/polars-expr/src/idx_table/row_encoded.rs index b8d09f448586..4aa8785f9f18 100644 --- a/crates/polars-expr/src/idx_table/row_encoded.rs +++ b/crates/polars-expr/src/idx_table/row_encoded.rs @@ -41,7 +41,7 @@ impl RowEncodedIdxTable { if let Some(idxs) = self.idx_map.get(hash, key) { for idx in &idxs[..] { // Create matches, making sure to clear top bit. - table_match.push(idx.load(Ordering::Relaxed) & !(1 << 63)); + table_match.push((idx.load(Ordering::Relaxed) & !(1 << 63)) as IdxSize); probe_match.push(key_idx); } @@ -148,10 +148,10 @@ impl IdxTable for RowEncodedIdxTable { if let Some(key) = key { match self.idx_map.entry(*hash, key) { Entry::Occupied(o) => { - o.into_mut().push(AtomicU64::new(idx)); + o.into_mut().push(AtomicU64::new(idx as u64)); }, Entry::Vacant(v) => { - v.insert(unitvec![AtomicU64::new(idx)]); + v.insert(unitvec![AtomicU64::new(idx as u64)]); }, } } else if track_unmatchable { @@ -176,10 +176,10 @@ impl IdxTable for RowEncodedIdxTable { if let Some(key) = key { match self.idx_map.entry(hash, key) { Entry::Occupied(o) => { - o.into_mut().push(AtomicU64::new(idx)); + o.into_mut().push(AtomicU64::new(idx as u64)); }, Entry::Vacant(v) => { - v.insert(unitvec![AtomicU64::new(idx)]); + v.insert(unitvec![AtomicU64::new(idx as u64)]); }, } } else if track_unmatchable { @@ -317,7 +317,7 @@ impl IdxTable for RowEncodedIdxTable { let first_idx_val = first_idx.load(Ordering::Acquire); if first_idx_val >> 63 == 0 { for idx in &idxs[..] { - out.push(idx.load(Ordering::Relaxed) & !(1 << 63)); + out.push((idx.load(Ordering::Relaxed) & !(1 << 63)) as IdxSize); } } diff --git a/crates/polars-stream/src/nodes/joins/new_equi_join.rs b/crates/polars-stream/src/nodes/joins/new_equi_join.rs index fe6b81735202..f34b227f794c 100644 --- a/crates/polars-stream/src/nodes/joins/new_equi_join.rs +++ b/crates/polars-stream/src/nodes/joins/new_equi_join.rs @@ -5,22 +5,19 @@ use crossbeam_queue::ArrayQueue; use polars_core::prelude::*; use polars_core::schema::{Schema, SchemaExt}; use polars_utils::sync::SyncPtr; -use polars_core::series::IsSorted; -use polars_core::utils::accumulate_dataframes_vertical_unchecked; use polars_core::{config, POOL}; use polars_expr::idx_table::{new_idx_table, IdxTable}; use polars_expr::hash_keys::HashKeys; use polars_io::pl_async::get_runtime; use polars_ops::frame::{JoinArgs, JoinType, MaintainOrderJoin}; -use polars_ops::prelude::TakeChunked; use polars_ops::series::coalesce_columns; use polars_utils::cardinality_sketch::CardinalitySketch; use polars_utils::hashing::HashPartitioner; use polars_utils::itertools::Itertools; use polars_utils::pl_str::PlSmallStr; use polars_utils::{format_pl_smallstr, IdxSize}; - use arrow::array::builder::ShareStrategy; - use polars_core::frame::builder::DataFrameBuilder; +use arrow::array::builder::ShareStrategy; +use polars_core::frame::builder::DataFrameBuilder; use rayon::prelude::*; use crate::async_primitives::connector::{connector, Receiver, Sender}; @@ -668,6 +665,9 @@ impl ProbeState { probe_payload_schema = ¶ms.left_payload_schema; }; + let mut build_out = DataFrameBuilder::new(build_payload_schema.clone()); + let mut probe_out = DataFrameBuilder::new(probe_payload_schema.clone()); + while let Ok(morsel) = recv.recv().await { // Compute hashed keys and payload. let (df, seq, src_token, wait_token) = morsel.into_inner(); @@ -691,10 +691,17 @@ impl ProbeState { if params.preserve_order_probe { todo!() } else { - let mut out_frames = Vec::new(); - let mut build_out = DataFrameBuilder::new(build_payload_schema.clone()); - let mut probe_out = DataFrameBuilder::new(probe_payload_schema.clone()); - let mut out_len = 0; + let new_morsel = |mut build_df: DataFrame, mut probe_df: DataFrame| { + let out_df = if params.left_is_build.unwrap() { + build_df.hstack_mut_unchecked(probe_df.get_columns()); + build_df + } else { + probe_df.hstack_mut_unchecked(build_df.get_columns()); + probe_df + }; + let out_df = postprocess_join(out_df, params); + Morsel::new(out_df, seq, src_token.clone()) + }; for (p, idxs_in_p) in partitions.iter().zip(&partition_idxs) { let mut offset = 0; @@ -706,7 +713,7 @@ impl ProbeState { &mut probe_match, mark_matches, emit_unmatched, - probe_limit - out_len, + probe_limit - probe_out.len() as IdxSize, ) as usize; if table_match.is_empty() { @@ -725,42 +732,19 @@ impl ProbeState { } probe_out.gather_extend(&payload, &probe_match, ShareStrategy::Always); - if probe_out.len() >= probe_limit { - let out_df = if params.left_is_build.unwrap() { - build_df.hstack_mut_unchecked(probe_df.get_columns()); - build_df - } else { - probe_df.hstack_mut_unchecked(build_df.get_columns()); - probe_df - }; - let out_df = postprocess_join(out_df, params); - - } - - - out_len = out_len - .checked_add(out_df.height().try_into().unwrap()) - .unwrap(); - out_frames.push(out_df); - - if out_len >= probe_limit { - out_len = 0; - let df = - accumulate_dataframes_vertical_unchecked(out_frames.drain(..)); - let out_morsel = Morsel::new(df, seq, src_token.clone()); + if probe_out.len() >= probe_limit as usize { + let out_morsel = new_morsel(build_out.freeze_reset(), probe_out.freeze_reset()); if send.send(out_morsel).await.is_err() { - break; + return Ok(max_seq); } } } } - */ - if out_len > 0 { - let df = accumulate_dataframes_vertical_unchecked(out_frames.drain(..)); - let out_morsel = Morsel::new(df, seq, src_token.clone()); + if !probe_out.is_empty() { + let out_morsel = new_morsel(build_out.freeze_reset(), probe_out.freeze_reset()); if send.send(out_morsel).await.is_err() { - break; + return Ok(max_seq); } } } @@ -771,11 +755,11 @@ impl ProbeState { Ok(max_seq) } - + fn ordered_unmatched( &mut self, - partitioner: &HashPartitioner, - params: &EquiJoinParams, + _partitioner: &HashPartitioner, + _params: &EquiJoinParams, ) -> DataFrame { todo!() } @@ -804,7 +788,66 @@ impl EmitUnmatchedState { params: &EquiJoinParams, num_pipelines: usize, ) -> PolarsResult<()> { - todo!() + let total_len: usize = self + .partitions + .iter() + .map(|p| p.hash_table.num_keys() as usize) + .sum(); + let ideal_morsel_count = (total_len / get_ideal_morsel_size()).max(1); + let morsel_count = ideal_morsel_count.next_multiple_of(num_pipelines); + let morsel_size = total_len.div_ceil(morsel_count).max(1); + + let wait_group = WaitGroup::default(); + let source_token = SourceToken::new(); + let mut unmarked_idxs = Vec::new(); + while let Some(p) = self.partitions.get(self.active_partition_idx) { + loop { + // Generate a chunk of unmarked key indices. + self.offset_in_active_p += p.hash_table.unmarked_keys( + &mut unmarked_idxs, + self.offset_in_active_p as IdxSize, + morsel_size as IdxSize, + ) as usize; + if unmarked_idxs.is_empty() { + break; + } + + // Gather and create full-null counterpart. + let out_df = unsafe { + let mut build_df = + p.payload.take_slice_unchecked_impl(&unmarked_idxs, false); + let len = build_df.height(); + if params.left_is_build.unwrap() { + let probe_df = DataFrame::full_null(¶ms.right_payload_schema, len); + build_df.hstack_mut_unchecked(probe_df.get_columns()); + build_df + } else { + let mut probe_df = DataFrame::full_null(¶ms.left_payload_schema, len); + probe_df.hstack_mut_unchecked(build_df.get_columns()); + probe_df + } + }; + let out_df = postprocess_join(out_df, params); + + // Send and wait until consume token is consumed. + let mut morsel = Morsel::new(out_df, self.morsel_seq, source_token.clone()); + self.morsel_seq = self.morsel_seq.successor(); + morsel.set_consume_token(wait_group.token()); + if send.send(morsel).await.is_err() { + return Ok(()); + } + + wait_group.wait().await; + if source_token.stop_requested() { + return Ok(()); + } + } + + self.active_partition_idx += 1; + self.offset_in_active_p = 0; + } + + Ok(()) } } diff --git a/crates/polars-stream/src/physical_plan/to_graph.rs b/crates/polars-stream/src/physical_plan/to_graph.rs index 7f0170677e36..994e11453314 100644 --- a/crates/polars-stream/src/physical_plan/to_graph.rs +++ b/crates/polars-stream/src/physical_plan/to_graph.rs @@ -9,6 +9,7 @@ use polars_expr::planner::{create_physical_expr, get_expr_depth_limit, Expressio use polars_expr::reduce::into_reduction; use polars_expr::state::ExecutionState; use polars_mem_engine::{create_physical_plan, create_scan_predicate}; +use polars_ops::frame::MaintainOrderJoin; use polars_plan::dsl::{JoinOptions, PartitionVariant}; use polars_plan::global::_set_n_rows_for_scan; use polars_plan::plans::expr_ir::ExprIR; @@ -808,21 +809,40 @@ fn to_graph_rec<'a>( .map(|e| create_stream_expr(e, ctx, &right_input_schema)) .try_collect_vec()?; - ctx.graph.add_node( - nodes::joins::equi_join::EquiJoinNode::new( - left_input_schema, - right_input_schema, - left_key_schema, - right_key_schema, - left_key_selectors, - right_key_selectors, - args, - )?, - [ - (left_input_key, input_left.port), - (right_input_key, input_right.port), - ], - ) + // TODO: implement order-maintaining join in new join impl. + if args.maintain_order == MaintainOrderJoin::None { + ctx.graph.add_node( + nodes::joins::new_equi_join::EquiJoinNode::new( + left_input_schema, + right_input_schema, + left_key_schema, + right_key_schema, + left_key_selectors, + right_key_selectors, + args, + )?, + [ + (left_input_key, input_left.port), + (right_input_key, input_right.port), + ], + ) + } else { + ctx.graph.add_node( + nodes::joins::equi_join::EquiJoinNode::new( + left_input_schema, + right_input_schema, + left_key_schema, + right_key_schema, + left_key_selectors, + right_key_selectors, + args, + )?, + [ + (left_input_key, input_left.port), + (right_input_key, input_right.port), + ], + ) + } }, #[cfg(feature = "merge_sorted")] From 329c3757a8b5274efdfa13bc4f4c5a7606ba05d2 Mon Sep 17 00:00:00 2001 From: Orson Peters Date: Wed, 5 Mar 2025 17:22:28 +0100 Subject: [PATCH 06/25] better reserve for gathers --- .../src/nodes/joins/new_equi_join.rs | 31 ++++++++++++++++--- 1 file changed, 27 insertions(+), 4 deletions(-) diff --git a/crates/polars-stream/src/nodes/joins/new_equi_join.rs b/crates/polars-stream/src/nodes/joins/new_equi_join.rs index f34b227f794c..026f699100f1 100644 --- a/crates/polars-stream/src/nodes/joins/new_equi_join.rs +++ b/crates/polars-stream/src/nodes/joins/new_equi_join.rs @@ -667,15 +667,30 @@ impl ProbeState { let mut build_out = DataFrameBuilder::new(build_payload_schema.clone()); let mut probe_out = DataFrameBuilder::new(probe_payload_schema.clone()); + + // A simple estimate used to size reserves. + let mut selectivity_estimate = 1.0; while let Ok(morsel) = recv.recv().await { // Compute hashed keys and payload. let (df, seq, src_token, wait_token) = morsel.into_inner(); + max_seq = seq; + + let df_height = df.height(); + if df_height == 0 { + continue; + } + let hash_keys = select_keys(&df, key_selectors, params, state).await?; let mut payload = select_payload(df, payload_selector); let mut payload_rechunked = false; // We don't eagerly rechunk because there might be no matches. - - max_seq = seq; + let mut total_matches = 0; + + // Use selectivity estimate to reserve for morsel builders. + let max_match_per_key_est = selectivity_estimate as usize + 16; + let out_est_size = ((selectivity_estimate * 1.2 * df_height as f64) as usize).min(probe_limit as usize); + build_out.reserve(out_est_size + max_match_per_key_est); + probe_out.reserve(out_est_size + max_match_per_key_est); unsafe { // Partition and probe the tables. @@ -715,10 +730,11 @@ impl ProbeState { emit_unmatched, probe_limit - probe_out.len() as IdxSize, ) as usize; - - if table_match.is_empty() { + + if probe_match.is_empty() { continue; } + total_matches += probe_match.len(); // Gather output and send. if emit_unmatched { @@ -737,6 +753,10 @@ impl ProbeState { if send.send(out_morsel).await.is_err() { return Ok(max_seq); } + // We had enough matches to need a mid-partition flush, let's assume there are a lot of + // matches and just do a large reserve. + build_out.reserve(probe_limit as usize + max_match_per_key_est); + probe_out.reserve(probe_limit as usize + max_match_per_key_est); } } } @@ -751,6 +771,9 @@ impl ProbeState { } drop(wait_token); + + // Move selectivity estimate a bit towards latest value. + selectivity_estimate = 0.8 * selectivity_estimate + 0.2 * (total_matches as f64 / df_height as f64); } Ok(max_seq) From 23e0d72ea6c95c94af3ec2f61a9f10cf68d57563 Mon Sep 17 00:00:00 2001 From: Orson Peters Date: Wed, 5 Mar 2025 17:23:05 +0100 Subject: [PATCH 07/25] use specialized extend --- crates/polars-arrow/src/array/primitive/builder.rs | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/crates/polars-arrow/src/array/primitive/builder.rs b/crates/polars-arrow/src/array/primitive/builder.rs index 82177674926f..688d84b75be7 100644 --- a/crates/polars-arrow/src/array/primitive/builder.rs +++ b/crates/polars-arrow/src/array/primitive/builder.rs @@ -76,11 +76,9 @@ impl StaticArrayBuilder for PrimitiveArrayBuilder { idxs: &[IdxSize], _share: ShareStrategy, ) { - self.values.reserve(idxs.len()); - for idx in idxs { - self.values - .push_unchecked(other.value_unchecked(*idx as usize)); - } + // TODO: SIMD gather kernels? + let other_values_slice = other.values().as_slice(); + self.values.extend(idxs.iter().map(|idx| *other_values_slice.get_unchecked(*idx as usize))); self.validity .gather_extend_from_opt_validity(other.validity(), idxs); } From 9a6c18ed06f5927b261b8acda8050001ec211ce6 Mon Sep 17 00:00:00 2001 From: Orson Peters Date: Thu, 6 Mar 2025 13:26:27 +0100 Subject: [PATCH 08/25] add order-preserving probes --- .../src/chunked_idx_table/row_encoded.rs | 3 - crates/polars-expr/src/hash_keys.rs | 64 +++++++- .../polars-expr/src/idx_table/row_encoded.rs | 3 - .../src/nodes/joins/equi_join.rs | 8 +- .../src/nodes/joins/new_equi_join.rs | 141 +++++++++++++----- .../src/physical_plan/to_graph.rs | 8 +- 6 files changed, 172 insertions(+), 55 deletions(-) diff --git a/crates/polars-expr/src/chunked_idx_table/row_encoded.rs b/crates/polars-expr/src/chunked_idx_table/row_encoded.rs index 7cd8505252d0..98fa4e8a821b 100644 --- a/crates/polars-expr/src/chunked_idx_table/row_encoded.rs +++ b/crates/polars-expr/src/chunked_idx_table/row_encoded.rs @@ -70,9 +70,6 @@ impl RowEncodedChunkedIdxTable { probe_match: &mut Vec, limit: IdxSize, ) -> IdxSize { - table_match.clear(); - probe_match.clear(); - let mut keys_processed = 0; for (key_idx, hash, key) in hash_keys { let found_match = if let Some(key) = key { diff --git a/crates/polars-expr/src/hash_keys.rs b/crates/polars-expr/src/hash_keys.rs index f832e1984fd7..82cb75b92397 100644 --- a/crates/polars-expr/src/hash_keys.rs +++ b/crates/polars-expr/src/hash_keys.rs @@ -74,10 +74,25 @@ impl HashKeys { self.len() == 0 } + /// After this call partitions will be extended with the partition for each + /// hash. Nulls are assigned IdxSize::MAX or a specific partition depending + /// on whether partition_nulls is true. + pub fn gen_partitions( + &self, + partitioner: &HashPartitioner, + partitions: &mut Vec, + partition_nulls: bool, + ) { + match self { + Self::RowEncoded(s) => s.gen_partitions(partitioner, partitions, partition_nulls), + Self::Single(s) => s.gen_partitions(partitioner, partitions, partition_nulls), + } + } + /// After this call partition_idxs[p] will be extended with the indices of /// hashes that belong to partition p, and the cardinality sketches are /// updated accordingly. - pub fn gen_partition_idxs( + pub fn gen_idxs_per_partition( &self, partitioner: &HashPartitioner, partition_idxs: &mut [Vec], @@ -86,13 +101,13 @@ impl HashKeys { ) { if sketches.is_empty() { match self { - Self::RowEncoded(s) => s.gen_partition_idxs::( + Self::RowEncoded(s) => s.gen_idxs_per_partition::( partitioner, partition_idxs, sketches, partition_nulls, ), - Self::Single(s) => s.gen_partition_idxs::( + Self::Single(s) => s.gen_idxs_per_partition::( partitioner, partition_idxs, sketches, @@ -101,13 +116,13 @@ impl HashKeys { } } else { match self { - Self::RowEncoded(s) => s.gen_partition_idxs::( + Self::RowEncoded(s) => s.gen_idxs_per_partition::( partitioner, partition_idxs, sketches, partition_nulls, ), - Self::Single(s) => s.gen_partition_idxs::( + Self::Single(s) => s.gen_idxs_per_partition::( partitioner, partition_idxs, sketches, @@ -159,7 +174,33 @@ pub struct RowEncodedKeys { } impl RowEncodedKeys { - pub fn gen_partition_idxs( + pub fn gen_partitions( + &self, + partitioner: &HashPartitioner, + partitions: &mut Vec, + partition_nulls: bool, + ) { + partitions.reserve(self.hashes.len()); + if let Some(validity) = self.keys.validity() { + // Arbitrarily put nulls in partition 0. + let null_p = if partition_nulls { 0 } else { IdxSize::MAX }; + partitions.extend(self.hashes.values_iter().zip(validity).map(|(h, is_v)| { + if is_v { + partitioner.hash_to_partition(*h) as IdxSize + } else { + null_p + } + })) + } else { + partitions.extend( + self.hashes + .values_iter() + .map(|h| partitioner.hash_to_partition(*h) as IdxSize), + ) + } + } + + pub fn gen_idxs_per_partition( &self, partitioner: &HashPartitioner, partition_idxs: &mut [Vec], @@ -261,7 +302,16 @@ pub struct SingleKeys { } impl SingleKeys { - pub fn gen_partition_idxs( + pub fn gen_partitions( + &self, + _partitioner: &HashPartitioner, + _partitions: &mut Vec, + _partition_nulls: bool, + ) { + todo!() + } + + pub fn gen_idxs_per_partition( &self, partitioner: &HashPartitioner, partition_idxs: &mut [Vec], diff --git a/crates/polars-expr/src/idx_table/row_encoded.rs b/crates/polars-expr/src/idx_table/row_encoded.rs index 4aa8785f9f18..4ed7e5b4cf55 100644 --- a/crates/polars-expr/src/idx_table/row_encoded.rs +++ b/crates/polars-expr/src/idx_table/row_encoded.rs @@ -68,9 +68,6 @@ impl RowEncodedIdxTable { probe_match: &mut Vec, limit: IdxSize, ) -> IdxSize { - table_match.clear(); - probe_match.clear(); - let mut keys_processed = 0; for (key_idx, hash, key) in hash_keys { let found_match = if let Some(key) = key { diff --git a/crates/polars-stream/src/nodes/joins/equi_join.rs b/crates/polars-stream/src/nodes/joins/equi_join.rs index 024133b2f699..c50c1db09af1 100644 --- a/crates/polars-stream/src/nodes/joins/equi_join.rs +++ b/crates/polars-stream/src/nodes/joins/equi_join.rs @@ -503,7 +503,7 @@ impl BuildState { for p in partition_idxs.iter_mut() { p.clear(); } - hash_keys.gen_partition_idxs( + hash_keys.gen_idxs_per_partition( &partitioner, &mut partition_idxs, &mut sketches, @@ -678,7 +678,7 @@ impl ProbeState { for p in partition_idxs.iter_mut() { p.clear(); } - hash_keys.gen_partition_idxs( + hash_keys.gen_idxs_per_partition( &partitioner, &mut partition_idxs, &mut [], @@ -690,6 +690,8 @@ impl ProbeState { let mut out_per_partition = Vec::with_capacity(partitioner.num_partitions()); let name = PlSmallStr::from_static("__POLARS_PROBE_PRESERVE_ORDER_IDX"); for (p, idxs_in_p) in partitions.iter().zip(&partition_idxs) { + table_match.clear(); + probe_match.clear(); p.hash_table.probe_subset( &hash_keys, idxs_in_p, @@ -759,6 +761,8 @@ impl ProbeState { for (p, idxs_in_p) in partitions.iter().zip(&partition_idxs) { let mut offset = 0; while offset < idxs_in_p.len() { + table_match.clear(); + probe_match.clear(); offset += p.hash_table.probe_subset( &hash_keys, &idxs_in_p[offset..], diff --git a/crates/polars-stream/src/nodes/joins/new_equi_join.rs b/crates/polars-stream/src/nodes/joins/new_equi_join.rs index 026f699100f1..a96cc7ccb1d3 100644 --- a/crates/polars-stream/src/nodes/joins/new_equi_join.rs +++ b/crates/polars-stream/src/nodes/joins/new_equi_join.rs @@ -516,7 +516,7 @@ impl BuildState { let mut payload = select_payload(morsel.df().clone(), payload_selector); payload.rechunk_mut(); - hash_keys.gen_partition_idxs( + hash_keys.gen_idxs_per_partition( &partitioner, &mut local.morsel_idxs_values_per_p, &mut local.sketch_per_p, @@ -644,6 +644,8 @@ impl ProbeState { ) -> PolarsResult { // TODO: shuffle after partitioning and keep probe tables thread-local. let mut partition_idxs = vec![Vec::new(); partitioner.num_partitions()]; + let mut probe_partitions = Vec::new(); + let mut materialized_idxsize_range = Vec::new(); let mut table_match = Vec::new(); let mut probe_match = Vec::new(); let mut max_seq = MorselSeq::default(); @@ -690,37 +692,96 @@ impl ProbeState { let max_match_per_key_est = selectivity_estimate as usize + 16; let out_est_size = ((selectivity_estimate * 1.2 * df_height as f64) as usize).min(probe_limit as usize); build_out.reserve(out_est_size + max_match_per_key_est); - probe_out.reserve(out_est_size + max_match_per_key_est); unsafe { - // Partition and probe the tables. - for p in partition_idxs.iter_mut() { - p.clear(); - } - hash_keys.gen_partition_idxs( - &partitioner, - &mut partition_idxs, - &mut [], - emit_unmatched, - ); + let new_morsel = |build: &mut DataFrameBuilder, probe: &mut DataFrameBuilder| { + let mut build_df = build.freeze_reset(); + let mut probe_df = probe.freeze_reset(); + let out_df = if params.left_is_build.unwrap() { + build_df.hstack_mut_unchecked(probe_df.get_columns()); + build_df + } else { + probe_df.hstack_mut_unchecked(build_df.get_columns()); + probe_df + }; + let out_df = postprocess_join(out_df, params); + Morsel::new(out_df, seq, src_token.clone()) + }; + if params.preserve_order_probe { - todo!() - } else { - let new_morsel = |mut build_df: DataFrame, mut probe_df: DataFrame| { - let out_df = if params.left_is_build.unwrap() { - build_df.hstack_mut_unchecked(probe_df.get_columns()); - build_df - } else { - probe_df.hstack_mut_unchecked(build_df.get_columns()); - probe_df + // To preserve the order we can't do bulk probes per partition and must follow + // the order of the probe morsel. We can still group probes that are + // consecutively on the same partition. + hash_keys.gen_partitions(&partitioner, &mut probe_partitions, emit_unmatched); + let mut probe_group_start = 0; + while probe_group_start < probe_partitions.len() { + let p_idx = probe_partitions[probe_group_start]; + let mut probe_group_end = probe_group_start + 1; + while probe_partitions.get(probe_group_end) == Some(&p_idx) { + probe_group_end += 1; + } + let Some(p) = partitions.get(p_idx as usize) else { + probe_group_start = probe_group_end; + continue; }; - let out_df = postprocess_join(out_df, params); - Morsel::new(out_df, seq, src_token.clone()) - }; + + materialized_idxsize_range.extend(materialized_idxsize_range.len() as IdxSize..probe_group_end as IdxSize); + + while probe_group_start < probe_group_end { + let matches_before_limit = probe_limit - probe_match.len() as IdxSize; + table_match.clear(); + probe_group_start += p.hash_table.probe_subset( + &hash_keys, + &materialized_idxsize_range[probe_group_start..probe_group_end], + &mut table_match, + &mut probe_match, + mark_matches, + emit_unmatched, + matches_before_limit, + ) as usize; + + if emit_unmatched { + build_out.opt_gather_extend(&p.payload, &table_match, ShareStrategy::Always); + } else { + build_out.gather_extend(&p.payload, &table_match, ShareStrategy::Always); + }; + + if probe_match.len() >= probe_limit as usize || probe_group_start == probe_partitions.len() { + if !payload_rechunked { + payload.rechunk_mut(); + payload_rechunked = true; + } + probe_out.gather_extend(&payload, &probe_match, ShareStrategy::Always); + probe_match.clear(); + let out_morsel = new_morsel(&mut build_out, &mut probe_out); + if send.send(out_morsel).await.is_err() { + return Ok(max_seq); + } + if probe_group_end != probe_partitions.len() { + // We had enough matches to need a mid-partition flush, let's assume there are a lot of + // matches and just do a large reserve. + build_out.reserve(probe_limit as usize + max_match_per_key_est); + } + } + } + } + } else { + // Partition and probe the tables. + for p in partition_idxs.iter_mut() { + p.clear(); + } + hash_keys.gen_idxs_per_partition( + &partitioner, + &mut partition_idxs, + &mut [], + emit_unmatched, + ); for (p, idxs_in_p) in partitions.iter().zip(&partition_idxs) { let mut offset = 0; while offset < idxs_in_p.len() { + let matches_before_limit = probe_limit - probe_match.len() as IdxSize; + table_match.clear(); offset += p.hash_table.probe_subset( &hash_keys, &idxs_in_p[offset..], @@ -728,41 +789,45 @@ impl ProbeState { &mut probe_match, mark_matches, emit_unmatched, - probe_limit - probe_out.len() as IdxSize, + matches_before_limit, ) as usize; - if probe_match.is_empty() { + if table_match.is_empty() { continue; } - total_matches += probe_match.len(); + total_matches += table_match.len(); - // Gather output and send. if emit_unmatched { build_out.opt_gather_extend(&p.payload, &table_match, ShareStrategy::Always); } else { build_out.gather_extend(&p.payload, &table_match, ShareStrategy::Always); }; - if !payload_rechunked { - payload.rechunk_mut(); - payload_rechunked = true; - } - probe_out.gather_extend(&payload, &probe_match, ShareStrategy::Always); - if probe_out.len() >= probe_limit as usize { - let out_morsel = new_morsel(build_out.freeze_reset(), probe_out.freeze_reset()); + if probe_match.len() >= probe_limit as usize { + if !payload_rechunked { + payload.rechunk_mut(); + payload_rechunked = true; + } + probe_out.gather_extend(&payload, &probe_match, ShareStrategy::Always); + probe_match.clear(); + let out_morsel = new_morsel(&mut build_out, &mut probe_out); if send.send(out_morsel).await.is_err() { return Ok(max_seq); } // We had enough matches to need a mid-partition flush, let's assume there are a lot of // matches and just do a large reserve. build_out.reserve(probe_limit as usize + max_match_per_key_est); - probe_out.reserve(probe_limit as usize + max_match_per_key_est); } } } - if !probe_out.is_empty() { - let out_morsel = new_morsel(build_out.freeze_reset(), probe_out.freeze_reset()); + if !probe_match.is_empty() { + if !payload_rechunked { + payload.rechunk_mut(); + } + probe_out.gather_extend(&payload, &probe_match, ShareStrategy::Always); + probe_match.clear(); + let out_morsel = new_morsel(&mut build_out, &mut probe_out); if send.send(out_morsel).await.is_err() { return Ok(max_seq); } diff --git a/crates/polars-stream/src/physical_plan/to_graph.rs b/crates/polars-stream/src/physical_plan/to_graph.rs index 994e11453314..3297decaecf0 100644 --- a/crates/polars-stream/src/physical_plan/to_graph.rs +++ b/crates/polars-stream/src/physical_plan/to_graph.rs @@ -809,8 +809,12 @@ fn to_graph_rec<'a>( .map(|e| create_stream_expr(e, ctx, &right_input_schema)) .try_collect_vec()?; - // TODO: implement order-maintaining join in new join impl. - if args.maintain_order == MaintainOrderJoin::None { + // TODO: implement build-side order-maintaining join in new join impl. + let preserve_order_build = matches!( + args.maintain_order, + MaintainOrderJoin::LeftRight | MaintainOrderJoin::RightLeft + ); + if !preserve_order_build { ctx.graph.add_node( nodes::joins::new_equi_join::EquiJoinNode::new( left_input_schema, From 7c5469c79f0ff93e1dbdf32f4e0703e4eb51a49a Mon Sep 17 00:00:00 2001 From: Orson Peters Date: Thu, 6 Mar 2025 13:27:44 +0100 Subject: [PATCH 09/25] fmt --- .../src/array/primitive/builder.rs | 5 +- crates/polars-expr/src/idx_table/mod.rs | 12 +- .../polars-expr/src/idx_table/row_encoded.rs | 25 +++- .../src/nodes/joins/equi_join.rs | 18 ++- crates/polars-stream/src/nodes/joins/mod.rs | 2 +- .../src/nodes/joins/new_equi_join.rs | 123 ++++++++++++------ 6 files changed, 125 insertions(+), 60 deletions(-) diff --git a/crates/polars-arrow/src/array/primitive/builder.rs b/crates/polars-arrow/src/array/primitive/builder.rs index 688d84b75be7..c942ae71553a 100644 --- a/crates/polars-arrow/src/array/primitive/builder.rs +++ b/crates/polars-arrow/src/array/primitive/builder.rs @@ -78,7 +78,10 @@ impl StaticArrayBuilder for PrimitiveArrayBuilder { ) { // TODO: SIMD gather kernels? let other_values_slice = other.values().as_slice(); - self.values.extend(idxs.iter().map(|idx| *other_values_slice.get_unchecked(*idx as usize))); + self.values.extend( + idxs.iter() + .map(|idx| *other_values_slice.get_unchecked(*idx as usize)), + ); self.validity .gather_extend_from_opt_validity(other.validity(), idxs); } diff --git a/crates/polars-expr/src/idx_table/mod.rs b/crates/polars-expr/src/idx_table/mod.rs index 7996d2f74e77..76a5093a1102 100644 --- a/crates/polars-expr/src/idx_table/mod.rs +++ b/crates/polars-expr/src/idx_table/mod.rs @@ -19,11 +19,16 @@ pub trait IdxTable: Any + Send + Sync { /// Inserts the given keys into this IdxTable. fn insert_keys(&mut self, keys: &HashKeys, track_unmatchable: bool); - + /// Inserts a subset of the given keys into this IdxTable. /// # Safety /// The provided subset indices must be in-bounds. - unsafe fn insert_keys_subset(&mut self, keys: &HashKeys, subset: &[IdxSize], track_unmatchable: bool); + unsafe fn insert_keys_subset( + &mut self, + keys: &HashKeys, + subset: &[IdxSize], + track_unmatchable: bool, + ); /// Probe the table, adding an entry to table_match and probe_match for each /// match. Will stop processing new keys once limit matches have been @@ -59,8 +64,7 @@ pub trait IdxTable: Any + Send + Sync { ) -> IdxSize; /// Get the ChunkIds for each key which was never marked during probing. - fn unmarked_keys(&self, out: &mut Vec, offset: IdxSize, limit: IdxSize) - -> IdxSize; + fn unmarked_keys(&self, out: &mut Vec, offset: IdxSize, limit: IdxSize) -> IdxSize; } pub fn new_idx_table(_key_schema: Arc) -> Box { diff --git a/crates/polars-expr/src/idx_table/row_encoded.rs b/crates/polars-expr/src/idx_table/row_encoded.rs index 4ed7e5b4cf55..80b214f71feb 100644 --- a/crates/polars-expr/src/idx_table/row_encoded.rs +++ b/crates/polars-expr/src/idx_table/row_encoded.rs @@ -132,8 +132,13 @@ impl IdxTable for RowEncodedIdxTable { let HashKeys::RowEncoded(hash_keys) = hash_keys else { unreachable!() }; - let new_idx_offset = (self.idx_offset as usize).checked_add(hash_keys.keys.len()).unwrap(); - assert!(new_idx_offset < IdxSize::MAX as usize, "overly large index in RowEncodedIdxTable"); + let new_idx_offset = (self.idx_offset as usize) + .checked_add(hash_keys.keys.len()) + .unwrap(); + assert!( + new_idx_offset < IdxSize::MAX as usize, + "overly large index in RowEncodedIdxTable" + ); for (i, (hash, key)) in hash_keys .hashes @@ -159,12 +164,22 @@ impl IdxTable for RowEncodedIdxTable { self.idx_offset = new_idx_offset as IdxSize; } - unsafe fn insert_keys_subset(&mut self, hash_keys: &HashKeys, subset: &[IdxSize], track_unmatchable: bool) { + unsafe fn insert_keys_subset( + &mut self, + hash_keys: &HashKeys, + subset: &[IdxSize], + track_unmatchable: bool, + ) { let HashKeys::RowEncoded(hash_keys) = hash_keys else { unreachable!() }; - let new_idx_offset = (self.idx_offset as usize).checked_add(subset.len()).unwrap(); - assert!(new_idx_offset < IdxSize::MAX as usize, "overly large index in RowEncodedIdxTable"); + let new_idx_offset = (self.idx_offset as usize) + .checked_add(subset.len()) + .unwrap(); + assert!( + new_idx_offset < IdxSize::MAX as usize, + "overly large index in RowEncodedIdxTable" + ); for (i, subset_idx) in subset.iter().enumerate_idx() { let hash = unsafe { hash_keys.hashes.value_unchecked(*subset_idx as usize) }; diff --git a/crates/polars-stream/src/nodes/joins/equi_join.rs b/crates/polars-stream/src/nodes/joins/equi_join.rs index c50c1db09af1..b03c4f2a07fe 100644 --- a/crates/polars-stream/src/nodes/joins/equi_join.rs +++ b/crates/polars-stream/src/nodes/joins/equi_join.rs @@ -710,7 +710,8 @@ impl ProbeState { let mut build_df = if emit_unmatched { p.payload.take_opt_chunked_unchecked(&table_match, false) } else { - p.payload.take_chunked_unchecked(&table_match, IsSorted::Not, false) + p.payload + .take_chunked_unchecked(&table_match, IsSorted::Not, false) }; if !payload_rechunked { @@ -781,7 +782,8 @@ impl ProbeState { let mut build_df = if emit_unmatched { p.payload.take_opt_chunked_unchecked(&table_match, false) } else { - p.payload.take_chunked_unchecked(&table_match, IsSorted::Not, false) + p.payload + .take_chunked_unchecked(&table_match, IsSorted::Not, false) }; if !payload_rechunked { // TODO: can avoid rechunk? We have to rechunk here or else we do it @@ -845,11 +847,13 @@ impl ProbeState { let mut unmarked_idxs = Vec::new(); unsafe { for p in self.table_per_partition.iter() { - p.hash_table.unmarked_keys(&mut unmarked_idxs, 0, IdxSize::MAX); + p.hash_table + .unmarked_keys(&mut unmarked_idxs, 0, IdxSize::MAX); // Gather and create full-null counterpart. let mut build_df = - p.payload.take_chunked_unchecked(&unmarked_idxs, IsSorted::Not, false); + p.payload + .take_chunked_unchecked(&unmarked_idxs, IsSorted::Not, false); let len = build_df.height(); let mut out_df = if params.left_is_build.unwrap() { let probe_df = DataFrame::full_null(¶ms.right_payload_schema, len); @@ -948,7 +952,8 @@ impl EmitUnmatchedState { // Gather and create full-null counterpart. let out_df = unsafe { let mut build_df = - p.payload.take_chunked_unchecked(&unmarked_idxs, IsSorted::Not, false); + p.payload + .take_chunked_unchecked(&unmarked_idxs, IsSorted::Not, false); let len = build_df.height(); if params.left_is_build.unwrap() { let probe_df = DataFrame::full_null(¶ms.right_payload_schema, len); @@ -1094,7 +1099,8 @@ impl EquiJoinNode { }; let left_payload_schema = Arc::new(select_schema(&left_input_schema, &left_payload_select)); - let right_payload_schema = Arc::new(select_schema(&right_input_schema, &right_payload_select)); + let right_payload_schema = + Arc::new(select_schema(&right_input_schema, &right_payload_select)); Ok(Self { state, num_pipelines: 0, diff --git a/crates/polars-stream/src/nodes/joins/mod.rs b/crates/polars-stream/src/nodes/joins/mod.rs index 14706fa4407f..ea6607737c43 100644 --- a/crates/polars-stream/src/nodes/joins/mod.rs +++ b/crates/polars-stream/src/nodes/joins/mod.rs @@ -1,3 +1,3 @@ pub mod equi_join; -pub mod new_equi_join; pub mod in_memory; +pub mod new_equi_join; diff --git a/crates/polars-stream/src/nodes/joins/new_equi_join.rs b/crates/polars-stream/src/nodes/joins/new_equi_join.rs index a96cc7ccb1d3..8d4ee4d7a7e6 100644 --- a/crates/polars-stream/src/nodes/joins/new_equi_join.rs +++ b/crates/polars-stream/src/nodes/joins/new_equi_join.rs @@ -1,13 +1,14 @@ use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::{Arc, LazyLock}; +use arrow::array::builder::ShareStrategy; use crossbeam_queue::ArrayQueue; +use polars_core::frame::builder::DataFrameBuilder; use polars_core::prelude::*; use polars_core::schema::{Schema, SchemaExt}; - use polars_utils::sync::SyncPtr; use polars_core::{config, POOL}; -use polars_expr::idx_table::{new_idx_table, IdxTable}; use polars_expr::hash_keys::HashKeys; +use polars_expr::idx_table::{new_idx_table, IdxTable}; use polars_io::pl_async::get_runtime; use polars_ops::frame::{JoinArgs, JoinType, MaintainOrderJoin}; use polars_ops::series::coalesce_columns; @@ -15,9 +16,8 @@ use polars_utils::cardinality_sketch::CardinalitySketch; use polars_utils::hashing::HashPartitioner; use polars_utils::itertools::Itertools; use polars_utils::pl_str::PlSmallStr; +use polars_utils::sync::SyncPtr; use polars_utils::{format_pl_smallstr, IdxSize}; -use arrow::array::builder::ShareStrategy; -use polars_core::frame::builder::DataFrameBuilder; use rayon::prelude::*; use crate::async_primitives::connector::{connector, Receiver, Sender}; @@ -416,7 +416,9 @@ impl SampleState { let partitioner = HashPartitioner::new(num_pipelines, 0); let mut build_state = BuildState { - local_builders: (0..num_pipelines).map(|_| LocalBuilder::default()).collect(), + local_builders: (0..num_pipelines) + .map(|_| LocalBuilder::default()) + .collect(), sampled_probe_morsels, }; @@ -429,8 +431,7 @@ impl SampleState { .reinsert(num_pipelines, None, scope, &mut join_handles) .unwrap(); - for (local_builder, recv) in build_state.local_builders.iter_mut().zip(receivers) - { + for (local_builder, recv) in build_state.local_builders.iter_mut().zip(receivers) { join_handles.push(scope.spawn_task( TaskPriority::High, BuildState::partition_and_sink( @@ -548,7 +549,6 @@ impl BuildState { .collect_vec(); let num_partitions = self.local_builders[0].sketch_per_p.len(); let local_builders = &self.local_builders; - let mut probe_tables = POOL.scope(|s| { let mut probe_tables: Vec = Vec::with_capacity(num_partitions); @@ -561,8 +561,9 @@ impl BuildState { let arc_morsels_per_local_builder = Arc::clone(&arc_morsels_per_local_builder); s.spawn(move |_| { // Extract from outer arc and drop outer arc. - let morsels_per_local_builder = Arc::unwrap_or_clone(arc_morsels_per_local_builder); - + let morsels_per_local_builder = + Arc::unwrap_or_clone(arc_morsels_per_local_builder); + // Compute cardinality estimate and total amount of // payload for this partition. let mut sketch = CardinalitySketch::new(); @@ -570,29 +571,37 @@ impl BuildState { for l in local_builders { sketch.combine(&l.sketch_per_p[p]); let offsets_len = l.morsel_idxs_offsets_per_p.len(); - payload_rows += l.morsel_idxs_offsets_per_p[offsets_len - num_partitions + p]; + payload_rows += + l.morsel_idxs_offsets_per_p[offsets_len - num_partitions + p]; } - + // Allocate hash table and payload builder. let mut p_table = table.new_empty(); p_table.reserve(sketch.estimate() * 5 / 4); let mut p_payload = DataFrameBuilder::new(payload_schema.clone()); p_payload.reserve(payload_rows); - + // Build. for (l, l_morsels) in local_builders.iter().zip(morsels_per_local_builder) { for (i, morsel) in l_morsels.iter().enumerate() { let (_mseq, payload, keys) = morsel; unsafe { - let p_morsel_idxs_start = l.morsel_idxs_offsets_per_p[i * num_partitions + p]; - let p_morsel_idxs_stop = l.morsel_idxs_offsets_per_p[(i + 1) * num_partitions + p]; - let p_morsel_idxs = &l.morsel_idxs_values_per_p[p][p_morsel_idxs_start..p_morsel_idxs_stop]; + let p_morsel_idxs_start = + l.morsel_idxs_offsets_per_p[i * num_partitions + p]; + let p_morsel_idxs_stop = + l.morsel_idxs_offsets_per_p[(i + 1) * num_partitions + p]; + let p_morsel_idxs = &l.morsel_idxs_values_per_p[p] + [p_morsel_idxs_start..p_morsel_idxs_stop]; p_table.insert_keys_subset(keys, p_morsel_idxs, track_unmatchable); - p_payload.gather_extend(payload, p_morsel_idxs, ShareStrategy::Never); + p_payload.gather_extend( + payload, + p_morsel_idxs, + ShareStrategy::Never, + ); } } } - + unsafe { probe_table_ptr.get().add(p).write(ProbeTable { hash_table: p_table, @@ -607,7 +616,7 @@ impl BuildState { drop(arc_morsels_per_local_builder); probe_tables }); - + unsafe { // SAFETY: all entries are initialized now. probe_tables.set_len(num_partitions); @@ -669,7 +678,7 @@ impl ProbeState { let mut build_out = DataFrameBuilder::new(build_payload_schema.clone()); let mut probe_out = DataFrameBuilder::new(probe_payload_schema.clone()); - + // A simple estimate used to size reserves. let mut selectivity_estimate = 1.0; @@ -687,10 +696,11 @@ impl ProbeState { let mut payload = select_payload(df, payload_selector); let mut payload_rechunked = false; // We don't eagerly rechunk because there might be no matches. let mut total_matches = 0; - + // Use selectivity estimate to reserve for morsel builders. let max_match_per_key_est = selectivity_estimate as usize + 16; - let out_est_size = ((selectivity_estimate * 1.2 * df_height as f64) as usize).min(probe_limit as usize); + let out_est_size = ((selectivity_estimate * 1.2 * df_height as f64) as usize) + .min(probe_limit as usize); build_out.reserve(out_est_size + max_match_per_key_est); unsafe { @@ -725,8 +735,10 @@ impl ProbeState { continue; }; - materialized_idxsize_range.extend(materialized_idxsize_range.len() as IdxSize..probe_group_end as IdxSize); - + materialized_idxsize_range.extend( + materialized_idxsize_range.len() as IdxSize..probe_group_end as IdxSize, + ); + while probe_group_start < probe_group_end { let matches_before_limit = probe_limit - probe_match.len() as IdxSize; table_match.clear(); @@ -739,19 +751,33 @@ impl ProbeState { emit_unmatched, matches_before_limit, ) as usize; - + if emit_unmatched { - build_out.opt_gather_extend(&p.payload, &table_match, ShareStrategy::Always); + build_out.opt_gather_extend( + &p.payload, + &table_match, + ShareStrategy::Always, + ); } else { - build_out.gather_extend(&p.payload, &table_match, ShareStrategy::Always); + build_out.gather_extend( + &p.payload, + &table_match, + ShareStrategy::Always, + ); }; - if probe_match.len() >= probe_limit as usize || probe_group_start == probe_partitions.len() { + if probe_match.len() >= probe_limit as usize + || probe_group_start == probe_partitions.len() + { if !payload_rechunked { payload.rechunk_mut(); payload_rechunked = true; } - probe_out.gather_extend(&payload, &probe_match, ShareStrategy::Always); + probe_out.gather_extend( + &payload, + &probe_match, + ShareStrategy::Always, + ); probe_match.clear(); let out_morsel = new_morsel(&mut build_out, &mut probe_out); if send.send(out_morsel).await.is_err() { @@ -791,24 +817,36 @@ impl ProbeState { emit_unmatched, matches_before_limit, ) as usize; - + if table_match.is_empty() { continue; } total_matches += table_match.len(); if emit_unmatched { - build_out.opt_gather_extend(&p.payload, &table_match, ShareStrategy::Always); + build_out.opt_gather_extend( + &p.payload, + &table_match, + ShareStrategy::Always, + ); } else { - build_out.gather_extend(&p.payload, &table_match, ShareStrategy::Always); + build_out.gather_extend( + &p.payload, + &table_match, + ShareStrategy::Always, + ); }; - + if probe_match.len() >= probe_limit as usize { if !payload_rechunked { payload.rechunk_mut(); payload_rechunked = true; } - probe_out.gather_extend(&payload, &probe_match, ShareStrategy::Always); + probe_out.gather_extend( + &payload, + &probe_match, + ShareStrategy::Always, + ); probe_match.clear(); let out_morsel = new_morsel(&mut build_out, &mut probe_out); if send.send(out_morsel).await.is_err() { @@ -836,14 +874,15 @@ impl ProbeState { } drop(wait_token); - + // Move selectivity estimate a bit towards latest value. - selectivity_estimate = 0.8 * selectivity_estimate + 0.2 * (total_matches as f64 / df_height as f64); + selectivity_estimate = + 0.8 * selectivity_estimate + 0.2 * (total_matches as f64 / df_height as f64); } Ok(max_seq) } - + fn ordered_unmatched( &mut self, _partitioner: &HashPartitioner, @@ -902,8 +941,7 @@ impl EmitUnmatchedState { // Gather and create full-null counterpart. let out_df = unsafe { - let mut build_df = - p.payload.take_slice_unchecked_impl(&unmarked_idxs, false); + let mut build_df = p.payload.take_slice_unchecked_impl(&unmarked_idxs, false); let len = build_df.height(); if params.left_is_build.unwrap() { let probe_df = DataFrame::full_null(¶ms.right_payload_schema, len); @@ -939,7 +977,6 @@ impl EmitUnmatchedState { } } - enum EquiJoinState { Sample(SampleState), Build(BuildState), @@ -1050,7 +1087,8 @@ impl EquiJoinNode { }; let left_payload_schema = Arc::new(select_schema(&left_input_schema, &left_payload_select)); - let right_payload_schema = Arc::new(select_schema(&right_input_schema, &right_payload_select)); + let right_payload_schema = + Arc::new(select_schema(&right_input_schema, &right_payload_select)); Ok(Self { state, num_pipelines: 0, @@ -1290,8 +1328,7 @@ impl ComputeNode for EquiJoinNode { .local_builders .resize_with(self.num_pipelines, Default::default); let partitioner = HashPartitioner::new(self.num_pipelines, 0); - for (local_builder, recv) in build_state.local_builders.iter_mut().zip(receivers) - { + for (local_builder, recv) in build_state.local_builders.iter_mut().zip(receivers) { join_handles.push(scope.spawn_task( TaskPriority::High, BuildState::partition_and_sink( From 99bbf73a9abbbf65343c37b854cacfc2c855cebf Mon Sep 17 00:00:00 2001 From: Orson Peters Date: Thu, 6 Mar 2025 13:29:27 +0100 Subject: [PATCH 10/25] clippy --- crates/polars-expr/src/hash_keys.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/crates/polars-expr/src/hash_keys.rs b/crates/polars-expr/src/hash_keys.rs index 82cb75b92397..1674eb705fd1 100644 --- a/crates/polars-expr/src/hash_keys.rs +++ b/crates/polars-expr/src/hash_keys.rs @@ -302,6 +302,7 @@ pub struct SingleKeys { } impl SingleKeys { + #[allow(clippy::ptr_arg)] // Remove when implemented. pub fn gen_partitions( &self, _partitioner: &HashPartitioner, From 24637dd5c832e75f0a18c7d93966e4168f86feb6 Mon Sep 17 00:00:00 2001 From: Orson Peters Date: Thu, 6 Mar 2025 13:46:07 +0100 Subject: [PATCH 11/25] more clippy garbage --- crates/polars-expr/src/idx_table/row_encoded.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/crates/polars-expr/src/idx_table/row_encoded.rs b/crates/polars-expr/src/idx_table/row_encoded.rs index 80b214f71feb..881a78cb30bc 100644 --- a/crates/polars-expr/src/idx_table/row_encoded.rs +++ b/crates/polars-expr/src/idx_table/row_encoded.rs @@ -1,3 +1,5 @@ +#![allow(clippy::unnecessary_cast)] // Clippy doesn't recognize that IdxSize and u64 can be different. + use std::sync::atomic::{AtomicU64, Ordering}; use arrow::array::Array; From 07bd7680ba2e48e6155abf2fcac5a2177670b00b Mon Sep 17 00:00:00 2001 From: Orson Peters Date: Thu, 6 Mar 2025 17:29:21 +0100 Subject: [PATCH 12/25] try to solve object builder --- crates/polars-arrow/src/bitmap/builder.rs | 34 +++++++ .../src/chunked_array/object/builder.rs | 93 +++++++++++++++++++ .../src/chunked_array/object/registry.rs | 12 ++- crates/polars-core/src/series/builder.rs | 8 +- 4 files changed, 144 insertions(+), 3 deletions(-) diff --git a/crates/polars-arrow/src/bitmap/builder.rs b/crates/polars-arrow/src/bitmap/builder.rs index 6a4bf6013ec6..fff539531a88 100644 --- a/crates/polars-arrow/src/bitmap/builder.rs +++ b/crates/polars-arrow/src/bitmap/builder.rs @@ -251,6 +251,13 @@ impl BitmapBuilder { self.extend_from_slice(slice, bm_offset + start, length); } + pub fn subslice_extend_from_opt_validity(&mut self, bitmap: Option<&Bitmap>, start: usize, length: usize) { + match bitmap { + Some(bm) => self.subslice_extend_from_bitmap(bm, start, length), + None => self.extend_constant(length, true), + } + } + /// # Safety /// The indices must be in-bounds. pub unsafe fn gather_extend_from_slice( @@ -308,6 +315,33 @@ impl BitmapBuilder { self.opt_gather_extend_from_slice(slice, offset, length, idxs); } + /// # Safety + /// The indices must be in-bounds. + pub unsafe fn gather_extend_from_opt_validity(&mut self, bitmap: Option<&Bitmap>, idxs: &[IdxSize], length: usize) { + if let Some(bm) = bitmap { + let (slice, offset, sl_length) = bm.as_slice(); + debug_assert_eq!(sl_length, length); + self.gather_extend_from_slice(slice, offset, length, idxs); + } else { + self.extend_constant(length, true); + } + } + + pub fn opt_gather_extend_from_opt_validity(&mut self, bitmap: Option<&Bitmap>, idxs: &[IdxSize], length: usize) { + if let Some(bm) = bitmap { + let (slice, offset, sl_length) = bm.as_slice(); + debug_assert_eq!(sl_length, length); + self.opt_gather_extend_from_slice(slice, offset, sl_length, idxs); + } else { + unsafe { + self.reserve(idxs.len()); + for idx in idxs { + self.push_unchecked((*idx as usize) < length); + } + } + } + } + /// # Safety /// May only be called once at the end. unsafe fn finish(&mut self) { diff --git a/crates/polars-core/src/chunked_array/object/builder.rs b/crates/polars-core/src/chunked_array/object/builder.rs index 7010ab59da49..6bba5bf32c39 100644 --- a/crates/polars-core/src/chunked_array/object/builder.rs +++ b/crates/polars-core/src/chunked_array/object/builder.rs @@ -1,4 +1,6 @@ +use arrow::array::builder::{ArrayBuilder, ShareStrategy}; use arrow::bitmap::BitmapBuilder; +use polars_utils::vec::PushUnchecked; use super::*; use crate::utils::get_iter_capacity; @@ -176,3 +178,94 @@ pub(crate) fn object_series_to_arrow_array(s: &Series) -> ArrayRef { let arr = arr.as_any().downcast_ref::>().unwrap(); arr.values().to_boxed() } + + +impl ArrayBuilder for ObjectChunkedBuilder { + fn dtype(&self) -> &ArrowDataType { + &ArrowDataType::FixedSizeBinary(size_of::()) + } + + fn reserve(&mut self, additional: usize) { + self.bitmask_builder.reserve(additional); + self.values.reserve(additional); + } + + fn freeze(self) -> Box { + Box::new(ObjectArray { + values: self.values.into(), + validity: self.bitmask_builder.into_opt_validity(), + }) + } + + fn freeze_reset(&mut self) -> Box { + Box::new(ObjectArray { + values: core::mem::take(&mut self.values).into(), + validity: core::mem::take(&mut self.bitmask_builder).into_opt_validity(), + }) + } + + fn len(&self) -> usize { + self.values.len() + } + + fn extend_nulls(&mut self, length: usize) { + self.values.resize(self.values.len() + length, T::default()); + self.bitmask_builder.extend_constant(length, false); + } + + fn subslice_extend( + &mut self, + other: &dyn Array, + start: usize, + length: usize, + _share: ShareStrategy, + ) { + let other: &ObjectArray = other.as_any().downcast_ref().unwrap(); + self.values + .extend_from_slice(&other.values[start..start + length]); + self.bitmask_builder + .subslice_extend_from_opt_validity(other.validity(), start, length); + } + + fn subslice_extend_repeated( + &mut self, + other: &dyn Array, + start: usize, + length: usize, + repeats: usize, + share: ShareStrategy, + ) { + for _ in 0..repeats { + self.subslice_extend(other, start, length, share) + } + } + + unsafe fn gather_extend(&mut self, other: &dyn Array, idxs: &[IdxSize], _share: ShareStrategy) { + let other: &ObjectArray = other.as_any().downcast_ref().unwrap(); + let other_values_slice = other.values.as_slice(); + self.values.extend( + idxs.iter() + .map(|idx| other_values_slice.get_unchecked(*idx as usize).clone()), + ); + self.bitmask_builder + .gather_extend_from_opt_validity(other.validity(), idxs, other.len()); + } + + fn opt_gather_extend(&mut self, other: &dyn Array, idxs: &[IdxSize], _share: ShareStrategy) { + let other: &ObjectArray = other.as_any().downcast_ref().unwrap(); + let other_values_slice = other.values.as_slice(); + self.values.reserve(idxs.len()); + unsafe { + for idx in idxs { + let val = if (*idx as usize) < other.len() { + other_values_slice.get_unchecked(*idx as usize).clone() + } else { + T::default() + }; + self.values.push_unchecked(val); + } + } + self.bitmask_builder + .opt_gather_extend_from_opt_validity(other.validity(), idxs, other.len()); + } +} \ No newline at end of file diff --git a/crates/polars-core/src/chunked_array/object/registry.rs b/crates/polars-core/src/chunked_array/object/registry.rs index 7364cf0361d7..bff4b05f9012 100644 --- a/crates/polars-core/src/chunked_array/object/registry.rs +++ b/crates/polars-core/src/chunked_array/object/registry.rs @@ -8,6 +8,7 @@ use std::ops::Deref; use std::sync::{Arc, RwLock}; use arrow::array::ArrayRef; +use arrow::array::builder::ArrayBuilder; use arrow::datatypes::ArrowDataType; use once_cell::sync::Lazy; use polars_utils::pl_str::PlSmallStr; @@ -40,7 +41,9 @@ static GLOBAL_OBJECT_REGISTRY: Lazy>> = Lazy::new( /// This trait can be registered, after which that global registration /// can be used to materialize object types -pub trait AnonymousObjectBuilder { +pub trait AnonymousObjectBuilder : ArrayBuilder { + fn as_array_builder(self: Box) -> Box; + /// # Safety /// Expect `ObjectArray` arrays. unsafe fn from_chunks(self: Box, chunks: Vec) -> Series; @@ -73,12 +76,17 @@ pub trait AnonymousObjectBuilder { } impl AnonymousObjectBuilder for ObjectChunkedBuilder { - // Expect ObjectArray arrays. + /// # Safety + /// Expects ObjectArray arrays. unsafe fn from_chunks(self: Box, chunks: Vec) -> Series { ObjectChunked::::new_with_compute_len(Arc::new(self.field().clone()), chunks) .into_series() } + fn as_array_builder(self: Box) -> Box { + self + } + fn append_null(&mut self) { self.append_null() } diff --git a/crates/polars-core/src/series/builder.rs b/crates/polars-core/src/series/builder.rs index 1d1f92de2b5e..73459ffae1fb 100644 --- a/crates/polars-core/src/series/builder.rs +++ b/crates/polars-core/src/series/builder.rs @@ -1,6 +1,7 @@ use arrow::array::builder::{make_builder, ArrayBuilder, ShareStrategy}; use polars_utils::IdxSize; +use crate::chunked_array::object::registry::get_object_builder; use crate::prelude::*; use crate::utils::Container; @@ -12,7 +13,12 @@ pub struct SeriesBuilder { impl SeriesBuilder { pub fn new(dtype: DataType) -> Self { - let builder = make_builder(&dtype.to_physical().to_arrow(CompatLevel::newest())); + let builder = if matches!(dtype, DataType::Object(_)) { + // FIXME: get rid of this hack. + get_object_builder(PlSmallStr::EMPTY, 0).as_array_builder() + } else { + make_builder(&dtype.to_physical().to_arrow(CompatLevel::newest())) + }; Self { dtype, builder } } From dc9da504a13a9031a7f74d977c253e3e769c793a Mon Sep 17 00:00:00 2001 From: Orson Peters Date: Fri, 7 Mar 2025 10:18:19 +0100 Subject: [PATCH 13/25] fmt --- crates/polars-arrow/src/bitmap/builder.rs | 21 ++++++++++++++++--- .../src/chunked_array/object/builder.rs | 10 +++++---- .../src/chunked_array/object/registry.rs | 6 +++--- 3 files changed, 27 insertions(+), 10 deletions(-) diff --git a/crates/polars-arrow/src/bitmap/builder.rs b/crates/polars-arrow/src/bitmap/builder.rs index fff539531a88..1ea13f865f43 100644 --- a/crates/polars-arrow/src/bitmap/builder.rs +++ b/crates/polars-arrow/src/bitmap/builder.rs @@ -251,7 +251,12 @@ impl BitmapBuilder { self.extend_from_slice(slice, bm_offset + start, length); } - pub fn subslice_extend_from_opt_validity(&mut self, bitmap: Option<&Bitmap>, start: usize, length: usize) { + pub fn subslice_extend_from_opt_validity( + &mut self, + bitmap: Option<&Bitmap>, + start: usize, + length: usize, + ) { match bitmap { Some(bm) => self.subslice_extend_from_bitmap(bm, start, length), None => self.extend_constant(length, true), @@ -317,7 +322,12 @@ impl BitmapBuilder { /// # Safety /// The indices must be in-bounds. - pub unsafe fn gather_extend_from_opt_validity(&mut self, bitmap: Option<&Bitmap>, idxs: &[IdxSize], length: usize) { + pub unsafe fn gather_extend_from_opt_validity( + &mut self, + bitmap: Option<&Bitmap>, + idxs: &[IdxSize], + length: usize, + ) { if let Some(bm) = bitmap { let (slice, offset, sl_length) = bm.as_slice(); debug_assert_eq!(sl_length, length); @@ -327,7 +337,12 @@ impl BitmapBuilder { } } - pub fn opt_gather_extend_from_opt_validity(&mut self, bitmap: Option<&Bitmap>, idxs: &[IdxSize], length: usize) { + pub fn opt_gather_extend_from_opt_validity( + &mut self, + bitmap: Option<&Bitmap>, + idxs: &[IdxSize], + length: usize, + ) { if let Some(bm) = bitmap { let (slice, offset, sl_length) = bm.as_slice(); debug_assert_eq!(sl_length, length); diff --git a/crates/polars-core/src/chunked_array/object/builder.rs b/crates/polars-core/src/chunked_array/object/builder.rs index 6bba5bf32c39..05e522c80150 100644 --- a/crates/polars-core/src/chunked_array/object/builder.rs +++ b/crates/polars-core/src/chunked_array/object/builder.rs @@ -179,7 +179,6 @@ pub(crate) fn object_series_to_arrow_array(s: &Series) -> ArrayRef { arr.values().to_boxed() } - impl ArrayBuilder for ObjectChunkedBuilder { fn dtype(&self) -> &ArrowDataType { &ArrowDataType::FixedSizeBinary(size_of::()) @@ -265,7 +264,10 @@ impl ArrayBuilder for ObjectChunkedBuilder { self.values.push_unchecked(val); } } - self.bitmask_builder - .opt_gather_extend_from_opt_validity(other.validity(), idxs, other.len()); + self.bitmask_builder.opt_gather_extend_from_opt_validity( + other.validity(), + idxs, + other.len(), + ); } -} \ No newline at end of file +} diff --git a/crates/polars-core/src/chunked_array/object/registry.rs b/crates/polars-core/src/chunked_array/object/registry.rs index bff4b05f9012..c16c2cc92780 100644 --- a/crates/polars-core/src/chunked_array/object/registry.rs +++ b/crates/polars-core/src/chunked_array/object/registry.rs @@ -7,8 +7,8 @@ use std::fmt::{Debug, Formatter}; use std::ops::Deref; use std::sync::{Arc, RwLock}; -use arrow::array::ArrayRef; use arrow::array::builder::ArrayBuilder; +use arrow::array::ArrayRef; use arrow::datatypes::ArrowDataType; use once_cell::sync::Lazy; use polars_utils::pl_str::PlSmallStr; @@ -41,9 +41,9 @@ static GLOBAL_OBJECT_REGISTRY: Lazy>> = Lazy::new( /// This trait can be registered, after which that global registration /// can be used to materialize object types -pub trait AnonymousObjectBuilder : ArrayBuilder { +pub trait AnonymousObjectBuilder: ArrayBuilder { fn as_array_builder(self: Box) -> Box; - + /// # Safety /// Expect `ObjectArray` arrays. unsafe fn from_chunks(self: Box, chunks: Vec) -> Series; From f88accccc4040d71e33e36b155dc757208a46c16 Mon Sep 17 00:00:00 2001 From: Orson Peters Date: Fri, 7 Mar 2025 12:33:45 +0100 Subject: [PATCH 14/25] make local morsel builder dropping more efficient and panic-leak-proof --- crates/polars-stream/Cargo.toml | 1 + .../src/nodes/joins/new_equi_join.rs | 61 +++++++++----- crates/polars-utils/src/lib.rs | 1 + crates/polars-utils/src/sparse_init_vec.rs | 82 +++++++++++++++++++ 4 files changed, 125 insertions(+), 20 deletions(-) create mode 100644 crates/polars-utils/src/sparse_init_vec.rs diff --git a/crates/polars-stream/Cargo.toml b/crates/polars-stream/Cargo.toml index dde5bc4cf73b..731e373ceff1 100644 --- a/crates/polars-stream/Cargo.toml +++ b/crates/polars-stream/Cargo.toml @@ -11,6 +11,7 @@ description = "Private crate for the streaming execution engine for the Polars D [dependencies] arrow = { workspace = true } atomic-waker = { workspace = true } +crossbeam-channel = { workspace = true } crossbeam-deque = { workspace = true } crossbeam-queue = { workspace = true } crossbeam-utils = { workspace = true } diff --git a/crates/polars-stream/src/nodes/joins/new_equi_join.rs b/crates/polars-stream/src/nodes/joins/new_equi_join.rs index 8d4ee4d7a7e6..4ee34d4383a5 100644 --- a/crates/polars-stream/src/nodes/joins/new_equi_join.rs +++ b/crates/polars-stream/src/nodes/joins/new_equi_join.rs @@ -16,7 +16,7 @@ use polars_utils::cardinality_sketch::CardinalitySketch; use polars_utils::hashing::HashPartitioner; use polars_utils::itertools::Itertools; use polars_utils::pl_str::PlSmallStr; -use polars_utils::sync::SyncPtr; +use polars_utils::sparse_init_vec::SparseInitVec; use polars_utils::{format_pl_smallstr, IdxSize}; use rayon::prelude::*; @@ -541,24 +541,28 @@ impl BuildState { }; // To reduce maximum memory usage we want to drop the morsels - // as soon as they're processed, so we move into Arcs. + // as soon as they're processed, so we move into Arcs. The drops might + // also be expensive, so instead of directly dropping we put that on + // a work queue. let morsels_per_local_builder = self .local_builders .iter_mut() .map(|b| Arc::new(core::mem::take(&mut b.morsels))) .collect_vec(); + let (morsel_drop_q_send, morsel_drop_q_recv) = crossbeam_channel::bounded(morsels_per_local_builder.len()); let num_partitions = self.local_builders[0].sketch_per_p.len(); let local_builders = &self.local_builders; + let probe_tables: SparseInitVec = SparseInitVec::with_capacity(num_partitions); - let mut probe_tables = POOL.scope(|s| { - let mut probe_tables: Vec = Vec::with_capacity(num_partitions); - let probe_table_ptr = unsafe { SyncPtr::new(probe_tables.as_mut_ptr()) }; - + POOL.scope(|s| { // Wrap in outer Arc to move to each thread, performing the // expensive clone on that thread. let arc_morsels_per_local_builder = Arc::new(morsels_per_local_builder); for p in 0..num_partitions { let arc_morsels_per_local_builder = Arc::clone(&arc_morsels_per_local_builder); + let morsel_drop_q_send = morsel_drop_q_send.clone(); + let morsel_drop_q_recv = morsel_drop_q_recv.clone(); + let probe_tables = &probe_tables; s.spawn(move |_| { // Extract from outer arc and drop outer arc. let morsels_per_local_builder = @@ -582,7 +586,13 @@ impl BuildState { p_payload.reserve(payload_rows); // Build. + let mut skip_drop_attempt = false; for (l, l_morsels) in local_builders.iter().zip(morsels_per_local_builder) { + // Try to help with dropping the processed morsels. + if !skip_drop_attempt { + drop(morsel_drop_q_recv.try_recv()); + } + for (i, morsel) in l_morsels.iter().enumerate() { let (_mseq, payload, keys) = morsel; unsafe { @@ -600,30 +610,41 @@ impl BuildState { ); } } + + if let Some(l) = Arc::into_inner(l_morsels) { + // If we're the last thread to process this set of morsels we're probably + // falling behind the rest, since the drop can be quite expensive we skip + // a drop attempt hoping someone else will pick up the slack. + morsel_drop_q_send.send(l).unwrap(); + skip_drop_attempt = true; + } else { + skip_drop_attempt = false; + } } - - unsafe { - probe_table_ptr.get().add(p).write(ProbeTable { - hash_table: p_table, - payload: p_payload.freeze(), - }); + + // We're done, help others out by doing drops. + drop(morsel_drop_q_send); // So we don't deadlock. + while let Ok(l_morsels) = morsel_drop_q_recv.recv() { + drop(l_morsels); } + + probe_tables.try_set(p, ProbeTable { + hash_table: p_table, + payload: p_payload.freeze(), + }).ok().unwrap(); }); } // Drop outer arc after spawning each thread so the inner arcs - // can get dropped as soon as they're processed. + // can get dropped as soon as they're processed. We also have to + // drop the drop queue sender so we don't deadlock waiting for it + // to end. drop(arc_morsels_per_local_builder); - probe_tables + drop(morsel_drop_q_send); }); - unsafe { - // SAFETY: all entries are initialized now. - probe_tables.set_len(num_partitions); - } - ProbeState { - table_per_partition: probe_tables, + table_per_partition: probe_tables.try_assume_init().ok().unwrap(), max_seq_sent: MorselSeq::default(), sampled_probe_morsels: core::mem::take(&mut self.sampled_probe_morsels), } diff --git a/crates/polars-utils/src/lib.rs b/crates/polars-utils/src/lib.rs index f0126825a537..88b456f0e33f 100644 --- a/crates/polars-utils/src/lib.rs +++ b/crates/polars-utils/src/lib.rs @@ -31,6 +31,7 @@ pub mod priority; pub mod select; pub mod slice; pub mod sort; +pub mod sparse_init_vec; pub mod sync; #[cfg(feature = "sysinfo")] pub mod sys; diff --git a/crates/polars-utils/src/sparse_init_vec.rs b/crates/polars-utils/src/sparse_init_vec.rs new file mode 100644 index 000000000000..12cef03cb2ea --- /dev/null +++ b/crates/polars-utils/src/sparse_init_vec.rs @@ -0,0 +1,82 @@ +use std::sync::atomic::{AtomicUsize, AtomicU8, Ordering}; + + +pub struct SparseInitVec { + ptr: *mut T, + len: usize, + cap: usize, + + num_init: AtomicUsize, + init_mask: Vec, +} + +unsafe impl Send for SparseInitVec { } +unsafe impl Sync for SparseInitVec { } + +impl SparseInitVec { + pub fn with_capacity(len: usize) -> Self { + let init_mask = (0..len.div_ceil(8)).map(|_| AtomicU8::new(0)).collect(); + let mut storage = Vec::with_capacity(len); + let cap = storage.capacity(); + let ptr = storage.as_mut_ptr(); + core::mem::forget(storage); + Self { + len, + cap, + ptr, + num_init: AtomicUsize::new(0), + init_mask, + } + } + + pub fn try_set(&self, idx: usize, value: T) -> Result<(), T> { + unsafe { + if idx >= self.len { + return Err(value); + } + + // SAFETY: we use Relaxed orderings as we only ever read data back in methods that take + // self mutably or owned, already implying synchronization. + let init_mask_byte = self.init_mask.get_unchecked(idx / 8); + let bit_mask = 1 << (idx % 8); + if init_mask_byte.fetch_or(bit_mask, Ordering::Relaxed) & bit_mask != 0 { + return Err(value); + } + + self.ptr.add(idx).write(value); + self.num_init.fetch_add(1, Ordering::Relaxed); + } + + Ok(()) + } + + pub fn try_assume_init(mut self) -> Result, Self> { + unsafe { + if *self.num_init.get_mut() == self.len { + let ret = Vec::from_raw_parts(self.ptr, self.len, self.cap); + drop(core::mem::take(&mut self.init_mask)); + core::mem::forget(self); + Ok(ret) + } else { + Err(self) + } + } + } +} + +impl Drop for SparseInitVec { + fn drop(&mut self) { + unsafe { + // Make sure storage gets dropped even if element drop panics. + let _storage = Vec::from_raw_parts(self.ptr, 0, self.cap); + + for idx in 0..self.len { + let init_mask_byte = self.init_mask.get_unchecked_mut(idx / 8); + let bit_mask = 1 << (idx % 8); + if *init_mask_byte.get_mut() & bit_mask != 0 { + self.ptr.add(idx).drop_in_place(); + } + } + } + } +} From 20406c89a1334f52cd57888c7d4d44524de8d874 Mon Sep 17 00:00:00 2001 From: Orson Peters Date: Fri, 7 Mar 2025 12:43:19 +0100 Subject: [PATCH 15/25] fix preserve order probe bug --- Cargo.lock | 1 + crates/polars-stream/src/nodes/joins/new_equi_join.rs | 1 + 2 files changed, 2 insertions(+) diff --git a/Cargo.lock b/Cargo.lock index b26c1405cf82..53d086eabae4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3516,6 +3516,7 @@ name = "polars-stream" version = "0.46.0" dependencies = [ "atomic-waker", + "crossbeam-channel", "crossbeam-deque", "crossbeam-queue", "crossbeam-utils", diff --git a/crates/polars-stream/src/nodes/joins/new_equi_join.rs b/crates/polars-stream/src/nodes/joins/new_equi_join.rs index 4ee34d4383a5..8230d8bca072 100644 --- a/crates/polars-stream/src/nodes/joins/new_equi_join.rs +++ b/crates/polars-stream/src/nodes/joins/new_equi_join.rs @@ -743,6 +743,7 @@ impl ProbeState { // To preserve the order we can't do bulk probes per partition and must follow // the order of the probe morsel. We can still group probes that are // consecutively on the same partition. + probe_partitions.clear(); hash_keys.gen_partitions(&partitioner, &mut probe_partitions, emit_unmatched); let mut probe_group_start = 0; while probe_group_start < probe_partitions.len() { From b8c56feee13627e84175ae18ee5965f8bc10ec19 Mon Sep 17 00:00:00 2001 From: Orson Peters Date: Fri, 7 Mar 2025 14:21:29 +0100 Subject: [PATCH 16/25] comment --- crates/polars-stream/src/nodes/joins/new_equi_join.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/polars-stream/src/nodes/joins/new_equi_join.rs b/crates/polars-stream/src/nodes/joins/new_equi_join.rs index 8230d8bca072..ce3b2281d4b5 100644 --- a/crates/polars-stream/src/nodes/joins/new_equi_join.rs +++ b/crates/polars-stream/src/nodes/joins/new_equi_join.rs @@ -623,7 +623,7 @@ impl BuildState { } // We're done, help others out by doing drops. - drop(morsel_drop_q_send); // So we don't deadlock. + drop(morsel_drop_q_send); // So we don't deadlock trying to receive from ourselves. while let Ok(l_morsels) = morsel_drop_q_recv.recv() { drop(l_morsels); } From 6bec9c0cc6d78bb3cef9bd16832e41967e4533b9 Mon Sep 17 00:00:00 2001 From: Orson Peters Date: Fri, 7 Mar 2025 17:15:14 +0100 Subject: [PATCH 17/25] order-preserving build --- .../src/nodes/joins/new_equi_join.rs | 175 +++++++++++++++++- .../src/physical_plan/to_graph.rs | 48 ++--- 2 files changed, 182 insertions(+), 41 deletions(-) diff --git a/crates/polars-stream/src/nodes/joins/new_equi_join.rs b/crates/polars-stream/src/nodes/joins/new_equi_join.rs index ce3b2281d4b5..657e5d1d5eeb 100644 --- a/crates/polars-stream/src/nodes/joins/new_equi_join.rs +++ b/crates/polars-stream/src/nodes/joins/new_equi_join.rs @@ -1,3 +1,5 @@ +use std::cmp::Reverse; +use std::collections::BinaryHeap; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::{Arc, LazyLock}; @@ -16,6 +18,7 @@ use polars_utils::cardinality_sketch::CardinalitySketch; use polars_utils::hashing::HashPartitioner; use polars_utils::itertools::Itertools; use polars_utils::pl_str::PlSmallStr; +use polars_utils::priority::Priority; use polars_utils::sparse_init_vec::SparseInitVec; use polars_utils::{format_pl_smallstr, IdxSize}; use rayon::prelude::*; @@ -532,7 +535,100 @@ impl BuildState { Ok(()) } - fn finalize(&mut self, params: &EquiJoinParams, table: &dyn IdxTable) -> ProbeState { + fn finalize_ordered(&mut self, params: &EquiJoinParams, table: &dyn IdxTable) -> ProbeState { + let track_unmatchable = params.emit_unmatched_build(); + let payload_schema = if params.left_is_build.unwrap() { + ¶ms.left_payload_schema + } else { + ¶ms.right_payload_schema + }; + + let num_partitions = self.local_builders[0].sketch_per_p.len(); + let local_builders = &self.local_builders; + let probe_tables: SparseInitVec = SparseInitVec::with_capacity(num_partitions); + + POOL.scope(|s| { + for p in 0..num_partitions { + let probe_tables = &probe_tables; + s.spawn(move |_| { + // TODO: every thread does an identical linearize, we can do a single parallel one. + let mut kmerge = BinaryHeap::with_capacity(local_builders.len()); + let mut cur_idx_per_loc = vec![0; local_builders.len()]; + + // Compute cardinality estimate and total amount of + // payload for this partition, and initialize k-way merge. + let mut sketch = CardinalitySketch::new(); + let mut payload_rows = 0; + for (l_idx, l) in local_builders.iter().enumerate() { + let Some((seq, _, _)) = l.morsels.get(0) else { continue }; + kmerge.push(Priority(Reverse(seq), l_idx)); + + sketch.combine(&l.sketch_per_p[p]); + let offsets_len = l.morsel_idxs_offsets_per_p.len(); + payload_rows += + l.morsel_idxs_offsets_per_p[offsets_len - num_partitions + p]; + } + + // Allocate hash table and payload builder. + let mut p_table = table.new_empty(); + p_table.reserve(sketch.estimate() * 5 / 4); + let mut p_payload = DataFrameBuilder::new(payload_schema.clone()); + p_payload.reserve(payload_rows); + + let mut p_seq_ids = Vec::new(); + if track_unmatchable { + p_seq_ids.reserve(payload_rows); + } + + // Linearize and build. + unsafe { + let mut norm_seq_id = 0 as IdxSize; + while let Some(Priority(Reverse(mut seq), l_idx)) = kmerge.pop() { + let l = local_builders.get_unchecked(l_idx); + let idx_in_l = *cur_idx_per_loc.get_unchecked(l_idx); + *cur_idx_per_loc.get_unchecked_mut(l_idx) += 1; + if let Some((next_seq, _, _)) = l.morsels.get(idx_in_l + 1) { + kmerge.push(Priority(Reverse(next_seq), l_idx)); + } + + let (_mseq, payload, keys) = l.morsels.get_unchecked(idx_in_l); + let p_morsel_idxs_start = + l.morsel_idxs_offsets_per_p[idx_in_l * num_partitions + p]; + let p_morsel_idxs_stop = + l.morsel_idxs_offsets_per_p[(idx_in_l + 1) * num_partitions + p]; + let p_morsel_idxs = &l.morsel_idxs_values_per_p[p] + [p_morsel_idxs_start..p_morsel_idxs_stop]; + p_table.insert_keys_subset(keys, p_morsel_idxs, track_unmatchable); + p_payload.gather_extend( + payload, + p_morsel_idxs, + ShareStrategy::Never, + ); + + if track_unmatchable { + p_seq_ids.resize(p_payload.len(), norm_seq_id); + norm_seq_id += 1; + } + } + } + + probe_tables.try_set(p, ProbeTable { + hash_table: p_table, + payload: p_payload.freeze(), + seq_ids: p_seq_ids, + }).ok().unwrap(); + }); + } + }); + + ProbeState { + table_per_partition: probe_tables.try_assume_init().ok().unwrap(), + max_seq_sent: MorselSeq::default(), + sampled_probe_morsels: core::mem::take(&mut self.sampled_probe_morsels), + } + } + + fn finalize_unordered(&mut self, params: &EquiJoinParams, table: &dyn IdxTable) -> ProbeState { let track_unmatchable = params.emit_unmatched_build(); let payload_schema = if params.left_is_build.unwrap() { ¶ms.left_payload_schema @@ -631,6 +727,7 @@ impl BuildState { probe_tables.try_set(p, ProbeTable { hash_table: p_table, payload: p_payload.freeze(), + seq_ids: Vec::new(), }).ok().unwrap(); }); } @@ -654,6 +751,7 @@ impl BuildState { struct ProbeTable { hash_table: Box, payload: DataFrame, + seq_ids: Vec, } struct ProbeState { @@ -907,10 +1005,68 @@ impl ProbeState { fn ordered_unmatched( &mut self, - _partitioner: &HashPartitioner, - _params: &EquiJoinParams, + params: &EquiJoinParams, ) -> DataFrame { - todo!() + // TODO: parallelize this operator. + + let build_payload_schema = if params.left_is_build.unwrap() { + ¶ms.left_payload_schema + } else { + ¶ms.right_payload_schema + }; + + let mut unmarked_idxs = Vec::new(); + let mut linearized_idxs = Vec::new(); + + for (p_idx, p) in self.table_per_partition.iter().enumerate_idx() { + p.hash_table + .unmarked_keys(&mut unmarked_idxs, 0, IdxSize::MAX); + linearized_idxs.extend(unmarked_idxs.iter().map(|i| { + (unsafe { *p.seq_ids.get_unchecked(*i as usize) }, p_idx, *i) + })); + } + + linearized_idxs.sort_by_key(|(seq_id, _, _)| *seq_id); + + unsafe { + let mut build_out = DataFrameBuilder::new(build_payload_schema.clone()); + build_out.reserve(linearized_idxs.len()); + + // Group indices from the same partition. + let mut group_start = 0; + let mut gather_idxs = Vec::new(); + while group_start < linearized_idxs.len() { + gather_idxs.clear(); + + let (_seq, p_idx, idx_in_p) = linearized_idxs[group_start]; + gather_idxs.push(idx_in_p); + let mut group_end = group_start + 1; + while group_end < linearized_idxs.len() && linearized_idxs[group_end].1 == p_idx { + gather_idxs.push(linearized_idxs[group_end].2); + group_end += 1; + } + + build_out.gather_extend( + &self.table_per_partition[p_idx as usize].payload, + &gather_idxs, + ShareStrategy::Never, // Don't keep entire table alive for unmatched indices. + ); + + group_start = group_end; + } + + let mut build_df = build_out.freeze(); + let out_df = if params.left_is_build.unwrap() { + let probe_df = DataFrame::full_null(¶ms.right_payload_schema, build_df.height()); + build_df.hstack_mut_unchecked(probe_df.get_columns()); + build_df + } else { + let mut probe_df = DataFrame::full_null(¶ms.left_payload_schema, build_df.height()); + probe_df.hstack_mut_unchecked(build_df.get_columns()); + probe_df + }; + postprocess_join(out_df, params) + } } } @@ -1173,9 +1329,12 @@ impl ComputeNode for EquiJoinNode { // If we are building and the build input is done, transition to probing. if let EquiJoinState::Build(build_state) = &mut self.state { if recv[build_idx] == PortState::Done { - self.state = EquiJoinState::Probe( - build_state.finalize(&self.params, self.table.as_deref().unwrap()), - ); + let probe_state = if self.params.preserve_order_build { + build_state.finalize_ordered(&self.params, self.table.as_deref().unwrap()) + } else { + build_state.finalize_unordered(&self.params, self.table.as_deref().unwrap()) + }; + self.state = EquiJoinState::Probe(probe_state); } } @@ -1187,7 +1346,7 @@ impl ComputeNode for EquiJoinNode { if self.params.emit_unmatched_build() { if self.params.preserve_order_build { let partitioner = HashPartitioner::new(self.num_pipelines, 0); - let unmatched = probe_state.ordered_unmatched(&partitioner, &self.params); + let unmatched = probe_state.ordered_unmatched(&self.params); let mut src = InMemorySourceNode::new( Arc::new(unmatched), probe_state.max_seq_sent.successor(), diff --git a/crates/polars-stream/src/physical_plan/to_graph.rs b/crates/polars-stream/src/physical_plan/to_graph.rs index 3297decaecf0..4d8685e2a638 100644 --- a/crates/polars-stream/src/physical_plan/to_graph.rs +++ b/crates/polars-stream/src/physical_plan/to_graph.rs @@ -814,39 +814,21 @@ fn to_graph_rec<'a>( args.maintain_order, MaintainOrderJoin::LeftRight | MaintainOrderJoin::RightLeft ); - if !preserve_order_build { - ctx.graph.add_node( - nodes::joins::new_equi_join::EquiJoinNode::new( - left_input_schema, - right_input_schema, - left_key_schema, - right_key_schema, - left_key_selectors, - right_key_selectors, - args, - )?, - [ - (left_input_key, input_left.port), - (right_input_key, input_right.port), - ], - ) - } else { - ctx.graph.add_node( - nodes::joins::equi_join::EquiJoinNode::new( - left_input_schema, - right_input_schema, - left_key_schema, - right_key_schema, - left_key_selectors, - right_key_selectors, - args, - )?, - [ - (left_input_key, input_left.port), - (right_input_key, input_right.port), - ], - ) - } + ctx.graph.add_node( + nodes::joins::new_equi_join::EquiJoinNode::new( + left_input_schema, + right_input_schema, + left_key_schema, + right_key_schema, + left_key_selectors, + right_key_selectors, + args, + )?, + [ + (left_input_key, input_left.port), + (right_input_key, input_right.port), + ], + ) }, #[cfg(feature = "merge_sorted")] From 69a13b2a95a26951e86e9fe65cdc64d02231d7cc Mon Sep 17 00:00:00 2001 From: Orson Peters Date: Fri, 7 Mar 2025 17:16:55 +0100 Subject: [PATCH 18/25] remove old join impl --- .../src/nodes/joins/equi_join.rs | 785 ++++---- crates/polars-stream/src/nodes/joins/mod.rs | 3 +- .../src/nodes/joins/new_equi_join.rs | 1582 ----------------- .../src/physical_plan/to_graph.rs | 7 +- 4 files changed, 478 insertions(+), 1899 deletions(-) delete mode 100644 crates/polars-stream/src/nodes/joins/new_equi_join.rs diff --git a/crates/polars-stream/src/nodes/joins/equi_join.rs b/crates/polars-stream/src/nodes/joins/equi_join.rs index b03c4f2a07fe..c1bcd6beeaa2 100644 --- a/crates/polars-stream/src/nodes/joins/equi_join.rs +++ b/crates/polars-stream/src/nodes/joins/equi_join.rs @@ -1,22 +1,25 @@ +use std::cmp::Reverse; +use std::collections::BinaryHeap; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::{Arc, LazyLock}; +use arrow::array::builder::ShareStrategy; use crossbeam_queue::ArrayQueue; +use polars_core::frame::builder::DataFrameBuilder; use polars_core::prelude::*; use polars_core::schema::{Schema, SchemaExt}; -use polars_core::series::IsSorted; -use polars_core::utils::accumulate_dataframes_vertical_unchecked; use polars_core::{config, POOL}; -use polars_expr::chunked_idx_table::{new_chunked_idx_table, ChunkedIdxTable}; use polars_expr::hash_keys::HashKeys; +use polars_expr::idx_table::{new_idx_table, IdxTable}; use polars_io::pl_async::get_runtime; use polars_ops::frame::{JoinArgs, JoinType, MaintainOrderJoin}; -use polars_ops::prelude::TakeChunked; use polars_ops::series::coalesce_columns; use polars_utils::cardinality_sketch::CardinalitySketch; use polars_utils::hashing::HashPartitioner; use polars_utils::itertools::Itertools; use polars_utils::pl_str::PlSmallStr; +use polars_utils::priority::Priority; +use polars_utils::sparse_init_vec::SparseInitVec; use polars_utils::{format_pl_smallstr, IdxSize}; use rayon::prelude::*; @@ -311,7 +314,7 @@ impl SampleState { recv: &[PortState], num_pipelines: usize, params: &mut EquiJoinParams, - table: &mut Option>, + table: &mut Option>, ) -> PolarsResult> { let left_saturated = self.left_len >= *SAMPLE_LIMIT; let right_saturated = self.right_len >= *SAMPLE_LIMIT; @@ -401,9 +404,9 @@ impl SampleState { // Transition to building state. params.left_is_build = Some(left_is_build); *table = Some(if left_is_build { - new_chunked_idx_table(params.left_key_schema.clone()) + new_idx_table(params.left_key_schema.clone()) } else { - new_chunked_idx_table(params.right_key_schema.clone()) + new_idx_table(params.right_key_schema.clone()) }); let mut sampled_build_morsels = @@ -416,7 +419,9 @@ impl SampleState { let partitioner = HashPartitioner::new(num_pipelines, 0); let mut build_state = BuildState { - partitions_per_worker: (0..num_pipelines).map(|_| Vec::new()).collect(), + local_builders: (0..num_pipelines) + .map(|_| LocalBuilder::default()) + .collect(), sampled_probe_morsels, }; @@ -429,13 +434,12 @@ impl SampleState { .reinsert(num_pipelines, None, scope, &mut join_handles) .unwrap(); - for (worker_ps, recv) in build_state.partitions_per_worker.iter_mut().zip(receivers) - { + for (local_builder, recv) in build_state.local_builders.iter_mut().zip(receivers) { join_handles.push(scope.spawn_task( TaskPriority::High, BuildState::partition_and_sink( recv, - worker_ps, + local_builder, partitioner.clone(), params, &state, @@ -457,30 +461,48 @@ impl SampleState { } #[derive(Default)] -struct BuildPartition { - hash_keys: Vec, - frames: Vec<(MorselSeq, DataFrame)>, - sketch: Option, +struct LocalBuilder { + // The complete list of morsels and their computed hashes seen by this builder. + morsels: Vec<(MorselSeq, DataFrame, HashKeys)>, + + // A cardinality sketch per partition for the keys seen by this builder. + sketch_per_p: Vec, + + // morsel_idxs_values_per_p[p][start..stop] contains the offsets into morsels[i] + // for partition p, where start, stop are: + // let start = morsel_idxs_offsets[i * num_partitions + p]; + // let stop = morsel_idxs_offsets[(i + 1) * num_partitions + p]; + morsel_idxs_values_per_p: Vec>, + morsel_idxs_offsets_per_p: Vec, } #[derive(Default)] struct BuildState { - partitions_per_worker: Vec>, + local_builders: Vec, sampled_probe_morsels: BufferedStream, } impl BuildState { async fn partition_and_sink( mut recv: Receiver, - partitions: &mut Vec, + local: &mut LocalBuilder, partitioner: HashPartitioner, params: &EquiJoinParams, state: &ExecutionState, ) -> PolarsResult<()> { let track_unmatchable = params.emit_unmatched_build(); - let mut partition_idxs = vec![Vec::new(); partitioner.num_partitions()]; - partitions.resize_with(partitioner.num_partitions(), BuildPartition::default); - let mut sketches = vec![CardinalitySketch::default(); partitioner.num_partitions()]; + local + .sketch_per_p + .resize_with(partitioner.num_partitions(), Default::default); + local + .morsel_idxs_values_per_p + .resize_with(partitioner.num_partitions(), Default::default); + + if local.morsel_idxs_offsets_per_p.is_empty() { + local + .morsel_idxs_offsets_per_p + .resize(partitioner.num_partitions(), 0); + } let (key_selectors, payload_selector); if params.left_is_build.unwrap() { @@ -493,140 +515,243 @@ impl BuildState { while let Ok(morsel) = recv.recv().await { // Compute hashed keys and payload. We must rechunk the payload for - // later chunked gathers. + // later gathers. let hash_keys = select_keys(morsel.df(), key_selectors, params, state).await?; let mut payload = select_payload(morsel.df().clone(), payload_selector); payload.rechunk_mut(); - payload._deshare_views_mut(); - unsafe { - for p in partition_idxs.iter_mut() { - p.clear(); - } - hash_keys.gen_idxs_per_partition( - &partitioner, - &mut partition_idxs, - &mut sketches, - track_unmatchable, - ); - for (p, idxs_in_p) in partitions.iter_mut().zip(&partition_idxs) { - let payload_for_partition = payload.take_slice_unchecked_impl(idxs_in_p, false); - p.hash_keys.push(hash_keys.gather(idxs_in_p)); - p.frames.push((morsel.seq(), payload_for_partition)); - } - } - } + hash_keys.gen_idxs_per_partition( + &partitioner, + &mut local.morsel_idxs_values_per_p, + &mut local.sketch_per_p, + track_unmatchable, + ); - for (p, sketch) in sketches.into_iter().enumerate() { - partitions[p].sketch = Some(sketch); + local + .morsel_idxs_offsets_per_p + .extend(local.morsel_idxs_values_per_p.iter().map(|vp| vp.len())); + local.morsels.push((morsel.seq(), payload, hash_keys)); } - Ok(()) } - fn finalize(&mut self, params: &EquiJoinParams, table: &dyn ChunkedIdxTable) -> ProbeState { - // Transpose. - let num_workers = self.partitions_per_worker.len(); - let num_partitions = self.partitions_per_worker[0].len(); - let mut results_per_partition = (0..num_partitions) - .map(|_| Vec::with_capacity(num_workers)) - .collect_vec(); - for worker in self.partitions_per_worker.drain(..) { - for (p, result) in worker.into_iter().enumerate() { - results_per_partition[p].push(result); - } - } + fn finalize_ordered(&mut self, params: &EquiJoinParams, table: &dyn IdxTable) -> ProbeState { + let track_unmatchable = params.emit_unmatched_build(); + let payload_schema = if params.left_is_build.unwrap() { + ¶ms.left_payload_schema + } else { + ¶ms.right_payload_schema + }; - POOL.install(|| { - let track_unmatchable = params.emit_unmatched_build(); - let table_per_partition: Vec<_> = results_per_partition - .into_par_iter() - .with_max_len(1) - .map(|results| { - // Estimate sizes and cardinality. + let num_partitions = self.local_builders[0].sketch_per_p.len(); + let local_builders = &self.local_builders; + let probe_tables: SparseInitVec = SparseInitVec::with_capacity(num_partitions); + + POOL.scope(|s| { + for p in 0..num_partitions { + let probe_tables = &probe_tables; + s.spawn(move |_| { + // TODO: every thread does an identical linearize, we can do a single parallel one. + let mut kmerge = BinaryHeap::with_capacity(local_builders.len()); + let mut cur_idx_per_loc = vec![0; local_builders.len()]; + + // Compute cardinality estimate and total amount of + // payload for this partition, and initialize k-way merge. let mut sketch = CardinalitySketch::new(); - let mut num_frames = 0; - for result in &results { - sketch.combine(result.sketch.as_ref().unwrap()); - num_frames += result.frames.len(); + let mut payload_rows = 0; + for (l_idx, l) in local_builders.iter().enumerate() { + let Some((seq, _, _)) = l.morsels.get(0) else { continue }; + kmerge.push(Priority(Reverse(seq), l_idx)); + + sketch.combine(&l.sketch_per_p[p]); + let offsets_len = l.morsel_idxs_offsets_per_p.len(); + payload_rows += + l.morsel_idxs_offsets_per_p[offsets_len - num_partitions + p]; } - // Build table for this partition. - let mut combined_frames = Vec::with_capacity(num_frames); - let mut chunk_seq_ids = Vec::with_capacity(num_frames); - let mut table = table.new_empty(); - table.reserve(sketch.estimate() * 5 / 4); - if params.preserve_order_build { - let mut combined = Vec::with_capacity(num_frames); - for result in results { - for (hash_keys, (seq, frame)) in - result.hash_keys.into_iter().zip(result.frames) - { - combined.push((seq, hash_keys, frame)); + // Allocate hash table and payload builder. + let mut p_table = table.new_empty(); + p_table.reserve(sketch.estimate() * 5 / 4); + let mut p_payload = DataFrameBuilder::new(payload_schema.clone()); + p_payload.reserve(payload_rows); + + let mut p_seq_ids = Vec::new(); + if track_unmatchable { + p_seq_ids.reserve(payload_rows); + } + + // Linearize and build. + unsafe { + let mut norm_seq_id = 0 as IdxSize; + while let Some(Priority(Reverse(seq), l_idx)) = kmerge.pop() { + let l = local_builders.get_unchecked(l_idx); + let idx_in_l = *cur_idx_per_loc.get_unchecked(l_idx); + *cur_idx_per_loc.get_unchecked_mut(l_idx) += 1; + if let Some((next_seq, _, _)) = l.morsels.get(idx_in_l + 1) { + kmerge.push(Priority(Reverse(next_seq), l_idx)); + } + + let (_mseq, payload, keys) = l.morsels.get_unchecked(idx_in_l); + let p_morsel_idxs_start = + l.morsel_idxs_offsets_per_p[idx_in_l * num_partitions + p]; + let p_morsel_idxs_stop = + l.morsel_idxs_offsets_per_p[(idx_in_l + 1) * num_partitions + p]; + let p_morsel_idxs = &l.morsel_idxs_values_per_p[p] + [p_morsel_idxs_start..p_morsel_idxs_stop]; + p_table.insert_keys_subset(keys, p_morsel_idxs, track_unmatchable); + p_payload.gather_extend( + payload, + p_morsel_idxs, + ShareStrategy::Never, + ); + + if track_unmatchable { + p_seq_ids.resize(p_payload.len(), norm_seq_id); + norm_seq_id += 1; } } + } - combined.sort_unstable_by_key(|c| c.0); - for (seq, hash_keys, frame) in combined { - // Zero-sized chunks can get deleted, so skip entirely to avoid messing - // up the chunk counter. - if frame.height() == 0 { - continue; - } + probe_tables.try_set(p, ProbeTable { + hash_table: p_table, + payload: p_payload.freeze(), + seq_ids: p_seq_ids, + }).ok().unwrap(); + }); + } + }); + + ProbeState { + table_per_partition: probe_tables.try_assume_init().ok().unwrap(), + max_seq_sent: MorselSeq::default(), + sampled_probe_morsels: core::mem::take(&mut self.sampled_probe_morsels), + } + } + + fn finalize_unordered(&mut self, params: &EquiJoinParams, table: &dyn IdxTable) -> ProbeState { + let track_unmatchable = params.emit_unmatched_build(); + let payload_schema = if params.left_is_build.unwrap() { + ¶ms.left_payload_schema + } else { + ¶ms.right_payload_schema + }; - table.insert_key_chunk(hash_keys, track_unmatchable); - combined_frames.push(frame); - chunk_seq_ids.push(seq); + // To reduce maximum memory usage we want to drop the morsels + // as soon as they're processed, so we move into Arcs. The drops might + // also be expensive, so instead of directly dropping we put that on + // a work queue. + let morsels_per_local_builder = self + .local_builders + .iter_mut() + .map(|b| Arc::new(core::mem::take(&mut b.morsels))) + .collect_vec(); + let (morsel_drop_q_send, morsel_drop_q_recv) = crossbeam_channel::bounded(morsels_per_local_builder.len()); + let num_partitions = self.local_builders[0].sketch_per_p.len(); + let local_builders = &self.local_builders; + let probe_tables: SparseInitVec = SparseInitVec::with_capacity(num_partitions); + + POOL.scope(|s| { + // Wrap in outer Arc to move to each thread, performing the + // expensive clone on that thread. + let arc_morsels_per_local_builder = Arc::new(morsels_per_local_builder); + for p in 0..num_partitions { + let arc_morsels_per_local_builder = Arc::clone(&arc_morsels_per_local_builder); + let morsel_drop_q_send = morsel_drop_q_send.clone(); + let morsel_drop_q_recv = morsel_drop_q_recv.clone(); + let probe_tables = &probe_tables; + s.spawn(move |_| { + // Extract from outer arc and drop outer arc. + let morsels_per_local_builder = + Arc::unwrap_or_clone(arc_morsels_per_local_builder); + + // Compute cardinality estimate and total amount of + // payload for this partition. + let mut sketch = CardinalitySketch::new(); + let mut payload_rows = 0; + for l in local_builders { + sketch.combine(&l.sketch_per_p[p]); + let offsets_len = l.morsel_idxs_offsets_per_p.len(); + payload_rows += + l.morsel_idxs_offsets_per_p[offsets_len - num_partitions + p]; + } + + // Allocate hash table and payload builder. + let mut p_table = table.new_empty(); + p_table.reserve(sketch.estimate() * 5 / 4); + let mut p_payload = DataFrameBuilder::new(payload_schema.clone()); + p_payload.reserve(payload_rows); + + // Build. + let mut skip_drop_attempt = false; + for (l, l_morsels) in local_builders.iter().zip(morsels_per_local_builder) { + // Try to help with dropping the processed morsels. + if !skip_drop_attempt { + drop(morsel_drop_q_recv.try_recv()); } - } else { - for result in results { - for (hash_keys, (_, frame)) in - result.hash_keys.into_iter().zip(result.frames) - { - // Zero-sized chunks can get deleted, so skip entirely to avoid messing - // up the chunk counter. - if frame.height() == 0 { - continue; - } - table.insert_key_chunk(hash_keys, track_unmatchable); - combined_frames.push(frame); + for (i, morsel) in l_morsels.iter().enumerate() { + let (_mseq, payload, keys) = morsel; + unsafe { + let p_morsel_idxs_start = + l.morsel_idxs_offsets_per_p[i * num_partitions + p]; + let p_morsel_idxs_stop = + l.morsel_idxs_offsets_per_p[(i + 1) * num_partitions + p]; + let p_morsel_idxs = &l.morsel_idxs_values_per_p[p] + [p_morsel_idxs_start..p_morsel_idxs_stop]; + p_table.insert_keys_subset(keys, p_morsel_idxs, track_unmatchable); + p_payload.gather_extend( + payload, + p_morsel_idxs, + ShareStrategy::Never, + ); } } - } - - let df = if combined_frames.is_empty() { - if params.left_is_build.unwrap() { - DataFrame::empty_with_schema(¶ms.left_payload_schema) + + if let Some(l) = Arc::into_inner(l_morsels) { + // If we're the last thread to process this set of morsels we're probably + // falling behind the rest, since the drop can be quite expensive we skip + // a drop attempt hoping someone else will pick up the slack. + morsel_drop_q_send.send(l).unwrap(); + skip_drop_attempt = true; } else { - DataFrame::empty_with_schema(¶ms.right_payload_schema) + skip_drop_attempt = false; } - } else { - accumulate_dataframes_vertical_unchecked(combined_frames) - }; - ProbeTable { - hash_table: table, - payload: df, - chunk_seq_ids, } - }) - .collect(); + + // We're done, help others out by doing drops. + drop(morsel_drop_q_send); // So we don't deadlock trying to receive from ourselves. + while let Ok(l_morsels) = morsel_drop_q_recv.recv() { + drop(l_morsels); + } - ProbeState { - table_per_partition, - max_seq_sent: MorselSeq::default(), - sampled_probe_morsels: core::mem::take(&mut self.sampled_probe_morsels), + probe_tables.try_set(p, ProbeTable { + hash_table: p_table, + payload: p_payload.freeze(), + seq_ids: Vec::new(), + }).ok().unwrap(); + }); } - }) + + // Drop outer arc after spawning each thread so the inner arcs + // can get dropped as soon as they're processed. We also have to + // drop the drop queue sender so we don't deadlock waiting for it + // to end. + drop(arc_morsels_per_local_builder); + drop(morsel_drop_q_send); + }); + + ProbeState { + table_per_partition: probe_tables.try_assume_init().ok().unwrap(), + max_seq_sent: MorselSeq::default(), + sampled_probe_morsels: core::mem::take(&mut self.sampled_probe_morsels), + } } } struct ProbeTable { - // Important that df is not rechunked, the chunks it was inserted with - // into the table must be preserved for chunked gathers. - hash_table: Box, + hash_table: Box, payload: DataFrame, - chunk_seq_ids: Vec, + seq_ids: Vec, } struct ProbeState { @@ -647,6 +772,8 @@ impl ProbeState { ) -> PolarsResult { // TODO: shuffle after partitioning and keep probe tables thread-local. let mut partition_idxs = vec![Vec::new(); partitioner.num_partitions()]; + let mut probe_partitions = Vec::new(); + let mut materialized_idxsize_range = Vec::new(); let mut table_match = Vec::new(); let mut probe_match = Vec::new(); let mut max_seq = MorselSeq::default(); @@ -655,115 +782,152 @@ impl ProbeState { let mark_matches = params.emit_unmatched_build(); let emit_unmatched = params.emit_unmatched_probe(); - let (key_selectors, payload_selector); + let (key_selectors, payload_selector, build_payload_schema, probe_payload_schema); if params.left_is_build.unwrap() { - payload_selector = ¶ms.right_payload_select; key_selectors = ¶ms.right_key_selectors; + payload_selector = ¶ms.right_payload_select; + build_payload_schema = ¶ms.left_payload_schema; + probe_payload_schema = ¶ms.right_payload_schema; } else { - payload_selector = ¶ms.left_payload_select; key_selectors = ¶ms.left_key_selectors; + payload_selector = ¶ms.left_payload_select; + build_payload_schema = ¶ms.right_payload_schema; + probe_payload_schema = ¶ms.left_payload_schema; }; + let mut build_out = DataFrameBuilder::new(build_payload_schema.clone()); + let mut probe_out = DataFrameBuilder::new(probe_payload_schema.clone()); + + // A simple estimate used to size reserves. + let mut selectivity_estimate = 1.0; + while let Ok(morsel) = recv.recv().await { // Compute hashed keys and payload. let (df, seq, src_token, wait_token) = morsel.into_inner(); + max_seq = seq; + + let df_height = df.height(); + if df_height == 0 { + continue; + } + let hash_keys = select_keys(&df, key_selectors, params, state).await?; let mut payload = select_payload(df, payload_selector); let mut payload_rechunked = false; // We don't eagerly rechunk because there might be no matches. + let mut total_matches = 0; - max_seq = seq; + // Use selectivity estimate to reserve for morsel builders. + let max_match_per_key_est = selectivity_estimate as usize + 16; + let out_est_size = ((selectivity_estimate * 1.2 * df_height as f64) as usize) + .min(probe_limit as usize); + build_out.reserve(out_est_size + max_match_per_key_est); unsafe { - // Partition and probe the tables. - for p in partition_idxs.iter_mut() { - p.clear(); - } - hash_keys.gen_idxs_per_partition( - &partitioner, - &mut partition_idxs, - &mut [], - emit_unmatched, - ); - if params.preserve_order_probe { - // TODO: non-sort based implementation, can directly scatter - // after finding matches for each partition. - let mut out_per_partition = Vec::with_capacity(partitioner.num_partitions()); - let name = PlSmallStr::from_static("__POLARS_PROBE_PRESERVE_ORDER_IDX"); - for (p, idxs_in_p) in partitions.iter().zip(&partition_idxs) { - table_match.clear(); - probe_match.clear(); - p.hash_table.probe_subset( - &hash_keys, - idxs_in_p, - &mut table_match, - &mut probe_match, - mark_matches, - emit_unmatched, - IdxSize::MAX, - ); + let new_morsel = |build: &mut DataFrameBuilder, probe: &mut DataFrameBuilder| { + let mut build_df = build.freeze_reset(); + let mut probe_df = probe.freeze_reset(); + let out_df = if params.left_is_build.unwrap() { + build_df.hstack_mut_unchecked(probe_df.get_columns()); + build_df + } else { + probe_df.hstack_mut_unchecked(build_df.get_columns()); + probe_df + }; + let out_df = postprocess_join(out_df, params); + Morsel::new(out_df, seq, src_token.clone()) + }; - if table_match.is_empty() { - continue; + if params.preserve_order_probe { + // To preserve the order we can't do bulk probes per partition and must follow + // the order of the probe morsel. We can still group probes that are + // consecutively on the same partition. + probe_partitions.clear(); + hash_keys.gen_partitions(&partitioner, &mut probe_partitions, emit_unmatched); + let mut probe_group_start = 0; + while probe_group_start < probe_partitions.len() { + let p_idx = probe_partitions[probe_group_start]; + let mut probe_group_end = probe_group_start + 1; + while probe_partitions.get(probe_group_end) == Some(&p_idx) { + probe_group_end += 1; } - - // Gather output and add to buffer. - let mut build_df = if emit_unmatched { - p.payload.take_opt_chunked_unchecked(&table_match, false) - } else { - p.payload - .take_chunked_unchecked(&table_match, IsSorted::Not, false) + let Some(p) = partitions.get(p_idx as usize) else { + probe_group_start = probe_group_end; + continue; }; - if !payload_rechunked { - // TODO: can avoid rechunk? We have to rechunk here or else we do it - // multiple times during the gather. - payload.rechunk_mut(); - payload_rechunked = true; - } - let mut probe_df = payload.take_slice_unchecked_impl(&probe_match, false); + materialized_idxsize_range.extend( + materialized_idxsize_range.len() as IdxSize..probe_group_end as IdxSize, + ); - let mut out_df = if params.left_is_build.unwrap() { - build_df.hstack_mut_unchecked(probe_df.get_columns()); - build_df - } else { - probe_df.hstack_mut_unchecked(build_df.get_columns()); - probe_df - }; + while probe_group_start < probe_group_end { + let matches_before_limit = probe_limit - probe_match.len() as IdxSize; + table_match.clear(); + probe_group_start += p.hash_table.probe_subset( + &hash_keys, + &materialized_idxsize_range[probe_group_start..probe_group_end], + &mut table_match, + &mut probe_match, + mark_matches, + emit_unmatched, + matches_before_limit, + ) as usize; - let idxs_ca = - IdxCa::from_vec(name.clone(), core::mem::take(&mut probe_match)); - out_df.with_column_unchecked(idxs_ca.into_column()); - out_per_partition.push(out_df); - } + if emit_unmatched { + build_out.opt_gather_extend( + &p.payload, + &table_match, + ShareStrategy::Always, + ); + } else { + build_out.gather_extend( + &p.payload, + &table_match, + ShareStrategy::Always, + ); + }; - if !out_per_partition.is_empty() { - let sort_options = SortMultipleOptions { - descending: vec![false], - nulls_last: vec![false], - multithreaded: false, - maintain_order: true, - limit: None, - }; - let mut out_df = - accumulate_dataframes_vertical_unchecked(out_per_partition); - out_df.sort_in_place([name.clone()], sort_options).unwrap(); - out_df.drop_in_place(&name).unwrap(); - out_df = postprocess_join(out_df, params); - - // TODO: break in smaller morsels. - let out_morsel = Morsel::new(out_df, seq, src_token.clone()); - if send.send(out_morsel).await.is_err() { - break; + if probe_match.len() >= probe_limit as usize + || probe_group_start == probe_partitions.len() + { + if !payload_rechunked { + payload.rechunk_mut(); + payload_rechunked = true; + } + probe_out.gather_extend( + &payload, + &probe_match, + ShareStrategy::Always, + ); + probe_match.clear(); + let out_morsel = new_morsel(&mut build_out, &mut probe_out); + if send.send(out_morsel).await.is_err() { + return Ok(max_seq); + } + if probe_group_end != probe_partitions.len() { + // We had enough matches to need a mid-partition flush, let's assume there are a lot of + // matches and just do a large reserve. + build_out.reserve(probe_limit as usize + max_match_per_key_est); + } + } } } } else { - let mut out_frames = Vec::new(); - let mut out_len = 0; + // Partition and probe the tables. + for p in partition_idxs.iter_mut() { + p.clear(); + } + hash_keys.gen_idxs_per_partition( + &partitioner, + &mut partition_idxs, + &mut [], + emit_unmatched, + ); + for (p, idxs_in_p) in partitions.iter().zip(&partition_idxs) { let mut offset = 0; while offset < idxs_in_p.len() { + let matches_before_limit = probe_limit - probe_match.len() as IdxSize; table_match.clear(); - probe_match.clear(); offset += p.hash_table.probe_subset( &hash_keys, &idxs_in_p[offset..], @@ -771,66 +935,69 @@ impl ProbeState { &mut probe_match, mark_matches, emit_unmatched, - probe_limit - out_len, + matches_before_limit, ) as usize; if table_match.is_empty() { continue; } - - // Gather output and send. - let mut build_df = if emit_unmatched { - p.payload.take_opt_chunked_unchecked(&table_match, false) + total_matches += table_match.len(); + + if emit_unmatched { + build_out.opt_gather_extend( + &p.payload, + &table_match, + ShareStrategy::Always, + ); } else { - p.payload - .take_chunked_unchecked(&table_match, IsSorted::Not, false) + build_out.gather_extend( + &p.payload, + &table_match, + ShareStrategy::Always, + ); }; - if !payload_rechunked { - // TODO: can avoid rechunk? We have to rechunk here or else we do it - // multiple times during the gather. - payload.rechunk_mut(); - payload_rechunked = true; - } - let mut probe_df = - payload.take_slice_unchecked_impl(&probe_match, false); - let out_df = if params.left_is_build.unwrap() { - build_df.hstack_mut_unchecked(probe_df.get_columns()); - build_df - } else { - probe_df.hstack_mut_unchecked(build_df.get_columns()); - probe_df - }; - let out_df = postprocess_join(out_df, params); - - out_len = out_len - .checked_add(out_df.height().try_into().unwrap()) - .unwrap(); - out_frames.push(out_df); - - if out_len >= probe_limit { - out_len = 0; - let df = - accumulate_dataframes_vertical_unchecked(out_frames.drain(..)); - let out_morsel = Morsel::new(df, seq, src_token.clone()); + if probe_match.len() >= probe_limit as usize { + if !payload_rechunked { + payload.rechunk_mut(); + payload_rechunked = true; + } + probe_out.gather_extend( + &payload, + &probe_match, + ShareStrategy::Always, + ); + probe_match.clear(); + let out_morsel = new_morsel(&mut build_out, &mut probe_out); if send.send(out_morsel).await.is_err() { - break; + return Ok(max_seq); } + // We had enough matches to need a mid-partition flush, let's assume there are a lot of + // matches and just do a large reserve. + build_out.reserve(probe_limit as usize + max_match_per_key_est); } } } - if out_len > 0 { - let df = accumulate_dataframes_vertical_unchecked(out_frames.drain(..)); - let out_morsel = Morsel::new(df, seq, src_token.clone()); + if !probe_match.is_empty() { + if !payload_rechunked { + payload.rechunk_mut(); + } + probe_out.gather_extend(&payload, &probe_match, ShareStrategy::Always); + probe_match.clear(); + let out_morsel = new_morsel(&mut build_out, &mut probe_out); if send.send(out_morsel).await.is_err() { - break; + return Ok(max_seq); } } } } drop(wait_token); + + // Move selectivity estimate a bit towards latest value. + selectivity_estimate = + 0.8 * selectivity_estimate + 0.2 * (total_matches as f64 / df_height as f64); } Ok(max_seq) @@ -838,66 +1005,67 @@ impl ProbeState { fn ordered_unmatched( &mut self, - partitioner: &HashPartitioner, params: &EquiJoinParams, ) -> DataFrame { - let mut out_per_partition = Vec::with_capacity(partitioner.num_partitions()); - let seq_name = PlSmallStr::from_static("__POLARS_PROBE_PRESERVE_ORDER_SEQ"); - let idx_name = PlSmallStr::from_static("__POLARS_PROBE_PRESERVE_ORDER_IDX"); + // TODO: parallelize this operator. + + let build_payload_schema = if params.left_is_build.unwrap() { + ¶ms.left_payload_schema + } else { + ¶ms.right_payload_schema + }; + let mut unmarked_idxs = Vec::new(); - unsafe { - for p in self.table_per_partition.iter() { - p.hash_table - .unmarked_keys(&mut unmarked_idxs, 0, IdxSize::MAX); + let mut linearized_idxs = Vec::new(); - // Gather and create full-null counterpart. - let mut build_df = - p.payload - .take_chunked_unchecked(&unmarked_idxs, IsSorted::Not, false); - let len = build_df.height(); - let mut out_df = if params.left_is_build.unwrap() { - let probe_df = DataFrame::full_null(¶ms.right_payload_schema, len); - build_df.hstack_mut_unchecked(probe_df.get_columns()); - build_df - } else { - let mut probe_df = DataFrame::full_null(¶ms.left_payload_schema, len); - probe_df.hstack_mut_unchecked(build_df.get_columns()); - probe_df - }; + for (p_idx, p) in self.table_per_partition.iter().enumerate_idx() { + p.hash_table + .unmarked_keys(&mut unmarked_idxs, 0, IdxSize::MAX); + linearized_idxs.extend(unmarked_idxs.iter().map(|i| { + (unsafe { *p.seq_ids.get_unchecked(*i as usize) }, p_idx, *i) + })); + } + + linearized_idxs.sort_by_key(|(seq_id, _, _)| *seq_id); - // The indices are not ordered globally, but within each chunk they are, so sorting - // by chunk sequence id, breaking ties by inner chunk idx works. - let (chunk_seqs, idx_in_chunk) = unmarked_idxs - .iter() - .map(|chunk_id| { - let (chunk, idx_in_chunk) = chunk_id.extract(); - (p.chunk_seq_ids[chunk as usize].to_u64(), idx_in_chunk) - }) - .unzip(); + unsafe { + let mut build_out = DataFrameBuilder::new(build_payload_schema.clone()); + build_out.reserve(linearized_idxs.len()); + + // Group indices from the same partition. + let mut group_start = 0; + let mut gather_idxs = Vec::new(); + while group_start < linearized_idxs.len() { + gather_idxs.clear(); + + let (_seq, p_idx, idx_in_p) = linearized_idxs[group_start]; + gather_idxs.push(idx_in_p); + let mut group_end = group_start + 1; + while group_end < linearized_idxs.len() && linearized_idxs[group_end].1 == p_idx { + gather_idxs.push(linearized_idxs[group_end].2); + group_end += 1; + } + + build_out.gather_extend( + &self.table_per_partition[p_idx as usize].payload, + &gather_idxs, + ShareStrategy::Never, // Don't keep entire table alive for unmatched indices. + ); - let chunk_seqs_ca = UInt64Chunked::from_vec(seq_name.clone(), chunk_seqs); - let idxs_ca = IdxCa::from_vec(idx_name.clone(), idx_in_chunk); - out_df.with_column_unchecked(chunk_seqs_ca.into_column()); - out_df.with_column_unchecked(idxs_ca.into_column()); - out_per_partition.push(out_df); + group_start = group_end; } - // Sort by chunk sequence id, then by inner chunk idx. - let sort_options = SortMultipleOptions { - descending: vec![false], - nulls_last: vec![false], - multithreaded: true, - maintain_order: false, - limit: None, + let mut build_df = build_out.freeze(); + let out_df = if params.left_is_build.unwrap() { + let probe_df = DataFrame::full_null(¶ms.right_payload_schema, build_df.height()); + build_df.hstack_mut_unchecked(probe_df.get_columns()); + build_df + } else { + let mut probe_df = DataFrame::full_null(¶ms.left_payload_schema, build_df.height()); + probe_df.hstack_mut_unchecked(build_df.get_columns()); + probe_df }; - let mut out_df = accumulate_dataframes_vertical_unchecked(out_per_partition); - out_df - .sort_in_place([seq_name.clone(), idx_name.clone()], sort_options) - .unwrap(); - out_df.drop_in_place(&seq_name).unwrap(); - out_df.drop_in_place(&idx_name).unwrap(); - out_df = postprocess_join(out_df, params); - out_df + postprocess_join(out_df, params) } } } @@ -951,9 +1119,7 @@ impl EmitUnmatchedState { // Gather and create full-null counterpart. let out_df = unsafe { - let mut build_df = - p.payload - .take_chunked_unchecked(&unmarked_idxs, IsSorted::Not, false); + let mut build_df = p.payload.take_slice_unchecked_impl(&unmarked_idxs, false); let len = build_df.height(); if params.left_is_build.unwrap() { let probe_df = DataFrame::full_null(¶ms.right_payload_schema, len); @@ -1038,7 +1204,7 @@ pub struct EquiJoinNode { state: EquiJoinState, params: EquiJoinParams, num_pipelines: usize, - table: Option>, + table: Option>, } impl EquiJoinNode { @@ -1065,9 +1231,9 @@ impl EquiJoinNode { let table = left_is_build.map(|lib| { if lib { - new_chunked_idx_table(left_key_schema.clone()) + new_idx_table(left_key_schema.clone()) } else { - new_chunked_idx_table(right_key_schema.clone()) + new_idx_table(right_key_schema.clone()) } }); @@ -1163,9 +1329,12 @@ impl ComputeNode for EquiJoinNode { // If we are building and the build input is done, transition to probing. if let EquiJoinState::Build(build_state) = &mut self.state { if recv[build_idx] == PortState::Done { - self.state = EquiJoinState::Probe( - build_state.finalize(&self.params, self.table.as_deref().unwrap()), - ); + let probe_state = if self.params.preserve_order_build { + build_state.finalize_ordered(&self.params, self.table.as_deref().unwrap()) + } else { + build_state.finalize_unordered(&self.params, self.table.as_deref().unwrap()) + }; + self.state = EquiJoinState::Probe(probe_state); } } @@ -1176,8 +1345,7 @@ impl ComputeNode for EquiJoinNode { if samples_consumed && recv[probe_idx] == PortState::Done { if self.params.emit_unmatched_build() { if self.params.preserve_order_build { - let partitioner = HashPartitioner::new(self.num_pipelines, 0); - let unmatched = probe_state.ordered_unmatched(&partitioner, &self.params); + let unmatched = probe_state.ordered_unmatched(&self.params); let mut src = InMemorySourceNode::new( Arc::new(unmatched), probe_state.max_seq_sent.successor(), @@ -1337,16 +1505,15 @@ impl ComputeNode for EquiJoinNode { let receivers = recv_ports[build_idx].take().unwrap().parallel(); build_state - .partitions_per_worker - .resize_with(self.num_pipelines, Vec::new); + .local_builders + .resize_with(self.num_pipelines, Default::default); let partitioner = HashPartitioner::new(self.num_pipelines, 0); - for (worker_ps, recv) in build_state.partitions_per_worker.iter_mut().zip(receivers) - { + for (local_builder, recv) in build_state.local_builders.iter_mut().zip(receivers) { join_handles.push(scope.spawn_task( TaskPriority::High, BuildState::partition_and_sink( recv, - worker_ps, + local_builder, partitioner.clone(), &self.params, state, diff --git a/crates/polars-stream/src/nodes/joins/mod.rs b/crates/polars-stream/src/nodes/joins/mod.rs index ea6607737c43..eb8dd4e5e833 100644 --- a/crates/polars-stream/src/nodes/joins/mod.rs +++ b/crates/polars-stream/src/nodes/joins/mod.rs @@ -1,3 +1,2 @@ -pub mod equi_join; pub mod in_memory; -pub mod new_equi_join; +pub mod equi_join; diff --git a/crates/polars-stream/src/nodes/joins/new_equi_join.rs b/crates/polars-stream/src/nodes/joins/new_equi_join.rs deleted file mode 100644 index 657e5d1d5eeb..000000000000 --- a/crates/polars-stream/src/nodes/joins/new_equi_join.rs +++ /dev/null @@ -1,1582 +0,0 @@ -use std::cmp::Reverse; -use std::collections::BinaryHeap; -use std::sync::atomic::{AtomicUsize, Ordering}; -use std::sync::{Arc, LazyLock}; - -use arrow::array::builder::ShareStrategy; -use crossbeam_queue::ArrayQueue; -use polars_core::frame::builder::DataFrameBuilder; -use polars_core::prelude::*; -use polars_core::schema::{Schema, SchemaExt}; -use polars_core::{config, POOL}; -use polars_expr::hash_keys::HashKeys; -use polars_expr::idx_table::{new_idx_table, IdxTable}; -use polars_io::pl_async::get_runtime; -use polars_ops::frame::{JoinArgs, JoinType, MaintainOrderJoin}; -use polars_ops::series::coalesce_columns; -use polars_utils::cardinality_sketch::CardinalitySketch; -use polars_utils::hashing::HashPartitioner; -use polars_utils::itertools::Itertools; -use polars_utils::pl_str::PlSmallStr; -use polars_utils::priority::Priority; -use polars_utils::sparse_init_vec::SparseInitVec; -use polars_utils::{format_pl_smallstr, IdxSize}; -use rayon::prelude::*; - -use crate::async_primitives::connector::{connector, Receiver, Sender}; -use crate::async_primitives::wait_group::WaitGroup; -use crate::expression::StreamExpr; -use crate::morsel::{get_ideal_morsel_size, SourceToken}; -use crate::nodes::compute_node_prelude::*; -use crate::nodes::in_memory_source::InMemorySourceNode; - -static SAMPLE_LIMIT: LazyLock = LazyLock::new(|| { - std::env::var("POLARS_JOIN_SAMPLE_LIMIT") - .map(|limit| limit.parse().unwrap()) - .unwrap_or(10_000_000) -}); - -// If one side is this much bigger than the other side we'll always use the -// smaller side as the build side without checking cardinalities. -const LOPSIDED_SAMPLE_FACTOR: usize = 10; - -/// A payload selector contains for each column whether that column should be -/// included in the payload, and if yes with what name. -fn compute_payload_selector( - this: &Schema, - other: &Schema, - this_key_schema: &Schema, - is_left: bool, - args: &JoinArgs, -) -> PolarsResult>> { - let should_coalesce = args.should_coalesce(); - - this.iter_names() - .enumerate() - .map(|(i, c)| { - let selector = if should_coalesce && this_key_schema.contains(c) { - if is_left != (args.how == JoinType::Right) { - Some(c.clone()) - } else if args.how == JoinType::Full { - // We must keep the right-hand side keycols around for - // coalescing. - Some(format_pl_smallstr!("__POLARS_COALESCE_KEYCOL{i}")) - } else { - None - } - } else if !other.contains(c) || is_left { - Some(c.clone()) - } else { - let suffixed = format_pl_smallstr!("{}{}", c, args.suffix()); - if other.contains(&suffixed) { - polars_bail!(Duplicate: "column with name '{suffixed}' already exists\n\n\ - You may want to try:\n\ - - renaming the column prior to joining\n\ - - using the `suffix` parameter to specify a suffix different to the default one ('_right')") - } - Some(suffixed) - }; - Ok(selector) - }) - .collect() -} - -/// Fixes names and does coalescing of columns post-join. -fn postprocess_join(df: DataFrame, params: &EquiJoinParams) -> DataFrame { - if params.args.how == JoinType::Full && params.args.should_coalesce() { - // TODO: don't do string-based column lookups for each dataframe, pre-compute coalesce indices. - let mut key_idx = 0; - df.get_columns() - .iter() - .filter_map(|c| { - if let Some((key_name, _)) = params.left_key_schema.get_at_index(key_idx) { - if c.name() == key_name { - let other = df - .column(&format_pl_smallstr!("__POLARS_COALESCE_KEYCOL{key_idx}")) - .unwrap(); - key_idx += 1; - return Some(coalesce_columns(&[c.clone(), other.clone()]).unwrap()); - } - } - - if c.name().starts_with("__POLARS_COALESCE_KEYCOL") { - return None; - } - - Some(c.clone()) - }) - .collect() - } else { - df - } -} - -fn select_schema(schema: &Schema, selector: &[Option]) -> Schema { - schema - .iter_fields() - .zip(selector) - .filter_map(|(f, name)| Some(f.with_name(name.clone()?))) - .collect() -} - -async fn select_keys( - df: &DataFrame, - key_selectors: &[StreamExpr], - params: &EquiJoinParams, - state: &ExecutionState, -) -> PolarsResult { - let mut key_columns = Vec::new(); - for (i, selector) in key_selectors.iter().enumerate() { - // We use key columns entirely by position, and allow duplicate names, - // so just assign arbitrary unique names. - let unique_name = format_pl_smallstr!("__POLARS_KEYCOL_{i}"); - let s = selector.evaluate(df, state).await?; - key_columns.push(s.into_column().with_name(unique_name)); - } - let keys = DataFrame::new_with_broadcast_len(key_columns, df.height())?; - Ok(HashKeys::from_df( - &keys, - params.random_state.clone(), - params.args.nulls_equal, - true, - )) -} - -fn select_payload(df: DataFrame, selector: &[Option]) -> DataFrame { - // Maintain height of zero-width dataframes. - if df.width() == 0 { - return df; - } - - df.take_columns() - .into_iter() - .zip(selector) - .filter_map(|(c, name)| Some(c.with_name(name.clone()?))) - .collect() -} - -fn estimate_cardinality( - morsels: &[Morsel], - key_selectors: &[StreamExpr], - params: &EquiJoinParams, - state: &ExecutionState, -) -> PolarsResult { - // TODO: parallelize. - let mut sketch = CardinalitySketch::new(); - for morsel in morsels { - let hash_keys = - get_runtime().block_on(select_keys(morsel.df(), key_selectors, params, state))?; - hash_keys.sketch_cardinality(&mut sketch); - } - Ok(sketch.estimate()) -} - -struct BufferedStream { - morsels: ArrayQueue, - post_buffer_offset: MorselSeq, -} - -impl BufferedStream { - pub fn new(morsels: Vec, start_offset: MorselSeq) -> Self { - // Relabel so we can insert into parallel streams later. - let mut seq = start_offset; - let queue = ArrayQueue::new(morsels.len().max(1)); - for mut morsel in morsels { - morsel.set_seq(seq); - queue.push(morsel).unwrap(); - seq = seq.successor(); - } - - Self { - morsels: queue, - post_buffer_offset: seq, - } - } - - pub fn is_empty(&self) -> bool { - self.morsels.is_empty() - } - - #[expect(clippy::needless_lifetimes)] - pub fn reinsert<'s, 'env>( - &'s self, - num_pipelines: usize, - recv_port: Option>, - scope: &'s TaskScope<'s, 'env>, - join_handles: &mut Vec>>, - ) -> Option>> { - let receivers = if let Some(p) = recv_port { - p.parallel().into_iter().map(Some).collect_vec() - } else { - (0..num_pipelines).map(|_| None).collect_vec() - }; - - let source_token = SourceToken::new(); - let mut out = Vec::new(); - for orig_recv in receivers { - let (mut new_send, new_recv) = connector(); - out.push(new_recv); - let source_token = source_token.clone(); - join_handles.push(scope.spawn_task(TaskPriority::High, async move { - // Act like an InMemorySource node until cached morsels are consumed. - let wait_group = WaitGroup::default(); - loop { - let Some(mut morsel) = self.morsels.pop() else { - break; - }; - morsel.replace_source_token(source_token.clone()); - morsel.set_consume_token(wait_group.token()); - if new_send.send(morsel).await.is_err() { - return Ok(()); - } - wait_group.wait().await; - // TODO: Unfortunately we can't actually stop here without - // re-buffering morsels from the stream that comes after. - // if source_token.stop_requested() { - // break; - // } - } - - if let Some(mut recv) = orig_recv { - while let Ok(mut morsel) = recv.recv().await { - if source_token.stop_requested() { - morsel.source_token().stop(); - } - morsel.set_seq(morsel.seq().offset_by(self.post_buffer_offset)); - if new_send.send(morsel).await.is_err() { - break; - } - } - } - Ok(()) - })); - } - Some(out) - } -} - -impl Default for BufferedStream { - fn default() -> Self { - Self { - morsels: ArrayQueue::new(1), - post_buffer_offset: MorselSeq::default(), - } - } -} - -impl Drop for BufferedStream { - fn drop(&mut self) { - POOL.install(|| { - // Parallel drop as the state might be quite big. - (0..self.morsels.len()) - .into_par_iter() - .for_each(|_| drop(self.morsels.pop())); - }) - } -} - -#[derive(Default)] -struct SampleState { - left: Vec, - left_len: usize, - right: Vec, - right_len: usize, -} - -impl SampleState { - async fn sink( - mut recv: Receiver, - morsels: &mut Vec, - len: &mut usize, - this_final_len: Arc, - other_final_len: Arc, - ) -> PolarsResult<()> { - while let Ok(mut morsel) = recv.recv().await { - *len += morsel.df().height(); - if *len >= *SAMPLE_LIMIT - || *len - >= other_final_len - .load(Ordering::Relaxed) - .saturating_mul(LOPSIDED_SAMPLE_FACTOR) - { - morsel.source_token().stop(); - } - - drop(morsel.take_consume_token()); - morsels.push(morsel); - } - this_final_len.store(*len, Ordering::Relaxed); - Ok(()) - } - - fn try_transition_to_build( - &mut self, - recv: &[PortState], - num_pipelines: usize, - params: &mut EquiJoinParams, - table: &mut Option>, - ) -> PolarsResult> { - let left_saturated = self.left_len >= *SAMPLE_LIMIT; - let right_saturated = self.right_len >= *SAMPLE_LIMIT; - let left_done = recv[0] == PortState::Done || left_saturated; - let right_done = recv[1] == PortState::Done || right_saturated; - #[expect(clippy::nonminimal_bool)] - let stop_sampling = (left_done && right_done) - || (left_done && self.right_len >= LOPSIDED_SAMPLE_FACTOR * self.left_len) - || (right_done && self.left_len >= LOPSIDED_SAMPLE_FACTOR * self.right_len); - if !stop_sampling { - return Ok(None); - } - - if config::verbose() { - eprintln!( - "choosing equi-join build side, sample lengths are: {} vs. {}", - self.left_len, self.right_len - ); - } - - let estimate_cardinalities = || { - let execution_state = ExecutionState::new(); - let left_cardinality = estimate_cardinality( - &self.left, - ¶ms.left_key_selectors, - params, - &execution_state, - )?; - let right_cardinality = estimate_cardinality( - &self.right, - ¶ms.right_key_selectors, - params, - &execution_state, - )?; - let norm_left_factor = self.left_len.min(*SAMPLE_LIMIT) as f64 / self.left_len as f64; - let norm_right_factor = - self.right_len.min(*SAMPLE_LIMIT) as f64 / self.right_len as f64; - let norm_left_cardinality = (left_cardinality as f64 * norm_left_factor) as usize; - let norm_right_cardinality = (right_cardinality as f64 * norm_right_factor) as usize; - if config::verbose() { - eprintln!("estimated cardinalities are: {norm_left_cardinality} vs. {norm_right_cardinality}"); - } - PolarsResult::Ok((norm_left_cardinality, norm_right_cardinality)) - }; - - let left_is_build = match (left_saturated, right_saturated) { - (false, false) => { - if self.left_len * LOPSIDED_SAMPLE_FACTOR < self.right_len - || self.left_len > self.right_len * LOPSIDED_SAMPLE_FACTOR - { - // Don't bother estimating cardinality, just choose smaller as it's highly - // imbalanced. - self.left_len < self.right_len - } else { - let (lc, rc) = estimate_cardinalities()?; - // Let's assume for now that per element building a - // table is 3x more expensive than a probe, with - // unique keys getting an additional 3x factor for - // having to update the hash table in addition to the probe. - let left_build_cost = self.left_len * 3 + 3 * lc; - let left_probe_cost = self.left_len; - let right_build_cost = self.right_len * 3 + 3 * rc; - let right_probe_cost = self.right_len; - left_build_cost + right_probe_cost < left_probe_cost + right_build_cost - } - }, - - // Choose the unsaturated side, the saturated side could be - // arbitrarily big. - (false, true) => true, - (true, false) => false, - - // Estimate cardinality and choose smaller. - (true, true) => { - let (lc, rc) = estimate_cardinalities()?; - lc < rc - }, - }; - - if config::verbose() { - eprintln!( - "build side chosen: {}", - if left_is_build { "left" } else { "right" } - ); - } - - // Transition to building state. - params.left_is_build = Some(left_is_build); - *table = Some(if left_is_build { - new_idx_table(params.left_key_schema.clone()) - } else { - new_idx_table(params.right_key_schema.clone()) - }); - - let mut sampled_build_morsels = - BufferedStream::new(core::mem::take(&mut self.left), MorselSeq::default()); - let mut sampled_probe_morsels = - BufferedStream::new(core::mem::take(&mut self.right), MorselSeq::default()); - if !left_is_build { - core::mem::swap(&mut sampled_build_morsels, &mut sampled_probe_morsels); - } - - let partitioner = HashPartitioner::new(num_pipelines, 0); - let mut build_state = BuildState { - local_builders: (0..num_pipelines) - .map(|_| LocalBuilder::default()) - .collect(), - sampled_probe_morsels, - }; - - // Simulate the sample build morsels flowing into the build side. - if !sampled_build_morsels.is_empty() { - let state = ExecutionState::new(); - crate::async_executor::task_scope(|scope| { - let mut join_handles = Vec::new(); - let receivers = sampled_build_morsels - .reinsert(num_pipelines, None, scope, &mut join_handles) - .unwrap(); - - for (local_builder, recv) in build_state.local_builders.iter_mut().zip(receivers) { - join_handles.push(scope.spawn_task( - TaskPriority::High, - BuildState::partition_and_sink( - recv, - local_builder, - partitioner.clone(), - params, - &state, - ), - )); - } - - polars_io::pl_async::get_runtime().block_on(async move { - for handle in join_handles { - handle.await?; - } - PolarsResult::Ok(()) - }) - })?; - } - - Ok(Some(build_state)) - } -} - -#[derive(Default)] -struct LocalBuilder { - // The complete list of morsels and their computed hashes seen by this builder. - morsels: Vec<(MorselSeq, DataFrame, HashKeys)>, - - // A cardinality sketch per partition for the keys seen by this builder. - sketch_per_p: Vec, - - // morsel_idxs_values_per_p[p][start..stop] contains the offsets into morsels[i] - // for partition p, where start, stop are: - // let start = morsel_idxs_offsets[i * num_partitions + p]; - // let stop = morsel_idxs_offsets[(i + 1) * num_partitions + p]; - morsel_idxs_values_per_p: Vec>, - morsel_idxs_offsets_per_p: Vec, -} - -#[derive(Default)] -struct BuildState { - local_builders: Vec, - sampled_probe_morsels: BufferedStream, -} - -impl BuildState { - async fn partition_and_sink( - mut recv: Receiver, - local: &mut LocalBuilder, - partitioner: HashPartitioner, - params: &EquiJoinParams, - state: &ExecutionState, - ) -> PolarsResult<()> { - let track_unmatchable = params.emit_unmatched_build(); - local - .sketch_per_p - .resize_with(partitioner.num_partitions(), Default::default); - local - .morsel_idxs_values_per_p - .resize_with(partitioner.num_partitions(), Default::default); - - if local.morsel_idxs_offsets_per_p.is_empty() { - local - .morsel_idxs_offsets_per_p - .resize(partitioner.num_partitions(), 0); - } - - let (key_selectors, payload_selector); - if params.left_is_build.unwrap() { - payload_selector = ¶ms.left_payload_select; - key_selectors = ¶ms.left_key_selectors; - } else { - payload_selector = ¶ms.right_payload_select; - key_selectors = ¶ms.right_key_selectors; - }; - - while let Ok(morsel) = recv.recv().await { - // Compute hashed keys and payload. We must rechunk the payload for - // later gathers. - let hash_keys = select_keys(morsel.df(), key_selectors, params, state).await?; - let mut payload = select_payload(morsel.df().clone(), payload_selector); - payload.rechunk_mut(); - - hash_keys.gen_idxs_per_partition( - &partitioner, - &mut local.morsel_idxs_values_per_p, - &mut local.sketch_per_p, - track_unmatchable, - ); - - local - .morsel_idxs_offsets_per_p - .extend(local.morsel_idxs_values_per_p.iter().map(|vp| vp.len())); - local.morsels.push((morsel.seq(), payload, hash_keys)); - } - Ok(()) - } - - fn finalize_ordered(&mut self, params: &EquiJoinParams, table: &dyn IdxTable) -> ProbeState { - let track_unmatchable = params.emit_unmatched_build(); - let payload_schema = if params.left_is_build.unwrap() { - ¶ms.left_payload_schema - } else { - ¶ms.right_payload_schema - }; - - let num_partitions = self.local_builders[0].sketch_per_p.len(); - let local_builders = &self.local_builders; - let probe_tables: SparseInitVec = SparseInitVec::with_capacity(num_partitions); - - POOL.scope(|s| { - for p in 0..num_partitions { - let probe_tables = &probe_tables; - s.spawn(move |_| { - // TODO: every thread does an identical linearize, we can do a single parallel one. - let mut kmerge = BinaryHeap::with_capacity(local_builders.len()); - let mut cur_idx_per_loc = vec![0; local_builders.len()]; - - // Compute cardinality estimate and total amount of - // payload for this partition, and initialize k-way merge. - let mut sketch = CardinalitySketch::new(); - let mut payload_rows = 0; - for (l_idx, l) in local_builders.iter().enumerate() { - let Some((seq, _, _)) = l.morsels.get(0) else { continue }; - kmerge.push(Priority(Reverse(seq), l_idx)); - - sketch.combine(&l.sketch_per_p[p]); - let offsets_len = l.morsel_idxs_offsets_per_p.len(); - payload_rows += - l.morsel_idxs_offsets_per_p[offsets_len - num_partitions + p]; - } - - // Allocate hash table and payload builder. - let mut p_table = table.new_empty(); - p_table.reserve(sketch.estimate() * 5 / 4); - let mut p_payload = DataFrameBuilder::new(payload_schema.clone()); - p_payload.reserve(payload_rows); - - let mut p_seq_ids = Vec::new(); - if track_unmatchable { - p_seq_ids.reserve(payload_rows); - } - - // Linearize and build. - unsafe { - let mut norm_seq_id = 0 as IdxSize; - while let Some(Priority(Reverse(mut seq), l_idx)) = kmerge.pop() { - let l = local_builders.get_unchecked(l_idx); - let idx_in_l = *cur_idx_per_loc.get_unchecked(l_idx); - *cur_idx_per_loc.get_unchecked_mut(l_idx) += 1; - if let Some((next_seq, _, _)) = l.morsels.get(idx_in_l + 1) { - kmerge.push(Priority(Reverse(next_seq), l_idx)); - } - - let (_mseq, payload, keys) = l.morsels.get_unchecked(idx_in_l); - let p_morsel_idxs_start = - l.morsel_idxs_offsets_per_p[idx_in_l * num_partitions + p]; - let p_morsel_idxs_stop = - l.morsel_idxs_offsets_per_p[(idx_in_l + 1) * num_partitions + p]; - let p_morsel_idxs = &l.morsel_idxs_values_per_p[p] - [p_morsel_idxs_start..p_morsel_idxs_stop]; - p_table.insert_keys_subset(keys, p_morsel_idxs, track_unmatchable); - p_payload.gather_extend( - payload, - p_morsel_idxs, - ShareStrategy::Never, - ); - - if track_unmatchable { - p_seq_ids.resize(p_payload.len(), norm_seq_id); - norm_seq_id += 1; - } - } - } - - probe_tables.try_set(p, ProbeTable { - hash_table: p_table, - payload: p_payload.freeze(), - seq_ids: p_seq_ids, - }).ok().unwrap(); - }); - } - }); - - ProbeState { - table_per_partition: probe_tables.try_assume_init().ok().unwrap(), - max_seq_sent: MorselSeq::default(), - sampled_probe_morsels: core::mem::take(&mut self.sampled_probe_morsels), - } - } - - fn finalize_unordered(&mut self, params: &EquiJoinParams, table: &dyn IdxTable) -> ProbeState { - let track_unmatchable = params.emit_unmatched_build(); - let payload_schema = if params.left_is_build.unwrap() { - ¶ms.left_payload_schema - } else { - ¶ms.right_payload_schema - }; - - // To reduce maximum memory usage we want to drop the morsels - // as soon as they're processed, so we move into Arcs. The drops might - // also be expensive, so instead of directly dropping we put that on - // a work queue. - let morsels_per_local_builder = self - .local_builders - .iter_mut() - .map(|b| Arc::new(core::mem::take(&mut b.morsels))) - .collect_vec(); - let (morsel_drop_q_send, morsel_drop_q_recv) = crossbeam_channel::bounded(morsels_per_local_builder.len()); - let num_partitions = self.local_builders[0].sketch_per_p.len(); - let local_builders = &self.local_builders; - let probe_tables: SparseInitVec = SparseInitVec::with_capacity(num_partitions); - - POOL.scope(|s| { - // Wrap in outer Arc to move to each thread, performing the - // expensive clone on that thread. - let arc_morsels_per_local_builder = Arc::new(morsels_per_local_builder); - for p in 0..num_partitions { - let arc_morsels_per_local_builder = Arc::clone(&arc_morsels_per_local_builder); - let morsel_drop_q_send = morsel_drop_q_send.clone(); - let morsel_drop_q_recv = morsel_drop_q_recv.clone(); - let probe_tables = &probe_tables; - s.spawn(move |_| { - // Extract from outer arc and drop outer arc. - let morsels_per_local_builder = - Arc::unwrap_or_clone(arc_morsels_per_local_builder); - - // Compute cardinality estimate and total amount of - // payload for this partition. - let mut sketch = CardinalitySketch::new(); - let mut payload_rows = 0; - for l in local_builders { - sketch.combine(&l.sketch_per_p[p]); - let offsets_len = l.morsel_idxs_offsets_per_p.len(); - payload_rows += - l.morsel_idxs_offsets_per_p[offsets_len - num_partitions + p]; - } - - // Allocate hash table and payload builder. - let mut p_table = table.new_empty(); - p_table.reserve(sketch.estimate() * 5 / 4); - let mut p_payload = DataFrameBuilder::new(payload_schema.clone()); - p_payload.reserve(payload_rows); - - // Build. - let mut skip_drop_attempt = false; - for (l, l_morsels) in local_builders.iter().zip(morsels_per_local_builder) { - // Try to help with dropping the processed morsels. - if !skip_drop_attempt { - drop(morsel_drop_q_recv.try_recv()); - } - - for (i, morsel) in l_morsels.iter().enumerate() { - let (_mseq, payload, keys) = morsel; - unsafe { - let p_morsel_idxs_start = - l.morsel_idxs_offsets_per_p[i * num_partitions + p]; - let p_morsel_idxs_stop = - l.morsel_idxs_offsets_per_p[(i + 1) * num_partitions + p]; - let p_morsel_idxs = &l.morsel_idxs_values_per_p[p] - [p_morsel_idxs_start..p_morsel_idxs_stop]; - p_table.insert_keys_subset(keys, p_morsel_idxs, track_unmatchable); - p_payload.gather_extend( - payload, - p_morsel_idxs, - ShareStrategy::Never, - ); - } - } - - if let Some(l) = Arc::into_inner(l_morsels) { - // If we're the last thread to process this set of morsels we're probably - // falling behind the rest, since the drop can be quite expensive we skip - // a drop attempt hoping someone else will pick up the slack. - morsel_drop_q_send.send(l).unwrap(); - skip_drop_attempt = true; - } else { - skip_drop_attempt = false; - } - } - - // We're done, help others out by doing drops. - drop(morsel_drop_q_send); // So we don't deadlock trying to receive from ourselves. - while let Ok(l_morsels) = morsel_drop_q_recv.recv() { - drop(l_morsels); - } - - probe_tables.try_set(p, ProbeTable { - hash_table: p_table, - payload: p_payload.freeze(), - seq_ids: Vec::new(), - }).ok().unwrap(); - }); - } - - // Drop outer arc after spawning each thread so the inner arcs - // can get dropped as soon as they're processed. We also have to - // drop the drop queue sender so we don't deadlock waiting for it - // to end. - drop(arc_morsels_per_local_builder); - drop(morsel_drop_q_send); - }); - - ProbeState { - table_per_partition: probe_tables.try_assume_init().ok().unwrap(), - max_seq_sent: MorselSeq::default(), - sampled_probe_morsels: core::mem::take(&mut self.sampled_probe_morsels), - } - } -} - -struct ProbeTable { - hash_table: Box, - payload: DataFrame, - seq_ids: Vec, -} - -struct ProbeState { - table_per_partition: Vec, - max_seq_sent: MorselSeq, - sampled_probe_morsels: BufferedStream, -} - -impl ProbeState { - /// Returns the max morsel sequence sent. - async fn partition_and_probe( - mut recv: Receiver, - mut send: Sender, - partitions: &[ProbeTable], - partitioner: HashPartitioner, - params: &EquiJoinParams, - state: &ExecutionState, - ) -> PolarsResult { - // TODO: shuffle after partitioning and keep probe tables thread-local. - let mut partition_idxs = vec![Vec::new(); partitioner.num_partitions()]; - let mut probe_partitions = Vec::new(); - let mut materialized_idxsize_range = Vec::new(); - let mut table_match = Vec::new(); - let mut probe_match = Vec::new(); - let mut max_seq = MorselSeq::default(); - - let probe_limit = get_ideal_morsel_size() as IdxSize; - let mark_matches = params.emit_unmatched_build(); - let emit_unmatched = params.emit_unmatched_probe(); - - let (key_selectors, payload_selector, build_payload_schema, probe_payload_schema); - if params.left_is_build.unwrap() { - key_selectors = ¶ms.right_key_selectors; - payload_selector = ¶ms.right_payload_select; - build_payload_schema = ¶ms.left_payload_schema; - probe_payload_schema = ¶ms.right_payload_schema; - } else { - key_selectors = ¶ms.left_key_selectors; - payload_selector = ¶ms.left_payload_select; - build_payload_schema = ¶ms.right_payload_schema; - probe_payload_schema = ¶ms.left_payload_schema; - }; - - let mut build_out = DataFrameBuilder::new(build_payload_schema.clone()); - let mut probe_out = DataFrameBuilder::new(probe_payload_schema.clone()); - - // A simple estimate used to size reserves. - let mut selectivity_estimate = 1.0; - - while let Ok(morsel) = recv.recv().await { - // Compute hashed keys and payload. - let (df, seq, src_token, wait_token) = morsel.into_inner(); - max_seq = seq; - - let df_height = df.height(); - if df_height == 0 { - continue; - } - - let hash_keys = select_keys(&df, key_selectors, params, state).await?; - let mut payload = select_payload(df, payload_selector); - let mut payload_rechunked = false; // We don't eagerly rechunk because there might be no matches. - let mut total_matches = 0; - - // Use selectivity estimate to reserve for morsel builders. - let max_match_per_key_est = selectivity_estimate as usize + 16; - let out_est_size = ((selectivity_estimate * 1.2 * df_height as f64) as usize) - .min(probe_limit as usize); - build_out.reserve(out_est_size + max_match_per_key_est); - - unsafe { - let new_morsel = |build: &mut DataFrameBuilder, probe: &mut DataFrameBuilder| { - let mut build_df = build.freeze_reset(); - let mut probe_df = probe.freeze_reset(); - let out_df = if params.left_is_build.unwrap() { - build_df.hstack_mut_unchecked(probe_df.get_columns()); - build_df - } else { - probe_df.hstack_mut_unchecked(build_df.get_columns()); - probe_df - }; - let out_df = postprocess_join(out_df, params); - Morsel::new(out_df, seq, src_token.clone()) - }; - - if params.preserve_order_probe { - // To preserve the order we can't do bulk probes per partition and must follow - // the order of the probe morsel. We can still group probes that are - // consecutively on the same partition. - probe_partitions.clear(); - hash_keys.gen_partitions(&partitioner, &mut probe_partitions, emit_unmatched); - let mut probe_group_start = 0; - while probe_group_start < probe_partitions.len() { - let p_idx = probe_partitions[probe_group_start]; - let mut probe_group_end = probe_group_start + 1; - while probe_partitions.get(probe_group_end) == Some(&p_idx) { - probe_group_end += 1; - } - let Some(p) = partitions.get(p_idx as usize) else { - probe_group_start = probe_group_end; - continue; - }; - - materialized_idxsize_range.extend( - materialized_idxsize_range.len() as IdxSize..probe_group_end as IdxSize, - ); - - while probe_group_start < probe_group_end { - let matches_before_limit = probe_limit - probe_match.len() as IdxSize; - table_match.clear(); - probe_group_start += p.hash_table.probe_subset( - &hash_keys, - &materialized_idxsize_range[probe_group_start..probe_group_end], - &mut table_match, - &mut probe_match, - mark_matches, - emit_unmatched, - matches_before_limit, - ) as usize; - - if emit_unmatched { - build_out.opt_gather_extend( - &p.payload, - &table_match, - ShareStrategy::Always, - ); - } else { - build_out.gather_extend( - &p.payload, - &table_match, - ShareStrategy::Always, - ); - }; - - if probe_match.len() >= probe_limit as usize - || probe_group_start == probe_partitions.len() - { - if !payload_rechunked { - payload.rechunk_mut(); - payload_rechunked = true; - } - probe_out.gather_extend( - &payload, - &probe_match, - ShareStrategy::Always, - ); - probe_match.clear(); - let out_morsel = new_morsel(&mut build_out, &mut probe_out); - if send.send(out_morsel).await.is_err() { - return Ok(max_seq); - } - if probe_group_end != probe_partitions.len() { - // We had enough matches to need a mid-partition flush, let's assume there are a lot of - // matches and just do a large reserve. - build_out.reserve(probe_limit as usize + max_match_per_key_est); - } - } - } - } - } else { - // Partition and probe the tables. - for p in partition_idxs.iter_mut() { - p.clear(); - } - hash_keys.gen_idxs_per_partition( - &partitioner, - &mut partition_idxs, - &mut [], - emit_unmatched, - ); - - for (p, idxs_in_p) in partitions.iter().zip(&partition_idxs) { - let mut offset = 0; - while offset < idxs_in_p.len() { - let matches_before_limit = probe_limit - probe_match.len() as IdxSize; - table_match.clear(); - offset += p.hash_table.probe_subset( - &hash_keys, - &idxs_in_p[offset..], - &mut table_match, - &mut probe_match, - mark_matches, - emit_unmatched, - matches_before_limit, - ) as usize; - - if table_match.is_empty() { - continue; - } - total_matches += table_match.len(); - - if emit_unmatched { - build_out.opt_gather_extend( - &p.payload, - &table_match, - ShareStrategy::Always, - ); - } else { - build_out.gather_extend( - &p.payload, - &table_match, - ShareStrategy::Always, - ); - }; - - if probe_match.len() >= probe_limit as usize { - if !payload_rechunked { - payload.rechunk_mut(); - payload_rechunked = true; - } - probe_out.gather_extend( - &payload, - &probe_match, - ShareStrategy::Always, - ); - probe_match.clear(); - let out_morsel = new_morsel(&mut build_out, &mut probe_out); - if send.send(out_morsel).await.is_err() { - return Ok(max_seq); - } - // We had enough matches to need a mid-partition flush, let's assume there are a lot of - // matches and just do a large reserve. - build_out.reserve(probe_limit as usize + max_match_per_key_est); - } - } - } - - if !probe_match.is_empty() { - if !payload_rechunked { - payload.rechunk_mut(); - } - probe_out.gather_extend(&payload, &probe_match, ShareStrategy::Always); - probe_match.clear(); - let out_morsel = new_morsel(&mut build_out, &mut probe_out); - if send.send(out_morsel).await.is_err() { - return Ok(max_seq); - } - } - } - } - - drop(wait_token); - - // Move selectivity estimate a bit towards latest value. - selectivity_estimate = - 0.8 * selectivity_estimate + 0.2 * (total_matches as f64 / df_height as f64); - } - - Ok(max_seq) - } - - fn ordered_unmatched( - &mut self, - params: &EquiJoinParams, - ) -> DataFrame { - // TODO: parallelize this operator. - - let build_payload_schema = if params.left_is_build.unwrap() { - ¶ms.left_payload_schema - } else { - ¶ms.right_payload_schema - }; - - let mut unmarked_idxs = Vec::new(); - let mut linearized_idxs = Vec::new(); - - for (p_idx, p) in self.table_per_partition.iter().enumerate_idx() { - p.hash_table - .unmarked_keys(&mut unmarked_idxs, 0, IdxSize::MAX); - linearized_idxs.extend(unmarked_idxs.iter().map(|i| { - (unsafe { *p.seq_ids.get_unchecked(*i as usize) }, p_idx, *i) - })); - } - - linearized_idxs.sort_by_key(|(seq_id, _, _)| *seq_id); - - unsafe { - let mut build_out = DataFrameBuilder::new(build_payload_schema.clone()); - build_out.reserve(linearized_idxs.len()); - - // Group indices from the same partition. - let mut group_start = 0; - let mut gather_idxs = Vec::new(); - while group_start < linearized_idxs.len() { - gather_idxs.clear(); - - let (_seq, p_idx, idx_in_p) = linearized_idxs[group_start]; - gather_idxs.push(idx_in_p); - let mut group_end = group_start + 1; - while group_end < linearized_idxs.len() && linearized_idxs[group_end].1 == p_idx { - gather_idxs.push(linearized_idxs[group_end].2); - group_end += 1; - } - - build_out.gather_extend( - &self.table_per_partition[p_idx as usize].payload, - &gather_idxs, - ShareStrategy::Never, // Don't keep entire table alive for unmatched indices. - ); - - group_start = group_end; - } - - let mut build_df = build_out.freeze(); - let out_df = if params.left_is_build.unwrap() { - let probe_df = DataFrame::full_null(¶ms.right_payload_schema, build_df.height()); - build_df.hstack_mut_unchecked(probe_df.get_columns()); - build_df - } else { - let mut probe_df = DataFrame::full_null(¶ms.left_payload_schema, build_df.height()); - probe_df.hstack_mut_unchecked(build_df.get_columns()); - probe_df - }; - postprocess_join(out_df, params) - } - } -} - -impl Drop for ProbeState { - fn drop(&mut self) { - POOL.install(|| { - // Parallel drop as the state might be quite big. - self.table_per_partition.par_drain(..).for_each(drop); - }) - } -} - -struct EmitUnmatchedState { - partitions: Vec, - active_partition_idx: usize, - offset_in_active_p: usize, - morsel_seq: MorselSeq, -} - -impl EmitUnmatchedState { - async fn emit_unmatched( - &mut self, - mut send: Sender, - params: &EquiJoinParams, - num_pipelines: usize, - ) -> PolarsResult<()> { - let total_len: usize = self - .partitions - .iter() - .map(|p| p.hash_table.num_keys() as usize) - .sum(); - let ideal_morsel_count = (total_len / get_ideal_morsel_size()).max(1); - let morsel_count = ideal_morsel_count.next_multiple_of(num_pipelines); - let morsel_size = total_len.div_ceil(morsel_count).max(1); - - let wait_group = WaitGroup::default(); - let source_token = SourceToken::new(); - let mut unmarked_idxs = Vec::new(); - while let Some(p) = self.partitions.get(self.active_partition_idx) { - loop { - // Generate a chunk of unmarked key indices. - self.offset_in_active_p += p.hash_table.unmarked_keys( - &mut unmarked_idxs, - self.offset_in_active_p as IdxSize, - morsel_size as IdxSize, - ) as usize; - if unmarked_idxs.is_empty() { - break; - } - - // Gather and create full-null counterpart. - let out_df = unsafe { - let mut build_df = p.payload.take_slice_unchecked_impl(&unmarked_idxs, false); - let len = build_df.height(); - if params.left_is_build.unwrap() { - let probe_df = DataFrame::full_null(¶ms.right_payload_schema, len); - build_df.hstack_mut_unchecked(probe_df.get_columns()); - build_df - } else { - let mut probe_df = DataFrame::full_null(¶ms.left_payload_schema, len); - probe_df.hstack_mut_unchecked(build_df.get_columns()); - probe_df - } - }; - let out_df = postprocess_join(out_df, params); - - // Send and wait until consume token is consumed. - let mut morsel = Morsel::new(out_df, self.morsel_seq, source_token.clone()); - self.morsel_seq = self.morsel_seq.successor(); - morsel.set_consume_token(wait_group.token()); - if send.send(morsel).await.is_err() { - return Ok(()); - } - - wait_group.wait().await; - if source_token.stop_requested() { - return Ok(()); - } - } - - self.active_partition_idx += 1; - self.offset_in_active_p = 0; - } - - Ok(()) - } -} - -enum EquiJoinState { - Sample(SampleState), - Build(BuildState), - Probe(ProbeState), - EmitUnmatchedBuild(EmitUnmatchedState), - EmitUnmatchedBuildInOrder(InMemorySourceNode), - Done, -} - -struct EquiJoinParams { - left_is_build: Option, - preserve_order_build: bool, - preserve_order_probe: bool, - left_key_schema: Arc, - left_key_selectors: Vec, - right_key_schema: Arc, - right_key_selectors: Vec, - left_payload_select: Vec>, - right_payload_select: Vec>, - left_payload_schema: Arc, - right_payload_schema: Arc, - args: JoinArgs, - random_state: PlRandomState, -} - -impl EquiJoinParams { - /// Should we emit unmatched rows from the build side? - fn emit_unmatched_build(&self) -> bool { - if self.left_is_build.unwrap() { - self.args.how == JoinType::Left || self.args.how == JoinType::Full - } else { - self.args.how == JoinType::Right || self.args.how == JoinType::Full - } - } - - /// Should we emit unmatched rows from the probe side? - fn emit_unmatched_probe(&self) -> bool { - if self.left_is_build.unwrap() { - self.args.how == JoinType::Right || self.args.how == JoinType::Full - } else { - self.args.how == JoinType::Left || self.args.how == JoinType::Full - } - } -} - -pub struct EquiJoinNode { - state: EquiJoinState, - params: EquiJoinParams, - num_pipelines: usize, - table: Option>, -} - -impl EquiJoinNode { - pub fn new( - left_input_schema: Arc, - right_input_schema: Arc, - left_key_schema: Arc, - right_key_schema: Arc, - left_key_selectors: Vec, - right_key_selectors: Vec, - args: JoinArgs, - ) -> PolarsResult { - let left_is_build = match args.maintain_order { - MaintainOrderJoin::None => { - if *SAMPLE_LIMIT == 0 { - Some(true) - } else { - None - } - }, - MaintainOrderJoin::Left | MaintainOrderJoin::LeftRight => Some(false), - MaintainOrderJoin::Right | MaintainOrderJoin::RightLeft => Some(true), - }; - - let table = left_is_build.map(|lib| { - if lib { - new_idx_table(left_key_schema.clone()) - } else { - new_idx_table(right_key_schema.clone()) - } - }); - - let preserve_order_probe = args.maintain_order != MaintainOrderJoin::None; - let preserve_order_build = matches!( - args.maintain_order, - MaintainOrderJoin::LeftRight | MaintainOrderJoin::RightLeft - ); - - let left_payload_select = compute_payload_selector( - &left_input_schema, - &right_input_schema, - &left_key_schema, - true, - &args, - )?; - let right_payload_select = compute_payload_selector( - &right_input_schema, - &left_input_schema, - &right_key_schema, - false, - &args, - )?; - - let state = if left_is_build.is_some() { - EquiJoinState::Build(BuildState::default()) - } else { - EquiJoinState::Sample(SampleState::default()) - }; - - let left_payload_schema = Arc::new(select_schema(&left_input_schema, &left_payload_select)); - let right_payload_schema = - Arc::new(select_schema(&right_input_schema, &right_payload_select)); - Ok(Self { - state, - num_pipelines: 0, - params: EquiJoinParams { - left_is_build, - preserve_order_build, - preserve_order_probe, - left_key_schema, - left_key_selectors, - right_key_schema, - right_key_selectors, - left_payload_select, - right_payload_select, - left_payload_schema, - right_payload_schema, - args, - random_state: PlRandomState::new(), - }, - table, - }) - } -} - -impl ComputeNode for EquiJoinNode { - fn name(&self) -> &str { - "equi_join" - } - - fn initialize(&mut self, num_pipelines: usize) { - self.num_pipelines = num_pipelines; - } - - fn update_state(&mut self, recv: &mut [PortState], send: &mut [PortState]) -> PolarsResult<()> { - assert!(recv.len() == 2 && send.len() == 1); - - // If the output doesn't want any more data, transition to being done. - if send[0] == PortState::Done { - self.state = EquiJoinState::Done; - } - - // If we are sampling and both sides are done/filled, transition to building. - if let EquiJoinState::Sample(sample_state) = &mut self.state { - if let Some(build_state) = sample_state.try_transition_to_build( - recv, - self.num_pipelines, - &mut self.params, - &mut self.table, - )? { - self.state = EquiJoinState::Build(build_state); - } - } - - let build_idx = if self.params.left_is_build == Some(true) { - 0 - } else { - 1 - }; - let probe_idx = 1 - build_idx; - - // If we are building and the build input is done, transition to probing. - if let EquiJoinState::Build(build_state) = &mut self.state { - if recv[build_idx] == PortState::Done { - let probe_state = if self.params.preserve_order_build { - build_state.finalize_ordered(&self.params, self.table.as_deref().unwrap()) - } else { - build_state.finalize_unordered(&self.params, self.table.as_deref().unwrap()) - }; - self.state = EquiJoinState::Probe(probe_state); - } - } - - // If we are probing and the probe input is done, emit unmatched if - // necessary, otherwise we're done. - if let EquiJoinState::Probe(probe_state) = &mut self.state { - let samples_consumed = probe_state.sampled_probe_morsels.is_empty(); - if samples_consumed && recv[probe_idx] == PortState::Done { - if self.params.emit_unmatched_build() { - if self.params.preserve_order_build { - let partitioner = HashPartitioner::new(self.num_pipelines, 0); - let unmatched = probe_state.ordered_unmatched(&self.params); - let mut src = InMemorySourceNode::new( - Arc::new(unmatched), - probe_state.max_seq_sent.successor(), - ); - src.initialize(self.num_pipelines); - self.state = EquiJoinState::EmitUnmatchedBuildInOrder(src); - } else { - self.state = EquiJoinState::EmitUnmatchedBuild(EmitUnmatchedState { - partitions: core::mem::take(&mut probe_state.table_per_partition), - active_partition_idx: 0, - offset_in_active_p: 0, - morsel_seq: probe_state.max_seq_sent.successor(), - }); - } - } else { - self.state = EquiJoinState::Done; - } - } - } - - // Finally, check if we are done emitting unmatched keys. - if let EquiJoinState::EmitUnmatchedBuild(emit_state) = &mut self.state { - if emit_state.active_partition_idx >= emit_state.partitions.len() { - self.state = EquiJoinState::Done; - } - } - - match &mut self.state { - EquiJoinState::Sample(sample_state) => { - send[0] = PortState::Blocked; - if recv[0] != PortState::Done { - recv[0] = if sample_state.left_len < *SAMPLE_LIMIT { - PortState::Ready - } else { - PortState::Blocked - }; - } - if recv[1] != PortState::Done { - recv[1] = if sample_state.right_len < *SAMPLE_LIMIT { - PortState::Ready - } else { - PortState::Blocked - }; - } - }, - EquiJoinState::Build(_) => { - send[0] = PortState::Blocked; - if recv[build_idx] != PortState::Done { - recv[build_idx] = PortState::Ready; - } - if recv[probe_idx] != PortState::Done { - recv[probe_idx] = PortState::Blocked; - } - }, - EquiJoinState::Probe(probe_state) => { - if recv[probe_idx] != PortState::Done { - core::mem::swap(&mut send[0], &mut recv[probe_idx]); - } else { - let samples_consumed = probe_state.sampled_probe_morsels.is_empty(); - send[0] = if samples_consumed { - PortState::Done - } else { - PortState::Ready - }; - } - recv[build_idx] = PortState::Done; - }, - EquiJoinState::EmitUnmatchedBuild(_) => { - send[0] = PortState::Ready; - recv[build_idx] = PortState::Done; - recv[probe_idx] = PortState::Done; - }, - EquiJoinState::EmitUnmatchedBuildInOrder(src_node) => { - recv[build_idx] = PortState::Done; - recv[probe_idx] = PortState::Done; - src_node.update_state(&mut [], &mut send[0..1])?; - if send[0] == PortState::Done { - self.state = EquiJoinState::Done; - } - }, - EquiJoinState::Done => { - send[0] = PortState::Done; - recv[0] = PortState::Done; - recv[1] = PortState::Done; - }, - } - Ok(()) - } - - fn is_memory_intensive_pipeline_blocker(&self) -> bool { - matches!( - self.state, - EquiJoinState::Sample { .. } | EquiJoinState::Build { .. } - ) - } - - fn spawn<'env, 's>( - &'env mut self, - scope: &'s TaskScope<'s, 'env>, - recv_ports: &mut [Option>], - send_ports: &mut [Option>], - state: &'s ExecutionState, - join_handles: &mut Vec>>, - ) { - assert!(recv_ports.len() == 2); - assert!(send_ports.len() == 1); - - let build_idx = if self.params.left_is_build == Some(true) { - 0 - } else { - 1 - }; - let probe_idx = 1 - build_idx; - - match &mut self.state { - EquiJoinState::Sample(sample_state) => { - assert!(send_ports[0].is_none()); - let left_final_len = Arc::new(AtomicUsize::new(if recv_ports[0].is_none() { - sample_state.left_len - } else { - usize::MAX - })); - let right_final_len = Arc::new(AtomicUsize::new(if recv_ports[1].is_none() { - sample_state.right_len - } else { - usize::MAX - })); - - if let Some(left_recv) = recv_ports[0].take() { - join_handles.push(scope.spawn_task( - TaskPriority::High, - SampleState::sink( - left_recv.serial(), - &mut sample_state.left, - &mut sample_state.left_len, - left_final_len.clone(), - right_final_len.clone(), - ), - )); - } - if let Some(right_recv) = recv_ports[1].take() { - join_handles.push(scope.spawn_task( - TaskPriority::High, - SampleState::sink( - right_recv.serial(), - &mut sample_state.right, - &mut sample_state.right_len, - right_final_len, - left_final_len, - ), - )); - } - }, - EquiJoinState::Build(build_state) => { - assert!(send_ports[0].is_none()); - assert!(recv_ports[probe_idx].is_none()); - let receivers = recv_ports[build_idx].take().unwrap().parallel(); - - build_state - .local_builders - .resize_with(self.num_pipelines, Default::default); - let partitioner = HashPartitioner::new(self.num_pipelines, 0); - for (local_builder, recv) in build_state.local_builders.iter_mut().zip(receivers) { - join_handles.push(scope.spawn_task( - TaskPriority::High, - BuildState::partition_and_sink( - recv, - local_builder, - partitioner.clone(), - &self.params, - state, - ), - )); - } - }, - EquiJoinState::Probe(probe_state) => { - assert!(recv_ports[build_idx].is_none()); - let senders = send_ports[0].take().unwrap().parallel(); - let receivers = probe_state - .sampled_probe_morsels - .reinsert( - self.num_pipelines, - recv_ports[probe_idx].take(), - scope, - join_handles, - ) - .unwrap(); - - let partitioner = HashPartitioner::new(self.num_pipelines, 0); - let probe_tasks = receivers - .into_iter() - .zip(senders) - .map(|(recv, send)| { - scope.spawn_task( - TaskPriority::High, - ProbeState::partition_and_probe( - recv, - send, - &probe_state.table_per_partition, - partitioner.clone(), - &self.params, - state, - ), - ) - }) - .collect_vec(); - - let max_seq_sent = &mut probe_state.max_seq_sent; - join_handles.push(scope.spawn_task(TaskPriority::High, async move { - for probe_task in probe_tasks { - *max_seq_sent = (*max_seq_sent).max(probe_task.await?); - } - Ok(()) - })); - }, - EquiJoinState::EmitUnmatchedBuild(emit_state) => { - assert!(recv_ports[build_idx].is_none()); - assert!(recv_ports[probe_idx].is_none()); - let send = send_ports[0].take().unwrap().serial(); - join_handles.push(scope.spawn_task( - TaskPriority::Low, - emit_state.emit_unmatched(send, &self.params, self.num_pipelines), - )); - }, - EquiJoinState::EmitUnmatchedBuildInOrder(src_node) => { - assert!(recv_ports[build_idx].is_none()); - assert!(recv_ports[probe_idx].is_none()); - src_node.spawn(scope, &mut [], send_ports, state, join_handles); - }, - EquiJoinState::Done => unreachable!(), - } - } -} diff --git a/crates/polars-stream/src/physical_plan/to_graph.rs b/crates/polars-stream/src/physical_plan/to_graph.rs index 4d8685e2a638..e9fb61127f3d 100644 --- a/crates/polars-stream/src/physical_plan/to_graph.rs +++ b/crates/polars-stream/src/physical_plan/to_graph.rs @@ -809,13 +809,8 @@ fn to_graph_rec<'a>( .map(|e| create_stream_expr(e, ctx, &right_input_schema)) .try_collect_vec()?; - // TODO: implement build-side order-maintaining join in new join impl. - let preserve_order_build = matches!( - args.maintain_order, - MaintainOrderJoin::LeftRight | MaintainOrderJoin::RightLeft - ); ctx.graph.add_node( - nodes::joins::new_equi_join::EquiJoinNode::new( + nodes::joins::equi_join::EquiJoinNode::new( left_input_schema, right_input_schema, left_key_schema, From c25619538bed46362c70460d4ba87f523a785f84 Mon Sep 17 00:00:00 2001 From: Orson Peters Date: Fri, 7 Mar 2025 17:17:19 +0100 Subject: [PATCH 19/25] fmt --- .../src/nodes/joins/equi_join.rs | 80 +++++++++++-------- crates/polars-stream/src/nodes/joins/mod.rs | 2 +- crates/polars-utils/src/sparse_init_vec.rs | 19 +++-- 3 files changed, 56 insertions(+), 45 deletions(-) diff --git a/crates/polars-stream/src/nodes/joins/equi_join.rs b/crates/polars-stream/src/nodes/joins/equi_join.rs index c1bcd6beeaa2..e0a44f443573 100644 --- a/crates/polars-stream/src/nodes/joins/equi_join.rs +++ b/crates/polars-stream/src/nodes/joins/equi_join.rs @@ -546,7 +546,7 @@ impl BuildState { let num_partitions = self.local_builders[0].sketch_per_p.len(); let local_builders = &self.local_builders; let probe_tables: SparseInitVec = SparseInitVec::with_capacity(num_partitions); - + POOL.scope(|s| { for p in 0..num_partitions { let probe_tables = &probe_tables; @@ -560,7 +560,9 @@ impl BuildState { let mut sketch = CardinalitySketch::new(); let mut payload_rows = 0; for (l_idx, l) in local_builders.iter().enumerate() { - let Some((seq, _, _)) = l.morsels.get(0) else { continue }; + let Some((seq, _, _)) = l.morsels.get(0) else { + continue; + }; kmerge.push(Priority(Reverse(seq), l_idx)); sketch.combine(&l.sketch_per_p[p]); @@ -574,7 +576,7 @@ impl BuildState { p_table.reserve(sketch.estimate() * 5 / 4); let mut p_payload = DataFrameBuilder::new(payload_schema.clone()); p_payload.reserve(payload_rows); - + let mut p_seq_ids = Vec::new(); if track_unmatchable { p_seq_ids.reserve(payload_rows); @@ -590,7 +592,7 @@ impl BuildState { if let Some((next_seq, _, _)) = l.morsels.get(idx_in_l + 1) { kmerge.push(Priority(Reverse(next_seq), l_idx)); } - + let (_mseq, payload, keys) = l.morsels.get_unchecked(idx_in_l); let p_morsel_idxs_start = l.morsel_idxs_offsets_per_p[idx_in_l * num_partitions + p]; @@ -599,12 +601,8 @@ impl BuildState { let p_morsel_idxs = &l.morsel_idxs_values_per_p[p] [p_morsel_idxs_start..p_morsel_idxs_stop]; p_table.insert_keys_subset(keys, p_morsel_idxs, track_unmatchable); - p_payload.gather_extend( - payload, - p_morsel_idxs, - ShareStrategy::Never, - ); - + p_payload.gather_extend(payload, p_morsel_idxs, ShareStrategy::Never); + if track_unmatchable { p_seq_ids.resize(p_payload.len(), norm_seq_id); norm_seq_id += 1; @@ -612,11 +610,17 @@ impl BuildState { } } - probe_tables.try_set(p, ProbeTable { - hash_table: p_table, - payload: p_payload.freeze(), - seq_ids: p_seq_ids, - }).ok().unwrap(); + probe_tables + .try_set( + p, + ProbeTable { + hash_table: p_table, + payload: p_payload.freeze(), + seq_ids: p_seq_ids, + }, + ) + .ok() + .unwrap(); }); } }); @@ -645,7 +649,8 @@ impl BuildState { .iter_mut() .map(|b| Arc::new(core::mem::take(&mut b.morsels))) .collect_vec(); - let (morsel_drop_q_send, morsel_drop_q_recv) = crossbeam_channel::bounded(morsels_per_local_builder.len()); + let (morsel_drop_q_send, morsel_drop_q_recv) = + crossbeam_channel::bounded(morsels_per_local_builder.len()); let num_partitions = self.local_builders[0].sketch_per_p.len(); let local_builders = &self.local_builders; let probe_tables: SparseInitVec = SparseInitVec::with_capacity(num_partitions); @@ -706,7 +711,7 @@ impl BuildState { ); } } - + if let Some(l) = Arc::into_inner(l_morsels) { // If we're the last thread to process this set of morsels we're probably // falling behind the rest, since the drop can be quite expensive we skip @@ -717,18 +722,24 @@ impl BuildState { skip_drop_attempt = false; } } - + // We're done, help others out by doing drops. drop(morsel_drop_q_send); // So we don't deadlock trying to receive from ourselves. while let Ok(l_morsels) = morsel_drop_q_recv.recv() { drop(l_morsels); } - probe_tables.try_set(p, ProbeTable { - hash_table: p_table, - payload: p_payload.freeze(), - seq_ids: Vec::new(), - }).ok().unwrap(); + probe_tables + .try_set( + p, + ProbeTable { + hash_table: p_table, + payload: p_payload.freeze(), + seq_ids: Vec::new(), + }, + ) + .ok() + .unwrap(); }); } @@ -1003,10 +1014,7 @@ impl ProbeState { Ok(max_seq) } - fn ordered_unmatched( - &mut self, - params: &EquiJoinParams, - ) -> DataFrame { + fn ordered_unmatched(&mut self, params: &EquiJoinParams) -> DataFrame { // TODO: parallelize this operator. let build_payload_schema = if params.left_is_build.unwrap() { @@ -1021,11 +1029,13 @@ impl ProbeState { for (p_idx, p) in self.table_per_partition.iter().enumerate_idx() { p.hash_table .unmarked_keys(&mut unmarked_idxs, 0, IdxSize::MAX); - linearized_idxs.extend(unmarked_idxs.iter().map(|i| { - (unsafe { *p.seq_ids.get_unchecked(*i as usize) }, p_idx, *i) - })); + linearized_idxs.extend( + unmarked_idxs + .iter() + .map(|i| (unsafe { *p.seq_ids.get_unchecked(*i as usize) }, p_idx, *i)), + ); } - + linearized_idxs.sort_by_key(|(seq_id, _, _)| *seq_id); unsafe { @@ -1045,7 +1055,7 @@ impl ProbeState { gather_idxs.push(linearized_idxs[group_end].2); group_end += 1; } - + build_out.gather_extend( &self.table_per_partition[p_idx as usize].payload, &gather_idxs, @@ -1057,11 +1067,13 @@ impl ProbeState { let mut build_df = build_out.freeze(); let out_df = if params.left_is_build.unwrap() { - let probe_df = DataFrame::full_null(¶ms.right_payload_schema, build_df.height()); + let probe_df = + DataFrame::full_null(¶ms.right_payload_schema, build_df.height()); build_df.hstack_mut_unchecked(probe_df.get_columns()); build_df } else { - let mut probe_df = DataFrame::full_null(¶ms.left_payload_schema, build_df.height()); + let mut probe_df = + DataFrame::full_null(¶ms.left_payload_schema, build_df.height()); probe_df.hstack_mut_unchecked(build_df.get_columns()); probe_df }; diff --git a/crates/polars-stream/src/nodes/joins/mod.rs b/crates/polars-stream/src/nodes/joins/mod.rs index eb8dd4e5e833..f5304162d56a 100644 --- a/crates/polars-stream/src/nodes/joins/mod.rs +++ b/crates/polars-stream/src/nodes/joins/mod.rs @@ -1,2 +1,2 @@ -pub mod in_memory; pub mod equi_join; +pub mod in_memory; diff --git a/crates/polars-utils/src/sparse_init_vec.rs b/crates/polars-utils/src/sparse_init_vec.rs index 12cef03cb2ea..16de05bad2c7 100644 --- a/crates/polars-utils/src/sparse_init_vec.rs +++ b/crates/polars-utils/src/sparse_init_vec.rs @@ -1,17 +1,16 @@ -use std::sync::atomic::{AtomicUsize, AtomicU8, Ordering}; - +use std::sync::atomic::{AtomicU8, AtomicUsize, Ordering}; pub struct SparseInitVec { ptr: *mut T, len: usize, cap: usize, - + num_init: AtomicUsize, init_mask: Vec, } -unsafe impl Send for SparseInitVec { } -unsafe impl Sync for SparseInitVec { } +unsafe impl Send for SparseInitVec {} +unsafe impl Sync for SparseInitVec {} impl SparseInitVec { pub fn with_capacity(len: usize) -> Self { @@ -28,7 +27,7 @@ impl SparseInitVec { init_mask, } } - + pub fn try_set(&self, idx: usize, value: T) -> Result<(), T> { unsafe { if idx >= self.len { @@ -42,14 +41,14 @@ impl SparseInitVec { if init_mask_byte.fetch_or(bit_mask, Ordering::Relaxed) & bit_mask != 0 { return Err(value); } - + self.ptr.add(idx).write(value); self.num_init.fetch_add(1, Ordering::Relaxed); } - + Ok(()) } - + pub fn try_assume_init(mut self) -> Result, Self> { unsafe { if *self.num_init.get_mut() == self.len { @@ -69,7 +68,7 @@ impl Drop for SparseInitVec { unsafe { // Make sure storage gets dropped even if element drop panics. let _storage = Vec::from_raw_parts(self.ptr, 0, self.cap); - + for idx in 0..self.len { let init_mask_byte = self.init_mask.get_unchecked_mut(idx / 8); let bit_mask = 1 << (idx % 8); From 988cbeaf7bd392601324e3f530a82f09ba75b213 Mon Sep 17 00:00:00 2001 From: Orson Peters Date: Fri, 7 Mar 2025 17:19:13 +0100 Subject: [PATCH 20/25] remove chunkedidxtable --- .../polars-expr/src/chunked_idx_table/mod.rs | 65 ---- .../src/chunked_idx_table/row_encoded.rs | 308 ------------------ 2 files changed, 373 deletions(-) delete mode 100644 crates/polars-expr/src/chunked_idx_table/mod.rs delete mode 100644 crates/polars-expr/src/chunked_idx_table/row_encoded.rs diff --git a/crates/polars-expr/src/chunked_idx_table/mod.rs b/crates/polars-expr/src/chunked_idx_table/mod.rs deleted file mode 100644 index 948e34effad0..000000000000 --- a/crates/polars-expr/src/chunked_idx_table/mod.rs +++ /dev/null @@ -1,65 +0,0 @@ -use std::any::Any; - -use polars_core::prelude::*; -use polars_utils::index::ChunkId; -use polars_utils::IdxSize; - -use crate::hash_keys::HashKeys; - -mod row_encoded; - -pub trait ChunkedIdxTable: Any + Send + Sync { - /// Creates a new empty ChunkedIdxTable similar to this one. - fn new_empty(&self) -> Box; - - /// Reserves space for the given number additional keys. - fn reserve(&mut self, additional: usize); - - /// Returns the number of unique keys in this ChunkedIdxTable. - fn num_keys(&self) -> IdxSize; - - /// Inserts the given key chunk into this ChunkedIdxTable. - fn insert_key_chunk(&mut self, keys: HashKeys, track_unmatchable: bool); - - /// Probe the table, updating table_match and probe_match with - /// (ChunkId, IdxSize) pairs for each match. Will stop processing new keys - /// once limit matches have been generated, returning the number of keys - /// processed. - /// - /// If mark_matches is true, matches are marked in the table as such. - /// - /// If emit_unmatched is true, for keys that do not have a match we emit a - /// match with ChunkId::null() on the table match. - fn probe( - &self, - hash_keys: &HashKeys, - table_match: &mut Vec>, - probe_match: &mut Vec, - mark_matches: bool, - emit_unmatched: bool, - limit: IdxSize, - ) -> IdxSize; - - /// The same as probe, except it will only apply to the specified subset of keys. - /// # Safety - /// The provided subset indices must be in-bounds. - #[allow(clippy::too_many_arguments)] - unsafe fn probe_subset( - &self, - hash_keys: &HashKeys, - subset: &[IdxSize], - table_match: &mut Vec>, - probe_match: &mut Vec, - mark_matches: bool, - emit_unmatched: bool, - limit: IdxSize, - ) -> IdxSize; - - /// Get the ChunkIds for each key which was never marked during probing. - fn unmarked_keys(&self, out: &mut Vec>, offset: IdxSize, limit: IdxSize) - -> IdxSize; -} - -pub fn new_chunked_idx_table(_key_schema: Arc) -> Box { - Box::new(row_encoded::RowEncodedChunkedIdxTable::new()) -} diff --git a/crates/polars-expr/src/chunked_idx_table/row_encoded.rs b/crates/polars-expr/src/chunked_idx_table/row_encoded.rs deleted file mode 100644 index 98fa4e8a821b..000000000000 --- a/crates/polars-expr/src/chunked_idx_table/row_encoded.rs +++ /dev/null @@ -1,308 +0,0 @@ -use std::sync::atomic::{AtomicU64, Ordering}; - -use arrow::array::Array; -use polars_utils::idx_map::bytes_idx_map::{BytesIndexMap, Entry}; -use polars_utils::idx_vec::UnitVec; -use polars_utils::itertools::Itertools; -use polars_utils::unitvec; - -use super::*; -use crate::hash_keys::HashKeys; - -#[derive(Default)] -pub struct RowEncodedChunkedIdxTable { - // These AtomicU64s actually are ChunkIds, but we use the top bit of the - // first chunk in each to mark keys during probing. - idx_map: BytesIndexMap>, - chunk_ctr: u32, - null_keys: Vec>, -} - -impl RowEncodedChunkedIdxTable { - pub fn new() -> Self { - Self { - idx_map: BytesIndexMap::new(), - chunk_ctr: 0, - null_keys: Vec::new(), - } - } -} - -impl RowEncodedChunkedIdxTable { - #[inline(always)] - fn probe_one( - &self, - key_idx: IdxSize, - hash: u64, - key: &[u8], - table_match: &mut Vec>, - probe_match: &mut Vec, - ) -> bool { - if let Some(chunk_ids) = self.idx_map.get(hash, key) { - for chunk_id in &chunk_ids[..] { - // Create matches, making sure to clear top bit. - let raw_chunk_id = chunk_id.load(Ordering::Relaxed); - let chunk_id = ChunkId::from_inner(raw_chunk_id & !(1 << 63)); - table_match.push(chunk_id); - probe_match.push(key_idx); - } - - // Mark if necessary. This action is idempotent so doesn't - // need any synchronization on the load, nor does it need a - // fetch_or to do it atomically. - if MARK_MATCHES { - let first_chunk_id = unsafe { chunk_ids.get_unchecked(0) }; - let first_chunk_val = first_chunk_id.load(Ordering::Relaxed); - if first_chunk_val >> 63 == 0 { - first_chunk_id.store(first_chunk_val | (1 << 63), Ordering::Release); - } - } - true - } else { - false - } - } - - fn probe_impl<'a, const MARK_MATCHES: bool, const EMIT_UNMATCHED: bool>( - &self, - hash_keys: impl Iterator)>, - table_match: &mut Vec>, - probe_match: &mut Vec, - limit: IdxSize, - ) -> IdxSize { - let mut keys_processed = 0; - for (key_idx, hash, key) in hash_keys { - let found_match = if let Some(key) = key { - self.probe_one::(key_idx, hash, key, table_match, probe_match) - } else { - false - }; - - if EMIT_UNMATCHED && !found_match { - table_match.push(ChunkId::null()); - probe_match.push(key_idx); - } - - keys_processed += 1; - if table_match.len() >= limit as usize { - break; - } - } - keys_processed - } - - fn probe_dispatch<'a>( - &self, - hash_keys: impl Iterator)>, - table_match: &mut Vec>, - probe_match: &mut Vec, - mark_matches: bool, - emit_unmatched: bool, - limit: IdxSize, - ) -> IdxSize { - match (mark_matches, emit_unmatched) { - (false, false) => { - self.probe_impl::(hash_keys, table_match, probe_match, limit) - }, - (false, true) => { - self.probe_impl::(hash_keys, table_match, probe_match, limit) - }, - (true, false) => { - self.probe_impl::(hash_keys, table_match, probe_match, limit) - }, - (true, true) => { - self.probe_impl::(hash_keys, table_match, probe_match, limit) - }, - } - } -} - -impl ChunkedIdxTable for RowEncodedChunkedIdxTable { - fn new_empty(&self) -> Box { - Box::new(Self::new()) - } - - fn reserve(&mut self, additional: usize) { - self.idx_map.reserve(additional); - } - - fn num_keys(&self) -> IdxSize { - self.idx_map.len() - } - - fn insert_key_chunk(&mut self, hash_keys: HashKeys, track_unmatchable: bool) { - let HashKeys::RowEncoded(hash_keys) = hash_keys else { - unreachable!() - }; - if hash_keys.keys.len() >= 1 << 31 { - panic!("overly large chunk in RowEncodedChunkedIdxTable"); - } - - for (i, (hash, key)) in hash_keys - .hashes - .values_iter() - .zip(hash_keys.keys.iter()) - .enumerate_idx() - { - let chunk_id = ChunkId::<32>::store(self.chunk_ctr as IdxSize, i); - if let Some(key) = key { - let chunk_id = AtomicU64::new(chunk_id.into_inner()); - match self.idx_map.entry(*hash, key) { - Entry::Occupied(o) => { - o.into_mut().push(chunk_id); - }, - Entry::Vacant(v) => { - v.insert(unitvec![chunk_id]); - }, - } - } else if track_unmatchable { - self.null_keys.push(chunk_id); - } - } - - self.chunk_ctr = self.chunk_ctr.checked_add(1).unwrap(); - } - - fn probe( - &self, - hash_keys: &HashKeys, - table_match: &mut Vec>, - probe_match: &mut Vec, - mark_matches: bool, - emit_unmatched: bool, - limit: IdxSize, - ) -> IdxSize { - let HashKeys::RowEncoded(hash_keys) = hash_keys else { - unreachable!() - }; - - if hash_keys.keys.has_nulls() { - let iter = hash_keys - .hashes - .values_iter() - .copied() - .zip(hash_keys.keys.iter()) - .enumerate_idx() - .map(|(i, (h, k))| (i, h, k)); - self.probe_dispatch( - iter, - table_match, - probe_match, - mark_matches, - emit_unmatched, - limit, - ) - } else { - let iter = hash_keys - .hashes - .values_iter() - .copied() - .zip(hash_keys.keys.values_iter().map(Some)) - .enumerate_idx() - .map(|(i, (h, k))| (i, h, k)); - self.probe_dispatch( - iter, - table_match, - probe_match, - mark_matches, - emit_unmatched, - limit, - ) - } - } - - unsafe fn probe_subset( - &self, - hash_keys: &HashKeys, - subset: &[IdxSize], - table_match: &mut Vec>, - probe_match: &mut Vec, - mark_matches: bool, - emit_unmatched: bool, - limit: IdxSize, - ) -> IdxSize { - let HashKeys::RowEncoded(hash_keys) = hash_keys else { - unreachable!() - }; - - if hash_keys.keys.has_nulls() { - let iter = subset.iter().map(|i| { - ( - *i, - hash_keys.hashes.value_unchecked(*i as usize), - hash_keys.keys.get_unchecked(*i as usize), - ) - }); - self.probe_dispatch( - iter, - table_match, - probe_match, - mark_matches, - emit_unmatched, - limit, - ) - } else { - let iter = subset.iter().map(|i| { - ( - *i, - hash_keys.hashes.value_unchecked(*i as usize), - Some(hash_keys.keys.value_unchecked(*i as usize)), - ) - }); - self.probe_dispatch( - iter, - table_match, - probe_match, - mark_matches, - emit_unmatched, - limit, - ) - } - } - - fn unmarked_keys( - &self, - out: &mut Vec>, - mut offset: IdxSize, - limit: IdxSize, - ) -> IdxSize { - out.clear(); - - let mut keys_processed = 0; - if (offset as usize) < self.null_keys.len() { - out.extend( - self.null_keys[offset as usize..] - .iter() - .copied() - .take(limit as usize), - ); - keys_processed += out.len() as IdxSize; - offset += out.len() as IdxSize; - if out.len() >= limit as usize { - return keys_processed; - } - } - - offset -= self.null_keys.len() as IdxSize; - - while let Some((_, _, chunk_ids)) = self.idx_map.get_index(offset) { - let first_chunk_id = unsafe { chunk_ids.get_unchecked(0) }; - let first_chunk_val = first_chunk_id.load(Ordering::Acquire); - if first_chunk_val >> 63 == 0 { - for chunk_id in &chunk_ids[..] { - let raw_chunk_id = chunk_id.load(Ordering::Relaxed); - let chunk_id = ChunkId::from_inner(raw_chunk_id & !(1 << 63)); - out.push(chunk_id); - } - } - - keys_processed += 1; - offset += 1; - if out.len() >= limit as usize { - break; - } - } - - keys_processed - } -} From b4e653c33009756fbeb6eb8927da8e0f9607b66d Mon Sep 17 00:00:00 2001 From: Orson Peters Date: Fri, 7 Mar 2025 17:19:16 +0100 Subject: [PATCH 21/25] clippy --- crates/polars-expr/src/lib.rs | 1 - crates/polars-stream/src/nodes/joins/equi_join.rs | 4 ++-- crates/polars-stream/src/physical_plan/to_graph.rs | 1 - 3 files changed, 2 insertions(+), 4 deletions(-) diff --git a/crates/polars-expr/src/lib.rs b/crates/polars-expr/src/lib.rs index ef6e96db1b1a..2683971d96de 100644 --- a/crates/polars-expr/src/lib.rs +++ b/crates/polars-expr/src/lib.rs @@ -1,4 +1,3 @@ -pub mod chunked_idx_table; mod expressions; pub mod groups; pub mod hash_keys; diff --git a/crates/polars-stream/src/nodes/joins/equi_join.rs b/crates/polars-stream/src/nodes/joins/equi_join.rs index e0a44f443573..42af13315033 100644 --- a/crates/polars-stream/src/nodes/joins/equi_join.rs +++ b/crates/polars-stream/src/nodes/joins/equi_join.rs @@ -560,7 +560,7 @@ impl BuildState { let mut sketch = CardinalitySketch::new(); let mut payload_rows = 0; for (l_idx, l) in local_builders.iter().enumerate() { - let Some((seq, _, _)) = l.morsels.get(0) else { + let Some((seq, _, _)) = l.morsels.first() else { continue; }; kmerge.push(Priority(Reverse(seq), l_idx)); @@ -585,7 +585,7 @@ impl BuildState { // Linearize and build. unsafe { let mut norm_seq_id = 0 as IdxSize; - while let Some(Priority(Reverse(seq), l_idx)) = kmerge.pop() { + while let Some(Priority(Reverse(_seq), l_idx)) = kmerge.pop() { let l = local_builders.get_unchecked(l_idx); let idx_in_l = *cur_idx_per_loc.get_unchecked(l_idx); *cur_idx_per_loc.get_unchecked_mut(l_idx) += 1; diff --git a/crates/polars-stream/src/physical_plan/to_graph.rs b/crates/polars-stream/src/physical_plan/to_graph.rs index e9fb61127f3d..7f0170677e36 100644 --- a/crates/polars-stream/src/physical_plan/to_graph.rs +++ b/crates/polars-stream/src/physical_plan/to_graph.rs @@ -9,7 +9,6 @@ use polars_expr::planner::{create_physical_expr, get_expr_depth_limit, Expressio use polars_expr::reduce::into_reduction; use polars_expr::state::ExecutionState; use polars_mem_engine::{create_physical_plan, create_scan_predicate}; -use polars_ops::frame::MaintainOrderJoin; use polars_plan::dsl::{JoinOptions, PartitionVariant}; use polars_plan::global::_set_n_rows_for_scan; use polars_plan::plans::expr_ir::ExprIR; From cea57438ae9b48f7c67d422a1374050ab6c9467f Mon Sep 17 00:00:00 2001 From: Orson Peters Date: Fri, 7 Mar 2025 17:22:44 +0100 Subject: [PATCH 22/25] feature flag --- crates/polars-core/src/series/builder.rs | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/crates/polars-core/src/series/builder.rs b/crates/polars-core/src/series/builder.rs index 73459ffae1fb..71119f91171b 100644 --- a/crates/polars-core/src/series/builder.rs +++ b/crates/polars-core/src/series/builder.rs @@ -1,6 +1,7 @@ use arrow::array::builder::{make_builder, ArrayBuilder, ShareStrategy}; use polars_utils::IdxSize; +#[cfg(feature = "object")] use crate::chunked_array::object::registry::get_object_builder; use crate::prelude::*; use crate::utils::Container; @@ -13,12 +14,14 @@ pub struct SeriesBuilder { impl SeriesBuilder { pub fn new(dtype: DataType) -> Self { - let builder = if matches!(dtype, DataType::Object(_)) { - // FIXME: get rid of this hack. - get_object_builder(PlSmallStr::EMPTY, 0).as_array_builder() - } else { - make_builder(&dtype.to_physical().to_arrow(CompatLevel::newest())) - }; + // FIXME: get rid of this hack. + #[cfg(feature = "object")] + if matches!(dtype, DataType::Object(_)) { + let builder = get_object_builder(PlSmallStr::EMPTY, 0).as_array_builder(); + return Self { dtype, builder }; + } + + let builder = make_builder(&dtype.to_physical().to_arrow(CompatLevel::newest())); Self { dtype, builder } } From ea984e2afbbe2309b1c0ff1e5458ce6dead0e14a Mon Sep 17 00:00:00 2001 From: Orson Peters Date: Fri, 7 Mar 2025 17:24:50 +0100 Subject: [PATCH 23/25] doc comment --- crates/polars-core/src/chunked_array/object/registry.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/polars-core/src/chunked_array/object/registry.rs b/crates/polars-core/src/chunked_array/object/registry.rs index c16c2cc92780..6a1d30a97821 100644 --- a/crates/polars-core/src/chunked_array/object/registry.rs +++ b/crates/polars-core/src/chunked_array/object/registry.rs @@ -77,7 +77,7 @@ pub trait AnonymousObjectBuilder: ArrayBuilder { impl AnonymousObjectBuilder for ObjectChunkedBuilder { /// # Safety - /// Expects ObjectArray arrays. + /// Expects `ObjectArray` arrays. unsafe fn from_chunks(self: Box, chunks: Vec) -> Series { ObjectChunked::::new_with_compute_len(Arc::new(self.field().clone()), chunks) .into_series() From 387d9a19a337ba3ec20bdcb027cea5072167e0a4 Mon Sep 17 00:00:00 2001 From: Orson Peters Date: Fri, 7 Mar 2025 19:06:11 +0100 Subject: [PATCH 24/25] parallelize cardinality estimation --- .../src/nodes/joins/equi_join.rs | 83 +++++++++++-------- 1 file changed, 48 insertions(+), 35 deletions(-) diff --git a/crates/polars-stream/src/nodes/joins/equi_join.rs b/crates/polars-stream/src/nodes/joins/equi_join.rs index 42af13315033..e7416f65e8eb 100644 --- a/crates/polars-stream/src/nodes/joins/equi_join.rs +++ b/crates/polars-stream/src/nodes/joins/equi_join.rs @@ -160,15 +160,48 @@ fn estimate_cardinality( key_selectors: &[StreamExpr], params: &EquiJoinParams, state: &ExecutionState, -) -> PolarsResult { - // TODO: parallelize. - let mut sketch = CardinalitySketch::new(); - for morsel in morsels { - let hash_keys = - get_runtime().block_on(select_keys(morsel.df(), key_selectors, params, state))?; - hash_keys.sketch_cardinality(&mut sketch); +) -> PolarsResult { + let sample_limit = *SAMPLE_LIMIT; + if morsels.is_empty() || sample_limit == 0 { + return Ok(0.0); } - Ok(sketch.estimate()) + + let mut total_height = 0; + let mut to_process_end = 0; + while to_process_end < morsels.len() && total_height < sample_limit { + total_height += morsels[to_process_end].df().height(); + to_process_end += 1; + } + let last_morsel_idx = to_process_end - 1; + let last_morsel_len = morsels[last_morsel_idx].df().height(); + let last_morsel_slice = last_morsel_len - total_height.saturating_sub(sample_limit); + let runtime = get_runtime(); + + POOL.install(|| { + let sample_cardinality = morsels[..to_process_end] + .par_iter() + .enumerate() + .try_fold( + || CardinalitySketch::new(), + |mut sketch, (morsel_idx, morsel)| { + let sliced; + let df = if morsel_idx == last_morsel_idx { + sliced = morsel.df().slice(0, last_morsel_slice); + &sliced + } else { + &morsel.df() + }; + let hash_keys = + runtime.block_on(select_keys(df, key_selectors, params, state))?; + hash_keys.sketch_cardinality(&mut sketch); + PolarsResult::Ok(sketch) + }, + ) + .map(|sketch| PolarsResult::Ok(sketch?.estimate())) + .try_reduce_with(|a, b| Ok(a + b)) + .unwrap()?; + Ok(sample_cardinality as f64 / total_height.min(sample_limit) as f64) + }) } struct BufferedStream { @@ -349,38 +382,18 @@ impl SampleState { params, &execution_state, )?; - let norm_left_factor = self.left_len.min(*SAMPLE_LIMIT) as f64 / self.left_len as f64; - let norm_right_factor = - self.right_len.min(*SAMPLE_LIMIT) as f64 / self.right_len as f64; - let norm_left_cardinality = (left_cardinality as f64 * norm_left_factor) as usize; - let norm_right_cardinality = (right_cardinality as f64 * norm_right_factor) as usize; if config::verbose() { - eprintln!("estimated cardinalities are: {norm_left_cardinality} vs. {norm_right_cardinality}"); + eprintln!( + "estimated cardinalities are: {left_cardinality} vs. {right_cardinality}" + ); } - PolarsResult::Ok((norm_left_cardinality, norm_right_cardinality)) + PolarsResult::Ok((left_cardinality, right_cardinality)) }; let left_is_build = match (left_saturated, right_saturated) { - (false, false) => { - if self.left_len * LOPSIDED_SAMPLE_FACTOR < self.right_len - || self.left_len > self.right_len * LOPSIDED_SAMPLE_FACTOR - { - // Don't bother estimating cardinality, just choose smaller as it's highly - // imbalanced. - self.left_len < self.right_len - } else { - let (lc, rc) = estimate_cardinalities()?; - // Let's assume for now that per element building a - // table is 3x more expensive than a probe, with - // unique keys getting an additional 3x factor for - // having to update the hash table in addition to the probe. - let left_build_cost = self.left_len * 3 + 3 * lc; - let left_probe_cost = self.left_len; - let right_build_cost = self.right_len * 3 + 3 * rc; - let right_probe_cost = self.right_len; - left_build_cost + right_probe_cost < left_probe_cost + right_build_cost - } - }, + // Don't bother estimating cardinality, just choose smaller side as + // we have everything in-memory anyway. + (false, false) => self.left_len < self.right_len, // Choose the unsaturated side, the saturated side could be // arbitrarily big. From 0abac7bdb1d8bd3ef5c8ee2b5f395786faf3a1b5 Mon Sep 17 00:00:00 2001 From: Orson Peters Date: Fri, 7 Mar 2025 19:08:07 +0100 Subject: [PATCH 25/25] clippy --- crates/polars-stream/src/nodes/joins/equi_join.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/crates/polars-stream/src/nodes/joins/equi_join.rs b/crates/polars-stream/src/nodes/joins/equi_join.rs index e7416f65e8eb..14a37dab855f 100644 --- a/crates/polars-stream/src/nodes/joins/equi_join.rs +++ b/crates/polars-stream/src/nodes/joins/equi_join.rs @@ -182,14 +182,14 @@ fn estimate_cardinality( .par_iter() .enumerate() .try_fold( - || CardinalitySketch::new(), + CardinalitySketch::new, |mut sketch, (morsel_idx, morsel)| { let sliced; let df = if morsel_idx == last_morsel_idx { sliced = morsel.df().slice(0, last_morsel_slice); &sliced } else { - &morsel.df() + morsel.df() }; let hash_keys = runtime.block_on(select_keys(df, key_selectors, params, state))?;