From fa7ec47be7c0381ab9eacd737d116113b0ecb4c3 Mon Sep 17 00:00:00 2001 From: Ritchie Vink Date: Fri, 27 Sep 2024 18:02:28 +0200 Subject: [PATCH] fix: Ensure same fmt in Series/AnyValue to string cast (#18982) --- .../src/compute/cast/primitive_to.rs | 2 +- crates/polars-core/src/datatypes/any_value.rs | 17 ++++++++++------- py-polars/tests/unit/operations/test_cast.py | 6 ++++++ 3 files changed, 17 insertions(+), 8 deletions(-) diff --git a/crates/polars-arrow/src/compute/cast/primitive_to.rs b/crates/polars-arrow/src/compute/cast/primitive_to.rs index 03a9c427cf4b..d017b0a8e212 100644 --- a/crates/polars-arrow/src/compute/cast/primitive_to.rs +++ b/crates/polars-arrow/src/compute/cast/primitive_to.rs @@ -13,7 +13,7 @@ use crate::offset::{Offset, Offsets}; use crate::temporal_conversions::*; use crate::types::{days_ms, f16, months_days_ns, NativeType}; -pub(super) trait SerPrimitive { +pub trait SerPrimitive { fn write(f: &mut Vec, val: Self) -> usize where Self: Sized; diff --git a/crates/polars-core/src/datatypes/any_value.rs b/crates/polars-core/src/datatypes/any_value.rs index 83d9d4f8301d..4155a9bf14e9 100644 --- a/crates/polars-core/src/datatypes/any_value.rs +++ b/crates/polars-core/src/datatypes/any_value.rs @@ -1,7 +1,7 @@ use std::borrow::Cow; +use arrow::compute::cast::SerPrimitive; use arrow::types::PrimitiveType; -use polars_utils::format_pl_smallstr; #[cfg(feature = "dtype-categorical")] use polars_utils::sync::SyncPtr; use polars_utils::total_ord::ToTotalOrd; @@ -563,19 +563,22 @@ impl<'a> AnyValue<'a> { (AnyValue::Float64(v), DataType::Boolean) => AnyValue::Boolean(*v != f64::default()), // to string - (AnyValue::String(v), DataType::String) => { - AnyValue::StringOwned(PlSmallStr::from_str(v)) - }, + (AnyValue::String(v), DataType::String) => AnyValue::String(v), (AnyValue::StringOwned(v), DataType::String) => AnyValue::StringOwned(v.clone()), (av, DataType::String) => { + let mut tmp = vec![]; if av.is_unsigned_integer() { - AnyValue::StringOwned(format_pl_smallstr!("{}", av.extract::()?)) + let val = av.extract::()?; + SerPrimitive::write(&mut tmp, val); } else if av.is_float() { - AnyValue::StringOwned(format_pl_smallstr!("{}", av.extract::()?)) + let val = av.extract::()?; + SerPrimitive::write(&mut tmp, val); } else { - AnyValue::StringOwned(format_pl_smallstr!("{}", av.extract::()?)) + let val = av.extract::()?; + SerPrimitive::write(&mut tmp, val); } + AnyValue::StringOwned(PlSmallStr::from_str(std::str::from_utf8(&tmp).unwrap())) }, // to binary diff --git a/py-polars/tests/unit/operations/test_cast.py b/py-polars/tests/unit/operations/test_cast.py index cdf7688f7c32..e01912237a19 100644 --- a/py-polars/tests/unit/operations/test_cast.py +++ b/py-polars/tests/unit/operations/test_cast.py @@ -686,3 +686,9 @@ def test_bool_numeric_supertype(dtype: PolarsDataType) -> None: df = pl.DataFrame({"v": [1, 2, 3, 4, 5, 6]}) result = df.select((pl.col("v") < 3).sum().cast(dtype) / pl.len()) assert result.item() - 0.3333333 <= 0.00001 + + +def test_cast_consistency() -> None: + assert pl.DataFrame().with_columns(a=pl.lit(0.0)).with_columns( + b=pl.col("a").cast(pl.String), c=pl.lit(0.0).cast(pl.String) + ).to_dict(as_series=False) == {"a": [0.0], "b": ["0.0"], "c": ["0.0"]}