Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add unit tests for analysis module #604

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
2 changes: 2 additions & 0 deletions data_juicer/analysis/column_wise_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ def get_row_col(total_num, factor=2):
for each stat figure
:return: "best" number of rows and columns, and the grid list
"""
if factor <= 0 or total_num <= 0:
return 0, 0, []
n = total_num * factor # actual number of figures
now_col = factor # search from the minimum number of columns
now_row = total_num
Expand Down
10 changes: 8 additions & 2 deletions data_juicer/analysis/draw.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,12 @@
import seaborn as sns


def draw_heatmap(data, xlabels, ylables=None, figsize=None, triangle=False):
def draw_heatmap(data,
xlabels,
ylables='auto',
figsize=None,
triangle=False,
show=False):
"""
Draw heatmap of input data with special labels.

Expand Down Expand Up @@ -38,5 +43,6 @@ def draw_heatmap(data, xlabels, ylables=None, figsize=None, triangle=False):
annot_kws={'size': 8})
plt.subplots_adjust(left=.1, right=0.95, bottom=0.22, top=0.95)
fig = plt.gcf()
plt.show()
if show:
plt.show()
return fig
2 changes: 2 additions & 0 deletions data_juicer/analysis/overall_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ def analyze(self, percentiles=[], num_proc=1, skip_export=False):
pool = Pool(num_proc)
for col_name in all_columns:
this_col = self.refine_single_column(stats_and_meta[col_name])
if this_col is None:
continue
res = pool.apply_async(_single_column_analysis,
kwds={
'col': this_col,
Expand Down
4 changes: 3 additions & 1 deletion data_juicer/utils/unittest_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,9 @@ def find_corresponding_test_file(file_path):
test_file = file_path.replace('data_juicer', 'tests')
basename = os.path.basename(test_file)
dir = os.path.dirname(test_file)
test_file = os.path.join(dir, 'test_' + basename)
if not basename.startswith('test_'):
basename = 'test_' + basename
test_file = os.path.join(dir, basename)
if os.path.exists(test_file):
return test_file
else:
Expand Down
32 changes: 32 additions & 0 deletions tests/analysis/test_collector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import os
import unittest

import torch.distributions

from data_juicer.analysis.collector import TextTokenDistCollector

from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase

class TextTokenDistCollectorTest(DataJuicerTestCaseBase):

test_data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)),
'..',
'..',
'demos',
'data',
'demo-dataset.jsonl')

tokenizer_model = 'EleutherAI/pythia-6.9b-deduped'

@classmethod
def tearDownClass(cls) -> None:
super().tearDownClass(cls.tokenizer_model)

def test_init(self):
collector = TextTokenDistCollector(self.tokenizer_model)
dist = collector.collect(self.test_data_path, 'text')
self.assertIsInstance(dist, torch.distributions.Categorical)


if __name__ == '__main__':
unittest.main()
195 changes: 195 additions & 0 deletions tests/analysis/test_column_wise_analysis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
import os
import unittest
import pandas as pd

from data_juicer.core.data import NestedDataset
from data_juicer.analysis.column_wise_analysis import get_row_col, ColumnWiseAnalysis
from data_juicer.utils.constant import DEFAULT_PREFIX, Fields

from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase

class GetRowColFuncTest(DataJuicerTestCaseBase):

def _run_test_data(self, data):
for args, truth in data.items():
res = get_row_col(*args)
self.assertEqual(res, truth)

def test_normal_func(self):
test_data = {
(4, 2): (2, 2, [(0, 0), (0, 1), (1, 0), (1, 1)]),
(3, 3): (3, 1, [(0, 0), (1, 0), (2, 0)]),
}
self._run_test_data(test_data)

def test_marginal_total_num(self):
test_data = {
(1, 1): (1, 1, [(0, 0)]),
(0, 1): (0, 0, []),
(-1, 1): (0, 0, []),
}
self._run_test_data(test_data)

def test_marginal_factor(self):
test_data = {
(4, 0): (0, 0, []),
(4, -1): (0, 0, []),
(4, 1): (2, 2, [(0, 0), (0, 1), (1, 0), (1, 1)]),
}
self._run_test_data(test_data)


class ColumnWiseAnalysisTest(DataJuicerTestCaseBase):

def setUp(self) -> None:
data_list = [
{
Fields.meta: {
f'{DEFAULT_PREFIX}meta_str1': 'human',
f'{DEFAULT_PREFIX}meta_str2': 'sft',
'meta_str3': 'code',
},
Fields.stats: {
'stats_num_list': [4, 5, 6],
'stats_num': 3.1,
'stats_str': 'zh',
}
},
{
Fields.meta: {
f'{DEFAULT_PREFIX}meta_str1': 'assistant',
f'{DEFAULT_PREFIX}meta_str2': 'rlhf',
'meta_str3': 'math',
},
Fields.stats: {
'stats_num_list': [7, 8, 9],
'stats_num': 4.1,
'stats_str': 'en',
}
},
{
Fields.meta: {
f'{DEFAULT_PREFIX}meta_str1': 'system',
f'{DEFAULT_PREFIX}meta_str2': 'dpo',
'meta_str3': 'reasoning',
},
Fields.stats: {
'stats_num_list': [10, 11, 12],
'stats_num': 5.1,
'stats_str': 'fr',
}
},
]
self.dataset_3_sample = NestedDataset.from_list(data_list)

data_list.append({
Fields.meta: {
f'{DEFAULT_PREFIX}meta_str1': 'robot',
f'{DEFAULT_PREFIX}meta_str2': 'sft',
'meta_str3': 'edu',
},
Fields.stats: {
'stats_num_list': [13, 14, 15],
'stats_num': 2.1,
'stats_str': 'it',
}
})
self.dataset_4_sample = NestedDataset.from_list(data_list)
self.temp_output_path = 'tmp/test_column_wise_analysis/'

def tearDown(self):
if os.path.exists(self.temp_output_path):
os.system(f'rm -rf {self.temp_output_path}')

def test_init(self):
column_wise_analysis = ColumnWiseAnalysis(
self.dataset_3_sample, self.temp_output_path)
# test if the non-tag columns are removed
self.assertNotIn('meta_str3', column_wise_analysis.meta.columns)
self.assertIn(f'{DEFAULT_PREFIX}meta_str1',
column_wise_analysis.meta.columns)
self.assertIn(f'{DEFAULT_PREFIX}meta_str2',
column_wise_analysis.meta.columns)
# test if overall_result is None
self.assertIsInstance(column_wise_analysis.overall_result, pd.DataFrame)
self.assertEqual(column_wise_analysis.save_stats_in_one_file, True)

# test for specify overall_result
column_wise_analysis = ColumnWiseAnalysis(
self.dataset_3_sample, self.temp_output_path, overall_result='temp_palceholder')
self.assertEqual(column_wise_analysis.overall_result, 'temp_palceholder')
# test for save_stats_in_one_file is False
column_wise_analysis = ColumnWiseAnalysis(
self.dataset_3_sample, self.temp_output_path, save_stats_in_one_file=False)
self.assertEqual(column_wise_analysis.save_stats_in_one_file, False)

# test number of stats and meta
self.assertEqual(len(column_wise_analysis.stats), 3)
self.assertEqual(len(column_wise_analysis.meta), 3)
self.assertEqual(len(column_wise_analysis.stats.columns), 3)
self.assertEqual(len(column_wise_analysis.meta.columns), 2)

def test_basic_analyze(self):
# test basic analyze
column_wise_analysis_3_sample = ColumnWiseAnalysis(
self.dataset_3_sample, self.temp_output_path)
column_wise_analysis_3_sample.analyze()
self.assertTrue(os.path.exists(os.path.join(self.temp_output_path, 'all-stats.png')))

def test_skip_export(self):
# test skip_export
column_wise_analysis_4_sample = ColumnWiseAnalysis(
self.dataset_4_sample, self.temp_output_path)
column_wise_analysis_4_sample.analyze(skip_export=True)
for stats in column_wise_analysis_4_sample.stats.columns:
self.assertFalse(os.path.exists(
os.path.join(self.temp_output_path, f'{stats}-hist.png')))
self.assertFalse(os.path.exists(
os.path.join(self.temp_output_path, f'{stats}-box.png')))
for meta in column_wise_analysis_4_sample.meta.columns:
self.assertFalse(os.path.exists(
os.path.join(self.temp_output_path, f'{meta}-hist.png')))
self.assertFalse(os.path.exists(
os.path.join(self.temp_output_path, f'{meta}-wordcloud.png')))

def test_not_save_stats_in_one_file(self):
# test analyze with save_stats_in_one_file is False
column_wise_analysis_3_sample = ColumnWiseAnalysis(
self.dataset_3_sample, self.temp_output_path,
save_stats_in_one_file=False)
column_wise_analysis_3_sample.analyze()
for stats in column_wise_analysis_3_sample.stats.columns:
self.assertTrue(os.path.exists(
os.path.join(self.temp_output_path, f'{stats}-hist.png')))
self.assertTrue(os.path.exists(
os.path.join(self.temp_output_path, f'{stats}-box.png'))
or os.path.exists(
os.path.join(self.temp_output_path, f'{stats}-wordcloud.png')))
for meta in column_wise_analysis_3_sample.meta.columns:
self.assertTrue(os.path.exists(
os.path.join(self.temp_output_path, f'{meta}-hist.png')))
self.assertTrue(os.path.exists(
os.path.join(self.temp_output_path, f'{meta}-wordcloud.png')))

def test_save_stats_in_one_file(self):
# test analyze with percentiles is True
column_wise_analysis_3_sample = ColumnWiseAnalysis(
self.dataset_3_sample, self.temp_output_path,
save_stats_in_one_file=False)
column_wise_analysis_3_sample.analyze(show_percentiles=True)
for stats in column_wise_analysis_3_sample.stats.columns:
self.assertTrue(os.path.exists(
os.path.join(self.temp_output_path, f'{stats}-hist.png')))
self.assertTrue(os.path.exists(
os.path.join(self.temp_output_path, f'{stats}-box.png'))
or os.path.exists(
os.path.join(self.temp_output_path, f'{stats}-wordcloud.png')))
for meta in column_wise_analysis_3_sample.meta.columns:
self.assertTrue(os.path.exists(
os.path.join(self.temp_output_path, f'{meta}-hist.png')))
self.assertTrue(os.path.exists(
os.path.join(self.temp_output_path, f'{meta}-wordcloud.png')))


if __name__ == '__main__':
unittest.main()
Loading
Loading