From 91d5ebfc97b8c764ed467c956626c8cff55e8304 Mon Sep 17 00:00:00 2001 From: Thomas Wood Date: Wed, 22 Jan 2025 15:00:48 +0000 Subject: [PATCH] Add H-score as output value --- src/harmony/matching/default_matcher.py | 4 +- .../instrument_to_instrument_similarity.py | 61 ++++++++++++++++ src/harmony/matching/matcher.py | 39 +++++----- src/harmony/schemas/responses/text.py | 38 +++++++++- tests/test_batch.py | 7 +- tests/test_batching_in_matcher.py | 4 +- tests/test_cluster.py | 18 ++--- tests/test_convert_text.py | 2 - tests/test_crosswalk.py | 16 +++-- ...est_instrument_to_instrument_similarity.py | 72 +++++++++++++++++++ tests/test_match.py | 44 ++++++------ tests/test_match_mhc.py | 12 ++-- tests/test_match_negative_polarity.py | 30 ++++---- tests/test_pdf_tables.py | 15 ++-- 14 files changed, 267 insertions(+), 95 deletions(-) create mode 100644 src/harmony/matching/instrument_to_instrument_similarity.py create mode 100644 tests/test_instrument_to_instrument_similarity.py diff --git a/src/harmony/matching/default_matcher.py b/src/harmony/matching/default_matcher.py index 767fe9e..cba9a9a 100644 --- a/src/harmony/matching/default_matcher.py +++ b/src/harmony/matching/default_matcher.py @@ -33,6 +33,8 @@ from numpy import ndarray from sentence_transformers import SentenceTransformer +from harmony.schemas.responses.text import HarmonyMatchResponse + if ( os.environ.get("HARMONY_SENTENCE_TRANSFORMER_PATH", None) is not None and os.environ.get("HARMONY_SENTENCE_TRANSFORMER_PATH", None) != "" @@ -76,7 +78,7 @@ def match_instruments( mhc_embeddings: np.ndarray = np.zeros((0, 0)), texts_cached_vectors: dict[str, List[float]] = {}, batch_size: int = 1000, max_batches: int = 2000, is_negate: bool = True -) -> tuple: +) -> HarmonyMatchResponse: return match_instruments_with_function( instruments=instruments, query=query, diff --git a/src/harmony/matching/instrument_to_instrument_similarity.py b/src/harmony/matching/instrument_to_instrument_similarity.py new file mode 100644 index 0000000..b1f6bc0 --- /dev/null +++ b/src/harmony/matching/instrument_to_instrument_similarity.py @@ -0,0 +1,61 @@ +import operator + +import numpy as np + +from harmony.schemas.responses.text import InstrumentToInstrumentSimilarity + + +def get_precision_recall_f1(item_to_item_similarity_matrix: np.ndarray) -> tuple: + abs_similarities_between_instruments = np.abs(item_to_item_similarity_matrix) + + coord_to_sim = {} + for y in range(abs_similarities_between_instruments.shape[0]): + for x in range(abs_similarities_between_instruments.shape[1]): + coord_to_sim[(y, x)] = abs_similarities_between_instruments[y, x] + + best_matches = set() + is_used_x = set() + is_used_y = set() + for (y, x), sim in sorted(coord_to_sim.items(), key=operator.itemgetter(1), reverse=True): + if x not in is_used_x and y not in is_used_y and abs_similarities_between_instruments[(y, x)] >= 0: + best_matches.add((x, y)) + + is_used_x.add(x) + is_used_y.add(y) + + precision = len(is_used_x) / abs_similarities_between_instruments.shape[1] + recall = len(is_used_y) / abs_similarities_between_instruments.shape[0] + + f1 = np.mean((precision, recall)) + + return precision, recall, f1 + + +def get_instrument_similarity(instruments, similarity_with_polarity): + instrument_start_pos = [] + instrument_end_pos = [] + cur_start = 0 + for instr_idx in range(len(instruments)): + instrument_start_pos.append(cur_start) + instrument_end_pos.append(cur_start + len(instruments[instr_idx].questions)) + cur_start += len(instruments[instr_idx].questions) + + instrument_to_instrument_similarities = [] + + for i in range(len(instruments)): + instrument_1 = instruments[i] + for j in range(i + 1, len(instruments)): + instrument_2 = instruments[j] + item_to_item_similarity_matrix = similarity_with_polarity[instrument_start_pos[i]:instrument_end_pos[i], + instrument_start_pos[j]:instrument_end_pos[j]] + + precision, recall, f1 = get_precision_recall_f1(item_to_item_similarity_matrix) + + instrument_to_instrument_similarities.append( + InstrumentToInstrumentSimilarity(instrument_1_idx=i, instrument_2_idx=j, + instrument_1_name=instrument_1.instrument_name, + instrument_2_name=instrument_2.instrument_name, precision=precision, + recall=recall, f1=f1) + ) + + return instrument_to_instrument_similarities diff --git a/src/harmony/matching/matcher.py b/src/harmony/matching/matcher.py index c460c2b..ad7a968 100644 --- a/src/harmony/matching/matcher.py +++ b/src/harmony/matching/matcher.py @@ -24,8 +24,9 @@ SOFTWARE. """ -import statistics import heapq +import os +import statistics from collections import Counter, OrderedDict from typing import List, Callable @@ -33,6 +34,7 @@ from numpy import dot, matmul, ndarray, matrix from numpy.linalg import norm +from harmony.matching.instrument_to_instrument_similarity import get_instrument_similarity from harmony.matching.negator import negate from harmony.schemas.catalogue_instrument import CatalogueInstrument from harmony.schemas.catalogue_question import CatalogueQuestion @@ -40,24 +42,24 @@ Instrument, Question, ) +from harmony.schemas.responses.text import HarmonyMatchResponse from harmony.schemas.text_vector import TextVector -import os - # This has been tested on 16 GB RAM production server, 1000 seems a safe number (TW, 15 Dec 2024) -def get_batch_size(default=1000): +def get_batch_size(default=1000): try: batch_size = int(os.getenv("BATCH_SIZE", default)) return max(batch_size, 0) except (ValueError, TypeError): return default + + def process_items_in_batches(items, llm_function): batch_size = get_batch_size() if batch_size == 0: - return llm_function(items) - + return llm_function(items) batches = [items[i:i + batch_size] for i in range(0, len(items), batch_size)] @@ -156,12 +158,9 @@ def create_full_text_vectors( # Texts with no cached vector texts_not_cached = [x.text for x in text_vectors if not x.vector] - - # Get vectors for all texts not cached new_vectors_list: List = process_items_in_batches(texts_not_cached, vectorisation_function) - # Create a dictionary with new vectors new_vectors_dict = {} for vector, text in zip(new_vectors_list, texts_not_cached): @@ -577,7 +576,7 @@ def match_instruments_with_function( mhc_embeddings: np.ndarray = np.zeros((0, 0)), texts_cached_vectors: dict[str, List[float]] = {}, is_negate: bool = True -) -> tuple: +) -> HarmonyMatchResponse: """ Match instruments. @@ -673,9 +672,17 @@ def match_instruments_with_function( for question in all_questions: question.topics_auto = [] - return ( - all_questions, - similarity_with_polarity, - query_similarity, - new_vectors_dict - ) + instrument_to_instrument_similarities = get_instrument_similarity(instruments, similarity_with_polarity) + + return HarmonyMatchResponse(questions=all_questions, + similarity_with_polarity=similarity_with_polarity, + query_similarity=query_similarity, + new_vectors_dict=new_vectors_dict, + instrument_to_instrument_similarities=instrument_to_instrument_similarities) + # return ( + # all_questions, + # similarity_with_polarity, + # query_similarity, + # new_vectors_dict, + # instrument_to_instrument_similarities + # ) diff --git a/src/harmony/schemas/responses/text.py b/src/harmony/schemas/responses/text.py index 2c5dd1d..4b73998 100644 --- a/src/harmony/schemas/responses/text.py +++ b/src/harmony/schemas/responses/text.py @@ -25,12 +25,25 @@ ''' -from typing import List +from typing import List, Any + +import numpy as np +from pydantic import BaseModel, Field, RootModel from harmony.schemas.catalogue_instrument import CatalogueInstrument from harmony.schemas.requests.text import Instrument from harmony.schemas.requests.text import Question -from pydantic import BaseModel, Field, RootModel + +class InstrumentToInstrumentSimilarity(BaseModel): + instrument_1_idx: int = Field( + description="The index of the first instrument in the similarity pair in the list of instruments passed to Harmony (zero-indexed)") + instrument_2_idx: int = Field( + description="The index of the second instrument in the similarity pair in the list of instruments passed to Harmony (zero-indexed)") + instrument_1_name: str = Field(description="The name of the first instrument in the similarity pai") + instrument_2_name: str = Field(description="The name of the second instrument in the similarity pai") + precision: float = Field(description="The precision score of the match between Instrument 1 and Instrument 2") + recall: float = Field(description="The recall score of the match between Instrument 1 and Instrument 2") + f1: float = Field(description="The F1 score of the match between Instrument 1 and Instrument 2") class MatchResponse(BaseModel): @@ -47,6 +60,9 @@ class MatchResponse(BaseModel): description="The closest catalogue instrument matches in the catalogue for all the instruments, " "the first index contains the best match etc." ) + instrument_to_instrument_similarities: List[InstrumentToInstrumentSimilarity] = Field( + None, description="A list of similarity values (precision, recall, F1) between instruments" + ) class SearchInstrumentsResponse(BaseModel): @@ -60,3 +76,21 @@ class InstrumentList(RootModel): class CacheResponse(BaseModel): instruments: List[Instrument] = Field(description="A list of instruments") vectors: List[dict] = Field(description="A list of vectors") + + + +# For use internally in the Python library but *not* the API because the NDarrays don't serialise +class HarmonyMatchResponse(BaseModel): + questions: List[Question] = Field( + description="The questions which were matched, in an order matching the order of the matrix" + ) + similarity_with_polarity: Any = Field(description="Matrix of cosine similarity matches") + query_similarity: Any = Field( + None, description="Similarity metric between query string and items" + ) + new_vectors_dict: dict = Field( + None, description="Vectors for the cache. These should be stored by the Harmony API to reduce unnecessary calls to the LLM" + ) + instrument_to_instrument_similarities: List[InstrumentToInstrumentSimilarity] = Field( + None, description="A list of similarity values (precision, recall, F1) between instruments" + ) diff --git a/tests/test_batch.py b/tests/test_batch.py index 9795c44..ee1316e 100644 --- a/tests/test_batch.py +++ b/tests/test_batch.py @@ -27,35 +27,34 @@ import sys import unittest + import numpy sys.path.append("../src") from harmony.matching.default_matcher import convert_texts_to_vector + class createModel: def encode(self, sentences, convert_to_numpy=True): # Generate a dummy embedding with 768 dimensions for each sentence return numpy.array([[1] * 768] * len(sentences)) - model = createModel() + class TestBatching(unittest.TestCase): def test_convert_texts_to_vector_with_batching(self): # Create a list of 10 dummy texts texts = ["text" + str(i) for i in range(10)] - batch_size = 5 max_batches = 2 embeddings = convert_texts_to_vector(texts, batch_size=batch_size, max_batches=max_batches) - self.assertEqual(embeddings.shape[0], 10) - self.assertEqual(embeddings.shape[1], 384) diff --git a/tests/test_batching_in_matcher.py b/tests/test_batching_in_matcher.py index da23f12..c047faa 100644 --- a/tests/test_batching_in_matcher.py +++ b/tests/test_batching_in_matcher.py @@ -1,11 +1,9 @@ -import sys import os +import sys import unittest -import numpy sys.path.append("../src") from unittest import TestCase, mock -from harmony.matching.matcher import get_batch_size from harmony.matching.matcher import process_items_in_batches diff --git a/tests/test_cluster.py b/tests/test_cluster.py index 882b108..64b4f23 100644 --- a/tests/test_cluster.py +++ b/tests/test_cluster.py @@ -31,24 +31,26 @@ sys.path.append("../src") from harmony.matching.cluster import cluster_questions -from harmony import create_instrument_from_list, import_instrument_into_harmony_web from harmony.schemas.requests.text import Instrument, Question class TestCluster(unittest.TestCase): def setUp(self): - self. all_questions_real = [Question(question_no="1", question_text="Feeling nervous, anxious, or on edge"), - Question(question_no="2", question_text="Not being able to stop or control worrying"), - Question(question_no="3", question_text="Little interest or pleasure in doing things"), - Question(question_no="4", question_text="Feeling down, depressed, or hopeless"), - Question(question_no="5", - question_text="Trouble falling/staying asleep, sleeping too much"), ] + self.all_questions_real = [Question(question_no="1", question_text="Feeling nervous, anxious, or on edge"), + Question(question_no="2", + question_text="Not being able to stop or control worrying"), + Question(question_no="3", + question_text="Little interest or pleasure in doing things"), + Question(question_no="4", question_text="Feeling down, depressed, or hopeless"), + Question(question_no="5", + question_text="Trouble falling/staying asleep, sleeping too much"), ] self.instruments = Instrument(questions=self.all_questions_real) def test_cluster(self): clusters_out, score_out = cluster_questions(self.instruments, 2, False) - assert(len(clusters_out) == 5) + assert (len(clusters_out) == 5) assert score_out + if __name__ == '__main__': unittest.main() diff --git a/tests/test_convert_text.py b/tests/test_convert_text.py index e73e863..02776ef 100644 --- a/tests/test_convert_text.py +++ b/tests/test_convert_text.py @@ -27,8 +27,6 @@ import sys import unittest -from harmony.parsing.text_parser import convert_text_to_instruments -from harmony.schemas.requests.text import RawFile, FileType sys.path.append("../src") diff --git a/tests/test_crosswalk.py b/tests/test_crosswalk.py index f9cd29d..fd5454c 100644 --- a/tests/test_crosswalk.py +++ b/tests/test_crosswalk.py @@ -83,8 +83,8 @@ def test_generate_crosswalk_table_empty(self): self.assertTrue(result.empty) def test_generate_crosswalk_table_real(self): - all_questions, similarity_with_polarity, _, _ = match_instruments(self.instruments) - result = generate_crosswalk_table(self.instruments, similarity_with_polarity, self.threshold, + match_response = match_instruments(self.instruments) + result = generate_crosswalk_table(self.instruments, match_response.similarity_with_polarity, self.threshold, is_allow_within_instrument_matches=True) expected_matches = [] @@ -94,7 +94,7 @@ def test_generate_crosswalk_table_real(self): self.assertEqual(len(result), len(expected_matches)) lower_threshold = 0.5 - result = generate_crosswalk_table(self.instruments, similarity_with_polarity, lower_threshold, + result = generate_crosswalk_table(self.instruments, match_response.similarity_with_polarity, lower_threshold, is_allow_within_instrument_matches=True) self.assertEqual(len(result), 1) @@ -106,8 +106,9 @@ def test_crosswalk_two_instruments_allow_many_to_one_matches(self): ["Feeling afraid, as if something awful might happen", "Feeling nervous, anxious, or on edge"]) instruments = [instrument_1, instrument_2] - all_questions, similarity_with_polarity, _, _ = match_instruments(instruments) - result = generate_crosswalk_table(instruments, similarity_with_polarity, 0, is_enforce_one_to_one=False) + match_response = match_instruments(instruments) + result = generate_crosswalk_table(instruments, match_response.similarity_with_polarity, 0, + is_enforce_one_to_one=False) self.assertEqual(2, len(result)) @@ -118,8 +119,9 @@ def test_crosswalk_two_instruments_enforce_one_to_one_matches(self): ["Feeling afraid, as if something awful might happen", "Feeling nervous, anxious, or on edge"]) instruments = [instrument_1, instrument_2] - all_questions, similarity_with_polarity, _, _ = match_instruments(instruments) - result = generate_crosswalk_table(instruments, similarity_with_polarity, 0, is_enforce_one_to_one=True) + match_response = match_instruments(instruments) + result = generate_crosswalk_table(instruments, match_response.similarity_with_polarity, 0, + is_enforce_one_to_one=True) self.assertEqual(1, len(result)) diff --git a/tests/test_instrument_to_instrument_similarity.py b/tests/test_instrument_to_instrument_similarity.py new file mode 100644 index 0000000..db8efa5 --- /dev/null +++ b/tests/test_instrument_to_instrument_similarity.py @@ -0,0 +1,72 @@ +''' +MIT License + +Copyright (c) 2023 Ulster University (https://www.ulster.ac.uk). +Project: Harmony (https://harmonydata.ac.uk) +Maintainer: Thomas Wood (https://fastdatascience.com) + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +''' + +import sys +import unittest + +sys.path.append("../src") + +from harmony import match_instruments +from harmony import create_instrument_from_list + + +class TestInstrumentToInstrumentSimilarity(unittest.TestCase): + + def test_same_instrument_twice(self): + gad_2 = create_instrument_from_list( + ["Feeling nervous, anxious, or on edge", "Not being able to stop or control worrying"]) + instruments = [gad_2, gad_2] + + match_response = match_instruments( + instruments) + + self.assertEqual(4, len(match_response.questions)) + self.assertEqual(4, len(match_response.similarity_with_polarity)) + self.assertEqual(1, len(match_response.instrument_to_instrument_similarities)) + self.assertEqual(1, match_response.instrument_to_instrument_similarities[0].precision) + self.assertEqual(1, match_response.instrument_to_instrument_similarities[0].recall) + self.assertEqual(1, match_response.instrument_to_instrument_similarities[0].f1) + + def test_two_instruments_one_a_subset_of_another(self): + gad_2 = create_instrument_from_list( + ["Feeling nervous, anxious, or on edge", "Not being able to stop or control worrying"]) + gad_1 = create_instrument_from_list( + ["Feeling nervous, anxious, or on edge"]) + instruments = [gad_2, gad_1] + + match_response = match_instruments( + instruments) + self.assertEqual(3, len(match_response.questions)) + self.assertEqual(3, len(match_response.similarity_with_polarity)) + self.assertEqual(1, len(match_response.instrument_to_instrument_similarities)) + self.assertEqual(1, match_response.instrument_to_instrument_similarities[0].precision) + self.assertEqual(0.5, match_response.instrument_to_instrument_similarities[0].recall) + self.assertEqual(0.75, match_response.instrument_to_instrument_similarities[0].f1) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_match.py b/tests/test_match.py index fe845ce..276c88a 100644 --- a/tests/test_match.py +++ b/tests/test_match.py @@ -121,36 +121,36 @@ class TestMatch(unittest.TestCase): def test_single_instrument_simple(self): - all_questions, similarity_with_polarity, query_similarity, new_vectors_dict = match_instruments([instrument_en]) - self.assertEqual(2, len(all_questions)) - self.assertEqual(2, len(similarity_with_polarity)) - self.assertLess(0.99, similarity_with_polarity[0][0]) - self.assertGreater(0.95, similarity_with_polarity[0][1]) - self.assertLess(0.99, similarity_with_polarity[1][1]) - self.assertGreater(0.95, similarity_with_polarity[1][0]) + match_response = match_instruments([instrument_en]) + self.assertEqual(2, len(match_response.questions)) + self.assertEqual(2, len(match_response.similarity_with_polarity)) + self.assertLess(0.99, match_response.similarity_with_polarity[0][0]) + self.assertGreater(0.95, match_response.similarity_with_polarity[0][1]) + self.assertLess(0.99, match_response.similarity_with_polarity[1][1]) + self.assertGreater(0.95, match_response.similarity_with_polarity[1][0]) def test_two_instruments_simple(self): - all_questions, similarity_with_polarity, query_similarity, new_vectors_dict = match_instruments( + match_response = match_instruments( [instrument_en, instrument_pt]) - self.assertEqual(4, len(all_questions)) - self.assertEqual(4, len(similarity_with_polarity)) - self.assertLess(0.99, similarity_with_polarity[0][0]) + self.assertEqual(4, len(match_response.questions)) + self.assertEqual(4, len(match_response.similarity_with_polarity)) + self.assertLess(0.99, match_response.similarity_with_polarity[0][0]) def test_single_instrument_full_metadata(self): - all_questions, similarity_with_polarity, query_similarity, new_vectors_dict = match_instruments([instrument_1]) - self.assertEqual(2, len(all_questions)) - self.assertEqual(2, len(similarity_with_polarity)) - self.assertLess(0.99, similarity_with_polarity[0][0]) - self.assertGreater(0.95, similarity_with_polarity[0][1]) - self.assertLess(0.99, similarity_with_polarity[1][1]) - self.assertGreater(0.95, similarity_with_polarity[1][0]) + match_response = match_instruments([instrument_1]) + self.assertEqual(2, len(match_response.questions)) + self.assertEqual(2, len(match_response.similarity_with_polarity)) + self.assertLess(0.99, match_response.similarity_with_polarity[0][0]) + self.assertGreater(0.95, match_response.similarity_with_polarity[0][1]) + self.assertLess(0.99, match_response.similarity_with_polarity[1][1]) + self.assertGreater(0.95, match_response.similarity_with_polarity[1][0]) def test_two_instruments_full_metadata(self): - all_questions, similarity_with_polarity, query_similarity, new_vectors_dict = match_instruments( + match_response = match_instruments( [instrument_1, instrument_2]) - self.assertEqual(4, len(all_questions)) - self.assertEqual(4, len(similarity_with_polarity)) - self.assertLess(0.99, similarity_with_polarity[0][0]) + self.assertEqual(4, len(match_response.questions)) + self.assertEqual(4, len(match_response.similarity_with_polarity)) + self.assertLess(0.99, match_response.similarity_with_polarity[0][0]) if __name__ == '__main__': diff --git a/tests/test_match_mhc.py b/tests/test_match_mhc.py index 7792598..54d829d 100644 --- a/tests/test_match_mhc.py +++ b/tests/test_match_mhc.py @@ -59,13 +59,13 @@ class TestMatchMhc(unittest.TestCase): def test_single_instrument_simple(self): - all_questions, similarity_with_polarity, query_similarity, new_vectors_dict = match_instruments([instrument_en], - mhc_questions=mhc_questions, - mhc_embeddings=mhc_embeddings, - mhc_all_metadatas=mhc_metadata) - self.assertEqual(2, len(all_questions)) + match_response = match_instruments([instrument_en], + mhc_questions=mhc_questions, + mhc_embeddings=mhc_embeddings, + mhc_all_metadatas=mhc_metadata) + self.assertEqual(2, len(match_response.questions)) - topics = all_questions[0].topics_strengths + topics = match_response.questions[0].topics_strengths top_topic = list(topics)[0] self.assertEqual("alcohol use", top_topic) self.assertLess(0.1, topics[top_topic]) diff --git a/tests/test_match_negative_polarity.py b/tests/test_match_negative_polarity.py index b98e490..1c7acaa 100644 --- a/tests/test_match_negative_polarity.py +++ b/tests/test_match_negative_polarity.py @@ -41,23 +41,23 @@ class TestMatch(unittest.TestCase): def test_single_instrument_with_negation_on_as_default(self): - all_questions, similarity_with_polarity, query_similarity, new_vectors_dict = match_instruments([instrument_en]) - self.assertEqual(2, len(all_questions)) - self.assertEqual(2, len(similarity_with_polarity)) - self.assertLess(0.99, similarity_with_polarity[0][0]) - self.assertGreater(0, similarity_with_polarity[0][1]) - self.assertLess(0.99, similarity_with_polarity[1][1]) - self.assertGreater(0, similarity_with_polarity[1][0]) + match_response = match_instruments([instrument_en]) + self.assertEqual(2, len(match_response.questions)) + self.assertEqual(2, len(match_response.similarity_with_polarity)) + self.assertLess(0.99, match_response.similarity_with_polarity[0][0]) + self.assertGreater(0, match_response.similarity_with_polarity[0][1]) + self.assertLess(0.99, match_response.similarity_with_polarity[1][1]) + self.assertGreater(0, match_response.similarity_with_polarity[1][0]) def test_single_instrument_without_negation(self): - all_questions, similarity_with_polarity, query_similarity, new_vectors_dict = match_instruments([instrument_en], - is_negate=False) - self.assertEqual(2, len(all_questions)) - self.assertEqual(2, len(similarity_with_polarity)) - self.assertLess(0.99, similarity_with_polarity[0][0]) - self.assertLess(0, similarity_with_polarity[0][1]) - self.assertLess(0.99, similarity_with_polarity[1][1]) - self.assertLess(0, similarity_with_polarity[1][0]) + match_response = match_instruments([instrument_en], + is_negate=False) + self.assertEqual(2, len(match_response.questions)) + self.assertEqual(2, len(match_response.similarity_with_polarity)) + self.assertLess(0.99, match_response.similarity_with_polarity[0][0]) + self.assertLess(0, match_response.similarity_with_polarity[0][1]) + self.assertLess(0.99, match_response.similarity_with_polarity[1][1]) + self.assertLess(0, match_response.similarity_with_polarity[1][0]) if __name__ == '__main__': diff --git a/tests/test_pdf_tables.py b/tests/test_pdf_tables.py index a2ad2fd..273a593 100644 --- a/tests/test_pdf_tables.py +++ b/tests/test_pdf_tables.py @@ -30,31 +30,28 @@ sys.path.append("../src") -from harmony import convert_pdf_to_instruments from harmony.schemas.requests.text import RawFile -from harmony import download_models - pdf_empty_table = RawFile.model_validate({ "file_id": "d39f31718513413fbfc620c6b6135d0c", "file_name": "GAD-7.pdf", "file_type": "pdf", "tables": [], - "text_content":"aaa", - "content":"" + "text_content": "aaa", + "content": "" }) pdf_non_empty_table = RawFile.model_validate({ "file_id": "d39f31718513413fbfc620c6b6135d0c", "file_name": "GAD-7.pdf", "file_type": "pdf", - 'tables': [["hello"]], - "text_content":"aaa", - "content":"" + 'tables': [["hello"]], + "text_content": "aaa", + "content": "" }) -class TestConvertPdfTables(unittest.TestCase): +class TestConvertPdfTables(unittest.TestCase): pass # Not using tables at the moment