diff --git a/README.md b/README.md index d3d8295..0c16551 100644 --- a/README.md +++ b/README.md @@ -7,8 +7,4 @@ Real-time exercise pose detection and coaching via WebRTC. ``` pip install -r requirements.txt python run.py -``` - -## Configuration - -Copy `.env.example` to `.env` and adjust settings, or set environment variables directly. +`` \ No newline at end of file diff --git a/app/audio/rep_announcer.py b/app/audio/rep_announcer.py index 3b228e9..3c42449 100644 --- a/app/audio/rep_announcer.py +++ b/app/audio/rep_announcer.py @@ -8,9 +8,11 @@ from typing import Any from loguru import logger - class RepAnnouncer: + """运动次数语音播报器""" + def __init__(self, *, enabled: bool = True, rate: int = 185, volume: float = 1.0) -> None: + """初始化TTS引擎(macOS用say,其他系统用pyttsx3)""" self.enabled = enabled self.rate = rate self.volume = volume @@ -24,6 +26,7 @@ class RepAnnouncer: self._start() def announce_count(self, count: int) -> None: + """将次数放入队列进行异步语音播报""" if not self.enabled or count <= 0: return while True: @@ -34,6 +37,7 @@ class RepAnnouncer: self._queue.put(str(count)) def close(self) -> None: + """停止播报线程并释放资源""" if not self.enabled: return self._queue.put(None) @@ -43,6 +47,7 @@ class RepAnnouncer: self._current_process.terminate() def _start(self) -> None: + """根据平台初始化TTS引擎并启动后台播报线程""" if sys.platform == "darwin": self._use_macos_say = True logger.info("Rep announcer initialized with macOS say") @@ -63,6 +68,7 @@ class RepAnnouncer: self._thread.start() def _run(self) -> None: + """后台线程:从队列读取文本并调用TTS播放""" while True: text = self._queue.get() if text is None: diff --git a/app/core/lifecycle.py b/app/core/lifecycle.py index 45ff7c7..9d2138f 100644 --- a/app/core/lifecycle.py +++ b/app/core/lifecycle.py @@ -5,6 +5,7 @@ from configs.load import config def startup() -> None: + """应用启动初始化:开启崩溃日志和日志系统""" enable_crash_handler(config.logging.dir_path) from app.core.logging import setup_logging setup_logging() diff --git a/app/core/logging.py b/app/core/logging.py index 69d2111..fbf9e94 100644 --- a/app/core/logging.py +++ b/app/core/logging.py @@ -8,6 +8,7 @@ from configs.load import config def setup_logging() -> None: + """配置loguru日志输出到按日期轮转的日志文件""" log_dir = config.logging.dir_path log_dir.mkdir(parents=True, exist_ok=True) diff --git a/app/diagnostics/crash_handler.py b/app/diagnostics/crash_handler.py index a5486bf..8dfb79d 100644 --- a/app/diagnostics/crash_handler.py +++ b/app/diagnostics/crash_handler.py @@ -5,6 +5,7 @@ from pathlib import Path def enable_crash_handler(log_dir: str | Path) -> None: + """启用faulthandler,将崩溃堆栈写入日志文件""" log_dir = Path(log_dir) log_dir.mkdir(parents=True, exist_ok=True) crash_log = open(log_dir / "posefit-crash.log", "a", buffering=1) diff --git a/app/diagnostics/perf_timer.py b/app/diagnostics/perf_timer.py index 25c2edd..313f134 100644 --- a/app/diagnostics/perf_timer.py +++ b/app/diagnostics/perf_timer.py @@ -5,28 +5,33 @@ from contextlib import contextmanager from loguru import logger - class PerfTimer: + """性能计时器,用于测量代码段执行耗时""" + def __init__(self, name: str = "") -> None: self.name = name self._start = 0.0 self._elapsed = 0.0 def start(self) -> PerfTimer: + """启动计时器""" self._start = time.perf_counter() return self def stop(self) -> float: + """停止计时器并返回耗时(秒)""" self._elapsed = time.perf_counter() - self._start return self._elapsed @property def elapsed_ms(self) -> float: + """返回已记录耗时(毫秒)""" return self._elapsed * 1000 @contextmanager def measure(name: str = ""): + """上下文管理器:进入时计时,退出时记录耗时日志""" timer = PerfTimer(name).start() yield timer elapsed = timer.stop() diff --git a/app/exercises/dead_bug/detector.py b/app/exercises/dead_bug/detector.py index 8f05094..670bcee 100644 --- a/app/exercises/dead_bug/detector.py +++ b/app/exercises/dead_bug/detector.py @@ -32,8 +32,9 @@ from app.vision.pose_types import ( RIGHT_WRIST, ) - class DeadBugDetector: + """死虫式(Dead Bug)运动检测器""" + def __init__( self, *, @@ -43,6 +44,7 @@ class DeadBugDetector: reset_confirm_frames: int = 3, prefer_gpu: bool = True, ) -> None: + """初始化姿态检测器、状态机和可视化渲染组件""" self.visibility_threshold = visibility_threshold self._latest_result = None @@ -72,9 +74,11 @@ class DeadBugDetector: self._last_timestamp_ms = -1 def close(self) -> None: + """释放MediaPipe模型资源""" self._landmarker.close() def process_frame(self, bgr_frame: np.ndarray, timestamp_ms: int) -> tuple[np.ndarray, DeadBugResult]: + """处理单帧:姿态检测、指标计算、状态机更新、可视化叠加""" timestamp_ms = self._normalize_timestamp(timestamp_ms) with self._result_lock: @@ -166,6 +170,7 @@ class DeadBugDetector: return annotated, result def _normalize_timestamp(self, timestamp_ms: int) -> int: + """确保时间戳严格递增(MediaPipe要求)""" if timestamp_ms <= self._last_timestamp_ms: timestamp_ms = self._last_timestamp_ms + 1 self._last_timestamp_ms = timestamp_ms diff --git a/app/exercises/dead_bug/metrics.py b/app/exercises/dead_bug/metrics.py index ad5601b..8c90cd9 100644 --- a/app/exercises/dead_bug/metrics.py +++ b/app/exercises/dead_bug/metrics.py @@ -7,6 +7,7 @@ from app.exercises.dead_bug.types import Point def angle(a: Point, b: Point, c: Point) -> float: + """计算以b为顶点的三点夹角(度数)""" ba = np.array([a.x - b.x, a.y - b.y], dtype=np.float32) bc = np.array([c.x - b.x, c.y - b.y], dtype=np.float32) denom = float(np.linalg.norm(ba) * np.linalg.norm(bc)) @@ -17,6 +18,7 @@ def angle(a: Point, b: Point, c: Point) -> float: def distance(a: Point, b: Point) -> float: + """计算两点之间的欧几里得距离(归一化坐标空间)""" return float(np.hypot(a.x - b.x, a.y - b.y)) @@ -37,6 +39,7 @@ def calculate_metrics( right_ankle: int, visibility_threshold: float = 0.45, ) -> dict: + """计算四肢关节角度、伸展状态及反馈信息""" left_elbow_angle = angle(lm[left_shoulder], lm[left_elbow], lm[left_wrist]) right_elbow_angle = angle(lm[right_shoulder], lm[right_elbow], lm[right_wrist]) left_knee_angle = angle(lm[left_hip], lm[left_knee], lm[left_ankle]) diff --git a/app/exercises/dead_bug/rules.py b/app/exercises/dead_bug/rules.py index 2256d81..e477525 100644 --- a/app/exercises/dead_bug/rules.py +++ b/app/exercises/dead_bug/rules.py @@ -4,10 +4,12 @@ from app.exercises.dead_bug.types import DeadBugMetrics, Point def has_required_visibility(landmarks: list[Point], required_indices: tuple[int, ...], visibility_threshold: float) -> bool: + """检查所有必需关键点的可见度是否高于阈值""" return all(landmarks[i].visibility >= visibility_threshold for i in required_indices) def detect_diagonal_extension(metrics: DeadBugMetrics) -> str | None: + """检测是否存在对角伸展(左臂+右腿 或 右臂+左腿)""" if metrics.left_leg_extended and metrics.right_leg_extended: return None @@ -19,6 +21,7 @@ def detect_diagonal_extension(metrics: DeadBugMetrics) -> str | None: def is_ready_position(metrics: DeadBugMetrics) -> bool: + """判断是否处于准备姿态(膝盖弯曲且四肢未伸展)""" knees_bent = metrics.left_knee_angle <= 140 and metrics.right_knee_angle <= 140 legs_not_extended = not metrics.left_leg_extended and not metrics.right_leg_extended return knees_bent and legs_not_extended and detect_diagonal_extension(metrics) is None diff --git a/app/exercises/dead_bug/state_machine.py b/app/exercises/dead_bug/state_machine.py index d79073a..a5e48dc 100644 --- a/app/exercises/dead_bug/state_machine.py +++ b/app/exercises/dead_bug/state_machine.py @@ -3,9 +3,11 @@ from __future__ import annotations from app.exercises.dead_bug.rules import detect_diagonal_extension, is_ready_position from app.exercises.dead_bug.types import DeadBugMetrics, DeadBugPhase, DeadBugResult - class DeadBugStateMachine: + """死虫式动作状态机:管理READY/EXTENDING/NEED_RESET/NO_POSE状态转换""" + def __init__(self, *, extension_confirm_frames: int = 4, reset_confirm_frames: int = 3) -> None: + """初始化并设置状态转换确认帧数""" self.extension_confirm_frames = extension_confirm_frames self.reset_confirm_frames = reset_confirm_frames @@ -17,6 +19,7 @@ class DeadBugStateMachine: self._reset_frames = 0 def update(self, metrics: DeadBugMetrics) -> DeadBugResult: + """根据传入指标更新状态机并返回本次结果""" side = detect_diagonal_extension(metrics) ready = is_ready_position(metrics) diff --git a/app/exercises/dead_bug/types.py b/app/exercises/dead_bug/types.py index 15c10ec..462d8ec 100644 --- a/app/exercises/dead_bug/types.py +++ b/app/exercises/dead_bug/types.py @@ -5,6 +5,7 @@ from enum import Enum class DeadBugPhase(str, Enum): + """死虫式动作阶段枚举""" READY = "ready" EXTENDING = "extending" NEED_RESET = "need_reset" @@ -13,6 +14,7 @@ class DeadBugPhase(str, Enum): @dataclass(frozen=True) class Point: + """三维关键点坐标及可见度""" x: float y: float z: float @@ -21,6 +23,7 @@ class Point: @dataclass class DeadBugMetrics: + """四肢关节度量数据""" left_arm_extended: bool right_arm_extended: bool left_leg_extended: bool @@ -34,6 +37,7 @@ class DeadBugMetrics: @dataclass class DeadBugResult: + """单帧检测结果""" rep_count: int phase: DeadBugPhase side: str | None diff --git a/app/main.py b/app/main.py index aa7b913..a72e699 100644 --- a/app/main.py +++ b/app/main.py @@ -9,6 +9,7 @@ from app.signaling.websocket_server import main as serve def main(): + """应用入口:启动服务并运行WebSocket信令服务器""" startup() logger.info("Starting server...") try: diff --git a/app/rendering/overlay_renderer.py b/app/rendering/overlay_renderer.py index 3e30ad5..02cf574 100644 --- a/app/rendering/overlay_renderer.py +++ b/app/rendering/overlay_renderer.py @@ -7,6 +7,7 @@ from app.exercises.dead_bug.types import DeadBugResult def draw_status_overlay(image: np.ndarray, result: DeadBugResult) -> None: + """在图像上叠加动作状态信息(次数、阶段、反馈)""" color = (60, 220, 90) if result.is_standard else (50, 180, 255) cv2.rectangle(image, (12, 12), (520, 142), (20, 20, 20), -1) cv2.putText(image, f"Dead bug reps: {result.rep_count}", (28, 48), cv2.FONT_HERSHEY_SIMPLEX, 0.9, color, 2) diff --git a/app/rendering/skeleton_renderer.py b/app/rendering/skeleton_renderer.py index 7e033e4..3cc9b86 100644 --- a/app/rendering/skeleton_renderer.py +++ b/app/rendering/skeleton_renderer.py @@ -18,6 +18,7 @@ def draw_landmarks( line_thickness: int = 2, point_radius: int = 4, ) -> None: + """绘制人体骨架关键点与连接线(仅绘制可见度达标的点)""" if connections is None: connections = _POSE_CONNECTIONS diff --git a/app/rendering/window_display.py b/app/rendering/window_display.py index 58c0e0c..f6efe35 100644 --- a/app/rendering/window_display.py +++ b/app/rendering/window_display.py @@ -6,16 +6,20 @@ WINDOW_NAME = "Android Camera (WebRTC)" def show_frame(image, window_name: str = WINDOW_NAME) -> None: + """在OpenCV窗口中显示图像帧""" cv2.imshow(window_name, image) def wait_key(delay_ms: int = 1) -> int: + """等待按键并返回ASCII码""" return cv2.waitKey(delay_ms) & 0xFF def is_esc_pressed() -> bool: + """检测ESC键是否被按下""" return wait_key(1) == 27 def close_window() -> None: + """关闭所有OpenCV窗口""" cv2.destroyAllWindows() diff --git a/app/signaling/ice_parser.py b/app/signaling/ice_parser.py index fcf4af4..e20c5f3 100644 --- a/app/signaling/ice_parser.py +++ b/app/signaling/ice_parser.py @@ -7,6 +7,7 @@ from aiortc import RTCIceCandidate def parse_ice(data: dict[str, Any]) -> RTCIceCandidate | None: + """解析ICE候选者字符串为RTCIceCandidate对象""" match = re.match( r'candidate:(\S+) (\d) (\S+) (\d+) (\S+) (\d+) typ (\S+)(?: raddr (\S+) rport (\d+))?', data["candidate"], diff --git a/app/signaling/message_models.py b/app/signaling/message_models.py index 16f4bfc..27c177f 100644 --- a/app/signaling/message_models.py +++ b/app/signaling/message_models.py @@ -1,10 +1,11 @@ from __future__ import annotations -from dataclasses import dataclass, field +from dataclasses import dataclass @dataclass class SignalingMessage: + """WebRTC信令消息数据模型""" type: str sdp: str = "" candidate: str = "" diff --git a/app/signaling/websocket_server.py b/app/signaling/websocket_server.py index 0b71e55..2a1d174 100644 --- a/app/signaling/websocket_server.py +++ b/app/signaling/websocket_server.py @@ -11,6 +11,7 @@ from configs.load import config async def handle_client(websocket): + """处理单个WebSocket客户端连接""" client = websocket.remote_address logger.info(f"Client connected: {client}") @@ -21,6 +22,7 @@ async def handle_client(websocket): async def main(): + """启动WebSocket信令服务器""" cfg = config.server logger.info(f"WebRTC signaling server: ws://{cfg.host}:{cfg.port}") async with websockets.serve(handle_client, cfg.host, cfg.port, max_size=cfg.max_ws_size): diff --git a/app/vision/frame_utils.py b/app/vision/frame_utils.py index b7a894c..a38f35e 100644 --- a/app/vision/frame_utils.py +++ b/app/vision/frame_utils.py @@ -9,6 +9,7 @@ TARGET_HEIGHT = 720 def resize_to_target(image: np.ndarray, width: int = TARGET_WIDTH, height: int = TARGET_HEIGHT) -> np.ndarray: + """将图像缩放到目标尺寸(仅当尺寸不一致时)""" h, w = image.shape[:2] if w == width and h == height: return image @@ -16,8 +17,10 @@ def resize_to_target(image: np.ndarray, width: int = TARGET_WIDTH, height: int = def bgr_to_rgba(bgr: np.ndarray) -> np.ndarray: + """将BGR格式图像转换为RGBA格式""" return cv2.cvtColor(bgr, cv2.COLOR_BGR2RGBA) def bgr_to_rgb(bgr: np.ndarray) -> np.ndarray: + """将BGR格式图像转换为RGB格式""" return cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB) diff --git a/app/vision/pose_landmarker.py b/app/vision/pose_landmarker.py index 7257c46..13de5e9 100644 --- a/app/vision/pose_landmarker.py +++ b/app/vision/pose_landmarker.py @@ -14,8 +14,9 @@ PoseLandmarkerOptions = mp.tasks.vision.PoseLandmarkerOptions VisionRunningMode = mp.tasks.vision.RunningMode BaseOptions = mp.tasks.BaseOptions - class PoseLandmarkerWrapper: + """MediaPipe姿态关键点检测器封装""" + def __init__( self, *, @@ -23,6 +24,7 @@ class PoseLandmarkerWrapper: prefer_gpu: bool = True, result_callback: Callable | None = None, ) -> None: + """初始化姿态检测器,优先尝试GPU委托,失败则回退到CPU""" self.model_path = model_path or DEFAULT_MODEL_PATH if prefer_gpu: @@ -39,6 +41,7 @@ class PoseLandmarkerWrapper: logger.info("MediaPipe PoseLandmarker initialized with CPU delegate") def _create(self, delegate, result_callback=None): + """根据委托类型和回调创建PoseLandmarker实例""" options = PoseLandmarkerOptions( base_options=BaseOptions(model_asset_path=self.model_path, delegate=delegate), running_mode=VisionRunningMode.LIVE_STREAM, @@ -51,7 +54,9 @@ class PoseLandmarkerWrapper: return PoseLandmarker.create_from_options(options) def detect_async(self, mp_image, timestamp_ms: int) -> None: + """异步执行姿态检测""" return self._landmarker.detect_async(mp_image, timestamp_ms) def close(self) -> None: + """释放MediaPipe资源""" self._landmarker.close() diff --git a/app/webrtc/frame_source.py b/app/webrtc/frame_source.py index 80d7ed1..d9f2411 100644 --- a/app/webrtc/frame_source.py +++ b/app/webrtc/frame_source.py @@ -9,6 +9,7 @@ TARGET_HEIGHT = 720 def validate_frame_size(image: np.ndarray, width: int = TARGET_WIDTH, height: int = TARGET_HEIGHT) -> None: + """验证视频帧尺寸是否与目标尺寸一致,不一致时记录警告""" h, w = image.shape[:2] if w != width or h != height: logger.warning("Unexpected frame size: {}x{}", w, h) diff --git a/app/webrtc/peer_session.py b/app/webrtc/peer_session.py index 611324a..5b7b76c 100644 --- a/app/webrtc/peer_session.py +++ b/app/webrtc/peer_session.py @@ -10,13 +10,15 @@ from loguru import logger from app.signaling.ice_parser import parse_ice from app.webrtc.video_receiver import VideoReceiver - class PeerSession: + """WebRTC对等连接会话管理""" + def __init__(self) -> None: self._pc = RTCPeerConnection() self._video_task: asyncio.Task | None = None async def handle(self, websocket) -> None: + """处理WebSocket信令交互与WebRTC连接建立""" self._setup_events() try: @@ -47,6 +49,7 @@ class PeerSession: await self._cleanup() def _setup_events(self) -> None: + """注册ICE连接状态变化和视频轨道接收事件处理器""" @self._pc.on("track") async def on_track(track): logger.info(f"Track received: kind={track.kind}") @@ -61,6 +64,7 @@ class PeerSession: await self._pc.close() async def _cleanup(self) -> None: + """清理视频任务并关闭对等连接""" if self._video_task: self._video_task.cancel() try: diff --git a/app/webrtc/video_receiver.py b/app/webrtc/video_receiver.py index 8b5b65c..60c6892 100644 --- a/app/webrtc/video_receiver.py +++ b/app/webrtc/video_receiver.py @@ -13,6 +13,7 @@ from configs.load import config def _format_pose_debug(pose_result) -> str: + """格式化姿态检测结果用于调试日志输出""" metrics = pose_result.metrics if metrics is None: return "metrics=None" @@ -24,12 +25,14 @@ def _format_pose_debug(pose_result) -> str: f"ll={metrics.left_leg_extended}, rl={metrics.right_leg_extended})" ) - class VideoReceiver: + """视频轨道接收与运动检测流水线""" + def __init__(self, track) -> None: self._track = track async def run(self) -> None: + """持续接收视频帧并进行姿态检测、渲染和语音播报""" logger.info("Start receiving video frames, process_every_n={}", config.video.process_every_n_frames) frame_count = 0 diff --git a/configs/load.py b/configs/load.py index 9ba8ec9..b0c50ef 100644 --- a/configs/load.py +++ b/configs/load.py @@ -29,6 +29,7 @@ _SECTION_CLASS = { def _dict_to_dataclass(cls: type, data: dict[str, Any] | None) -> dict[str, Any]: + """将字典过滤为仅包含指定dataclass字段的键值对""" if data is None: return {} field_names = {f.name for f in dataclasses.fields(cls)} @@ -36,6 +37,7 @@ def _dict_to_dataclass(cls: type, data: dict[str, Any] | None) -> dict[str, Any] def _read_yaml(path: Path) -> dict[str, Any]: + """读取YAML配置文件并返回字典""" if not path.exists(): return {} with open(path, encoding="utf-8") as f: @@ -43,6 +45,7 @@ def _read_yaml(path: Path) -> dict[str, Any]: def load_config(config_path: str | Path | None = None) -> AppConfig: + """加载并解析应用配置,返回AppConfig实例""" if config_path is None: config_path = _PROJECT_ROOT / "config.yaml" diff --git a/configs/models.py b/configs/models.py index ae31778..1853318 100644 --- a/configs/models.py +++ b/configs/models.py @@ -6,6 +6,7 @@ from pathlib import Path @dataclass class ServerConfig: + """WebSocket服务器配置""" host: str = "0.0.0.0" port: int = 8765 max_ws_size: int = 10_485_760 @@ -13,16 +14,19 @@ class ServerConfig: @dataclass class VideoConfig: + """视频帧处理配置""" process_every_n_frames: int = 1 @dataclass class ModelConfig: + """姿态检测模型配置""" path: str = "" prefer_gpu: bool = True @property def resolved_path(self) -> str: + """返回模型文件的绝对路径""" if self.path: return self.path return str(Path(__file__).resolve().parent.parent / "pose_models" / "pose_landmarker_full.task") @@ -30,6 +34,7 @@ class ModelConfig: @dataclass class DeadBugConfig: + """死虫式(Dead Bug)运动检测配置""" visibility_threshold: float = 0.45 extension_confirm_frames: int = 4 reset_confirm_frames: int = 3 @@ -37,6 +42,7 @@ class DeadBugConfig: @dataclass class AudioConfig: + """语音播报配置""" rep_announcer_enabled: bool = True rep_announcer_rate: int = 185 rep_announcer_volume: float = 1.0 @@ -44,17 +50,20 @@ class AudioConfig: @dataclass class LoggingConfig: + """日志配置""" dir: str = "logs" rotation: str = "20 MB" retention: str = "14 days" @property def dir_path(self) -> Path: + """返回日志目录的绝对路径""" return Path(__file__).resolve().parent.parent / self.dir @dataclass class AppConfig: + """应用总配置,聚合所有子配置""" server: ServerConfig = field(default_factory=ServerConfig) video: VideoConfig = field(default_factory=VideoConfig) model: ModelConfig = field(default_factory=ModelConfig) diff --git a/requirements.txt b/requirements.txt index 6a5fef2..9e20cea 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ aiortc>=1.9.0 websockets>=13.0 -opencv-contrib-python>=4.10.0 +opencv-contrib-python>=4.13.0.92 numpy>=1.26,<2 loguru>=0.7.0 mediapipe==0.10.21 diff --git a/tests/test_dead_bug_rules.py b/tests/test_dead_bug_rules.py index 86d48d9..4c413de 100644 --- a/tests/test_dead_bug_rules.py +++ b/tests/test_dead_bug_rules.py @@ -3,26 +3,32 @@ from __future__ import annotations from app.exercises.dead_bug.rules import detect_diagonal_extension, has_required_visibility, is_ready_position from app.exercises.dead_bug.types import DeadBugMetrics, Point - class TestDeadBugRules: + """死虫式规则函数单元测试""" + def _make_landmark(self, x=0.5, y=0.5, z=0.0, visibility=1.0): + """创建测试用Point对象""" return Point(x, y, z, visibility) def _make_visible_landmarks(self): + """创建33个全可见的测试用关键点""" return [self._make_landmark() for _ in range(33)] def test_has_required_visibility_all_visible(self): + """测试:所有关键点可见时应返回True""" lm = self._make_visible_landmarks() indices = (11, 12, 13, 14, 15, 16, 23, 24, 25, 26, 27, 28) assert has_required_visibility(lm, indices, 0.45) def test_has_required_visibility_low(self): + """测试:关键点可见度过低时应返回False""" lm = self._make_visible_landmarks() lm[11] = self._make_landmark(visibility=0.1) indices = (11, 12, 13, 14, 15, 16, 23, 24, 25, 26, 27, 28) assert not has_required_visibility(lm, indices, 0.45) def test_detect_diagonal_extension_none(self): + """测试:四肢均未伸展时应返回None""" metrics = DeadBugMetrics( left_arm_extended=False, right_arm_extended=False, left_leg_extended=False, right_leg_extended=False, @@ -33,6 +39,7 @@ class TestDeadBugRules: assert detect_diagonal_extension(metrics) is None def test_detect_diagonal_extension_left_arm_right_leg(self): + """测试:左臂+右腿对角伸展检测""" metrics = DeadBugMetrics( left_arm_extended=True, right_arm_extended=False, left_leg_extended=False, right_leg_extended=True, @@ -43,6 +50,7 @@ class TestDeadBugRules: assert detect_diagonal_extension(metrics) == "left_arm_right_leg" def test_is_ready_position(self): + """测试:膝盖弯曲且四肢收缩应识别为准备姿态""" metrics = DeadBugMetrics( left_arm_extended=False, right_arm_extended=False, left_leg_extended=False, right_leg_extended=False, @@ -53,6 +61,7 @@ class TestDeadBugRules: assert is_ready_position(metrics) def test_is_not_ready_legs_extended(self): + """测试:腿部伸展时不识别为准备姿态""" metrics = DeadBugMetrics( left_arm_extended=False, right_arm_extended=False, left_leg_extended=True, right_leg_extended=False, diff --git a/tests/test_dead_bug_state_machine.py b/tests/test_dead_bug_state_machine.py index a7a10de..e187d79 100644 --- a/tests/test_dead_bug_state_machine.py +++ b/tests/test_dead_bug_state_machine.py @@ -3,9 +3,11 @@ from __future__ import annotations from app.exercises.dead_bug.state_machine import DeadBugStateMachine from app.exercises.dead_bug.types import DeadBugMetrics, DeadBugPhase - class TestDeadBugStateMachine: + """死虫式状态机单元测试""" + def _ready_metrics(self) -> DeadBugMetrics: + """构建准备姿态的度量数据""" return DeadBugMetrics( left_arm_extended=False, right_arm_extended=False, left_leg_extended=False, right_leg_extended=False, @@ -15,6 +17,7 @@ class TestDeadBugStateMachine: ) def _extended_left(self) -> DeadBugMetrics: + """构建左臂+右腿对角伸展的度量数据""" return DeadBugMetrics( left_arm_extended=True, right_arm_extended=False, left_leg_extended=False, right_leg_extended=True, @@ -24,17 +27,20 @@ class TestDeadBugStateMachine: ) def test_initial_state(self): + """测试:状态机初始化后应为READY且计数为0""" sm = DeadBugStateMachine() assert sm.phase == DeadBugPhase.READY assert sm.rep_count == 0 def test_no_transition_in_ready(self): + """测试:准备姿态下不触发状态转换""" sm = DeadBugStateMachine() result = sm.update(self._ready_metrics()) assert sm.phase == DeadBugPhase.READY assert result.rep_count == 0 def test_confirm_extension(self): + """测试:连续确认帧数后从READY转换到EXTENDING""" sm = DeadBugStateMachine(extension_confirm_frames=2, reset_confirm_frames=2) sm.update(self._extended_left()) assert sm.phase == DeadBugPhase.READY diff --git a/tests/test_ice_parser.py b/tests/test_ice_parser.py index e69e3ba..0b79947 100644 --- a/tests/test_ice_parser.py +++ b/tests/test_ice_parser.py @@ -2,9 +2,11 @@ from __future__ import annotations from app.signaling.ice_parser import parse_ice - class TestIceParser: + """ICE候选者解析单元测试""" + def test_parse_valid_ice(self): + """测试:解析有效的ICE host候选者""" data = { "candidate": "1234567890 1 UDP 2130706431 192.168.1.1 12345 typ host", "sdpMid": "0", @@ -20,9 +22,11 @@ class TestIceParser: assert cand.type == "host" def test_parse_invalid_ice(self): + """测试:解析无效ICE字符串应返回None""" assert parse_ice({"candidate": "invalid"}) is None def test_parse_srflx(self): + """测试:解析含有raddr/rport的srflx候选者""" data = { "candidate": "abcdef 1 UDP 1686052607 203.0.113.1 50000 typ srflx raddr 192.168.1.1 rport 12345", "sdpMid": "0",