From 2918efadf6ab5dc293faa3ad0a36efd904191cda Mon Sep 17 00:00:00 2001 From: coastalwhite Date: Tue, 24 Sep 2024 15:22:09 +0200 Subject: [PATCH] fix: AnyValue Series from Categorical/Enum Fixes #18874. --- .../src/chunked_array/logical/mod.rs | 2 + crates/polars-core/src/series/any_value.rs | 99 ++++++++++++++++--- .../constructors/test_any_value_fallbacks.py | 20 +++- 3 files changed, 108 insertions(+), 13 deletions(-) diff --git a/crates/polars-core/src/chunked_array/logical/mod.rs b/crates/polars-core/src/chunked_array/logical/mod.rs index d4e6d4eb84aa..0baa286e9f1d 100644 --- a/crates/polars-core/src/chunked_array/logical/mod.rs +++ b/crates/polars-core/src/chunked_array/logical/mod.rs @@ -16,6 +16,8 @@ mod duration; pub use duration::*; #[cfg(feature = "dtype-categorical")] pub mod categorical; +#[cfg(feature = "dtype-categorical")] +pub mod enum_; #[cfg(feature = "dtype-time")] mod time; diff --git a/crates/polars-core/src/series/any_value.rs b/crates/polars-core/src/series/any_value.rs index 1a8c1df02b65..1f6562b5397a 100644 --- a/crates/polars-core/src/series/any_value.rs +++ b/crates/polars-core/src/series/any_value.rs @@ -2,8 +2,6 @@ use std::fmt::Write; use arrow::bitmap::MutableBitmap; -#[cfg(feature = "dtype-categorical")] -use crate::chunked_array::cast::CastOptions; #[cfg(feature = "object")] use crate::chunked_array::object::registry::ObjectRegistry; use crate::prelude::*; @@ -146,9 +144,9 @@ impl Series { #[cfg(feature = "dtype-duration")] DataType::Duration(tu) => any_values_to_duration(values, *tu, strict)?.into_series(), #[cfg(feature = "dtype-categorical")] - dt @ (DataType::Categorical(_, _) | DataType::Enum(_, _)) => { - any_values_to_categorical(values, dt, strict)? - }, + dt @ DataType::Categorical(_, _) => any_values_to_categorical(values, dt, strict)?, + #[cfg(feature = "dtype-categorical")] + dt @ DataType::Enum(_, _) => any_values_to_enum(values, dt, strict)?, #[cfg(feature = "dtype-decimal")] DataType::Decimal(precision, scale) => { any_values_to_decimal(values, *precision, *scale, strict)?.into_series() @@ -445,14 +443,91 @@ fn any_values_to_categorical( dtype: &DataType, strict: bool, ) -> PolarsResult { - // TODO: Handle AnyValues of type Categorical/Enum. - // TODO: Avoid materializing to String before casting to Categorical/Enum. - let ca = any_values_to_string(values, strict)?; - if strict { - ca.into_series().strict_cast(dtype) - } else { - ca.cast_with_options(dtype, CastOptions::NonStrict) + let ordering = match dtype { + DataType::Categorical(_, ordering) => ordering, + _ => panic!("any_values_to_categorical with dtype={dtype:?}"), + }; + + let mut builder = CategoricalChunkedBuilder::new(PlSmallStr::EMPTY, values.len(), *ordering); + + let mut owned = String::new(); // Amortize allocations. + for av in values { + match av { + AnyValue::String(s) => builder.append_value(s), + AnyValue::StringOwned(s) => builder.append_value(s), + + AnyValue::Enum(s, rev, _) => builder.append_value(rev.get(*s)), + AnyValue::EnumOwned(s, rev, _) => builder.append_value(rev.get(*s)), + + AnyValue::Categorical(s, rev, _) => builder.append_value(rev.get(*s)), + AnyValue::CategoricalOwned(s, rev, _) => builder.append_value(rev.get(*s)), + + AnyValue::Binary(_) | AnyValue::BinaryOwned(_) if !strict => builder.append_null(), + AnyValue::Null => builder.append_null(), + + av => { + if strict { + return Err(invalid_value_error(&DataType::String, av)); + } + + owned.clear(); + write!(owned, "{av}").unwrap(); + builder.append_value(&owned); + }, + } + } + + let ca = builder.finish(); + + Ok(ca.into_series()) +} + +#[cfg(feature = "dtype-categorical")] +fn any_values_to_enum(values: &[AnyValue], dtype: &DataType, strict: bool) -> PolarsResult { + use self::enum_::EnumChunkedBuilder; + + let (rev, ordering) = match dtype { + DataType::Enum(rev, ordering) => (rev.clone(), ordering), + _ => panic!("any_values_to_categorical with dtype={dtype:?}"), + }; + + let Some(rev) = rev else { + polars_bail!(nyi = "Not yet possible to create enum series without a rev-map"); + }; + + let mut builder = + EnumChunkedBuilder::new(PlSmallStr::EMPTY, values.len(), rev, *ordering, strict); + + let mut owned = String::new(); // Amortize allocations. + for av in values { + match av { + AnyValue::String(s) => builder.append_str(s)?, + AnyValue::StringOwned(s) => builder.append_str(s)?, + + AnyValue::Enum(s, rev, _) => builder.append_enum(*s, rev)?, + AnyValue::EnumOwned(s, rev, _) => builder.append_enum(*s, rev)?, + + AnyValue::Categorical(s, rev, _) => builder.append_str(rev.get(*s))?, + AnyValue::CategoricalOwned(s, rev, _) => builder.append_str(rev.get(*s))?, + + AnyValue::Binary(_) | AnyValue::BinaryOwned(_) if !strict => builder.append_null(), + AnyValue::Null => builder.append_null(), + + av => { + if strict { + return Err(invalid_value_error(&DataType::String, av)); + } + + owned.clear(); + write!(owned, "{av}").unwrap(); + builder.append_str(&owned)? + }, + }; } + + let ca = builder.finish(); + + Ok(ca.into_series()) } #[cfg(feature = "dtype-decimal")] diff --git a/py-polars/tests/unit/constructors/test_any_value_fallbacks.py b/py-polars/tests/unit/constructors/test_any_value_fallbacks.py index 0f3c40f21925..8e6f8581072f 100644 --- a/py-polars/tests/unit/constructors/test_any_value_fallbacks.py +++ b/py-polars/tests/unit/constructors/test_any_value_fallbacks.py @@ -11,6 +11,7 @@ import polars as pl from polars._utils.wrap import wrap_s from polars.polars import PySeries +from polars.testing import assert_frame_equal if TYPE_CHECKING: from polars._typing import PolarsDataType @@ -379,7 +380,9 @@ def test_fallback_with_dtype_strict_failure_enum_casting() -> None: dtype = pl.Enum(["a", "b"]) values = ["a", "b", "c", None] - with pytest.raises(TypeError, match="conversion from `str` to `enum` failed"): + with pytest.raises( + TypeError, match="cannot append 'c' to enum without that variant" + ): PySeries.new_from_any_values_and_dtype("", values, dtype, strict=True) @@ -391,3 +394,18 @@ def test_fallback_with_dtype_strict_failure_decimal_precision() -> None: TypeError, match="decimal precision 3 can't fit values with 5 digits" ): PySeries.new_from_any_values_and_dtype("", values, dtype, strict=True) + + +def test_categorical_lit_18874() -> None: + with pl.StringCache(): + assert_frame_equal( + pl.DataFrame( + {"a": [1, 2, 3]}, + ).with_columns(b=pl.lit("foo").cast(pl.Categorical)), + pl.DataFrame( + [ + pl.Series("a", [1, 2, 3]), + pl.Series("b", ["foo"] * 3, pl.Categorical), + ] + ), + )