Skip to content

Commit

Permalink
feat: add DuckDB Struct dtype (#1851)
Browse files Browse the repository at this point in the history
  • Loading branch information
raisadz authored Jan 22, 2025
1 parent 07e8024 commit 38e5261
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 9 deletions.
11 changes: 7 additions & 4 deletions narwhals/_duckdb/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,9 +157,9 @@ def native_to_narwhals_dtype(duckdb_dtype: str, version: Version) -> DType:
def narwhals_to_native_dtype(dtype: DType | type[DType], version: Version) -> str:
dtypes = import_dtypes_module(version)
if isinstance_or_issubclass(dtype, dtypes.Float64):
return "FLOAT"
if isinstance_or_issubclass(dtype, dtypes.Float32):
return "DOUBLE"
if isinstance_or_issubclass(dtype, dtypes.Float32):
return "FLOAT"
if isinstance_or_issubclass(dtype, dtypes.Int64):
return "BIGINT"
if isinstance_or_issubclass(dtype, dtypes.Int32):
Expand Down Expand Up @@ -198,8 +198,11 @@ def narwhals_to_native_dtype(dtype: DType | type[DType], version: Version) -> st
inner = narwhals_to_native_dtype(dtype.inner, version) # type: ignore[union-attr]
return f"{inner}[]"
if isinstance_or_issubclass(dtype, dtypes.Struct): # pragma: no cover
msg = "todo"
raise NotImplementedError(msg)
inner = ", ".join(
f'"{field.name}" {narwhals_to_native_dtype(field.dtype, version)}'
for field in dtype.fields # type: ignore[union-attr]
)
return f"STRUCT({inner})"
if isinstance_or_issubclass(dtype, dtypes.Array): # pragma: no cover
msg = "todo"
raise NotImplementedError(msg)
Expand Down
9 changes: 4 additions & 5 deletions tests/expr_and_series/cast_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,8 +236,7 @@ def test_cast_datetime_tz_aware(

def test_cast_struct(request: pytest.FixtureRequest, constructor: Constructor) -> None:
if any(
backend in str(constructor)
for backend in ("dask", "modin", "cudf", "duckdb", "pyspark")
backend in str(constructor) for backend in ("dask", "modin", "cudf", "pyspark")
):
request.applymarker(pytest.mark.xfail)

Expand All @@ -246,12 +245,12 @@ def test_cast_struct(request: pytest.FixtureRequest, constructor: Constructor) -

data = {
"a": [
{"movie": "Cars", "rating": 4.5},
{"movie": "Toy Story", "rating": 4.9},
{"movie ": "Cars", "rating": 4.5},
{"movie ": "Toy Story", "rating": 4.9},
]
}

dtype = nw.Struct([nw.Field("movie", nw.String()), nw.Field("rating", nw.Float64())])
dtype = nw.Struct([nw.Field("movie ", nw.String()), nw.Field("rating", nw.Float64())])
result = (
nw.from_native(constructor(data)).select(nw.col("a").cast(dtype)).lazy().collect()
)
Expand Down

0 comments on commit 38e5261

Please sign in to comment.