from __future__ import annotations import threading import time import cv2 import mediapipe as mp import numpy as np from loguru import logger from app.exercises.dead_bug.metrics import calculate_metrics from app.exercises.dead_bug.rules import has_required_visibility from app.exercises.dead_bug.state_machine import DeadBugStateMachine from app.exercises.dead_bug.types import DeadBugMetrics, DeadBugPhase, DeadBugResult, Point from app.rendering.overlay_renderer import draw_status_overlay from app.rendering.skeleton_renderer import draw_landmarks from app.vision.frame_utils import bgr_to_rgba from app.vision.pose_landmarker import PoseLandmarkerWrapper from app.vision.pose_types import ( LEFT_ANKLE, LEFT_ELBOW, LEFT_HIP, LEFT_KNEE, LEFT_SHOULDER, LEFT_WRIST, REQUIRED_LANDMARKS, RIGHT_ANKLE, RIGHT_ELBOW, RIGHT_HIP, RIGHT_KNEE, RIGHT_SHOULDER, RIGHT_WRIST, ) class DeadBugDetector: """死虫式(Dead Bug)运动检测器""" def __init__( self, *, model_path: str | None = None, visibility_threshold: float = 0.45, extension_confirm_frames: int = 4, reset_confirm_frames: int = 3, prefer_gpu: bool = True, ) -> None: """初始化姿态检测器、状态机和可视化渲染组件""" self.visibility_threshold = visibility_threshold self._latest_result = None self._result_lock = threading.Lock() self._result_event = threading.Event() self._inflight = False self._inflight_started_at = 0.0 self.last_timing: dict[str, float | bool] = {} def on_result(pose_result, _image, _timestamp_ms): with self._result_lock: self._latest_result = pose_result self._inflight = False self._inflight_started_at = 0.0 self._result_event.set() self._landmarker = PoseLandmarkerWrapper( model_path=model_path, prefer_gpu=prefer_gpu, result_callback=on_result, ) self._state = DeadBugStateMachine( extension_confirm_frames=extension_confirm_frames, reset_confirm_frames=reset_confirm_frames, ) self._last_timestamp_ms = -1 def close(self) -> None: """释放MediaPipe模型资源""" self._landmarker.close() def process_frame(self, bgr_frame: np.ndarray, timestamp_ms: int) -> tuple[np.ndarray, DeadBugResult]: """处理单帧:姿态检测、指标计算、状态机更新、可视化叠加""" total_started = time.perf_counter() timestamp_ms = self._normalize_timestamp(timestamp_ms) normalize_done = time.perf_counter() with self._result_lock: if self._inflight and time.monotonic() - self._inflight_started_at > 0.5: logger.warning("MediaPipe detect_async timed out; allowing next frame submission") self._inflight = False self._inflight_started_at = 0.0 should_submit = not self._inflight if should_submit: self._inflight = True self._inflight_started_at = time.monotonic() lock_done = time.perf_counter() if should_submit: rgba_frame = bgr_to_rgba(bgr_frame) convert_done = time.perf_counter() mp_image = mp.Image(image_format=mp.ImageFormat.SRGBA, data=rgba_frame) self._result_event.clear() try: self._landmarker.detect_async(mp_image, timestamp_ms) except Exception: with self._result_lock: self._inflight = False self._inflight_started_at = 0.0 raise submit_done = time.perf_counter() self._result_event.wait(timeout=0.08) wait_done = time.perf_counter() else: convert_done = lock_done submit_done = lock_done wait_done = lock_done with self._result_lock: pose_result = self._latest_result result_read_done = time.perf_counter() annotated = bgr_frame.copy() copy_done = time.perf_counter() if pose_result is None or not pose_result.pose_landmarks: self._state.mark_no_pose() result = DeadBugResult( rep_count=self._state.rep_count, phase=DeadBugPhase.NO_POSE, side=self._state.active_side, is_standard=False, feedback=["No full body detected"], metrics=None, ) draw_status_overlay(annotated, result) self._record_timing( total_started, normalize_done, lock_done, convert_done, submit_done, wait_done, result_read_done, copy_done, time.perf_counter(), should_submit, ) return annotated, result landmarks = [Point(lm.x, lm.y, lm.z, getattr(lm, "visibility", 1.0)) for lm in pose_result.pose_landmarks[0]] draw_landmarks(annotated, landmarks, REQUIRED_LANDMARKS, visibility_threshold=self.visibility_threshold) if not has_required_visibility(landmarks, REQUIRED_LANDMARKS, self.visibility_threshold): self._state.mark_no_pose() result = DeadBugResult( rep_count=self._state.rep_count, phase=DeadBugPhase.NO_POSE, side=self._state.active_side, is_standard=False, feedback=["Keep shoulders, elbows, wrists, hips, knees, ankles visible"], metrics=None, ) draw_status_overlay(annotated, result) self._record_timing( total_started, normalize_done, lock_done, convert_done, submit_done, wait_done, result_read_done, copy_done, time.perf_counter(), should_submit, ) return annotated, result raw = calculate_metrics( landmarks, left_shoulder=LEFT_SHOULDER, right_shoulder=RIGHT_SHOULDER, left_elbow=LEFT_ELBOW, right_elbow=RIGHT_ELBOW, left_wrist=LEFT_WRIST, right_wrist=RIGHT_WRIST, left_hip=LEFT_HIP, right_hip=RIGHT_HIP, left_knee=LEFT_KNEE, right_knee=RIGHT_KNEE, left_ankle=LEFT_ANKLE, right_ankle=RIGHT_ANKLE, visibility_threshold=self.visibility_threshold, ) metrics = DeadBugMetrics( left_arm_extended=raw["left_arm_extended"], right_arm_extended=raw["right_arm_extended"], left_leg_extended=raw["left_leg_extended"], right_leg_extended=raw["right_leg_extended"], left_elbow_angle=raw["left_elbow_angle"], right_elbow_angle=raw["right_elbow_angle"], left_knee_angle=raw["left_knee_angle"], right_knee_angle=raw["right_knee_angle"], feedback=raw["feedback"], ) result = self._state.update(metrics) draw_status_overlay(annotated, result) self._record_timing( total_started, normalize_done, lock_done, convert_done, submit_done, wait_done, result_read_done, copy_done, time.perf_counter(), should_submit, ) return annotated, result def _record_timing( self, total_started: float, normalize_done: float, lock_done: float, convert_done: float, submit_done: float, wait_done: float, result_read_done: float, copy_done: float, finished: float, submitted: bool, ) -> None: self.last_timing = { "total_ms": (finished - total_started) * 1000, "timestamp_ms": (normalize_done - total_started) * 1000, "lock_ms": (lock_done - normalize_done) * 1000, "convert_ms": (convert_done - lock_done) * 1000, "submit_ms": (submit_done - convert_done) * 1000, "wait_ms": (wait_done - submit_done) * 1000, "result_read_ms": (result_read_done - wait_done) * 1000, "copy_ms": (copy_done - result_read_done) * 1000, "postprocess_draw_ms": (finished - copy_done) * 1000, "submitted": submitted, } def _normalize_timestamp(self, timestamp_ms: int) -> int: """确保时间戳严格递增(MediaPipe要求)""" if timestamp_ms <= self._last_timestamp_ms: timestamp_ms = self._last_timestamp_ms + 1 self._last_timestamp_ms = timestamp_ms return timestamp_ms