Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python | |
| # coding: utf-8 | |
| # In[1]: | |
| import torch | |
| import torch.nn as nn | |
| from torchvision import models, transforms | |
| from PIL import Image | |
| import numpy as np | |
| import torch.nn.functional as F | |
| from torchvision.models import resnet50, ResNet50_Weights | |
| # In[4]: | |
| class MalariaResNet50(nn.Module): | |
| def __init__(self, num_classes=2): | |
| super(MalariaResNet50, self).__init__() | |
| # Load pretrained ResNet50 | |
| self.backbone = models.resnet50(weights=ResNet50_Weights.DEFAULT) | |
| # Replace final fully connected layer for binary classification | |
| num_ftrs = self.backbone.fc.in_features | |
| self.backbone.fc = nn.Linear(num_ftrs, num_classes) | |
| def forward(self, x): | |
| return self.backbone(x) | |
| def predict(self, image_path, device='cpu', show_image=False): | |
| """ | |
| Predict class of a single image. | |
| Args: | |
| image_path (str): Path to input image | |
| device (torch.device): 'cuda' or 'cpu' | |
| show_image (bool): Whether to display the image | |
| Returns: | |
| pred_label (str): "Infected" or "Uninfected" | |
| confidence (float): Confidence score (softmax output) | |
| """ | |
| from torchvision import transforms | |
| from PIL import Image | |
| import matplotlib.pyplot as plt | |
| transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
| ]) | |
| # Load and preprocess image | |
| img = Image.open(image_path).convert('RGB') | |
| img_tensor = transform(img).unsqueeze(0).to(device) | |
| # Inference | |
| self.eval() | |
| with torch.no_grad(): | |
| output = self(img_tensor) | |
| probs = F.softmax(output, dim=1) | |
| _, preds = torch.max(output, 1) | |
| pred_idx = preds.item() | |
| confidence = probs[0][pred_idx].item() | |
| classes = ['Uninfected', 'Infected'] | |
| pred_label = classes[pred_idx] | |
| if show_image: | |
| plt.imshow(img) | |
| plt.title(f"Predicted: {pred_label} ({confidence:.2%})") | |
| plt.axis("off") | |
| plt.show() | |
| return pred_label, confidence | |
| def save(self, path): | |
| """Save model state dict""" | |
| torch.save(self.state_dict(), path) | |
| print(f"Model saved to {path}") | |
| def load(self, path): | |
| """Load model state dict from file""" | |
| state_dict = torch.load(path, map_location=torch.device('cpu')) | |
| self.load_state_dict(state_dict) | |
| print(f"Model loaded from {path}") | |
| # In[ ]: | |