diff --git a/narwhals/_spark_like/namespace.py b/narwhals/_spark_like/namespace.py index b7e3bfb78..87963c417 100644 --- a/narwhals/_spark_like/namespace.py +++ b/narwhals/_spark_like/namespace.py @@ -16,6 +16,7 @@ from narwhals._spark_like.selectors import SparkLikeSelectorNamespace from narwhals._spark_like.utils import ExprKind from narwhals._spark_like.utils import n_ary_operation_expr_kind +from narwhals._spark_like.utils import narwhals_to_native_dtype from narwhals.typing import CompliantNamespace if TYPE_CHECKING: @@ -80,12 +81,15 @@ def nth(self: Self, *column_indices: int) -> SparkLikeExpr: ) def lit(self: Self, value: object, dtype: DType | None) -> SparkLikeExpr: - if dtype is not None: - msg = "todo" - raise NotImplementedError(msg) - def _lit(df: SparkLikeLazyFrame) -> list[Column]: - return [df._F.lit(value)] + column = df._F.lit(value) + if dtype: + native_dtype = narwhals_to_native_dtype( + dtype, version=self._version, spark_types=df._native_dtypes + ) + column = column.cast(native_dtype) + + return [column] return SparkLikeExpr( call=_lit, diff --git a/narwhals/_spark_like/utils.py b/narwhals/_spark_like/utils.py index 3a728d5c5..7ef7f6b01 100644 --- a/narwhals/_spark_like/utils.py +++ b/narwhals/_spark_like/utils.py @@ -65,8 +65,10 @@ def native_to_narwhals_dtype( return dtypes.Boolean() if isinstance(dtype, spark_types.DateType): return dtypes.Date() - if isinstance(dtype, (spark_types.TimestampType, spark_types.TimestampNTZType)): + if isinstance(dtype, spark_types.TimestampNTZType): return dtypes.Datetime() + if isinstance(dtype, spark_types.TimestampType): + return dtypes.Datetime(time_zone="UTC") if isinstance(dtype, spark_types.DecimalType): # pragma: no cover # TODO(unassigned): cover this in dtypes_test.py return dtypes.Decimal() @@ -100,9 +102,16 @@ def narwhals_to_native_dtype( return spark_types.StringType() if isinstance_or_issubclass(dtype, dtypes.Boolean): return spark_types.BooleanType() - if isinstance_or_issubclass(dtype, (dtypes.Date, dtypes.Datetime)): - msg = "Converting to Date or Datetime dtype is not supported yet" - raise NotImplementedError(msg) + if isinstance_or_issubclass(dtype, dtypes.Date): + return spark_types.DateType() + if isinstance_or_issubclass(dtype, dtypes.Datetime): + dt_time_zone = getattr(dtype, "time_zone", None) + if dt_time_zone is None: + return spark_types.TimestampNTZType() + if dt_time_zone != "UTC": # pragma: no cover + msg = f"Only UTC time zone is supported for PySpark, got: {dt_time_zone}" + raise ValueError(msg) + return spark_types.TimestampType() if isinstance_or_issubclass(dtype, dtypes.List): # pragma: no cover inner = narwhals_to_native_dtype( dtype.inner, # type: ignore[union-attr] diff --git a/tests/expr_and_series/lit_test.py b/tests/expr_and_series/lit_test.py index 900592766..6907b00c7 100644 --- a/tests/expr_and_series/lit_test.py +++ b/tests/expr_and_series/lit_test.py @@ -22,13 +22,10 @@ [(None, [2, 2, 2]), (nw.String, ["2", "2", "2"]), (nw.Float32, [2.0, 2.0, 2.0])], ) def test_lit( - request: pytest.FixtureRequest, constructor: Constructor, dtype: DType | None, expected_lit: list[Any], ) -> None: - if "pyspark" in str(constructor) and dtype is not None: - request.applymarker(pytest.mark.xfail) data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8.0, 9.0]} df_raw = constructor(data) df = nw.from_native(df_raw).lazy() @@ -83,6 +80,7 @@ def test_lit_out_name(constructor: Constructor) -> None: ("right_scalar", nw.col("a") + 1, [2, 4, 3]), ("left_scalar_with_agg", 1 + nw.col("a").mean(), [3]), ("right_scalar_with_agg", nw.col("a").mean() - 1, [1]), + ("lit_compare", nw.col("a") == nw.lit(3), [False, True, False]), ], ) def test_lit_operation_in_select( @@ -130,7 +128,7 @@ def test_lit_operation_in_with_columns( @pytest.mark.skipif(PANDAS_VERSION < (1, 5), reason="too old for pyarrow") def test_date_lit(constructor: Constructor, request: pytest.FixtureRequest) -> None: - if "dask" in str(constructor) or "pyspark" in str(constructor): + if "dask" in str(constructor): # https://github.com/dask/dask/issues/11637 request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor({"a": [1]}))