Optimize pose server processing

This commit is contained in:
2026-06-09 23:07:48 +08:00
parent a16b3e2d77
commit 8b878cb9e5
6 changed files with 238 additions and 43 deletions
+70 -32
View File
@@ -1,6 +1,7 @@
from __future__ import annotations
import threading
import time
from dataclasses import dataclass
from enum import Enum
from pathlib import Path
@@ -8,6 +9,7 @@ from pathlib import Path
import cv2
import mediapipe as mp
import numpy as np
from loguru import logger
PoseLandmarker = mp.tasks.vision.PoseLandmarker
@@ -41,7 +43,6 @@ class DeadBugMetrics:
right_elbow_angle: float
left_knee_angle: float
right_knee_angle: float
torso_tilt: float
feedback: list[str]
@@ -98,6 +99,7 @@ class DeadBugDetector:
visibility_threshold: float = 0.45,
extension_confirm_frames: int = 4,
reset_confirm_frames: int = 3,
prefer_gpu: bool = True,
) -> None:
if model_path is None:
model_path = Path(__file__).resolve().parent / "pose_models" / "pose_landmarker_full.task"
@@ -106,26 +108,22 @@ class DeadBugDetector:
self.visibility_threshold = visibility_threshold
self.extension_confirm_frames = extension_confirm_frames
self.reset_confirm_frames = reset_confirm_frames
self.delegate = BaseOptions.Delegate.GPU if prefer_gpu else BaseOptions.Delegate.CPU
self._latest_result = None
self._result_lock = threading.Lock()
self._result_event = threading.Event()
self._inflight = False
self._inflight_started_at = 0.0
def on_result(pose_result, _image, _timestamp_ms):
with self._result_lock:
self._latest_result = pose_result
self._inflight = False
self._inflight_started_at = 0.0
self._result_event.set()
options = PoseLandmarkerOptions(
base_options=BaseOptions(model_asset_path=self.model_path),
running_mode=VisionRunningMode.LIVE_STREAM,
result_callback=on_result,
num_poses=1,
min_pose_detection_confidence=0.5,
min_pose_presence_confidence=0.5,
min_tracking_confidence=0.5,
)
self._landmarker = PoseLandmarker.create_from_options(options)
self._landmarker = self._create_landmarker(on_result)
self.rep_count = 0
self.phase = DeadBugPhase.READY
@@ -138,20 +136,67 @@ class DeadBugDetector:
def close(self) -> None:
self._landmarker.close()
def _create_landmarker(self, result_callback):
try:
landmarker = PoseLandmarker.create_from_options(
self._build_options(self.delegate, result_callback)
)
logger.info("MediaPipe PoseLandmarker initialized with {} delegate", self.delegate.name)
return landmarker
except Exception as exc:
if self.delegate == BaseOptions.Delegate.CPU:
raise
logger.warning("MediaPipe GPU delegate unavailable, falling back to CPU: {}", exc)
self.delegate = BaseOptions.Delegate.CPU
landmarker = PoseLandmarker.create_from_options(
self._build_options(self.delegate, result_callback)
)
logger.info("MediaPipe PoseLandmarker initialized with CPU delegate")
return landmarker
def _build_options(self, delegate, result_callback):
return PoseLandmarkerOptions(
base_options=BaseOptions(model_asset_path=self.model_path, delegate=delegate),
running_mode=VisionRunningMode.LIVE_STREAM,
result_callback=result_callback,
num_poses=1,
min_pose_detection_confidence=0.5,
min_pose_presence_confidence=0.5,
min_tracking_confidence=0.5,
)
def process_frame(self, bgr_frame: np.ndarray, timestamp_ms: int) -> tuple[np.ndarray, DeadBugResult]:
timestamp_ms = self._normalize_timestamp(timestamp_ms)
rgb_frame = cv2.cvtColor(bgr_frame, cv2.COLOR_BGR2RGB)
mp_image = mp.Image(image_format=mp.ImageFormat.SRGB, data=rgb_frame)
self._result_event.clear()
self._landmarker.detect_async(mp_image, timestamp_ms)
self._result_event.wait(timeout=0.1)
with self._result_lock:
if self._inflight and time.monotonic() - self._inflight_started_at > 0.5:
logger.warning("MediaPipe detect_async timed out; allowing next frame submission")
self._inflight = False
self._inflight_started_at = 0.0
should_submit = not self._inflight
if should_submit:
self._inflight = True
self._inflight_started_at = time.monotonic()
if should_submit:
rgba_frame = cv2.cvtColor(bgr_frame, cv2.COLOR_BGR2RGBA)
mp_image = mp.Image(image_format=mp.ImageFormat.SRGBA, data=rgba_frame)
self._result_event.clear()
try:
self._landmarker.detect_async(mp_image, timestamp_ms)
except Exception:
with self._result_lock:
self._inflight = False
self._inflight_started_at = 0.0
raise
self._result_event.wait(timeout=0.08)
with self._result_lock:
pose_result = self._latest_result
annotated = bgr_frame.copy()
if not pose_result.pose_landmarks:
if pose_result is None or not pose_result.pose_landmarks:
result = DeadBugResult(
rep_count=self.rep_count,
phase=DeadBugPhase.NO_POSE,
@@ -224,10 +269,7 @@ class DeadBugDetector:
and lm[self.RIGHT_ANKLE].y >= lm[self.RIGHT_KNEE].y - scale * 0.2
)
torso_tilt = abs(lm[self.LEFT_HIP].y - lm[self.RIGHT_HIP].y) / scale
feedback: list[str] = []
if torso_tilt > 0.35:
feedback.append("Keep pelvis level and core stable")
if left_arm_extended and left_elbow < 160:
feedback.append("Straighten left arm")
if right_arm_extended and right_elbow < 160:
@@ -246,7 +288,6 @@ class DeadBugDetector:
right_elbow_angle=right_elbow,
left_knee_angle=left_knee,
right_knee_angle=right_knee,
torso_tilt=torso_tilt,
feedback=feedback,
)
@@ -305,19 +346,16 @@ class DeadBugDetector:
)
def _detect_diagonal_extension(self, metrics: DeadBugMetrics) -> str | None:
left_arm_right_leg = metrics.left_arm_extended and metrics.right_leg_extended
right_arm_left_leg = metrics.right_arm_extended and metrics.left_leg_extended
same_side_noise = (
metrics.left_arm_extended
and metrics.left_leg_extended
or metrics.right_arm_extended
and metrics.right_leg_extended
)
if same_side_noise:
if metrics.left_leg_extended and metrics.right_leg_extended:
return None
if left_arm_right_leg and not right_arm_left_leg:
# Dead bug starts with both arms raised, so the non-moving arm may also
# look "extended" in 2D. Infer the rep from the single extended leg and
# require the opposite arm to be extended, instead of rejecting both-arm
# frames as same-side noise.
if metrics.right_leg_extended and metrics.left_arm_extended:
return "left_arm_right_leg"
if right_arm_left_leg and not left_arm_right_leg:
if metrics.left_leg_extended and metrics.right_arm_extended:
return "right_arm_left_leg"
return None