-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcount_near_dups.py
188 lines (174 loc) · 7.32 KB
/
count_near_dups.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
# Copyright (C) 2023 National Research Council Canada.
#
# This file is part of vardial-2023.
#
# vardial-2023 is free software: you can redistribute it and/or modify it under
# the terms of the GNU General Public License as published by the Free Software
# Foundation, either version 3 of the License, or (at your option) any later
# version.
#
# vardial-2023 is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License along with
# vardial-2023. If not, see https://www.gnu.org/licenses/.
import os, argparse, pickle, logging, random
from difflib import ndiff
from tqdm import tqdm
import numpy as np
from scipy.sparse import tril, find
from utils import load_lines
DOC="""
Count near-duplicates in dataset.
"""
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt='%m/%d/%Y %H:%M:%S',
level=logging.INFO)
logger = logging.getLogger(__name__)
def main(args):
logger.info(f"Loading data from {args.path_pickle}...")
with open(args.path_pickle, 'rb') as f:
data = pickle.load(f)
mat = data["matrix"]
mat_labels = data["labels"]
mat = mat.tocoo()
logger.info(f"Type of matrix: {type(mat)}")
logger.info(f"Shape of matrix: {mat.shape}")
logger.info(f"Nb nnz: {mat.nnz}")
logger.info(f"Loading texts from {args.path_texts} and labels from {args.path_labels}...")
texts = load_lines(args.path_texts)
labels = load_lines(args.path_labels)
assert len(texts) == len(labels)
# Map texts to labels
text_to_labels = {}
for x,y in zip(texts, labels):
if x not in text_to_labels:
text_to_labels[x] = []
text_to_labels[x].append(y)
# Count near duplicates
logger.info(f"Identifying near duplicates with sim>={args.min_sim}...")
mat = tril(mat, k=1)
rows, cols, vals = find(mat)
nd = [(rows[i],cols[i],vals[i]) for i in np.where(vals>=args.min_sim)[0]]
logger.info("Counting near duplicates that have different sets of unique labels...")
ambig = set()
pbar = tqdm(total=len(nd))
for (i,j,s) in nd:
li = set(text_to_labels[mat_labels[i]])
lj = set(text_to_labels[mat_labels[j]])
if len(li.symmetric_difference(lj)):
ambig.add((i,j,s))
pbar.update(1)
pbar.close()
logger.info(f"Nb near duplicate pairs: {len(nd)}")
logger.info(f"Nb ambiguous duplicate pairs: {len(ambig)}/{len(nd)}")
# Count frequency of edits, and format them for printing
logger.info("Counting frequency of edits...")
ambig = sorted(ambig, key=lambda x:x[2], reverse=True)
edit2freq = {}
pretty_ambig = []
pbar = tqdm(total=len(ambig))
delim = " " if args.ndiff_type == "token" else ""
for ix,(i,j,sim) in enumerate(ambig):
# Do diff and format edit blocks
ops = []
prev_op = None
block = []
if args.ndiff_type == "char":
diff = ndiff(mat_labels[i],mat_labels[j])
elif args.ndiff_type == "token":
tokens_i = [x+" " for x in mat_labels[i].split(" ")]
tokens_j = [x+" " for x in mat_labels[j].split(" ")]
diff = ndiff(tokens_i, tokens_j)
for x in diff:
op = x[0]
if op != prev_op:
if len(block):
pretty_op = "=" if prev_op == " " else prev_op
pretty_block = delim.join(block)
ops.append((pretty_op, pretty_block))
block = []
prev_op = op
block.append(x[2:])
if len(block):
pretty_op = "=" if prev_op == " " else prev_op
pretty_block = delim.join(block)
ops.append((pretty_op, pretty_block))
# Store message for printing
msg = f"************* Example {ix+1} ****************\n"
msg += f"- Sim={sim:.5f}\n"
msg += f"- Diff:\n"
for (symbol, string) in ops:
msg += f" {symbol} [{string}]\n"
msg += f"- Text {i}: {mat_labels[i]}\n"
msg += f"- Labels of text {i}: {text_to_labels[mat_labels[i]]}\n"
msg += f"- Text {j}: {mat_labels[j]}\n"
msg += f"- Labels of text {j}: {text_to_labels[mat_labels[j]]}\n"
pretty_ambig.append(msg)
# Concatenate edit ops, and count them
block = []
edits = []
for (symbol, string) in ops:
if symbol == "=":
if len(block):
edits.append(tuple(block))
block = []
else:
block.append(symbol)
block.append(string)
if len(block):
edits.append(tuple(block))
for edit in edits:
if edit not in edit2freq:
edit2freq[edit] = 0
edit2freq[edit] += 1
pbar.update(1)
pbar.close()
# Print most frequent edits
logger.info("Most frequent edits:")
k = 10
topk = sorted(edit2freq.keys(), key=edit2freq.get, reverse=True)[:k]
for i,edit in enumerate(topk):
freq = edit2freq[edit]
logger.info(f" {i+1}. Freq={freq}")
for i in range(0, len(edit)//2, 2):
symbol = edit[i]
string = edit[i+1]
logger.info(f" {symbol} [{string}]")
# Write ambiguous near duplicates
if args.write_to:
if args.seed:
random.seed(args.seed)
if args.ndiff_sample_size:
logger.info(f"Sampling {args.ndiff_sample_size} ndiff examples at random...")
assert args.ndiff_sample_size < len(pretty_ambig)
sample = np.random.choice(len(pretty_ambig), args.ndiff_sample_size, False)
sample.sort()
pretty_ambig = [pretty_ambig[i] for i in sample]
logger.info(f"Writing ambiguous near duplicates to {args.write_to}...")
with open(args.write_to, 'w') as f:
for msg in pretty_ambig:
f.write(msg + "\n")
logger.info("Done.")
return
if __name__ == "__main__":
p = argparse.ArgumentParser(description=DOC)
p.add_argument("path_pickle", help="Path of pickle file containing similarity matrix and list of unique texts")
p.add_argument("path_texts", help="Path of text file containing texts")
p.add_argument("path_labels", help="Path of text file containing labels")
p.add_argument("--min_sim", "-m", type=float, default=0, help="minimum similarity for a text pair to be considered near duplicates")
p.add_argument("--write_to", "-w", type=str, help="Optional path of file to write ambiguous near duplicates to.")
p.add_argument("--ndiff_type", "-n", choices=["char", "token"], default="char", help="Type of ndiff used to highligh differences (if --write_to is specified)")
p.add_argument("--ndiff_sample_size", "-i", type=int, help="Number of ndiffs to sample (if --write_to is specified)")
p.add_argument("--seed", "-s", help="Seed for RNG (used for sampling ndiff outputs")
args = p.parse_args()
assert args.min_sim >= 0
assert args.min_sim < 1
if args.write_to:
assert not os.path.exists(args.write_to)
if args.ndiff_sample_size:
assert args.write_to
if args.seed:
assert args.write_to
main(args)