calender's picture
Upload 5 files
a79f504 verified
raw
history blame
11.6 kB
"""
HuggingFace Spaces App for ConvNeXt CheXpert Classification with GradCAM
This app provides a web interface for chest X-ray classification with GradCAM visualization
showing model attention regions for confident predictions (>60% confidence).
Usage:
Run this file and access the Gradio interface via the provided URL
"""
import os
import torch
import timm
import gradio as gr
import numpy as np
import torch.nn as nn
import matplotlib.pyplot as plt
from PIL import Image
from torchvision import transforms
import cv2
# GradCAM imports
try:
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image
except ImportError:
print("Installing required packages...")
os.system("pip install pytorch-grad-cam")
# Disease labels in the correct order
DISEASE_LABELS = [
"No Finding", "Enlarged Cardiomediastinum", "Cardiomegaly",
"Lung Opacity", "Lung Lesion", "Edema", "Consolidation",
"Pneumonia", "Atelectasis", "Pneumothorax", "Pleural Effusion",
"Pleural Other", "Fracture", "Support Devices"
]
# Model configuration
MODEL_CONFIG = {
"input_size": 384,
"num_classes": 14,
"mean": [0.5029414296150208] * 3,
"std": [0.2892409563064575] * 3
}
class ConvNeXtWithCBAM(nn.Module):
"""ConvNeXt model with CBAM attention for GradCAM compatibility"""
def __init__(self, num_classes=14, model_name="convnext_base"):
super().__init__()
# Create ConvNeXt backbone
self.backbone = timm.create_model(
model_name,
pretrained=False,
num_classes=0,
features_only=True
)
# Add CBAM attention
feature_dim = self.backbone.feature_info.channels()[-1]
self.cbam = self._create_cbam_attention(feature_dim)
# Global pooling and classifier
self.global_pool = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(feature_dim, num_classes)
def _create_cbam_attention(self, channels, reduction=16, kernel_size=7):
"""Create CBAM attention module"""
return nn.Sequential(
# Channel attention
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(channels, channels // reduction, 1, bias=False),
nn.ReLU(),
nn.Conv2d(channels // reduction, channels, 1, bias=False),
nn.Sigmoid(),
# Spatial attention
nn.Conv2d(2, 1, kernel_size, padding=kernel_size // 2, bias=False),
nn.Sigmoid()
)
def forward(self, x):
# Extract features
features = self.backbone(x)[-1]
# Apply CBAM attention
ca = self.cbam[:5](features) # Channel attention
features = features * ca
# Spatial attention (simplified for GradCAM)
avg_out = torch.mean(features, dim=1, keepdim=True)
max_out, _ = torch.max(features, dim=1, keepdim=True)
sa = self.cbam[5](torch.cat([avg_out, max_out], dim=1))
features = features * sa
# Global pooling and classification
features = self.global_pool(features)
features = features.view(features.size(0), -1)
return self.classifier(features)
def load_model(model_repo="calender/Convnext-Chexpert-Attention"):
"""Load the trained model from HuggingFace Hub"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Create model
model = ConvNeXtWithCBAM(num_classes=14).to(device)
# Load state dict from HuggingFace Hub
try:
from huggingface_hub import hf_hub_download
model_path = hf_hub_download(repo_id=model_repo, filename="model.pth")
print(f"Downloaded model from {model_repo}")
except ImportError:
print("huggingface_hub not available, trying local model...")
model_path = "model/model.pth"
state_dict = torch.load(model_path, map_location=device)
# Handle DataParallel
if any(key.startswith('module.') for key in state_dict.keys()):
state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
model.load_state_dict(state_dict)
model.eval()
print("Model loaded successfully!")
return model, device
def predict_with_gradcam(model, device, image, confidence_threshold=0.6):
"""Get predictions and GradCAM visualizations for confident predictions"""
# Image preprocessing
transform = transforms.Compose([
transforms.Grayscale(num_output_channels=3), # Convert grayscale to RGB
transforms.Resize((MODEL_CONFIG["input_size"], MODEL_CONFIG["input_size"])),
transforms.ToTensor(),
transforms.Normalize(mean=MODEL_CONFIG["mean"], std=MODEL_CONFIG["std"])
])
# Prepare input
input_tensor = transform(image).unsqueeze(0).to(device)
# Get predictions
with torch.no_grad():
logits = model(input_tensor)
probabilities = torch.sigmoid(logits).squeeze().cpu().numpy()
# Find confident predictions
confident_indices = []
confident_predictions = []
for idx, (prob, disease) in enumerate(zip(probabilities, DISEASE_LABELS)):
if prob > confidence_threshold:
confident_indices.append(idx)
confident_predictions.append({
'disease': disease,
'confidence': float(prob),
'class_idx': idx
})
if not confident_predictions:
return {
'predictions': [],
'message': f'No findings above {confidence_threshold:.0%} confidence threshold',
'visualizations': None
}
# Find target layer for GradCAM
target_layer = None
for module in reversed(list(model.backbone.modules())):
if isinstance(module, nn.Conv2d):
target_layer = module
break
if target_layer is None:
return {
'predictions': confident_predictions,
'message': 'Could not find suitable layer for GradCAM',
'visualizations': None
}
# Generate GradCAM for each confident prediction
visualizations = {}
for pred in confident_predictions:
class_idx = pred['class_idx']
disease = pred['disease']
confidence = pred['confidence']
# Generate GradCAM
targets = [ClassifierOutputTarget(class_idx)]
try:
with GradCAM(model=model, target_layers=[target_layer]) as cam:
grayscale_cam = cam(input_tensor=input_tensor, targets=targets)[0, :]
# Convert to RGB for visualization
rgb_img = np.array(image.convert('RGB'), dtype=np.float32) / 255.0
# Resize heatmap to match image
grayscale_cam_resized = cv2.resize(grayscale_cam, (rgb_img.shape[1], rgb_img.shape[0]))
# Create overlay
cam_overlay = show_cam_on_image(
rgb_img,
grayscale_cam_resized,
use_rgb=True,
image_weight=0.5,
colormap=cv2.COLORMAP_JET
)
visualizations[disease] = {
'heatmap': grayscale_cam_resized,
'overlay': cam_overlay,
'confidence': confidence
}
except Exception as e:
print(f"Error generating GradCAM for {disease}: {e}")
continue
return {
'predictions': confident_predictions,
'message': f'Found {len(confident_predictions)} confident predictions above {confidence_threshold:.0%} threshold',
'visualizations': visualizations
}
def create_gradio_interface():
"""Create and configure the Gradio interface"""
model, device = load_model()
def analyze_xray(image):
"""Analyze uploaded X-ray image"""
if image is None:
return "Please upload a chest X-ray image", None, None
try:
# Get predictions and GradCAM
results = predict_with_gradcam(model, device, image)
if not results['predictions']:
return results['message'], None, None
# Create prediction text
prediction_text = f"## Analysis Results\n\n{results['message']}\n\n"
prediction_text += "### Confident Predictions:\n\n"
for pred in results['predictions']:
prediction_text += f"πŸ” **{pred['disease']}**: {pred['confidence']:.1%}\n"
# Create visualization plots
if results['visualizations']:
num_plots = len(results['visualizations'])
fig, axes = plt.subplots(num_plots, 3, figsize=(15, 5 * num_plots))
if num_plots == 1:
axes = axes.reshape(1, -1)
for i, (disease, vis_data) in enumerate(results['visualizations'].items()):
# Original image
axes[i, 0].imshow(image, cmap='gray')
axes[i, 0].set_title(f"Original X-ray\n{disease}", fontsize=10)
axes[i, 0].axis('off')
# GradCAM heatmap
axes[i, 1].imshow(vis_data['heatmap'], cmap='jet')
axes[i, 1].set_title(f"GradCAM Heatmap\n{vis_data['confidence']:.1%}", fontsize=10)
axes[i, 1].axis('off')
# GradCAM overlay
axes[i, 2].imshow(vis_data['overlay'])
axes[i, 2].set_title(f"GradCAM Overlay\n{disease}", fontsize=10)
axes[i, 2].axis('off')
plt.tight_layout()
return prediction_text, fig, "βœ… Analysis completed successfully!"
return prediction_text, None, "βœ… Analysis completed successfully!"
except Exception as e:
return f"❌ Error analyzing image: {str(e)}", None, "Analysis failed"
# Create Gradio interface
interface = gr.Interface(
fn=analyze_xray,
inputs=gr.Image(label="Upload Chest X-ray", type="pil"),
outputs=[
gr.Markdown(label="Analysis Results"),
gr.Plot(label="GradCAM Visualizations"),
gr.Textbox(label="Status", interactive=False)
],
title="🫁 ConvNeXt CheXpert Classifier with GradCAM",
description="""
**Medical AI for Chest X-ray Analysis**
This tool uses a ConvNeXt-Base model with CBAM attention to analyze chest X-rays and identify 14 different thoracic pathologies.
**Features:**
- πŸ” Multi-label classification of 14 chest conditions
- πŸ“Š Shows only confident predictions (>60% confidence)
- 🎯 GradCAM visualization showing model attention regions
- πŸ₯ Designed for research and educational purposes
**⚠️ Important Medical Disclaimer:**
This tool is for research and educational purposes only. Always consult qualified healthcare professionals for medical decisions.
**Supported Conditions:**
No Finding, Enlarged Cardiomediastinum, Cardiomegaly, Lung Opacity, Lung Lesion, Edema, Consolidation, Pneumonia, Atelectasis, Pneumothorax, Pleural Effusion, Pleural Other, Fracture, Support Devices
""",
theme="default",
allow_flagging="never"
)
return interface
# Main execution
if __name__ == "__main__":
print("Starting ConvNeXt CheXpert GradCAM App...")
interface = create_gradio_interface()
interface.launch(
server_name="0.0.0.0",
server_port=7860,
share=True,
show_error=True
)