BdSLW60 / app.py
Shawon16's picture
Update app.py
50f72af verified
import cv2
import gradio as gr
import numpy as np
import os
import torch
from pytorchvideo.transforms import (
Normalize,
UniformTemporalSubsample,
)
from torchvision.transforms import Compose, Lambda, Resize
from transformers import VideoMAEFeatureExtractor, VideoMAEForVideoClassification
# FIXED IMPORT
from torchvision.transforms import functional as F
# --- Fix pytorchvideo import error for Kaggle/torchvision >= 0.17 ---
import sys
import types
# Create a fake module to satisfy pytorchvideo
fake_ft = types.ModuleType("torchvision.transforms.functional_tensor")
sys.modules["torchvision.transforms.functional_tensor"] = fake_ft
# Load model and processor
MODEL_CKPT = "Shawon16/VideoMAE_BdSLW401_20_epochs_p5_SR_10"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL = VideoMAEForVideoClassification.from_pretrained(MODEL_CKPT).to(DEVICE)
PROCESSOR = VideoMAEFeatureExtractor.from_pretrained(MODEL_CKPT)
RESIZE_TO = PROCESSOR.size["shortest_edge"]
NUM_FRAMES_TO_SAMPLE = MODEL.config.num_frames
IMAGE_STATS = {"image_mean": [0.485, 0.456, 0.406], "image_std": [0.229, 0.224, 0.225]}
VAL_TRANSFORMS = Compose(
[
UniformTemporalSubsample(NUM_FRAMES_TO_SAMPLE),
Lambda(lambda x: x / 255.0),
Normalize(IMAGE_STATS["image_mean"], IMAGE_STATS["image_std"]),
Resize((RESIZE_TO, RESIZE_TO)),
]
)
LABELS = list(MODEL.config.label2id.keys())
def parse_video(video_file):
"""Extract frames from a video file with a sample rate of 10."""
vs = cv2.VideoCapture(video_file)
frames = []
frame_id = 0
while True:
grabbed, frame = vs.read()
if not grabbed:
break
if frame_id % 10 == 0: # Sample every 10th frame
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frames.append(frame)
frame_id += 1
vs.release()
return frames
def preprocess_video(frames):
"""Preprocess video frames for inference."""
video_tensor = torch.tensor(np.array(frames).astype(frames[0].dtype))
video_tensor = video_tensor.permute(3, 0, 1, 2) # (num_channels, num_frames, height, width)
video_tensor_pp = VAL_TRANSFORMS(video_tensor)
video_tensor_pp = video_tensor_pp.permute(1, 0, 2, 3) # (num_frames, num_channels, height, width)
video_tensor_pp = video_tensor_pp.unsqueeze(0) # Add batch dimension
return video_tensor_pp.to(DEVICE)
def infer(video_file):
frames = parse_video(video_file)
video_tensor = preprocess_video(frames)
inputs = {"pixel_values": video_tensor}
# Forward pass
with torch.no_grad():
outputs = MODEL(**inputs)
logits = outputs.logits
softmax_scores = torch.nn.functional.softmax(logits, dim=-1).squeeze(0)
confidences = {LABELS[i]: float(softmax_scores[i]) for i in range(len(LABELS))}
return confidences, frames # Remove confidence plot
custom_css = """
/* Hide the webcam button */
button[data-testid="webcam-button"] {
display: none !important;
}
/* Reduce padding and margins */
.gradio-container {
max-width: 700px !important; /* Set a smaller max width */
margin: auto;
padding: 10px !important;
}
/* Reduce the gallery size */
.gr-gallery {
max-height: 200px !important; /* Make the frames smaller */
}
/* Center the title */
h1 {
text-align: center !important;
}
"""
gr.Interface(
fn=infer,
inputs=[gr.Video(label="Upload Video")], # Keep Video for preview
outputs=[
gr.Label(num_top_classes=5, label="Top 5 Predictions"),
gr.Gallery(label="Sampled Frames (Rate: 10)", columns=4, height="200px"), # Smaller gallery
],
examples=[
["W002S08F_03.mp4"],
["W003S08F_11.mp4"],
#["W205S08F_02.mp4"],
#["W211S04F_03.mp4"],
["W389S08F_02.mp4"],
["W401S04F_06.mp4"],
#[r"C:\Users\shawo\Desktop\BdSLW60 Full DataSet\FrameRate Corrected Clips\W2\U8W2F_trial_6_R.mp4"],
#[r"C:\Users\shawo\Desktop\BdSLW60 Full DataSet\FrameRate Corrected Clips\W20\U4W20F_trial_9_R.mp4"],
],
title="Bangla Word Level (BdSLW401) Sign Language Recognition Interface",
description=(
"This framework uses a fine-tuned VideoLLM (VideoMAE) to classify Bangla Sign Language words from video inputs."
" Upload a video for predictions."
),
flagging_mode="never",
css=custom_css, # Apply custom CSS for compact design
).launch()