File size: 4,449 Bytes
123bcda 4697a0b 46daabb 03c413c 123bcda 53313b4 46daabb 123bcda 53313b4 50f72af 53313b4 b25c835 123bcda 53313b4 123bcda b25c835 123bcda b25c835 123bcda b25c835 123bcda b25c835 123bcda b25c835 123bcda b25c835 123bcda b25c835 123bcda b25c835 123bcda b25c835 123bcda b25c835 123bcda 3543cb0 8f245f3 3543cb0 123bcda 3543cb0 b25c835 3543cb0 b25c835 123bcda f2b46a7 778696e f2b46a7 7d4b91b b3aabda 123bcda 0ab0ff0 8f245f3 35db1ff 8f245f3 b25c835 3543cb0 b25c835 3543cb0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
import cv2
import gradio as gr
import numpy as np
import os
import torch
from pytorchvideo.transforms import (
Normalize,
UniformTemporalSubsample,
)
from torchvision.transforms import Compose, Lambda, Resize
from transformers import VideoMAEFeatureExtractor, VideoMAEForVideoClassification
# FIXED IMPORT
from torchvision.transforms import functional as F
# --- Fix pytorchvideo import error for Kaggle/torchvision >= 0.17 ---
import sys
import types
# Create a fake module to satisfy pytorchvideo
fake_ft = types.ModuleType("torchvision.transforms.functional_tensor")
sys.modules["torchvision.transforms.functional_tensor"] = fake_ft
# Load model and processor
MODEL_CKPT = "Shawon16/VideoMAE_BdSLW401_20_epochs_p5_SR_10"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL = VideoMAEForVideoClassification.from_pretrained(MODEL_CKPT).to(DEVICE)
PROCESSOR = VideoMAEFeatureExtractor.from_pretrained(MODEL_CKPT)
RESIZE_TO = PROCESSOR.size["shortest_edge"]
NUM_FRAMES_TO_SAMPLE = MODEL.config.num_frames
IMAGE_STATS = {"image_mean": [0.485, 0.456, 0.406], "image_std": [0.229, 0.224, 0.225]}
VAL_TRANSFORMS = Compose(
[
UniformTemporalSubsample(NUM_FRAMES_TO_SAMPLE),
Lambda(lambda x: x / 255.0),
Normalize(IMAGE_STATS["image_mean"], IMAGE_STATS["image_std"]),
Resize((RESIZE_TO, RESIZE_TO)),
]
)
LABELS = list(MODEL.config.label2id.keys())
def parse_video(video_file):
"""Extract frames from a video file with a sample rate of 10."""
vs = cv2.VideoCapture(video_file)
frames = []
frame_id = 0
while True:
grabbed, frame = vs.read()
if not grabbed:
break
if frame_id % 10 == 0: # Sample every 10th frame
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frames.append(frame)
frame_id += 1
vs.release()
return frames
def preprocess_video(frames):
"""Preprocess video frames for inference."""
video_tensor = torch.tensor(np.array(frames).astype(frames[0].dtype))
video_tensor = video_tensor.permute(3, 0, 1, 2) # (num_channels, num_frames, height, width)
video_tensor_pp = VAL_TRANSFORMS(video_tensor)
video_tensor_pp = video_tensor_pp.permute(1, 0, 2, 3) # (num_frames, num_channels, height, width)
video_tensor_pp = video_tensor_pp.unsqueeze(0) # Add batch dimension
return video_tensor_pp.to(DEVICE)
def infer(video_file):
frames = parse_video(video_file)
video_tensor = preprocess_video(frames)
inputs = {"pixel_values": video_tensor}
# Forward pass
with torch.no_grad():
outputs = MODEL(**inputs)
logits = outputs.logits
softmax_scores = torch.nn.functional.softmax(logits, dim=-1).squeeze(0)
confidences = {LABELS[i]: float(softmax_scores[i]) for i in range(len(LABELS))}
return confidences, frames # Remove confidence plot
custom_css = """
/* Hide the webcam button */
button[data-testid="webcam-button"] {
display: none !important;
}
/* Reduce padding and margins */
.gradio-container {
max-width: 700px !important; /* Set a smaller max width */
margin: auto;
padding: 10px !important;
}
/* Reduce the gallery size */
.gr-gallery {
max-height: 200px !important; /* Make the frames smaller */
}
/* Center the title */
h1 {
text-align: center !important;
}
"""
gr.Interface(
fn=infer,
inputs=[gr.Video(label="Upload Video")], # Keep Video for preview
outputs=[
gr.Label(num_top_classes=5, label="Top 5 Predictions"),
gr.Gallery(label="Sampled Frames (Rate: 10)", columns=4, height="200px"), # Smaller gallery
],
examples=[
["W002S08F_03.mp4"],
["W003S08F_11.mp4"],
#["W205S08F_02.mp4"],
#["W211S04F_03.mp4"],
["W389S08F_02.mp4"],
["W401S04F_06.mp4"],
#[r"C:\Users\shawo\Desktop\BdSLW60 Full DataSet\FrameRate Corrected Clips\W2\U8W2F_trial_6_R.mp4"],
#[r"C:\Users\shawo\Desktop\BdSLW60 Full DataSet\FrameRate Corrected Clips\W20\U4W20F_trial_9_R.mp4"],
],
title="Bangla Word Level (BdSLW401) Sign Language Recognition Interface",
description=(
"This framework uses a fine-tuned VideoLLM (VideoMAE) to classify Bangla Sign Language words from video inputs."
" Upload a video for predictions."
),
flagging_mode="never",
css=custom_css, # Apply custom CSS for compact design
).launch()
|