Spaces:
Sleeping
Sleeping
| import os | |
| import tempfile | |
| from collections import defaultdict | |
| from typing import Optional, Tuple, List | |
| import cv2 | |
| import numpy as np | |
| import pandas as pd | |
| import plotly.graph_objects as go | |
| import requests | |
| import streamlit as st | |
| from ultralytics import YOLO | |
| # Try to import yt_dlp; if not available, we will show a helpful message when user tries YouTube | |
| try: | |
| import yt_dlp # type: ignore | |
| _YT_DLP_AVAILABLE = True | |
| except Exception: | |
| _YT_DLP_AVAILABLE = False | |
| # --- Page config --- | |
| st.set_page_config(page_title="YOLOv8 Object Tracking & Counter", page_icon="π€", layout="wide") | |
| st.title("π¦ Smart Object Traffic Analyzer (YOLOv8)") | |
| st.markdown( | |
| """ | |
| Process local videos, direct public video URLs, or YouTube links to track and count unique object crossings. | |
| Uses YOLOv8 detection and ByteTrack (when available) for robust multi-object tracking. | |
| """ | |
| ) | |
| # --- Class mappings (subset of COCO) --- | |
| COCO_CLASS_NAMES = { | |
| 0: "person", 1: "bicycle", 2: "car", 3: "motorcycle", 4: "airplane", | |
| 5: "bus", 6: "train", 7: "truck", 8: "boat", 9: "traffic light" | |
| } | |
| CLASS_MAPPING = { | |
| "Person": 0, | |
| "Bicycle": 1, | |
| "Car": 2, | |
| "Motorcycle": 3, | |
| "Bus": 5, | |
| "Truck": 7, | |
| } | |
| # --- Session state initialization --- | |
| if "processed_data" not in st.session_state: | |
| st.session_state.processed_data = { | |
| "total_counts": defaultdict(int), | |
| "frame_counts": [], | |
| "processed_video": None, | |
| "processing_complete": False, | |
| "tracked_objects": {}, | |
| } | |
| # --- Sidebar: configuration --- | |
| with st.sidebar: | |
| st.header("βοΈ Configuration") | |
| st.subheader("Model & detection") | |
| model_name = st.selectbox("Select YOLO model", options=["yolov8n.pt", "yolov8s.pt"], | |
| help="Nano (n) is fast; Small (s) is more accurate.") | |
| confidence = st.slider("Detection confidence threshold", min_value=0.1, max_value=1.0, | |
| value=0.40, step=0.05, help="Minimum confidence to consider a detection valid.") | |
| st.subheader("Objects for counting") | |
| selected_classes_ui = {} | |
| for name in CLASS_MAPPING.keys(): | |
| default_val = name in ["Person", "Car"] | |
| selected_classes_ui[name] = st.checkbox(name, value=default_val) | |
| st.subheader("Counting line settings") | |
| show_line = st.checkbox("Show crossing line", value=True) | |
| line_position = st.slider("Line position (vertical % from left)", min_value=10, max_value=90, value=50, | |
| help="Place the vertical counting line as a percentage of frame width.") | |
| st.subheader("Performance options") | |
| process_every_nth = st.slider("Frame skip (process every Nth frame)", min_value=1, max_value=10, value=2, | |
| help="Higher values speed up processing but reduce tracking smoothness.") | |
| max_frames = st.number_input("Maximum frames to analyze", min_value=10, max_value=5000, value=500, | |
| help="Limit processing for long videos. Increase for full videos.") | |
| # --- Helpers --- | |
| def load_model(model_path: str): | |
| """Load and cache YOLO model.""" | |
| return YOLO(model_path) | |
| def get_selected_class_ids() -> List[int]: | |
| """Return list of selected COCO class IDs.""" | |
| return [CLASS_MAPPING[name] for name, selected in selected_classes_ui.items() if selected] | |
| def download_direct_url(url: str, timeout: int = 30) -> Tuple[Optional[str], Optional[str]]: | |
| """ | |
| Download a direct video URL (mp4/mov/etc.) to a temporary file. | |
| Returns (file_path, error_message). On success error_message is None. | |
| """ | |
| try: | |
| resp = requests.get(url, stream=True, timeout=timeout) | |
| resp.raise_for_status() | |
| content_type = resp.headers.get("Content-Type", "") | |
| suffix = ".mp4" if "mp4" in content_type.lower() or url.lower().endswith(".mp4") else ".mp4" | |
| temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=suffix) | |
| for chunk in resp.iter_content(chunk_size=8192): | |
| if not chunk: | |
| continue | |
| temp_file.write(chunk) | |
| temp_file.close() | |
| return temp_file.name, None | |
| except requests.exceptions.RequestException as e: | |
| return None, f"Failed to download direct URL: {e}. Check the URL and network access." | |
| except Exception as e: | |
| return None, f"Unexpected error while downloading direct URL: {e}" | |
| def download_youtube_video(youtube_url: str) -> Tuple[Optional[str], Optional[str]]: | |
| """ | |
| Attempt to download a YouTube video using yt-dlp. | |
| Returns (file_path, error_message). If download succeeds, error_message is None. | |
| """ | |
| if not _YT_DLP_AVAILABLE: | |
| return None, "yt-dlp is not available in this environment. Install yt-dlp or use a direct URL / upload." | |
| try: | |
| temp_dir = tempfile.mkdtemp() | |
| output_template = os.path.join(temp_dir, "video.%(ext)s") | |
| ydl_opts = { | |
| "format": "best[ext=mp4]/best", | |
| "outtmpl": output_template, | |
| "noplaylist": True, | |
| "quiet": True, | |
| "no_warnings": True, | |
| "retries": 2, | |
| "merge_output_format": "mp4", | |
| } | |
| with yt_dlp.YoutubeDL(ydl_opts) as ydl: | |
| info = ydl.extract_info(youtube_url, download=True) | |
| filename = ydl.prepare_filename(info) | |
| # prefer .mp4 if merged | |
| if not filename.endswith(".mp4"): | |
| mp4_candidate = os.path.splitext(filename)[0] + ".mp4" | |
| if os.path.exists(mp4_candidate): | |
| filename = mp4_candidate | |
| if os.path.exists(filename): | |
| return filename, None | |
| else: | |
| return None, "Download completed but output file not found." | |
| except yt_dlp.utils.DownloadError as e: | |
| # Likely network or availability issue | |
| guidance = ( | |
| "yt-dlp failed to download the YouTube video. This can happen if the runtime has no outbound network access " | |
| "or YouTube is blocked. Alternatives:\n" | |
| "β’ Upload the video file directly using the uploader.\n" | |
| "β’ Provide a direct public MP4 URL (use the Direct URL option).\n" | |
| "β’ Host the video in the Space repository or on the Hugging Face Hub and provide the path.\n" | |
| "β’ Run the app locally where internet access is available." | |
| ) | |
| return None, f"{e}\n\n{guidance}" | |
| except Exception as e: | |
| return None, f"Unexpected error while downloading YouTube video: {e}" | |
| # --- Core processing function --- | |
| def process_video(video_path: str, selected_class_ids: List[int], model_path: str) -> Optional[str]: | |
| """ | |
| Process the video, perform detection + tracking, count crossings, and write an annotated output video. | |
| Returns path to annotated video on success, otherwise None. | |
| """ | |
| model = load_model(model_path) | |
| cap = cv2.VideoCapture(video_path) | |
| if not cap.isOpened(): | |
| st.error("Could not open the video file. The file may be corrupted or in an unsupported format.") | |
| return None | |
| fps = int(cap.get(cv2.CAP_PROP_FPS)) or 30 | |
| width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) or 640 | |
| height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) or 360 | |
| total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) or 0 | |
| if total_frames > max_frames: | |
| st.warning(f"Video will be processed for the first {max_frames} frames only (sidebar setting).") | |
| temp_output = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") | |
| output_path = temp_output.name | |
| fourcc = cv2.VideoWriter_fourcc(*"mp4v") | |
| out = cv2.VideoWriter(output_path, fourcc, max(int(fps / process_every_nth), 1), (width, height)) | |
| state = st.session_state.processed_data | |
| state["total_counts"] = defaultdict(int) | |
| state["frame_counts"] = [] | |
| state["tracked_objects"] = {} | |
| line_x = int(width * line_position / 100) | |
| progress_bar = st.progress(0) | |
| status_text = st.empty() | |
| frame_idx = 0 | |
| processed_frames = 0 | |
| while cap.isOpened(): | |
| ret, frame = cap.read() | |
| if not ret or processed_frames >= max_frames: | |
| break | |
| frame_idx += 1 | |
| if frame_idx % process_every_nth != 0: | |
| continue | |
| # Run YOLOv8 tracking (ByteTrack if available in ultralytics) | |
| try: | |
| results = model.track( | |
| frame, | |
| conf=confidence, | |
| classes=selected_class_ids if selected_class_ids else None, | |
| persist=True, | |
| tracker="bytetrack.yaml", | |
| verbose=False | |
| ) | |
| except Exception: | |
| # Fallback to detection-only if tracker config not available | |
| results = model(frame, conf=confidence, classes=selected_class_ids if selected_class_ids else None) | |
| annotated = frame.copy() | |
| frame_counts = defaultdict(int) | |
| # Parse results (works for both track and detect outputs) | |
| if results and hasattr(results[0], "boxes"): | |
| boxes_obj = results[0].boxes | |
| # Some detect-only outputs may not have ids | |
| ids_attr = getattr(boxes_obj, "id", None) | |
| try: | |
| boxes = boxes_obj.xyxy.cpu().numpy().astype(int) | |
| class_ids = boxes_obj.cls.cpu().numpy().astype(int) | |
| except Exception: | |
| boxes = [] | |
| class_ids = [] | |
| ids = None | |
| if ids_attr is not None: | |
| try: | |
| ids = ids_attr.cpu().numpy().astype(int) | |
| except Exception: | |
| ids = None | |
| if len(boxes) > 0: | |
| for i, box in enumerate(boxes): | |
| x1, y1, x2, y2 = box | |
| cx, cy = (x1 + x2) // 2, (y1 + y2) // 2 | |
| cls_id = int(class_ids[i]) if i < len(class_ids) else -1 | |
| cls_name = COCO_CLASS_NAMES.get(cls_id, "Unknown") | |
| frame_counts[cls_name.lower()] += 1 | |
| track_id = int(ids[i]) if (ids is not None and i < len(ids)) else None | |
| if track_id is None: | |
| # Use a synthetic id based on bbox and frame to avoid counting duplicates across frames | |
| track_id = hash((x1, y1, x2, y2, frame_idx)) & 0x7FFFFFFF | |
| if track_id not in state["tracked_objects"]: | |
| state["tracked_objects"][track_id] = { | |
| "class": cls_name, | |
| "last_centroid": (cx, cy), | |
| "counted": False | |
| } | |
| else: | |
| obj = state["tracked_objects"][track_id] | |
| prev_x = obj["last_centroid"][0] | |
| if not obj["counted"]: | |
| crossed_right = prev_x < line_x and cx >= line_x | |
| crossed_left = prev_x > line_x and cx <= line_x | |
| if crossed_right or crossed_left: | |
| state["total_counts"][cls_name] += 1 | |
| obj["counted"] = True | |
| obj["last_centroid"] = (cx, cy) | |
| # Draw annotations | |
| cv2.rectangle(annotated, (x1, y1), (x2, y2), (255, 0, 0), 2) | |
| cv2.circle(annotated, (cx, cy), 5, (0, 0, 255), -1) | |
| label = f"ID:{track_id} {cls_name}" | |
| cv2.putText(annotated, label, (x1, max(10, y1 - 10)), | |
| cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2) | |
| # Draw counting line and totals | |
| if show_line: | |
| line_color = (0, 255, 255) | |
| cv2.line(annotated, (line_x, 0), (line_x, height), line_color, 2) | |
| cv2.putText(annotated, "COUNTING LINE", (min(width - 180, line_x + 5), 20), | |
| cv2.FONT_HERSHEY_SIMPLEX, 0.6, line_color, 2) | |
| y_offset = 30 | |
| for obj_type, count in state["total_counts"].items(): | |
| cv2.putText(annotated, f"TOTAL {obj_type.upper()}: {count}", (max(10, width - 320), y_offset), | |
| cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 255, 0), 2) | |
| y_offset += 35 | |
| # Save frame counts | |
| frame_data = {"frame": processed_frames * process_every_nth} | |
| for name in CLASS_MAPPING.keys(): | |
| frame_data[name.lower()] = frame_counts.get(name.lower(), 0) | |
| state["frame_counts"].append(frame_data) | |
| out.write(annotated) | |
| processed_frames += 1 | |
| progress = min(processed_frames / max_frames, 1.0) | |
| progress_bar.progress(progress) | |
| status_text.text(f"Analyzing frame {frame_idx}/{total_frames or 'unknown'} (Processed {processed_frames})") | |
| cap.release() | |
| out.release() | |
| state["processing_complete"] = True | |
| state["processed_video"] = output_path | |
| st.session_state.processed_data = state | |
| return output_path | |
| # --- UI layout: tabs --- | |
| tab1, tab2, tab3 = st.tabs(["πΉ Video input", "π Analysis & results", "βΉοΈ Documentation"]) | |
| with tab1: | |
| col1, col2 = st.columns(2) | |
| video_path: Optional[str] = None | |
| with col1: | |
| st.subheader("π Upload video file") | |
| uploaded_file = st.file_uploader("Choose a video file", type=["mp4", "avi", "mov", "mkv"], | |
| help="Supported formats. For large files, consider shorter clips.") | |
| if uploaded_file is not None: | |
| tfile = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") | |
| tfile.write(uploaded_file.getbuffer()) | |
| tfile.close() | |
| video_path = tfile.name | |
| st.info(f"Video ready: {uploaded_file.name}") | |
| st.video(uploaded_file) | |
| with col2: | |
| st.subheader("π Direct public video URL") | |
| direct_url = st.text_input("Enter a direct public video URL (e.g., .mp4)", placeholder="https://example.com/video.mp4") | |
| if st.button("β¬οΈ Download from URL", use_container_width=True) and direct_url: | |
| st.info("Attempting to download the direct video URL...") | |
| path, err = download_direct_url(direct_url) | |
| if path: | |
| video_path = path | |
| st.success("Direct URL downloaded and ready for processing.") | |
| try: | |
| cap = cv2.VideoCapture(video_path) | |
| ret, frame = cap.read() | |
| if ret: | |
| st.image(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB), caption="Video preview", use_column_width=True) | |
| cap.release() | |
| except Exception: | |
| st.warning("Could not display video preview.") | |
| else: | |
| st.error(err) | |
| st.markdown("---") | |
| st.subheader("π₯ YouTube link (optional)") | |
| youtube_url = st.text_input("Enter a YouTube video URL", placeholder="https://www.youtube.com/watch?v=...") | |
| if st.button("β¬οΈ Download from YouTube", use_container_width=True) and youtube_url: | |
| if not _YT_DLP_AVAILABLE: | |
| st.error("yt-dlp is not installed in this environment. Use a direct URL or upload the file.") | |
| else: | |
| st.info("Attempting to download YouTube video...") | |
| path, err = download_youtube_video(youtube_url) | |
| if path: | |
| video_path = path | |
| st.success("YouTube video downloaded and ready for processing.") | |
| try: | |
| cap = cv2.VideoCapture(video_path) | |
| ret, frame = cap.read() | |
| if ret: | |
| st.image(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB), caption="Video preview", use_column_width=True) | |
| cap.release() | |
| except Exception: | |
| st.warning("Could not display video preview.") | |
| else: | |
| st.error(err) | |
| st.markdown("---") | |
| if video_path: | |
| if st.button("π START TRACKING AND COUNTING", type="primary", use_container_width=True): | |
| selected_class_ids = get_selected_class_ids() | |
| if not selected_class_ids: | |
| st.error("Please select at least one object type to count in the sidebar.") | |
| else: | |
| try: | |
| with st.spinner(f"Analyzing video with {model_name}..."): | |
| out_path = process_video(video_path, selected_class_ids, model_name) | |
| if out_path: | |
| st.success("Analysis complete! See results in the 'Analysis & results' tab.") | |
| else: | |
| st.error("Processing failed. Check the logs and input file.") | |
| except Exception as e: | |
| st.error(f"An error occurred during video processing: {e}") | |
| else: | |
| st.info("Upload a video, provide a direct URL, or a YouTube link to begin.") | |
| with tab2: | |
| data = st.session_state.processed_data | |
| if data["processing_complete"]: | |
| st.header("Results summary") | |
| col1, col2 = st.columns([2, 1]) | |
| with col1: | |
| st.subheader("π₯ Analyzed video output") | |
| try: | |
| with open(data["processed_video"], "rb") as video_file: | |
| video_bytes = video_file.read() | |
| st.video(video_bytes) | |
| st.download_button(label="π₯ Download annotated video (MP4)", data=video_bytes, | |
| file_name="analyzed_tracking_video.mp4", mime="video/mp4", use_container_width=True) | |
| except Exception: | |
| st.error("Could not load the processed video file.") | |
| with col2: | |
| st.subheader("β Object crossing totals") | |
| if data["total_counts"]: | |
| for obj_type, count in data["total_counts"].items(): | |
| st.metric(label=f"Total {obj_type.capitalize()} crossed", value=count) | |
| else: | |
| st.info("No objects crossed the counting line in the analyzed section.") | |
| st.subheader("π Object presence over processed frames") | |
| if data["frame_counts"]: | |
| df = pd.DataFrame(data["frame_counts"]).fillna(0) | |
| fig = go.Figure() | |
| for column in df.columns: | |
| if column != "frame": | |
| fig.add_trace(go.Scatter(x=df["frame"], y=df[column], name=column.capitalize(), mode="lines+markers")) | |
| fig.update_layout(title="Count of objects present per processed frame", | |
| xaxis_title="Frame number (processed frames)", | |
| yaxis_title="Instance count", hovermode="x unified", height=400) | |
| st.plotly_chart(fig, use_container_width=True) | |
| st.subheader("Data export") | |
| st.dataframe(df.tail(10), use_container_width=True, height=200) | |
| csv = df.to_csv(index=False).encode("utf-8") | |
| st.download_button(label="β¬οΈ Download frame-by-frame data (CSV)", data=csv, | |
| file_name="object_count_data.csv", mime="text/csv") | |
| else: | |
| st.warning("No tracking data available. Process a video first.") | |
| else: | |
| st.info("Process a video in the 'Video input' tab to view analysis results.") | |
| with tab3: | |
| st.header("Documentation & Notes") | |
| st.markdown( | |
| """ | |
| **Supported inputs** | |
| - Local upload (recommended for Spaces demos). | |
| - Direct public video URL (MP4 preferred). | |
| - YouTube link (requires `yt-dlp` and outbound network access). | |
| **Why YouTube downloads may fail in Spaces** | |
| Hugging Face Spaces may restrict outbound network access or DNS resolution. If YouTube download fails, use a direct URL or upload the file. Running the app locally will allow YouTube downloads if your machine has internet access. | |
| **Performance tips** | |
| - Use `yolov8n.pt` for faster processing. | |
| - Increase `Frame skip` (process every Nth frame) to speed up long videos. | |
| - Reduce `Maximum frames` for quick demos. | |
| **System packages** | |
| This app uses `opencv-python-headless` to avoid GUI dependencies. You generally do not need a `setup.sh` that installs `libgl1-mesa-glx` or `libglib2.0-0`. Remove `setup.sh` unless you switch to non-headless OpenCV or require specific system libraries. | |
| """ | |
| ) | |