为所有函数和类添加中文注释文档字符串
This commit is contained in:
@@ -32,8 +32,9 @@ from app.vision.pose_types import (
|
||||
RIGHT_WRIST,
|
||||
)
|
||||
|
||||
|
||||
class DeadBugDetector:
|
||||
"""死虫式(Dead Bug)运动检测器"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
@@ -43,6 +44,7 @@ class DeadBugDetector:
|
||||
reset_confirm_frames: int = 3,
|
||||
prefer_gpu: bool = True,
|
||||
) -> None:
|
||||
"""初始化姿态检测器、状态机和可视化渲染组件"""
|
||||
self.visibility_threshold = visibility_threshold
|
||||
|
||||
self._latest_result = None
|
||||
@@ -72,9 +74,11 @@ class DeadBugDetector:
|
||||
self._last_timestamp_ms = -1
|
||||
|
||||
def close(self) -> None:
|
||||
"""释放MediaPipe模型资源"""
|
||||
self._landmarker.close()
|
||||
|
||||
def process_frame(self, bgr_frame: np.ndarray, timestamp_ms: int) -> tuple[np.ndarray, DeadBugResult]:
|
||||
"""处理单帧:姿态检测、指标计算、状态机更新、可视化叠加"""
|
||||
timestamp_ms = self._normalize_timestamp(timestamp_ms)
|
||||
|
||||
with self._result_lock:
|
||||
@@ -166,6 +170,7 @@ class DeadBugDetector:
|
||||
return annotated, result
|
||||
|
||||
def _normalize_timestamp(self, timestamp_ms: int) -> int:
|
||||
"""确保时间戳严格递增(MediaPipe要求)"""
|
||||
if timestamp_ms <= self._last_timestamp_ms:
|
||||
timestamp_ms = self._last_timestamp_ms + 1
|
||||
self._last_timestamp_ms = timestamp_ms
|
||||
|
||||
@@ -7,6 +7,7 @@ from app.exercises.dead_bug.types import Point
|
||||
|
||||
|
||||
def angle(a: Point, b: Point, c: Point) -> float:
|
||||
"""计算以b为顶点的三点夹角(度数)"""
|
||||
ba = np.array([a.x - b.x, a.y - b.y], dtype=np.float32)
|
||||
bc = np.array([c.x - b.x, c.y - b.y], dtype=np.float32)
|
||||
denom = float(np.linalg.norm(ba) * np.linalg.norm(bc))
|
||||
@@ -17,6 +18,7 @@ def angle(a: Point, b: Point, c: Point) -> float:
|
||||
|
||||
|
||||
def distance(a: Point, b: Point) -> float:
|
||||
"""计算两点之间的欧几里得距离(归一化坐标空间)"""
|
||||
return float(np.hypot(a.x - b.x, a.y - b.y))
|
||||
|
||||
|
||||
@@ -37,6 +39,7 @@ def calculate_metrics(
|
||||
right_ankle: int,
|
||||
visibility_threshold: float = 0.45,
|
||||
) -> dict:
|
||||
"""计算四肢关节角度、伸展状态及反馈信息"""
|
||||
left_elbow_angle = angle(lm[left_shoulder], lm[left_elbow], lm[left_wrist])
|
||||
right_elbow_angle = angle(lm[right_shoulder], lm[right_elbow], lm[right_wrist])
|
||||
left_knee_angle = angle(lm[left_hip], lm[left_knee], lm[left_ankle])
|
||||
|
||||
@@ -4,10 +4,12 @@ from app.exercises.dead_bug.types import DeadBugMetrics, Point
|
||||
|
||||
|
||||
def has_required_visibility(landmarks: list[Point], required_indices: tuple[int, ...], visibility_threshold: float) -> bool:
|
||||
"""检查所有必需关键点的可见度是否高于阈值"""
|
||||
return all(landmarks[i].visibility >= visibility_threshold for i in required_indices)
|
||||
|
||||
|
||||
def detect_diagonal_extension(metrics: DeadBugMetrics) -> str | None:
|
||||
"""检测是否存在对角伸展(左臂+右腿 或 右臂+左腿)"""
|
||||
if metrics.left_leg_extended and metrics.right_leg_extended:
|
||||
return None
|
||||
|
||||
@@ -19,6 +21,7 @@ def detect_diagonal_extension(metrics: DeadBugMetrics) -> str | None:
|
||||
|
||||
|
||||
def is_ready_position(metrics: DeadBugMetrics) -> bool:
|
||||
"""判断是否处于准备姿态(膝盖弯曲且四肢未伸展)"""
|
||||
knees_bent = metrics.left_knee_angle <= 140 and metrics.right_knee_angle <= 140
|
||||
legs_not_extended = not metrics.left_leg_extended and not metrics.right_leg_extended
|
||||
return knees_bent and legs_not_extended and detect_diagonal_extension(metrics) is None
|
||||
|
||||
@@ -3,9 +3,11 @@ from __future__ import annotations
|
||||
from app.exercises.dead_bug.rules import detect_diagonal_extension, is_ready_position
|
||||
from app.exercises.dead_bug.types import DeadBugMetrics, DeadBugPhase, DeadBugResult
|
||||
|
||||
|
||||
class DeadBugStateMachine:
|
||||
"""死虫式动作状态机:管理READY/EXTENDING/NEED_RESET/NO_POSE状态转换"""
|
||||
|
||||
def __init__(self, *, extension_confirm_frames: int = 4, reset_confirm_frames: int = 3) -> None:
|
||||
"""初始化并设置状态转换确认帧数"""
|
||||
self.extension_confirm_frames = extension_confirm_frames
|
||||
self.reset_confirm_frames = reset_confirm_frames
|
||||
|
||||
@@ -17,6 +19,7 @@ class DeadBugStateMachine:
|
||||
self._reset_frames = 0
|
||||
|
||||
def update(self, metrics: DeadBugMetrics) -> DeadBugResult:
|
||||
"""根据传入指标更新状态机并返回本次结果"""
|
||||
side = detect_diagonal_extension(metrics)
|
||||
ready = is_ready_position(metrics)
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ from enum import Enum
|
||||
|
||||
|
||||
class DeadBugPhase(str, Enum):
|
||||
"""死虫式动作阶段枚举"""
|
||||
READY = "ready"
|
||||
EXTENDING = "extending"
|
||||
NEED_RESET = "need_reset"
|
||||
@@ -13,6 +14,7 @@ class DeadBugPhase(str, Enum):
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Point:
|
||||
"""三维关键点坐标及可见度"""
|
||||
x: float
|
||||
y: float
|
||||
z: float
|
||||
@@ -21,6 +23,7 @@ class Point:
|
||||
|
||||
@dataclass
|
||||
class DeadBugMetrics:
|
||||
"""四肢关节度量数据"""
|
||||
left_arm_extended: bool
|
||||
right_arm_extended: bool
|
||||
left_leg_extended: bool
|
||||
@@ -34,6 +37,7 @@ class DeadBugMetrics:
|
||||
|
||||
@dataclass
|
||||
class DeadBugResult:
|
||||
"""单帧检测结果"""
|
||||
rep_count: int
|
||||
phase: DeadBugPhase
|
||||
side: str | None
|
||||
|
||||
Reference in New Issue
Block a user