Add dead bug pose detection
This commit is contained in:
@@ -0,0 +1,376 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import mediapipe as mp
|
||||
import numpy as np
|
||||
|
||||
|
||||
PoseLandmarker = mp.tasks.vision.PoseLandmarker
|
||||
PoseLandmarkerOptions = mp.tasks.vision.PoseLandmarkerOptions
|
||||
VisionRunningMode = mp.tasks.vision.RunningMode
|
||||
BaseOptions = mp.tasks.BaseOptions
|
||||
|
||||
|
||||
class DeadBugPhase(str, Enum):
|
||||
READY = "ready"
|
||||
EXTENDING = "extending"
|
||||
NEED_RESET = "need_reset"
|
||||
NO_POSE = "no_pose"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Point:
|
||||
x: float
|
||||
y: float
|
||||
z: float
|
||||
visibility: float
|
||||
|
||||
|
||||
@dataclass
|
||||
class DeadBugMetrics:
|
||||
left_arm_extended: bool
|
||||
right_arm_extended: bool
|
||||
left_leg_extended: bool
|
||||
right_leg_extended: bool
|
||||
left_elbow_angle: float
|
||||
right_elbow_angle: float
|
||||
left_knee_angle: float
|
||||
right_knee_angle: float
|
||||
torso_tilt: float
|
||||
feedback: list[str]
|
||||
|
||||
|
||||
@dataclass
|
||||
class DeadBugResult:
|
||||
rep_count: int
|
||||
phase: DeadBugPhase
|
||||
side: str | None
|
||||
is_standard: bool
|
||||
feedback: list[str]
|
||||
metrics: DeadBugMetrics | None
|
||||
|
||||
|
||||
class DeadBugDetector:
|
||||
"""MediaPipe Pose based dead bug detector and counter.
|
||||
|
||||
The rules are intentionally conservative because a phone stream only gives
|
||||
us 2D landmarks. A rep is counted when one diagonal pair extends cleanly and
|
||||
the body returns to the bent-knee ready position.
|
||||
"""
|
||||
|
||||
LEFT_SHOULDER = 11
|
||||
RIGHT_SHOULDER = 12
|
||||
LEFT_ELBOW = 13
|
||||
RIGHT_ELBOW = 14
|
||||
LEFT_WRIST = 15
|
||||
RIGHT_WRIST = 16
|
||||
LEFT_HIP = 23
|
||||
RIGHT_HIP = 24
|
||||
LEFT_KNEE = 25
|
||||
RIGHT_KNEE = 26
|
||||
LEFT_ANKLE = 27
|
||||
RIGHT_ANKLE = 28
|
||||
|
||||
REQUIRED_LANDMARKS = (
|
||||
LEFT_SHOULDER,
|
||||
RIGHT_SHOULDER,
|
||||
LEFT_ELBOW,
|
||||
RIGHT_ELBOW,
|
||||
LEFT_WRIST,
|
||||
RIGHT_WRIST,
|
||||
LEFT_HIP,
|
||||
RIGHT_HIP,
|
||||
LEFT_KNEE,
|
||||
RIGHT_KNEE,
|
||||
LEFT_ANKLE,
|
||||
RIGHT_ANKLE,
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_path: str | Path | None = None,
|
||||
*,
|
||||
visibility_threshold: float = 0.45,
|
||||
extension_confirm_frames: int = 4,
|
||||
reset_confirm_frames: int = 3,
|
||||
) -> None:
|
||||
if model_path is None:
|
||||
model_path = Path(__file__).resolve().parent / "pose_models" / "pose_landmarker_full.task"
|
||||
|
||||
self.model_path = str(model_path)
|
||||
self.visibility_threshold = visibility_threshold
|
||||
self.extension_confirm_frames = extension_confirm_frames
|
||||
self.reset_confirm_frames = reset_confirm_frames
|
||||
|
||||
options = PoseLandmarkerOptions(
|
||||
base_options=BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=VisionRunningMode.VIDEO,
|
||||
num_poses=1,
|
||||
min_pose_detection_confidence=0.5,
|
||||
min_pose_presence_confidence=0.5,
|
||||
min_tracking_confidence=0.5,
|
||||
)
|
||||
self._landmarker = PoseLandmarker.create_from_options(options)
|
||||
|
||||
self.rep_count = 0
|
||||
self.phase = DeadBugPhase.READY
|
||||
self.active_side: str | None = None
|
||||
self._candidate_side: str | None = None
|
||||
self._candidate_frames = 0
|
||||
self._reset_frames = 0
|
||||
self._last_timestamp_ms = -1
|
||||
|
||||
def close(self) -> None:
|
||||
self._landmarker.close()
|
||||
|
||||
def process_frame(self, bgr_frame: np.ndarray, timestamp_ms: int) -> tuple[np.ndarray, DeadBugResult]:
|
||||
timestamp_ms = self._normalize_timestamp(timestamp_ms)
|
||||
rgb_frame = cv2.cvtColor(bgr_frame, cv2.COLOR_BGR2RGB)
|
||||
mp_image = mp.Image(image_format=mp.ImageFormat.SRGB, data=rgb_frame)
|
||||
pose_result = self._landmarker.detect_for_video(mp_image, timestamp_ms)
|
||||
|
||||
annotated = bgr_frame.copy()
|
||||
if not pose_result.pose_landmarks:
|
||||
result = DeadBugResult(
|
||||
rep_count=self.rep_count,
|
||||
phase=DeadBugPhase.NO_POSE,
|
||||
side=self.active_side,
|
||||
is_standard=False,
|
||||
feedback=["No full body detected"],
|
||||
metrics=None,
|
||||
)
|
||||
self._draw_status(annotated, result)
|
||||
return annotated, result
|
||||
|
||||
landmarks = [Point(lm.x, lm.y, lm.z, getattr(lm, "visibility", 1.0)) for lm in pose_result.pose_landmarks[0]]
|
||||
self._draw_landmarks(annotated, landmarks)
|
||||
|
||||
if not self._has_required_visibility(landmarks):
|
||||
result = DeadBugResult(
|
||||
rep_count=self.rep_count,
|
||||
phase=DeadBugPhase.NO_POSE,
|
||||
side=self.active_side,
|
||||
is_standard=False,
|
||||
feedback=["Keep shoulders, elbows, wrists, hips, knees, ankles visible"],
|
||||
metrics=None,
|
||||
)
|
||||
self._draw_status(annotated, result)
|
||||
return annotated, result
|
||||
|
||||
metrics = self._calculate_metrics(landmarks)
|
||||
result = self._update_state(metrics)
|
||||
self._draw_status(annotated, result)
|
||||
return annotated, result
|
||||
|
||||
def _normalize_timestamp(self, timestamp_ms: int) -> int:
|
||||
if timestamp_ms <= self._last_timestamp_ms:
|
||||
timestamp_ms = self._last_timestamp_ms + 1
|
||||
self._last_timestamp_ms = timestamp_ms
|
||||
return timestamp_ms
|
||||
|
||||
def _has_required_visibility(self, landmarks: list[Point]) -> bool:
|
||||
return all(landmarks[i].visibility >= self.visibility_threshold for i in self.REQUIRED_LANDMARKS)
|
||||
|
||||
def _calculate_metrics(self, lm: list[Point]) -> DeadBugMetrics:
|
||||
left_elbow = angle(lm[self.LEFT_SHOULDER], lm[self.LEFT_ELBOW], lm[self.LEFT_WRIST])
|
||||
right_elbow = angle(lm[self.RIGHT_SHOULDER], lm[self.RIGHT_ELBOW], lm[self.RIGHT_WRIST])
|
||||
left_knee = angle(lm[self.LEFT_HIP], lm[self.LEFT_KNEE], lm[self.LEFT_ANKLE])
|
||||
right_knee = angle(lm[self.RIGHT_HIP], lm[self.RIGHT_KNEE], lm[self.RIGHT_ANKLE])
|
||||
|
||||
shoulder_width = distance(lm[self.LEFT_SHOULDER], lm[self.RIGHT_SHOULDER])
|
||||
hip_width = distance(lm[self.LEFT_HIP], lm[self.RIGHT_HIP])
|
||||
scale = max(shoulder_width, hip_width, 0.08)
|
||||
|
||||
left_arm_extended = (
|
||||
left_elbow >= 145
|
||||
and distance(lm[self.LEFT_SHOULDER], lm[self.LEFT_WRIST]) >= scale * 1.15
|
||||
and lm[self.LEFT_WRIST].y <= lm[self.LEFT_SHOULDER].y + scale * 0.35
|
||||
)
|
||||
right_arm_extended = (
|
||||
right_elbow >= 145
|
||||
and distance(lm[self.RIGHT_SHOULDER], lm[self.RIGHT_WRIST]) >= scale * 1.15
|
||||
and lm[self.RIGHT_WRIST].y <= lm[self.RIGHT_SHOULDER].y + scale * 0.35
|
||||
)
|
||||
|
||||
left_leg_extended = (
|
||||
left_knee >= 150
|
||||
and distance(lm[self.LEFT_HIP], lm[self.LEFT_ANKLE]) >= scale * 1.55
|
||||
and lm[self.LEFT_ANKLE].y >= lm[self.LEFT_KNEE].y - scale * 0.2
|
||||
)
|
||||
right_leg_extended = (
|
||||
right_knee >= 150
|
||||
and distance(lm[self.RIGHT_HIP], lm[self.RIGHT_ANKLE]) >= scale * 1.55
|
||||
and lm[self.RIGHT_ANKLE].y >= lm[self.RIGHT_KNEE].y - scale * 0.2
|
||||
)
|
||||
|
||||
torso_tilt = abs(lm[self.LEFT_HIP].y - lm[self.RIGHT_HIP].y) / scale
|
||||
feedback: list[str] = []
|
||||
if torso_tilt > 0.35:
|
||||
feedback.append("Keep pelvis level and core stable")
|
||||
if left_arm_extended and left_elbow < 160:
|
||||
feedback.append("Straighten left arm")
|
||||
if right_arm_extended and right_elbow < 160:
|
||||
feedback.append("Straighten right arm")
|
||||
if left_leg_extended and left_knee < 165:
|
||||
feedback.append("Straighten left leg")
|
||||
if right_leg_extended and right_knee < 165:
|
||||
feedback.append("Straighten right leg")
|
||||
|
||||
return DeadBugMetrics(
|
||||
left_arm_extended=left_arm_extended,
|
||||
right_arm_extended=right_arm_extended,
|
||||
left_leg_extended=left_leg_extended,
|
||||
right_leg_extended=right_leg_extended,
|
||||
left_elbow_angle=left_elbow,
|
||||
right_elbow_angle=right_elbow,
|
||||
left_knee_angle=left_knee,
|
||||
right_knee_angle=right_knee,
|
||||
torso_tilt=torso_tilt,
|
||||
feedback=feedback,
|
||||
)
|
||||
|
||||
def _update_state(self, metrics: DeadBugMetrics) -> DeadBugResult:
|
||||
side = self._detect_diagonal_extension(metrics)
|
||||
ready = self._is_ready_position(metrics)
|
||||
|
||||
if side is None:
|
||||
self._candidate_side = None
|
||||
self._candidate_frames = 0
|
||||
elif side == self._candidate_side:
|
||||
self._candidate_frames += 1
|
||||
else:
|
||||
self._candidate_side = side
|
||||
self._candidate_frames = 1
|
||||
|
||||
if self.phase in (DeadBugPhase.READY, DeadBugPhase.NO_POSE):
|
||||
if self._candidate_frames >= self.extension_confirm_frames and side is not None:
|
||||
self.phase = DeadBugPhase.EXTENDING
|
||||
self.active_side = side
|
||||
self._reset_frames = 0
|
||||
elif self.phase == DeadBugPhase.EXTENDING:
|
||||
if side == self.active_side:
|
||||
self.phase = DeadBugPhase.NEED_RESET
|
||||
elif self.phase == DeadBugPhase.NEED_RESET:
|
||||
if ready:
|
||||
self._reset_frames += 1
|
||||
if self._reset_frames >= self.reset_confirm_frames:
|
||||
self.rep_count += 1
|
||||
self.phase = DeadBugPhase.READY
|
||||
self.active_side = None
|
||||
self._candidate_side = None
|
||||
self._candidate_frames = 0
|
||||
self._reset_frames = 0
|
||||
else:
|
||||
self._reset_frames = 0
|
||||
|
||||
feedback = list(metrics.feedback)
|
||||
if side is None and not ready:
|
||||
feedback.append("Extend opposite arm and leg only")
|
||||
if ready:
|
||||
feedback.append("Ready position")
|
||||
elif side == "left_arm_right_leg":
|
||||
feedback.append("Left arm + right leg")
|
||||
elif side == "right_arm_left_leg":
|
||||
feedback.append("Right arm + left leg")
|
||||
|
||||
is_standard = side is not None and not metrics.feedback
|
||||
return DeadBugResult(
|
||||
rep_count=self.rep_count,
|
||||
phase=self.phase,
|
||||
side=side,
|
||||
is_standard=is_standard,
|
||||
feedback=feedback[:3],
|
||||
metrics=metrics,
|
||||
)
|
||||
|
||||
def _detect_diagonal_extension(self, metrics: DeadBugMetrics) -> str | None:
|
||||
left_arm_right_leg = metrics.left_arm_extended and metrics.right_leg_extended
|
||||
right_arm_left_leg = metrics.right_arm_extended and metrics.left_leg_extended
|
||||
same_side_noise = (
|
||||
metrics.left_arm_extended
|
||||
and metrics.left_leg_extended
|
||||
or metrics.right_arm_extended
|
||||
and metrics.right_leg_extended
|
||||
)
|
||||
if same_side_noise:
|
||||
return None
|
||||
if left_arm_right_leg and not right_arm_left_leg:
|
||||
return "left_arm_right_leg"
|
||||
if right_arm_left_leg and not left_arm_right_leg:
|
||||
return "right_arm_left_leg"
|
||||
return None
|
||||
|
||||
def _is_ready_position(self, metrics: DeadBugMetrics) -> bool:
|
||||
knees_bent = metrics.left_knee_angle <= 140 and metrics.right_knee_angle <= 140
|
||||
legs_not_extended = not metrics.left_leg_extended and not metrics.right_leg_extended
|
||||
return knees_bent and legs_not_extended and self._detect_diagonal_extension(metrics) is None
|
||||
|
||||
def _draw_landmarks(self, image: np.ndarray, landmarks: list[Point]) -> None:
|
||||
h, w = image.shape[:2]
|
||||
connections = getattr(getattr(mp, "solutions", None), "pose", None)
|
||||
pose_connections = getattr(connections, "POSE_CONNECTIONS", _POSE_CONNECTIONS)
|
||||
for start, end in pose_connections:
|
||||
if start >= len(landmarks) or end >= len(landmarks):
|
||||
continue
|
||||
p1 = landmarks[start]
|
||||
p2 = landmarks[end]
|
||||
if p1.visibility < self.visibility_threshold or p2.visibility < self.visibility_threshold:
|
||||
continue
|
||||
cv2.line(
|
||||
image,
|
||||
(int(p1.x * w), int(p1.y * h)),
|
||||
(int(p2.x * w), int(p2.y * h)),
|
||||
(65, 180, 255),
|
||||
2,
|
||||
)
|
||||
for idx in self.REQUIRED_LANDMARKS:
|
||||
p = landmarks[idx]
|
||||
if p.visibility >= self.visibility_threshold:
|
||||
cv2.circle(image, (int(p.x * w), int(p.y * h)), 4, (80, 255, 120), -1)
|
||||
|
||||
def _draw_status(self, image: np.ndarray, result: DeadBugResult) -> None:
|
||||
color = (60, 220, 90) if result.is_standard else (50, 180, 255)
|
||||
cv2.rectangle(image, (12, 12), (520, 142), (20, 20, 20), -1)
|
||||
cv2.putText(image, f"Dead bug reps: {result.rep_count}", (28, 48), cv2.FONT_HERSHEY_SIMPLEX, 0.9, color, 2)
|
||||
cv2.putText(image, f"phase: {result.phase.value}", (28, 82), cv2.FONT_HERSHEY_SIMPLEX, 0.68, (230, 230, 230), 2)
|
||||
status = "standard" if result.is_standard else "adjust"
|
||||
cv2.putText(image, f"status: {status}", (28, 116), cv2.FONT_HERSHEY_SIMPLEX, 0.68, color, 2)
|
||||
|
||||
y = 170
|
||||
for text in result.feedback:
|
||||
cv2.putText(image, text, (28, y), cv2.FONT_HERSHEY_SIMPLEX, 0.68, (255, 255, 255), 2)
|
||||
y += 30
|
||||
|
||||
|
||||
def angle(a: Point, b: Point, c: Point) -> float:
|
||||
ba = np.array([a.x - b.x, a.y - b.y], dtype=np.float32)
|
||||
bc = np.array([c.x - b.x, c.y - b.y], dtype=np.float32)
|
||||
denom = float(np.linalg.norm(ba) * np.linalg.norm(bc))
|
||||
if denom == 0:
|
||||
return 0.0
|
||||
cos_value = float(np.dot(ba, bc) / denom)
|
||||
return float(np.degrees(np.arccos(np.clip(cos_value, -1.0, 1.0))))
|
||||
|
||||
|
||||
def distance(a: Point, b: Point) -> float:
|
||||
return float(np.hypot(a.x - b.x, a.y - b.y))
|
||||
|
||||
|
||||
_POSE_CONNECTIONS = (
|
||||
(11, 12),
|
||||
(11, 13),
|
||||
(13, 15),
|
||||
(12, 14),
|
||||
(14, 16),
|
||||
(11, 23),
|
||||
(12, 24),
|
||||
(23, 24),
|
||||
(23, 25),
|
||||
(25, 27),
|
||||
(24, 26),
|
||||
(26, 28),
|
||||
)
|
||||
Reference in New Issue
Block a user