Skip to content

Commit

Permalink
Variables: Handle repeated pandas column labels when inspecting; Rewo…
Browse files Browse the repository at this point in the history
…rk variables service tests (#3426)
  • Loading branch information
seeM authored Jun 6, 2024
1 parent 1a1d5b7 commit 0ce25fc
Show file tree
Hide file tree
Showing 4 changed files with 326 additions and 132 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@
SIMPLER_NAMES = {
"pandas.core.frame.DataFrame": "pandas.DataFrame",
"pandas.core.series.Series": "pandas.Series",
"polars.dataframe.frame.DataFrame": "polars.DataFrame",
"polars.series.series.Series": "polars.Series",
}


Expand Down Expand Up @@ -761,8 +763,8 @@ def get_display_value(
print_width: Optional[int] = PRINT_WIDTH,
truncate_at: int = TRUNCATE_AT,
) -> Tuple[str, bool]:
# TODO(pyright): cast shouldn't be necessary, recheck in a future version of pyright
display_value = str(cast(Column, self.value[:100]).to_list())
display_value = _get_class_display(self.value)
display_value = f"[{len(self.value)} values] {display_value}"
return (display_value, True)


Expand Down Expand Up @@ -798,15 +800,6 @@ def to_plaintext(self) -> str:
def has_viewer(self) -> bool:
return True

def get_display_value(
self,
print_width: Optional[int] = PRINT_WIDTH,
truncate_at: int = TRUNCATE_AT,
) -> Tuple[str, bool]:
display_value = _get_class_display(self.value)
display_value = f"[{len(self.value)} values] {display_value}"
return (display_value, True)


class PandasIndexInspector(BaseColumnInspector["pd.Index"]):
CLASS_QNAME = [
Expand All @@ -829,7 +822,8 @@ def get_display_value(
if isinstance(self.value, not_none(pd_).RangeIndex):
return str(self.value), False

return super().get_display_value(print_width, truncate_at)
display_value = str(self.value[:100].to_list())
return display_value, True

def has_children(self) -> bool:
# For ranges, we don't visualize the children as they're
Expand Down Expand Up @@ -896,24 +890,12 @@ def get_length(self) -> int:
# number of rows per column is handled by ColumnInspector
return self.value.shape[1]

def get_children(self) -> Collection[Any]:
return self.value.columns

def has_viewer(self) -> bool:
return True

def is_mutable(self) -> bool:
return True


#
# Custom inspectors for specific types
#


class PandasDataFrameInspector(BaseTableInspector["pd.DataFrame", "pd.Series"]):
CLASS_QNAME = "pandas.core.frame.DataFrame"

def get_display_value(
self,
print_width: Optional[int] = PRINT_WIDTH,
Expand All @@ -926,6 +908,24 @@ def get_display_value(

return (display_value, True)


#
# Custom inspectors for specific types
#


class PandasDataFrameInspector(BaseTableInspector["pd.DataFrame", "pd.Series"]):
CLASS_QNAME = "pandas.core.frame.DataFrame"

def get_display_name(self, key: int) -> str:
return str(self.value.columns[key])

def get_children(self):
return range(self.value.shape[1])

def get_child(self, key: int) -> Any:
return self.value.iloc[:, key]

def equals(self, value: pd.DataFrame) -> bool:
return self.value.equals(value)

Expand All @@ -947,6 +947,9 @@ class PolarsDataFrameInspector(BaseTableInspector["pl.DataFrame", "pl.Series"]):
"polars.internals.dataframe.frame.DataFrame",
]

def get_children(self):
return self.value.columns

def get_display_value(
self,
print_width: Optional[int] = PRINT_WIDTH,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,26 @@ def patch_create_comm(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(comm, "create_comm", DummyComm)


def _prepare_shell(shell: PositronShell) -> None:
# TODO: For some reason these vars are in user_ns but not user_ns_hidden during tests. For now,
# manually add them to user_ns_hidden to replicate running in Positron.
shell.user_ns_hidden.update(
{
k: None
for k in [
"__name__",
"__doc__",
"__package__",
"__loader__",
"__spec__",
"_",
"__",
"___",
]
}
)


@pytest.fixture
def kernel() -> PositronIPyKernel:
"""
Expand All @@ -77,30 +97,18 @@ def kernel() -> PositronIPyKernel:
)
raise

# Prepare the shell here as well, since users of this fixture may indirectly depend on it
# e.g. the variables service.
_prepare_shell(kernel.shell)

return kernel


@pytest.fixture
def shell(kernel) -> Iterable[PositronShell]:
shell = PositronShell.instance(parent=kernel)

# TODO: For some reason these vars are in user_ns but not user_ns_hidden during tests. For now,
# manually add them to user_ns_hidden to replicate running in Positron.
shell.user_ns_hidden.update(
{
k: None
for k in [
"__name__",
"__doc__",
"__package__",
"__loader__",
"__spec__",
"_",
"__",
"___",
]
}
)
_prepare_shell(shell)

user_ns_keys = set(shell.user_ns.keys())

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -760,7 +760,7 @@ def test_inspect_polars_dataframe() -> None:
rows, cols = value.shape
verify_inspector(
value=value,
display_value=f"[{rows} rows x {cols} columns] {get_type_as_str(value)}",
display_value=f"[{rows} rows x {cols} columns] polars.DataFrame",
kind=VariableKind.Table,
display_type=f"DataFrame [{rows}x{cols}]",
type_info=get_type_as_str(value),
Expand All @@ -782,7 +782,7 @@ def mutate(x):

verify_inspector(
value=value,
display_value="[0, 1]",
display_value=f"[{rows} values] polars.Series",
kind=VariableKind.Map,
display_type=f"Int64 [{rows}]",
type_info=get_type_as_str(value),
Expand All @@ -799,7 +799,7 @@ def mutate(x):
[
(pd.Series({"a": 0, "b": 1}), range(2)),
(pl.Series([0, 1]), range(2)),
(pd.DataFrame({"a": [1, 2], "b": ["3", "4"]}), ["a", "b"]),
(pd.DataFrame({"a": [1, 2], "b": ["3", "4"]}), range(2)),
(pl.DataFrame({"a": [1, 2], "b": ["3", "4"]}), ["a", "b"]),
(pd.Index([0, 1]), range(0, 2)),
(
Expand Down Expand Up @@ -827,7 +827,7 @@ def test_get_children(data: Any, expected: Iterable) -> None:
(pl.Series([0, 1]), 0, 0),
(
pd.DataFrame({"a": [1, 2], "b": ["3", "4"]}),
"a",
0,
pd.Series([1, 2], name="a"),
),
(
Expand Down
Loading

0 comments on commit 0ce25fc

Please sign in to comment.