diff --git a/src/harmony/matching/default_matcher.py b/src/harmony/matching/default_matcher.py index 1f8ada7..902f931 100644 --- a/src/harmony/matching/default_matcher.py +++ b/src/harmony/matching/default_matcher.py @@ -28,15 +28,14 @@ from typing import List import numpy as np -from numpy import ndarray -from sentence_transformers import SentenceTransformer - from harmony import match_instruments_with_function from harmony.schemas.requests.text import Instrument +from numpy import ndarray +from sentence_transformers import SentenceTransformer if ( - os.environ.get("HARMONY_SENTENCE_TRANSFORMER_PATH", None) is not None - and os.environ.get("HARMONY_SENTENCE_TRANSFORMER_PATH", None) != "" + os.environ.get("HARMONY_SENTENCE_TRANSFORMER_PATH", None) is not None + and os.environ.get("HARMONY_SENTENCE_TRANSFORMER_PATH", None) != "" ): sentence_transformer_path = os.environ["HARMONY_SENTENCE_TRANSFORMER_PATH"] else: @@ -47,24 +46,42 @@ model = SentenceTransformer(sentence_transformer_path) -def convert_texts_to_vector(texts: List) -> ndarray: - embeddings = model.encode(sentences=texts, convert_to_numpy=True) +def convert_texts_to_vector(texts: List, batch_size=50, max_batches=2000) -> ndarray: + if batch_size == 0: + embeddings = model.encode(sentences=texts, convert_to_numpy=True) - return embeddings + return embeddings + + embeddings = [] + batch_count = 0 + + # Process texts in batches + for i in range(0, len(texts), batch_size): + if batch_count >= max_batches: + break + batch = texts[i:i + batch_size] + batch_embeddings = model.encode(sentences=batch, convert_to_numpy=True) + embeddings.append(batch_embeddings) + batch_count += 1 + + # Concatenate all batch embeddings into a single NumPy array + return np.concatenate(embeddings, axis=0) def match_instruments( - instruments: List[Instrument], - query: str = None, - mhc_questions: List = [], - mhc_all_metadatas: List = [], - mhc_embeddings: np.ndarray = np.zeros((0, 0)), - texts_cached_vectors: dict[str, List[float]] = {}, + instruments: List[Instrument], + query: str = None, + mhc_questions: List = [], + mhc_all_metadatas: List = [], + mhc_embeddings: np.ndarray = np.zeros((0, 0)), + texts_cached_vectors: dict[str, List[float]] = {}, batch_size: int = 50, max_batches: int = 2000, + ) -> tuple: return match_instruments_with_function( instruments=instruments, query=query, - vectorisation_function=convert_texts_to_vector, + vectorisation_function=lambda texts: convert_texts_to_vector(texts, batch_size=batch_size, + max_batches=max_batches), mhc_questions=mhc_questions, mhc_all_metadatas=mhc_all_metadatas, mhc_embeddings=mhc_embeddings, diff --git a/tests/test_batch.py b/tests/test_batch.py new file mode 100644 index 0000000..9795c44 --- /dev/null +++ b/tests/test_batch.py @@ -0,0 +1,63 @@ +''' +MIT License + +Copyright (c) 2023 Ulster University (https://www.ulster.ac.uk). +Project: Harmony (https://harmonydata.ac.uk) +Maintainer: Thomas Wood (https://fastdatascience.com) + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +''' + +import sys +import unittest +import numpy + +sys.path.append("../src") + +from harmony.matching.default_matcher import convert_texts_to_vector + +class createModel: + def encode(self, sentences, convert_to_numpy=True): + # Generate a dummy embedding with 768 dimensions for each sentence + return numpy.array([[1] * 768] * len(sentences)) + + + +model = createModel() + +class TestBatching(unittest.TestCase): + def test_convert_texts_to_vector_with_batching(self): + # Create a list of 10 dummy texts + texts = ["text" + str(i) for i in range(10)] + + + batch_size = 5 + max_batches = 2 + embeddings = convert_texts_to_vector(texts, batch_size=batch_size, max_batches=max_batches) + + + self.assertEqual(embeddings.shape[0], 10) + + + self.assertEqual(embeddings.shape[1], 384) + + +if __name__ == "__main__": + unittest.main()