medsiglip-classifier / classifier.py
sitammeur's picture
Update classifier.py
b256293 verified
import os
from dotenv import load_dotenv
import numpy as np
from PIL import Image
import torch
from transformers import AutoProcessor, AutoModel
import tensorflow as tf
# Load the Environment Variables from .env file
load_dotenv()
# Access token for using the model
access_token = os.environ.get("ACCESS_TOKEN")
class MedSigLIPClassifier:
"""MedSigLIPClassifier class for zero-shot classification of medical images."""
def __init__(self, model_id="google/medsiglip-448"):
"""Initialize the classifier with the given model ID."""
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model = AutoModel.from_pretrained(model_id, token=access_token).to(
self.device
)
self.processor = AutoProcessor.from_pretrained(model_id, use_fast=True, token=access_token)
def _resize(self, image):
"""Resizes the image using TensorFlow's resize method to match MedSigLIP training preprocessing."""
return Image.fromarray(
tf.image.resize(
images=image, size=[448, 448], method="bilinear", antialias=False
)
.numpy()
.astype(np.uint8)
)
def predict(self, image: Image.Image, candidate_labels: list[str]):
"""Predicts the probabilities for the given image and candidate labels."""
# Ensure image is RGB
if image.mode != "RGB":
image = image.convert("RGB")
# Resize image
resized_image = self._resize(image)
# Prepare inputs
inputs = self.processor(
text=candidate_labels,
images=resized_image,
padding="max_length",
return_tensors="pt",
).to(self.device)
# Inference
with torch.no_grad():
outputs = self.model(**inputs)
logits_per_image = outputs.logits_per_image
probs = torch.softmax(logits_per_image, dim=1)
# Format results
probs_list = probs[0].tolist()
return {label: prob for label, prob in zip(candidate_labels, probs_list)}