From 05c4c4b381b47c3643ae19f84d90789f62927b56 Mon Sep 17 00:00:00 2001 From: Bruno Oliveira Date: Tue, 17 Dec 2024 10:14:49 -0300 Subject: [PATCH] Apply suggestions from code review --- src/pytest_regressions/common.py | 26 ++++++++++++++--------- src/pytest_regressions/data_regression.py | 9 ++++---- 2 files changed, 21 insertions(+), 14 deletions(-) diff --git a/src/pytest_regressions/common.py b/src/pytest_regressions/common.py index 0655652..c12a01f 100644 --- a/src/pytest_regressions/common.py +++ b/src/pytest_regressions/common.py @@ -197,22 +197,28 @@ def make_location_message(banner: str, filename: Path, aux_files: List[str]) -> T = TypeVar("T", bound=Union[MutableSequence, MutableMapping]) -def round_digits(data: T, precision: int) -> T: +def round_digits(data: T, digits: int) -> T: """ - Recursively Round the values of any float value in a collection to the given precision. - - :param data: The collection to round. - :param precision: The number of decimal places to round to. - :return: The collection with all float values rounded to the given precision. - + Recursively round the values of any float value in a collection to the given number of digits. The rounding is done in-place. + + :param data: + The collection to round. + + :param digits: + The number of digits to round to. + + :return: + The collection with all float values rounded to the given precision. + Note that the rounding is done in-place, so this return value only exists + because we use the function recursively. """ - # change the generator depending on the collection type + # Change the generator depending on the collection type. generator = enumerate(data) if isinstance(data, MutableSequence) else data.items() for k, v in generator: if isinstance(v, (MutableSequence, MutableMapping)): - data[k] = round_digits(v, precision) + data[k] = round_digits(v, digits) elif isinstance(v, float): - data[k] = round(v, precision) + data[k] = round(v, digits) else: data[k] = v return data diff --git a/src/pytest_regressions/data_regression.py b/src/pytest_regressions/data_regression.py index e9c40c6..9bf0c9a 100644 --- a/src/pytest_regressions/data_regression.py +++ b/src/pytest_regressions/data_regression.py @@ -33,7 +33,7 @@ def check( data_dict: Dict[str, Any], basename: Optional[str] = None, fullpath: Optional["os.PathLike[str]"] = None, - precision: Optional[int] = None, + round_digits: Optional[int] = None, ) -> None: """ Checks the given dict against a previously recorded version, or generate a new file. @@ -48,14 +48,15 @@ def check( will ignore ``datadir`` fixture when reading *expected* files but will still use it to write *obtained* files. Useful if a reference file is located in the session data dir for example. - :param precision: if given, round all floats in the dict to the given number of digits. + :param round_digits: + If given, round all floats in the dict to the given number of digits. ``basename`` and ``fullpath`` are exclusive. """ __tracebackhide__ = True - if precision is not None: - round_digits(data_dict, precision) + if round_digits is not None: + round_digits(data_dict, round_digits) def dump(filename: Path) -> None: """Dump dict contents to the given filename"""