Skip to content

Commit

Permalink
Merge dev into samukweku/refactor_expand_grid
Browse files Browse the repository at this point in the history
  • Loading branch information
ericmjl authored Jul 5, 2024
2 parents dee2e89 + bbb5891 commit 361e871
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 29 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
- [ENH] Added a `clean_names` method for polars - it can be used to clean the column names, or clean column values . Issue #1343 @samukweku
- [ENH] Improved performance for non-equi joins when using numba - @samukweku PR #1341
- [ENH] pandas Index,Series, DataFrame now supported in the `complete` method. - PR #1369 @samukweku
- [ENH] Improve performance for `first/last` in `conditional_join, when the join columns in the right dataframe are sorted. - PR #1382 @samukweku

## [v0.27.0] - 2024-03-21

Expand Down
87 changes: 60 additions & 27 deletions janitor/functions/conditional_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,7 +508,10 @@ def _conditional_join_compute(
for condition in conditions:
left_on, right_on, op = condition
_conditional_join_type_check(
df[left_on], right[right_on], op, use_numba
left_column=df[left_on],
right_column=right[right_on],
op=op,
use_numba=use_numba,
)
if op == _JoinOperator.STRICTLY_EQUAL.value:
eq_check = True
Expand All @@ -520,19 +523,25 @@ def _conditional_join_compute(
if (len(conditions) > 1) or eq_check:
if eq_check:
result = _multiple_conditional_join_eq(
df,
right,
conditions,
keep,
use_numba,
force,
df=df,
right=right,
conditions=conditions,
keep=keep,
use_numba=use_numba,
force=force,
)
elif le_lt_check:
result = _multiple_conditional_join_le_lt(
df, right, conditions, keep, use_numba
df=df,
right=right,
conditions=conditions,
keep=keep,
use_numba=use_numba,
)
else:
result = _multiple_conditional_join_ne(df, right, conditions, keep)
result = _multiple_conditional_join_ne(
df=df, right=right, conditions=conditions, keep=keep
)
else:
left_on, right_on, op = conditions[0]
if use_numba:
Expand All @@ -558,14 +567,16 @@ def _conditional_join_compute(
if return_matching_indices:
return result

left_index, right_index = result
return _create_frame(
df,
right,
*result,
how,
df_columns,
right_columns,
indicator,
df=df,
right=right,
left_index=left_index,
right_index=right_index,
how=how,
df_columns=df_columns,
right_columns=right_columns,
indicator=indicator,
)


Expand Down Expand Up @@ -630,9 +641,9 @@ def _multiple_conditional_join_ne(
left_on, right_on, op = first

indices = _generic_func_cond_join(
df[left_on],
right[right_on],
op,
left=df[left_on],
right=right[right_on],
op=op,
multiple_conditions=False,
keep="all",
)
Expand Down Expand Up @@ -755,7 +766,9 @@ def _multiple_conditional_join_eq(
)
if not right_is_sorted:
right_df = right_df.sort_values(right_columns)
indices = _numba_equi_join(left_df, right_df, eqs, ge_gt, le_lt)
indices = _numba_equi_join(
df=left_df, right=right_df, eqs=eqs, ge_gt=ge_gt, le_lt=le_lt
)
if not rest or (indices is None):
return indices

Expand Down Expand Up @@ -909,7 +922,25 @@ def _multiple_conditional_join_le_lt(
if condition not in (ge_gt, le_lt)
]

indices = _range_indices(df, right, ge_gt, le_lt)
if conditions:
_keep = None
else:
first = ge_gt[1]
second = le_lt[1]
right_is_sorted = (
right[first].is_monotonic_increasing
& right[second].is_monotonic_increasing
)
if right_is_sorted:
_keep = keep
else:
_keep = None

indices = _range_indices(
df=df, right=right, first=ge_gt, second=le_lt, keep=_keep
)
if _keep:
return indices

# no optimised path
# blow up the rows and prune
Expand All @@ -926,9 +957,9 @@ def _multiple_conditional_join_le_lt(
left_on, right_on, op = ge_gt

indices = _generic_func_cond_join(
df[left_on],
right[right_on],
op,
left=df[left_on],
right=right[right_on],
op=op,
multiple_conditions=False,
keep="all",
)
Expand All @@ -951,6 +982,7 @@ def _range_indices(
right: pd.DataFrame,
first: tuple,
second: tuple,
keep: str,
) -> Union[tuple[np.ndarray, np.ndarray], None]:
"""
Retrieve index positions for range/interval joins.
Expand Down Expand Up @@ -1019,8 +1051,6 @@ def _range_indices(
# this is solved by getting the cumulative max
# thus ensuring that the first match is obtained
# via a binary search
# this allows us to avoid the less efficient linear search
# of using a for loop with a break to get the first match
outcome = _generic_func_cond_join(
left=left_c,
right=right_c.cummax(),
Expand Down Expand Up @@ -1055,7 +1085,10 @@ def _range_indices(
# this also implies that the intervals
# do not overlap on the right side
return left_index, right_index[starts]

if keep == "first":
return left_index, right_index[starts]
if keep == "last":
return left_index, right_index[ends - 1]
right_index = [right_index[start:end] for start, end in zip(starts, ends)]
right_index = np.concatenate(right_index)
left_index = left_index.repeat(repeater)
Expand Down
11 changes: 9 additions & 2 deletions janitor/functions/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -813,6 +813,10 @@ def _less_than_indices(

if multiple_conditions:
return left_index, right_index, search_indices
if right_is_sorted and (keep == "last"):
indexer = np.empty_like(search_indices)
indexer[:] = len_right - 1
return left_index, right_index[indexer]
if right_is_sorted and (keep == "first"):
if any_nulls:
return left_index, right_index[search_indices]
Expand Down Expand Up @@ -902,6 +906,9 @@ def _greater_than_indices(

if multiple_conditions:
return left_index, right_index, search_indices
if right_is_sorted and (keep == "first"):
indexer = np.zeros_like(search_indices)
return left_index, right_index[indexer]
if right_is_sorted and (keep == "last"):
if any_nulls:
return left_index, right_index[search_indices - 1]
Expand Down Expand Up @@ -1043,9 +1050,9 @@ def _keep_output(keep: str, left: np.ndarray, right: np.ndarray):
grouped = pd.Series(right).groupby(left)
if keep == "first":
grouped = grouped.min()
return grouped.index, grouped.array
return grouped.index, grouped._values
grouped = grouped.max()
return grouped.index, grouped.array
return grouped.index, grouped._values


class col:
Expand Down

0 comments on commit 361e871

Please sign in to comment.