From 8f37672868f69a1bc4080c732f5f3f2a4d662ca7 Mon Sep 17 00:00:00 2001 From: JBGreisman Date: Fri, 16 Aug 2024 11:52:43 -0400 Subject: [PATCH 1/3] Add test to verify rs.DataSet call signatures against pandas --- tests/test_dataset_signatures.py | 41 ++++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) create mode 100644 tests/test_dataset_signatures.py diff --git a/tests/test_dataset_signatures.py b/tests/test_dataset_signatures.py new file mode 100644 index 00000000..63fcfe8d --- /dev/null +++ b/tests/test_dataset_signatures.py @@ -0,0 +1,41 @@ +from inspect import signature + +import pandas as pd +import pytest +from pandas.testing import assert_frame_equal + + +def test_reset_index_signature(dataset_hkl): + """ + Test call signature of rs.DataSet.reset_index() matches call signature of + pd.DataFrame.reset_index() using default parameters + """ + df = pd.DataFrame(dataset_hkl) + sig = signature(pd.DataFrame.reset_index) + bsig = sig.bind(df) + bsig.apply_defaults() + + expected = df.reset_index(*bsig.args[1:], **bsig.kwargs) + result = dataset_hkl.reset_index(*bsig.args[1:], **bsig.kwargs) + result = pd.DataFrame(result) + + assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("names", ["H", "K", ["H", "K"]]) +def test_set_index_signature(dataset_hkl, names): + """ + Test call signature of rs.DataSet.set_index() matches call signature of + pd.DataFrame.set_index() using default parameters + """ + ds = dataset_hkl.reset_index() + df = pd.DataFrame(ds) + sig = signature(pd.DataFrame.set_index) + bsig = sig.bind(df, names) + bsig.apply_defaults() + + expected = df.set_index(*bsig.args[1:], **bsig.kwargs) + result = ds.set_index(*bsig.args[1:], **bsig.kwargs) + result = pd.DataFrame(result) + + assert_frame_equal(result, expected) From 35402591e0e6d0d2d532ed1e0e2de5fefe1d67d9 Mon Sep 17 00:00:00 2001 From: JBGreisman Date: Fri, 16 Aug 2024 12:16:51 -0400 Subject: [PATCH 2/3] Add minimal example to tests from GH#223 --- tests/test_dataset_signatures.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/tests/test_dataset_signatures.py b/tests/test_dataset_signatures.py index 63fcfe8d..fe780d78 100644 --- a/tests/test_dataset_signatures.py +++ b/tests/test_dataset_signatures.py @@ -4,6 +4,18 @@ import pytest from pandas.testing import assert_frame_equal +import reciprocalspaceship as rs + + +def test_reset_index_dataseries(): + """ + Minimal example from GH#223 + """ + result = rs.DataSeries(range(10)).reset_index() + expected = pd.Series(range(10)).reset_index() + expected = rs.DataSet(expected) + assert_frame_equal(result, expected) + def test_reset_index_signature(dataset_hkl): """ From 92c32c3266be4732af9549be8e96d1d984fb8045 Mon Sep 17 00:00:00 2001 From: JBGreisman Date: Fri, 16 Aug 2024 12:17:15 -0400 Subject: [PATCH 3/3] Fixes #223: Correct the call signature to rs.DataSet.reset_index() --- reciprocalspaceship/dataset.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/reciprocalspaceship/dataset.py b/reciprocalspaceship/dataset.py index b4a681d3..04c56f19 100644 --- a/reciprocalspaceship/dataset.py +++ b/reciprocalspaceship/dataset.py @@ -258,7 +258,14 @@ def set_index( ) def reset_index( - self, level=None, drop=False, inplace=False, col_level=0, col_fill="" + self, + level=None, + drop=False, + inplace=False, + col_level=0, + col_fill="", + allow_duplicates=lib.no_default, + names=None, ): """ Reset the index or a specific level of a MultiIndex. @@ -281,6 +288,12 @@ def reset_index( col_fill : object If the columns have multiple levels, determines how the other levels are named. If None then the index name is repeated. + allow_duplicates : bool + Allow duplicate column labels to be created. + names : int, str, tuple, list + Using the given string, rename the DataSet column which contains the + index data. If the DataSet has a MultiIndex, this has to be a list or + tuple with length equal to the number of levels. Returns ------- @@ -317,6 +330,8 @@ def _handle_cached_dtypes(dataset, columns, drop): inplace=inplace, col_level=col_level, col_fill=col_fill, + allow_duplicates=allow_duplicates, + names=names, ) _handle_cached_dtypes(self, columns, drop) return @@ -327,6 +342,8 @@ def _handle_cached_dtypes(dataset, columns, drop): inplace=inplace, col_level=col_level, col_fill=col_fill, + allow_duplicates=allow_duplicates, + names=names, ) dataset._index_dtypes = dataset._index_dtypes.copy() dataset = _handle_cached_dtypes(dataset, columns, drop)