| | |
| |
|
| | import torch |
| | import torchvision.transforms as T |
| | from PIL import Image |
| | import io |
| | import json |
| |
|
| | |
| | CLASS_LABELS = [ |
| | "glove_outline", |
| | "webbing", |
| | "thumb", |
| | "palm_pocket", |
| | "hand", |
| | "glove_exterior" |
| | ] |
| |
|
| | |
| | |
| | |
| | def load_model(): |
| | model = torch.load("pytorch_model.bin", map_location="cpu") |
| | model.eval() |
| | return model |
| |
|
| | model = load_model() |
| |
|
| | |
| | |
| | |
| | transform = T.Compose([ |
| | T.Resize((720, 1280)), |
| | T.ToTensor() |
| | ]) |
| |
|
| | def preprocess(input_bytes): |
| | image = Image.open(io.BytesIO(input_bytes)).convert("RGB") |
| | tensor = transform(image).unsqueeze(0) |
| | return tensor |
| |
|
| | |
| | |
| | |
| | class DummyInput: |
| | def __init__(self, image_tensor): |
| | B, C, H, W = image_tensor.shape |
| | self.images = image_tensor |
| | self.masks = [torch.zeros(B, H, W, dtype=torch.bool)] |
| | self.num_frames = 1 |
| | self.original_size = [(H, W)] |
| | self.target_size = [(H, W)] |
| | self.point_coords = [None] |
| | self.point_labels = [None] |
| | self.boxes = [None] |
| | self.mask_inputs = torch.zeros(B, 1, H, W) |
| | self.video_mask = torch.zeros(B, 1, H, W) |
| | self.flat_obj_to_img_idx = [[0]] |
| |
|
| | |
| | |
| | |
| | def postprocess(output_tensor): |
| | if isinstance(output_tensor, dict) and "masks" in output_tensor: |
| | logits = output_tensor["masks"] |
| | else: |
| | logits = output_tensor |
| | pred = torch.argmax(logits, dim=1)[0].cpu().numpy() |
| | return pred.tolist() |
| |
|
| | |
| | |
| | |
| | def infer(payload): |
| | if isinstance(payload, bytes): |
| | image_tensor = preprocess(payload) |
| | elif isinstance(payload, dict) and "inputs" in payload: |
| | from base64 import b64decode |
| | image_tensor = preprocess(b64decode(payload["inputs"])) |
| | else: |
| | raise ValueError("Unsupported input format") |
| |
|
| | input_obj = DummyInput(image_tensor) |
| |
|
| | with torch.no_grad(): |
| | output = model(input_obj) |
| |
|
| | mask = postprocess(output) |
| | return { |
| | "mask": mask, |
| | "classes": CLASS_LABELS |
| | } |
| |
|