diff --git a/docs/user-guide/gpudeduplication.rst b/docs/user-guide/gpudeduplication.rst index 990783fe..7fca4963 100644 --- a/docs/user-guide/gpudeduplication.rst +++ b/docs/user-guide/gpudeduplication.rst @@ -63,14 +63,19 @@ After ensuring your dataset has a unique ID field (or creating one with the code from nemo_curator.datasets import DocumentDataset # Initialize the deduplication object - ExactDups = ExactDuplicates(id_field="my_id", text_field="text") + exact_duplicates = ExactDuplicates( + id_field="my_id", + text_field="text", + perform_removal=True, + cache_dir="/path/to/dedup_outputs", # Recommended to specify a cache_dir if perform_removal=True + ) dataset = DocumentDataset.read_parquet( input_files="/path/to/parquet/data", backend="cudf", # or "pandas" for CPU ) - - duplicate_docs = ExactDups(dataset) + # Users who have specified perform_removal=False can split as following + duplicate_docs = exact_duplicates.identify_duplicates(dataset) """ Sample output: @@ -82,9 +87,14 @@ After ensuring your dataset has a unique ID field (or creating one with the code 107 doc_prefix-52271 0f763a2937d57b9d96bf9f220e55f2bd """ + deduplicated_dataset = exact_duplicates.remove(dataset, duplicate_docs) + + # Users who have specified perform_removal=True can get the output deduplicated dataset directly as follows + # deduplicated_dataset = exact_duplicates(dataset) + + .. tip:: - A more comprehensive example, including how to remove documents from a corpus using the list of - duplicate IDs generated from the exact deduplication step above, can be found in `examples/exact_deduplication.py `_. + A more comprehensive example, can be found in `examples/exact_deduplication.py `_. """""""""""" CLI Utility @@ -187,6 +197,7 @@ Python API cache_dir="/path/to/dedup_outputs", # must be cleared between runs id_field="my_id", text_field="text", + perform_removal=False, # dictates if deduplicated dataset or IDs of duplicates are returned seed=42, char_ngrams=24, num_buckets=20, @@ -203,6 +214,7 @@ Python API cache_dir: /path/to/dedup_outputs id_field: my_id text_field: text + perform_removal: False seed: 42 char_ngrams: 24 num_buckets: 20 @@ -226,14 +238,15 @@ Python API from nemo_curator.datasets import DocumentDataset # Initialize the deduplication object - FuzzyDups = FuzzyDuplicates(config=config, logger="./") + fuzzy_duplicates = FuzzyDuplicates(config=config, logger="./") dataset = DocumentDataset.read_json( input_files="/path/to/jsonl/data", backend="cudf", # FuzzyDuplicates only supports datasets with the cuDF backend. ) - duplicate_docs = FuzzyDups(dataset) + # Users who have specified perform_removal=False can split as following + duplicate_docs = fuzzy_duplicates.identify_duplicates(dataset) """ Sample output: my_id group @@ -244,10 +257,15 @@ Python API 4 doc_prefix-42050 154 """ + deduplicated_dataset = fuzzy_duplicates.remove(dataset, duplicate_docs) + + # Users who have specified perform_removal=True can get the output deduplicated dataset directly as follows + # deduplicated_dataset = fuzzy_duplicates(dataset) + + .. tip:: - - A more comprehensive example for the above, including how to remove documents from a corpus using the list of - duplicate IDs generated from fuzzy deduplication, can be found in `examples/fuzzy_deduplication.py `_. + - A comprehensive example can be found in `examples/fuzzy_deduplication.py `_. - The default values of ``num_buckets`` and ``hashes_per_bucket`` are set to find documents with an approximately Jaccard similarity of 0.8 or above. - Higher ``buckets_per_shuffle`` values can lead to better performance but might lead to out of memory errors. - Setting the ``false_positive_check`` flag to ``False`` is ideal for optimal performance. diff --git a/examples/exact_deduplication.py b/examples/exact_deduplication.py index 81a2d66c..e50893a7 100644 --- a/examples/exact_deduplication.py +++ b/examples/exact_deduplication.py @@ -17,8 +17,7 @@ from nemo_curator.datasets import DocumentDataset from nemo_curator.modules import ExactDuplicates -from nemo_curator.utils.distributed_utils import get_client, read_data, write_to_disk -from nemo_curator.utils.file_utils import get_all_files_paths_under +from nemo_curator.utils.distributed_utils import get_client, write_to_disk from nemo_curator.utils.script_utils import ArgumentHelper @@ -40,36 +39,33 @@ def main(args): client.run(pre_imports) t0 = time.time() - input_dataset = DocumentDataset.read_json(dataset_dir, backend=backend) + input_dataset = DocumentDataset.read_json( + dataset_dir, backend=backend, blocksize="1GiB", files_per_partition=None + ) exact_dup = ExactDuplicates( logger=log_dir, id_field=dataset_id_field, text_field=dataset_text_field, + # Decides whether output of the module is deduplicated dataset or duplicates + # If true, you should set cache_dir for performance improvement + perform_removal=False, # cache_dir=output_dir # Optionally write the output to disk ) - duplicates = exact_dup(dataset=input_dataset) + # When perform_removal=False, it will only call .identify_duplicates() and return the list of duplicate IDs. + # When perform_removal=True, then exact_dup outputs the dataset with the duplicates removed. + # It will behave by calling .identify_duplicates() and .remove() in sequence. + duplicates = exact_dup( + dataset=input_dataset + ) # or exact_dup.identify_duplicates(input_dataset) # If caching, result is a path to the output dataset. if isinstance(duplicates, str): duplicates = DocumentDataset.read_parquet(duplicates, backend=backend) # It's easy to apply dataframe operations to the dataset by using the underlying df. - - # By default all duplicate id's are included in the result - # keep 1 document from each group of duplcates and mark the others to remove - # https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.duplicated.html - docs_to_remove = duplicates.df.map_partitions( - lambda x: x[x._hashes.duplicated(keep="first")] - ) - - # When there are few duplicates we can compute the results to a list and use `isin`. - result = input_dataset.df[ - ~input_dataset.df[dataset_id_field].isin( - docs_to_remove[dataset_id_field].compute() - ) - ] + result = exact_dup.remove(input_dataset, duplicates) write_to_disk(result, output_dir, output_type="parquet") print(time.time() - t0) diff --git a/examples/fuzzy_deduplication.py b/examples/fuzzy_deduplication.py index 51344ccb..892c5222 100644 --- a/examples/fuzzy_deduplication.py +++ b/examples/fuzzy_deduplication.py @@ -68,6 +68,8 @@ def main(args): cache_dir=cache_dir, id_field=dataset_id_field, text_field=dataset_text_field, + # Decides whether output of the module is a deduplicated dataset or the IDs of the duplicates + perform_removal=False, seed=42, char_ngrams=24, num_buckets=20, @@ -77,26 +79,20 @@ def main(args): false_positive_check=False, ) fuzzy_dup = FuzzyDuplicates(logger=log_dir, config=fuzzy_dedup_config) - duplicates = fuzzy_dup(dataset=input_dataset) + + # When perform_removal=False, it will only call .identify_duplicates() and return the list of duplicate IDs. + # When perform_removal=True, then exact_dup outputs the dataset with the duplicates removed. + # It will behave by calling .identify_duplicates() and .remove() in sequence. + duplicates = fuzzy_dup( + dataset=input_dataset + ) # or fuzzy_dup.identify_duplicates(input_dataset) if duplicates is None: print("No duplicates found") print(f"Time taken:{time.time() - t0}s") return - # By default all duplicate id's and the group they belong to are included in the result - # keep 1 document from each group of duplcates and mark the others to remove - # https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.duplicated.html - docs_to_remove = duplicates.df.map_partitions( - lambda x: x[x.group.duplicated(keep="first")] - ) - - # When there are few duplicates we can compute the results to a list and use `isin`. - result = input_dataset.df[ - ~input_dataset.df[dataset_id_field].isin( - docs_to_remove[dataset_id_field].compute() - ) - ] + result = fuzzy_dup.remove(input_dataset, duplicates) write_to_disk(result, output_dir, output_type=filetype) print(f"Time taken:{time.time() - t0}s") diff --git a/nemo_curator/modules/config.py b/nemo_curator/modules/config.py index 50c71017..67bf06af 100644 --- a/nemo_curator/modules/config.py +++ b/nemo_curator/modules/config.py @@ -44,6 +44,8 @@ class FuzzyDuplicatesConfig(BaseConfig): but might lead to memory pressures and related errors. id_field: Column in the Dataset denoting document ID. text_field: Column in the Dataset denoting document content. + perform_removal: Boolean value to specify whether calling the module should remove the duplicates from + the original dataset, or return the list of IDs denoting duplicates. profile_dir: str, Default None If specified directory to write dask profile cache_dir: str, Default None @@ -64,6 +66,7 @@ class FuzzyDuplicatesConfig(BaseConfig): profile_dir: Optional[str] = None id_field: str = "id" text_field: str = "text" + perform_removal: bool = False # Minhash + LSH Config seed: int = 42 @@ -131,6 +134,11 @@ def __post_init__(self): if not 1 <= self.buckets_per_shuffle <= self.num_buckets: raise ValueError("Buckets per shuffle must be between [1, num_buckets]") + if not self.perform_removal: + warnings.warn( + "In future releases (starting with 0.8.0) the default will be True." + ) + @dataclass class SemDedupConfig(BaseConfig): diff --git a/nemo_curator/modules/exact_dedup.py b/nemo_curator/modules/exact_dedup.py index 5d65ff1b..4bdd93b3 100644 --- a/nemo_curator/modules/exact_dedup.py +++ b/nemo_curator/modules/exact_dedup.py @@ -18,7 +18,6 @@ import time import warnings from contextlib import nullcontext -from datetime import datetime from hashlib import md5 from typing import Optional, Union @@ -31,6 +30,7 @@ from nemo_curator.log import create_logger from nemo_curator.modules.base import BaseModule from nemo_curator.utils.distributed_utils import performance_report_if_with_ts_suffix +from nemo_curator.utils.duplicates_removal import remove_duplicates from nemo_curator.utils.gpu_utils import is_cudf_type @@ -45,6 +45,7 @@ def __init__( id_field: str = "id", text_field: str = "text", hash_method: str = "md5", + perform_removal: bool = False, profile_dir: Optional[str] = None, cache_dir: Optional[str] = None, ): @@ -66,9 +67,17 @@ def __init__( raise ValueError( f"{hash_method} not in supported hash_methods. Choose a hash_method from {self.SUPPORTED_HASHES}" ) + self.hash_method = hash_method self.id_field = id_field self.text_field = text_field + self.perform_removal = perform_removal + if not self.perform_removal: + warnings.warn( + "In future releases (starting with 0.8.0) the default will be True." + ) + if self.perform_removal and cache_dir is None: + warnings.warn("cache_dir is recommended to remove duplicates.") if cache_dir is None and profile_dir is not None: warnings.warn( "cache_dir for intermediate outputs is required to generate profiles" @@ -137,7 +146,7 @@ def hash_documents( # TODO: Generalize ty using self.hash_method return df.apply(lambda x: md5(x.encode()).hexdigest()) - def call(self, dataset: DocumentDataset) -> Union[DocumentDataset, str]: + def identify_duplicates(self, dataset: DocumentDataset) -> DocumentDataset: """ Find document ID's for exact duplicates in a given DocumentDataset Parameters @@ -168,10 +177,38 @@ def call(self, dataset: DocumentDataset) -> Union[DocumentDataset, str]: self._logger.info( f"Time taken for Exact Dedup Computation = {time.time() - t0}s and output written at {write_path}" ) - if is_cudf_type(result): - import dask_cudf + backend = "cudf" if is_cudf_type(result) else "pandas" + return DocumentDataset.read_parquet( + write_path, + backend=backend, + # We read with files_per_partition=1 so that groups are read in whole (and do not exist across partitions) + files_per_partition=1, + blocksize=None, + ) - result_dataset = dask_cudf.read_parquet(write_path, split_row_groups=False) - else: - result_dataset = dd.read_parquet(write_path) - return DocumentDataset(result_dataset) + def remove( + self, dataset: DocumentDataset, duplicates_to_remove: Optional[DocumentDataset] + ) -> DocumentDataset: + """ + Remove exact duplicates from a given DocumentDataset + Parameters + ---------- + dataset: DocumentDataset + The input datset to remove exact duplicates + Returns + ------- + DocumentDataset containing only non-duplicate documents + """ + result = remove_duplicates( + left=dataset.df, + duplicates=duplicates_to_remove.df, + id_field=self.id_field, + group_field="_hashes", + ) + return DocumentDataset(result) + + def call(self, dataset: DocumentDataset) -> DocumentDataset: + duplicates = self.identify_duplicates(dataset) + if self.perform_removal: + return self.remove(dataset, duplicates) + return duplicates diff --git a/nemo_curator/modules/fuzzy_dedup/fuzzyduplicates.py b/nemo_curator/modules/fuzzy_dedup/fuzzyduplicates.py index 516fe9c4..cdc98e6d 100644 --- a/nemo_curator/modules/fuzzy_dedup/fuzzyduplicates.py +++ b/nemo_curator/modules/fuzzy_dedup/fuzzyduplicates.py @@ -17,9 +17,7 @@ import logging import os import time -from typing import Union - -import dask_cudf +from typing import Optional, Union from nemo_curator.datasets import DocumentDataset from nemo_curator.log import create_logger @@ -34,6 +32,7 @@ from nemo_curator.modules.fuzzy_dedup.minhash import MinHash from nemo_curator.modules.meta import Sequential from nemo_curator.utils.distributed_utils import performance_report_if_with_ts_suffix +from nemo_curator.utils.duplicates_removal import remove_duplicates class FuzzyDuplicates(BaseModule): @@ -65,6 +64,7 @@ def __init__( self._logger = logger self.config = config + self.minhash = MinHash( seed=self.config.seed, num_hashes=self.config.num_hashes, @@ -131,7 +131,9 @@ def __init__( profile_dir=self.config.profile_dir, ) - def call(self, dataset: DocumentDataset): + def identify_duplicates( + self, dataset: DocumentDataset + ) -> Optional[DocumentDataset]: """ Parameters ---------- @@ -245,4 +247,41 @@ def call(self, dataset: DocumentDataset): print(f"Stage {stage_num}: Connected Components across buckets complete!") stage_num += 1 - return DocumentDataset(dask_cudf.read_parquet(cc_path, split_row_groups=False)) + return DocumentDataset.read_parquet( + cc_path, + backend="cudf", + # We read with files_per_partition=1 so that groups are read in whole (and do not exist across partitions) + files_per_partition=1, + blocksize=None, + ) + + def remove( + self, dataset: DocumentDataset, duplicates_to_remove: Optional[DocumentDataset] + ) -> Optional[DocumentDataset]: + """ + Remove exact duplicates from a given DocumentDataset + Parameters + ---------- + dataset: DocumentDataset + The input datset to remove exact duplicates + Returns + ------- + DocumentDataset containing only non-duplicate documents + """ + if not duplicates_to_remove: + return None + result = remove_duplicates( + left=dataset.df, + duplicates=duplicates_to_remove.df, + id_field=self.config.id_field, + group_field="group", + ) + return DocumentDataset(result) + + def call( + self, dataset: DocumentDataset, perform_removal: bool = False + ) -> DocumentDataset: + duplicates = self.identify_duplicates(dataset) + if perform_removal: + return self.remove(dataset, duplicates) + return duplicates diff --git a/nemo_curator/utils/duplicates_removal.py b/nemo_curator/utils/duplicates_removal.py new file mode 100644 index 00000000..ea654515 --- /dev/null +++ b/nemo_curator/utils/duplicates_removal.py @@ -0,0 +1,73 @@ +from typing import List, Union + +import dask.dataframe as dd + + +def deduplicate_groups( + duplicates: dd.DataFrame, group_field: str, perform_shuffle: bool +) -> dd.DataFrame: + if perform_shuffle: + # Redistribute data across partitions so that all duplicates are in same partition + duplicates_shuffled = duplicates.shuffle(on=[group_field], ignore_index=True) + else: + duplicates_shuffled = duplicates + + duplicates_to_remove = ( + duplicates_shuffled + # For each partition, keep only the duplicated rows (excluding first occurrence) + .map_partitions(lambda x: x[x[group_field].duplicated(keep="first")]).drop( + columns=group_field + ) + ) + return duplicates_to_remove + + +def left_anti_join( + left: dd.DataFrame, + right: dd.DataFrame, + left_on: Union[str, List[str]], + right_on: Union[str, List[str]], +): + assert left_on != right_on, "left_on and right_on cannot be the same" + merge = left.merge( + right=right, + how="left", + broadcast=True, # Broadcast smaller DataFrame to all partitions + left_on=left_on, + right_on=right_on, + ) + + # This effectively removes all rows that were not in duplicates_to_remove + removed_result = merge[merge[right_on].isna()].drop(columns=[right_on]) + return removed_result + + +def remove_duplicates( + left: dd.DataFrame, + duplicates: dd.DataFrame, + id_field: str, + group_field: str, + perform_shuffle: bool = False, +) -> dd.DataFrame: + if left.npartitions < duplicates.npartitions: + msg = ( + "The number of partitions in `left` is less than the number of partitions in the duplicates dataset. " + "This may lead to a shuffle join. Please re-read left and right with different partition sizes, or repartition left / right." + ) + raise ValueError(msg) + + # Create a new column name for temporary ID storage during merge + new_id_field = f"{id_field}_new" + + duplicates_to_remove = ( + deduplicate_groups(duplicates, group_field, perform_shuffle) + # Rename the ID field to avoid conflicts in the upcoming merge + .rename(columns={id_field: new_id_field})[[new_id_field]] + ) + + return left_anti_join( + left=left, + right=duplicates_to_remove, + left_on=id_field, + right_on=new_id_field, + ) diff --git a/tests/test_duplicates_removal.py b/tests/test_duplicates_removal.py new file mode 100644 index 00000000..3f406681 --- /dev/null +++ b/tests/test_duplicates_removal.py @@ -0,0 +1,208 @@ +from typing import Literal + +import pandas as pd +import pytest +from dask import dataframe as dd + +from nemo_curator.utils.duplicates_removal import remove_duplicates + + +@pytest.fixture() +def ids(): + # Dataset has id a0...a9, b0...b9, c0...c9, d0...d9 + l = [f"{group}{i}" for group in ["a", "b", "c", "d"] for i in range(10)] + return l + + +@pytest.fixture +def sample_data(ids): + df = pd.DataFrame( + { + "id": ids, + "text": [f"text for {_id}" for _id in ids], + } + ) + return dd.from_pandas(df, npartitions=4) + + +@pytest.fixture +def duplicate_data(ids): + # In each group we want to keep only the first occurrence (e.g. a1, b1, c1, d1) + df = pd.DataFrame([{"id": _id, "group": _id[0]} for _id in ids]) + return dd.from_pandas(df, npartitions=2) + + +@pytest.mark.parametrize( + "backend", + [ + "pandas", + pytest.param("cudf", marks=pytest.mark.gpu), + ], +) +@pytest.mark.parametrize("perform_shuffle", [False, True]) +def test_remove_duplicates_basic( + backend: Literal["cudf", "pandas"], + perform_shuffle: bool, + sample_data: dd.DataFrame, + duplicate_data: dd.DataFrame, +): + if perform_shuffle: + # We shuffle the data to make sure that duplicates are not in the same partition + duplicate_data = duplicate_data.sample(frac=1).reset_index(drop=True) + + sample_data = sample_data.to_backend(backend) + duplicate_data = duplicate_data.to_backend(backend) + + # Test basic duplicate removal functionality + result = remove_duplicates( + left=sample_data, + duplicates=duplicate_data, + id_field="id", + group_field="group", + perform_shuffle=perform_shuffle, + ).to_backend("pandas") + + result = result.compute() + + assert list(result.columns) == ["id", "text"] + assert len(result) == 4 + # It's not guaranteed that we'll have a0, b0, c0, d0 in the result + # So we should check the first character + assert set(result["id"].apply(lambda x: x[0]).tolist()) == set(["a", "b", "c", "d"]) + + +@pytest.mark.parametrize( + "backend", + [ + "pandas", + pytest.param("cudf", marks=pytest.mark.gpu), + ], +) +@pytest.mark.parametrize("perform_shuffle", [False, True]) +def test_remove_duplicates_all_duplicates( + backend: Literal["cudf", "pandas"], + perform_shuffle: bool, + ids: list[str], + sample_data: dd.DataFrame, +): + + duplicates = dd.from_pandas( + pd.DataFrame({"id": ids, "group": [1] * len(ids)}), npartitions=2 + ) + sample_data = sample_data.to_backend(backend) + duplicates = duplicates.to_backend(backend) + + result = remove_duplicates( + left=sample_data, + duplicates=duplicates, + id_field="id", + group_field="group", + perform_shuffle=perform_shuffle, + ).to_backend("pandas") + + assert list(result.columns) == ["id", "text"] + result = result.compute() + if perform_shuffle: + assert len(result) == 1 + else: + # If we don't shuffle, and both partitions have the same group + # in both partitions we'd be left with 1 row after "deduplication" + # and after the left-anti join we'd be left with 2 rows + assert len(result) == 2 + + +@pytest.mark.parametrize( + "backend", + [ + "pandas", + pytest.param("cudf", marks=pytest.mark.gpu), + ], +) +@pytest.mark.parametrize("perform_shuffle", [False, True]) +def test_not_remove_duplicates_unique( + backend: Literal["cudf", "pandas"], + perform_shuffle: bool, + ids: list[str], + sample_data: dd.DataFrame, +): + # We create a dataset where first 30 ids are in one group + # Next 9 ids are in distinct groups + # And last id is not mentioned in duplicates + + duplicates = dd.from_pandas( + pd.DataFrame( + { + "id": ids[:30] + ids[30:39], + "group": ["group0"] * 30 + [f"group{i}" for i in range(1, 10)], + } + ), + npartitions=2, + ) + sample_data = sample_data.to_backend(backend) + duplicates = duplicates.to_backend(backend) + if perform_shuffle: + # We shuffle the data to make sure that duplicates are not in the same partition + duplicates = duplicates.sample(frac=1, random_state=42).reset_index(drop=True) + + result = remove_duplicates( + left=sample_data, + duplicates=duplicates, + id_field="id", + group_field="group", + perform_shuffle=perform_shuffle, + ).to_backend("pandas") + + result = result.compute() + assert list(result.columns) == ["id", "text"] + if perform_shuffle: + # Since we've performed a shuffle, we know groups are collacated and there are 3 groups + # 1. 1 row from the first group of 30 + # 2. 9 rows from the 9 distinct groups + # 3. And 1 row from the last group which is not included in set of duplicates + assert len(result) == 1 + 9 + 1 + # The last 10 ids should be in the result, there would be one more from the first 30 + assert set(ids[30:]).issubset(set(result["id"].tolist())) + else: + # If we don't shuffle, we'de be left with 2 partitions both having rows from group 1 + assert len(result) == 2 + 9 + 1 + + +@pytest.mark.parametrize( + "backend", + [ + "pandas", + pytest.param("cudf", marks=pytest.mark.gpu), + ], +) +def test_remove_duplicates_raise_error( + backend: Literal["cudf", "pandas"], +): + # Create sample dataframes with specific partition counts + df1 = dd.from_pandas( + pd.DataFrame({"id": ["a1", "a2", "a3"], "text": ["text1", "text2", "text3"]}), + npartitions=2, + ) # dataset with 2 partitions + + duplicates = dd.from_pandas( + pd.DataFrame( + {"id": ["a1", "a2", "a3"], "group": ["group1", "group1", "group1"]} + ), + npartitions=3, + ) # duplicates dataset with 3 partitions + df1 = df1.to_backend(backend) + duplicates = duplicates.to_backend(backend) + + # Test that it raises ValueError when right npartitions are greater than left npartitions + with pytest.raises(ValueError) as exc_info: + remove_duplicates( + left=df1, + duplicates=duplicates, + id_field="id", + group_field="group", + ) + + expected_msg = ( + "The number of partitions in `left` is less than the number of partitions in the duplicates dataset. " + "This may lead to a shuffle join. Please re-read left and right with different partition sizes, or repartition left / right." + ) + assert str(exc_info.value) == expected_msg diff --git a/tests/test_exact_dedup.py b/tests/test_exact_dedup.py index d0408073..af2b0188 100644 --- a/tests/test_exact_dedup.py +++ b/tests/test_exact_dedup.py @@ -12,10 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +from hashlib import md5 + import pandas as pd import pytest from dask import dataframe as dd -from dask.dataframe.utils import assert_eq from nemo_curator.datasets import DocumentDataset from nemo_curator.modules import ExactDuplicates @@ -47,7 +48,29 @@ def test_dup(self, exact_dedup_data, cache_result, tmpdir): hash_method="md5", cache_dir=tmpdir if cache_result else None, ) - result = exact_dups(exact_dedup_data) - expected_df = exact_dedup_data.df.compute() - expected_df = expected_df[expected_df.text.duplicated(keep=False)] - assert_eq(result.df.id, expected_df.id, check_index=False) + duplicates = exact_dups.identify_duplicates(exact_dedup_data) + deduplicated_ds = exact_dups.remove(exact_dedup_data, duplicates) + deduplicated_ids_series = deduplicated_ds.df.to_backend("pandas").compute()[ + "id" + ] + output_deduplicated_ids = set(deduplicated_ids_series.tolist()) + assert ( + len(output_deduplicated_ids) == 3 + and 300 in output_deduplicated_ids + and len({-1, 1}.intersection(output_deduplicated_ids)) == 1 + and len({2, 4}.intersection(output_deduplicated_ids)) == 1 + ) + + duplicates_df = ( + duplicates.df.to_backend("pandas") + .compute() + .sort_values(by="id", ignore_index=True) + ) + expected_df = pd.DataFrame( + { + "id": [1, -1] + [2, 4], + "_hashes": [md5(b"abc").hexdigest()] * 2 + + [md5(b"aba").hexdigest()] * 2, + } + ).sort_values(by="id", ignore_index=True) + pd.testing.assert_frame_equal(duplicates_df, expected_df, check_like=True) diff --git a/tests/test_fuzzy_dedup.py b/tests/test_fuzzy_dedup.py index e62ab91b..bb9810ce 100644 --- a/tests/test_fuzzy_dedup.py +++ b/tests/test_fuzzy_dedup.py @@ -347,7 +347,7 @@ def test_fuzzy_dedup( jaccard_threshold=jaccard_threshold, ) fuzzy_duplicates = FuzzyDuplicates(config=config) - result = fuzzy_duplicates(fuzzy_dedup_data) + result = fuzzy_duplicates.identify_duplicates(fuzzy_dedup_data) result_df = result.df.compute() # Drop non duplicated docs result_df = result_df[result_df.group.duplicated(keep=False)] @@ -378,20 +378,31 @@ def test_different_fields(self, fuzzy_dedup_data, tmpdir): char_ngrams=5, ) fuzzy_duplicates = FuzzyDuplicates(config=config) - result = fuzzy_duplicates(fuzzy_dedup_data) - result_df = result.df.compute() + duplicates = fuzzy_duplicates.identify_duplicates(fuzzy_dedup_data) + deduplicated_ds = fuzzy_duplicates.remove(fuzzy_dedup_data, duplicates) + deduplicated_df = deduplicated_ds.df.compute() + output_deduplicated_ids = set(deduplicated_df["col0"].to_arrow().to_pylist()) + assert len(deduplicated_df) == 3 + # From each of our groups we'll have atmost one document that is not duplicated + assert ( + 300 in output_deduplicated_ids + and len({-1, 4}.intersection(output_deduplicated_ids)) == 1 + and len({1, 2}.intersection(output_deduplicated_ids)) == 1 + ) + # Drop non duplicated docs - result_df = result_df[result_df.group.duplicated(keep=False)] - result_df = result_df.groupby("group")["col0"].agg(list) + duplicates_df = duplicates.df.compute() + duplicates_df = duplicates_df[duplicates_df.group.duplicated(keep=False)] + duplicates_df = duplicates_df.groupby("group")["col0"].agg(list) # Sort to maintain uniform ordering - result_df = result_df.list.sort_values() - result_df = result_df.sort_values() + duplicates_df = duplicates_df.list.sort_values() + duplicates_df = duplicates_df.sort_values() duplicate_docs = [[4, -1], [1, 2]] expected_df = cudf.Series(duplicate_docs, name="col0") expected_df = expected_df.list.sort_values() expected_df = expected_df.sort_values() - assert_eq(expected_df, result_df, check_index=False) + assert_eq(expected_df, duplicates_df, check_index=False) @pytest.mark.xfail def test_non_uniform_indices( @@ -430,19 +441,29 @@ def test_non_uniform_indices( jaccard_threshold=0.39, ) fuzzy_duplicates = FuzzyDuplicates(config=config) - result = fuzzy_duplicates(data) - result_df = result.df.compute() + duplicates = fuzzy_duplicates.identify_duplicates(data) + deduplicated_ds = fuzzy_duplicates.remove(fuzzy_dedup_data, duplicates) + deduplicated_df = deduplicated_ds.df.compute() + output_deduplicated_ids = set(deduplicated_df["col0"].to_arrow().to_pylist()) + assert len(deduplicated_df) == 2 + # From each of our groups we'll have atmost one document that is not duplicated + assert ( + len({4, -1}.intersection(output_deduplicated_ids)) == 1 + and len({1, 2, 300}.intersection(output_deduplicated_ids)) == 1 + ) + + duplicates_df = duplicates.df.compute() # Drop non duplicated docs - result_df = result_df[result_df.group.duplicated(keep=False)] - result_df = result_df.groupby("group").id.agg(list) + duplicates_df = duplicates_df[duplicates_df.group.duplicated(keep=False)] + duplicates_df = duplicates_df.groupby("group").id.agg(list) # Sort to maintain uniform ordering - result_df = result_df.list.sort_values() - result_df = result_df.sort_values() + duplicates_df = duplicates_df.list.sort_values() + duplicates_df = duplicates_df.sort_values() expected_df = cudf.Series(duplicate_docs, name="id") expected_df = expected_df.list.sort_values() expected_df = expected_df.sort_values() - assert_eq(expected_df, result_df, check_index=False) + assert_eq(expected_df, duplicates_df, check_index=False) @pytest.mark.parametrize("num_anchors", [1, 3, 10]) def test_num_anchors(self, large_fuzzy_dedup_data, num_anchors, tmpdir): @@ -494,7 +515,7 @@ def test_no_fp_check( jaccard_threshold=0.39, ) fuzzy_duplicates = FuzzyDuplicates(config=config) - result = fuzzy_duplicates(fuzzy_dedup_data) + result = fuzzy_duplicates.identify_duplicates(fuzzy_dedup_data) result_df = result.df.compute() # Drop non duplicated docs result_df = result_df[result_df.group.duplicated(keep=False)] @@ -532,7 +553,7 @@ def test_shuffle_fail_fuzzy_dedup_data( jaccard_threshold=0.39, ) fuzzy_duplicates = FuzzyDuplicates(config=config) - result = fuzzy_duplicates(shuffle_fail_fuzzy_dedup_data) + result = fuzzy_duplicates.identify_duplicates(shuffle_fail_fuzzy_dedup_data) result_df = result.df.compute() # Drop non duplicated docs result_df = result_df[result_df.group.duplicated(keep=False)] @@ -569,7 +590,7 @@ def test_fuzzy_dedup_no_duplicates( jaccard_threshold=0.39, ) fuzzy_duplicates = FuzzyDuplicates(config=config) - result = fuzzy_duplicates(no_duplicates_fuzzy_dedup_data) + result = fuzzy_duplicates.identify_duplicates(no_duplicates_fuzzy_dedup_data) assert result is None