Add dead bug pose detection

This commit is contained in:
2026-06-02 00:59:41 +08:00
parent feb456261c
commit fde0e0383d
4 changed files with 398 additions and 27 deletions
+376
View File
@@ -0,0 +1,376 @@
from __future__ import annotations
from dataclasses import dataclass
from enum import Enum
from pathlib import Path
import cv2
import mediapipe as mp
import numpy as np
PoseLandmarker = mp.tasks.vision.PoseLandmarker
PoseLandmarkerOptions = mp.tasks.vision.PoseLandmarkerOptions
VisionRunningMode = mp.tasks.vision.RunningMode
BaseOptions = mp.tasks.BaseOptions
class DeadBugPhase(str, Enum):
READY = "ready"
EXTENDING = "extending"
NEED_RESET = "need_reset"
NO_POSE = "no_pose"
@dataclass(frozen=True)
class Point:
x: float
y: float
z: float
visibility: float
@dataclass
class DeadBugMetrics:
left_arm_extended: bool
right_arm_extended: bool
left_leg_extended: bool
right_leg_extended: bool
left_elbow_angle: float
right_elbow_angle: float
left_knee_angle: float
right_knee_angle: float
torso_tilt: float
feedback: list[str]
@dataclass
class DeadBugResult:
rep_count: int
phase: DeadBugPhase
side: str | None
is_standard: bool
feedback: list[str]
metrics: DeadBugMetrics | None
class DeadBugDetector:
"""MediaPipe Pose based dead bug detector and counter.
The rules are intentionally conservative because a phone stream only gives
us 2D landmarks. A rep is counted when one diagonal pair extends cleanly and
the body returns to the bent-knee ready position.
"""
LEFT_SHOULDER = 11
RIGHT_SHOULDER = 12
LEFT_ELBOW = 13
RIGHT_ELBOW = 14
LEFT_WRIST = 15
RIGHT_WRIST = 16
LEFT_HIP = 23
RIGHT_HIP = 24
LEFT_KNEE = 25
RIGHT_KNEE = 26
LEFT_ANKLE = 27
RIGHT_ANKLE = 28
REQUIRED_LANDMARKS = (
LEFT_SHOULDER,
RIGHT_SHOULDER,
LEFT_ELBOW,
RIGHT_ELBOW,
LEFT_WRIST,
RIGHT_WRIST,
LEFT_HIP,
RIGHT_HIP,
LEFT_KNEE,
RIGHT_KNEE,
LEFT_ANKLE,
RIGHT_ANKLE,
)
def __init__(
self,
model_path: str | Path | None = None,
*,
visibility_threshold: float = 0.45,
extension_confirm_frames: int = 4,
reset_confirm_frames: int = 3,
) -> None:
if model_path is None:
model_path = Path(__file__).resolve().parent / "pose_models" / "pose_landmarker_full.task"
self.model_path = str(model_path)
self.visibility_threshold = visibility_threshold
self.extension_confirm_frames = extension_confirm_frames
self.reset_confirm_frames = reset_confirm_frames
options = PoseLandmarkerOptions(
base_options=BaseOptions(model_asset_path=self.model_path),
running_mode=VisionRunningMode.VIDEO,
num_poses=1,
min_pose_detection_confidence=0.5,
min_pose_presence_confidence=0.5,
min_tracking_confidence=0.5,
)
self._landmarker = PoseLandmarker.create_from_options(options)
self.rep_count = 0
self.phase = DeadBugPhase.READY
self.active_side: str | None = None
self._candidate_side: str | None = None
self._candidate_frames = 0
self._reset_frames = 0
self._last_timestamp_ms = -1
def close(self) -> None:
self._landmarker.close()
def process_frame(self, bgr_frame: np.ndarray, timestamp_ms: int) -> tuple[np.ndarray, DeadBugResult]:
timestamp_ms = self._normalize_timestamp(timestamp_ms)
rgb_frame = cv2.cvtColor(bgr_frame, cv2.COLOR_BGR2RGB)
mp_image = mp.Image(image_format=mp.ImageFormat.SRGB, data=rgb_frame)
pose_result = self._landmarker.detect_for_video(mp_image, timestamp_ms)
annotated = bgr_frame.copy()
if not pose_result.pose_landmarks:
result = DeadBugResult(
rep_count=self.rep_count,
phase=DeadBugPhase.NO_POSE,
side=self.active_side,
is_standard=False,
feedback=["No full body detected"],
metrics=None,
)
self._draw_status(annotated, result)
return annotated, result
landmarks = [Point(lm.x, lm.y, lm.z, getattr(lm, "visibility", 1.0)) for lm in pose_result.pose_landmarks[0]]
self._draw_landmarks(annotated, landmarks)
if not self._has_required_visibility(landmarks):
result = DeadBugResult(
rep_count=self.rep_count,
phase=DeadBugPhase.NO_POSE,
side=self.active_side,
is_standard=False,
feedback=["Keep shoulders, elbows, wrists, hips, knees, ankles visible"],
metrics=None,
)
self._draw_status(annotated, result)
return annotated, result
metrics = self._calculate_metrics(landmarks)
result = self._update_state(metrics)
self._draw_status(annotated, result)
return annotated, result
def _normalize_timestamp(self, timestamp_ms: int) -> int:
if timestamp_ms <= self._last_timestamp_ms:
timestamp_ms = self._last_timestamp_ms + 1
self._last_timestamp_ms = timestamp_ms
return timestamp_ms
def _has_required_visibility(self, landmarks: list[Point]) -> bool:
return all(landmarks[i].visibility >= self.visibility_threshold for i in self.REQUIRED_LANDMARKS)
def _calculate_metrics(self, lm: list[Point]) -> DeadBugMetrics:
left_elbow = angle(lm[self.LEFT_SHOULDER], lm[self.LEFT_ELBOW], lm[self.LEFT_WRIST])
right_elbow = angle(lm[self.RIGHT_SHOULDER], lm[self.RIGHT_ELBOW], lm[self.RIGHT_WRIST])
left_knee = angle(lm[self.LEFT_HIP], lm[self.LEFT_KNEE], lm[self.LEFT_ANKLE])
right_knee = angle(lm[self.RIGHT_HIP], lm[self.RIGHT_KNEE], lm[self.RIGHT_ANKLE])
shoulder_width = distance(lm[self.LEFT_SHOULDER], lm[self.RIGHT_SHOULDER])
hip_width = distance(lm[self.LEFT_HIP], lm[self.RIGHT_HIP])
scale = max(shoulder_width, hip_width, 0.08)
left_arm_extended = (
left_elbow >= 145
and distance(lm[self.LEFT_SHOULDER], lm[self.LEFT_WRIST]) >= scale * 1.15
and lm[self.LEFT_WRIST].y <= lm[self.LEFT_SHOULDER].y + scale * 0.35
)
right_arm_extended = (
right_elbow >= 145
and distance(lm[self.RIGHT_SHOULDER], lm[self.RIGHT_WRIST]) >= scale * 1.15
and lm[self.RIGHT_WRIST].y <= lm[self.RIGHT_SHOULDER].y + scale * 0.35
)
left_leg_extended = (
left_knee >= 150
and distance(lm[self.LEFT_HIP], lm[self.LEFT_ANKLE]) >= scale * 1.55
and lm[self.LEFT_ANKLE].y >= lm[self.LEFT_KNEE].y - scale * 0.2
)
right_leg_extended = (
right_knee >= 150
and distance(lm[self.RIGHT_HIP], lm[self.RIGHT_ANKLE]) >= scale * 1.55
and lm[self.RIGHT_ANKLE].y >= lm[self.RIGHT_KNEE].y - scale * 0.2
)
torso_tilt = abs(lm[self.LEFT_HIP].y - lm[self.RIGHT_HIP].y) / scale
feedback: list[str] = []
if torso_tilt > 0.35:
feedback.append("Keep pelvis level and core stable")
if left_arm_extended and left_elbow < 160:
feedback.append("Straighten left arm")
if right_arm_extended and right_elbow < 160:
feedback.append("Straighten right arm")
if left_leg_extended and left_knee < 165:
feedback.append("Straighten left leg")
if right_leg_extended and right_knee < 165:
feedback.append("Straighten right leg")
return DeadBugMetrics(
left_arm_extended=left_arm_extended,
right_arm_extended=right_arm_extended,
left_leg_extended=left_leg_extended,
right_leg_extended=right_leg_extended,
left_elbow_angle=left_elbow,
right_elbow_angle=right_elbow,
left_knee_angle=left_knee,
right_knee_angle=right_knee,
torso_tilt=torso_tilt,
feedback=feedback,
)
def _update_state(self, metrics: DeadBugMetrics) -> DeadBugResult:
side = self._detect_diagonal_extension(metrics)
ready = self._is_ready_position(metrics)
if side is None:
self._candidate_side = None
self._candidate_frames = 0
elif side == self._candidate_side:
self._candidate_frames += 1
else:
self._candidate_side = side
self._candidate_frames = 1
if self.phase in (DeadBugPhase.READY, DeadBugPhase.NO_POSE):
if self._candidate_frames >= self.extension_confirm_frames and side is not None:
self.phase = DeadBugPhase.EXTENDING
self.active_side = side
self._reset_frames = 0
elif self.phase == DeadBugPhase.EXTENDING:
if side == self.active_side:
self.phase = DeadBugPhase.NEED_RESET
elif self.phase == DeadBugPhase.NEED_RESET:
if ready:
self._reset_frames += 1
if self._reset_frames >= self.reset_confirm_frames:
self.rep_count += 1
self.phase = DeadBugPhase.READY
self.active_side = None
self._candidate_side = None
self._candidate_frames = 0
self._reset_frames = 0
else:
self._reset_frames = 0
feedback = list(metrics.feedback)
if side is None and not ready:
feedback.append("Extend opposite arm and leg only")
if ready:
feedback.append("Ready position")
elif side == "left_arm_right_leg":
feedback.append("Left arm + right leg")
elif side == "right_arm_left_leg":
feedback.append("Right arm + left leg")
is_standard = side is not None and not metrics.feedback
return DeadBugResult(
rep_count=self.rep_count,
phase=self.phase,
side=side,
is_standard=is_standard,
feedback=feedback[:3],
metrics=metrics,
)
def _detect_diagonal_extension(self, metrics: DeadBugMetrics) -> str | None:
left_arm_right_leg = metrics.left_arm_extended and metrics.right_leg_extended
right_arm_left_leg = metrics.right_arm_extended and metrics.left_leg_extended
same_side_noise = (
metrics.left_arm_extended
and metrics.left_leg_extended
or metrics.right_arm_extended
and metrics.right_leg_extended
)
if same_side_noise:
return None
if left_arm_right_leg and not right_arm_left_leg:
return "left_arm_right_leg"
if right_arm_left_leg and not left_arm_right_leg:
return "right_arm_left_leg"
return None
def _is_ready_position(self, metrics: DeadBugMetrics) -> bool:
knees_bent = metrics.left_knee_angle <= 140 and metrics.right_knee_angle <= 140
legs_not_extended = not metrics.left_leg_extended and not metrics.right_leg_extended
return knees_bent and legs_not_extended and self._detect_diagonal_extension(metrics) is None
def _draw_landmarks(self, image: np.ndarray, landmarks: list[Point]) -> None:
h, w = image.shape[:2]
connections = getattr(getattr(mp, "solutions", None), "pose", None)
pose_connections = getattr(connections, "POSE_CONNECTIONS", _POSE_CONNECTIONS)
for start, end in pose_connections:
if start >= len(landmarks) or end >= len(landmarks):
continue
p1 = landmarks[start]
p2 = landmarks[end]
if p1.visibility < self.visibility_threshold or p2.visibility < self.visibility_threshold:
continue
cv2.line(
image,
(int(p1.x * w), int(p1.y * h)),
(int(p2.x * w), int(p2.y * h)),
(65, 180, 255),
2,
)
for idx in self.REQUIRED_LANDMARKS:
p = landmarks[idx]
if p.visibility >= self.visibility_threshold:
cv2.circle(image, (int(p.x * w), int(p.y * h)), 4, (80, 255, 120), -1)
def _draw_status(self, image: np.ndarray, result: DeadBugResult) -> None:
color = (60, 220, 90) if result.is_standard else (50, 180, 255)
cv2.rectangle(image, (12, 12), (520, 142), (20, 20, 20), -1)
cv2.putText(image, f"Dead bug reps: {result.rep_count}", (28, 48), cv2.FONT_HERSHEY_SIMPLEX, 0.9, color, 2)
cv2.putText(image, f"phase: {result.phase.value}", (28, 82), cv2.FONT_HERSHEY_SIMPLEX, 0.68, (230, 230, 230), 2)
status = "standard" if result.is_standard else "adjust"
cv2.putText(image, f"status: {status}", (28, 116), cv2.FONT_HERSHEY_SIMPLEX, 0.68, color, 2)
y = 170
for text in result.feedback:
cv2.putText(image, text, (28, y), cv2.FONT_HERSHEY_SIMPLEX, 0.68, (255, 255, 255), 2)
y += 30
def angle(a: Point, b: Point, c: Point) -> float:
ba = np.array([a.x - b.x, a.y - b.y], dtype=np.float32)
bc = np.array([c.x - b.x, c.y - b.y], dtype=np.float32)
denom = float(np.linalg.norm(ba) * np.linalg.norm(bc))
if denom == 0:
return 0.0
cos_value = float(np.dot(ba, bc) / denom)
return float(np.degrees(np.arccos(np.clip(cos_value, -1.0, 1.0))))
def distance(a: Point, b: Point) -> float:
return float(np.hypot(a.x - b.x, a.y - b.y))
_POSE_CONNECTIONS = (
(11, 12),
(11, 13),
(13, 15),
(12, 14),
(14, 16),
(11, 23),
(12, 24),
(23, 24),
(23, 25),
(25, 27),
(24, 26),
(26, 28),
)