为所有函数和类添加中文注释文档字符串

This commit is contained in:
2026-06-10 10:34:11 +08:00
parent c612a7ad71
commit c3f93e4441
29 changed files with 103 additions and 17 deletions
+7 -1
View File
@@ -8,9 +8,11 @@ from typing import Any
from loguru import logger
class RepAnnouncer:
"""运动次数语音播报器"""
def __init__(self, *, enabled: bool = True, rate: int = 185, volume: float = 1.0) -> None:
"""初始化TTS引擎(macOS用say,其他系统用pyttsx3"""
self.enabled = enabled
self.rate = rate
self.volume = volume
@@ -24,6 +26,7 @@ class RepAnnouncer:
self._start()
def announce_count(self, count: int) -> None:
"""将次数放入队列进行异步语音播报"""
if not self.enabled or count <= 0:
return
while True:
@@ -34,6 +37,7 @@ class RepAnnouncer:
self._queue.put(str(count))
def close(self) -> None:
"""停止播报线程并释放资源"""
if not self.enabled:
return
self._queue.put(None)
@@ -43,6 +47,7 @@ class RepAnnouncer:
self._current_process.terminate()
def _start(self) -> None:
"""根据平台初始化TTS引擎并启动后台播报线程"""
if sys.platform == "darwin":
self._use_macos_say = True
logger.info("Rep announcer initialized with macOS say")
@@ -63,6 +68,7 @@ class RepAnnouncer:
self._thread.start()
def _run(self) -> None:
"""后台线程:从队列读取文本并调用TTS播放"""
while True:
text = self._queue.get()
if text is None:
+1
View File
@@ -5,6 +5,7 @@ from configs.load import config
def startup() -> None:
"""应用启动初始化:开启崩溃日志和日志系统"""
enable_crash_handler(config.logging.dir_path)
from app.core.logging import setup_logging
setup_logging()
+1
View File
@@ -8,6 +8,7 @@ from configs.load import config
def setup_logging() -> None:
"""配置loguru日志输出到按日期轮转的日志文件"""
log_dir = config.logging.dir_path
log_dir.mkdir(parents=True, exist_ok=True)
+1
View File
@@ -5,6 +5,7 @@ from pathlib import Path
def enable_crash_handler(log_dir: str | Path) -> None:
"""启用faulthandler,将崩溃堆栈写入日志文件"""
log_dir = Path(log_dir)
log_dir.mkdir(parents=True, exist_ok=True)
crash_log = open(log_dir / "posefit-crash.log", "a", buffering=1)
+6 -1
View File
@@ -5,28 +5,33 @@ from contextlib import contextmanager
from loguru import logger
class PerfTimer:
"""性能计时器,用于测量代码段执行耗时"""
def __init__(self, name: str = "") -> None:
self.name = name
self._start = 0.0
self._elapsed = 0.0
def start(self) -> PerfTimer:
"""启动计时器"""
self._start = time.perf_counter()
return self
def stop(self) -> float:
"""停止计时器并返回耗时(秒)"""
self._elapsed = time.perf_counter() - self._start
return self._elapsed
@property
def elapsed_ms(self) -> float:
"""返回已记录耗时(毫秒)"""
return self._elapsed * 1000
@contextmanager
def measure(name: str = ""):
"""上下文管理器:进入时计时,退出时记录耗时日志"""
timer = PerfTimer(name).start()
yield timer
elapsed = timer.stop()
+6 -1
View File
@@ -32,8 +32,9 @@ from app.vision.pose_types import (
RIGHT_WRIST,
)
class DeadBugDetector:
"""死虫式(Dead Bug)运动检测器"""
def __init__(
self,
*,
@@ -43,6 +44,7 @@ class DeadBugDetector:
reset_confirm_frames: int = 3,
prefer_gpu: bool = True,
) -> None:
"""初始化姿态检测器、状态机和可视化渲染组件"""
self.visibility_threshold = visibility_threshold
self._latest_result = None
@@ -72,9 +74,11 @@ class DeadBugDetector:
self._last_timestamp_ms = -1
def close(self) -> None:
"""释放MediaPipe模型资源"""
self._landmarker.close()
def process_frame(self, bgr_frame: np.ndarray, timestamp_ms: int) -> tuple[np.ndarray, DeadBugResult]:
"""处理单帧:姿态检测、指标计算、状态机更新、可视化叠加"""
timestamp_ms = self._normalize_timestamp(timestamp_ms)
with self._result_lock:
@@ -166,6 +170,7 @@ class DeadBugDetector:
return annotated, result
def _normalize_timestamp(self, timestamp_ms: int) -> int:
"""确保时间戳严格递增(MediaPipe要求)"""
if timestamp_ms <= self._last_timestamp_ms:
timestamp_ms = self._last_timestamp_ms + 1
self._last_timestamp_ms = timestamp_ms
+3
View File
@@ -7,6 +7,7 @@ from app.exercises.dead_bug.types import Point
def angle(a: Point, b: Point, c: Point) -> float:
"""计算以b为顶点的三点夹角(度数)"""
ba = np.array([a.x - b.x, a.y - b.y], dtype=np.float32)
bc = np.array([c.x - b.x, c.y - b.y], dtype=np.float32)
denom = float(np.linalg.norm(ba) * np.linalg.norm(bc))
@@ -17,6 +18,7 @@ def angle(a: Point, b: Point, c: Point) -> float:
def distance(a: Point, b: Point) -> float:
"""计算两点之间的欧几里得距离(归一化坐标空间)"""
return float(np.hypot(a.x - b.x, a.y - b.y))
@@ -37,6 +39,7 @@ def calculate_metrics(
right_ankle: int,
visibility_threshold: float = 0.45,
) -> dict:
"""计算四肢关节角度、伸展状态及反馈信息"""
left_elbow_angle = angle(lm[left_shoulder], lm[left_elbow], lm[left_wrist])
right_elbow_angle = angle(lm[right_shoulder], lm[right_elbow], lm[right_wrist])
left_knee_angle = angle(lm[left_hip], lm[left_knee], lm[left_ankle])
+3
View File
@@ -4,10 +4,12 @@ from app.exercises.dead_bug.types import DeadBugMetrics, Point
def has_required_visibility(landmarks: list[Point], required_indices: tuple[int, ...], visibility_threshold: float) -> bool:
"""检查所有必需关键点的可见度是否高于阈值"""
return all(landmarks[i].visibility >= visibility_threshold for i in required_indices)
def detect_diagonal_extension(metrics: DeadBugMetrics) -> str | None:
"""检测是否存在对角伸展(左臂+右腿 或 右臂+左腿)"""
if metrics.left_leg_extended and metrics.right_leg_extended:
return None
@@ -19,6 +21,7 @@ def detect_diagonal_extension(metrics: DeadBugMetrics) -> str | None:
def is_ready_position(metrics: DeadBugMetrics) -> bool:
"""判断是否处于准备姿态(膝盖弯曲且四肢未伸展)"""
knees_bent = metrics.left_knee_angle <= 140 and metrics.right_knee_angle <= 140
legs_not_extended = not metrics.left_leg_extended and not metrics.right_leg_extended
return knees_bent and legs_not_extended and detect_diagonal_extension(metrics) is None
+4 -1
View File
@@ -3,9 +3,11 @@ from __future__ import annotations
from app.exercises.dead_bug.rules import detect_diagonal_extension, is_ready_position
from app.exercises.dead_bug.types import DeadBugMetrics, DeadBugPhase, DeadBugResult
class DeadBugStateMachine:
"""死虫式动作状态机:管理READY/EXTENDING/NEED_RESET/NO_POSE状态转换"""
def __init__(self, *, extension_confirm_frames: int = 4, reset_confirm_frames: int = 3) -> None:
"""初始化并设置状态转换确认帧数"""
self.extension_confirm_frames = extension_confirm_frames
self.reset_confirm_frames = reset_confirm_frames
@@ -17,6 +19,7 @@ class DeadBugStateMachine:
self._reset_frames = 0
def update(self, metrics: DeadBugMetrics) -> DeadBugResult:
"""根据传入指标更新状态机并返回本次结果"""
side = detect_diagonal_extension(metrics)
ready = is_ready_position(metrics)
+4
View File
@@ -5,6 +5,7 @@ from enum import Enum
class DeadBugPhase(str, Enum):
"""死虫式动作阶段枚举"""
READY = "ready"
EXTENDING = "extending"
NEED_RESET = "need_reset"
@@ -13,6 +14,7 @@ class DeadBugPhase(str, Enum):
@dataclass(frozen=True)
class Point:
"""三维关键点坐标及可见度"""
x: float
y: float
z: float
@@ -21,6 +23,7 @@ class Point:
@dataclass
class DeadBugMetrics:
"""四肢关节度量数据"""
left_arm_extended: bool
right_arm_extended: bool
left_leg_extended: bool
@@ -34,6 +37,7 @@ class DeadBugMetrics:
@dataclass
class DeadBugResult:
"""单帧检测结果"""
rep_count: int
phase: DeadBugPhase
side: str | None
+1
View File
@@ -9,6 +9,7 @@ from app.signaling.websocket_server import main as serve
def main():
"""应用入口:启动服务并运行WebSocket信令服务器"""
startup()
logger.info("Starting server...")
try:
+1
View File
@@ -7,6 +7,7 @@ from app.exercises.dead_bug.types import DeadBugResult
def draw_status_overlay(image: np.ndarray, result: DeadBugResult) -> None:
"""在图像上叠加动作状态信息(次数、阶段、反馈)"""
color = (60, 220, 90) if result.is_standard else (50, 180, 255)
cv2.rectangle(image, (12, 12), (520, 142), (20, 20, 20), -1)
cv2.putText(image, f"Dead bug reps: {result.rep_count}", (28, 48), cv2.FONT_HERSHEY_SIMPLEX, 0.9, color, 2)
+1
View File
@@ -18,6 +18,7 @@ def draw_landmarks(
line_thickness: int = 2,
point_radius: int = 4,
) -> None:
"""绘制人体骨架关键点与连接线(仅绘制可见度达标的点)"""
if connections is None:
connections = _POSE_CONNECTIONS
+4
View File
@@ -6,16 +6,20 @@ WINDOW_NAME = "Android Camera (WebRTC)"
def show_frame(image, window_name: str = WINDOW_NAME) -> None:
"""在OpenCV窗口中显示图像帧"""
cv2.imshow(window_name, image)
def wait_key(delay_ms: int = 1) -> int:
"""等待按键并返回ASCII码"""
return cv2.waitKey(delay_ms) & 0xFF
def is_esc_pressed() -> bool:
"""检测ESC键是否被按下"""
return wait_key(1) == 27
def close_window() -> None:
"""关闭所有OpenCV窗口"""
cv2.destroyAllWindows()
+1
View File
@@ -7,6 +7,7 @@ from aiortc import RTCIceCandidate
def parse_ice(data: dict[str, Any]) -> RTCIceCandidate | None:
"""解析ICE候选者字符串为RTCIceCandidate对象"""
match = re.match(
r'candidate:(\S+) (\d) (\S+) (\d+) (\S+) (\d+) typ (\S+)(?: raddr (\S+) rport (\d+))?',
data["candidate"],
+2 -1
View File
@@ -1,10 +1,11 @@
from __future__ import annotations
from dataclasses import dataclass, field
from dataclasses import dataclass
@dataclass
class SignalingMessage:
"""WebRTC信令消息数据模型"""
type: str
sdp: str = ""
candidate: str = ""
+2
View File
@@ -11,6 +11,7 @@ from configs.load import config
async def handle_client(websocket):
"""处理单个WebSocket客户端连接"""
client = websocket.remote_address
logger.info(f"Client connected: {client}")
@@ -21,6 +22,7 @@ async def handle_client(websocket):
async def main():
"""启动WebSocket信令服务器"""
cfg = config.server
logger.info(f"WebRTC signaling server: ws://{cfg.host}:{cfg.port}")
async with websockets.serve(handle_client, cfg.host, cfg.port, max_size=cfg.max_ws_size):
+3
View File
@@ -9,6 +9,7 @@ TARGET_HEIGHT = 720
def resize_to_target(image: np.ndarray, width: int = TARGET_WIDTH, height: int = TARGET_HEIGHT) -> np.ndarray:
"""将图像缩放到目标尺寸(仅当尺寸不一致时)"""
h, w = image.shape[:2]
if w == width and h == height:
return image
@@ -16,8 +17,10 @@ def resize_to_target(image: np.ndarray, width: int = TARGET_WIDTH, height: int =
def bgr_to_rgba(bgr: np.ndarray) -> np.ndarray:
"""将BGR格式图像转换为RGBA格式"""
return cv2.cvtColor(bgr, cv2.COLOR_BGR2RGBA)
def bgr_to_rgb(bgr: np.ndarray) -> np.ndarray:
"""将BGR格式图像转换为RGB格式"""
return cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
+6 -1
View File
@@ -14,8 +14,9 @@ PoseLandmarkerOptions = mp.tasks.vision.PoseLandmarkerOptions
VisionRunningMode = mp.tasks.vision.RunningMode
BaseOptions = mp.tasks.BaseOptions
class PoseLandmarkerWrapper:
"""MediaPipe姿态关键点检测器封装"""
def __init__(
self,
*,
@@ -23,6 +24,7 @@ class PoseLandmarkerWrapper:
prefer_gpu: bool = True,
result_callback: Callable | None = None,
) -> None:
"""初始化姿态检测器,优先尝试GPU委托,失败则回退到CPU"""
self.model_path = model_path or DEFAULT_MODEL_PATH
if prefer_gpu:
@@ -39,6 +41,7 @@ class PoseLandmarkerWrapper:
logger.info("MediaPipe PoseLandmarker initialized with CPU delegate")
def _create(self, delegate, result_callback=None):
"""根据委托类型和回调创建PoseLandmarker实例"""
options = PoseLandmarkerOptions(
base_options=BaseOptions(model_asset_path=self.model_path, delegate=delegate),
running_mode=VisionRunningMode.LIVE_STREAM,
@@ -51,7 +54,9 @@ class PoseLandmarkerWrapper:
return PoseLandmarker.create_from_options(options)
def detect_async(self, mp_image, timestamp_ms: int) -> None:
"""异步执行姿态检测"""
return self._landmarker.detect_async(mp_image, timestamp_ms)
def close(self) -> None:
"""释放MediaPipe资源"""
self._landmarker.close()
+1
View File
@@ -9,6 +9,7 @@ TARGET_HEIGHT = 720
def validate_frame_size(image: np.ndarray, width: int = TARGET_WIDTH, height: int = TARGET_HEIGHT) -> None:
"""验证视频帧尺寸是否与目标尺寸一致,不一致时记录警告"""
h, w = image.shape[:2]
if w != width or h != height:
logger.warning("Unexpected frame size: {}x{}", w, h)
+5 -1
View File
@@ -10,13 +10,15 @@ from loguru import logger
from app.signaling.ice_parser import parse_ice
from app.webrtc.video_receiver import VideoReceiver
class PeerSession:
"""WebRTC对等连接会话管理"""
def __init__(self) -> None:
self._pc = RTCPeerConnection()
self._video_task: asyncio.Task | None = None
async def handle(self, websocket) -> None:
"""处理WebSocket信令交互与WebRTC连接建立"""
self._setup_events()
try:
@@ -47,6 +49,7 @@ class PeerSession:
await self._cleanup()
def _setup_events(self) -> None:
"""注册ICE连接状态变化和视频轨道接收事件处理器"""
@self._pc.on("track")
async def on_track(track):
logger.info(f"Track received: kind={track.kind}")
@@ -61,6 +64,7 @@ class PeerSession:
await self._pc.close()
async def _cleanup(self) -> None:
"""清理视频任务并关闭对等连接"""
if self._video_task:
self._video_task.cancel()
try:
+4 -1
View File
@@ -13,6 +13,7 @@ from configs.load import config
def _format_pose_debug(pose_result) -> str:
"""格式化姿态检测结果用于调试日志输出"""
metrics = pose_result.metrics
if metrics is None:
return "metrics=None"
@@ -24,12 +25,14 @@ def _format_pose_debug(pose_result) -> str:
f"ll={metrics.left_leg_extended}, rl={metrics.right_leg_extended})"
)
class VideoReceiver:
"""视频轨道接收与运动检测流水线"""
def __init__(self, track) -> None:
self._track = track
async def run(self) -> None:
"""持续接收视频帧并进行姿态检测、渲染和语音播报"""
logger.info("Start receiving video frames, process_every_n={}", config.video.process_every_n_frames)
frame_count = 0