diff --git a/crates/polars-core/src/datatypes/dtype.rs b/crates/polars-core/src/datatypes/dtype.rs index 30d96649762f..f4c87ce0ad22 100644 --- a/crates/polars-core/src/datatypes/dtype.rs +++ b/crates/polars-core/src/datatypes/dtype.rs @@ -572,6 +572,52 @@ impl DataType { } } + /// Try to get the maximum value for this datatype. + pub fn max(&self) -> PolarsResult { + use DataType::*; + let v = match self { + #[cfg(feature = "dtype-i8")] + Int8 => Scalar::from(i8::MAX), + #[cfg(feature = "dtype-i16")] + Int16 => Scalar::from(i16::MAX), + Int32 => Scalar::from(i32::MAX), + Int64 => Scalar::from(i64::MAX), + #[cfg(feature = "dtype-u8")] + UInt8 => Scalar::from(u8::MAX), + #[cfg(feature = "dtype-u16")] + UInt16 => Scalar::from(u16::MAX), + UInt32 => Scalar::from(u32::MAX), + UInt64 => Scalar::from(u64::MAX), + Float32 => Scalar::from(f32::INFINITY), + Float64 => Scalar::from(f64::INFINITY), + dt => polars_bail!(ComputeError: "cannot determine upper bound for dtype `{}`", dt), + }; + Ok(v) + } + + /// Try to get the minimum value for this datatype. + pub fn min(&self) -> PolarsResult { + use DataType::*; + let v = match self { + #[cfg(feature = "dtype-i8")] + Int8 => Scalar::from(i8::MIN), + #[cfg(feature = "dtype-i16")] + Int16 => Scalar::from(i16::MIN), + Int32 => Scalar::from(i32::MIN), + Int64 => Scalar::from(i64::MIN), + #[cfg(feature = "dtype-u8")] + UInt8 => Scalar::from(u8::MIN), + #[cfg(feature = "dtype-u16")] + UInt16 => Scalar::from(u16::MIN), + UInt32 => Scalar::from(u32::MIN), + UInt64 => Scalar::from(u64::MIN), + Float32 => Scalar::from(f32::NEG_INFINITY), + Float64 => Scalar::from(f64::NEG_INFINITY), + dt => polars_bail!(ComputeError: "cannot determine lower bound for dtype `{}`", dt), + }; + Ok(v) + } + /// Convert to an Arrow data type. #[inline] pub fn to_arrow(&self, compat_level: CompatLevel) -> ArrowDataType { diff --git a/crates/polars-plan/src/dsl/function_expr/bounds.rs b/crates/polars-plan/src/dsl/function_expr/bounds.rs index 77c8a6f3ef5f..ae0f36a0956e 100644 --- a/crates/polars-plan/src/dsl/function_expr/bounds.rs +++ b/crates/polars-plan/src/dsl/function_expr/bounds.rs @@ -2,50 +2,12 @@ use super::*; pub(super) fn upper_bound(s: &Column) -> PolarsResult { let name = s.name().clone(); - use DataType::*; - let s = match s.dtype().to_physical() { - #[cfg(feature = "dtype-i8")] - Int8 => Column::new_scalar(name, Scalar::from(i8::MAX), 1), - #[cfg(feature = "dtype-i16")] - Int16 => Column::new_scalar(name, Scalar::from(i16::MAX), 1), - Int32 => Column::new_scalar(name, Scalar::from(i32::MAX), 1), - Int64 => Column::new_scalar(name, Scalar::from(i64::MAX), 1), - #[cfg(feature = "dtype-u8")] - UInt8 => Column::new_scalar(name, Scalar::from(u8::MAX), 1), - #[cfg(feature = "dtype-u16")] - UInt16 => Column::new_scalar(name, Scalar::from(u16::MAX), 1), - UInt32 => Column::new_scalar(name, Scalar::from(u32::MAX), 1), - UInt64 => Column::new_scalar(name, Scalar::from(u64::MAX), 1), - Float32 => Column::new_scalar(name, Scalar::from(f32::INFINITY), 1), - Float64 => Column::new_scalar(name, Scalar::from(f64::INFINITY), 1), - dt => polars_bail!( - ComputeError: "cannot determine upper bound for dtype `{}`", dt, - ), - }; - Ok(s) + let scalar = s.dtype().to_physical().max()?; + Ok(Column::new_scalar(name, scalar, 1)) } pub(super) fn lower_bound(s: &Column) -> PolarsResult { let name = s.name().clone(); - use DataType::*; - let s = match s.dtype().to_physical() { - #[cfg(feature = "dtype-i8")] - Int8 => Column::new_scalar(name, Scalar::from(i8::MIN), 1), - #[cfg(feature = "dtype-i16")] - Int16 => Column::new_scalar(name, Scalar::from(i16::MIN), 1), - Int32 => Column::new_scalar(name, Scalar::from(i32::MIN), 1), - Int64 => Column::new_scalar(name, Scalar::from(i64::MIN), 1), - #[cfg(feature = "dtype-u8")] - UInt8 => Column::new_scalar(name, Scalar::from(u8::MIN), 1), - #[cfg(feature = "dtype-u16")] - UInt16 => Column::new_scalar(name, Scalar::from(u16::MIN), 1), - UInt32 => Column::new_scalar(name, Scalar::from(u32::MIN), 1), - UInt64 => Column::new_scalar(name, Scalar::from(u64::MIN), 1), - Float32 => Column::new_scalar(name, Scalar::from(f32::NEG_INFINITY), 1), - Float64 => Column::new_scalar(name, Scalar::from(f64::NEG_INFINITY), 1), - dt => polars_bail!( - ComputeError: "cannot determine lower bound for dtype `{}`", dt, - ), - }; - Ok(s) + let scalar = s.dtype().to_physical().min()?; + Ok(Column::new_scalar(name, scalar, 1)) } diff --git a/crates/polars-python/src/datatypes.rs b/crates/polars-python/src/datatypes.rs index a31a2301f866..ea7686a29ec6 100644 --- a/crates/polars-python/src/datatypes.rs +++ b/crates/polars-python/src/datatypes.rs @@ -1,10 +1,12 @@ use polars::prelude::*; use polars_core::utils::arrow::array::Utf8ViewArray; +use polars_lazy::dsl; use pyo3::prelude::*; +use crate::error::PyPolarsErr; #[cfg(feature = "object")] use crate::object::OBJECT_NAME; -use crate::Wrap; +use crate::{PyExpr, Wrap}; // Don't change the order of these! #[repr(u8)] @@ -117,3 +119,15 @@ impl<'py> FromPyObject<'py> for PyDataType { Ok(dt.0.into()) } } + +#[pyfunction] +pub fn _get_dtype_max(dt: Wrap) -> PyResult { + let v = dt.0.max().map_err(PyPolarsErr::from)?; + Ok(dsl::lit(v).into()) +} + +#[pyfunction] +pub fn _get_dtype_min(dt: Wrap) -> PyResult { + let v = dt.0.min().map_err(PyPolarsErr::from)?; + Ok(dsl::lit(v).into()) +} diff --git a/py-polars/polars/datatypes/classes.py b/py-polars/polars/datatypes/classes.py index 64eaf13ea7b4..5543f629a620 100644 --- a/py-polars/polars/datatypes/classes.py +++ b/py-polars/polars/datatypes/classes.py @@ -12,6 +12,7 @@ import polars.functions as F with contextlib.suppress(ImportError): # Module not available when building docs + import polars.polars as plr from polars.polars import dtype_str_repr as _dtype_str_repr if TYPE_CHECKING: @@ -238,6 +239,44 @@ def to_python(self) -> PythonDataType: class NumericType(DataType): """Base class for numeric data types.""" + @classmethod + def max(cls) -> pl.Expr: + """ + Return a literal expression representing the maximum value of this data type. + + Examples + -------- + >>> pl.select(pl.Int8.max() == 127) + shape: (1, 1) + ┌─────────┐ + │ literal │ + │ --- │ + │ bool │ + ╞═════════╡ + │ true │ + └─────────┘ + """ + return pl.Expr._from_pyexpr(plr._get_dtype_max(cls)) + + @classmethod + def min(cls) -> pl.Expr: + """ + Return a literal expression representing the minimum value of this data type. + + Examples + -------- + >>> pl.select(pl.Int8.min() == -128) + shape: (1, 1) + ┌─────────┐ + │ literal │ + │ --- │ + │ bool │ + ╞═════════╡ + │ true │ + └─────────┘ + """ + return pl.Expr._from_pyexpr(plr._get_dtype_min(cls)) + class IntegerType(NumericType): """Base class for integer data types.""" diff --git a/py-polars/src/lib.rs b/py-polars/src/lib.rs index 859609828d19..f73577319545 100644 --- a/py-polars/src/lib.rs +++ b/py-polars/src/lib.rs @@ -20,7 +20,7 @@ use polars_python::lazygroupby::PyLazyGroupBy; use polars_python::series::PySeries; #[cfg(feature = "sql")] use polars_python::sql::PySQLContext; -use polars_python::{exceptions, functions}; +use polars_python::{datatypes, exceptions, functions}; use pyo3::prelude::*; use pyo3::{wrap_pyfunction, wrap_pymodule}; @@ -279,6 +279,12 @@ fn polars(py: Python, m: &Bound) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(functions::escape_regex)) .unwrap(); + // Dtype helpers + m.add_wrapped(wrap_pyfunction!(datatypes::_get_dtype_max)) + .unwrap(); + m.add_wrapped(wrap_pyfunction!(datatypes::_get_dtype_min)) + .unwrap(); + // Exceptions - Errors m.add( "PolarsError", diff --git a/py-polars/tests/unit/test_datatypes.py b/py-polars/tests/unit/test_datatypes.py index 4d604f2964e9..ed4b8cd1dd61 100644 --- a/py-polars/tests/unit/test_datatypes.py +++ b/py-polars/tests/unit/test_datatypes.py @@ -202,3 +202,28 @@ def test_struct_field_iter() -> None: def test_raise_invalid_namespace() -> None: with pytest.raises(pl.exceptions.InvalidOperationError): pl.select(pl.lit(1.5).str.replace("1", "2")) + + +@pytest.mark.parametrize( + ("dtype", "lower", "upper"), + [ + (pl.Int8, -128, 127), + (pl.UInt8, 0, 255), + (pl.Int16, -32768, 32767), + (pl.UInt16, 0, 65535), + (pl.Int32, -2147483648, 2147483647), + (pl.UInt32, 0, 4294967295), + (pl.Int64, -9223372036854775808, 9223372036854775807), + (pl.UInt64, 0, 18446744073709551615), + (pl.Float32, float("-inf"), float("inf")), + (pl.Float64, float("-inf"), float("inf")), + ], +) +def test_max_min( + dtype: datatypes.IntegerType | datatypes.Float32 | datatypes.Float64, + upper: int | float, + lower: int | float, +) -> None: + df = pl.select(min=dtype.min(), max=dtype.max()) + assert df.to_series(0).item() == lower + assert df.to_series(1).item() == upper