Optimize pose server processing

This commit is contained in:
2026-06-09 23:07:48 +08:00
parent a16b3e2d77
commit 8b878cb9e5
6 changed files with 238 additions and 43 deletions
+62 -9
View File
@@ -1,12 +1,34 @@
import asyncio
import json
import os
import re
import websockets
import cv2
from loguru import logger
from aiortc import RTCPeerConnection, RTCSessionDescription, RTCIceCandidate
from aiortc.mediastreams import MediaStreamError
from dead_bug_detector import DeadBugDetector
from rep_announcer import RepAnnouncer
PROCESS_EVERY_N_FRAMES = max(1, int(os.getenv("POSEFIT_PROCESS_EVERY_N_FRAMES", "1")))
TARGET_FRAME_WIDTH = max(1, int(os.getenv("POSEFIT_FRAME_WIDTH", "1080")))
TARGET_FRAME_HEIGHT = max(1, int(os.getenv("POSEFIT_FRAME_HEIGHT", "720")))
def format_pose_debug(pose_result):
metrics = pose_result.metrics
if metrics is None:
return "metrics=None"
return (
f"side={pose_result.side}, standard={pose_result.is_standard}, "
f"angles(le={metrics.left_elbow_angle:.1f}, re={metrics.right_elbow_angle:.1f}, "
f"lk={metrics.left_knee_angle:.1f}, rk={metrics.right_knee_angle:.1f}), "
f"extended(la={metrics.left_arm_extended}, ra={metrics.right_arm_extended}, "
f"ll={metrics.left_leg_extended}, rl={metrics.right_leg_extended})"
)
async def handle_client(websocket):
@@ -40,26 +62,47 @@ async def handle_client(websocket):
return cand
async def receive_video(track):
logger.info("Start receiving video frames")
logger.info(
"Start receiving video frames, process_every_n_frames={}, target_frame={}x{}",
PROCESS_EVERY_N_FRAMES,
TARGET_FRAME_WIDTH,
TARGET_FRAME_HEIGHT,
)
frame_count = 0
processed_count = 0
detector = DeadBugDetector()
announcer = RepAnnouncer()
last_announced_rep = 0
last_pose_result = None
last_annotated = None
try:
while True:
frame = await track.recv()
frame_count += 1
img = frame.to_ndarray(format="bgr24")
raw_img = frame.to_ndarray(format="bgr24")
img = normalize_frame(raw_img)
timestamp_ms = int(frame.time * 1000) if frame.time is not None else frame_count * 33
annotated, pose_result = detector.process_frame(img, timestamp_ms)
cv2.imshow("Android Camera (WebRTC)", annotated)
if frame_count % PROCESS_EVERY_N_FRAMES == 0 or last_pose_result is None:
processed_count += 1
last_annotated, last_pose_result = detector.process_frame(img, timestamp_ms)
if last_pose_result.rep_count > last_announced_rep:
last_announced_rep = last_pose_result.rep_count
announcer.announce_count(last_announced_rep)
cv2.imshow("Android Camera (WebRTC)", last_annotated if last_annotated is not None else img)
if frame_count % 100 == 0:
logger.info(
"Received {} frames, shape={}, reps={}, phase={}, feedback={}",
"Received {} frames, processed={}, raw_shape={}, shape={}, reps={}, phase={}, feedback={}, {}",
frame_count,
processed_count,
raw_img.shape,
img.shape,
pose_result.rep_count,
pose_result.phase.value,
" | ".join(pose_result.feedback),
last_pose_result.rep_count if last_pose_result is not None else 0,
last_pose_result.phase.value if last_pose_result is not None else "none",
" | ".join(last_pose_result.feedback) if last_pose_result is not None else "",
format_pose_debug(last_pose_result) if last_pose_result is not None else "metrics=None",
)
if cv2.waitKey(1) & 0xFF == 27:
@@ -67,9 +110,12 @@ async def handle_client(websocket):
break
except asyncio.CancelledError:
logger.info("Video receive task cancelled")
except MediaStreamError:
logger.info("Video track ended")
except Exception as e:
logger.error(f"Video receive error: {e}")
logger.exception(f"Video receive error: {e!r}")
finally:
announcer.close()
detector.close()
@pc.on("track")
@@ -131,5 +177,12 @@ async def main():
await asyncio.Future()
def normalize_frame(image):
height, width = image.shape[:2]
if width == TARGET_FRAME_WIDTH and height == TARGET_FRAME_HEIGHT:
return image
return cv2.resize(image, (TARGET_FRAME_WIDTH, TARGET_FRAME_HEIGHT), interpolation=cv2.INTER_AREA)
if __name__ == "__main__":
asyncio.run(main())