diff --git a/code_bert_score/score.py b/code_bert_score/score.py index e74a6a8..b65a178 100644 --- a/code_bert_score/score.py +++ b/code_bert_score/score.py @@ -164,10 +164,10 @@ def score( baselines = torch.from_numpy(pd.read_csv(baseline_path).iloc[num_layers].to_numpy())[1:].float() else: baselines = torch.from_numpy(pd.read_csv(baseline_path).to_numpy())[:, 1:].unsqueeze(1).float() - - all_preds = (all_preds - baselines) / (1 - baselines) else: - print(f"Warning: Baseline not Found for {model_type} on {lang} at {baseline_path}", file=sys.stderr) + # print(f"Warning: Baseline not Found for {model_type} on {lang} at {baseline_path}", file=sys.stderr) + baselines = 0.5 + all_preds = (all_preds - baselines) / (1 - baselines) out = all_preds[..., 0], all_preds[..., 1], all_preds[..., 2], all_preds[..., 3] # P, R, F, F3 @@ -252,9 +252,11 @@ def plot_example( baseline_path = os.path.join(os.path.dirname(__file__), f"rescale_baseline/{lang}/{model_type}.tsv") if os.path.isfile(baseline_path): baselines = torch.from_numpy(pd.read_csv(baseline_path).iloc[num_layers].to_numpy())[1:].float() - sim = (sim - baselines[2].item()) / (1 - baselines[2].item()) + baseline = baselines[2].item() else: - print(f"Warning: Baseline not Found for {model_type} on {lang} at {baseline_path}", file=sys.stderr) + # print(f"Warning: Baseline not Found for {model_type} on {lang} at {baseline_path}", file=sys.stderr) + baseline = 0.5 + sim = (sim - baseline) / (1 - baseline) import matplotlib.pyplot as plt from mpl_toolkits.axes_grid1 import make_axes_locatable diff --git a/example.py b/example.py index 21a90bd..381c229 100644 --- a/example.py +++ b/example.py @@ -56,16 +56,19 @@ def print_results(predictions, refs, pred_results): with open('idf_dicts/java_idf.pkl', 'rb') as f: java_idf = pickle.load(f) - # pred_results = code_bert_score.score([''],['a'], sources=["a"], lang="python") - # pred_results = code_bert_score.score(cands=predictions, refs=refs, no_punc=True, lang='java', idf=java_idf) - # print_results(predictions, refs, pred_results) + pred_results = code_bert_score.score([''],['a'], sources=["a"], lang="python") + pred_results = code_bert_score.score(cands=predictions, refs=refs, no_punc=True, lang='java', idf=java_idf) + print_results(predictions, refs, pred_results) - # print('When providing the context: "find the index of target in this.elements"') - # pred_results = code_bert_score.score(cands=predictions, refs=refs, no_punc=True, lang='java', idf=java_idf, sources=['find the index of target in this.elements'] * 2) - # print_results(predictions, refs, pred_results) + print('When providing the context: "find the index of target in this.elements"') + pred_results = code_bert_score.score(cands=predictions, refs=refs, no_punc=True, lang='java', idf=java_idf, sources=['find the index of target in this.elements'] * 2) + print_results(predictions, refs, pred_results) with open('idf_dicts/python_idf.pkl', 'rb') as f: python_idf = pickle.load(f) pred_results = code_bert_score.score(cands=['math.sqrt(x)'], refs=[['x ** 0.5']], no_punc=True, lang='python', idf=python_idf) + print(pred_results) + + pred_results = code_bert_score.score(cands=['math.sqrt(x)'], refs=[['x ** 0.5']], rescale_with_baseline=True, lang='en') print(pred_results) \ No newline at end of file