Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Hold string cache in new streaming engine and fix row-encoding #21039

Merged
merged 7 commits into from
Feb 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 71 additions & 52 deletions crates/polars-core/src/chunked_array/ops/row_encode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ pub fn encode_rows_vertical_par_unordered_broadcast_nulls(
///
/// This should be given the logical type in order to communicate Polars datatype information down
/// into the row encoding / decoding.
pub fn get_row_encoding_context(dtype: &DataType) -> Option<RowEncodingContext> {
pub fn get_row_encoding_context(dtype: &DataType, ordered: bool) -> Option<RowEncodingContext> {
match dtype {
DataType::Boolean
| DataType::UInt8
Expand Down Expand Up @@ -108,67 +108,86 @@ pub fn get_row_encoding_context(dtype: &DataType) -> Option<RowEncodingContext>
},

#[cfg(feature = "dtype-array")]
DataType::Array(dtype, _) => get_row_encoding_context(dtype),
DataType::List(dtype) => get_row_encoding_context(dtype),
DataType::Array(dtype, _) => get_row_encoding_context(dtype, ordered),
DataType::List(dtype) => get_row_encoding_context(dtype, ordered),
#[cfg(feature = "dtype-categorical")]
DataType::Categorical(revmap, ordering) | DataType::Enum(revmap, ordering) => {
let revmap = revmap.as_ref().unwrap();

let (num_known_categories, lexical_sort_idxs) = match revmap.as_ref() {
RevMapping::Global(map, _, _) => {
let num_known_categories = map.keys().max().copied().map_or(0, |m| m + 1);

// @TODO: This should probably be cached.
let lexical_sort_idxs =
matches!(ordering, CategoricalOrdering::Lexical).then(|| {
let read_map = crate::STRING_CACHE.read_map();
let payloads = read_map.get_current_payloads();
assert!(payloads.len() >= num_known_categories as usize);

let mut idxs = (0..num_known_categories).collect::<Vec<u32>>();
idxs.sort_by_key(|&k| payloads[k as usize].as_str());
let mut sort_idxs = vec![0; num_known_categories as usize];
for (i, idx) in idxs.into_iter().enumerate_u32() {
sort_idxs[idx as usize] = i;
}
sort_idxs
});

(num_known_categories, lexical_sort_idxs)
let is_enum = dtype.is_enum();
let ctx = match revmap {
Some(revmap) => {
let (num_known_categories, lexical_sort_idxs) = match revmap.as_ref() {
RevMapping::Global(map, _, _) => {
let num_known_categories =
map.keys().max().copied().map_or(0, |m| m + 1);

// @TODO: This should probably be cached.
let lexical_sort_idxs = (ordered
&& matches!(ordering, CategoricalOrdering::Lexical))
.then(|| {
let read_map = crate::STRING_CACHE.read_map();
let payloads = read_map.get_current_payloads();
assert!(payloads.len() >= num_known_categories as usize);

let mut idxs = (0..num_known_categories).collect::<Vec<u32>>();
idxs.sort_by_key(|&k| payloads[k as usize].as_str());
let mut sort_idxs = vec![0; num_known_categories as usize];
for (i, idx) in idxs.into_iter().enumerate_u32() {
sort_idxs[idx as usize] = i;
}
sort_idxs
});

(num_known_categories, lexical_sort_idxs)
},
RevMapping::Local(values, _) => {
// @TODO: This should probably be cached.
let lexical_sort_idxs = (ordered
&& matches!(ordering, CategoricalOrdering::Lexical))
.then(|| {
assert_eq!(values.null_count(), 0);
let values: Vec<&str> = values.values_iter().collect();

let mut idxs = (0..values.len() as u32).collect::<Vec<u32>>();
idxs.sort_by_key(|&k| values[k as usize]);
let mut sort_idxs = vec![0; values.len()];
for (i, idx) in idxs.into_iter().enumerate_u32() {
sort_idxs[idx as usize] = i;
}
sort_idxs
});

(values.len() as u32, lexical_sort_idxs)
},
};

RowEncodingCategoricalContext {
num_known_categories,
is_enum,
lexical_sort_idxs,
}
},
RevMapping::Local(values, _) => {
// @TODO: This should probably be cached.
let lexical_sort_idxs =
matches!(ordering, CategoricalOrdering::Lexical).then(|| {
assert_eq!(values.null_count(), 0);
let values: Vec<&str> = values.values_iter().collect();

let mut idxs = (0..values.len() as u32).collect::<Vec<u32>>();
idxs.sort_by_key(|&k| values[k as usize]);
let mut sort_idxs = vec![0; values.len()];
for (i, idx) in idxs.into_iter().enumerate_u32() {
sort_idxs[idx as usize] = i;
}
sort_idxs
});

(values.len() as u32, lexical_sort_idxs)
None => {
let num_known_categories = u32::MAX;

if matches!(ordering, CategoricalOrdering::Lexical) && ordered {
panic!("lexical ordering not yet supported if rev-map not given");
}
RowEncodingCategoricalContext {
num_known_categories,
is_enum,
lexical_sort_idxs: None,
}
},
};

let ctx = RowEncodingCategoricalContext {
num_known_categories,
is_enum: matches!(dtype, DataType::Enum(_, _)),
lexical_sort_idxs,
};
Some(RowEncodingContext::Categorical(ctx))
},
#[cfg(feature = "dtype-struct")]
DataType::Struct(fs) => {
let mut ctxts = Vec::new();

for (i, f) in fs.iter().enumerate() {
if let Some(ctxt) = get_row_encoding_context(f.dtype()) {
if let Some(ctxt) = get_row_encoding_context(f.dtype(), ordered) {
ctxts.reserve(fs.len());
ctxts.extend(std::iter::repeat_n(None, i));
ctxts.push(Some(ctxt));
Expand All @@ -183,7 +202,7 @@ pub fn get_row_encoding_context(dtype: &DataType) -> Option<RowEncodingContext>
ctxts.extend(
fs[ctxts.len()..]
.iter()
.map(|f| get_row_encoding_context(f.dtype())),
.map(|f| get_row_encoding_context(f.dtype(), ordered)),
);

Some(RowEncodingContext::Struct(ctxts))
Expand Down Expand Up @@ -214,7 +233,7 @@ pub fn _get_rows_encoded_unordered(by: &[Column]) -> PolarsResult<RowsEncoded> {
let by = by.as_materialized_series();
let arr = by.to_physical_repr().rechunk().chunks()[0].to_boxed();
let opt = RowEncodingOptions::new_unsorted();
let ctxt = get_row_encoding_context(by.dtype());
let ctxt = get_row_encoding_context(by.dtype(), false);

cols.push(arr);
opts.push(opt);
Expand Down Expand Up @@ -245,7 +264,7 @@ pub fn _get_rows_encoded(
let by = by.as_materialized_series();
let arr = by.to_physical_repr().rechunk().chunks()[0].to_boxed();
let opt = RowEncodingOptions::new_sorted(*desc, *null_last);
let ctxt = get_row_encoding_context(by.dtype());
let ctxt = get_row_encoding_context(by.dtype(), true);

cols.push(arr);
opts.push(opt);
Expand Down
29 changes: 20 additions & 9 deletions crates/polars-core/src/series/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -509,15 +509,26 @@ impl Series {
},

#[cfg(feature = "dtype-categorical")]
(D::UInt32, D::Categorical(revmap, ordering)) => Ok(unsafe {
CategoricalChunked::from_cats_and_rev_map_unchecked(
self.u32().unwrap().clone(),
revmap.as_ref().unwrap().clone(),
false,
*ordering,
)
}
.into_series()),
(D::UInt32, D::Categorical(revmap, ordering)) => match revmap {
Some(revmap) => Ok(unsafe {
CategoricalChunked::from_cats_and_rev_map_unchecked(
self.u32().unwrap().clone(),
revmap.clone(),
false,
*ordering,
)
}
.into_series()),
// In the streaming engine this is `None` and the global string cache is turned on
// for the duration of the query.
None => Ok(unsafe {
CategoricalChunked::from_global_indices_unchecked(
self.u32().unwrap().clone(),
*ordering,
)
.into_series()
}),
},
#[cfg(feature = "dtype-categorical")]
(D::UInt32, D::Enum(revmap, ordering)) => Ok(unsafe {
CategoricalChunked::from_cats_and_rev_map_unchecked(
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-expr/src/groups/row_encoded.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ impl RowEncodedHashGrouper {
let ctxts = self
.key_schema
.iter()
.map(|(_, dt)| get_row_encoding_context(dt))
.map(|(_, dt)| get_row_encoding_context(dt, false))
.collect::<Vec<_>>();
let fields = vec![RowEncodingOptions::new_unsorted(); key_dtypes.len()];
let key_columns =
Expand Down
1 change: 1 addition & 0 deletions crates/polars-lazy/src/frame/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -868,6 +868,7 @@ impl LazyFrame {
payload,
});

let _hold = StringCacheHolder::hold();
let f = || {
polars_stream::run_query(stream_lp_top, alp_plan.lp_arena, &mut alp_plan.expr_arena)
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ impl Eval {
let mut dicts = Vec::with_capacity(self.key_columns_expr.len());
for phys_e in self.key_columns_expr.iter() {
let s = phys_e.evaluate(chunk, &context.execution_state)?;
dicts.push(get_row_encoding_context(s.dtype()));
dicts.push(get_row_encoding_context(s.dtype(), false));
let s = s.to_physical_repr().into_owned();
let s = prepare_key(&s, chunk);
keys_columns.push(s.to_arrow(0, CompatLevel::newest()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ impl<const FIXED: bool> AggHashTable<FIXED> {
.output_schema
.iter_values()
.take(self.num_keys)
.map(get_row_encoding_context)
.map(|dt| get_row_encoding_context(dt, false))
.collect::<Vec<_>>();
let fields = vec![Default::default(); self.num_keys];
let key_columns =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ impl<K: ExtraPayload> GenericBuild<K> {
let s = phys_e.evaluate(chunk, &context.execution_state)?;
let arr = s.to_physical_repr().rechunk().array_ref(0).clone();
self.join_columns.push(arr);
ctxts.push(get_row_encoding_context(s.dtype()));
ctxts.push(get_row_encoding_context(s.dtype(), false));
}
let rows_encoded = polars_row::convert_columns_no_order(
self.join_columns[0].len(), // @NOTE: does not work for ZFS
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ impl RowValues {
names.push(s.name().to_string());
}
self.join_columns_material.push(s.array_ref(0).clone());
ctxts.push(get_row_encoding_context(s.dtype()));
ctxts.push(get_row_encoding_context(s.dtype(), false));
}

// We determine the indices of the columns that have to be removed
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ impl SortSinkMultiple {
.iter()
.map(|i| {
let (_, dtype) = schema.get_at_index(*i).unwrap();
get_row_encoding_context(dtype)
get_row_encoding_context(dtype, true)
})
.collect::<Vec<_>>();

Expand Down
4 changes: 3 additions & 1 deletion crates/polars-python/src/series/general.rs
Original file line number Diff line number Diff line change
Expand Up @@ -484,7 +484,9 @@ impl PySeries {

let dicts = dtypes
.iter()
.map(|(_, dtype)| get_row_encoding_context(&dtype.0))
.map(|(_, dt)| dt)
.zip(opts.iter())
.map(|(dtype, opts)| get_row_encoding_context(&dtype.0, opts.is_ordered()))
.collect::<Vec<_>>();

// Get the BinaryOffset array.
Expand Down
4 changes: 4 additions & 0 deletions crates/polars-row/src/row.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,10 @@ impl RowEncodingOptions {
Self::NO_ORDER
}

pub fn is_ordered(self) -> bool {
!self.contains(Self::NO_ORDER)
}

pub fn null_sentinel(self) -> u8 {
if self.contains(Self::NULLS_LAST) {
0xFF
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,7 @@ def test_fallback_with_dtype_strict_failure_decimal_precision() -> None:


@pytest.mark.usefixtures("test_global_and_local")
@pytest.mark.may_fail_auto_streaming
def test_categorical_lit_18874() -> None:
assert_frame_equal(
pl.DataFrame(
Expand Down
1 change: 1 addition & 0 deletions py-polars/tests/unit/constructors/test_constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -1741,6 +1741,7 @@ def __arrow_c_array__(self, requested_schema: object = None) -> object:
return self.arrow_obj.__arrow_c_array__(requested_schema)


@pytest.mark.may_fail_auto_streaming
def test_pycapsule_interface(df: pl.DataFrame) -> None:
df = df.rechunk()
pyarrow_table = df.to_arrow()
Expand Down
1 change: 1 addition & 0 deletions py-polars/tests/unit/dataframe/test_df.py
Original file line number Diff line number Diff line change
Expand Up @@ -1902,6 +1902,7 @@ def test_empty_projection() -> None:
assert empty_df.shape == (0, 0)


@pytest.mark.may_fail_auto_streaming
def test_fill_null() -> None:
df = pl.DataFrame({"a": [1, 2], "b": [3, None]})
assert_frame_equal(df.fill_null(4), pl.DataFrame({"a": [1, 2], "b": [3, 4]}))
Expand Down
2 changes: 2 additions & 0 deletions py-polars/tests/unit/dataframe/test_serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,15 @@ def test_df_serde_roundtrip_json(df: pl.DataFrame) -> None:
assert_frame_equal(result, df, categorical_as_str=True)


@pytest.mark.may_fail_auto_streaming
def test_df_serde(df: pl.DataFrame) -> None:
serialized = df.serialize()
assert isinstance(serialized, bytes)
result = pl.DataFrame.deserialize(io.BytesIO(serialized))
assert_frame_equal(result, df)


@pytest.mark.may_fail_auto_streaming
def test_df_serde_json_stringio(df: pl.DataFrame) -> None:
serialized = df.serialize(format="json")
assert isinstance(serialized, str)
Expand Down
4 changes: 2 additions & 2 deletions py-polars/tests/unit/datatypes/test_categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -625,6 +625,7 @@ def test_categorical_fill_null_existing_category() -> None:


@pytest.mark.usefixtures("test_global_and_local")
@pytest.mark.may_fail_auto_streaming
def test_categorical_fill_null_stringcache() -> None:
df = pl.LazyFrame(
{"index": [1, 2, 3], "cat": ["a", "b", None]},
Expand Down Expand Up @@ -849,6 +850,7 @@ def test_cat_preserve_lexical_ordering_on_concat() -> None:


@pytest.mark.usefixtures("test_global_and_local")
@pytest.mark.may_fail_auto_streaming
def test_cat_append_lexical_sorted_flag() -> None:
df = pl.DataFrame({"x": [0, 1, 1], "y": ["B", "B", "A"]}).with_columns(
pl.col("y").cast(pl.Categorical(ordering="lexical"))
Expand Down Expand Up @@ -932,7 +934,6 @@ def test_categorical_unique() -> None:
assert s.unique().sort().to_list() == [None, "a", "b"]


@pytest.mark.may_fail_auto_streaming
@pytest.mark.usefixtures("test_global_and_local")
def test_categorical_unique_20539() -> None:
df = pl.DataFrame({"number": [1, 1, 2, 2, 3], "letter": ["a", "b", "b", "c", "c"]})
Expand All @@ -953,7 +954,6 @@ def test_categorical_unique_20539() -> None:
}


@pytest.mark.may_fail_auto_streaming
@pytest.mark.usefixtures("test_global_and_local")
def test_categorical_prefill() -> None:
# https://github.com/pola-rs/polars/pull/20547#issuecomment-2569473443
Expand Down
1 change: 1 addition & 0 deletions py-polars/tests/unit/datatypes/test_enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -707,6 +707,7 @@ class Number(*EnumBase): # type: ignore[misc]
assert_series_equal(expected, s)


@pytest.mark.may_fail_auto_streaming
def test_read_enum_from_csv() -> None:
df = pl.DataFrame(
{
Expand Down
1 change: 1 addition & 0 deletions py-polars/tests/unit/datatypes/test_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -831,6 +831,7 @@ def test_list_list_sum_exception_12935() -> None:
pl.Series([[1], [2]]).sum()


@pytest.mark.may_fail_auto_streaming
def test_null_list_categorical_16405() -> None:
df = pl.DataFrame(
[(None, "foo")],
Expand Down
1 change: 1 addition & 0 deletions py-polars/tests/unit/functions/test_when_then.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,7 @@ def test_object_when_then_4702() -> None:
}


@pytest.mark.may_fail_auto_streaming
def test_comp_categorical_lit_dtype() -> None:
df = pl.DataFrame(
data={"column": ["a", "b", "e"], "values": [1, 5, 9]},
Expand Down
1 change: 0 additions & 1 deletion py-polars/tests/unit/io/test_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -1501,7 +1501,6 @@ def test_csv_categorical_lifetime() -> None:
assert (df["a"] == df["b"]).to_list() == [False, False, None]


@pytest.mark.may_fail_auto_streaming
def test_csv_categorical_categorical_merge() -> None:
N = 50
f = io.BytesIO()
Expand Down
Loading
Loading