Skip to content

Commit

Permalink
chore: simplify imports (#1838)
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoGorelli authored Jan 20, 2025
1 parent 04274b3 commit 5ca7688
Show file tree
Hide file tree
Showing 33 changed files with 108 additions and 599 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -95,4 +95,4 @@ jobs:
- name: Run doctests
# reprs differ between versions, so we only run doctests on the latest Python
if: matrix.python-version == '3.13'
run: pytest narwhals --doctest-modules
run: pytest narwhals/*.py --doctest-modules
21 changes: 3 additions & 18 deletions narwhals/_arrow/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
from typing import Sequence
from typing import overload

import pyarrow as pa
import pyarrow.compute as pc

from narwhals._arrow.utils import broadcast_series
from narwhals._arrow.utils import convert_str_slice_to_int_slice
from narwhals._arrow.utils import native_to_narwhals_dtype
Expand All @@ -31,7 +34,6 @@
import numpy as np
import pandas as pd
import polars as pl
import pyarrow as pa
from typing_extensions import Self

from narwhals._arrow.group_by import ArrowGroupBy
Expand Down Expand Up @@ -289,8 +291,6 @@ def columns(self: Self) -> list[str]:
return self._native_frame.schema.names # type: ignore[no-any-return]

def select(self: Self, *exprs: IntoArrowExpr, **named_exprs: IntoArrowExpr) -> Self:
import pyarrow as pa

new_series = evaluate_into_exprs(self, *exprs, **named_exprs)
if not new_series:
# return empty dataframe, like Polars does
Expand Down Expand Up @@ -469,8 +469,6 @@ def to_dict(
return {name: col.to_pylist() for name, col in names_and_values}

def with_row_index(self: Self, name: str) -> Self:
import pyarrow as pa

df = self._native_frame
cols = self.columns

Expand Down Expand Up @@ -506,8 +504,6 @@ def filter(self: Self, *predicates: IntoArrowExpr, **constraints: Any) -> Self:
return self._from_native_frame(self._native_frame.filter(mask_native))

def null_count(self: Self) -> Self:
import pyarrow as pa

df = self._native_frame
names_and_values = zip(df.column_names, df.columns)

Expand Down Expand Up @@ -583,7 +579,6 @@ def write_parquet(self: Self, file: Any) -> None:
pp.write_table(self._native_frame, file)

def write_csv(self: Self, file: Any) -> Any:
import pyarrow as pa
import pyarrow.csv as pa_csv

pa_table = self._native_frame
Expand All @@ -594,9 +589,6 @@ def write_csv(self: Self, file: Any) -> Any:
return pa_csv.write_csv(pa_table, file)

def is_duplicated(self: Self) -> ArrowSeries:
import pyarrow as pa
import pyarrow.compute as pc

from narwhals._arrow.series import ArrowSeries

columns = self.columns
Expand Down Expand Up @@ -631,8 +623,6 @@ def is_duplicated(self: Self) -> ArrowSeries:
return res.fill_null(res.null_count() > 1, strategy=None, limit=None)

def is_unique(self: Self) -> ArrowSeries:
import pyarrow.compute as pc

from narwhals._arrow.series import ArrowSeries

is_duplicated = self.is_duplicated()._native_series
Expand All @@ -654,8 +644,6 @@ def unique(
# The param `maintain_order` is only here for compatibility with the Polars API
# and has no effect on the output.
import numpy as np # ignore-banned-import
import pyarrow as pa
import pyarrow.compute as pc

df = self._native_frame
check_column_exists(self.columns, subset)
Expand Down Expand Up @@ -693,7 +681,6 @@ def sample(
seed: int | None,
) -> Self:
import numpy as np # ignore-banned-import
import pyarrow.compute as pc

frame = self._native_frame
num_rows = len(self)
Expand All @@ -713,8 +700,6 @@ def unpivot(
variable_name: str | None,
value_name: str | None,
) -> Self:
import pyarrow as pa

native_frame = self._native_frame
variable_name = variable_name if variable_name is not None else "variable"
value_name = value_name if value_name is not None else "value"
Expand Down
12 changes: 3 additions & 9 deletions narwhals/_arrow/group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,16 @@
from typing import Iterator
from typing import Sequence

import pyarrow as pa
import pyarrow.compute as pc

from narwhals._expression_parsing import is_simple_aggregation
from narwhals._expression_parsing import parse_into_exprs
from narwhals.exceptions import AnonymousExprError
from narwhals.utils import generate_temporary_column_name
from narwhals.utils import remove_prefix

if TYPE_CHECKING:
import pyarrow as pa
import pyarrow.compute as pc
from typing_extensions import Self

from narwhals._arrow.dataframe import ArrowDataFrame
Expand All @@ -41,8 +42,6 @@ class ArrowGroupBy:
def __init__(
self: Self, df: ArrowDataFrame, keys: list[str], *, drop_null_keys: bool
) -> None:
import pyarrow as pa

if drop_null_keys:
self._df = df.drop_nulls(keys)
else:
Expand Down Expand Up @@ -74,9 +73,6 @@ def agg(
)

def __iter__(self: Self) -> Iterator[tuple[Any, ArrowDataFrame]]:
import pyarrow as pa
import pyarrow.compute as pc

col_token = generate_temporary_column_name(n_bytes=8, columns=self._df.columns)
null_token = "__null_token_value__" # noqa: S105

Expand Down Expand Up @@ -114,8 +110,6 @@ def agg_arrow(
from_dataframe: Callable[[Any], ArrowDataFrame],
backend_version: tuple[int, ...],
) -> ArrowDataFrame:
import pyarrow.compute as pc

all_simple_aggs = True
for expr in exprs:
if not (
Expand Down
14 changes: 3 additions & 11 deletions narwhals/_arrow/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
from typing import Literal
from typing import Sequence

import pyarrow as pa
import pyarrow.compute as pc

from narwhals._arrow.dataframe import ArrowDataFrame
from narwhals._arrow.expr import ArrowExpr
from narwhals._arrow.selectors import ArrowSelectorNamespace
Expand Down Expand Up @@ -85,8 +88,6 @@ def _create_series_from_scalar(
)

def _create_compliant_series(self: Self, value: Any) -> ArrowSeries:
import pyarrow as pa

from narwhals._arrow.series import ArrowSeries

return ArrowSeries(
Expand Down Expand Up @@ -266,8 +267,6 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]:
)

def min_horizontal(self: Self, *exprs: IntoArrowExpr) -> ArrowExpr:
import pyarrow.compute as pc

parsed_exprs = parse_into_exprs(*exprs, namespace=self)

def func(df: ArrowDataFrame) -> list[ArrowSeries]:
Expand Down Expand Up @@ -295,8 +294,6 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]:
)

def max_horizontal(self: Self, *exprs: IntoArrowExpr) -> ArrowExpr:
import pyarrow.compute as pc

parsed_exprs = parse_into_exprs(*exprs, namespace=self)

def func(df: ArrowDataFrame) -> list[ArrowSeries]:
Expand Down Expand Up @@ -369,8 +366,6 @@ def concat_str(
separator: str,
ignore_nulls: bool,
) -> ArrowExpr:
import pyarrow.compute as pc

parsed_exprs = [
*parse_into_exprs(*exprs, namespace=self),
*parse_into_exprs(*more_exprs, namespace=self),
Expand Down Expand Up @@ -428,9 +423,6 @@ def __init__(
self._version = version

def __call__(self: Self, df: ArrowDataFrame) -> Sequence[ArrowSeries]:
import pyarrow as pa
import pyarrow.compute as pc

from narwhals._expression_parsing import parse_into_expr

plx = df.__narwhals_namespace__()
Expand Down
Loading

0 comments on commit 5ca7688

Please sign in to comment.