Spaces:
Running
Running
| import os | |
| from dotenv import load_dotenv | |
| import numpy as np | |
| from PIL import Image | |
| import torch | |
| from transformers import AutoProcessor, AutoModel | |
| import tensorflow as tf | |
| # Load the Environment Variables from .env file | |
| load_dotenv() | |
| # Access token for using the model | |
| access_token = os.environ.get("ACCESS_TOKEN") | |
| class MedSigLIPClassifier: | |
| """MedSigLIPClassifier class for zero-shot classification of medical images.""" | |
| def __init__(self, model_id="google/medsiglip-448"): | |
| """Initialize the classifier with the given model ID.""" | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| self.model = AutoModel.from_pretrained(model_id, token=access_token).to( | |
| self.device | |
| ) | |
| self.processor = AutoProcessor.from_pretrained(model_id, use_fast=True, token=access_token) | |
| def _resize(self, image): | |
| """Resizes the image using TensorFlow's resize method to match MedSigLIP training preprocessing.""" | |
| return Image.fromarray( | |
| tf.image.resize( | |
| images=image, size=[448, 448], method="bilinear", antialias=False | |
| ) | |
| .numpy() | |
| .astype(np.uint8) | |
| ) | |
| def predict(self, image: Image.Image, candidate_labels: list[str]): | |
| """Predicts the probabilities for the given image and candidate labels.""" | |
| # Ensure image is RGB | |
| if image.mode != "RGB": | |
| image = image.convert("RGB") | |
| # Resize image | |
| resized_image = self._resize(image) | |
| # Prepare inputs | |
| inputs = self.processor( | |
| text=candidate_labels, | |
| images=resized_image, | |
| padding="max_length", | |
| return_tensors="pt", | |
| ).to(self.device) | |
| # Inference | |
| with torch.no_grad(): | |
| outputs = self.model(**inputs) | |
| logits_per_image = outputs.logits_per_image | |
| probs = torch.softmax(logits_per_image, dim=1) | |
| # Format results | |
| probs_list = probs[0].tolist() | |
| return {label: prob for label, prob in zip(candidate_labels, probs_list)} | |