Spaces:
Paused
Paused
| import cv2 | |
| import numpy as np | |
| from scipy.optimize import linear_sum_assignment | |
| import requests | |
| import json | |
| from fastapi.responses import JSONResponse | |
| from config import API_URL, API_KEY | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| class RepCounter: | |
| def __init__(self, fps, height_cm, mass_kg=0, target_reps=None): | |
| self.count = 0 | |
| self.last_state = None | |
| self.cooldown_frames = 15 | |
| self.cooldown = 0 | |
| self.rep_start_frame = None | |
| self.start_wrist_y = None | |
| self.rep_data = [] | |
| self.power_data = [] | |
| self.fps = fps | |
| self.cm_per_pixel = None | |
| self.real_distance_cm = height_cm * 0.2735 | |
| self.calibration_done = False | |
| self.mass_kg = mass_kg | |
| self.gravity = 9.81 | |
| self.target_reps = int(target_reps) | |
| self.target_reached = False | |
| self.final_speed = None | |
| self.final_power = None | |
| self.SKELETON = [ | |
| (5, 6), (5, 7), (7, 9), (6, 8), (8, 10), | |
| (9, 10), (11, 12), (5, 11), (6, 12), | |
| (11, 13), (13, 15), (12, 14), (14, 16), (13, 14) | |
| ] | |
| def update(self, wrist_y, knee_y, current_frame): | |
| if self.target_reached or self.cooldown > 0: | |
| self.cooldown = max(0, self.cooldown - 1) | |
| return | |
| current_state = 'above' if wrist_y < knee_y else 'below' | |
| if self.last_state != current_state: | |
| if current_state == 'below': | |
| self.rep_start_frame = current_frame | |
| self.start_wrist_y = wrist_y | |
| elif current_state == 'above' and self.last_state == 'below': | |
| if self.rep_start_frame is not None and self.cm_per_pixel is not None: | |
| end_frame = current_frame | |
| duration = (end_frame - self.rep_start_frame) / self.fps | |
| distance_pixels = self.start_wrist_y - wrist_y | |
| distance_cm = distance_pixels * self.cm_per_pixel | |
| if duration > 0: | |
| speed_cmps = abs(distance_cm) / duration | |
| self.rep_data.append(speed_cmps) | |
| if self.mass_kg > 0: | |
| speed_mps = speed_cmps / 100 | |
| force = self.mass_kg * self.gravity | |
| power = force * speed_mps | |
| self.power_data.append(power) | |
| self.count += 1 | |
| if self.target_reps and self.count >= self.target_reps: | |
| self.count = self.target_reps | |
| self.target_reached = True | |
| self.final_speed = np.mean(self.rep_data) if self.rep_data else 0 | |
| self.final_power = np.mean(self.power_data) if self.power_data else 0 | |
| self.cooldown = self.cooldown_frames | |
| self.last_state = current_state | |
| # CentroidTracker class | |
| class CentroidTracker: | |
| def __init__(self, max_disappeared=50, max_distance=100): | |
| self.next_id = 0 | |
| self.objects = {} | |
| self.max_disappeared = max_disappeared | |
| self.max_distance = max_distance | |
| def _update_missing(self): | |
| to_delete = [] | |
| for obj_id in list(self.objects.keys()): | |
| self.objects[obj_id]["missed"] += 1 | |
| if self.objects[obj_id]["missed"] > self.max_disappeared: | |
| to_delete.append(obj_id) | |
| for obj_id in to_delete: | |
| del self.objects[obj_id] | |
| def update(self, detections): | |
| if len(detections) == 0: | |
| self._update_missing() | |
| return [] | |
| centroids = np.array([[(x1 + x2) / 2, (y1 + y2) / 2] for x1, y1, x2, y2 in detections]) | |
| if len(self.objects) == 0: | |
| return self._register_new(centroids) | |
| return self._match_existing(centroids, detections) | |
| def _register_new(self, centroids): | |
| new_ids = [] | |
| for centroid in centroids: | |
| self.objects[self.next_id] = {"centroid": centroid, "missed": 0} | |
| new_ids.append(self.next_id) | |
| self.next_id += 1 | |
| return new_ids | |
| def _match_existing(self, centroids, detections): | |
| existing_ids = list(self.objects.keys()) | |
| existing_centroids = [self.objects[obj_id]["centroid"] for obj_id in existing_ids] | |
| cost = np.linalg.norm(np.array(existing_centroids)[:, np.newaxis] - centroids, axis=2) | |
| row_ind, col_ind = linear_sum_assignment(cost) | |
| used_rows = set() | |
| used_cols = set() | |
| matches = {} | |
| for (row, col) in zip(row_ind, col_ind): | |
| if cost[row, col] <= self.max_distance: | |
| obj_id = existing_ids[row] | |
| matches[obj_id] = centroids[col] | |
| used_rows.add(row) | |
| used_cols.add(col) | |
| for obj_id in existing_ids: | |
| if obj_id not in matches: | |
| self.objects[obj_id]["missed"] += 1 | |
| if self.objects[obj_id]["missed"] > self.max_disappeared: | |
| del self.objects[obj_id] | |
| new_ids = [] | |
| for col in range(len(centroids)): | |
| if col not in used_cols: | |
| self.objects[self.next_id] = {"centroid": centroids[col], "missed": 0} | |
| new_ids.append(self.next_id) | |
| self.next_id += 1 | |
| for obj_id, centroid in matches.items(): | |
| self.objects[obj_id]["centroid"] = centroid | |
| self.objects[obj_id]["missed"] = 0 | |
| all_ids = [] | |
| for detection in detections: | |
| centroid = np.array([(detection[0] + detection[2]) / 2, (detection[1] + detection[3]) / 2]) | |
| min_id = None | |
| min_dist = float('inf') | |
| for obj_id, data in self.objects.items(): | |
| dist = np.linalg.norm(centroid - data["centroid"]) | |
| if dist < min_dist and dist <= self.max_distance: | |
| min_dist = dist | |
| min_id = obj_id | |
| if min_id is not None: | |
| all_ids.append(min_id) | |
| self.objects[min_id]["centroid"] = centroid | |
| else: | |
| all_ids.append(self.next_id) | |
| self.objects[self.next_id] = {"centroid": centroid, "missed": 0} | |
| self.next_id += 1 | |
| return all_ids | |
| # Función de procesamiento optimizada | |
| def process_frame_for_counting(frame, tracker, rep_counter, frame_number,vitpose): | |
| pose_results = vitpose.pipeline(frame) | |
| keypoints = pose_results.keypoints_xy.float().cpu().numpy()[0] | |
| scores = pose_results.scores.float().cpu().numpy()[0] | |
| valid_points = {} | |
| wrist_midpoint = None | |
| knee_line_y = None | |
| print(keypoints) | |
| print(scores) | |
| # Procesar puntos clave | |
| for i, (kp, conf) in enumerate(zip(keypoints, scores)): | |
| if conf > 0.3 and 5 <= i <= 16: | |
| x, y = map(int, kp[:2]) | |
| valid_points[i] = (x, y) | |
| # Calibración usando keypoints de rodilla (14) y pie (16) | |
| if not rep_counter.calibration_done and 14 in valid_points and 16 in valid_points: | |
| knee = valid_points[14] | |
| ankle = valid_points[16] | |
| pixel_distance = np.sqrt((knee[0] - ankle[0])**2 + (knee[1] - ankle[1])**2) | |
| if pixel_distance > 0: | |
| rep_counter.cm_per_pixel = rep_counter.real_distance_cm / pixel_distance | |
| rep_counter.calibration_done = True | |
| # Calcular puntos de referencia para conteo | |
| if 9 in valid_points and 10 in valid_points: | |
| wrist_midpoint = ( | |
| (valid_points[9][0] + valid_points[10][0]) // 2, | |
| (valid_points[9][1] + valid_points[10][1]) // 2 | |
| ) | |
| if 13 in valid_points and 14 in valid_points: | |
| pt1 = np.array(valid_points[13]) | |
| pt2 = np.array(valid_points[14]) | |
| direction = pt2 - pt1 | |
| extension = 0.2 | |
| new_pt1 = pt1 - direction * extension | |
| new_pt2 = pt2 + direction * extension | |
| knee_line_y = (new_pt1[1] + new_pt2[1]) // 2 | |
| # Actualizar contador | |
| if wrist_midpoint and knee_line_y: | |
| rep_counter.update(wrist_midpoint[1], knee_line_y, frame_number) | |
| # Función principal de Gradio | |
| def analyze_dead_lift(input_video, reps, weight, height,vitpose,player_id,exercise_id): | |
| cap = cv2.VideoCapture(input_video) | |
| fps = cap.get(cv2.CAP_PROP_FPS) | |
| rep_counter = RepCounter(fps, int(height), int(weight), int(reps)) | |
| tracker = CentroidTracker(max_distance=150) | |
| frame_number = 0 | |
| while cap.isOpened(): | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| process_frame_for_counting(frame, tracker, rep_counter, frame_number,vitpose) | |
| frame_number += 1 | |
| cap.release() | |
| # Preparar payload para webhook | |
| if rep_counter.mass_kg > 0: | |
| power_data = rep_counter.power_data | |
| else: | |
| # Si no hay masa, usar ceros para potencia | |
| power_data = [0] * len(rep_counter.rep_data) if rep_counter.rep_data else [] | |
| # Asegurar que tenemos datos para enviar | |
| if rep_counter.rep_data: | |
| payload = {"repetition_data": [ | |
| {"repetition": i, "velocidad": round(s,1), "potencia": round(p,1)} | |
| for i, (s, p) in enumerate(zip(rep_counter.rep_data, power_data), start=1) | |
| ]} | |
| else: | |
| # En caso de no detectar repeticiones | |
| payload = {"repetition_data": []} | |
| send_results_api(payload, player_id, exercise_id, input_video) | |
| def send_results_api(results_dict: dict, | |
| player_id: str, | |
| exercise_id: str, | |
| video_path: str) -> JSONResponse: | |
| """ | |
| Send video analysis results to the API webhook endpoint. | |
| This function uploads the analyzed video file along with the computed metrics | |
| to the API's webhook endpoint for processing and storage. | |
| Args: | |
| results_dict (dict): Dictionary containing analysis results including: | |
| - video_analysis: Information about the processed video | |
| - repetition_data: List of metrics for each jump repetition | |
| player_id (str): Unique identifier for the player | |
| exercise_id (str): Unique identifier for the exercise | |
| video_path (str): Path to the video file to upload | |
| Returns: | |
| JSONResponse: HTTP response from the API endpoint | |
| Raises: | |
| FileNotFoundError: If the video file doesn't exist | |
| requests.RequestException: If the API request fails | |
| json.JSONEncodeError: If results_dict cannot be serialized to JSON | |
| """ | |
| url = API_URL + "/exercises/webhooks/video-processed-results" | |
| logger.info(f"Sending video results to {url}") | |
| # Open the video file | |
| with open(video_path, 'rb') as video_file: | |
| # Prepare the files dictionary for file upload | |
| files = { | |
| 'file': (video_path.split('/')[-1], video_file, 'video/mp4') | |
| } | |
| # Prepare the form data | |
| data = { | |
| 'player_id': player_id, | |
| 'exercise_id': exercise_id, | |
| 'results': json.dumps(results_dict) # Convert dict to JSON string | |
| } | |
| # Send the request with both files and data | |
| response = requests.post( | |
| url, | |
| headers={"token": API_KEY}, | |
| files=files, | |
| data=data, | |
| stream=True | |
| ) | |
| logger.info(f"Response: {response.status_code}") | |
| logger.info(f"Response: {response.text}") | |
| return response |