Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add unit tests for analysis module #604

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
3 changes: 3 additions & 0 deletions configs/config_all.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,9 @@ process:
lang: en # sample in which language
tokenization: false # whether to use model to tokenize documents
substrings: ['http', 'www', '.com', 'href', '//'] # incorrect substrings to remove
- replace_content_mapper: # replace all content in the text that matches a specific regular expression pattern with a designated replacement string.
pattern: null # regular expression pattern(s) to search for within text
repl: '' # replacement string(s), default is empty string
- sdxl_prompt2prompt_mapper: # use the generative model SDXL and image editing technique Prompt-to-Prompt to generate pairs of similar images.
hf_diffusion: 'stabilityai/stable-diffusion-xl-base-1.0' # model name of the SDXL model on huggingface
num_inference_steps: 50 # the larger the value, the better the image generation quality
Expand Down
2 changes: 2 additions & 0 deletions data_juicer/analysis/column_wise_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ def get_row_col(total_num, factor=2):
for each stat figure
:return: "best" number of rows and columns, and the grid list
"""
if factor <= 0 or total_num <= 0:
return 0, 0, []
n = total_num * factor # actual number of figures
now_col = factor # search from the minimum number of columns
now_row = total_num
Expand Down
10 changes: 8 additions & 2 deletions data_juicer/analysis/draw.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,12 @@
import seaborn as sns


def draw_heatmap(data, xlabels, ylables=None, figsize=None, triangle=False):
def draw_heatmap(data,
xlabels,
ylables='auto',
figsize=None,
triangle=False,
show=False):
"""
Draw heatmap of input data with special labels.

Expand Down Expand Up @@ -38,5 +43,6 @@ def draw_heatmap(data, xlabels, ylables=None, figsize=None, triangle=False):
annot_kws={'size': 8})
plt.subplots_adjust(left=.1, right=0.95, bottom=0.22, top=0.95)
fig = plt.gcf()
plt.show()
if show:
plt.show()
return fig
2 changes: 2 additions & 0 deletions data_juicer/analysis/overall_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ def analyze(self, percentiles=[], num_proc=1, skip_export=False):
pool = Pool(num_proc)
for col_name in all_columns:
this_col = self.refine_single_column(stats_and_meta[col_name])
if this_col is None:
continue
res = pool.apply_async(_single_column_analysis,
kwds={
'col': this_col,
Expand Down
2 changes: 1 addition & 1 deletion data_juicer/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import yaml
from jsonargparse import (ActionConfigFile, ArgumentParser, Namespace,
dict_to_namespace, namespace_to_dict)
from jsonargparse.typehints import ActionTypeHint
from jsonargparse._typehints import ActionTypeHint
from jsonargparse.typing import ClosedUnitInterval, NonNegativeInt, PositiveInt
from loguru import logger

Expand Down
4 changes: 3 additions & 1 deletion data_juicer/utils/unittest_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,9 @@ def find_corresponding_test_file(file_path):
test_file = file_path.replace('data_juicer', 'tests')
basename = os.path.basename(test_file)
dir = os.path.dirname(test_file)
test_file = os.path.join(dir, 'test_' + basename)
if not basename.startswith('test_') and basename != 'run.py':
basename = 'test_' + basename
test_file = os.path.join(dir, basename)
if os.path.exists(test_file):
return test_file
else:
Expand Down
32 changes: 32 additions & 0 deletions tests/analysis/test_collector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import os
import unittest

import torch.distributions

from data_juicer.analysis.collector import TextTokenDistCollector

from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase

class TextTokenDistCollectorTest(DataJuicerTestCaseBase):

test_data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)),
'..',
'..',
'demos',
'data',
'demo-dataset.jsonl')

tokenizer_model = 'EleutherAI/pythia-6.9b-deduped'

@classmethod
def tearDownClass(cls) -> None:
super().tearDownClass(cls.tokenizer_model)

def test_basic_func(self):
collector = TextTokenDistCollector(self.tokenizer_model)
dist = collector.collect(self.test_data_path, 'text')
self.assertIsInstance(dist, torch.distributions.Categorical)


if __name__ == '__main__':
unittest.main()
195 changes: 195 additions & 0 deletions tests/analysis/test_column_wise_analysis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
import os
import unittest
import pandas as pd

from data_juicer.core.data import NestedDataset
from data_juicer.analysis.column_wise_analysis import get_row_col, ColumnWiseAnalysis
from data_juicer.utils.constant import DEFAULT_PREFIX, Fields

from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase

class GetRowColFuncTest(DataJuicerTestCaseBase):

def _run_test_data(self, data):
for args, truth in data.items():
res = get_row_col(*args)
self.assertEqual(res, truth)

def test_normal_func(self):
test_data = {
(4, 2): (2, 2, [(0, 0), (0, 1), (1, 0), (1, 1)]),
(3, 3): (3, 1, [(0, 0), (1, 0), (2, 0)]),
}
self._run_test_data(test_data)

def test_marginal_total_num(self):
test_data = {
(1, 1): (1, 1, [(0, 0)]),
(0, 1): (0, 0, []),
(-1, 1): (0, 0, []),
}
self._run_test_data(test_data)

def test_marginal_factor(self):
test_data = {
(4, 0): (0, 0, []),
(4, -1): (0, 0, []),
(4, 1): (2, 2, [(0, 0), (0, 1), (1, 0), (1, 1)]),
}
self._run_test_data(test_data)


class ColumnWiseAnalysisTest(DataJuicerTestCaseBase):

def setUp(self) -> None:
data_list = [
{
Fields.meta: {
f'{DEFAULT_PREFIX}meta_str1': 'human',
f'{DEFAULT_PREFIX}meta_str2': 'sft',
'meta_str3': 'code',
},
Fields.stats: {
'stats_num_list': [4, 5, 6],
'stats_num': 3.1,
'stats_str': 'zh',
}
},
{
Fields.meta: {
f'{DEFAULT_PREFIX}meta_str1': 'assistant',
f'{DEFAULT_PREFIX}meta_str2': 'rlhf',
'meta_str3': 'math',
},
Fields.stats: {
'stats_num_list': [7, 8, 9],
'stats_num': 4.1,
'stats_str': 'en',
}
},
{
Fields.meta: {
f'{DEFAULT_PREFIX}meta_str1': 'system',
f'{DEFAULT_PREFIX}meta_str2': 'dpo',
'meta_str3': 'reasoning',
},
Fields.stats: {
'stats_num_list': [10, 11, 12],
'stats_num': 5.1,
'stats_str': 'fr',
}
},
]
self.dataset_3_sample = NestedDataset.from_list(data_list)

data_list.append({
Fields.meta: {
f'{DEFAULT_PREFIX}meta_str1': 'robot',
f'{DEFAULT_PREFIX}meta_str2': 'sft',
'meta_str3': 'edu',
},
Fields.stats: {
'stats_num_list': [13, 14, 15],
'stats_num': 2.1,
'stats_str': 'it',
}
})
self.dataset_4_sample = NestedDataset.from_list(data_list)
self.temp_output_path = 'tmp/test_column_wise_analysis/'

def tearDown(self):
if os.path.exists(self.temp_output_path):
os.system(f'rm -rf {self.temp_output_path}')

def test_init(self):
column_wise_analysis = ColumnWiseAnalysis(
self.dataset_3_sample, self.temp_output_path)
# test if the non-tag columns are removed
self.assertNotIn('meta_str3', column_wise_analysis.meta.columns)
self.assertIn(f'{DEFAULT_PREFIX}meta_str1',
column_wise_analysis.meta.columns)
self.assertIn(f'{DEFAULT_PREFIX}meta_str2',
column_wise_analysis.meta.columns)
# test if overall_result is None
self.assertIsInstance(column_wise_analysis.overall_result, pd.DataFrame)
self.assertEqual(column_wise_analysis.save_stats_in_one_file, True)

# test for specify overall_result
column_wise_analysis = ColumnWiseAnalysis(
self.dataset_3_sample, self.temp_output_path, overall_result='temp_palceholder')
self.assertEqual(column_wise_analysis.overall_result, 'temp_palceholder')
# test for save_stats_in_one_file is False
column_wise_analysis = ColumnWiseAnalysis(
self.dataset_3_sample, self.temp_output_path, save_stats_in_one_file=False)
self.assertEqual(column_wise_analysis.save_stats_in_one_file, False)

# test number of stats and meta
self.assertEqual(len(column_wise_analysis.stats), 3)
self.assertEqual(len(column_wise_analysis.meta), 3)
self.assertEqual(len(column_wise_analysis.stats.columns), 3)
self.assertEqual(len(column_wise_analysis.meta.columns), 2)

def test_basic_analyze(self):
# test basic analyze
column_wise_analysis_3_sample = ColumnWiseAnalysis(
self.dataset_3_sample, self.temp_output_path)
column_wise_analysis_3_sample.analyze()
self.assertTrue(os.path.exists(os.path.join(self.temp_output_path, 'all-stats.png')))

def test_skip_export(self):
# test skip_export
column_wise_analysis_4_sample = ColumnWiseAnalysis(
self.dataset_4_sample, self.temp_output_path)
column_wise_analysis_4_sample.analyze(skip_export=True)
for stats in column_wise_analysis_4_sample.stats.columns:
self.assertFalse(os.path.exists(
os.path.join(self.temp_output_path, f'{stats}-hist.png')))
self.assertFalse(os.path.exists(
os.path.join(self.temp_output_path, f'{stats}-box.png')))
for meta in column_wise_analysis_4_sample.meta.columns:
self.assertFalse(os.path.exists(
os.path.join(self.temp_output_path, f'{meta}-hist.png')))
self.assertFalse(os.path.exists(
os.path.join(self.temp_output_path, f'{meta}-wordcloud.png')))

def test_not_save_stats_in_one_file(self):
# test analyze with save_stats_in_one_file is False
column_wise_analysis_3_sample = ColumnWiseAnalysis(
self.dataset_3_sample, self.temp_output_path,
save_stats_in_one_file=False)
column_wise_analysis_3_sample.analyze()
for stats in column_wise_analysis_3_sample.stats.columns:
self.assertTrue(os.path.exists(
os.path.join(self.temp_output_path, f'{stats}-hist.png')))
self.assertTrue(os.path.exists(
os.path.join(self.temp_output_path, f'{stats}-box.png'))
or os.path.exists(
os.path.join(self.temp_output_path, f'{stats}-wordcloud.png')))
for meta in column_wise_analysis_3_sample.meta.columns:
self.assertTrue(os.path.exists(
os.path.join(self.temp_output_path, f'{meta}-hist.png')))
self.assertTrue(os.path.exists(
os.path.join(self.temp_output_path, f'{meta}-wordcloud.png')))

def test_save_stats_in_one_file(self):
# test analyze with percentiles is True
column_wise_analysis_3_sample = ColumnWiseAnalysis(
self.dataset_3_sample, self.temp_output_path,
save_stats_in_one_file=False)
column_wise_analysis_3_sample.analyze(show_percentiles=True)
for stats in column_wise_analysis_3_sample.stats.columns:
self.assertTrue(os.path.exists(
os.path.join(self.temp_output_path, f'{stats}-hist.png')))
self.assertTrue(os.path.exists(
os.path.join(self.temp_output_path, f'{stats}-box.png'))
or os.path.exists(
os.path.join(self.temp_output_path, f'{stats}-wordcloud.png')))
for meta in column_wise_analysis_3_sample.meta.columns:
self.assertTrue(os.path.exists(
os.path.join(self.temp_output_path, f'{meta}-hist.png')))
self.assertTrue(os.path.exists(
os.path.join(self.temp_output_path, f'{meta}-wordcloud.png')))


if __name__ == '__main__':
unittest.main()
Loading