zye0616 commited on
Commit
2ab6e0a
·
1 Parent(s): bfe562f

initial commit

Browse files
models/model_loader.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import Tuple
3
+
4
+ import torch
5
+ from transformers import Owlv2ForObjectDetection, Owlv2Processor
6
+
7
+ MODEL_NAME = "google/owlv2-large-patch14"
8
+ _DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
+
10
+ logging.info("Loading %s onto %s", MODEL_NAME, _DEVICE)
11
+ _PROCESSOR = Owlv2Processor.from_pretrained(MODEL_NAME)
12
+ torch_dtype = torch.float16 if _DEVICE.type == "cuda" else torch.float32
13
+ _MODEL = Owlv2ForObjectDetection.from_pretrained(MODEL_NAME, torch_dtype=torch_dtype)
14
+ _MODEL.to(_DEVICE)
15
+ _MODEL.eval()
16
+
17
+
18
+ def load_model() -> Tuple[Owlv2Processor, Owlv2ForObjectDetection, torch.device]:
19
+ """Expose processor/model singletons so the API never reloads weights."""
20
+ return _PROCESSOR, _MODEL, _DEVICE
utils/.gitignore ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ .venv/
3
+ *.mp4
4
+ *.log
5
+ *.tmp
6
+ .DS_Store
utils/app.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import tempfile
4
+ from pathlib import Path
5
+
6
+ from fastapi import BackgroundTasks, FastAPI, File, Form, HTTPException, UploadFile
7
+ from fastapi.responses import FileResponse, JSONResponse
8
+ import uvicorn
9
+
10
+ from inference import run_inference
11
+
12
+ logging.basicConfig(level=logging.INFO)
13
+
14
+ app = FastAPI(title="Video Processing Backend")
15
+
16
+
17
+ def _save_upload_to_tmp(upload: UploadFile) -> str:
18
+ suffix = Path(upload.filename or "upload.mp4").suffix or ".mp4"
19
+ fd, path = tempfile.mkstemp(prefix="input_", suffix=suffix, dir="/tmp")
20
+ os.close(fd)
21
+ with open(path, "wb") as buffer:
22
+ data = upload.file.read()
23
+ buffer.write(data)
24
+ return path
25
+
26
+
27
+ def _safe_delete(path: str) -> None:
28
+ try:
29
+ os.remove(path)
30
+ except FileNotFoundError:
31
+ return
32
+ except Exception:
33
+ logging.exception("Failed to remove temporary file: %s", path)
34
+
35
+
36
+ def _schedule_cleanup(background_tasks: BackgroundTasks, path: str) -> None:
37
+ def _cleanup(target: str = path) -> None:
38
+ _safe_delete(target)
39
+
40
+ background_tasks.add_task(_cleanup)
41
+
42
+
43
+ @app.post("/process_video")
44
+ async def process_video(
45
+ background_tasks: BackgroundTasks,
46
+ video: UploadFile = File(...),
47
+ prompt: str = Form(...),
48
+ ):
49
+ if video is None:
50
+ raise HTTPException(status_code=400, detail="Video file is required.")
51
+ if not prompt:
52
+ raise HTTPException(status_code=400, detail="Prompt is required.")
53
+
54
+ try:
55
+ input_path = _save_upload_to_tmp(video)
56
+ except Exception:
57
+ logging.exception("Failed to save uploaded file.")
58
+ raise HTTPException(status_code=500, detail="Failed to save uploaded video.")
59
+ finally:
60
+ await video.close()
61
+
62
+ fd, output_path = tempfile.mkstemp(prefix="output_", suffix=".mp4", dir="/tmp")
63
+ os.close(fd)
64
+
65
+ try:
66
+ run_inference(input_path, output_path, prompt, max_frames=10)
67
+ except ValueError as exc:
68
+ logging.exception("Video decoding failed.")
69
+ _safe_delete(input_path)
70
+ _safe_delete(output_path)
71
+ raise HTTPException(status_code=500, detail=str(exc))
72
+ except Exception as exc:
73
+ logging.exception("Inference failed.")
74
+ _safe_delete(input_path)
75
+ _safe_delete(output_path)
76
+ return JSONResponse(status_code=500, content={"error": str(exc)})
77
+
78
+ _schedule_cleanup(background_tasks, input_path)
79
+ _schedule_cleanup(background_tasks, output_path)
80
+
81
+ return FileResponse(
82
+ path=output_path,
83
+ media_type="video/mp4",
84
+ filename="processed.mp4",
85
+ )
86
+
87
+
88
+ if __name__ == "__main__":
89
+ uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=False)
utils/inference.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import List, Optional
3
+
4
+ import cv2
5
+ import numpy as np
6
+ import torch
7
+
8
+ from models.model_loader import load_model
9
+ from utils.video import extract_frames, write_video
10
+
11
+
12
+ def draw_boxes(frame: np.ndarray, boxes: np.ndarray) -> np.ndarray:
13
+ output = frame.copy()
14
+ if boxes is None:
15
+ return output
16
+ for box in boxes:
17
+ x1, y1, x2, y2 = [int(coord) for coord in box]
18
+ cv2.rectangle(output, (x1, y1), (x2, y2), (0, 255, 0), thickness=2)
19
+ return output
20
+
21
+
22
+ def infer_frame(frame: np.ndarray, prompt: str) -> np.ndarray:
23
+ processor, model, device = load_model()
24
+ try:
25
+ inputs = processor(text=[prompt], images=frame, return_tensors="pt")
26
+ if hasattr(inputs, "to"):
27
+ inputs = inputs.to(device)
28
+ else:
29
+ inputs = {k: v.to(device) if hasattr(v, "to") else v for k, v in inputs.items()}
30
+ with torch.no_grad():
31
+ outputs = model(**inputs)
32
+ results = processor.post_process_object_detection(
33
+ outputs,
34
+ threshold=0.3,
35
+ target_sizes=[frame.shape[:2]],
36
+ )[0]
37
+ boxes = results["boxes"]
38
+ if hasattr(boxes, "cpu"):
39
+ boxes_np = boxes.cpu().numpy()
40
+ else:
41
+ boxes_np = np.asarray(boxes)
42
+ except Exception:
43
+ logging.exception("Inference failed for prompt '%s'", prompt)
44
+ raise
45
+ return draw_boxes(frame, boxes_np)
46
+
47
+
48
+ def run_inference(
49
+ input_video_path: str,
50
+ output_video_path: str,
51
+ prompt: str,
52
+ max_frames: Optional[int] = None,
53
+ ) -> str:
54
+ try:
55
+ frames, fps, width, height = extract_frames(input_video_path)
56
+ except ValueError as exc:
57
+ logging.exception("Failed to decode video at %s", input_video_path)
58
+ raise
59
+
60
+ processed_frames: List[np.ndarray] = []
61
+ for idx, frame in enumerate(frames):
62
+ if max_frames is not None and idx >= max_frames:
63
+ break
64
+ logging.debug("Processing frame %d", idx)
65
+ processed_frame = infer_frame(frame, prompt)
66
+ processed_frames.append(processed_frame)
67
+
68
+ write_video(processed_frames, output_video_path, fps=fps, width=width, height=height)
69
+ return output_video_path
utils/requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn
3
+ torch
4
+ transformers
5
+ opencv-python
6
+ python-multipart
7
+ accelerate
8
+ pillow
9
+ scipy
utils/video.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Tuple
2
+
3
+ import cv2
4
+ import numpy as np
5
+
6
+
7
+ def extract_frames(video_path: str) -> Tuple[List[np.ndarray], float, int, int]:
8
+ cap = cv2.VideoCapture(video_path)
9
+ if not cap.isOpened():
10
+ raise ValueError("Unable to open video.")
11
+
12
+ fps = cap.get(cv2.CAP_PROP_FPS) or 0.0
13
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
14
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
15
+
16
+ frames: List[np.ndarray] = []
17
+ success, frame = cap.read()
18
+ while success:
19
+ frames.append(frame)
20
+ success, frame = cap.read()
21
+
22
+ cap.release()
23
+
24
+ if not frames:
25
+ raise ValueError("Video decode produced zero frames.")
26
+
27
+ return frames, fps, width, height
28
+
29
+
30
+ def write_video(frames: List[np.ndarray], output_path: str, fps: float, width: int, height: int) -> None:
31
+ if not frames:
32
+ raise ValueError("No frames available for writing.")
33
+ fourcc = cv2.VideoWriter_fourcc(*"mp4v")
34
+ writer = cv2.VideoWriter(output_path, fourcc, fps or 1.0, (width, height))
35
+ if not writer.isOpened():
36
+ raise ValueError("Failed to open VideoWriter.")
37
+
38
+ for frame in frames:
39
+ writer.write(frame)
40
+
41
+ writer.release()