Spaces:
Runtime error
Runtime error
| import json | |
| from pprint import pprint | |
| import torch | |
| import torch.hub | |
| from gradio import Interface, inputs, outputs | |
| from PIL import Image | |
| from torchvision import transforms | |
| real_load = torch.hub.load_state_dict_from_url | |
| def load_state_dict_from_url(*args, **kwargs): | |
| kwargs["map_location"] = "cpu" | |
| return real_load(*args, **kwargs) | |
| torch.hub.load_state_dict_from_url = load_state_dict_from_url | |
| model = torch.hub.load("RF5/danbooru-pretrained", "resnet50") | |
| model.eval() | |
| with open("./tags.json", "rt", encoding="utf-8") as f: | |
| tags = json.load(f) | |
| def main(input_image: Image.Image, threshold: float): | |
| preprocess = transforms.Compose( | |
| [ | |
| transforms.Resize(360), | |
| transforms.ToTensor(), | |
| transforms.Normalize( | |
| mean=[0.7137, 0.6628, 0.6519], std=[0.2970, 0.3017, 0.2979] | |
| ), | |
| ] | |
| ) | |
| input_tensor = preprocess(input_image) | |
| input_batch = input_tensor.unsqueeze( | |
| 0 | |
| ) # create a mini-batch as expected by the model | |
| with torch.no_grad(): | |
| output, *_ = model(input_batch) | |
| probs = torch.sigmoid(output) | |
| results = probs[probs > threshold] | |
| inds = probs.argsort(descending=True) | |
| tag_confidences = {} | |
| for index in inds[0 : len(results)]: | |
| tag_confidences[tags[index]] = float(probs[index].cpu().numpy()) | |
| pprint(tag_confidences) | |
| return tag_confidences | |
| image = inputs.Image(label="Upload your image here!", type="pil") | |
| threshold = inputs.Slider( | |
| label="Hide images confidence under", maximum=1, minimum=0, default=0.2 | |
| ) | |
| labels = outputs.Label(label="Tags", type="confidences") | |
| interface = Interface(main, inputs=[image, threshold], outputs=[labels]) | |
| interface.launch() | |