Skip to content

Commit

Permalink
feat(rust,python): Implement max/min methods for dtypes (#19494)
Browse files Browse the repository at this point in the history
  • Loading branch information
eitsupi authored Nov 12, 2024
1 parent 36e5913 commit 017508b
Show file tree
Hide file tree
Showing 6 changed files with 136 additions and 44 deletions.
46 changes: 46 additions & 0 deletions crates/polars-core/src/datatypes/dtype.rs
Original file line number Diff line number Diff line change
Expand Up @@ -572,6 +572,52 @@ impl DataType {
}
}

/// Try to get the maximum value for this datatype.
pub fn max(&self) -> PolarsResult<Scalar> {
use DataType::*;
let v = match self {
#[cfg(feature = "dtype-i8")]
Int8 => Scalar::from(i8::MAX),
#[cfg(feature = "dtype-i16")]
Int16 => Scalar::from(i16::MAX),
Int32 => Scalar::from(i32::MAX),
Int64 => Scalar::from(i64::MAX),
#[cfg(feature = "dtype-u8")]
UInt8 => Scalar::from(u8::MAX),
#[cfg(feature = "dtype-u16")]
UInt16 => Scalar::from(u16::MAX),
UInt32 => Scalar::from(u32::MAX),
UInt64 => Scalar::from(u64::MAX),
Float32 => Scalar::from(f32::INFINITY),
Float64 => Scalar::from(f64::INFINITY),
dt => polars_bail!(ComputeError: "cannot determine upper bound for dtype `{}`", dt),
};
Ok(v)
}

/// Try to get the minimum value for this datatype.
pub fn min(&self) -> PolarsResult<Scalar> {
use DataType::*;
let v = match self {
#[cfg(feature = "dtype-i8")]
Int8 => Scalar::from(i8::MIN),
#[cfg(feature = "dtype-i16")]
Int16 => Scalar::from(i16::MIN),
Int32 => Scalar::from(i32::MIN),
Int64 => Scalar::from(i64::MIN),
#[cfg(feature = "dtype-u8")]
UInt8 => Scalar::from(u8::MIN),
#[cfg(feature = "dtype-u16")]
UInt16 => Scalar::from(u16::MIN),
UInt32 => Scalar::from(u32::MIN),
UInt64 => Scalar::from(u64::MIN),
Float32 => Scalar::from(f32::NEG_INFINITY),
Float64 => Scalar::from(f64::NEG_INFINITY),
dt => polars_bail!(ComputeError: "cannot determine lower bound for dtype `{}`", dt),
};
Ok(v)
}

/// Convert to an Arrow data type.
#[inline]
pub fn to_arrow(&self, compat_level: CompatLevel) -> ArrowDataType {
Expand Down
46 changes: 4 additions & 42 deletions crates/polars-plan/src/dsl/function_expr/bounds.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,50 +2,12 @@ use super::*;

pub(super) fn upper_bound(s: &Column) -> PolarsResult<Column> {
let name = s.name().clone();
use DataType::*;
let s = match s.dtype().to_physical() {
#[cfg(feature = "dtype-i8")]
Int8 => Column::new_scalar(name, Scalar::from(i8::MAX), 1),
#[cfg(feature = "dtype-i16")]
Int16 => Column::new_scalar(name, Scalar::from(i16::MAX), 1),
Int32 => Column::new_scalar(name, Scalar::from(i32::MAX), 1),
Int64 => Column::new_scalar(name, Scalar::from(i64::MAX), 1),
#[cfg(feature = "dtype-u8")]
UInt8 => Column::new_scalar(name, Scalar::from(u8::MAX), 1),
#[cfg(feature = "dtype-u16")]
UInt16 => Column::new_scalar(name, Scalar::from(u16::MAX), 1),
UInt32 => Column::new_scalar(name, Scalar::from(u32::MAX), 1),
UInt64 => Column::new_scalar(name, Scalar::from(u64::MAX), 1),
Float32 => Column::new_scalar(name, Scalar::from(f32::INFINITY), 1),
Float64 => Column::new_scalar(name, Scalar::from(f64::INFINITY), 1),
dt => polars_bail!(
ComputeError: "cannot determine upper bound for dtype `{}`", dt,
),
};
Ok(s)
let scalar = s.dtype().to_physical().max()?;
Ok(Column::new_scalar(name, scalar, 1))
}

pub(super) fn lower_bound(s: &Column) -> PolarsResult<Column> {
let name = s.name().clone();
use DataType::*;
let s = match s.dtype().to_physical() {
#[cfg(feature = "dtype-i8")]
Int8 => Column::new_scalar(name, Scalar::from(i8::MIN), 1),
#[cfg(feature = "dtype-i16")]
Int16 => Column::new_scalar(name, Scalar::from(i16::MIN), 1),
Int32 => Column::new_scalar(name, Scalar::from(i32::MIN), 1),
Int64 => Column::new_scalar(name, Scalar::from(i64::MIN), 1),
#[cfg(feature = "dtype-u8")]
UInt8 => Column::new_scalar(name, Scalar::from(u8::MIN), 1),
#[cfg(feature = "dtype-u16")]
UInt16 => Column::new_scalar(name, Scalar::from(u16::MIN), 1),
UInt32 => Column::new_scalar(name, Scalar::from(u32::MIN), 1),
UInt64 => Column::new_scalar(name, Scalar::from(u64::MIN), 1),
Float32 => Column::new_scalar(name, Scalar::from(f32::NEG_INFINITY), 1),
Float64 => Column::new_scalar(name, Scalar::from(f64::NEG_INFINITY), 1),
dt => polars_bail!(
ComputeError: "cannot determine lower bound for dtype `{}`", dt,
),
};
Ok(s)
let scalar = s.dtype().to_physical().min()?;
Ok(Column::new_scalar(name, scalar, 1))
}
16 changes: 15 additions & 1 deletion crates/polars-python/src/datatypes.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
use polars::prelude::*;
use polars_core::utils::arrow::array::Utf8ViewArray;
use polars_lazy::dsl;
use pyo3::prelude::*;

use crate::error::PyPolarsErr;
#[cfg(feature = "object")]
use crate::object::OBJECT_NAME;
use crate::Wrap;
use crate::{PyExpr, Wrap};

// Don't change the order of these!
#[repr(u8)]
Expand Down Expand Up @@ -117,3 +119,15 @@ impl<'py> FromPyObject<'py> for PyDataType {
Ok(dt.0.into())
}
}

#[pyfunction]
pub fn _get_dtype_max(dt: Wrap<DataType>) -> PyResult<PyExpr> {
let v = dt.0.max().map_err(PyPolarsErr::from)?;
Ok(dsl::lit(v).into())
}

#[pyfunction]
pub fn _get_dtype_min(dt: Wrap<DataType>) -> PyResult<PyExpr> {
let v = dt.0.min().map_err(PyPolarsErr::from)?;
Ok(dsl::lit(v).into())
}
39 changes: 39 additions & 0 deletions py-polars/polars/datatypes/classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import polars.functions as F

with contextlib.suppress(ImportError): # Module not available when building docs
import polars.polars as plr
from polars.polars import dtype_str_repr as _dtype_str_repr

if TYPE_CHECKING:
Expand Down Expand Up @@ -238,6 +239,44 @@ def to_python(self) -> PythonDataType:
class NumericType(DataType):
"""Base class for numeric data types."""

@classmethod
def max(cls) -> pl.Expr:
"""
Return a literal expression representing the maximum value of this data type.
Examples
--------
>>> pl.select(pl.Int8.max() == 127)
shape: (1, 1)
┌─────────┐
│ literal │
│ --- │
│ bool │
╞═════════╡
│ true │
└─────────┘
"""
return pl.Expr._from_pyexpr(plr._get_dtype_max(cls))

@classmethod
def min(cls) -> pl.Expr:
"""
Return a literal expression representing the minimum value of this data type.
Examples
--------
>>> pl.select(pl.Int8.min() == -128)
shape: (1, 1)
┌─────────┐
│ literal │
│ --- │
│ bool │
╞═════════╡
│ true │
└─────────┘
"""
return pl.Expr._from_pyexpr(plr._get_dtype_min(cls))


class IntegerType(NumericType):
"""Base class for integer data types."""
Expand Down
8 changes: 7 additions & 1 deletion py-polars/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use polars_python::lazygroupby::PyLazyGroupBy;
use polars_python::series::PySeries;
#[cfg(feature = "sql")]
use polars_python::sql::PySQLContext;
use polars_python::{exceptions, functions};
use polars_python::{datatypes, exceptions, functions};
use pyo3::prelude::*;
use pyo3::{wrap_pyfunction, wrap_pymodule};

Expand Down Expand Up @@ -279,6 +279,12 @@ fn polars(py: Python, m: &Bound<PyModule>) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(functions::escape_regex))
.unwrap();

// Dtype helpers
m.add_wrapped(wrap_pyfunction!(datatypes::_get_dtype_max))
.unwrap();
m.add_wrapped(wrap_pyfunction!(datatypes::_get_dtype_min))
.unwrap();

// Exceptions - Errors
m.add(
"PolarsError",
Expand Down
25 changes: 25 additions & 0 deletions py-polars/tests/unit/test_datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,3 +202,28 @@ def test_struct_field_iter() -> None:
def test_raise_invalid_namespace() -> None:
with pytest.raises(pl.exceptions.InvalidOperationError):
pl.select(pl.lit(1.5).str.replace("1", "2"))


@pytest.mark.parametrize(
("dtype", "lower", "upper"),
[
(pl.Int8, -128, 127),
(pl.UInt8, 0, 255),
(pl.Int16, -32768, 32767),
(pl.UInt16, 0, 65535),
(pl.Int32, -2147483648, 2147483647),
(pl.UInt32, 0, 4294967295),
(pl.Int64, -9223372036854775808, 9223372036854775807),
(pl.UInt64, 0, 18446744073709551615),
(pl.Float32, float("-inf"), float("inf")),
(pl.Float64, float("-inf"), float("inf")),
],
)
def test_max_min(
dtype: datatypes.IntegerType | datatypes.Float32 | datatypes.Float64,
upper: int | float,
lower: int | float,
) -> None:
df = pl.select(min=dtype.min(), max=dtype.max())
assert df.to_series(0).item() == lower
assert df.to_series(1).item() == upper

0 comments on commit 017508b

Please sign in to comment.