Skip to content

Commit

Permalink
split code / read fpp=1
Browse files Browse the repository at this point in the history
Signed-off-by: Praateek <praateekm@gmail.com>
  • Loading branch information
praateekmahajan committed Feb 7, 2025
1 parent f8040b5 commit 82f0c6c
Show file tree
Hide file tree
Showing 5 changed files with 163 additions and 47 deletions.
6 changes: 3 additions & 3 deletions nemo_curator/modules/exact_dedup.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,9 +181,9 @@ def identify_duplicates(self, dataset: DocumentDataset) -> DocumentDataset:
return DocumentDataset.read_parquet(
write_path,
backend=backend,
blocksize="1024MiB",
files_per_partition=None,
split_row_groups=False,
# we read with FPP=1 so that groups are read in whole (and don't exist across partitions)
files_per_partition=1,
blocksize=None,
)

def remove(
Expand Down
2 changes: 2 additions & 0 deletions nemo_curator/modules/fuzzy_dedup/connectedcomponents.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,8 @@ def _run_connected_components(
f"# rows in labels_df = {len(labels_df)}"
)
assert num_nodes == len(labels_df)
# Ensure all docs in the same group are in the same partition
labels_df = labels_df.shuffle(on=["group"], ignore_index=True)
labels_df.to_parquet(output_path, write_index=False, overwrite=True)
Comms.destroy()
self._logger.info(
Expand Down
6 changes: 3 additions & 3 deletions nemo_curator/modules/fuzzy_dedup/fuzzyduplicates.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,9 +250,9 @@ def identify_duplicates(
return DocumentDataset.read_parquet(
cc_path,
backend="cudf",
blocksize="1024MiB",
files_per_partition=None,
split_row_groups=False,
# we read with FPP=1 so that groups are read in whole (and don't exist across partitions)
files_per_partition=1,
blocksize=None,
)

def remove(
Expand Down
59 changes: 45 additions & 14 deletions nemo_curator/utils/duplicates_removal.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,53 @@
from typing import List, Union

import dask.dataframe as dd


def deduplicate_groups(
duplicates: dd.DataFrame, group_field: str, perform_shuffle: bool
) -> dd.DataFrame:
if perform_shuffle:
# Redistribute data across partitions so that all duplicates are in same partition
duplicates_shuffled = duplicates.shuffle(on=[group_field], ignore_index=True)
else:
duplicates_shuffled = duplicates

duplicates_to_remove = (
duplicates_shuffled
# For each partition, keep only the duplicated rows (excluding first occurrence)
.map_partitions(lambda x: x[x[group_field].duplicated(keep="first")]).drop(
columns=group_field
)
)
return duplicates_to_remove


def left_anti_join(
left: dd.DataFrame,
right: dd.DataFrame,
left_on: Union[str, List[str]],
right_on: Union[str, List[str]],
):
assert left_on != right_on, "left_on and right_on cannot be the same"
merge = left.merge(
right=right,
how="left",
broadcast=True, # Broadcast smaller DataFrame to all partitions
left_on=left_on,
right_on=right_on,
)

# This effectively removes all rows that were not in duplicates_to_remove
removed_result = merge[merge[right_on].isna()].drop(columns=[right_on])
return removed_result


def remove_duplicates(
left: dd.DataFrame,
duplicates: dd.DataFrame,
id_field: str,
group_field: str,
perform_shuffle: bool = False,
) -> dd.DataFrame:
if left.npartitions < duplicates.npartitions:
msg = (
Expand All @@ -18,25 +60,14 @@ def remove_duplicates(
new_id_field = f"{id_field}_new"

duplicates_to_remove = (
duplicates
# Redistribute data across partitions so that all duplicates are in same partition
.shuffle(on=[group_field], ignore_index=True)
# For each partition, keep only the duplicated rows (excluding first occurrence)
.map_partitions(lambda x: x[x[group_field].duplicated(keep="first")]).drop(
columns=group_field
)
deduplicate_groups(duplicates, group_field, perform_shuffle)
# Rename the ID field to avoid conflicts in the upcoming merge
.rename(columns={id_field: new_id_field})[[new_id_field]]
)

merge = left.merge(
return left_anti_join(
left=left,
right=duplicates_to_remove,
how="left",
broadcast=True, # Broadcast smaller DataFrame to all partitions
left_on=id_field,
right_on=new_id_field,
)

# This effectively removes all rows that were not in duplicates_to_remove
removed_result = merge[merge[new_id_field].isna()].drop(columns=[new_id_field])
return removed_result
137 changes: 110 additions & 27 deletions tests/test_duplicates_removal.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import random
from typing import Literal

import pandas as pd
import pytest
Expand All @@ -11,8 +11,6 @@
def ids():
# Dataset has id a0...a9, b0...b9, c0...c9, d0...d9
l = [f"{group}{i}" for group in ["a", "b", "c", "d"] for i in range(10)]
# We shuffle it to make sure all duplicates are not in the same partition
random.shuffle(l)
return l


Expand All @@ -31,17 +29,38 @@ def sample_data(ids):
def duplicate_data(ids):
# In each group we want to keep only the first occurrence (e.g. a1, b1, c1, d1)
df = pd.DataFrame([{"id": _id, "group": _id[0]} for _id in ids])
# Shuffle to make sure all duplicates are not in the same partition
return dd.from_pandas(df, npartitions=2)


@pytest.mark.parametrize(
"backend",
[
"pandas",
pytest.param("cudf", marks=pytest.mark.gpu),
],
)
@pytest.mark.parametrize("perform_shuffle", [False, True])
def test_remove_duplicates_basic(
sample_data: dd.DataFrame, duplicate_data: dd.DataFrame
backend: Literal["cudf", "pandas"],
perform_shuffle: bool,
sample_data: dd.DataFrame,
duplicate_data: dd.DataFrame,
):
if perform_shuffle:
# We shuffle the data to make sure that duplicates are not in the same partition
duplicate_data = duplicate_data.sample(frac=1).reset_index(drop=True)

sample_data = sample_data.to_backend(backend)
duplicate_data = duplicate_data.to_backend(backend)

# Test basic duplicate removal functionality
result = remove_duplicates(
left=sample_data, duplicates=duplicate_data, id_field="id", group_field="group"
)
left=sample_data,
duplicates=duplicate_data,
id_field="id",
group_field="group",
perform_shuffle=perform_shuffle,
).to_backend("pandas")

result = result.compute()

Expand All @@ -52,22 +71,60 @@ def test_remove_duplicates_basic(
assert set(result["id"].apply(lambda x: x[0]).tolist()) == set(["a", "b", "c", "d"])


def test_remove_duplicates_all_duplicates(ids: list[str], sample_data: dd.DataFrame):
@pytest.mark.parametrize(
"backend",
[
"pandas",
pytest.param("cudf", marks=pytest.mark.gpu),
],
)
@pytest.mark.parametrize("perform_shuffle", [False, True])
def test_remove_duplicates_all_duplicates(
backend: Literal["cudf", "pandas"],
perform_shuffle: bool,
ids: list[str],
sample_data: dd.DataFrame,
):

duplicates = dd.from_pandas(
pd.DataFrame({"id": ids, "group": [1] * len(ids)}), npartitions=2
)
sample_data = sample_data.to_backend(backend)
duplicates = duplicates.to_backend(backend)

result = remove_duplicates(
left=sample_data, duplicates=duplicates, id_field="id", group_field="group"
)
left=sample_data,
duplicates=duplicates,
id_field="id",
group_field="group",
perform_shuffle=perform_shuffle,
).to_backend("pandas")

result = result.compute()
assert list(result.columns) == ["id", "text"]
# Should keep only one of the occurrences
assert len(result) == 1


def test_not_remove_duplicates_unique(ids: list[str], sample_data: dd.DataFrame):
result = result.compute()
if perform_shuffle:
assert len(result) == 1
else:
# If we don't shuffle, and both partitions have the same group
# in both partitions we'd be left with 1 row after "deduplication"
# and after the left-anti join we'd be left with 2 rows
assert len(result) == 2


@pytest.mark.parametrize(
"backend",
[
"pandas",
pytest.param("cudf", marks=pytest.mark.gpu),
],
)
@pytest.mark.parametrize("perform_shuffle", [False, True])
def test_not_remove_duplicates_unique(
backend: Literal["cudf", "pandas"],
perform_shuffle: bool,
ids: list[str],
sample_data: dd.DataFrame,
):
# We create a dataset where first 30 ids are in one group
# Next 9 ids are in distinct groups
# And last id is not mentioned in duplicates
Expand All @@ -81,21 +138,45 @@ def test_not_remove_duplicates_unique(ids: list[str], sample_data: dd.DataFrame)
),
npartitions=2,
)
sample_data = sample_data.to_backend(backend)
duplicates = duplicates.to_backend(backend)
if perform_shuffle:
# We shuffle the data to make sure that duplicates are not in the same partition
duplicates = duplicates.sample(frac=1, random_state=42).reset_index(drop=True)

result = remove_duplicates(
left=sample_data, duplicates=duplicates, id_field="id", group_field="group"
)
left=sample_data,
duplicates=duplicates,
id_field="id",
group_field="group",
perform_shuffle=perform_shuffle,
).to_backend("pandas")

result = result.compute()
assert list(result.columns) == ["id", "text"]
# It has 1 row from the first group of 30
# 9 rows from the 9 distinct groups
# And 1 row from the last group which is not included in set of duplicates
assert len(result) == 1 + 9 + 1
# The last 10 ids should be in the result, there would be one more from the first 30
assert set(ids[30:]).issubset(set(result["id"].tolist()))


def test_remove_duplicates_raise_error():
if perform_shuffle:
# Since we've performed a shuffle, we know groups are collacated and there are 3 groups
# 1. 1 row from the first group of 30
# 2. 9 rows from the 9 distinct groups
# 3. And 1 row from the last group which is not included in set of duplicates
assert len(result) == 1 + 9 + 1
# The last 10 ids should be in the result, there would be one more from the first 30
assert set(ids[30:]).issubset(set(result["id"].tolist()))
else:
# If we don't shuffle, we'de be left with 2 partitions both having rows from group 1
assert len(result) == 2 + 9 + 1


@pytest.mark.parametrize(
"backend",
[
"pandas",
pytest.param("cudf", marks=pytest.mark.gpu),
],
)
def test_remove_duplicates_raise_error(
backend: Literal["cudf", "pandas"],
):
# Create sample dataframes with specific partition counts
df1 = dd.from_pandas(
pd.DataFrame({"id": ["a1", "a2", "a3"], "text": ["text1", "text2", "text3"]}),
Expand All @@ -108,6 +189,8 @@ def test_remove_duplicates_raise_error():
),
npartitions=3,
) # duplicates dataset with 3 partitions
df1 = df1.to_backend(backend)
duplicates = duplicates.to_backend(backend)

# Test that it raises ValueError when right npartitions are greater than left npartitions
with pytest.raises(ValueError) as exc_info:
Expand Down

0 comments on commit 82f0c6c

Please sign in to comment.