diff --git a/datafusion/physical-expr-common/src/aggregate/groups_accumulator/accumulate.rs b/datafusion/physical-expr-common/src/aggregate/groups_accumulator/accumulate.rs index c9393349d931..608595f0828c 100644 --- a/datafusion/physical-expr-common/src/aggregate/groups_accumulator/accumulate.rs +++ b/datafusion/physical-expr-common/src/aggregate/groups_accumulator/accumulate.rs @@ -27,6 +27,7 @@ use arrow::buffer::{BooleanBuffer, NullBuffer}; use arrow::datatypes::ArrowPrimitiveType; use datafusion_expr::EmitTo; + /// Track the accumulator null state per row: if any values for that /// group were null and if any values have been seen at all for that group. /// @@ -487,6 +488,18 @@ impl NullState { } } + /// Check if the accumulated value for the group at the given `index` is valid, + /// meaning that there was at least one value passing the filter for this group. + pub fn is_valid(&self, index: usize) -> bool { + self.seen_values.get_bit(index) + } + + /// Check if the accumulated value for the group at the given `index` is `null`, + /// meaning that no values passing the filter were seen yet for this group. + pub fn is_null(&self, index: usize) -> bool { + !self.is_valid(index) + } + /// Creates the a [`NullBuffer`] representing which group_indices /// should have null values (because they never saw any values) /// for the `emit_to` rows. @@ -623,11 +636,12 @@ fn initialize_builder( #[cfg(test)] mod test { - use super::*; + use std::collections::HashSet; use arrow::array::UInt32Array; use rand::{rngs::ThreadRng, Rng}; - use std::collections::HashSet; + + use super::*; #[test] fn accumulate() { @@ -991,9 +1005,12 @@ mod test { // Validate the final buffer (one value per group) let expected_null_buffer = mock.expected_null_buffer(total_num_groups); + for (i, expected) in expected_null_buffer.iter().enumerate() { + assert_eq!(expected, null_state.is_valid(i)); + assert_eq!(!expected, null_state.is_null(i)) + } let null_buffer = null_state.build(EmitTo::All); - assert_eq!(null_buffer, expected_null_buffer); } }