|
|
import cv2 |
|
|
import gradio as gr |
|
|
import numpy as np |
|
|
|
|
|
import os |
|
|
|
|
|
import torch |
|
|
|
|
|
|
|
|
from pytorchvideo.transforms import ( |
|
|
Normalize, |
|
|
UniformTemporalSubsample, |
|
|
) |
|
|
from torchvision.transforms import Compose, Lambda, Resize |
|
|
from transformers import VideoMAEFeatureExtractor, VideoMAEForVideoClassification |
|
|
|
|
|
from torchvision.transforms import functional as F |
|
|
|
|
|
|
|
|
import sys |
|
|
import types |
|
|
|
|
|
|
|
|
fake_ft = types.ModuleType("torchvision.transforms.functional_tensor") |
|
|
sys.modules["torchvision.transforms.functional_tensor"] = fake_ft |
|
|
|
|
|
|
|
|
MODEL_CKPT = "Shawon16/VideoMAE_BdSLW401_20_epochs_p5_SR_10" |
|
|
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
MODEL = VideoMAEForVideoClassification.from_pretrained(MODEL_CKPT).to(DEVICE) |
|
|
PROCESSOR = VideoMAEFeatureExtractor.from_pretrained(MODEL_CKPT) |
|
|
|
|
|
RESIZE_TO = PROCESSOR.size["shortest_edge"] |
|
|
NUM_FRAMES_TO_SAMPLE = MODEL.config.num_frames |
|
|
IMAGE_STATS = {"image_mean": [0.485, 0.456, 0.406], "image_std": [0.229, 0.224, 0.225]} |
|
|
VAL_TRANSFORMS = Compose( |
|
|
[ |
|
|
UniformTemporalSubsample(NUM_FRAMES_TO_SAMPLE), |
|
|
Lambda(lambda x: x / 255.0), |
|
|
Normalize(IMAGE_STATS["image_mean"], IMAGE_STATS["image_std"]), |
|
|
Resize((RESIZE_TO, RESIZE_TO)), |
|
|
] |
|
|
) |
|
|
LABELS = list(MODEL.config.label2id.keys()) |
|
|
|
|
|
def parse_video(video_file): |
|
|
"""Extract frames from a video file with a sample rate of 10.""" |
|
|
vs = cv2.VideoCapture(video_file) |
|
|
frames = [] |
|
|
frame_id = 0 |
|
|
|
|
|
while True: |
|
|
grabbed, frame = vs.read() |
|
|
if not grabbed: |
|
|
break |
|
|
|
|
|
if frame_id % 10 == 0: |
|
|
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
|
|
frames.append(frame) |
|
|
|
|
|
frame_id += 1 |
|
|
|
|
|
vs.release() |
|
|
return frames |
|
|
|
|
|
def preprocess_video(frames): |
|
|
"""Preprocess video frames for inference.""" |
|
|
video_tensor = torch.tensor(np.array(frames).astype(frames[0].dtype)) |
|
|
video_tensor = video_tensor.permute(3, 0, 1, 2) |
|
|
video_tensor_pp = VAL_TRANSFORMS(video_tensor) |
|
|
video_tensor_pp = video_tensor_pp.permute(1, 0, 2, 3) |
|
|
video_tensor_pp = video_tensor_pp.unsqueeze(0) |
|
|
return video_tensor_pp.to(DEVICE) |
|
|
|
|
|
def infer(video_file): |
|
|
frames = parse_video(video_file) |
|
|
video_tensor = preprocess_video(frames) |
|
|
inputs = {"pixel_values": video_tensor} |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = MODEL(**inputs) |
|
|
logits = outputs.logits |
|
|
softmax_scores = torch.nn.functional.softmax(logits, dim=-1).squeeze(0) |
|
|
confidences = {LABELS[i]: float(softmax_scores[i]) for i in range(len(LABELS))} |
|
|
|
|
|
return confidences, frames |
|
|
|
|
|
custom_css = """ |
|
|
/* Hide the webcam button */ |
|
|
button[data-testid="webcam-button"] { |
|
|
display: none !important; |
|
|
} |
|
|
|
|
|
/* Reduce padding and margins */ |
|
|
.gradio-container { |
|
|
max-width: 700px !important; /* Set a smaller max width */ |
|
|
margin: auto; |
|
|
padding: 10px !important; |
|
|
} |
|
|
|
|
|
/* Reduce the gallery size */ |
|
|
.gr-gallery { |
|
|
max-height: 200px !important; /* Make the frames smaller */ |
|
|
} |
|
|
|
|
|
/* Center the title */ |
|
|
h1 { |
|
|
text-align: center !important; |
|
|
} |
|
|
""" |
|
|
|
|
|
gr.Interface( |
|
|
fn=infer, |
|
|
inputs=[gr.Video(label="Upload Video")], |
|
|
outputs=[ |
|
|
gr.Label(num_top_classes=5, label="Top 5 Predictions"), |
|
|
gr.Gallery(label="Sampled Frames (Rate: 10)", columns=4, height="200px"), |
|
|
], |
|
|
examples=[ |
|
|
["W002S08F_03.mp4"], |
|
|
["W003S08F_11.mp4"], |
|
|
|
|
|
|
|
|
["W389S08F_02.mp4"], |
|
|
["W401S04F_06.mp4"], |
|
|
|
|
|
|
|
|
], |
|
|
title="Bangla Word Level (BdSLW401) Sign Language Recognition Interface", |
|
|
|
|
|
description=( |
|
|
"This framework uses a fine-tuned VideoLLM (VideoMAE) to classify Bangla Sign Language words from video inputs." |
|
|
" Upload a video for predictions." |
|
|
), |
|
|
flagging_mode="never", |
|
|
css=custom_css, |
|
|
).launch() |
|
|
|
|
|
|