Spaces:
Sleeping
Sleeping
| import torchvision | |
| import gradio as gr | |
| import torch | |
| import torch.nn.functional as F | |
| import einops | |
| import matplotlib.cm as cm | |
| import numpy as np | |
| def colorize(tensor, cmap_fn=cm.turbo): | |
| colors = cmap_fn(np.linspace(0, 1, 256))[:, :3] | |
| colors = torch.from_numpy(colors).to(tensor) | |
| tensor = tensor.squeeze(1) if tensor.ndim == 4 else tensor | |
| ids = (tensor * 256).clamp(0, 255).long() | |
| tensor = F.embedding(ids, colors).permute(0, 3, 1, 2) | |
| tensor = tensor.mul(255).clamp(0, 255).byte() | |
| return tensor | |
| with open("classes.txt") as f: | |
| id2label = f.read().splitlines() | |
| id2label = [c.split(",")[0].lower() for c in id2label] | |
| label2id = dict([(c, i) for i, c in enumerate(id2label)]) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model = torchvision.models.resnet50(weights="DEFAULT") | |
| model.eval() | |
| model.to(device) | |
| fmap_pool = dict() | |
| grad_pool = dict() | |
| def forward_hook(name): | |
| def _hook(module, input, output): | |
| fmap_pool[name] = output.detach() | |
| return _hook | |
| def backward_hook(name): | |
| def _hook(module, grad_in, grad_out): | |
| grad_pool[name] = grad_out[0].detach() | |
| return _hook | |
| layer_choices = [] | |
| for n, m in model.named_children(): | |
| layer_choices.append(n) | |
| m.register_forward_hook(forward_hook(n)) | |
| m.register_backward_hook(backward_hook(n)) | |
| preprocess = torchvision.transforms.Compose( | |
| [ | |
| torchvision.transforms.ToTensor(), | |
| torchvision.transforms.Resize((224, 224)), | |
| torchvision.transforms.Normalize( | |
| mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] | |
| ), | |
| ] | |
| ) | |
| def predict(image): | |
| if image is None: | |
| return None, None | |
| image = preprocess(image)[None].to(device) | |
| probs = model(image).softmax(dim=1) | |
| result = dict([(c, float(p)) for c, p in zip(id2label, probs[0])]) | |
| return result, None | |
| def gradcam(image_orig, layer, event: gr.SelectData): | |
| # forward & backward | |
| target_class = torch.tensor([label2id[event.value]], device=device) | |
| gradient = F.one_hot(target_class, num_classes=len(label2id)).float() | |
| image = preprocess(image_orig)[None] | |
| model(image).backward(gradient=gradient) | |
| # Grad-CAM | |
| fmaps = fmap_pool[layer] | |
| grads = grad_pool[layer] | |
| weights = F.adaptive_avg_pool2d(grads, 1) | |
| gcam = torch.mul(fmaps, weights).sum(dim=1, keepdim=True) | |
| gcam = F.relu(gcam) | |
| # post-process | |
| gcam = F.interpolate( | |
| gcam, size=image_orig.shape[:2], mode="bicubic", antialias=True | |
| ) | |
| gcam -= einops.reduce(gcam, "b c h w -> b () () ()", "min") | |
| gcam /= einops.reduce(gcam, "b c h w -> b () () ()", "max") | |
| gcam = colorize(gcam)[0].permute(1, 2, 0).cpu().numpy() | |
| return gcam | |
| with gr.Blocks(title="Grad-CAM") as demo: | |
| gr.Markdown( | |
| """ | |
| # Grad-CAM | |
| Unofficial re-implementation of Grad-CAM (https://arxiv.org/abs/1610.02391).<br> | |
| Upload an image and select a prediction to show the Grad-CAM heatmap. | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| layer = gr.Dropdown(layer_choices, label="ResNet-50", value="layer4") | |
| image = gr.Image(label="input", type="numpy") | |
| label = gr.Label(num_top_classes=10, label="top-10 predictions") | |
| exmpl = gr.Examples(["cat_dog.png"], image) | |
| with gr.Column(): | |
| img_out = gr.Image(type="numpy", label="result") | |
| image.change(predict, inputs=[image], outputs=[label, img_out]) | |
| label.select(gradcam, inputs=[image, layer], outputs=[img_out]) | |
| demo.launch() | |