Skip to content

Commit

Permalink
feat: spark date and datetimedata types, allow for dtype in `nw.lit…
Browse files Browse the repository at this point in the history
…` for pyspark (#1913)


---------

Co-authored-by: osoucy <osoucy.transactions@gmail.com>
Co-authored-by: Olivier Soucy <62808289+osoucy@users.noreply.github.com>
Co-authored-by: Francesco Bruzzesi <42817048+FBruzzesi@users.noreply.github.com>
  • Loading branch information
4 people authored Feb 2, 2025
1 parent 53e780c commit e0f37bf
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 13 deletions.
14 changes: 9 additions & 5 deletions narwhals/_spark_like/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from narwhals._spark_like.selectors import SparkLikeSelectorNamespace
from narwhals._spark_like.utils import ExprKind
from narwhals._spark_like.utils import n_ary_operation_expr_kind
from narwhals._spark_like.utils import narwhals_to_native_dtype
from narwhals.typing import CompliantNamespace

if TYPE_CHECKING:
Expand Down Expand Up @@ -80,12 +81,15 @@ def nth(self: Self, *column_indices: int) -> SparkLikeExpr:
)

def lit(self: Self, value: object, dtype: DType | None) -> SparkLikeExpr:
if dtype is not None:
msg = "todo"
raise NotImplementedError(msg)

def _lit(df: SparkLikeLazyFrame) -> list[Column]:
return [df._F.lit(value)]
column = df._F.lit(value)
if dtype:
native_dtype = narwhals_to_native_dtype(
dtype, version=self._version, spark_types=df._native_dtypes
)
column = column.cast(native_dtype)

return [column]

return SparkLikeExpr(
call=_lit,
Expand Down
17 changes: 13 additions & 4 deletions narwhals/_spark_like/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,10 @@ def native_to_narwhals_dtype(
return dtypes.Boolean()
if isinstance(dtype, spark_types.DateType):
return dtypes.Date()
if isinstance(dtype, (spark_types.TimestampType, spark_types.TimestampNTZType)):
if isinstance(dtype, spark_types.TimestampNTZType):
return dtypes.Datetime()
if isinstance(dtype, spark_types.TimestampType):
return dtypes.Datetime(time_zone="UTC")
if isinstance(dtype, spark_types.DecimalType): # pragma: no cover
# TODO(unassigned): cover this in dtypes_test.py
return dtypes.Decimal()
Expand Down Expand Up @@ -100,9 +102,16 @@ def narwhals_to_native_dtype(
return spark_types.StringType()
if isinstance_or_issubclass(dtype, dtypes.Boolean):
return spark_types.BooleanType()
if isinstance_or_issubclass(dtype, (dtypes.Date, dtypes.Datetime)):
msg = "Converting to Date or Datetime dtype is not supported yet"
raise NotImplementedError(msg)
if isinstance_or_issubclass(dtype, dtypes.Date):
return spark_types.DateType()
if isinstance_or_issubclass(dtype, dtypes.Datetime):
dt_time_zone = getattr(dtype, "time_zone", None)
if dt_time_zone is None:
return spark_types.TimestampNTZType()
if dt_time_zone != "UTC": # pragma: no cover
msg = f"Only UTC time zone is supported for PySpark, got: {dt_time_zone}"
raise ValueError(msg)
return spark_types.TimestampType()
if isinstance_or_issubclass(dtype, dtypes.List): # pragma: no cover
inner = narwhals_to_native_dtype(
dtype.inner, # type: ignore[union-attr]
Expand Down
6 changes: 2 additions & 4 deletions tests/expr_and_series/lit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,10 @@
[(None, [2, 2, 2]), (nw.String, ["2", "2", "2"]), (nw.Float32, [2.0, 2.0, 2.0])],
)
def test_lit(
request: pytest.FixtureRequest,
constructor: Constructor,
dtype: DType | None,
expected_lit: list[Any],
) -> None:
if "pyspark" in str(constructor) and dtype is not None:
request.applymarker(pytest.mark.xfail)
data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8.0, 9.0]}
df_raw = constructor(data)
df = nw.from_native(df_raw).lazy()
Expand Down Expand Up @@ -83,6 +80,7 @@ def test_lit_out_name(constructor: Constructor) -> None:
("right_scalar", nw.col("a") + 1, [2, 4, 3]),
("left_scalar_with_agg", 1 + nw.col("a").mean(), [3]),
("right_scalar_with_agg", nw.col("a").mean() - 1, [1]),
("lit_compare", nw.col("a") == nw.lit(3), [False, True, False]),
],
)
def test_lit_operation_in_select(
Expand Down Expand Up @@ -130,7 +128,7 @@ def test_lit_operation_in_with_columns(

@pytest.mark.skipif(PANDAS_VERSION < (1, 5), reason="too old for pyarrow")
def test_date_lit(constructor: Constructor, request: pytest.FixtureRequest) -> None:
if "dask" in str(constructor) or "pyspark" in str(constructor):
if "dask" in str(constructor):
# https://github.com/dask/dask/issues/11637
request.applymarker(pytest.mark.xfail)
df = nw.from_native(constructor({"a": [1]}))
Expand Down

0 comments on commit e0f37bf

Please sign in to comment.