diff --git a/dead_bug_detector.py b/dead_bug_detector.py new file mode 100644 index 0000000..0707727 --- /dev/null +++ b/dead_bug_detector.py @@ -0,0 +1,376 @@ +from __future__ import annotations + +from dataclasses import dataclass +from enum import Enum +from pathlib import Path + +import cv2 +import mediapipe as mp +import numpy as np + + +PoseLandmarker = mp.tasks.vision.PoseLandmarker +PoseLandmarkerOptions = mp.tasks.vision.PoseLandmarkerOptions +VisionRunningMode = mp.tasks.vision.RunningMode +BaseOptions = mp.tasks.BaseOptions + + +class DeadBugPhase(str, Enum): + READY = "ready" + EXTENDING = "extending" + NEED_RESET = "need_reset" + NO_POSE = "no_pose" + + +@dataclass(frozen=True) +class Point: + x: float + y: float + z: float + visibility: float + + +@dataclass +class DeadBugMetrics: + left_arm_extended: bool + right_arm_extended: bool + left_leg_extended: bool + right_leg_extended: bool + left_elbow_angle: float + right_elbow_angle: float + left_knee_angle: float + right_knee_angle: float + torso_tilt: float + feedback: list[str] + + +@dataclass +class DeadBugResult: + rep_count: int + phase: DeadBugPhase + side: str | None + is_standard: bool + feedback: list[str] + metrics: DeadBugMetrics | None + + +class DeadBugDetector: + """MediaPipe Pose based dead bug detector and counter. + + The rules are intentionally conservative because a phone stream only gives + us 2D landmarks. A rep is counted when one diagonal pair extends cleanly and + the body returns to the bent-knee ready position. + """ + + LEFT_SHOULDER = 11 + RIGHT_SHOULDER = 12 + LEFT_ELBOW = 13 + RIGHT_ELBOW = 14 + LEFT_WRIST = 15 + RIGHT_WRIST = 16 + LEFT_HIP = 23 + RIGHT_HIP = 24 + LEFT_KNEE = 25 + RIGHT_KNEE = 26 + LEFT_ANKLE = 27 + RIGHT_ANKLE = 28 + + REQUIRED_LANDMARKS = ( + LEFT_SHOULDER, + RIGHT_SHOULDER, + LEFT_ELBOW, + RIGHT_ELBOW, + LEFT_WRIST, + RIGHT_WRIST, + LEFT_HIP, + RIGHT_HIP, + LEFT_KNEE, + RIGHT_KNEE, + LEFT_ANKLE, + RIGHT_ANKLE, + ) + + def __init__( + self, + model_path: str | Path | None = None, + *, + visibility_threshold: float = 0.45, + extension_confirm_frames: int = 4, + reset_confirm_frames: int = 3, + ) -> None: + if model_path is None: + model_path = Path(__file__).resolve().parent / "pose_models" / "pose_landmarker_full.task" + + self.model_path = str(model_path) + self.visibility_threshold = visibility_threshold + self.extension_confirm_frames = extension_confirm_frames + self.reset_confirm_frames = reset_confirm_frames + + options = PoseLandmarkerOptions( + base_options=BaseOptions(model_asset_path=self.model_path), + running_mode=VisionRunningMode.VIDEO, + num_poses=1, + min_pose_detection_confidence=0.5, + min_pose_presence_confidence=0.5, + min_tracking_confidence=0.5, + ) + self._landmarker = PoseLandmarker.create_from_options(options) + + self.rep_count = 0 + self.phase = DeadBugPhase.READY + self.active_side: str | None = None + self._candidate_side: str | None = None + self._candidate_frames = 0 + self._reset_frames = 0 + self._last_timestamp_ms = -1 + + def close(self) -> None: + self._landmarker.close() + + def process_frame(self, bgr_frame: np.ndarray, timestamp_ms: int) -> tuple[np.ndarray, DeadBugResult]: + timestamp_ms = self._normalize_timestamp(timestamp_ms) + rgb_frame = cv2.cvtColor(bgr_frame, cv2.COLOR_BGR2RGB) + mp_image = mp.Image(image_format=mp.ImageFormat.SRGB, data=rgb_frame) + pose_result = self._landmarker.detect_for_video(mp_image, timestamp_ms) + + annotated = bgr_frame.copy() + if not pose_result.pose_landmarks: + result = DeadBugResult( + rep_count=self.rep_count, + phase=DeadBugPhase.NO_POSE, + side=self.active_side, + is_standard=False, + feedback=["No full body detected"], + metrics=None, + ) + self._draw_status(annotated, result) + return annotated, result + + landmarks = [Point(lm.x, lm.y, lm.z, getattr(lm, "visibility", 1.0)) for lm in pose_result.pose_landmarks[0]] + self._draw_landmarks(annotated, landmarks) + + if not self._has_required_visibility(landmarks): + result = DeadBugResult( + rep_count=self.rep_count, + phase=DeadBugPhase.NO_POSE, + side=self.active_side, + is_standard=False, + feedback=["Keep shoulders, elbows, wrists, hips, knees, ankles visible"], + metrics=None, + ) + self._draw_status(annotated, result) + return annotated, result + + metrics = self._calculate_metrics(landmarks) + result = self._update_state(metrics) + self._draw_status(annotated, result) + return annotated, result + + def _normalize_timestamp(self, timestamp_ms: int) -> int: + if timestamp_ms <= self._last_timestamp_ms: + timestamp_ms = self._last_timestamp_ms + 1 + self._last_timestamp_ms = timestamp_ms + return timestamp_ms + + def _has_required_visibility(self, landmarks: list[Point]) -> bool: + return all(landmarks[i].visibility >= self.visibility_threshold for i in self.REQUIRED_LANDMARKS) + + def _calculate_metrics(self, lm: list[Point]) -> DeadBugMetrics: + left_elbow = angle(lm[self.LEFT_SHOULDER], lm[self.LEFT_ELBOW], lm[self.LEFT_WRIST]) + right_elbow = angle(lm[self.RIGHT_SHOULDER], lm[self.RIGHT_ELBOW], lm[self.RIGHT_WRIST]) + left_knee = angle(lm[self.LEFT_HIP], lm[self.LEFT_KNEE], lm[self.LEFT_ANKLE]) + right_knee = angle(lm[self.RIGHT_HIP], lm[self.RIGHT_KNEE], lm[self.RIGHT_ANKLE]) + + shoulder_width = distance(lm[self.LEFT_SHOULDER], lm[self.RIGHT_SHOULDER]) + hip_width = distance(lm[self.LEFT_HIP], lm[self.RIGHT_HIP]) + scale = max(shoulder_width, hip_width, 0.08) + + left_arm_extended = ( + left_elbow >= 145 + and distance(lm[self.LEFT_SHOULDER], lm[self.LEFT_WRIST]) >= scale * 1.15 + and lm[self.LEFT_WRIST].y <= lm[self.LEFT_SHOULDER].y + scale * 0.35 + ) + right_arm_extended = ( + right_elbow >= 145 + and distance(lm[self.RIGHT_SHOULDER], lm[self.RIGHT_WRIST]) >= scale * 1.15 + and lm[self.RIGHT_WRIST].y <= lm[self.RIGHT_SHOULDER].y + scale * 0.35 + ) + + left_leg_extended = ( + left_knee >= 150 + and distance(lm[self.LEFT_HIP], lm[self.LEFT_ANKLE]) >= scale * 1.55 + and lm[self.LEFT_ANKLE].y >= lm[self.LEFT_KNEE].y - scale * 0.2 + ) + right_leg_extended = ( + right_knee >= 150 + and distance(lm[self.RIGHT_HIP], lm[self.RIGHT_ANKLE]) >= scale * 1.55 + and lm[self.RIGHT_ANKLE].y >= lm[self.RIGHT_KNEE].y - scale * 0.2 + ) + + torso_tilt = abs(lm[self.LEFT_HIP].y - lm[self.RIGHT_HIP].y) / scale + feedback: list[str] = [] + if torso_tilt > 0.35: + feedback.append("Keep pelvis level and core stable") + if left_arm_extended and left_elbow < 160: + feedback.append("Straighten left arm") + if right_arm_extended and right_elbow < 160: + feedback.append("Straighten right arm") + if left_leg_extended and left_knee < 165: + feedback.append("Straighten left leg") + if right_leg_extended and right_knee < 165: + feedback.append("Straighten right leg") + + return DeadBugMetrics( + left_arm_extended=left_arm_extended, + right_arm_extended=right_arm_extended, + left_leg_extended=left_leg_extended, + right_leg_extended=right_leg_extended, + left_elbow_angle=left_elbow, + right_elbow_angle=right_elbow, + left_knee_angle=left_knee, + right_knee_angle=right_knee, + torso_tilt=torso_tilt, + feedback=feedback, + ) + + def _update_state(self, metrics: DeadBugMetrics) -> DeadBugResult: + side = self._detect_diagonal_extension(metrics) + ready = self._is_ready_position(metrics) + + if side is None: + self._candidate_side = None + self._candidate_frames = 0 + elif side == self._candidate_side: + self._candidate_frames += 1 + else: + self._candidate_side = side + self._candidate_frames = 1 + + if self.phase in (DeadBugPhase.READY, DeadBugPhase.NO_POSE): + if self._candidate_frames >= self.extension_confirm_frames and side is not None: + self.phase = DeadBugPhase.EXTENDING + self.active_side = side + self._reset_frames = 0 + elif self.phase == DeadBugPhase.EXTENDING: + if side == self.active_side: + self.phase = DeadBugPhase.NEED_RESET + elif self.phase == DeadBugPhase.NEED_RESET: + if ready: + self._reset_frames += 1 + if self._reset_frames >= self.reset_confirm_frames: + self.rep_count += 1 + self.phase = DeadBugPhase.READY + self.active_side = None + self._candidate_side = None + self._candidate_frames = 0 + self._reset_frames = 0 + else: + self._reset_frames = 0 + + feedback = list(metrics.feedback) + if side is None and not ready: + feedback.append("Extend opposite arm and leg only") + if ready: + feedback.append("Ready position") + elif side == "left_arm_right_leg": + feedback.append("Left arm + right leg") + elif side == "right_arm_left_leg": + feedback.append("Right arm + left leg") + + is_standard = side is not None and not metrics.feedback + return DeadBugResult( + rep_count=self.rep_count, + phase=self.phase, + side=side, + is_standard=is_standard, + feedback=feedback[:3], + metrics=metrics, + ) + + def _detect_diagonal_extension(self, metrics: DeadBugMetrics) -> str | None: + left_arm_right_leg = metrics.left_arm_extended and metrics.right_leg_extended + right_arm_left_leg = metrics.right_arm_extended and metrics.left_leg_extended + same_side_noise = ( + metrics.left_arm_extended + and metrics.left_leg_extended + or metrics.right_arm_extended + and metrics.right_leg_extended + ) + if same_side_noise: + return None + if left_arm_right_leg and not right_arm_left_leg: + return "left_arm_right_leg" + if right_arm_left_leg and not left_arm_right_leg: + return "right_arm_left_leg" + return None + + def _is_ready_position(self, metrics: DeadBugMetrics) -> bool: + knees_bent = metrics.left_knee_angle <= 140 and metrics.right_knee_angle <= 140 + legs_not_extended = not metrics.left_leg_extended and not metrics.right_leg_extended + return knees_bent and legs_not_extended and self._detect_diagonal_extension(metrics) is None + + def _draw_landmarks(self, image: np.ndarray, landmarks: list[Point]) -> None: + h, w = image.shape[:2] + connections = getattr(getattr(mp, "solutions", None), "pose", None) + pose_connections = getattr(connections, "POSE_CONNECTIONS", _POSE_CONNECTIONS) + for start, end in pose_connections: + if start >= len(landmarks) or end >= len(landmarks): + continue + p1 = landmarks[start] + p2 = landmarks[end] + if p1.visibility < self.visibility_threshold or p2.visibility < self.visibility_threshold: + continue + cv2.line( + image, + (int(p1.x * w), int(p1.y * h)), + (int(p2.x * w), int(p2.y * h)), + (65, 180, 255), + 2, + ) + for idx in self.REQUIRED_LANDMARKS: + p = landmarks[idx] + if p.visibility >= self.visibility_threshold: + cv2.circle(image, (int(p.x * w), int(p.y * h)), 4, (80, 255, 120), -1) + + def _draw_status(self, image: np.ndarray, result: DeadBugResult) -> None: + color = (60, 220, 90) if result.is_standard else (50, 180, 255) + cv2.rectangle(image, (12, 12), (520, 142), (20, 20, 20), -1) + cv2.putText(image, f"Dead bug reps: {result.rep_count}", (28, 48), cv2.FONT_HERSHEY_SIMPLEX, 0.9, color, 2) + cv2.putText(image, f"phase: {result.phase.value}", (28, 82), cv2.FONT_HERSHEY_SIMPLEX, 0.68, (230, 230, 230), 2) + status = "standard" if result.is_standard else "adjust" + cv2.putText(image, f"status: {status}", (28, 116), cv2.FONT_HERSHEY_SIMPLEX, 0.68, color, 2) + + y = 170 + for text in result.feedback: + cv2.putText(image, text, (28, y), cv2.FONT_HERSHEY_SIMPLEX, 0.68, (255, 255, 255), 2) + y += 30 + + +def angle(a: Point, b: Point, c: Point) -> float: + ba = np.array([a.x - b.x, a.y - b.y], dtype=np.float32) + bc = np.array([c.x - b.x, c.y - b.y], dtype=np.float32) + denom = float(np.linalg.norm(ba) * np.linalg.norm(bc)) + if denom == 0: + return 0.0 + cos_value = float(np.dot(ba, bc) / denom) + return float(np.degrees(np.arccos(np.clip(cos_value, -1.0, 1.0)))) + + +def distance(a: Point, b: Point) -> float: + return float(np.hypot(a.x - b.x, a.y - b.y)) + + +_POSE_CONNECTIONS = ( + (11, 12), + (11, 13), + (13, 15), + (12, 14), + (14, 16), + (11, 23), + (12, 24), + (23, 24), + (23, 25), + (25, 27), + (24, 26), + (26, 28), +) diff --git a/handle_client.py b/handle_client.py index d69069e..f233b13 100644 --- a/handle_client.py +++ b/handle_client.py @@ -6,6 +6,8 @@ import cv2 from loguru import logger from aiortc import RTCPeerConnection, RTCSessionDescription, RTCIceCandidate +from dead_bug_detector import DeadBugDetector + async def handle_client(websocket): client = websocket.remote_address @@ -40,15 +42,25 @@ async def handle_client(websocket): async def receive_video(track): logger.info("Start receiving video frames") frame_count = 0 + detector = DeadBugDetector() try: while True: frame = await track.recv() frame_count += 1 img = frame.to_ndarray(format="bgr24") - cv2.imshow("Android Camera (WebRTC)", img) + timestamp_ms = int(frame.time * 1000) if frame.time is not None else frame_count * 33 + annotated, pose_result = detector.process_frame(img, timestamp_ms) + cv2.imshow("Android Camera (WebRTC)", annotated) if frame_count % 100 == 0: - logger.info(f"Received {frame_count} frames, shape={img.shape}") + logger.info( + "Received {} frames, shape={}, reps={}, phase={}, feedback={}", + frame_count, + img.shape, + pose_result.rep_count, + pose_result.phase.value, + " | ".join(pose_result.feedback), + ) if cv2.waitKey(1) & 0xFF == 27: logger.info("ESC pressed, closing display") @@ -57,6 +69,8 @@ async def handle_client(websocket): logger.info("Video receive task cancelled") except Exception as e: logger.error(f"Video receive error: {e}") + finally: + detector.close() @pc.on("track") async def on_track(track): diff --git a/main.py b/main.py index 5aca3c9..cb5db29 100644 --- a/main.py +++ b/main.py @@ -1,27 +1,7 @@ -import mediapipe as mp +import asyncio -def main(): - BaseOptions = mp.tasks.BaseOptions - PoseLandmarker = mp.tasks.vision.PoseLandmarker - PoseLandmarkerOptions = mp.tasks.vision.PoseLandmarkerOptions - PoseLandmarkerResult = mp.tasks.vision.PoseLandmarkerResult - VisionRunningMode = mp.tasks.vision.RunningMode - - # Create a pose landmarker instance with the live stream mode: - def print_result(result: PoseLandmarkerResult, output_image: mp.Image, timestamp_ms: int): - print('pose landmarker result: {}'.format(result)) - - options = PoseLandmarkerOptions( - base_options=BaseOptions(model_asset_path=model_path), - running_mode=VisionRunningMode.LIVE_STREAM, - result_callback=print_result) - - with PoseLandmarker.create_from_options(options) as landmarker: - +from handle_client import main - - - -if __name__ == '__main__': - main() +if __name__ == "__main__": + asyncio.run(main()) diff --git a/requirements.txt b/requirements.txt index 0311b44..9fc8336 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,6 @@ aiortc>=1.9.0 websockets>=13.0 -opencv-python>=4.10.0 +opencv-contrib-python>=4.10.0 numpy>=2.0.0 loguru>=0.7.0 +mediapipe>=0.10.35