为所有函数和类添加中文注释文档字符串

This commit is contained in:
2026-06-10 10:34:11 +08:00
parent c612a7ad71
commit c3f93e4441
29 changed files with 103 additions and 17 deletions
+1 -5
View File
@@ -7,8 +7,4 @@ Real-time exercise pose detection and coaching via WebRTC.
``` ```
pip install -r requirements.txt pip install -r requirements.txt
python run.py python run.py
``` ``
## Configuration
Copy `.env.example` to `.env` and adjust settings, or set environment variables directly.
+7 -1
View File
@@ -8,9 +8,11 @@ from typing import Any
from loguru import logger from loguru import logger
class RepAnnouncer: class RepAnnouncer:
"""运动次数语音播报器"""
def __init__(self, *, enabled: bool = True, rate: int = 185, volume: float = 1.0) -> None: def __init__(self, *, enabled: bool = True, rate: int = 185, volume: float = 1.0) -> None:
"""初始化TTS引擎(macOS用say,其他系统用pyttsx3"""
self.enabled = enabled self.enabled = enabled
self.rate = rate self.rate = rate
self.volume = volume self.volume = volume
@@ -24,6 +26,7 @@ class RepAnnouncer:
self._start() self._start()
def announce_count(self, count: int) -> None: def announce_count(self, count: int) -> None:
"""将次数放入队列进行异步语音播报"""
if not self.enabled or count <= 0: if not self.enabled or count <= 0:
return return
while True: while True:
@@ -34,6 +37,7 @@ class RepAnnouncer:
self._queue.put(str(count)) self._queue.put(str(count))
def close(self) -> None: def close(self) -> None:
"""停止播报线程并释放资源"""
if not self.enabled: if not self.enabled:
return return
self._queue.put(None) self._queue.put(None)
@@ -43,6 +47,7 @@ class RepAnnouncer:
self._current_process.terminate() self._current_process.terminate()
def _start(self) -> None: def _start(self) -> None:
"""根据平台初始化TTS引擎并启动后台播报线程"""
if sys.platform == "darwin": if sys.platform == "darwin":
self._use_macos_say = True self._use_macos_say = True
logger.info("Rep announcer initialized with macOS say") logger.info("Rep announcer initialized with macOS say")
@@ -63,6 +68,7 @@ class RepAnnouncer:
self._thread.start() self._thread.start()
def _run(self) -> None: def _run(self) -> None:
"""后台线程:从队列读取文本并调用TTS播放"""
while True: while True:
text = self._queue.get() text = self._queue.get()
if text is None: if text is None:
+1
View File
@@ -5,6 +5,7 @@ from configs.load import config
def startup() -> None: def startup() -> None:
"""应用启动初始化:开启崩溃日志和日志系统"""
enable_crash_handler(config.logging.dir_path) enable_crash_handler(config.logging.dir_path)
from app.core.logging import setup_logging from app.core.logging import setup_logging
setup_logging() setup_logging()
+1
View File
@@ -8,6 +8,7 @@ from configs.load import config
def setup_logging() -> None: def setup_logging() -> None:
"""配置loguru日志输出到按日期轮转的日志文件"""
log_dir = config.logging.dir_path log_dir = config.logging.dir_path
log_dir.mkdir(parents=True, exist_ok=True) log_dir.mkdir(parents=True, exist_ok=True)
+1
View File
@@ -5,6 +5,7 @@ from pathlib import Path
def enable_crash_handler(log_dir: str | Path) -> None: def enable_crash_handler(log_dir: str | Path) -> None:
"""启用faulthandler,将崩溃堆栈写入日志文件"""
log_dir = Path(log_dir) log_dir = Path(log_dir)
log_dir.mkdir(parents=True, exist_ok=True) log_dir.mkdir(parents=True, exist_ok=True)
crash_log = open(log_dir / "posefit-crash.log", "a", buffering=1) crash_log = open(log_dir / "posefit-crash.log", "a", buffering=1)
+6 -1
View File
@@ -5,28 +5,33 @@ from contextlib import contextmanager
from loguru import logger from loguru import logger
class PerfTimer: class PerfTimer:
"""性能计时器,用于测量代码段执行耗时"""
def __init__(self, name: str = "") -> None: def __init__(self, name: str = "") -> None:
self.name = name self.name = name
self._start = 0.0 self._start = 0.0
self._elapsed = 0.0 self._elapsed = 0.0
def start(self) -> PerfTimer: def start(self) -> PerfTimer:
"""启动计时器"""
self._start = time.perf_counter() self._start = time.perf_counter()
return self return self
def stop(self) -> float: def stop(self) -> float:
"""停止计时器并返回耗时(秒)"""
self._elapsed = time.perf_counter() - self._start self._elapsed = time.perf_counter() - self._start
return self._elapsed return self._elapsed
@property @property
def elapsed_ms(self) -> float: def elapsed_ms(self) -> float:
"""返回已记录耗时(毫秒)"""
return self._elapsed * 1000 return self._elapsed * 1000
@contextmanager @contextmanager
def measure(name: str = ""): def measure(name: str = ""):
"""上下文管理器:进入时计时,退出时记录耗时日志"""
timer = PerfTimer(name).start() timer = PerfTimer(name).start()
yield timer yield timer
elapsed = timer.stop() elapsed = timer.stop()
+6 -1
View File
@@ -32,8 +32,9 @@ from app.vision.pose_types import (
RIGHT_WRIST, RIGHT_WRIST,
) )
class DeadBugDetector: class DeadBugDetector:
"""死虫式(Dead Bug)运动检测器"""
def __init__( def __init__(
self, self,
*, *,
@@ -43,6 +44,7 @@ class DeadBugDetector:
reset_confirm_frames: int = 3, reset_confirm_frames: int = 3,
prefer_gpu: bool = True, prefer_gpu: bool = True,
) -> None: ) -> None:
"""初始化姿态检测器、状态机和可视化渲染组件"""
self.visibility_threshold = visibility_threshold self.visibility_threshold = visibility_threshold
self._latest_result = None self._latest_result = None
@@ -72,9 +74,11 @@ class DeadBugDetector:
self._last_timestamp_ms = -1 self._last_timestamp_ms = -1
def close(self) -> None: def close(self) -> None:
"""释放MediaPipe模型资源"""
self._landmarker.close() self._landmarker.close()
def process_frame(self, bgr_frame: np.ndarray, timestamp_ms: int) -> tuple[np.ndarray, DeadBugResult]: def process_frame(self, bgr_frame: np.ndarray, timestamp_ms: int) -> tuple[np.ndarray, DeadBugResult]:
"""处理单帧:姿态检测、指标计算、状态机更新、可视化叠加"""
timestamp_ms = self._normalize_timestamp(timestamp_ms) timestamp_ms = self._normalize_timestamp(timestamp_ms)
with self._result_lock: with self._result_lock:
@@ -166,6 +170,7 @@ class DeadBugDetector:
return annotated, result return annotated, result
def _normalize_timestamp(self, timestamp_ms: int) -> int: def _normalize_timestamp(self, timestamp_ms: int) -> int:
"""确保时间戳严格递增(MediaPipe要求)"""
if timestamp_ms <= self._last_timestamp_ms: if timestamp_ms <= self._last_timestamp_ms:
timestamp_ms = self._last_timestamp_ms + 1 timestamp_ms = self._last_timestamp_ms + 1
self._last_timestamp_ms = timestamp_ms self._last_timestamp_ms = timestamp_ms
+3
View File
@@ -7,6 +7,7 @@ from app.exercises.dead_bug.types import Point
def angle(a: Point, b: Point, c: Point) -> float: def angle(a: Point, b: Point, c: Point) -> float:
"""计算以b为顶点的三点夹角(度数)"""
ba = np.array([a.x - b.x, a.y - b.y], dtype=np.float32) ba = np.array([a.x - b.x, a.y - b.y], dtype=np.float32)
bc = np.array([c.x - b.x, c.y - b.y], dtype=np.float32) bc = np.array([c.x - b.x, c.y - b.y], dtype=np.float32)
denom = float(np.linalg.norm(ba) * np.linalg.norm(bc)) denom = float(np.linalg.norm(ba) * np.linalg.norm(bc))
@@ -17,6 +18,7 @@ def angle(a: Point, b: Point, c: Point) -> float:
def distance(a: Point, b: Point) -> float: def distance(a: Point, b: Point) -> float:
"""计算两点之间的欧几里得距离(归一化坐标空间)"""
return float(np.hypot(a.x - b.x, a.y - b.y)) return float(np.hypot(a.x - b.x, a.y - b.y))
@@ -37,6 +39,7 @@ def calculate_metrics(
right_ankle: int, right_ankle: int,
visibility_threshold: float = 0.45, visibility_threshold: float = 0.45,
) -> dict: ) -> dict:
"""计算四肢关节角度、伸展状态及反馈信息"""
left_elbow_angle = angle(lm[left_shoulder], lm[left_elbow], lm[left_wrist]) left_elbow_angle = angle(lm[left_shoulder], lm[left_elbow], lm[left_wrist])
right_elbow_angle = angle(lm[right_shoulder], lm[right_elbow], lm[right_wrist]) right_elbow_angle = angle(lm[right_shoulder], lm[right_elbow], lm[right_wrist])
left_knee_angle = angle(lm[left_hip], lm[left_knee], lm[left_ankle]) left_knee_angle = angle(lm[left_hip], lm[left_knee], lm[left_ankle])
+3
View File
@@ -4,10 +4,12 @@ from app.exercises.dead_bug.types import DeadBugMetrics, Point
def has_required_visibility(landmarks: list[Point], required_indices: tuple[int, ...], visibility_threshold: float) -> bool: def has_required_visibility(landmarks: list[Point], required_indices: tuple[int, ...], visibility_threshold: float) -> bool:
"""检查所有必需关键点的可见度是否高于阈值"""
return all(landmarks[i].visibility >= visibility_threshold for i in required_indices) return all(landmarks[i].visibility >= visibility_threshold for i in required_indices)
def detect_diagonal_extension(metrics: DeadBugMetrics) -> str | None: def detect_diagonal_extension(metrics: DeadBugMetrics) -> str | None:
"""检测是否存在对角伸展(左臂+右腿 或 右臂+左腿)"""
if metrics.left_leg_extended and metrics.right_leg_extended: if metrics.left_leg_extended and metrics.right_leg_extended:
return None return None
@@ -19,6 +21,7 @@ def detect_diagonal_extension(metrics: DeadBugMetrics) -> str | None:
def is_ready_position(metrics: DeadBugMetrics) -> bool: def is_ready_position(metrics: DeadBugMetrics) -> bool:
"""判断是否处于准备姿态(膝盖弯曲且四肢未伸展)"""
knees_bent = metrics.left_knee_angle <= 140 and metrics.right_knee_angle <= 140 knees_bent = metrics.left_knee_angle <= 140 and metrics.right_knee_angle <= 140
legs_not_extended = not metrics.left_leg_extended and not metrics.right_leg_extended legs_not_extended = not metrics.left_leg_extended and not metrics.right_leg_extended
return knees_bent and legs_not_extended and detect_diagonal_extension(metrics) is None return knees_bent and legs_not_extended and detect_diagonal_extension(metrics) is None
+4 -1
View File
@@ -3,9 +3,11 @@ from __future__ import annotations
from app.exercises.dead_bug.rules import detect_diagonal_extension, is_ready_position from app.exercises.dead_bug.rules import detect_diagonal_extension, is_ready_position
from app.exercises.dead_bug.types import DeadBugMetrics, DeadBugPhase, DeadBugResult from app.exercises.dead_bug.types import DeadBugMetrics, DeadBugPhase, DeadBugResult
class DeadBugStateMachine: class DeadBugStateMachine:
"""死虫式动作状态机:管理READY/EXTENDING/NEED_RESET/NO_POSE状态转换"""
def __init__(self, *, extension_confirm_frames: int = 4, reset_confirm_frames: int = 3) -> None: def __init__(self, *, extension_confirm_frames: int = 4, reset_confirm_frames: int = 3) -> None:
"""初始化并设置状态转换确认帧数"""
self.extension_confirm_frames = extension_confirm_frames self.extension_confirm_frames = extension_confirm_frames
self.reset_confirm_frames = reset_confirm_frames self.reset_confirm_frames = reset_confirm_frames
@@ -17,6 +19,7 @@ class DeadBugStateMachine:
self._reset_frames = 0 self._reset_frames = 0
def update(self, metrics: DeadBugMetrics) -> DeadBugResult: def update(self, metrics: DeadBugMetrics) -> DeadBugResult:
"""根据传入指标更新状态机并返回本次结果"""
side = detect_diagonal_extension(metrics) side = detect_diagonal_extension(metrics)
ready = is_ready_position(metrics) ready = is_ready_position(metrics)
+4
View File
@@ -5,6 +5,7 @@ from enum import Enum
class DeadBugPhase(str, Enum): class DeadBugPhase(str, Enum):
"""死虫式动作阶段枚举"""
READY = "ready" READY = "ready"
EXTENDING = "extending" EXTENDING = "extending"
NEED_RESET = "need_reset" NEED_RESET = "need_reset"
@@ -13,6 +14,7 @@ class DeadBugPhase(str, Enum):
@dataclass(frozen=True) @dataclass(frozen=True)
class Point: class Point:
"""三维关键点坐标及可见度"""
x: float x: float
y: float y: float
z: float z: float
@@ -21,6 +23,7 @@ class Point:
@dataclass @dataclass
class DeadBugMetrics: class DeadBugMetrics:
"""四肢关节度量数据"""
left_arm_extended: bool left_arm_extended: bool
right_arm_extended: bool right_arm_extended: bool
left_leg_extended: bool left_leg_extended: bool
@@ -34,6 +37,7 @@ class DeadBugMetrics:
@dataclass @dataclass
class DeadBugResult: class DeadBugResult:
"""单帧检测结果"""
rep_count: int rep_count: int
phase: DeadBugPhase phase: DeadBugPhase
side: str | None side: str | None
+1
View File
@@ -9,6 +9,7 @@ from app.signaling.websocket_server import main as serve
def main(): def main():
"""应用入口:启动服务并运行WebSocket信令服务器"""
startup() startup()
logger.info("Starting server...") logger.info("Starting server...")
try: try:
+1
View File
@@ -7,6 +7,7 @@ from app.exercises.dead_bug.types import DeadBugResult
def draw_status_overlay(image: np.ndarray, result: DeadBugResult) -> None: def draw_status_overlay(image: np.ndarray, result: DeadBugResult) -> None:
"""在图像上叠加动作状态信息(次数、阶段、反馈)"""
color = (60, 220, 90) if result.is_standard else (50, 180, 255) color = (60, 220, 90) if result.is_standard else (50, 180, 255)
cv2.rectangle(image, (12, 12), (520, 142), (20, 20, 20), -1) cv2.rectangle(image, (12, 12), (520, 142), (20, 20, 20), -1)
cv2.putText(image, f"Dead bug reps: {result.rep_count}", (28, 48), cv2.FONT_HERSHEY_SIMPLEX, 0.9, color, 2) cv2.putText(image, f"Dead bug reps: {result.rep_count}", (28, 48), cv2.FONT_HERSHEY_SIMPLEX, 0.9, color, 2)
+1
View File
@@ -18,6 +18,7 @@ def draw_landmarks(
line_thickness: int = 2, line_thickness: int = 2,
point_radius: int = 4, point_radius: int = 4,
) -> None: ) -> None:
"""绘制人体骨架关键点与连接线(仅绘制可见度达标的点)"""
if connections is None: if connections is None:
connections = _POSE_CONNECTIONS connections = _POSE_CONNECTIONS
+4
View File
@@ -6,16 +6,20 @@ WINDOW_NAME = "Android Camera (WebRTC)"
def show_frame(image, window_name: str = WINDOW_NAME) -> None: def show_frame(image, window_name: str = WINDOW_NAME) -> None:
"""在OpenCV窗口中显示图像帧"""
cv2.imshow(window_name, image) cv2.imshow(window_name, image)
def wait_key(delay_ms: int = 1) -> int: def wait_key(delay_ms: int = 1) -> int:
"""等待按键并返回ASCII码"""
return cv2.waitKey(delay_ms) & 0xFF return cv2.waitKey(delay_ms) & 0xFF
def is_esc_pressed() -> bool: def is_esc_pressed() -> bool:
"""检测ESC键是否被按下"""
return wait_key(1) == 27 return wait_key(1) == 27
def close_window() -> None: def close_window() -> None:
"""关闭所有OpenCV窗口"""
cv2.destroyAllWindows() cv2.destroyAllWindows()
+1
View File
@@ -7,6 +7,7 @@ from aiortc import RTCIceCandidate
def parse_ice(data: dict[str, Any]) -> RTCIceCandidate | None: def parse_ice(data: dict[str, Any]) -> RTCIceCandidate | None:
"""解析ICE候选者字符串为RTCIceCandidate对象"""
match = re.match( match = re.match(
r'candidate:(\S+) (\d) (\S+) (\d+) (\S+) (\d+) typ (\S+)(?: raddr (\S+) rport (\d+))?', r'candidate:(\S+) (\d) (\S+) (\d+) (\S+) (\d+) typ (\S+)(?: raddr (\S+) rport (\d+))?',
data["candidate"], data["candidate"],
+2 -1
View File
@@ -1,10 +1,11 @@
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass, field from dataclasses import dataclass
@dataclass @dataclass
class SignalingMessage: class SignalingMessage:
"""WebRTC信令消息数据模型"""
type: str type: str
sdp: str = "" sdp: str = ""
candidate: str = "" candidate: str = ""
+2
View File
@@ -11,6 +11,7 @@ from configs.load import config
async def handle_client(websocket): async def handle_client(websocket):
"""处理单个WebSocket客户端连接"""
client = websocket.remote_address client = websocket.remote_address
logger.info(f"Client connected: {client}") logger.info(f"Client connected: {client}")
@@ -21,6 +22,7 @@ async def handle_client(websocket):
async def main(): async def main():
"""启动WebSocket信令服务器"""
cfg = config.server cfg = config.server
logger.info(f"WebRTC signaling server: ws://{cfg.host}:{cfg.port}") logger.info(f"WebRTC signaling server: ws://{cfg.host}:{cfg.port}")
async with websockets.serve(handle_client, cfg.host, cfg.port, max_size=cfg.max_ws_size): async with websockets.serve(handle_client, cfg.host, cfg.port, max_size=cfg.max_ws_size):
+3
View File
@@ -9,6 +9,7 @@ TARGET_HEIGHT = 720
def resize_to_target(image: np.ndarray, width: int = TARGET_WIDTH, height: int = TARGET_HEIGHT) -> np.ndarray: def resize_to_target(image: np.ndarray, width: int = TARGET_WIDTH, height: int = TARGET_HEIGHT) -> np.ndarray:
"""将图像缩放到目标尺寸(仅当尺寸不一致时)"""
h, w = image.shape[:2] h, w = image.shape[:2]
if w == width and h == height: if w == width and h == height:
return image return image
@@ -16,8 +17,10 @@ def resize_to_target(image: np.ndarray, width: int = TARGET_WIDTH, height: int =
def bgr_to_rgba(bgr: np.ndarray) -> np.ndarray: def bgr_to_rgba(bgr: np.ndarray) -> np.ndarray:
"""将BGR格式图像转换为RGBA格式"""
return cv2.cvtColor(bgr, cv2.COLOR_BGR2RGBA) return cv2.cvtColor(bgr, cv2.COLOR_BGR2RGBA)
def bgr_to_rgb(bgr: np.ndarray) -> np.ndarray: def bgr_to_rgb(bgr: np.ndarray) -> np.ndarray:
"""将BGR格式图像转换为RGB格式"""
return cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB) return cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
+6 -1
View File
@@ -14,8 +14,9 @@ PoseLandmarkerOptions = mp.tasks.vision.PoseLandmarkerOptions
VisionRunningMode = mp.tasks.vision.RunningMode VisionRunningMode = mp.tasks.vision.RunningMode
BaseOptions = mp.tasks.BaseOptions BaseOptions = mp.tasks.BaseOptions
class PoseLandmarkerWrapper: class PoseLandmarkerWrapper:
"""MediaPipe姿态关键点检测器封装"""
def __init__( def __init__(
self, self,
*, *,
@@ -23,6 +24,7 @@ class PoseLandmarkerWrapper:
prefer_gpu: bool = True, prefer_gpu: bool = True,
result_callback: Callable | None = None, result_callback: Callable | None = None,
) -> None: ) -> None:
"""初始化姿态检测器,优先尝试GPU委托,失败则回退到CPU"""
self.model_path = model_path or DEFAULT_MODEL_PATH self.model_path = model_path or DEFAULT_MODEL_PATH
if prefer_gpu: if prefer_gpu:
@@ -39,6 +41,7 @@ class PoseLandmarkerWrapper:
logger.info("MediaPipe PoseLandmarker initialized with CPU delegate") logger.info("MediaPipe PoseLandmarker initialized with CPU delegate")
def _create(self, delegate, result_callback=None): def _create(self, delegate, result_callback=None):
"""根据委托类型和回调创建PoseLandmarker实例"""
options = PoseLandmarkerOptions( options = PoseLandmarkerOptions(
base_options=BaseOptions(model_asset_path=self.model_path, delegate=delegate), base_options=BaseOptions(model_asset_path=self.model_path, delegate=delegate),
running_mode=VisionRunningMode.LIVE_STREAM, running_mode=VisionRunningMode.LIVE_STREAM,
@@ -51,7 +54,9 @@ class PoseLandmarkerWrapper:
return PoseLandmarker.create_from_options(options) return PoseLandmarker.create_from_options(options)
def detect_async(self, mp_image, timestamp_ms: int) -> None: def detect_async(self, mp_image, timestamp_ms: int) -> None:
"""异步执行姿态检测"""
return self._landmarker.detect_async(mp_image, timestamp_ms) return self._landmarker.detect_async(mp_image, timestamp_ms)
def close(self) -> None: def close(self) -> None:
"""释放MediaPipe资源"""
self._landmarker.close() self._landmarker.close()
+1
View File
@@ -9,6 +9,7 @@ TARGET_HEIGHT = 720
def validate_frame_size(image: np.ndarray, width: int = TARGET_WIDTH, height: int = TARGET_HEIGHT) -> None: def validate_frame_size(image: np.ndarray, width: int = TARGET_WIDTH, height: int = TARGET_HEIGHT) -> None:
"""验证视频帧尺寸是否与目标尺寸一致,不一致时记录警告"""
h, w = image.shape[:2] h, w = image.shape[:2]
if w != width or h != height: if w != width or h != height:
logger.warning("Unexpected frame size: {}x{}", w, h) logger.warning("Unexpected frame size: {}x{}", w, h)
+5 -1
View File
@@ -10,13 +10,15 @@ from loguru import logger
from app.signaling.ice_parser import parse_ice from app.signaling.ice_parser import parse_ice
from app.webrtc.video_receiver import VideoReceiver from app.webrtc.video_receiver import VideoReceiver
class PeerSession: class PeerSession:
"""WebRTC对等连接会话管理"""
def __init__(self) -> None: def __init__(self) -> None:
self._pc = RTCPeerConnection() self._pc = RTCPeerConnection()
self._video_task: asyncio.Task | None = None self._video_task: asyncio.Task | None = None
async def handle(self, websocket) -> None: async def handle(self, websocket) -> None:
"""处理WebSocket信令交互与WebRTC连接建立"""
self._setup_events() self._setup_events()
try: try:
@@ -47,6 +49,7 @@ class PeerSession:
await self._cleanup() await self._cleanup()
def _setup_events(self) -> None: def _setup_events(self) -> None:
"""注册ICE连接状态变化和视频轨道接收事件处理器"""
@self._pc.on("track") @self._pc.on("track")
async def on_track(track): async def on_track(track):
logger.info(f"Track received: kind={track.kind}") logger.info(f"Track received: kind={track.kind}")
@@ -61,6 +64,7 @@ class PeerSession:
await self._pc.close() await self._pc.close()
async def _cleanup(self) -> None: async def _cleanup(self) -> None:
"""清理视频任务并关闭对等连接"""
if self._video_task: if self._video_task:
self._video_task.cancel() self._video_task.cancel()
try: try:
+4 -1
View File
@@ -13,6 +13,7 @@ from configs.load import config
def _format_pose_debug(pose_result) -> str: def _format_pose_debug(pose_result) -> str:
"""格式化姿态检测结果用于调试日志输出"""
metrics = pose_result.metrics metrics = pose_result.metrics
if metrics is None: if metrics is None:
return "metrics=None" return "metrics=None"
@@ -24,12 +25,14 @@ def _format_pose_debug(pose_result) -> str:
f"ll={metrics.left_leg_extended}, rl={metrics.right_leg_extended})" f"ll={metrics.left_leg_extended}, rl={metrics.right_leg_extended})"
) )
class VideoReceiver: class VideoReceiver:
"""视频轨道接收与运动检测流水线"""
def __init__(self, track) -> None: def __init__(self, track) -> None:
self._track = track self._track = track
async def run(self) -> None: async def run(self) -> None:
"""持续接收视频帧并进行姿态检测、渲染和语音播报"""
logger.info("Start receiving video frames, process_every_n={}", config.video.process_every_n_frames) logger.info("Start receiving video frames, process_every_n={}", config.video.process_every_n_frames)
frame_count = 0 frame_count = 0
+3
View File
@@ -29,6 +29,7 @@ _SECTION_CLASS = {
def _dict_to_dataclass(cls: type, data: dict[str, Any] | None) -> dict[str, Any]: def _dict_to_dataclass(cls: type, data: dict[str, Any] | None) -> dict[str, Any]:
"""将字典过滤为仅包含指定dataclass字段的键值对"""
if data is None: if data is None:
return {} return {}
field_names = {f.name for f in dataclasses.fields(cls)} field_names = {f.name for f in dataclasses.fields(cls)}
@@ -36,6 +37,7 @@ def _dict_to_dataclass(cls: type, data: dict[str, Any] | None) -> dict[str, Any]
def _read_yaml(path: Path) -> dict[str, Any]: def _read_yaml(path: Path) -> dict[str, Any]:
"""读取YAML配置文件并返回字典"""
if not path.exists(): if not path.exists():
return {} return {}
with open(path, encoding="utf-8") as f: with open(path, encoding="utf-8") as f:
@@ -43,6 +45,7 @@ def _read_yaml(path: Path) -> dict[str, Any]:
def load_config(config_path: str | Path | None = None) -> AppConfig: def load_config(config_path: str | Path | None = None) -> AppConfig:
"""加载并解析应用配置,返回AppConfig实例"""
if config_path is None: if config_path is None:
config_path = _PROJECT_ROOT / "config.yaml" config_path = _PROJECT_ROOT / "config.yaml"
+9
View File
@@ -6,6 +6,7 @@ from pathlib import Path
@dataclass @dataclass
class ServerConfig: class ServerConfig:
"""WebSocket服务器配置"""
host: str = "0.0.0.0" host: str = "0.0.0.0"
port: int = 8765 port: int = 8765
max_ws_size: int = 10_485_760 max_ws_size: int = 10_485_760
@@ -13,16 +14,19 @@ class ServerConfig:
@dataclass @dataclass
class VideoConfig: class VideoConfig:
"""视频帧处理配置"""
process_every_n_frames: int = 1 process_every_n_frames: int = 1
@dataclass @dataclass
class ModelConfig: class ModelConfig:
"""姿态检测模型配置"""
path: str = "" path: str = ""
prefer_gpu: bool = True prefer_gpu: bool = True
@property @property
def resolved_path(self) -> str: def resolved_path(self) -> str:
"""返回模型文件的绝对路径"""
if self.path: if self.path:
return self.path return self.path
return str(Path(__file__).resolve().parent.parent / "pose_models" / "pose_landmarker_full.task") return str(Path(__file__).resolve().parent.parent / "pose_models" / "pose_landmarker_full.task")
@@ -30,6 +34,7 @@ class ModelConfig:
@dataclass @dataclass
class DeadBugConfig: class DeadBugConfig:
"""死虫式(Dead Bug)运动检测配置"""
visibility_threshold: float = 0.45 visibility_threshold: float = 0.45
extension_confirm_frames: int = 4 extension_confirm_frames: int = 4
reset_confirm_frames: int = 3 reset_confirm_frames: int = 3
@@ -37,6 +42,7 @@ class DeadBugConfig:
@dataclass @dataclass
class AudioConfig: class AudioConfig:
"""语音播报配置"""
rep_announcer_enabled: bool = True rep_announcer_enabled: bool = True
rep_announcer_rate: int = 185 rep_announcer_rate: int = 185
rep_announcer_volume: float = 1.0 rep_announcer_volume: float = 1.0
@@ -44,17 +50,20 @@ class AudioConfig:
@dataclass @dataclass
class LoggingConfig: class LoggingConfig:
"""日志配置"""
dir: str = "logs" dir: str = "logs"
rotation: str = "20 MB" rotation: str = "20 MB"
retention: str = "14 days" retention: str = "14 days"
@property @property
def dir_path(self) -> Path: def dir_path(self) -> Path:
"""返回日志目录的绝对路径"""
return Path(__file__).resolve().parent.parent / self.dir return Path(__file__).resolve().parent.parent / self.dir
@dataclass @dataclass
class AppConfig: class AppConfig:
"""应用总配置,聚合所有子配置"""
server: ServerConfig = field(default_factory=ServerConfig) server: ServerConfig = field(default_factory=ServerConfig)
video: VideoConfig = field(default_factory=VideoConfig) video: VideoConfig = field(default_factory=VideoConfig)
model: ModelConfig = field(default_factory=ModelConfig) model: ModelConfig = field(default_factory=ModelConfig)
+1 -1
View File
@@ -1,6 +1,6 @@
aiortc>=1.9.0 aiortc>=1.9.0
websockets>=13.0 websockets>=13.0
opencv-contrib-python>=4.10.0 opencv-contrib-python>=4.13.0.92
numpy>=1.26,<2 numpy>=1.26,<2
loguru>=0.7.0 loguru>=0.7.0
mediapipe==0.10.21 mediapipe==0.10.21
+10 -1
View File
@@ -3,26 +3,32 @@ from __future__ import annotations
from app.exercises.dead_bug.rules import detect_diagonal_extension, has_required_visibility, is_ready_position from app.exercises.dead_bug.rules import detect_diagonal_extension, has_required_visibility, is_ready_position
from app.exercises.dead_bug.types import DeadBugMetrics, Point from app.exercises.dead_bug.types import DeadBugMetrics, Point
class TestDeadBugRules: class TestDeadBugRules:
"""死虫式规则函数单元测试"""
def _make_landmark(self, x=0.5, y=0.5, z=0.0, visibility=1.0): def _make_landmark(self, x=0.5, y=0.5, z=0.0, visibility=1.0):
"""创建测试用Point对象"""
return Point(x, y, z, visibility) return Point(x, y, z, visibility)
def _make_visible_landmarks(self): def _make_visible_landmarks(self):
"""创建33个全可见的测试用关键点"""
return [self._make_landmark() for _ in range(33)] return [self._make_landmark() for _ in range(33)]
def test_has_required_visibility_all_visible(self): def test_has_required_visibility_all_visible(self):
"""测试:所有关键点可见时应返回True"""
lm = self._make_visible_landmarks() lm = self._make_visible_landmarks()
indices = (11, 12, 13, 14, 15, 16, 23, 24, 25, 26, 27, 28) indices = (11, 12, 13, 14, 15, 16, 23, 24, 25, 26, 27, 28)
assert has_required_visibility(lm, indices, 0.45) assert has_required_visibility(lm, indices, 0.45)
def test_has_required_visibility_low(self): def test_has_required_visibility_low(self):
"""测试:关键点可见度过低时应返回False"""
lm = self._make_visible_landmarks() lm = self._make_visible_landmarks()
lm[11] = self._make_landmark(visibility=0.1) lm[11] = self._make_landmark(visibility=0.1)
indices = (11, 12, 13, 14, 15, 16, 23, 24, 25, 26, 27, 28) indices = (11, 12, 13, 14, 15, 16, 23, 24, 25, 26, 27, 28)
assert not has_required_visibility(lm, indices, 0.45) assert not has_required_visibility(lm, indices, 0.45)
def test_detect_diagonal_extension_none(self): def test_detect_diagonal_extension_none(self):
"""测试:四肢均未伸展时应返回None"""
metrics = DeadBugMetrics( metrics = DeadBugMetrics(
left_arm_extended=False, right_arm_extended=False, left_arm_extended=False, right_arm_extended=False,
left_leg_extended=False, right_leg_extended=False, left_leg_extended=False, right_leg_extended=False,
@@ -33,6 +39,7 @@ class TestDeadBugRules:
assert detect_diagonal_extension(metrics) is None assert detect_diagonal_extension(metrics) is None
def test_detect_diagonal_extension_left_arm_right_leg(self): def test_detect_diagonal_extension_left_arm_right_leg(self):
"""测试:左臂+右腿对角伸展检测"""
metrics = DeadBugMetrics( metrics = DeadBugMetrics(
left_arm_extended=True, right_arm_extended=False, left_arm_extended=True, right_arm_extended=False,
left_leg_extended=False, right_leg_extended=True, left_leg_extended=False, right_leg_extended=True,
@@ -43,6 +50,7 @@ class TestDeadBugRules:
assert detect_diagonal_extension(metrics) == "left_arm_right_leg" assert detect_diagonal_extension(metrics) == "left_arm_right_leg"
def test_is_ready_position(self): def test_is_ready_position(self):
"""测试:膝盖弯曲且四肢收缩应识别为准备姿态"""
metrics = DeadBugMetrics( metrics = DeadBugMetrics(
left_arm_extended=False, right_arm_extended=False, left_arm_extended=False, right_arm_extended=False,
left_leg_extended=False, right_leg_extended=False, left_leg_extended=False, right_leg_extended=False,
@@ -53,6 +61,7 @@ class TestDeadBugRules:
assert is_ready_position(metrics) assert is_ready_position(metrics)
def test_is_not_ready_legs_extended(self): def test_is_not_ready_legs_extended(self):
"""测试:腿部伸展时不识别为准备姿态"""
metrics = DeadBugMetrics( metrics = DeadBugMetrics(
left_arm_extended=False, right_arm_extended=False, left_arm_extended=False, right_arm_extended=False,
left_leg_extended=True, right_leg_extended=False, left_leg_extended=True, right_leg_extended=False,
+7 -1
View File
@@ -3,9 +3,11 @@ from __future__ import annotations
from app.exercises.dead_bug.state_machine import DeadBugStateMachine from app.exercises.dead_bug.state_machine import DeadBugStateMachine
from app.exercises.dead_bug.types import DeadBugMetrics, DeadBugPhase from app.exercises.dead_bug.types import DeadBugMetrics, DeadBugPhase
class TestDeadBugStateMachine: class TestDeadBugStateMachine:
"""死虫式状态机单元测试"""
def _ready_metrics(self) -> DeadBugMetrics: def _ready_metrics(self) -> DeadBugMetrics:
"""构建准备姿态的度量数据"""
return DeadBugMetrics( return DeadBugMetrics(
left_arm_extended=False, right_arm_extended=False, left_arm_extended=False, right_arm_extended=False,
left_leg_extended=False, right_leg_extended=False, left_leg_extended=False, right_leg_extended=False,
@@ -15,6 +17,7 @@ class TestDeadBugStateMachine:
) )
def _extended_left(self) -> DeadBugMetrics: def _extended_left(self) -> DeadBugMetrics:
"""构建左臂+右腿对角伸展的度量数据"""
return DeadBugMetrics( return DeadBugMetrics(
left_arm_extended=True, right_arm_extended=False, left_arm_extended=True, right_arm_extended=False,
left_leg_extended=False, right_leg_extended=True, left_leg_extended=False, right_leg_extended=True,
@@ -24,17 +27,20 @@ class TestDeadBugStateMachine:
) )
def test_initial_state(self): def test_initial_state(self):
"""测试:状态机初始化后应为READY且计数为0"""
sm = DeadBugStateMachine() sm = DeadBugStateMachine()
assert sm.phase == DeadBugPhase.READY assert sm.phase == DeadBugPhase.READY
assert sm.rep_count == 0 assert sm.rep_count == 0
def test_no_transition_in_ready(self): def test_no_transition_in_ready(self):
"""测试:准备姿态下不触发状态转换"""
sm = DeadBugStateMachine() sm = DeadBugStateMachine()
result = sm.update(self._ready_metrics()) result = sm.update(self._ready_metrics())
assert sm.phase == DeadBugPhase.READY assert sm.phase == DeadBugPhase.READY
assert result.rep_count == 0 assert result.rep_count == 0
def test_confirm_extension(self): def test_confirm_extension(self):
"""测试:连续确认帧数后从READY转换到EXTENDING"""
sm = DeadBugStateMachine(extension_confirm_frames=2, reset_confirm_frames=2) sm = DeadBugStateMachine(extension_confirm_frames=2, reset_confirm_frames=2)
sm.update(self._extended_left()) sm.update(self._extended_left())
assert sm.phase == DeadBugPhase.READY assert sm.phase == DeadBugPhase.READY
+5 -1
View File
@@ -2,9 +2,11 @@ from __future__ import annotations
from app.signaling.ice_parser import parse_ice from app.signaling.ice_parser import parse_ice
class TestIceParser: class TestIceParser:
"""ICE候选者解析单元测试"""
def test_parse_valid_ice(self): def test_parse_valid_ice(self):
"""测试:解析有效的ICE host候选者"""
data = { data = {
"candidate": "1234567890 1 UDP 2130706431 192.168.1.1 12345 typ host", "candidate": "1234567890 1 UDP 2130706431 192.168.1.1 12345 typ host",
"sdpMid": "0", "sdpMid": "0",
@@ -20,9 +22,11 @@ class TestIceParser:
assert cand.type == "host" assert cand.type == "host"
def test_parse_invalid_ice(self): def test_parse_invalid_ice(self):
"""测试:解析无效ICE字符串应返回None"""
assert parse_ice({"candidate": "invalid"}) is None assert parse_ice({"candidate": "invalid"}) is None
def test_parse_srflx(self): def test_parse_srflx(self):
"""测试:解析含有raddr/rport的srflx候选者"""
data = { data = {
"candidate": "abcdef 1 UDP 1686052607 203.0.113.1 50000 typ srflx raddr 192.168.1.1 rport 12345", "candidate": "abcdef 1 UDP 1686052607 203.0.113.1 50000 typ srflx raddr 192.168.1.1 rport 12345",
"sdpMid": "0", "sdpMid": "0",