Spaces:
Running
on
Zero
Running
on
Zero
| from typing import Literal | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| from transformers import AutoImageProcessor, AutoModel, AutoTokenizer | |
| from PIL import Image | |
| from transformers.utils import ModelOutput | |
| from tqdm import tqdm | |
| class MultimodalEmbedder: | |
| """A multimodal embedder that supports text and image embeddings.""" | |
| def __init__( | |
| self, | |
| text_model: str = 'nomic-ai/nomic-embed-text-v1.5', | |
| image_model: str = 'nomic-ai/nomic-embed-vision-v1.5', | |
| batch_size: int = 64 | |
| ): | |
| self.tokenizer = AutoTokenizer.from_pretrained(text_model) | |
| self.text_model = AutoModel.from_pretrained(text_model, trust_remote_code=True) | |
| self.text_model.eval() | |
| self.text_embedding_size = self.text_model.config.hidden_size | |
| self.processor = AutoImageProcessor.from_pretrained(image_model) | |
| self.image_model = AutoModel.from_pretrained(image_model, trust_remote_code=True) | |
| self.image_embedding_size = self.image_model.config.hidden_size | |
| self.batch_size = batch_size | |
| def embed_texts( | |
| self, | |
| texts: list[str], | |
| kind: Literal['query', 'document'] = 'document', | |
| device: str = 'cpu' | |
| ) -> list[list[float]]: | |
| """Embed a list of texts""" | |
| texts = [f'search_query: {text}' if kind == 'query' else f'search_document: {text}' for text in texts] | |
| model = self.text_model.to(device) | |
| all_embeddings = [] | |
| for start in tqdm(range(0, len(texts), self.batch_size), desc='Embed texts'): | |
| batch_texts = texts[start:start + self.batch_size] | |
| inputs = self.tokenizer( | |
| batch_texts, | |
| padding=True, | |
| truncation=True, | |
| return_tensors='pt' | |
| ).to(device) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| embeddings = mean_pooling(outputs, inputs['attention_mask']) | |
| embeddings = F.layer_norm(embeddings, normalized_shape=(embeddings.shape[1],)) | |
| embeddings = F.normalize(embeddings, p=2, dim=1) | |
| all_embeddings.append(embeddings.cpu()) | |
| return torch.cat(all_embeddings, dim=0).tolist() | |
| def embed_images(self, images: list[str | Image.Image], device: str = 'cpu') -> list[list[float]]: | |
| """Embed a list of images, which can be file paths or PIL Image objects.""" | |
| images = [Image.open(img) if isinstance(img, str) else img for img in images] | |
| images = [img.convert('RGB') for img in images] | |
| model = self.image_model.to(device) | |
| all_embeddings = [] | |
| for start in tqdm(range(0, len(images), self.batch_size), desc='Embed images'): | |
| batch_images = images[start:start + self.batch_size] | |
| inputs = self.processor(batch_images, return_tensors='pt').to(device) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| embeddings = outputs.last_hidden_state[:, 0] # CLS token | |
| embeddings = F.normalize(embeddings, p=2, dim=1) | |
| all_embeddings.append(embeddings.cpu()) | |
| return torch.cat(all_embeddings, dim=0).tolist() | |
| def similarity( | |
| self, | |
| embeddings1: list[list[float]], | |
| embeddings2: list[list[float]], | |
| pair_type: Literal['text-text', 'image-image', 'text-image'] | |
| ) -> list[list[float]]: | |
| """Calculate cosine similarity between two sets of embeddings.""" | |
| pair_min_max = { | |
| 'text-text': (0.4, 1.0), | |
| 'image-image': (0.75, 1.0), | |
| 'text-image': (0.01, 0.09) | |
| } | |
| min_val, max_val = pair_min_max[pair_type] | |
| similarities = np.dot(embeddings1, np.transpose(embeddings2)) | |
| similarities = np.clip((similarities - min_val) / (max_val - min_val), 0, 1) | |
| return similarities.tolist() | |
| def mean_pooling(model_output: ModelOutput, attention_mask: torch.Tensor) -> torch.Tensor: | |
| """Mean pooling for the model output.""" | |
| token_embeddings = model_output[0] | |
| input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() | |
| return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) | |