|
|
""" |
|
|
HuggingFace Spaces App for ConvNeXt CheXpert Classification with GradCAM |
|
|
|
|
|
This app provides a web interface for chest X-ray classification with GradCAM visualization |
|
|
showing model attention regions for confident predictions (>60% confidence). |
|
|
|
|
|
Usage: |
|
|
Run this file and access the Gradio interface via the provided URL |
|
|
""" |
|
|
|
|
|
import os |
|
|
import torch |
|
|
import timm |
|
|
import gradio as gr |
|
|
import numpy as np |
|
|
import torch.nn as nn |
|
|
import matplotlib.pyplot as plt |
|
|
from PIL import Image |
|
|
from torchvision import transforms |
|
|
import cv2 |
|
|
|
|
|
|
|
|
try: |
|
|
from pytorch_grad_cam import GradCAM |
|
|
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget |
|
|
from pytorch_grad_cam.utils.image import show_cam_on_image |
|
|
except ImportError: |
|
|
print("Installing required packages...") |
|
|
os.system("pip install pytorch-grad-cam") |
|
|
|
|
|
|
|
|
DISEASE_LABELS = [ |
|
|
"No Finding", "Enlarged Cardiomediastinum", "Cardiomegaly", |
|
|
"Lung Opacity", "Lung Lesion", "Edema", "Consolidation", |
|
|
"Pneumonia", "Atelectasis", "Pneumothorax", "Pleural Effusion", |
|
|
"Pleural Other", "Fracture", "Support Devices" |
|
|
] |
|
|
|
|
|
|
|
|
MODEL_CONFIG = { |
|
|
"input_size": 384, |
|
|
"num_classes": 14, |
|
|
"mean": [0.5029414296150208] * 3, |
|
|
"std": [0.2892409563064575] * 3 |
|
|
} |
|
|
|
|
|
class ConvNeXtWithCBAM(nn.Module): |
|
|
"""ConvNeXt model with CBAM attention for GradCAM compatibility""" |
|
|
def __init__(self, num_classes=14, model_name="convnext_base"): |
|
|
super().__init__() |
|
|
|
|
|
self.backbone = timm.create_model( |
|
|
model_name, |
|
|
pretrained=False, |
|
|
num_classes=0, |
|
|
features_only=True |
|
|
) |
|
|
|
|
|
|
|
|
feature_dim = self.backbone.feature_info.channels()[-1] |
|
|
self.cbam = self._create_cbam_attention(feature_dim) |
|
|
|
|
|
|
|
|
self.global_pool = nn.AdaptiveAvgPool2d(1) |
|
|
self.classifier = nn.Linear(feature_dim, num_classes) |
|
|
|
|
|
def _create_cbam_attention(self, channels, reduction=16, kernel_size=7): |
|
|
"""Create CBAM attention module""" |
|
|
return nn.Sequential( |
|
|
|
|
|
nn.AdaptiveAvgPool2d(1), |
|
|
nn.Conv2d(channels, channels // reduction, 1, bias=False), |
|
|
nn.ReLU(), |
|
|
nn.Conv2d(channels // reduction, channels, 1, bias=False), |
|
|
nn.Sigmoid(), |
|
|
|
|
|
nn.Conv2d(2, 1, kernel_size, padding=kernel_size // 2, bias=False), |
|
|
nn.Sigmoid() |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
|
|
|
features = self.backbone(x)[-1] |
|
|
|
|
|
|
|
|
ca = self.cbam[:5](features) |
|
|
features = features * ca |
|
|
|
|
|
|
|
|
avg_out = torch.mean(features, dim=1, keepdim=True) |
|
|
max_out, _ = torch.max(features, dim=1, keepdim=True) |
|
|
sa = self.cbam[5](torch.cat([avg_out, max_out], dim=1)) |
|
|
features = features * sa |
|
|
|
|
|
|
|
|
features = self.global_pool(features) |
|
|
features = features.view(features.size(0), -1) |
|
|
return self.classifier(features) |
|
|
|
|
|
def load_model(model_repo="calender/Convnext-Chexpert-Attention"): |
|
|
"""Load the trained model from HuggingFace Hub""" |
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
print(f"Using device: {device}") |
|
|
|
|
|
|
|
|
model = ConvNeXtWithCBAM(num_classes=14).to(device) |
|
|
|
|
|
|
|
|
try: |
|
|
from huggingface_hub import hf_hub_download |
|
|
model_path = hf_hub_download(repo_id=model_repo, filename="model.pth") |
|
|
print(f"Downloaded model from {model_repo}") |
|
|
except ImportError: |
|
|
print("huggingface_hub not available, trying local model...") |
|
|
model_path = "model/model.pth" |
|
|
|
|
|
state_dict = torch.load(model_path, map_location=device) |
|
|
|
|
|
|
|
|
if any(key.startswith('module.') for key in state_dict.keys()): |
|
|
state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()} |
|
|
|
|
|
model.load_state_dict(state_dict) |
|
|
model.eval() |
|
|
|
|
|
print("Model loaded successfully!") |
|
|
return model, device |
|
|
|
|
|
def predict_with_gradcam(model, device, image, confidence_threshold=0.6): |
|
|
"""Get predictions and GradCAM visualizations for confident predictions""" |
|
|
|
|
|
|
|
|
transform = transforms.Compose([ |
|
|
transforms.Grayscale(num_output_channels=3), |
|
|
transforms.Resize((MODEL_CONFIG["input_size"], MODEL_CONFIG["input_size"])), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize(mean=MODEL_CONFIG["mean"], std=MODEL_CONFIG["std"]) |
|
|
]) |
|
|
|
|
|
|
|
|
input_tensor = transform(image).unsqueeze(0).to(device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
logits = model(input_tensor) |
|
|
probabilities = torch.sigmoid(logits).squeeze().cpu().numpy() |
|
|
|
|
|
|
|
|
confident_indices = [] |
|
|
confident_predictions = [] |
|
|
|
|
|
for idx, (prob, disease) in enumerate(zip(probabilities, DISEASE_LABELS)): |
|
|
if prob > confidence_threshold: |
|
|
confident_indices.append(idx) |
|
|
confident_predictions.append({ |
|
|
'disease': disease, |
|
|
'confidence': float(prob), |
|
|
'class_idx': idx |
|
|
}) |
|
|
|
|
|
if not confident_predictions: |
|
|
return { |
|
|
'predictions': [], |
|
|
'message': f'No findings above {confidence_threshold:.0%} confidence threshold', |
|
|
'visualizations': None |
|
|
} |
|
|
|
|
|
|
|
|
target_layer = None |
|
|
for module in reversed(list(model.backbone.modules())): |
|
|
if isinstance(module, nn.Conv2d): |
|
|
target_layer = module |
|
|
break |
|
|
|
|
|
if target_layer is None: |
|
|
return { |
|
|
'predictions': confident_predictions, |
|
|
'message': 'Could not find suitable layer for GradCAM', |
|
|
'visualizations': None |
|
|
} |
|
|
|
|
|
|
|
|
visualizations = {} |
|
|
|
|
|
for pred in confident_predictions: |
|
|
class_idx = pred['class_idx'] |
|
|
disease = pred['disease'] |
|
|
confidence = pred['confidence'] |
|
|
|
|
|
|
|
|
targets = [ClassifierOutputTarget(class_idx)] |
|
|
|
|
|
try: |
|
|
with GradCAM(model=model, target_layers=[target_layer]) as cam: |
|
|
grayscale_cam = cam(input_tensor=input_tensor, targets=targets)[0, :] |
|
|
|
|
|
|
|
|
rgb_img = np.array(image.convert('RGB'), dtype=np.float32) / 255.0 |
|
|
|
|
|
|
|
|
grayscale_cam_resized = cv2.resize(grayscale_cam, (rgb_img.shape[1], rgb_img.shape[0])) |
|
|
|
|
|
|
|
|
cam_overlay = show_cam_on_image( |
|
|
rgb_img, |
|
|
grayscale_cam_resized, |
|
|
use_rgb=True, |
|
|
image_weight=0.5, |
|
|
colormap=cv2.COLORMAP_JET |
|
|
) |
|
|
|
|
|
visualizations[disease] = { |
|
|
'heatmap': grayscale_cam_resized, |
|
|
'overlay': cam_overlay, |
|
|
'confidence': confidence |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error generating GradCAM for {disease}: {e}") |
|
|
continue |
|
|
|
|
|
return { |
|
|
'predictions': confident_predictions, |
|
|
'message': f'Found {len(confident_predictions)} confident predictions above {confidence_threshold:.0%} threshold', |
|
|
'visualizations': visualizations |
|
|
} |
|
|
|
|
|
def create_gradio_interface(): |
|
|
"""Create and configure the Gradio interface""" |
|
|
model, device = load_model() |
|
|
|
|
|
def analyze_xray(image): |
|
|
"""Analyze uploaded X-ray image""" |
|
|
if image is None: |
|
|
return "Please upload a chest X-ray image", None, None |
|
|
|
|
|
try: |
|
|
|
|
|
results = predict_with_gradcam(model, device, image) |
|
|
|
|
|
if not results['predictions']: |
|
|
return results['message'], None, None |
|
|
|
|
|
|
|
|
prediction_text = f"## Analysis Results\n\n{results['message']}\n\n" |
|
|
prediction_text += "### Confident Predictions:\n\n" |
|
|
|
|
|
for pred in results['predictions']: |
|
|
prediction_text += f"π **{pred['disease']}**: {pred['confidence']:.1%}\n" |
|
|
|
|
|
|
|
|
if results['visualizations']: |
|
|
num_plots = len(results['visualizations']) |
|
|
fig, axes = plt.subplots(num_plots, 3, figsize=(15, 5 * num_plots)) |
|
|
|
|
|
if num_plots == 1: |
|
|
axes = axes.reshape(1, -1) |
|
|
|
|
|
for i, (disease, vis_data) in enumerate(results['visualizations'].items()): |
|
|
|
|
|
axes[i, 0].imshow(image, cmap='gray') |
|
|
axes[i, 0].set_title(f"Original X-ray\n{disease}", fontsize=10) |
|
|
axes[i, 0].axis('off') |
|
|
|
|
|
|
|
|
axes[i, 1].imshow(vis_data['heatmap'], cmap='jet') |
|
|
axes[i, 1].set_title(f"GradCAM Heatmap\n{vis_data['confidence']:.1%}", fontsize=10) |
|
|
axes[i, 1].axis('off') |
|
|
|
|
|
|
|
|
axes[i, 2].imshow(vis_data['overlay']) |
|
|
axes[i, 2].set_title(f"GradCAM Overlay\n{disease}", fontsize=10) |
|
|
axes[i, 2].axis('off') |
|
|
|
|
|
plt.tight_layout() |
|
|
|
|
|
return prediction_text, fig, "β
Analysis completed successfully!" |
|
|
|
|
|
return prediction_text, None, "β
Analysis completed successfully!" |
|
|
|
|
|
except Exception as e: |
|
|
return f"β Error analyzing image: {str(e)}", None, "Analysis failed" |
|
|
|
|
|
|
|
|
interface = gr.Interface( |
|
|
fn=analyze_xray, |
|
|
inputs=gr.Image(label="Upload Chest X-ray", type="pil"), |
|
|
outputs=[ |
|
|
gr.Markdown(label="Analysis Results"), |
|
|
gr.Plot(label="GradCAM Visualizations"), |
|
|
gr.Textbox(label="Status", interactive=False) |
|
|
], |
|
|
title="π« ConvNeXt CheXpert Classifier with GradCAM", |
|
|
description=""" |
|
|
**Medical AI for Chest X-ray Analysis** |
|
|
|
|
|
This tool uses a ConvNeXt-Base model with CBAM attention to analyze chest X-rays and identify 14 different thoracic pathologies. |
|
|
|
|
|
**Features:** |
|
|
- π Multi-label classification of 14 chest conditions |
|
|
- π Shows only confident predictions (>60% confidence) |
|
|
- π― GradCAM visualization showing model attention regions |
|
|
- π₯ Designed for research and educational purposes |
|
|
|
|
|
**β οΈ Important Medical Disclaimer:** |
|
|
This tool is for research and educational purposes only. Always consult qualified healthcare professionals for medical decisions. |
|
|
|
|
|
**Supported Conditions:** |
|
|
No Finding, Enlarged Cardiomediastinum, Cardiomegaly, Lung Opacity, Lung Lesion, Edema, Consolidation, Pneumonia, Atelectasis, Pneumothorax, Pleural Effusion, Pleural Other, Fracture, Support Devices |
|
|
""", |
|
|
theme="default", |
|
|
allow_flagging="never" |
|
|
) |
|
|
|
|
|
return interface |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
print("Starting ConvNeXt CheXpert GradCAM App...") |
|
|
interface = create_gradio_interface() |
|
|
interface.launch( |
|
|
server_name="0.0.0.0", |
|
|
server_port=7860, |
|
|
share=True, |
|
|
show_error=True |
|
|
) |
|
|
|