为所有函数和类添加中文注释文档字符串
This commit is contained in:
@@ -8,9 +8,11 @@ from typing import Any
|
||||
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class RepAnnouncer:
|
||||
"""运动次数语音播报器"""
|
||||
|
||||
def __init__(self, *, enabled: bool = True, rate: int = 185, volume: float = 1.0) -> None:
|
||||
"""初始化TTS引擎(macOS用say,其他系统用pyttsx3)"""
|
||||
self.enabled = enabled
|
||||
self.rate = rate
|
||||
self.volume = volume
|
||||
@@ -24,6 +26,7 @@ class RepAnnouncer:
|
||||
self._start()
|
||||
|
||||
def announce_count(self, count: int) -> None:
|
||||
"""将次数放入队列进行异步语音播报"""
|
||||
if not self.enabled or count <= 0:
|
||||
return
|
||||
while True:
|
||||
@@ -34,6 +37,7 @@ class RepAnnouncer:
|
||||
self._queue.put(str(count))
|
||||
|
||||
def close(self) -> None:
|
||||
"""停止播报线程并释放资源"""
|
||||
if not self.enabled:
|
||||
return
|
||||
self._queue.put(None)
|
||||
@@ -43,6 +47,7 @@ class RepAnnouncer:
|
||||
self._current_process.terminate()
|
||||
|
||||
def _start(self) -> None:
|
||||
"""根据平台初始化TTS引擎并启动后台播报线程"""
|
||||
if sys.platform == "darwin":
|
||||
self._use_macos_say = True
|
||||
logger.info("Rep announcer initialized with macOS say")
|
||||
@@ -63,6 +68,7 @@ class RepAnnouncer:
|
||||
self._thread.start()
|
||||
|
||||
def _run(self) -> None:
|
||||
"""后台线程:从队列读取文本并调用TTS播放"""
|
||||
while True:
|
||||
text = self._queue.get()
|
||||
if text is None:
|
||||
|
||||
@@ -5,6 +5,7 @@ from configs.load import config
|
||||
|
||||
|
||||
def startup() -> None:
|
||||
"""应用启动初始化:开启崩溃日志和日志系统"""
|
||||
enable_crash_handler(config.logging.dir_path)
|
||||
from app.core.logging import setup_logging
|
||||
setup_logging()
|
||||
|
||||
@@ -8,6 +8,7 @@ from configs.load import config
|
||||
|
||||
|
||||
def setup_logging() -> None:
|
||||
"""配置loguru日志输出到按日期轮转的日志文件"""
|
||||
log_dir = config.logging.dir_path
|
||||
log_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ from pathlib import Path
|
||||
|
||||
|
||||
def enable_crash_handler(log_dir: str | Path) -> None:
|
||||
"""启用faulthandler,将崩溃堆栈写入日志文件"""
|
||||
log_dir = Path(log_dir)
|
||||
log_dir.mkdir(parents=True, exist_ok=True)
|
||||
crash_log = open(log_dir / "posefit-crash.log", "a", buffering=1)
|
||||
|
||||
@@ -5,28 +5,33 @@ from contextlib import contextmanager
|
||||
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class PerfTimer:
|
||||
"""性能计时器,用于测量代码段执行耗时"""
|
||||
|
||||
def __init__(self, name: str = "") -> None:
|
||||
self.name = name
|
||||
self._start = 0.0
|
||||
self._elapsed = 0.0
|
||||
|
||||
def start(self) -> PerfTimer:
|
||||
"""启动计时器"""
|
||||
self._start = time.perf_counter()
|
||||
return self
|
||||
|
||||
def stop(self) -> float:
|
||||
"""停止计时器并返回耗时(秒)"""
|
||||
self._elapsed = time.perf_counter() - self._start
|
||||
return self._elapsed
|
||||
|
||||
@property
|
||||
def elapsed_ms(self) -> float:
|
||||
"""返回已记录耗时(毫秒)"""
|
||||
return self._elapsed * 1000
|
||||
|
||||
|
||||
@contextmanager
|
||||
def measure(name: str = ""):
|
||||
"""上下文管理器:进入时计时,退出时记录耗时日志"""
|
||||
timer = PerfTimer(name).start()
|
||||
yield timer
|
||||
elapsed = timer.stop()
|
||||
|
||||
@@ -32,8 +32,9 @@ from app.vision.pose_types import (
|
||||
RIGHT_WRIST,
|
||||
)
|
||||
|
||||
|
||||
class DeadBugDetector:
|
||||
"""死虫式(Dead Bug)运动检测器"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
@@ -43,6 +44,7 @@ class DeadBugDetector:
|
||||
reset_confirm_frames: int = 3,
|
||||
prefer_gpu: bool = True,
|
||||
) -> None:
|
||||
"""初始化姿态检测器、状态机和可视化渲染组件"""
|
||||
self.visibility_threshold = visibility_threshold
|
||||
|
||||
self._latest_result = None
|
||||
@@ -72,9 +74,11 @@ class DeadBugDetector:
|
||||
self._last_timestamp_ms = -1
|
||||
|
||||
def close(self) -> None:
|
||||
"""释放MediaPipe模型资源"""
|
||||
self._landmarker.close()
|
||||
|
||||
def process_frame(self, bgr_frame: np.ndarray, timestamp_ms: int) -> tuple[np.ndarray, DeadBugResult]:
|
||||
"""处理单帧:姿态检测、指标计算、状态机更新、可视化叠加"""
|
||||
timestamp_ms = self._normalize_timestamp(timestamp_ms)
|
||||
|
||||
with self._result_lock:
|
||||
@@ -166,6 +170,7 @@ class DeadBugDetector:
|
||||
return annotated, result
|
||||
|
||||
def _normalize_timestamp(self, timestamp_ms: int) -> int:
|
||||
"""确保时间戳严格递增(MediaPipe要求)"""
|
||||
if timestamp_ms <= self._last_timestamp_ms:
|
||||
timestamp_ms = self._last_timestamp_ms + 1
|
||||
self._last_timestamp_ms = timestamp_ms
|
||||
|
||||
@@ -7,6 +7,7 @@ from app.exercises.dead_bug.types import Point
|
||||
|
||||
|
||||
def angle(a: Point, b: Point, c: Point) -> float:
|
||||
"""计算以b为顶点的三点夹角(度数)"""
|
||||
ba = np.array([a.x - b.x, a.y - b.y], dtype=np.float32)
|
||||
bc = np.array([c.x - b.x, c.y - b.y], dtype=np.float32)
|
||||
denom = float(np.linalg.norm(ba) * np.linalg.norm(bc))
|
||||
@@ -17,6 +18,7 @@ def angle(a: Point, b: Point, c: Point) -> float:
|
||||
|
||||
|
||||
def distance(a: Point, b: Point) -> float:
|
||||
"""计算两点之间的欧几里得距离(归一化坐标空间)"""
|
||||
return float(np.hypot(a.x - b.x, a.y - b.y))
|
||||
|
||||
|
||||
@@ -37,6 +39,7 @@ def calculate_metrics(
|
||||
right_ankle: int,
|
||||
visibility_threshold: float = 0.45,
|
||||
) -> dict:
|
||||
"""计算四肢关节角度、伸展状态及反馈信息"""
|
||||
left_elbow_angle = angle(lm[left_shoulder], lm[left_elbow], lm[left_wrist])
|
||||
right_elbow_angle = angle(lm[right_shoulder], lm[right_elbow], lm[right_wrist])
|
||||
left_knee_angle = angle(lm[left_hip], lm[left_knee], lm[left_ankle])
|
||||
|
||||
@@ -4,10 +4,12 @@ from app.exercises.dead_bug.types import DeadBugMetrics, Point
|
||||
|
||||
|
||||
def has_required_visibility(landmarks: list[Point], required_indices: tuple[int, ...], visibility_threshold: float) -> bool:
|
||||
"""检查所有必需关键点的可见度是否高于阈值"""
|
||||
return all(landmarks[i].visibility >= visibility_threshold for i in required_indices)
|
||||
|
||||
|
||||
def detect_diagonal_extension(metrics: DeadBugMetrics) -> str | None:
|
||||
"""检测是否存在对角伸展(左臂+右腿 或 右臂+左腿)"""
|
||||
if metrics.left_leg_extended and metrics.right_leg_extended:
|
||||
return None
|
||||
|
||||
@@ -19,6 +21,7 @@ def detect_diagonal_extension(metrics: DeadBugMetrics) -> str | None:
|
||||
|
||||
|
||||
def is_ready_position(metrics: DeadBugMetrics) -> bool:
|
||||
"""判断是否处于准备姿态(膝盖弯曲且四肢未伸展)"""
|
||||
knees_bent = metrics.left_knee_angle <= 140 and metrics.right_knee_angle <= 140
|
||||
legs_not_extended = not metrics.left_leg_extended and not metrics.right_leg_extended
|
||||
return knees_bent and legs_not_extended and detect_diagonal_extension(metrics) is None
|
||||
|
||||
@@ -3,9 +3,11 @@ from __future__ import annotations
|
||||
from app.exercises.dead_bug.rules import detect_diagonal_extension, is_ready_position
|
||||
from app.exercises.dead_bug.types import DeadBugMetrics, DeadBugPhase, DeadBugResult
|
||||
|
||||
|
||||
class DeadBugStateMachine:
|
||||
"""死虫式动作状态机:管理READY/EXTENDING/NEED_RESET/NO_POSE状态转换"""
|
||||
|
||||
def __init__(self, *, extension_confirm_frames: int = 4, reset_confirm_frames: int = 3) -> None:
|
||||
"""初始化并设置状态转换确认帧数"""
|
||||
self.extension_confirm_frames = extension_confirm_frames
|
||||
self.reset_confirm_frames = reset_confirm_frames
|
||||
|
||||
@@ -17,6 +19,7 @@ class DeadBugStateMachine:
|
||||
self._reset_frames = 0
|
||||
|
||||
def update(self, metrics: DeadBugMetrics) -> DeadBugResult:
|
||||
"""根据传入指标更新状态机并返回本次结果"""
|
||||
side = detect_diagonal_extension(metrics)
|
||||
ready = is_ready_position(metrics)
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ from enum import Enum
|
||||
|
||||
|
||||
class DeadBugPhase(str, Enum):
|
||||
"""死虫式动作阶段枚举"""
|
||||
READY = "ready"
|
||||
EXTENDING = "extending"
|
||||
NEED_RESET = "need_reset"
|
||||
@@ -13,6 +14,7 @@ class DeadBugPhase(str, Enum):
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Point:
|
||||
"""三维关键点坐标及可见度"""
|
||||
x: float
|
||||
y: float
|
||||
z: float
|
||||
@@ -21,6 +23,7 @@ class Point:
|
||||
|
||||
@dataclass
|
||||
class DeadBugMetrics:
|
||||
"""四肢关节度量数据"""
|
||||
left_arm_extended: bool
|
||||
right_arm_extended: bool
|
||||
left_leg_extended: bool
|
||||
@@ -34,6 +37,7 @@ class DeadBugMetrics:
|
||||
|
||||
@dataclass
|
||||
class DeadBugResult:
|
||||
"""单帧检测结果"""
|
||||
rep_count: int
|
||||
phase: DeadBugPhase
|
||||
side: str | None
|
||||
|
||||
@@ -9,6 +9,7 @@ from app.signaling.websocket_server import main as serve
|
||||
|
||||
|
||||
def main():
|
||||
"""应用入口:启动服务并运行WebSocket信令服务器"""
|
||||
startup()
|
||||
logger.info("Starting server...")
|
||||
try:
|
||||
|
||||
@@ -7,6 +7,7 @@ from app.exercises.dead_bug.types import DeadBugResult
|
||||
|
||||
|
||||
def draw_status_overlay(image: np.ndarray, result: DeadBugResult) -> None:
|
||||
"""在图像上叠加动作状态信息(次数、阶段、反馈)"""
|
||||
color = (60, 220, 90) if result.is_standard else (50, 180, 255)
|
||||
cv2.rectangle(image, (12, 12), (520, 142), (20, 20, 20), -1)
|
||||
cv2.putText(image, f"Dead bug reps: {result.rep_count}", (28, 48), cv2.FONT_HERSHEY_SIMPLEX, 0.9, color, 2)
|
||||
|
||||
@@ -18,6 +18,7 @@ def draw_landmarks(
|
||||
line_thickness: int = 2,
|
||||
point_radius: int = 4,
|
||||
) -> None:
|
||||
"""绘制人体骨架关键点与连接线(仅绘制可见度达标的点)"""
|
||||
if connections is None:
|
||||
connections = _POSE_CONNECTIONS
|
||||
|
||||
|
||||
@@ -6,16 +6,20 @@ WINDOW_NAME = "Android Camera (WebRTC)"
|
||||
|
||||
|
||||
def show_frame(image, window_name: str = WINDOW_NAME) -> None:
|
||||
"""在OpenCV窗口中显示图像帧"""
|
||||
cv2.imshow(window_name, image)
|
||||
|
||||
|
||||
def wait_key(delay_ms: int = 1) -> int:
|
||||
"""等待按键并返回ASCII码"""
|
||||
return cv2.waitKey(delay_ms) & 0xFF
|
||||
|
||||
|
||||
def is_esc_pressed() -> bool:
|
||||
"""检测ESC键是否被按下"""
|
||||
return wait_key(1) == 27
|
||||
|
||||
|
||||
def close_window() -> None:
|
||||
"""关闭所有OpenCV窗口"""
|
||||
cv2.destroyAllWindows()
|
||||
|
||||
@@ -7,6 +7,7 @@ from aiortc import RTCIceCandidate
|
||||
|
||||
|
||||
def parse_ice(data: dict[str, Any]) -> RTCIceCandidate | None:
|
||||
"""解析ICE候选者字符串为RTCIceCandidate对象"""
|
||||
match = re.match(
|
||||
r'candidate:(\S+) (\d) (\S+) (\d+) (\S+) (\d+) typ (\S+)(?: raddr (\S+) rport (\d+))?',
|
||||
data["candidate"],
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class SignalingMessage:
|
||||
"""WebRTC信令消息数据模型"""
|
||||
type: str
|
||||
sdp: str = ""
|
||||
candidate: str = ""
|
||||
|
||||
@@ -11,6 +11,7 @@ from configs.load import config
|
||||
|
||||
|
||||
async def handle_client(websocket):
|
||||
"""处理单个WebSocket客户端连接"""
|
||||
client = websocket.remote_address
|
||||
logger.info(f"Client connected: {client}")
|
||||
|
||||
@@ -21,6 +22,7 @@ async def handle_client(websocket):
|
||||
|
||||
|
||||
async def main():
|
||||
"""启动WebSocket信令服务器"""
|
||||
cfg = config.server
|
||||
logger.info(f"WebRTC signaling server: ws://{cfg.host}:{cfg.port}")
|
||||
async with websockets.serve(handle_client, cfg.host, cfg.port, max_size=cfg.max_ws_size):
|
||||
|
||||
@@ -9,6 +9,7 @@ TARGET_HEIGHT = 720
|
||||
|
||||
|
||||
def resize_to_target(image: np.ndarray, width: int = TARGET_WIDTH, height: int = TARGET_HEIGHT) -> np.ndarray:
|
||||
"""将图像缩放到目标尺寸(仅当尺寸不一致时)"""
|
||||
h, w = image.shape[:2]
|
||||
if w == width and h == height:
|
||||
return image
|
||||
@@ -16,8 +17,10 @@ def resize_to_target(image: np.ndarray, width: int = TARGET_WIDTH, height: int =
|
||||
|
||||
|
||||
def bgr_to_rgba(bgr: np.ndarray) -> np.ndarray:
|
||||
"""将BGR格式图像转换为RGBA格式"""
|
||||
return cv2.cvtColor(bgr, cv2.COLOR_BGR2RGBA)
|
||||
|
||||
|
||||
def bgr_to_rgb(bgr: np.ndarray) -> np.ndarray:
|
||||
"""将BGR格式图像转换为RGB格式"""
|
||||
return cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
|
||||
|
||||
@@ -14,8 +14,9 @@ PoseLandmarkerOptions = mp.tasks.vision.PoseLandmarkerOptions
|
||||
VisionRunningMode = mp.tasks.vision.RunningMode
|
||||
BaseOptions = mp.tasks.BaseOptions
|
||||
|
||||
|
||||
class PoseLandmarkerWrapper:
|
||||
"""MediaPipe姿态关键点检测器封装"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
@@ -23,6 +24,7 @@ class PoseLandmarkerWrapper:
|
||||
prefer_gpu: bool = True,
|
||||
result_callback: Callable | None = None,
|
||||
) -> None:
|
||||
"""初始化姿态检测器,优先尝试GPU委托,失败则回退到CPU"""
|
||||
self.model_path = model_path or DEFAULT_MODEL_PATH
|
||||
|
||||
if prefer_gpu:
|
||||
@@ -39,6 +41,7 @@ class PoseLandmarkerWrapper:
|
||||
logger.info("MediaPipe PoseLandmarker initialized with CPU delegate")
|
||||
|
||||
def _create(self, delegate, result_callback=None):
|
||||
"""根据委托类型和回调创建PoseLandmarker实例"""
|
||||
options = PoseLandmarkerOptions(
|
||||
base_options=BaseOptions(model_asset_path=self.model_path, delegate=delegate),
|
||||
running_mode=VisionRunningMode.LIVE_STREAM,
|
||||
@@ -51,7 +54,9 @@ class PoseLandmarkerWrapper:
|
||||
return PoseLandmarker.create_from_options(options)
|
||||
|
||||
def detect_async(self, mp_image, timestamp_ms: int) -> None:
|
||||
"""异步执行姿态检测"""
|
||||
return self._landmarker.detect_async(mp_image, timestamp_ms)
|
||||
|
||||
def close(self) -> None:
|
||||
"""释放MediaPipe资源"""
|
||||
self._landmarker.close()
|
||||
|
||||
@@ -9,6 +9,7 @@ TARGET_HEIGHT = 720
|
||||
|
||||
|
||||
def validate_frame_size(image: np.ndarray, width: int = TARGET_WIDTH, height: int = TARGET_HEIGHT) -> None:
|
||||
"""验证视频帧尺寸是否与目标尺寸一致,不一致时记录警告"""
|
||||
h, w = image.shape[:2]
|
||||
if w != width or h != height:
|
||||
logger.warning("Unexpected frame size: {}x{}", w, h)
|
||||
|
||||
@@ -10,13 +10,15 @@ from loguru import logger
|
||||
from app.signaling.ice_parser import parse_ice
|
||||
from app.webrtc.video_receiver import VideoReceiver
|
||||
|
||||
|
||||
class PeerSession:
|
||||
"""WebRTC对等连接会话管理"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._pc = RTCPeerConnection()
|
||||
self._video_task: asyncio.Task | None = None
|
||||
|
||||
async def handle(self, websocket) -> None:
|
||||
"""处理WebSocket信令交互与WebRTC连接建立"""
|
||||
self._setup_events()
|
||||
|
||||
try:
|
||||
@@ -47,6 +49,7 @@ class PeerSession:
|
||||
await self._cleanup()
|
||||
|
||||
def _setup_events(self) -> None:
|
||||
"""注册ICE连接状态变化和视频轨道接收事件处理器"""
|
||||
@self._pc.on("track")
|
||||
async def on_track(track):
|
||||
logger.info(f"Track received: kind={track.kind}")
|
||||
@@ -61,6 +64,7 @@ class PeerSession:
|
||||
await self._pc.close()
|
||||
|
||||
async def _cleanup(self) -> None:
|
||||
"""清理视频任务并关闭对等连接"""
|
||||
if self._video_task:
|
||||
self._video_task.cancel()
|
||||
try:
|
||||
|
||||
@@ -13,6 +13,7 @@ from configs.load import config
|
||||
|
||||
|
||||
def _format_pose_debug(pose_result) -> str:
|
||||
"""格式化姿态检测结果用于调试日志输出"""
|
||||
metrics = pose_result.metrics
|
||||
if metrics is None:
|
||||
return "metrics=None"
|
||||
@@ -24,12 +25,14 @@ def _format_pose_debug(pose_result) -> str:
|
||||
f"ll={metrics.left_leg_extended}, rl={metrics.right_leg_extended})"
|
||||
)
|
||||
|
||||
|
||||
class VideoReceiver:
|
||||
"""视频轨道接收与运动检测流水线"""
|
||||
|
||||
def __init__(self, track) -> None:
|
||||
self._track = track
|
||||
|
||||
async def run(self) -> None:
|
||||
"""持续接收视频帧并进行姿态检测、渲染和语音播报"""
|
||||
logger.info("Start receiving video frames, process_every_n={}", config.video.process_every_n_frames)
|
||||
|
||||
frame_count = 0
|
||||
|
||||
Reference in New Issue
Block a user