Optimize pose server processing

This commit is contained in:
2026-06-09 23:07:48 +08:00
parent a16b3e2d77
commit 8b878cb9e5
6 changed files with 238 additions and 43 deletions
+1
View File
@@ -1,3 +1,4 @@
.venv/ .venv/
.idea/ .idea/
__pycache__/ __pycache__/
logs/
+70 -32
View File
@@ -1,6 +1,7 @@
from __future__ import annotations from __future__ import annotations
import threading import threading
import time
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum from enum import Enum
from pathlib import Path from pathlib import Path
@@ -8,6 +9,7 @@ from pathlib import Path
import cv2 import cv2
import mediapipe as mp import mediapipe as mp
import numpy as np import numpy as np
from loguru import logger
PoseLandmarker = mp.tasks.vision.PoseLandmarker PoseLandmarker = mp.tasks.vision.PoseLandmarker
@@ -41,7 +43,6 @@ class DeadBugMetrics:
right_elbow_angle: float right_elbow_angle: float
left_knee_angle: float left_knee_angle: float
right_knee_angle: float right_knee_angle: float
torso_tilt: float
feedback: list[str] feedback: list[str]
@@ -98,6 +99,7 @@ class DeadBugDetector:
visibility_threshold: float = 0.45, visibility_threshold: float = 0.45,
extension_confirm_frames: int = 4, extension_confirm_frames: int = 4,
reset_confirm_frames: int = 3, reset_confirm_frames: int = 3,
prefer_gpu: bool = True,
) -> None: ) -> None:
if model_path is None: if model_path is None:
model_path = Path(__file__).resolve().parent / "pose_models" / "pose_landmarker_full.task" model_path = Path(__file__).resolve().parent / "pose_models" / "pose_landmarker_full.task"
@@ -106,26 +108,22 @@ class DeadBugDetector:
self.visibility_threshold = visibility_threshold self.visibility_threshold = visibility_threshold
self.extension_confirm_frames = extension_confirm_frames self.extension_confirm_frames = extension_confirm_frames
self.reset_confirm_frames = reset_confirm_frames self.reset_confirm_frames = reset_confirm_frames
self.delegate = BaseOptions.Delegate.GPU if prefer_gpu else BaseOptions.Delegate.CPU
self._latest_result = None self._latest_result = None
self._result_lock = threading.Lock() self._result_lock = threading.Lock()
self._result_event = threading.Event() self._result_event = threading.Event()
self._inflight = False
self._inflight_started_at = 0.0
def on_result(pose_result, _image, _timestamp_ms): def on_result(pose_result, _image, _timestamp_ms):
with self._result_lock: with self._result_lock:
self._latest_result = pose_result self._latest_result = pose_result
self._inflight = False
self._inflight_started_at = 0.0
self._result_event.set() self._result_event.set()
options = PoseLandmarkerOptions( self._landmarker = self._create_landmarker(on_result)
base_options=BaseOptions(model_asset_path=self.model_path),
running_mode=VisionRunningMode.LIVE_STREAM,
result_callback=on_result,
num_poses=1,
min_pose_detection_confidence=0.5,
min_pose_presence_confidence=0.5,
min_tracking_confidence=0.5,
)
self._landmarker = PoseLandmarker.create_from_options(options)
self.rep_count = 0 self.rep_count = 0
self.phase = DeadBugPhase.READY self.phase = DeadBugPhase.READY
@@ -138,20 +136,67 @@ class DeadBugDetector:
def close(self) -> None: def close(self) -> None:
self._landmarker.close() self._landmarker.close()
def _create_landmarker(self, result_callback):
try:
landmarker = PoseLandmarker.create_from_options(
self._build_options(self.delegate, result_callback)
)
logger.info("MediaPipe PoseLandmarker initialized with {} delegate", self.delegate.name)
return landmarker
except Exception as exc:
if self.delegate == BaseOptions.Delegate.CPU:
raise
logger.warning("MediaPipe GPU delegate unavailable, falling back to CPU: {}", exc)
self.delegate = BaseOptions.Delegate.CPU
landmarker = PoseLandmarker.create_from_options(
self._build_options(self.delegate, result_callback)
)
logger.info("MediaPipe PoseLandmarker initialized with CPU delegate")
return landmarker
def _build_options(self, delegate, result_callback):
return PoseLandmarkerOptions(
base_options=BaseOptions(model_asset_path=self.model_path, delegate=delegate),
running_mode=VisionRunningMode.LIVE_STREAM,
result_callback=result_callback,
num_poses=1,
min_pose_detection_confidence=0.5,
min_pose_presence_confidence=0.5,
min_tracking_confidence=0.5,
)
def process_frame(self, bgr_frame: np.ndarray, timestamp_ms: int) -> tuple[np.ndarray, DeadBugResult]: def process_frame(self, bgr_frame: np.ndarray, timestamp_ms: int) -> tuple[np.ndarray, DeadBugResult]:
timestamp_ms = self._normalize_timestamp(timestamp_ms) timestamp_ms = self._normalize_timestamp(timestamp_ms)
rgb_frame = cv2.cvtColor(bgr_frame, cv2.COLOR_BGR2RGB)
mp_image = mp.Image(image_format=mp.ImageFormat.SRGB, data=rgb_frame)
self._result_event.clear() with self._result_lock:
self._landmarker.detect_async(mp_image, timestamp_ms) if self._inflight and time.monotonic() - self._inflight_started_at > 0.5:
self._result_event.wait(timeout=0.1) logger.warning("MediaPipe detect_async timed out; allowing next frame submission")
self._inflight = False
self._inflight_started_at = 0.0
should_submit = not self._inflight
if should_submit:
self._inflight = True
self._inflight_started_at = time.monotonic()
if should_submit:
rgba_frame = cv2.cvtColor(bgr_frame, cv2.COLOR_BGR2RGBA)
mp_image = mp.Image(image_format=mp.ImageFormat.SRGBA, data=rgba_frame)
self._result_event.clear()
try:
self._landmarker.detect_async(mp_image, timestamp_ms)
except Exception:
with self._result_lock:
self._inflight = False
self._inflight_started_at = 0.0
raise
self._result_event.wait(timeout=0.08)
with self._result_lock: with self._result_lock:
pose_result = self._latest_result pose_result = self._latest_result
annotated = bgr_frame.copy() annotated = bgr_frame.copy()
if not pose_result.pose_landmarks: if pose_result is None or not pose_result.pose_landmarks:
result = DeadBugResult( result = DeadBugResult(
rep_count=self.rep_count, rep_count=self.rep_count,
phase=DeadBugPhase.NO_POSE, phase=DeadBugPhase.NO_POSE,
@@ -224,10 +269,7 @@ class DeadBugDetector:
and lm[self.RIGHT_ANKLE].y >= lm[self.RIGHT_KNEE].y - scale * 0.2 and lm[self.RIGHT_ANKLE].y >= lm[self.RIGHT_KNEE].y - scale * 0.2
) )
torso_tilt = abs(lm[self.LEFT_HIP].y - lm[self.RIGHT_HIP].y) / scale
feedback: list[str] = [] feedback: list[str] = []
if torso_tilt > 0.35:
feedback.append("Keep pelvis level and core stable")
if left_arm_extended and left_elbow < 160: if left_arm_extended and left_elbow < 160:
feedback.append("Straighten left arm") feedback.append("Straighten left arm")
if right_arm_extended and right_elbow < 160: if right_arm_extended and right_elbow < 160:
@@ -246,7 +288,6 @@ class DeadBugDetector:
right_elbow_angle=right_elbow, right_elbow_angle=right_elbow,
left_knee_angle=left_knee, left_knee_angle=left_knee,
right_knee_angle=right_knee, right_knee_angle=right_knee,
torso_tilt=torso_tilt,
feedback=feedback, feedback=feedback,
) )
@@ -305,19 +346,16 @@ class DeadBugDetector:
) )
def _detect_diagonal_extension(self, metrics: DeadBugMetrics) -> str | None: def _detect_diagonal_extension(self, metrics: DeadBugMetrics) -> str | None:
left_arm_right_leg = metrics.left_arm_extended and metrics.right_leg_extended if metrics.left_leg_extended and metrics.right_leg_extended:
right_arm_left_leg = metrics.right_arm_extended and metrics.left_leg_extended
same_side_noise = (
metrics.left_arm_extended
and metrics.left_leg_extended
or metrics.right_arm_extended
and metrics.right_leg_extended
)
if same_side_noise:
return None return None
if left_arm_right_leg and not right_arm_left_leg:
# Dead bug starts with both arms raised, so the non-moving arm may also
# look "extended" in 2D. Infer the rep from the single extended leg and
# require the opposite arm to be extended, instead of rejecting both-arm
# frames as same-side noise.
if metrics.right_leg_extended and metrics.left_arm_extended:
return "left_arm_right_leg" return "left_arm_right_leg"
if right_arm_left_leg and not left_arm_right_leg: if metrics.left_leg_extended and metrics.right_arm_extended:
return "right_arm_left_leg" return "right_arm_left_leg"
return None return None
+62 -9
View File
@@ -1,12 +1,34 @@
import asyncio import asyncio
import json import json
import os
import re import re
import websockets import websockets
import cv2 import cv2
from loguru import logger from loguru import logger
from aiortc import RTCPeerConnection, RTCSessionDescription, RTCIceCandidate from aiortc import RTCPeerConnection, RTCSessionDescription, RTCIceCandidate
from aiortc.mediastreams import MediaStreamError
from dead_bug_detector import DeadBugDetector from dead_bug_detector import DeadBugDetector
from rep_announcer import RepAnnouncer
PROCESS_EVERY_N_FRAMES = max(1, int(os.getenv("POSEFIT_PROCESS_EVERY_N_FRAMES", "1")))
TARGET_FRAME_WIDTH = max(1, int(os.getenv("POSEFIT_FRAME_WIDTH", "1080")))
TARGET_FRAME_HEIGHT = max(1, int(os.getenv("POSEFIT_FRAME_HEIGHT", "720")))
def format_pose_debug(pose_result):
metrics = pose_result.metrics
if metrics is None:
return "metrics=None"
return (
f"side={pose_result.side}, standard={pose_result.is_standard}, "
f"angles(le={metrics.left_elbow_angle:.1f}, re={metrics.right_elbow_angle:.1f}, "
f"lk={metrics.left_knee_angle:.1f}, rk={metrics.right_knee_angle:.1f}), "
f"extended(la={metrics.left_arm_extended}, ra={metrics.right_arm_extended}, "
f"ll={metrics.left_leg_extended}, rl={metrics.right_leg_extended})"
)
async def handle_client(websocket): async def handle_client(websocket):
@@ -40,26 +62,47 @@ async def handle_client(websocket):
return cand return cand
async def receive_video(track): async def receive_video(track):
logger.info("Start receiving video frames") logger.info(
"Start receiving video frames, process_every_n_frames={}, target_frame={}x{}",
PROCESS_EVERY_N_FRAMES,
TARGET_FRAME_WIDTH,
TARGET_FRAME_HEIGHT,
)
frame_count = 0 frame_count = 0
processed_count = 0
detector = DeadBugDetector() detector = DeadBugDetector()
announcer = RepAnnouncer()
last_announced_rep = 0
last_pose_result = None
last_annotated = None
try: try:
while True: while True:
frame = await track.recv() frame = await track.recv()
frame_count += 1 frame_count += 1
img = frame.to_ndarray(format="bgr24") raw_img = frame.to_ndarray(format="bgr24")
img = normalize_frame(raw_img)
timestamp_ms = int(frame.time * 1000) if frame.time is not None else frame_count * 33 timestamp_ms = int(frame.time * 1000) if frame.time is not None else frame_count * 33
annotated, pose_result = detector.process_frame(img, timestamp_ms)
cv2.imshow("Android Camera (WebRTC)", annotated) if frame_count % PROCESS_EVERY_N_FRAMES == 0 or last_pose_result is None:
processed_count += 1
last_annotated, last_pose_result = detector.process_frame(img, timestamp_ms)
if last_pose_result.rep_count > last_announced_rep:
last_announced_rep = last_pose_result.rep_count
announcer.announce_count(last_announced_rep)
cv2.imshow("Android Camera (WebRTC)", last_annotated if last_annotated is not None else img)
if frame_count % 100 == 0: if frame_count % 100 == 0:
logger.info( logger.info(
"Received {} frames, shape={}, reps={}, phase={}, feedback={}", "Received {} frames, processed={}, raw_shape={}, shape={}, reps={}, phase={}, feedback={}, {}",
frame_count, frame_count,
processed_count,
raw_img.shape,
img.shape, img.shape,
pose_result.rep_count, last_pose_result.rep_count if last_pose_result is not None else 0,
pose_result.phase.value, last_pose_result.phase.value if last_pose_result is not None else "none",
" | ".join(pose_result.feedback), " | ".join(last_pose_result.feedback) if last_pose_result is not None else "",
format_pose_debug(last_pose_result) if last_pose_result is not None else "metrics=None",
) )
if cv2.waitKey(1) & 0xFF == 27: if cv2.waitKey(1) & 0xFF == 27:
@@ -67,9 +110,12 @@ async def handle_client(websocket):
break break
except asyncio.CancelledError: except asyncio.CancelledError:
logger.info("Video receive task cancelled") logger.info("Video receive task cancelled")
except MediaStreamError:
logger.info("Video track ended")
except Exception as e: except Exception as e:
logger.error(f"Video receive error: {e}") logger.exception(f"Video receive error: {e!r}")
finally: finally:
announcer.close()
detector.close() detector.close()
@pc.on("track") @pc.on("track")
@@ -131,5 +177,12 @@ async def main():
await asyncio.Future() await asyncio.Future()
def normalize_frame(image):
height, width = image.shape[:2]
if width == TARGET_FRAME_WIDTH and height == TARGET_FRAME_HEIGHT:
return image
return cv2.resize(image, (TARGET_FRAME_WIDTH, TARGET_FRAME_HEIGHT), interpolation=cv2.INTER_AREA)
if __name__ == "__main__": if __name__ == "__main__":
asyncio.run(main()) asyncio.run(main())
+18
View File
@@ -1,4 +1,9 @@
import os import os
import faulthandler
from pathlib import Path
from loguru import logger
os.environ["MEDIAPIPE_DISABLE_LOGGING"] = "1" os.environ["MEDIAPIPE_DISABLE_LOGGING"] = "1"
os.environ["GLOG_minloglevel"] = "3" os.environ["GLOG_minloglevel"] = "3"
@@ -8,4 +13,17 @@ from handle_client import main
if __name__ == "__main__": if __name__ == "__main__":
log_dir = Path(__file__).resolve().parent / "logs"
log_dir.mkdir(exist_ok=True)
crash_log = open(log_dir / "posefit-crash.log", "a", buffering=1)
faulthandler.enable(file=crash_log, all_threads=True)
logger.add(
log_dir / "posefit-server_{time:YYYY-MM-DD}.log",
rotation="20 MB",
retention="14 days",
enqueue=True,
backtrace=True,
diagnose=True,
)
logger.info("Starting server...")
asyncio.run(main()) asyncio.run(main())
+84
View File
@@ -0,0 +1,84 @@
from __future__ import annotations
import queue
import subprocess
import sys
import threading
from typing import Any
from loguru import logger
class RepAnnouncer:
def __init__(self, *, enabled: bool = True, rate: int = 185, volume: float = 1.0) -> None:
self.enabled = enabled
self.rate = rate
self.volume = volume
self._queue: queue.Queue[str | None] = queue.Queue()
self._thread: threading.Thread | None = None
self._engine: Any | None = None
self._use_macos_say = False
self._current_process: subprocess.Popen | None = None
if self.enabled:
self._start()
def announce_count(self, count: int) -> None:
if not self.enabled or count <= 0:
return
while True:
try:
self._queue.get_nowait()
except queue.Empty:
break
self._queue.put(str(count))
def close(self) -> None:
if not self.enabled:
return
self._queue.put(None)
if self._thread is not None:
self._thread.join(timeout=1.0)
if self._current_process is not None and self._current_process.poll() is None:
self._current_process.terminate()
def _start(self) -> None:
if sys.platform == "darwin":
self._use_macos_say = True
logger.info("Rep announcer initialized with macOS say")
else:
try:
import pyttsx3
self._engine = pyttsx3.init()
self._engine.setProperty("rate", self.rate)
self._engine.setProperty("volume", self.volume)
logger.info("Rep announcer initialized with pyttsx3")
except Exception as exc:
self.enabled = False
logger.warning("Rep announcer disabled, pyttsx3 unavailable: {}", exc)
return
self._thread = threading.Thread(target=self._run, name="RepAnnouncer", daemon=True)
self._thread.start()
def _run(self) -> None:
while True:
text = self._queue.get()
if text is None:
return
try:
if self._use_macos_say:
if self._current_process is not None and self._current_process.poll() is None:
self._current_process.terminate()
self._current_process = subprocess.Popen(
["say", "-r", str(self.rate), text],
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
)
elif self._engine is not None:
self._engine.say(text)
self._engine.runAndWait()
except Exception as exc:
logger.warning("Failed to announce rep count {}: {}", text, exc)
+2 -1
View File
@@ -1,6 +1,7 @@
aiortc>=1.9.0 aiortc>=1.9.0
websockets>=13.0 websockets>=13.0
opencv-contrib-python>=4.10.0 opencv-contrib-python>=4.10.0
numpy>=2.0.0 numpy>=1.26,<2
loguru>=0.7.0 loguru>=0.7.0
mediapipe==0.10.21 mediapipe==0.10.21
pyttsx3>=2.99