From aea3c336c1562d1a1218c9edee8885e408357cc5 Mon Sep 17 00:00:00 2001 From: Marshall Crumiller Date: Sun, 22 Sep 2024 15:50:44 -0400 Subject: [PATCH] Fix join arg checks --- py-polars/polars/lazyframe/frame.py | 18 ++++++-- py-polars/tests/unit/operations/test_join.py | 46 ++++++++++++++++++++ 2 files changed, 60 insertions(+), 4 deletions(-) diff --git a/py-polars/polars/lazyframe/frame.py b/py-polars/polars/lazyframe/frame.py index dc39ab3c198c..cf98067d1809 100644 --- a/py-polars/polars/lazyframe/frame.py +++ b/py-polars/polars/lazyframe/frame.py @@ -4502,6 +4502,17 @@ def join( msg = f"expected `other` join table to be a LazyFrame, not a {type(other).__name__!r}" raise TypeError(msg) + uses_on = on is not None + uses_left_on = left_on is not None + uses_right_on = right_on is not None + uses_lr_on = uses_left_on or uses_right_on + if uses_on and uses_lr_on: + msg = "cannot use 'on' in conjunction with 'left_on' or 'right_on'" + raise ValueError(msg) + elif uses_left_on != uses_right_on: + msg = "'left_on' requires corresponding 'right_on'" + raise ValueError(msg) + if how == "outer": how = "full" issue_deprecation_warning( @@ -4515,9 +4526,8 @@ def join( "Use of `how='outer_coalesce'` should be replaced with `how='full', coalesce=True`.", version="0.20.29", ) - elif how == "cross": - if left_on is not None or right_on is not None: + if uses_on or uses_lr_on: msg = "cross join should not pass join keys" raise ValueError(msg) return self._from_pyldf( @@ -4534,11 +4544,11 @@ def join( ) ) - if on is not None: + if uses_on: pyexprs = parse_into_list_of_expressions(on) pyexprs_left = pyexprs pyexprs_right = pyexprs - elif left_on is not None and right_on is not None: + elif uses_lr_on: pyexprs_left = parse_into_list_of_expressions(left_on) pyexprs_right = parse_into_list_of_expressions(right_on) else: diff --git a/py-polars/tests/unit/operations/test_join.py b/py-polars/tests/unit/operations/test_join.py index 0a9d7ab2d9fd..68220cf81551 100644 --- a/py-polars/tests/unit/operations/test_join.py +++ b/py-polars/tests/unit/operations/test_join.py @@ -1036,3 +1036,49 @@ def test_join_coalesce_not_supported_warning() -> None: ) assert_frame_equal(expect, got, check_row_order=False) + + +@pytest.mark.parametrize( + ("on_args"), + [ + {"on": "a", "left_on": "a"}, + {"on": "a", "right_on": "a"}, + {"on": "a", "left_on": "a", "right_on": "a"}, + ], +) +def test_join_on_and_left_right_on(on_args: dict[str, str]) -> None: + df1 = pl.DataFrame({"a": [1], "b": [2]}) + df2 = pl.DataFrame({"a": [1], "c": [3]}) + msg = "cannot use 'on' in conjunction with 'left_on' or 'right_on'" + with pytest.raises(ValueError, match=msg): + df1.join(df2, **on_args) # type: ignore[arg-type] + + +@pytest.mark.parametrize( + ("on_args"), + [ + {"left_on": "a"}, + {"right_on": "a"}, + ], +) +def test_join_only_left_or_right_on(on_args: dict[str, str]) -> None: + df1 = pl.DataFrame({"a": [1]}) + df2 = pl.DataFrame({"a": [1]}) + msg = "'left_on' requires corresponding 'right_on'" + with pytest.raises(ValueError, match=msg): + df1.join(df2, **on_args) # type: ignore[arg-type] + + +@pytest.mark.parametrize( + ("on_args"), + [ + {"on": "a"}, + {"left_on": "a", "right_on": "a"}, + ], +) +def test_cross_join_no_on_keys(on_args: dict[str, str]) -> None: + df1 = pl.DataFrame({"a": [1, 2]}) + df2 = pl.DataFrame({"b": [3, 4]}) + msg = "cross join should not pass join keys" + with pytest.raises(ValueError, match=msg): + df1.join(df2, how="cross", **on_args) # type: ignore[arg-type]