Skip to content

Commit

Permalink
fix: pandas was raising when index name and column names overlapped i…
Browse files Browse the repository at this point in the history
…n groupby (#1908)
  • Loading branch information
MarcoGorelli authored Feb 1, 2025
1 parent 5a8668d commit 4ce9b0f
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 22 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
[![PyPI version](https://badge.fury.io/py/narwhals.svg)](https://badge.fury.io/py/narwhals)
[![Downloads](https://static.pepy.tech/badge/narwhals/month)](https://pepy.tech/project/narwhals)
[![Trusted publishing](https://img.shields.io/badge/Trusted_publishing-Provides_attestations-bright_green)](https://peps.python.org/pep-0740/)
[![PYPI - Types](https://img.shields.io/pypi/types/narwhals)](https://pypi.org/project/narwhals)

Extremely lightweight and extensible compatibility layer between dataframe libraries!

Expand Down
7 changes: 5 additions & 2 deletions docs/pandas_like_concepts/pandas_index.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ Narwhals aims to accommodate both!

Let's learn about what Narwhals promises.

## 1. Narwhals will preserve your index for dataframe operations
## 1. Narwhals will preserve your index for common dataframe operations

```python exec="1" source="above" session="ex1"
import narwhals as nw
Expand All @@ -39,7 +39,10 @@ print(my_func(df))
```

Note how the result still has the original index - Narwhals did not modify
it.
it. Narwhals will preserve your original index for most common dataframe
operations. However, Narwhals will _not_ preserve the original index for
`DataFrame.group_by`, because there, overlapping index and column names
raise errors.

## 2. Index alignment follows the left-hand-rule

Expand Down
19 changes: 7 additions & 12 deletions narwhals/_pandas_like/group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,32 +43,27 @@ def __init__(
) -> None:
self._df = df
self._keys = keys
# Drop index to avoid potential collisions:
# https://github.com/narwhals-dev/narwhals/issues/1907.
native_frame = df._native_frame.reset_index(drop=True)
if (
self._df._implementation is Implementation.PANDAS
and self._df._backend_version < (1, 1)
): # pragma: no cover
if (
not drop_null_keys
and select_columns_by_name(
self._df._native_frame,
self._keys,
self._df._backend_version,
self._df._implementation,
)
.isna()
.any()
.any()
and self._df.simple_select(*self._keys)._native_frame.isna().any().any()
):
msg = "Grouping by null values is not supported in pandas < 1.0.0"
msg = "Grouping by null values is not supported in pandas < 1.1.0"
raise NotImplementedError(msg)
self._grouped = self._df._native_frame.groupby(
self._grouped = native_frame.groupby(
list(self._keys),
sort=False,
as_index=True,
observed=True,
)
else:
self._grouped = self._df._native_frame.groupby(
self._grouped = native_frame.groupby(
list(self._keys),
sort=False,
as_index=True,
Expand Down
9 changes: 2 additions & 7 deletions narwhals/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,19 +60,14 @@ def simple_select(
) -> CompliantDataFrame: ... # `select` where all args are column names


class CompliantLazyFramme(Protocol):
def __narwhals_lazyframe__(self) -> CompliantDataFrame: ...
class CompliantLazyFrame(Protocol):
def __narwhals_lazyframe__(self) -> CompliantLazyFrame: ...
def __narwhals_namespace__(self) -> Any: ...
def simple_select(
self, *column_names: str
) -> CompliantLazyFrame: ... # `select` where all args are column names


class CompliantLazyFrame(Protocol):
def __narwhals_lazyframe__(self) -> CompliantLazyFrame: ...
def __narwhals_namespace__(self) -> Any: ...


CompliantSeriesT_co = TypeVar(
"CompliantSeriesT_co", bound=CompliantSeries, covariant=True
)
Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def sqlframe_pyspark_lazy_constructor(
"pyspark": pyspark_lazy_constructor, # type: ignore[dict-item]
# We've reported several bugs to sqlframe - once they address
# them, we can start testing them as part of our CI.
# "sqlframe": pyspark_lazy_constructor, # noqa: ERA001
# "sqlframe": sqlframe_pyspark_lazy_constructor, # noqa: ERA001
}
GPU_CONSTRUCTORS: dict[str, Callable[[Any], IntoFrame]] = {"cudf": cudf_constructor}

Expand Down
14 changes: 14 additions & 0 deletions tests/group_by_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,3 +406,17 @@ def test_group_by_expr(constructor: Constructor) -> None:
df = nw.from_native(constructor({"a": [1, 1, 3], "b": [4, 5, 6]}))
with pytest.raises(NotImplementedError, match=r"not \(yet\?\) supported"):
df.group_by(nw.col("a")).agg(nw.col("b").mean()) # type: ignore[arg-type]


def test_pandas_group_by_index_and_column_overlap() -> None:
df = pd.DataFrame(
{"a": [1, 1, 2], "b": [4, 5, 6]}, index=pd.Index([0, 1, 2], name="a")
)
result = nw.from_native(df, eager_only=True).group_by("a").agg(nw.col("b").mean())
expected = {"a": [1, 2], "b": [4.5, 6.0]}
assert_equal_data(result, expected)

key, result = next(iter(nw.from_native(df, eager_only=True).group_by("a")))
assert key == (1,)
expected_native = pd.DataFrame({"a": [1, 1], "b": [4, 5]})
pd.testing.assert_frame_equal(result.to_native(), expected_native)

0 comments on commit 4ce9b0f

Please sign in to comment.