From 997c2270911c33851bea7d0be8685bda7a03b00e Mon Sep 17 00:00:00 2001 From: tikendraw Date: Thu, 19 Dec 2024 17:39:52 +0530 Subject: [PATCH 1/2] Crashes on malformed image Fixes #4 , logs bad images --- v2/embed_model.py | 68 ++++++++++++++++++++++++++++--------------- v2/embedding_store.py | 2 +- 2 files changed, 46 insertions(+), 24 deletions(-) diff --git a/v2/embed_model.py b/v2/embed_model.py index 15b2b71..a41f848 100644 --- a/v2/embed_model.py +++ b/v2/embed_model.py @@ -3,11 +3,10 @@ from chromadb.api.types import Documents, EmbeddingFunction, Embeddings from PIL import Image import torch -import numpy as np - +import logging class EfficientNetEmbeddingFunction(EmbeddingFunction[Documents]): - """To use this EmbeddingFunction, you must have the google.generativeai Python package installed and have a PaLM API key.""" + """To use this EmbeddingFunction, you must have the transformers Python package installed.""" def __init__(self, model_name: str = "google/efficientnet-b0", device="cuda"): self.device = ( @@ -17,23 +16,27 @@ def __init__(self, model_name: str = "google/efficientnet-b0", device="cuda"): self.image_processor = AutoImageProcessor.from_pretrained(model_name) self.model = AutoModel.from_pretrained(model_name).to(self.device) - def load_image(self, image: str | Path | Image.Image | np.ndarray): + def load_image(self, image: str | Path ): """ Loads an image and processes it using the model's image processor. Reusable helper function for single and batch image processing. """ - if isinstance(image, str) or isinstance(image, Path): - image = Image.open(image) + try: + if isinstance(image, str) or isinstance(image, Path): + image = Image.open(image) + else: + logging.error('Image type not supported(knowingly), use str or pathlib.Path ') + return None - if isinstance(image, np.ndarray): - image = Image.fromarray(image) + if image.mode != "RGB": + image = image.convert("RGB") - if image.mode != "RGB": - image = image.convert("RGB") - - image = self.image_processor(images=image, return_tensors="pt").to(self.device) + image = self.image_processor(images=image, return_tensors="pt").to(self.device) - return image["pixel_values"] + return image["pixel_values"] + except OSError as e: + logging.error(f"Error processing image {image}: {e}") + return None def _embed(self, pixel_values): """ @@ -48,27 +51,46 @@ def embed_image(self, image): Embeds a single image by first loading and processing it, then passing it to the model. """ pixel_values = self.load_image(image) - embeddings = self._embed(pixel_values) - return embeddings.cpu().numpy().tolist() + if pixel_values is not None: + embeddings = self._embed(pixel_values) + return embeddings.cpu().numpy().tolist() + else: + logging.error(f"Failed to load image: {image}") + return None def batch_embed_images( - self, images: list[str | Path | Image.Image | np.ndarray], batch_size: int = 32 + self, images: list[str | Path], batch_size: int = 32 ): """ Embeds a batch of images, processing them in batches of the specified batch_size. """ all_embeddings = [] - + bad_images = [] for i in range(0, len(images), batch_size): batch = images[i : i + batch_size] - processed_images = [self.load_image(image) for image in batch] - batched_images = torch.cat(processed_images).to(self.device) - batch_embeddings = self._embed(batched_images) - all_embeddings.extend(batch_embeddings.cpu().numpy().tolist()) + processed_images = [] + for image in batch: + pixel_values = self.load_image(image) + if pixel_values is not None: + processed_images.append(pixel_values) + else: + bad_images.append(str(image)) + + if not processed_images: + continue + + try: + batched_images = torch.cat(processed_images).to(self.device) + batch_embeddings = self._embed(batched_images) + all_embeddings.extend(batch_embeddings.cpu().numpy().tolist()) + except Exception as e: + logging.error(f"Error during batch embedding, skipping batch: {e}") + if bad_images: + logging.warning(f"The following images failed to process : {', '.join(bad_images)}") return all_embeddings def __call__( - self, images: list[Image.Image | Path | str], batch_size: int = 8 + self, images: list[Path | str], batch_size: int = 8 ) -> Embeddings: - return self.batch_embed_images(images, batch_size=batch_size) + return self.batch_embed_images(images, batch_size=batch_size) \ No newline at end of file diff --git a/v2/embedding_store.py b/v2/embedding_store.py index 147d1da..9a21c60 100644 --- a/v2/embedding_store.py +++ b/v2/embedding_store.py @@ -275,4 +275,4 @@ def delete_embeddings(self, image_paths: list[str] = None) -> None: delete_ids = [idd for idd, uri in zip(ids, uris) if uri in image_paths] self.collection.delete(ids=delete_ids) self._delete_cache(image_paths=image_paths) - print(f"deleted {len(delete_ids)} embeddins!") + print(f"deleted {len(delete_ids)} embeddings!") From 1a297dd3367e8ba7a6a23f16511a9b598ab3937d Mon Sep 17 00:00:00 2001 From: tikendraw Date: Thu, 19 Dec 2024 17:52:00 +0530 Subject: [PATCH 2/2] logs bad images --- v2/embed_model.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/v2/embed_model.py b/v2/embed_model.py index a41f848..6302cec 100644 --- a/v2/embed_model.py +++ b/v2/embed_model.py @@ -68,13 +68,9 @@ def batch_embed_images( bad_images = [] for i in range(0, len(images), batch_size): batch = images[i : i + batch_size] - processed_images = [] - for image in batch: - pixel_values = self.load_image(image) - if pixel_values is not None: - processed_images.append(pixel_values) - else: - bad_images.append(str(image)) + processed_images = [self.load_image(image) for image in batch] + bad_images = [image for image, pixel_values in zip(batch, processed_images) if pixel_values is None] + processed_images = [pixel_values for pixel_values in processed_images if pixel_values is not None] if not processed_images: continue