Skip to content

Commit

Permalink
Merge pull request #5 from tikendraw/tikendraw/issue4
Browse files Browse the repository at this point in the history
Tikendraw/issue4
  • Loading branch information
tikendraw authored Dec 19, 2024
2 parents 80a46d9 + 1a297dd commit c41b093
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 23 deletions.
62 changes: 40 additions & 22 deletions v2/embed_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,10 @@
from chromadb.api.types import Documents, EmbeddingFunction, Embeddings
from PIL import Image
import torch
import numpy as np

import logging

class EfficientNetEmbeddingFunction(EmbeddingFunction[Documents]):
"""To use this EmbeddingFunction, you must have the google.generativeai Python package installed and have a PaLM API key."""
"""To use this EmbeddingFunction, you must have the transformers Python package installed."""

def __init__(self, model_name: str = "google/efficientnet-b0", device="cuda"):
self.device = (
Expand All @@ -17,23 +16,27 @@ def __init__(self, model_name: str = "google/efficientnet-b0", device="cuda"):
self.image_processor = AutoImageProcessor.from_pretrained(model_name)
self.model = AutoModel.from_pretrained(model_name).to(self.device)

def load_image(self, image: str | Path | Image.Image | np.ndarray):
def load_image(self, image: str | Path ):
"""
Loads an image and processes it using the model's image processor.
Reusable helper function for single and batch image processing.
"""
if isinstance(image, str) or isinstance(image, Path):
image = Image.open(image)
try:
if isinstance(image, str) or isinstance(image, Path):
image = Image.open(image)
else:
logging.error('Image type not supported(knowingly), use str or pathlib.Path ')
return None

if isinstance(image, np.ndarray):
image = Image.fromarray(image)
if image.mode != "RGB":
image = image.convert("RGB")

if image.mode != "RGB":
image = image.convert("RGB")

image = self.image_processor(images=image, return_tensors="pt").to(self.device)
image = self.image_processor(images=image, return_tensors="pt").to(self.device)

return image["pixel_values"]
return image["pixel_values"]
except OSError as e:
logging.error(f"Error processing image {image}: {e}")
return None

def _embed(self, pixel_values):
"""
Expand All @@ -48,27 +51,42 @@ def embed_image(self, image):
Embeds a single image by first loading and processing it, then passing it to the model.
"""
pixel_values = self.load_image(image)
embeddings = self._embed(pixel_values)
return embeddings.cpu().numpy().tolist()
if pixel_values is not None:
embeddings = self._embed(pixel_values)
return embeddings.cpu().numpy().tolist()
else:
logging.error(f"Failed to load image: {image}")
return None

def batch_embed_images(
self, images: list[str | Path | Image.Image | np.ndarray], batch_size: int = 32
self, images: list[str | Path], batch_size: int = 32
):
"""
Embeds a batch of images, processing them in batches of the specified batch_size.
"""
all_embeddings = []

bad_images = []
for i in range(0, len(images), batch_size):
batch = images[i : i + batch_size]
processed_images = [self.load_image(image) for image in batch]
batched_images = torch.cat(processed_images).to(self.device)
batch_embeddings = self._embed(batched_images)
all_embeddings.extend(batch_embeddings.cpu().numpy().tolist())
bad_images = [image for image, pixel_values in zip(batch, processed_images) if pixel_values is None]
processed_images = [pixel_values for pixel_values in processed_images if pixel_values is not None]

if not processed_images:
continue

try:
batched_images = torch.cat(processed_images).to(self.device)
batch_embeddings = self._embed(batched_images)
all_embeddings.extend(batch_embeddings.cpu().numpy().tolist())
except Exception as e:
logging.error(f"Error during batch embedding, skipping batch: {e}")

if bad_images:
logging.warning(f"The following images failed to process : {', '.join(bad_images)}")
return all_embeddings

def __call__(
self, images: list[Image.Image | Path | str], batch_size: int = 8
self, images: list[Path | str], batch_size: int = 8
) -> Embeddings:
return self.batch_embed_images(images, batch_size=batch_size)
return self.batch_embed_images(images, batch_size=batch_size)
2 changes: 1 addition & 1 deletion v2/embedding_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,4 +275,4 @@ def delete_embeddings(self, image_paths: list[str] = None) -> None:
delete_ids = [idd for idd, uri in zip(ids, uris) if uri in image_paths]
self.collection.delete(ids=delete_ids)
self._delete_cache(image_paths=image_paths)
print(f"deleted {len(delete_ids)} embeddins!")
print(f"deleted {len(delete_ids)} embeddings!")

0 comments on commit c41b093

Please sign in to comment.