From 4485cbf702a6b66b549d676c9031b9ce034094bd Mon Sep 17 00:00:00 2001 From: hjwang <2392948297@qq.com> Date: Wed, 10 Jun 2026 10:14:43 +0800 Subject: [PATCH] Refactor into modular app structure Split monolithic files into focused modules: - app/core: settings, logging, lifecycle - app/signaling: websocket server, ICE parser, message models - app/webrtc: peer session, video receiver, frame source - app/vision: pose landmarker wrapper, model config, pose types - app/exercises/dead_bug: detector, metrics, rules, state machine, types - app/rendering: skeleton renderer, status overlay, window display - app/audio: rep announcer - app/diagnostics: perf timer, crash handler - configs: environment-based settings - tests: unit tests for rules, state machine, ICE parser - run.py: entry point --- .env.example | 24 + README.md | 14 + config/load_config.py => app/__init__.py | 0 app/audio/__init__.py | 0 .../audio/rep_announcer.py | 0 app/core/__init__.py | 0 app/core/lifecycle.py | 10 + app/core/logging.py | 20 + app/diagnostics/__init__.py | 0 app/diagnostics/crash_handler.py | 11 + app/diagnostics/perf_timer.py | 34 ++ app/exercises/__init__.py | 0 app/exercises/dead_bug/__init__.py | 0 app/exercises/dead_bug/detector.py | 172 +++++++ app/exercises/dead_bug/metrics.py | 92 ++++ app/exercises/dead_bug/rules.py | 24 + app/exercises/dead_bug/state_machine.py | 71 +++ app/exercises/dead_bug/types.py | 42 ++ app/main.py | 19 + app/rendering/__init__.py | 0 app/rendering/overlay_renderer.py | 20 + app/rendering/skeleton_renderer.py | 46 ++ app/rendering/window_display.py | 21 + app/signaling/__init__.py | 0 app/signaling/ice_parser.py | 30 ++ app/signaling/message_models.py | 12 + app/signaling/websocket_server.py | 26 ++ app/vision/__init__.py | 0 app/vision/frame_utils.py | 23 + app/vision/pose_landmarker.py | 57 +++ app/vision/pose_models.py | 14 + app/vision/pose_types.py | 80 ++++ app/webrtc/__init__.py | 0 app/webrtc/frame_source.py | 14 + app/webrtc/peer_session.py | 70 +++ app/webrtc/video_receiver.py | 106 +++++ configs/default.py | 32 ++ dead_bug_detector.py | 431 ------------------ handle_client.py | 188 -------- main.py | 29 -- run.py | 4 + tests/test_dead_bug_rules.py | 65 +++ tests/test_dead_bug_state_machine.py | 42 ++ tests/test_ice_parser.py | 35 ++ 44 files changed, 1230 insertions(+), 648 deletions(-) create mode 100644 .env.example create mode 100644 README.md rename config/load_config.py => app/__init__.py (100%) create mode 100644 app/audio/__init__.py rename rep_announcer.py => app/audio/rep_announcer.py (100%) create mode 100644 app/core/__init__.py create mode 100644 app/core/lifecycle.py create mode 100644 app/core/logging.py create mode 100644 app/diagnostics/__init__.py create mode 100644 app/diagnostics/crash_handler.py create mode 100644 app/diagnostics/perf_timer.py create mode 100644 app/exercises/__init__.py create mode 100644 app/exercises/dead_bug/__init__.py create mode 100644 app/exercises/dead_bug/detector.py create mode 100644 app/exercises/dead_bug/metrics.py create mode 100644 app/exercises/dead_bug/rules.py create mode 100644 app/exercises/dead_bug/state_machine.py create mode 100644 app/exercises/dead_bug/types.py create mode 100644 app/main.py create mode 100644 app/rendering/__init__.py create mode 100644 app/rendering/overlay_renderer.py create mode 100644 app/rendering/skeleton_renderer.py create mode 100644 app/rendering/window_display.py create mode 100644 app/signaling/__init__.py create mode 100644 app/signaling/ice_parser.py create mode 100644 app/signaling/message_models.py create mode 100644 app/signaling/websocket_server.py create mode 100644 app/vision/__init__.py create mode 100644 app/vision/frame_utils.py create mode 100644 app/vision/pose_landmarker.py create mode 100644 app/vision/pose_models.py create mode 100644 app/vision/pose_types.py create mode 100644 app/webrtc/__init__.py create mode 100644 app/webrtc/frame_source.py create mode 100644 app/webrtc/peer_session.py create mode 100644 app/webrtc/video_receiver.py create mode 100644 configs/default.py delete mode 100644 dead_bug_detector.py delete mode 100644 handle_client.py delete mode 100644 main.py create mode 100644 run.py create mode 100644 tests/test_dead_bug_rules.py create mode 100644 tests/test_dead_bug_state_machine.py create mode 100644 tests/test_ice_parser.py diff --git a/.env.example b/.env.example new file mode 100644 index 0000000..8e2f62f --- /dev/null +++ b/.env.example @@ -0,0 +1,24 @@ +# Server +POSEFIT_WS_HOST=0.0.0.0 +POSEFIT_WS_PORT=8765 + +# Video processing +POSEFIT_PROCESS_EVERY_N_FRAMES=1 + +# Model +POSEFIT_MODEL_PATH=pose_models/pose_landmarker_full.task +POSEFIT_PREFER_GPU=1 + +# Dead bug exercise +POSEFIT_VISIBILITY_THRESHOLD=0.45 +POSEFIT_EXTENSION_CONFIRM_FRAMES=4 +POSEFIT_RESET_CONFIRM_FRAMES=3 + +# Audio +POSEFIT_REP_ANNOUNCER_ENABLED=1 +POSEFIT_REP_ANNOUNCER_RATE=185 +POSEFIT_REP_ANNOUNCER_VOLUME=1.0 + +# Logging +POSEFIT_LOG_ROTATION=20 MB +POSEFIT_LOG_RETENTION=14 days diff --git a/README.md b/README.md new file mode 100644 index 0000000..d3d8295 --- /dev/null +++ b/README.md @@ -0,0 +1,14 @@ +# PoseFit Server + +Real-time exercise pose detection and coaching via WebRTC. + +## Quick Start + +``` +pip install -r requirements.txt +python run.py +``` + +## Configuration + +Copy `.env.example` to `.env` and adjust settings, or set environment variables directly. diff --git a/config/load_config.py b/app/__init__.py similarity index 100% rename from config/load_config.py rename to app/__init__.py diff --git a/app/audio/__init__.py b/app/audio/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/rep_announcer.py b/app/audio/rep_announcer.py similarity index 100% rename from rep_announcer.py rename to app/audio/rep_announcer.py diff --git a/app/core/__init__.py b/app/core/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/core/lifecycle.py b/app/core/lifecycle.py new file mode 100644 index 0000000..71637a4 --- /dev/null +++ b/app/core/lifecycle.py @@ -0,0 +1,10 @@ +from __future__ import annotations + +from app.diagnostics.crash_handler import enable_crash_handler +from configs.default import LOG_DIR + + +def startup() -> None: + enable_crash_handler(LOG_DIR) + from app.core.logging import setup_logging + setup_logging() diff --git a/app/core/logging.py b/app/core/logging.py new file mode 100644 index 0000000..533b11c --- /dev/null +++ b/app/core/logging.py @@ -0,0 +1,20 @@ +from __future__ import annotations + +from pathlib import Path + +from loguru import logger + +from configs.default import LOG_DIR, LOG_RETENTION, LOG_ROTATION + + +def setup_logging() -> None: + LOG_DIR.mkdir(parents=True, exist_ok=True) + + logger.add( + LOG_DIR / "posefit-server_{time:YYYY-MM-DD}.log", + rotation=LOG_ROTATION, + retention=LOG_RETENTION, + enqueue=True, + backtrace=True, + diagnose=True, + ) diff --git a/app/diagnostics/__init__.py b/app/diagnostics/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/diagnostics/crash_handler.py b/app/diagnostics/crash_handler.py new file mode 100644 index 0000000..a5486bf --- /dev/null +++ b/app/diagnostics/crash_handler.py @@ -0,0 +1,11 @@ +from __future__ import annotations + +import faulthandler +from pathlib import Path + + +def enable_crash_handler(log_dir: str | Path) -> None: + log_dir = Path(log_dir) + log_dir.mkdir(parents=True, exist_ok=True) + crash_log = open(log_dir / "posefit-crash.log", "a", buffering=1) + faulthandler.enable(file=crash_log, all_threads=True) diff --git a/app/diagnostics/perf_timer.py b/app/diagnostics/perf_timer.py new file mode 100644 index 0000000..25c2edd --- /dev/null +++ b/app/diagnostics/perf_timer.py @@ -0,0 +1,34 @@ +from __future__ import annotations + +import time +from contextlib import contextmanager + +from loguru import logger + + +class PerfTimer: + def __init__(self, name: str = "") -> None: + self.name = name + self._start = 0.0 + self._elapsed = 0.0 + + def start(self) -> PerfTimer: + self._start = time.perf_counter() + return self + + def stop(self) -> float: + self._elapsed = time.perf_counter() - self._start + return self._elapsed + + @property + def elapsed_ms(self) -> float: + return self._elapsed * 1000 + + +@contextmanager +def measure(name: str = ""): + timer = PerfTimer(name).start() + yield timer + elapsed = timer.stop() + if name: + logger.debug("{} took {:.1f}ms", name, timer.elapsed_ms) diff --git a/app/exercises/__init__.py b/app/exercises/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/exercises/dead_bug/__init__.py b/app/exercises/dead_bug/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/exercises/dead_bug/detector.py b/app/exercises/dead_bug/detector.py new file mode 100644 index 0000000..8f05094 --- /dev/null +++ b/app/exercises/dead_bug/detector.py @@ -0,0 +1,172 @@ +from __future__ import annotations + +import threading +import time + +import cv2 +import mediapipe as mp +import numpy as np +from loguru import logger + +from app.exercises.dead_bug.metrics import calculate_metrics +from app.exercises.dead_bug.rules import has_required_visibility +from app.exercises.dead_bug.state_machine import DeadBugStateMachine +from app.exercises.dead_bug.types import DeadBugMetrics, DeadBugPhase, DeadBugResult, Point +from app.rendering.overlay_renderer import draw_status_overlay +from app.rendering.skeleton_renderer import draw_landmarks +from app.vision.frame_utils import bgr_to_rgba +from app.vision.pose_landmarker import PoseLandmarkerWrapper +from app.vision.pose_types import ( + LEFT_ANKLE, + LEFT_ELBOW, + LEFT_HIP, + LEFT_KNEE, + LEFT_SHOULDER, + LEFT_WRIST, + REQUIRED_LANDMARKS, + RIGHT_ANKLE, + RIGHT_ELBOW, + RIGHT_HIP, + RIGHT_KNEE, + RIGHT_SHOULDER, + RIGHT_WRIST, +) + + +class DeadBugDetector: + def __init__( + self, + *, + model_path: str | None = None, + visibility_threshold: float = 0.45, + extension_confirm_frames: int = 4, + reset_confirm_frames: int = 3, + prefer_gpu: bool = True, + ) -> None: + self.visibility_threshold = visibility_threshold + + self._latest_result = None + self._result_lock = threading.Lock() + self._result_event = threading.Event() + self._inflight = False + self._inflight_started_at = 0.0 + + def on_result(pose_result, _image, _timestamp_ms): + with self._result_lock: + self._latest_result = pose_result + self._inflight = False + self._inflight_started_at = 0.0 + self._result_event.set() + + self._landmarker = PoseLandmarkerWrapper( + model_path=model_path, + prefer_gpu=prefer_gpu, + result_callback=on_result, + ) + + self._state = DeadBugStateMachine( + extension_confirm_frames=extension_confirm_frames, + reset_confirm_frames=reset_confirm_frames, + ) + + self._last_timestamp_ms = -1 + + def close(self) -> None: + self._landmarker.close() + + def process_frame(self, bgr_frame: np.ndarray, timestamp_ms: int) -> tuple[np.ndarray, DeadBugResult]: + timestamp_ms = self._normalize_timestamp(timestamp_ms) + + with self._result_lock: + if self._inflight and time.monotonic() - self._inflight_started_at > 0.5: + logger.warning("MediaPipe detect_async timed out; allowing next frame submission") + self._inflight = False + self._inflight_started_at = 0.0 + should_submit = not self._inflight + if should_submit: + self._inflight = True + self._inflight_started_at = time.monotonic() + + if should_submit: + rgba_frame = bgr_to_rgba(bgr_frame) + mp_image = mp.Image(image_format=mp.ImageFormat.SRGBA, data=rgba_frame) + self._result_event.clear() + try: + self._landmarker.detect_async(mp_image, timestamp_ms) + except Exception: + with self._result_lock: + self._inflight = False + self._inflight_started_at = 0.0 + raise + self._result_event.wait(timeout=0.08) + + with self._result_lock: + pose_result = self._latest_result + + annotated = bgr_frame.copy() + + if pose_result is None or not pose_result.pose_landmarks: + result = DeadBugResult( + rep_count=self._state.rep_count, + phase=DeadBugPhase.NO_POSE, + side=self._state.active_side, + is_standard=False, + feedback=["No full body detected"], + metrics=None, + ) + draw_status_overlay(annotated, result) + return annotated, result + + landmarks = [Point(lm.x, lm.y, lm.z, getattr(lm, "visibility", 1.0)) for lm in pose_result.pose_landmarks[0]] + draw_landmarks(annotated, landmarks, REQUIRED_LANDMARKS, visibility_threshold=self.visibility_threshold) + + if not has_required_visibility(landmarks, REQUIRED_LANDMARKS, self.visibility_threshold): + result = DeadBugResult( + rep_count=self._state.rep_count, + phase=DeadBugPhase.NO_POSE, + side=self._state.active_side, + is_standard=False, + feedback=["Keep shoulders, elbows, wrists, hips, knees, ankles visible"], + metrics=None, + ) + draw_status_overlay(annotated, result) + return annotated, result + + raw = calculate_metrics( + landmarks, + left_shoulder=LEFT_SHOULDER, + right_shoulder=RIGHT_SHOULDER, + left_elbow=LEFT_ELBOW, + right_elbow=RIGHT_ELBOW, + left_wrist=LEFT_WRIST, + right_wrist=RIGHT_WRIST, + left_hip=LEFT_HIP, + right_hip=RIGHT_HIP, + left_knee=LEFT_KNEE, + right_knee=RIGHT_KNEE, + left_ankle=LEFT_ANKLE, + right_ankle=RIGHT_ANKLE, + visibility_threshold=self.visibility_threshold, + ) + + metrics = DeadBugMetrics( + left_arm_extended=raw["left_arm_extended"], + right_arm_extended=raw["right_arm_extended"], + left_leg_extended=raw["left_leg_extended"], + right_leg_extended=raw["right_leg_extended"], + left_elbow_angle=raw["left_elbow_angle"], + right_elbow_angle=raw["right_elbow_angle"], + left_knee_angle=raw["left_knee_angle"], + right_knee_angle=raw["right_knee_angle"], + feedback=raw["feedback"], + ) + + result = self._state.update(metrics) + draw_status_overlay(annotated, result) + return annotated, result + + def _normalize_timestamp(self, timestamp_ms: int) -> int: + if timestamp_ms <= self._last_timestamp_ms: + timestamp_ms = self._last_timestamp_ms + 1 + self._last_timestamp_ms = timestamp_ms + return timestamp_ms diff --git a/app/exercises/dead_bug/metrics.py b/app/exercises/dead_bug/metrics.py new file mode 100644 index 0000000..ad5601b --- /dev/null +++ b/app/exercises/dead_bug/metrics.py @@ -0,0 +1,92 @@ +from __future__ import annotations + +import cv2 +import numpy as np + +from app.exercises.dead_bug.types import Point + + +def angle(a: Point, b: Point, c: Point) -> float: + ba = np.array([a.x - b.x, a.y - b.y], dtype=np.float32) + bc = np.array([c.x - b.x, c.y - b.y], dtype=np.float32) + denom = float(np.linalg.norm(ba) * np.linalg.norm(bc)) + if denom == 0: + return 0.0 + cos_value = float(np.dot(ba, bc) / denom) + return float(np.degrees(np.arccos(np.clip(cos_value, -1.0, 1.0)))) + + +def distance(a: Point, b: Point) -> float: + return float(np.hypot(a.x - b.x, a.y - b.y)) + + +def calculate_metrics( + lm: list[Point], + *, + left_shoulder: int, + right_shoulder: int, + left_elbow: int, + right_elbow: int, + left_wrist: int, + right_wrist: int, + left_hip: int, + right_hip: int, + left_knee: int, + right_knee: int, + left_ankle: int, + right_ankle: int, + visibility_threshold: float = 0.45, +) -> dict: + left_elbow_angle = angle(lm[left_shoulder], lm[left_elbow], lm[left_wrist]) + right_elbow_angle = angle(lm[right_shoulder], lm[right_elbow], lm[right_wrist]) + left_knee_angle = angle(lm[left_hip], lm[left_knee], lm[left_ankle]) + right_knee_angle = angle(lm[right_hip], lm[right_knee], lm[right_ankle]) + + shoulder_width = distance(lm[left_shoulder], lm[right_shoulder]) + hip_width = distance(lm[left_hip], lm[right_hip]) + scale = max(shoulder_width, hip_width, 0.08) + + left_arm_extended = ( + left_elbow_angle >= 145 + and distance(lm[left_shoulder], lm[left_wrist]) >= scale * 1.15 + and lm[left_wrist].y <= lm[left_shoulder].y + scale * 0.35 + ) + right_arm_extended = ( + right_elbow_angle >= 145 + and distance(lm[right_shoulder], lm[right_wrist]) >= scale * 1.15 + and lm[right_wrist].y <= lm[right_shoulder].y + scale * 0.35 + ) + + left_leg_extended = ( + left_knee_angle >= 150 + and distance(lm[left_hip], lm[left_ankle]) >= scale * 1.55 + and lm[left_ankle].y >= lm[left_knee].y - scale * 0.2 + ) + right_leg_extended = ( + right_knee_angle >= 150 + and distance(lm[right_hip], lm[right_ankle]) >= scale * 1.55 + and lm[right_ankle].y >= lm[right_knee].y - scale * 0.2 + ) + + feedback: list[str] = [] + if left_arm_extended and left_elbow_angle < 160: + feedback.append("Straighten left arm") + if right_arm_extended and right_elbow_angle < 160: + feedback.append("Straighten right arm") + if left_leg_extended and left_knee_angle < 165: + feedback.append("Straighten left leg") + if right_leg_extended and right_knee_angle < 165: + feedback.append("Straighten right leg") + + return { + "left_arm_extended": left_arm_extended, + "right_arm_extended": right_arm_extended, + "left_leg_extended": left_leg_extended, + "right_leg_extended": right_leg_extended, + "left_elbow_angle": left_elbow_angle, + "right_elbow_angle": right_elbow_angle, + "left_knee_angle": left_knee_angle, + "right_knee_angle": right_knee_angle, + "scale": scale, + "feedback": feedback, + } diff --git a/app/exercises/dead_bug/rules.py b/app/exercises/dead_bug/rules.py new file mode 100644 index 0000000..2256d81 --- /dev/null +++ b/app/exercises/dead_bug/rules.py @@ -0,0 +1,24 @@ +from __future__ import annotations + +from app.exercises.dead_bug.types import DeadBugMetrics, Point + + +def has_required_visibility(landmarks: list[Point], required_indices: tuple[int, ...], visibility_threshold: float) -> bool: + return all(landmarks[i].visibility >= visibility_threshold for i in required_indices) + + +def detect_diagonal_extension(metrics: DeadBugMetrics) -> str | None: + if metrics.left_leg_extended and metrics.right_leg_extended: + return None + + if metrics.right_leg_extended and metrics.left_arm_extended: + return "left_arm_right_leg" + if metrics.left_leg_extended and metrics.right_arm_extended: + return "right_arm_left_leg" + return None + + +def is_ready_position(metrics: DeadBugMetrics) -> bool: + knees_bent = metrics.left_knee_angle <= 140 and metrics.right_knee_angle <= 140 + legs_not_extended = not metrics.left_leg_extended and not metrics.right_leg_extended + return knees_bent and legs_not_extended and detect_diagonal_extension(metrics) is None diff --git a/app/exercises/dead_bug/state_machine.py b/app/exercises/dead_bug/state_machine.py new file mode 100644 index 0000000..d79073a --- /dev/null +++ b/app/exercises/dead_bug/state_machine.py @@ -0,0 +1,71 @@ +from __future__ import annotations + +from app.exercises.dead_bug.rules import detect_diagonal_extension, is_ready_position +from app.exercises.dead_bug.types import DeadBugMetrics, DeadBugPhase, DeadBugResult + + +class DeadBugStateMachine: + def __init__(self, *, extension_confirm_frames: int = 4, reset_confirm_frames: int = 3) -> None: + self.extension_confirm_frames = extension_confirm_frames + self.reset_confirm_frames = reset_confirm_frames + + self.rep_count = 0 + self.phase = DeadBugPhase.READY + self.active_side: str | None = None + self._candidate_side: str | None = None + self._candidate_frames = 0 + self._reset_frames = 0 + + def update(self, metrics: DeadBugMetrics) -> DeadBugResult: + side = detect_diagonal_extension(metrics) + ready = is_ready_position(metrics) + + if side is None: + self._candidate_side = None + self._candidate_frames = 0 + elif side == self._candidate_side: + self._candidate_frames += 1 + else: + self._candidate_side = side + self._candidate_frames = 1 + + if self.phase in (DeadBugPhase.READY, DeadBugPhase.NO_POSE): + if self._candidate_frames >= self.extension_confirm_frames and side is not None: + self.phase = DeadBugPhase.EXTENDING + self.active_side = side + self._reset_frames = 0 + elif self.phase == DeadBugPhase.EXTENDING: + if side == self.active_side: + self.phase = DeadBugPhase.NEED_RESET + elif self.phase == DeadBugPhase.NEED_RESET: + if ready: + self._reset_frames += 1 + if self._reset_frames >= self.reset_confirm_frames: + self.rep_count += 1 + self.phase = DeadBugPhase.READY + self.active_side = None + self._candidate_side = None + self._candidate_frames = 0 + self._reset_frames = 0 + else: + self._reset_frames = 0 + + feedback = list(metrics.feedback) + if side is None and not ready: + feedback.append("Extend opposite arm and leg only") + if ready: + feedback.append("Ready position") + elif side == "left_arm_right_leg": + feedback.append("Left arm + right leg") + elif side == "right_arm_left_leg": + feedback.append("Right arm + left leg") + + is_standard = side is not None and not metrics.feedback + return DeadBugResult( + rep_count=self.rep_count, + phase=self.phase, + side=side, + is_standard=is_standard, + feedback=feedback[:3], + metrics=metrics, + ) diff --git a/app/exercises/dead_bug/types.py b/app/exercises/dead_bug/types.py new file mode 100644 index 0000000..15c10ec --- /dev/null +++ b/app/exercises/dead_bug/types.py @@ -0,0 +1,42 @@ +from __future__ import annotations + +from dataclasses import dataclass +from enum import Enum + + +class DeadBugPhase(str, Enum): + READY = "ready" + EXTENDING = "extending" + NEED_RESET = "need_reset" + NO_POSE = "no_pose" + + +@dataclass(frozen=True) +class Point: + x: float + y: float + z: float + visibility: float + + +@dataclass +class DeadBugMetrics: + left_arm_extended: bool + right_arm_extended: bool + left_leg_extended: bool + right_leg_extended: bool + left_elbow_angle: float + right_elbow_angle: float + left_knee_angle: float + right_knee_angle: float + feedback: list[str] + + +@dataclass +class DeadBugResult: + rep_count: int + phase: DeadBugPhase + side: str | None + is_standard: bool + feedback: list[str] + metrics: DeadBugMetrics | None diff --git a/app/main.py b/app/main.py new file mode 100644 index 0000000..d21a230 --- /dev/null +++ b/app/main.py @@ -0,0 +1,19 @@ +from __future__ import annotations + +import os + +os.environ["MEDIAPIPE_DISABLE_LOGGING"] = "1" +os.environ["GLOG_minloglevel"] = "3" + +import asyncio + +from loguru import logger + +from app.core.lifecycle import startup +from app.signaling.websocket_server import main as serve + + +if __name__ == "__main__": + startup() + logger.info("Starting server...") + asyncio.run(serve()) diff --git a/app/rendering/__init__.py b/app/rendering/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/rendering/overlay_renderer.py b/app/rendering/overlay_renderer.py new file mode 100644 index 0000000..3e30ad5 --- /dev/null +++ b/app/rendering/overlay_renderer.py @@ -0,0 +1,20 @@ +from __future__ import annotations + +import cv2 +import numpy as np + +from app.exercises.dead_bug.types import DeadBugResult + + +def draw_status_overlay(image: np.ndarray, result: DeadBugResult) -> None: + color = (60, 220, 90) if result.is_standard else (50, 180, 255) + cv2.rectangle(image, (12, 12), (520, 142), (20, 20, 20), -1) + cv2.putText(image, f"Dead bug reps: {result.rep_count}", (28, 48), cv2.FONT_HERSHEY_SIMPLEX, 0.9, color, 2) + cv2.putText(image, f"phase: {result.phase.value}", (28, 82), cv2.FONT_HERSHEY_SIMPLEX, 0.68, (230, 230, 230), 2) + status = "standard" if result.is_standard else "adjust" + cv2.putText(image, f"status: {status}", (28, 116), cv2.FONT_HERSHEY_SIMPLEX, 0.68, color, 2) + + y = 170 + for text in result.feedback: + cv2.putText(image, text, (28, y), cv2.FONT_HERSHEY_SIMPLEX, 0.68, (255, 255, 255), 2) + y += 30 diff --git a/app/rendering/skeleton_renderer.py b/app/rendering/skeleton_renderer.py new file mode 100644 index 0000000..7e033e4 --- /dev/null +++ b/app/rendering/skeleton_renderer.py @@ -0,0 +1,46 @@ +from __future__ import annotations + +import cv2 +import numpy as np + +from app.exercises.dead_bug.types import DeadBugResult, Point +from app.vision.pose_types import _POSE_CONNECTIONS + + +def draw_landmarks( + image: np.ndarray, + landmarks: list[Point], + required_indices: tuple[int, ...], + connections: tuple[tuple[int, int], ...] | None = None, + visibility_threshold: float = 0.45, + line_color: tuple[int, int, int] = (65, 180, 255), + point_color: tuple[int, int, int] = (80, 255, 120), + line_thickness: int = 2, + point_radius: int = 4, +) -> None: + if connections is None: + connections = _POSE_CONNECTIONS + + h, w = image.shape[:2] + + for start, end in connections: + if start >= len(landmarks) or end >= len(landmarks): + continue + p1 = landmarks[start] + p2 = landmarks[end] + if p1.visibility < visibility_threshold or p2.visibility < visibility_threshold: + continue + cv2.line( + image, + (int(p1.x * w), int(p1.y * h)), + (int(p2.x * w), int(p2.y * h)), + line_color, + line_thickness, + ) + + for idx in required_indices: + if idx >= len(landmarks): + continue + p = landmarks[idx] + if p.visibility >= visibility_threshold: + cv2.circle(image, (int(p.x * w), int(p.y * h)), point_radius, point_color, -1) diff --git a/app/rendering/window_display.py b/app/rendering/window_display.py new file mode 100644 index 0000000..58c0e0c --- /dev/null +++ b/app/rendering/window_display.py @@ -0,0 +1,21 @@ +from __future__ import annotations + +import cv2 + +WINDOW_NAME = "Android Camera (WebRTC)" + + +def show_frame(image, window_name: str = WINDOW_NAME) -> None: + cv2.imshow(window_name, image) + + +def wait_key(delay_ms: int = 1) -> int: + return cv2.waitKey(delay_ms) & 0xFF + + +def is_esc_pressed() -> bool: + return wait_key(1) == 27 + + +def close_window() -> None: + cv2.destroyAllWindows() diff --git a/app/signaling/__init__.py b/app/signaling/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/signaling/ice_parser.py b/app/signaling/ice_parser.py new file mode 100644 index 0000000..fcf4af4 --- /dev/null +++ b/app/signaling/ice_parser.py @@ -0,0 +1,30 @@ +from __future__ import annotations + +import re +from typing import Any + +from aiortc import RTCIceCandidate + + +def parse_ice(data: dict[str, Any]) -> RTCIceCandidate | None: + match = re.match( + r'candidate:(\S+) (\d) (\S+) (\d+) (\S+) (\d+) typ (\S+)(?: raddr (\S+) rport (\d+))?', + data["candidate"], + ) + if not match: + return None + g = match.groups() + cand = RTCIceCandidate( + foundation=g[0], + component=int(g[1]), + protocol=g[2].lower(), + priority=int(g[3]), + ip=g[4], + port=int(g[5]), + type=g[6], + relatedAddress=g[7], + relatedPort=int(g[8]) if g[8] else None, + ) + cand.sdpMid = data.get("sdpMid") + cand.sdpMLineIndex = data.get("sdpMLineIndex", 0) + return cand diff --git a/app/signaling/message_models.py b/app/signaling/message_models.py new file mode 100644 index 0000000..16f4bfc --- /dev/null +++ b/app/signaling/message_models.py @@ -0,0 +1,12 @@ +from __future__ import annotations + +from dataclasses import dataclass, field + + +@dataclass +class SignalingMessage: + type: str + sdp: str = "" + candidate: str = "" + sdpMid: str | None = None + sdpMLineIndex: int = 0 diff --git a/app/signaling/websocket_server.py b/app/signaling/websocket_server.py new file mode 100644 index 0000000..ccc44de --- /dev/null +++ b/app/signaling/websocket_server.py @@ -0,0 +1,26 @@ +from __future__ import annotations + +import asyncio +import json + +import websockets +from loguru import logger + +from app.webrtc.peer_session import PeerSession +from configs.default import WS_HOST, WS_MAX_SIZE, WS_PORT + + +async def handle_client(websocket): + client = websocket.remote_address + logger.info(f"Client connected: {client}") + + session = PeerSession() + await session.handle(websocket) + + logger.info(f"Connection closed: {client}") + + +async def main(): + logger.info(f"WebRTC signaling server: ws://{WS_HOST}:{WS_PORT}") + async with websockets.serve(handle_client, WS_HOST, WS_PORT, max_size=WS_MAX_SIZE): + await asyncio.Future() diff --git a/app/vision/__init__.py b/app/vision/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/vision/frame_utils.py b/app/vision/frame_utils.py new file mode 100644 index 0000000..b7a894c --- /dev/null +++ b/app/vision/frame_utils.py @@ -0,0 +1,23 @@ +from __future__ import annotations + +import cv2 +import numpy as np + + +TARGET_WIDTH = 1280 +TARGET_HEIGHT = 720 + + +def resize_to_target(image: np.ndarray, width: int = TARGET_WIDTH, height: int = TARGET_HEIGHT) -> np.ndarray: + h, w = image.shape[:2] + if w == width and h == height: + return image + return cv2.resize(image, (width, height)) + + +def bgr_to_rgba(bgr: np.ndarray) -> np.ndarray: + return cv2.cvtColor(bgr, cv2.COLOR_BGR2RGBA) + + +def bgr_to_rgb(bgr: np.ndarray) -> np.ndarray: + return cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB) diff --git a/app/vision/pose_landmarker.py b/app/vision/pose_landmarker.py new file mode 100644 index 0000000..7257c46 --- /dev/null +++ b/app/vision/pose_landmarker.py @@ -0,0 +1,57 @@ +from __future__ import annotations + +import threading +import time +from typing import Callable + +import mediapipe as mp +from loguru import logger + +from app.vision.pose_models import DEFAULT_MODEL_PATH + +PoseLandmarker = mp.tasks.vision.PoseLandmarker +PoseLandmarkerOptions = mp.tasks.vision.PoseLandmarkerOptions +VisionRunningMode = mp.tasks.vision.RunningMode +BaseOptions = mp.tasks.BaseOptions + + +class PoseLandmarkerWrapper: + def __init__( + self, + *, + model_path: str | None = None, + prefer_gpu: bool = True, + result_callback: Callable | None = None, + ) -> None: + self.model_path = model_path or DEFAULT_MODEL_PATH + + if prefer_gpu: + try: + self.delegate = BaseOptions.Delegate.GPU + self._landmarker = self._create(PoseLandmarker.Delegate.GPU) + logger.info("MediaPipe PoseLandmarker initialized with GPU delegate") + return + except Exception as exc: + logger.warning("MediaPipe GPU delegate unavailable, falling back to CPU: {}", exc) + + self.delegate = BaseOptions.Delegate.CPU + self._landmarker = self._create(PoseLandmarker.Delegate.CPU, result_callback) + logger.info("MediaPipe PoseLandmarker initialized with CPU delegate") + + def _create(self, delegate, result_callback=None): + options = PoseLandmarkerOptions( + base_options=BaseOptions(model_asset_path=self.model_path, delegate=delegate), + running_mode=VisionRunningMode.LIVE_STREAM, + result_callback=result_callback, + num_poses=1, + min_pose_detection_confidence=0.5, + min_pose_presence_confidence=0.5, + min_tracking_confidence=0.5, + ) + return PoseLandmarker.create_from_options(options) + + def detect_async(self, mp_image, timestamp_ms: int) -> None: + return self._landmarker.detect_async(mp_image, timestamp_ms) + + def close(self) -> None: + self._landmarker.close() diff --git a/app/vision/pose_models.py b/app/vision/pose_models.py new file mode 100644 index 0000000..01aff62 --- /dev/null +++ b/app/vision/pose_models.py @@ -0,0 +1,14 @@ +from __future__ import annotations + +from pathlib import Path + +import mediapipe as mp + +BaseOptions = mp.tasks.BaseOptions + +_MODELS_DIR = Path(__file__).resolve().parent.parent.parent / "pose_models" + +POSE_LANDMARKER_FULL = _MODELS_DIR / "pose_landmarker_full.task" +POSE_LANDMARKER_LITE = _MODELS_DIR / "pose_landmarker_lite.task" + +DEFAULT_MODEL_PATH = str(POSE_LANDMARKER_FULL) if POSE_LANDMARKER_FULL.exists() else str(POSE_LANDMARKER_LITE) diff --git a/app/vision/pose_types.py b/app/vision/pose_types.py new file mode 100644 index 0000000..a969929 --- /dev/null +++ b/app/vision/pose_types.py @@ -0,0 +1,80 @@ +from __future__ import annotations + +_POSE_CONNECTIONS = ( + (11, 12), + (11, 13), + (13, 15), + (12, 14), + (14, 16), + (11, 23), + (12, 24), + (23, 24), + (23, 25), + (25, 27), + (24, 26), + (26, 28), +) + +LANDMARK_NAMES: dict[int, str] = { + 0: "nose", + 1: "left_eye_inner", + 2: "left_eye", + 3: "left_eye_outer", + 4: "right_eye_inner", + 5: "right_eye", + 6: "right_eye_outer", + 7: "left_ear", + 8: "right_ear", + 9: "mouth_left", + 10: "mouth_right", + 11: "left_shoulder", + 12: "right_shoulder", + 13: "left_elbow", + 14: "right_elbow", + 15: "left_wrist", + 16: "right_wrist", + 17: "left_pinky", + 18: "right_pinky", + 19: "left_index", + 20: "right_index", + 21: "left_thumb", + 22: "right_thumb", + 23: "left_hip", + 24: "right_hip", + 25: "left_knee", + 26: "right_knee", + 27: "left_ankle", + 28: "right_ankle", + 29: "left_heel", + 30: "right_heel", + 31: "left_foot_index", + 32: "right_foot_index", +} + +LEFT_SHOULDER = 11 +RIGHT_SHOULDER = 12 +LEFT_ELBOW = 13 +RIGHT_ELBOW = 14 +LEFT_WRIST = 15 +RIGHT_WRIST = 16 +LEFT_HIP = 23 +RIGHT_HIP = 24 +LEFT_KNEE = 25 +RIGHT_KNEE = 26 +LEFT_ANKLE = 27 +RIGHT_ANKLE = 28 + +REQUIRED_LANDMARKS = ( + LEFT_SHOULDER, + RIGHT_SHOULDER, + LEFT_ELBOW, + RIGHT_ELBOW, + LEFT_WRIST, + RIGHT_WRIST, + LEFT_HIP, + RIGHT_HIP, + LEFT_KNEE, + RIGHT_KNEE, + LEFT_ANKLE, + RIGHT_ANKLE, +) diff --git a/app/webrtc/__init__.py b/app/webrtc/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/webrtc/frame_source.py b/app/webrtc/frame_source.py new file mode 100644 index 0000000..80d7ed1 --- /dev/null +++ b/app/webrtc/frame_source.py @@ -0,0 +1,14 @@ +from __future__ import annotations + +import numpy as np +from loguru import logger + + +TARGET_WIDTH = 1280 +TARGET_HEIGHT = 720 + + +def validate_frame_size(image: np.ndarray, width: int = TARGET_WIDTH, height: int = TARGET_HEIGHT) -> None: + h, w = image.shape[:2] + if w != width or h != height: + logger.warning("Unexpected frame size: {}x{}", w, h) diff --git a/app/webrtc/peer_session.py b/app/webrtc/peer_session.py new file mode 100644 index 0000000..611324a --- /dev/null +++ b/app/webrtc/peer_session.py @@ -0,0 +1,70 @@ +from __future__ import annotations + +import asyncio +import json + +import websockets +from aiortc import RTCPeerConnection, RTCSessionDescription +from loguru import logger + +from app.signaling.ice_parser import parse_ice +from app.webrtc.video_receiver import VideoReceiver + + +class PeerSession: + def __init__(self) -> None: + self._pc = RTCPeerConnection() + self._video_task: asyncio.Task | None = None + + async def handle(self, websocket) -> None: + self._setup_events() + + try: + async for message in websocket: + data = json.loads(message) + msg_type = data.get("type") + + if msg_type == "offer": + offer = RTCSessionDescription(sdp=data["sdp"], type="offer") + await self._pc.setRemoteDescription(offer) + answer = await self._pc.createAnswer() + await self._pc.setLocalDescription(answer) + await websocket.send(json.dumps({ + "type": "answer", + "sdp": self._pc.localDescription.sdp, + })) + + elif msg_type == "candidate": + cand = parse_ice(data) + if cand: + await self._pc.addIceCandidate(cand) + + except websockets.ConnectionClosed: + pass + except Exception as e: + logger.exception(f"Error: {e}") + finally: + await self._cleanup() + + def _setup_events(self) -> None: + @self._pc.on("track") + async def on_track(track): + logger.info(f"Track received: kind={track.kind}") + if track.kind == "video": + receiver = VideoReceiver(track) + self._video_task = asyncio.ensure_future(receiver.run()) + + @self._pc.on("iceconnectionstatechange") + async def on_iceconnectionstatechange(): + logger.info(f"ICE state: {self._pc.iceConnectionState}") + if self._pc.iceConnectionState in ("failed", "closed", "disconnected"): + await self._pc.close() + + async def _cleanup(self) -> None: + if self._video_task: + self._video_task.cancel() + try: + await self._video_task + except asyncio.CancelledError: + pass + await self._pc.close() diff --git a/app/webrtc/video_receiver.py b/app/webrtc/video_receiver.py new file mode 100644 index 0000000..766f509 --- /dev/null +++ b/app/webrtc/video_receiver.py @@ -0,0 +1,106 @@ +from __future__ import annotations + +import asyncio +import os + +import cv2 +from aiortc.mediastreams import MediaStreamError +from loguru import logger + +from app.audio.rep_announcer import RepAnnouncer +from app.exercises.dead_bug.detector import DeadBugDetector +from app.rendering.window_display import close_window, is_esc_pressed, show_frame +from configs.default import ( + EXTENSION_CONFIRM_FRAMES, + MODEL_PATH, + PREFER_GPU, + PROCESS_EVERY_N_FRAMES, + REP_ANNOUNCER_ENABLED, + REP_ANNOUNCER_RATE, + REP_ANNOUNCER_VOLUME, + RESET_CONFIRM_FRAMES, + VISIBILITY_THRESHOLD, +) + + +def _format_pose_debug(pose_result) -> str: + metrics = pose_result.metrics + if metrics is None: + return "metrics=None" + return ( + f"side={pose_result.side}, standard={pose_result.is_standard}, " + f"angles(le={metrics.left_elbow_angle:.1f}, re={metrics.right_elbow_angle:.1f}, " + f"lk={metrics.left_knee_angle:.1f}, rk={metrics.right_knee_angle:.1f}), " + f"extended(la={metrics.left_arm_extended}, ra={metrics.right_arm_extended}, " + f"ll={metrics.left_leg_extended}, rl={metrics.right_leg_extended})" + ) + + +class VideoReceiver: + def __init__(self, track) -> None: + self._track = track + + async def run(self) -> None: + logger.info("Start receiving video frames, process_every_n={}", PROCESS_EVERY_N_FRAMES) + + frame_count = 0 + processed_count = 0 + detector = DeadBugDetector( + model_path=MODEL_PATH, + visibility_threshold=VISIBILITY_THRESHOLD, + extension_confirm_frames=EXTENSION_CONFIRM_FRAMES, + reset_confirm_frames=RESET_CONFIRM_FRAMES, + prefer_gpu=PREFER_GPU, + ) + announcer = RepAnnouncer( + enabled=REP_ANNOUNCER_ENABLED, + rate=REP_ANNOUNCER_RATE, + volume=REP_ANNOUNCER_VOLUME, + ) + last_announced_rep = 0 + last_pose_result = None + last_annotated = None + + try: + while True: + frame = await self._track.recv() + frame_count += 1 + raw_img = frame.to_ndarray(format="bgr24") + timestamp_ms = int(frame.time * 1000) if frame.time is not None else frame_count * 33 + + if frame_count % PROCESS_EVERY_N_FRAMES == 0 or last_pose_result is None: + processed_count += 1 + last_annotated, last_pose_result = detector.process_frame(raw_img, timestamp_ms) + if last_pose_result.rep_count > last_announced_rep: + last_announced_rep = last_pose_result.rep_count + announcer.announce_count(last_announced_rep) + + display_img = last_annotated if last_annotated is not None else raw_img + show_frame(display_img) + + if frame_count % 100 == 0: + logger.info( + "Received {} frames, processed={}, raw_shape={}, reps={}, phase={}, feedback={}, {}", + frame_count, + processed_count, + raw_img.shape, + last_pose_result.rep_count if last_pose_result is not None else 0, + last_pose_result.phase.value if last_pose_result is not None else "none", + " | ".join(last_pose_result.feedback) if last_pose_result is not None else "", + _format_pose_debug(last_pose_result) if last_pose_result is not None else "metrics=None", + ) + + if is_esc_pressed(): + logger.info("ESC pressed, closing display") + break + + except asyncio.CancelledError: + logger.info("Video receive task cancelled") + except MediaStreamError: + logger.info("Video track ended") + except Exception as e: + logger.exception(f"Video receive error: {e!r}") + finally: + announcer.close() + detector.close() + close_window() diff --git a/configs/default.py b/configs/default.py new file mode 100644 index 0000000..0c2a7ba --- /dev/null +++ b/configs/default.py @@ -0,0 +1,32 @@ +from __future__ import annotations + +import os +from pathlib import Path + +# ── Server ────────────────────────────────────────────────────────────────── +WS_HOST = os.getenv("POSEFIT_WS_HOST", "0.0.0.0") +WS_PORT = int(os.getenv("POSEFIT_WS_PORT", "8765")) +WS_MAX_SIZE = int(os.getenv("POSEFIT_WS_MAX_SIZE", str(10 * 1024 * 1024))) + +# ── Video processing ──────────────────────────────────────────────────────── +PROCESS_EVERY_N_FRAMES = max(1, int(os.getenv("POSEFIT_PROCESS_EVERY_N_FRAMES", "1"))) + +# ── Model ─────────────────────────────────────────────────────────────────── +MODEL_DIR: Path = Path(__file__).resolve().parent.parent / "pose_models" +MODEL_PATH = os.getenv("POSEFIT_MODEL_PATH", str(MODEL_DIR / "pose_landmarker_full.task")) +PREFER_GPU = os.getenv("POSEFIT_PREFER_GPU", "1") not in ("0", "false", "False") + +# ── Dead bug exercise ─────────────────────────────────────────────────────── +VISIBILITY_THRESHOLD = float(os.getenv("POSEFIT_VISIBILITY_THRESHOLD", "0.45")) +EXTENSION_CONFIRM_FRAMES = int(os.getenv("POSEFIT_EXTENSION_CONFIRM_FRAMES", "4")) +RESET_CONFIRM_FRAMES = int(os.getenv("POSEFIT_RESET_CONFIRM_FRAMES", "3")) + +# ── Audio ─────────────────────────────────────────────────────────────────── +REP_ANNOUNCER_ENABLED = os.getenv("POSEFIT_REP_ANNOUNCER_ENABLED", "1") not in ("0", "false", "False") +REP_ANNOUNCER_RATE = int(os.getenv("POSEFIT_REP_ANNOUNCER_RATE", "185")) +REP_ANNOUNCER_VOLUME = float(os.getenv("POSEFIT_REP_ANNOUNCER_VOLUME", "1.0")) + +# ── Logging ───────────────────────────────────────────────────────────────── +LOG_DIR: Path = Path(__file__).resolve().parent.parent / "logs" +LOG_ROTATION = os.getenv("POSEFIT_LOG_ROTATION", "20 MB") +LOG_RETENTION = os.getenv("POSEFIT_LOG_RETENTION", "14 days") diff --git a/dead_bug_detector.py b/dead_bug_detector.py deleted file mode 100644 index f964669..0000000 --- a/dead_bug_detector.py +++ /dev/null @@ -1,431 +0,0 @@ -from __future__ import annotations - -import threading -import time -from dataclasses import dataclass -from enum import Enum -from pathlib import Path - -import cv2 -import mediapipe as mp -import numpy as np -from loguru import logger - - -PoseLandmarker = mp.tasks.vision.PoseLandmarker -PoseLandmarkerOptions = mp.tasks.vision.PoseLandmarkerOptions -VisionRunningMode = mp.tasks.vision.RunningMode -BaseOptions = mp.tasks.BaseOptions - - -class DeadBugPhase(str, Enum): - READY = "ready" - EXTENDING = "extending" - NEED_RESET = "need_reset" - NO_POSE = "no_pose" - - -@dataclass(frozen=True) -class Point: - x: float - y: float - z: float - visibility: float - - -@dataclass -class DeadBugMetrics: - left_arm_extended: bool - right_arm_extended: bool - left_leg_extended: bool - right_leg_extended: bool - left_elbow_angle: float - right_elbow_angle: float - left_knee_angle: float - right_knee_angle: float - feedback: list[str] - - -@dataclass -class DeadBugResult: - rep_count: int - phase: DeadBugPhase - side: str | None - is_standard: bool - feedback: list[str] - metrics: DeadBugMetrics | None - - -class DeadBugDetector: - """MediaPipe Pose based dead bug detector and counter. - - The rules are intentionally conservative because a phone stream only gives - us 2D landmarks. A rep is counted when one diagonal pair extends cleanly and - the body returns to the bent-knee ready position. - """ - - LEFT_SHOULDER = 11 - RIGHT_SHOULDER = 12 - LEFT_ELBOW = 13 - RIGHT_ELBOW = 14 - LEFT_WRIST = 15 - RIGHT_WRIST = 16 - LEFT_HIP = 23 - RIGHT_HIP = 24 - LEFT_KNEE = 25 - RIGHT_KNEE = 26 - LEFT_ANKLE = 27 - RIGHT_ANKLE = 28 - - REQUIRED_LANDMARKS = ( - LEFT_SHOULDER, - RIGHT_SHOULDER, - LEFT_ELBOW, - RIGHT_ELBOW, - LEFT_WRIST, - RIGHT_WRIST, - LEFT_HIP, - RIGHT_HIP, - LEFT_KNEE, - RIGHT_KNEE, - LEFT_ANKLE, - RIGHT_ANKLE, - ) - - def __init__( - self, - model_path: str | Path | None = None, - *, - visibility_threshold: float = 0.45, - extension_confirm_frames: int = 4, - reset_confirm_frames: int = 3, - prefer_gpu: bool = True, - ) -> None: - if model_path is None: - model_path = Path(__file__).resolve().parent / "pose_models" / "pose_landmarker_full.task" - - self.model_path = str(model_path) - self.visibility_threshold = visibility_threshold - self.extension_confirm_frames = extension_confirm_frames - self.reset_confirm_frames = reset_confirm_frames - self.delegate = BaseOptions.Delegate.GPU if prefer_gpu else BaseOptions.Delegate.CPU - - self._latest_result = None - self._result_lock = threading.Lock() - self._result_event = threading.Event() - self._inflight = False - self._inflight_started_at = 0.0 - - def on_result(pose_result, _image, _timestamp_ms): - with self._result_lock: - self._latest_result = pose_result - self._inflight = False - self._inflight_started_at = 0.0 - self._result_event.set() - - self._landmarker = self._create_landmarker(on_result) - - self.rep_count = 0 - self.phase = DeadBugPhase.READY - self.active_side: str | None = None - self._candidate_side: str | None = None - self._candidate_frames = 0 - self._reset_frames = 0 - self._last_timestamp_ms = -1 - - def close(self) -> None: - self._landmarker.close() - - def _create_landmarker(self, result_callback): - try: - landmarker = PoseLandmarker.create_from_options( - self._build_options(self.delegate, result_callback) - ) - logger.info("MediaPipe PoseLandmarker initialized with {} delegate", self.delegate.name) - return landmarker - except Exception as exc: - if self.delegate == BaseOptions.Delegate.CPU: - raise - - logger.warning("MediaPipe GPU delegate unavailable, falling back to CPU: {}", exc) - self.delegate = BaseOptions.Delegate.CPU - landmarker = PoseLandmarker.create_from_options( - self._build_options(self.delegate, result_callback) - ) - logger.info("MediaPipe PoseLandmarker initialized with CPU delegate") - return landmarker - - def _build_options(self, delegate, result_callback): - return PoseLandmarkerOptions( - base_options=BaseOptions(model_asset_path=self.model_path, delegate=delegate), - running_mode=VisionRunningMode.LIVE_STREAM, - result_callback=result_callback, - num_poses=1, - min_pose_detection_confidence=0.5, - min_pose_presence_confidence=0.5, - min_tracking_confidence=0.5, - ) - - def process_frame(self, bgr_frame: np.ndarray, timestamp_ms: int) -> tuple[np.ndarray, DeadBugResult]: - timestamp_ms = self._normalize_timestamp(timestamp_ms) - - with self._result_lock: - if self._inflight and time.monotonic() - self._inflight_started_at > 0.5: - logger.warning("MediaPipe detect_async timed out; allowing next frame submission") - self._inflight = False - self._inflight_started_at = 0.0 - should_submit = not self._inflight - if should_submit: - self._inflight = True - self._inflight_started_at = time.monotonic() - - if should_submit: - rgba_frame = cv2.cvtColor(bgr_frame, cv2.COLOR_BGR2RGBA) - mp_image = mp.Image(image_format=mp.ImageFormat.SRGBA, data=rgba_frame) - self._result_event.clear() - try: - self._landmarker.detect_async(mp_image, timestamp_ms) - except Exception: - with self._result_lock: - self._inflight = False - self._inflight_started_at = 0.0 - raise - self._result_event.wait(timeout=0.08) - - with self._result_lock: - pose_result = self._latest_result - - annotated = bgr_frame.copy() - if pose_result is None or not pose_result.pose_landmarks: - result = DeadBugResult( - rep_count=self.rep_count, - phase=DeadBugPhase.NO_POSE, - side=self.active_side, - is_standard=False, - feedback=["No full body detected"], - metrics=None, - ) - self._draw_status(annotated, result) - return annotated, result - - landmarks = [Point(lm.x, lm.y, lm.z, getattr(lm, "visibility", 1.0)) for lm in pose_result.pose_landmarks[0]] - self._draw_landmarks(annotated, landmarks) - - if not self._has_required_visibility(landmarks): - result = DeadBugResult( - rep_count=self.rep_count, - phase=DeadBugPhase.NO_POSE, - side=self.active_side, - is_standard=False, - feedback=["Keep shoulders, elbows, wrists, hips, knees, ankles visible"], - metrics=None, - ) - self._draw_status(annotated, result) - return annotated, result - - metrics = self._calculate_metrics(landmarks) - result = self._update_state(metrics) - self._draw_status(annotated, result) - return annotated, result - - def _normalize_timestamp(self, timestamp_ms: int) -> int: - if timestamp_ms <= self._last_timestamp_ms: - timestamp_ms = self._last_timestamp_ms + 1 - self._last_timestamp_ms = timestamp_ms - return timestamp_ms - - def _has_required_visibility(self, landmarks: list[Point]) -> bool: - return all(landmarks[i].visibility >= self.visibility_threshold for i in self.REQUIRED_LANDMARKS) - - def _calculate_metrics(self, lm: list[Point]) -> DeadBugMetrics: - left_elbow = angle(lm[self.LEFT_SHOULDER], lm[self.LEFT_ELBOW], lm[self.LEFT_WRIST]) - right_elbow = angle(lm[self.RIGHT_SHOULDER], lm[self.RIGHT_ELBOW], lm[self.RIGHT_WRIST]) - left_knee = angle(lm[self.LEFT_HIP], lm[self.LEFT_KNEE], lm[self.LEFT_ANKLE]) - right_knee = angle(lm[self.RIGHT_HIP], lm[self.RIGHT_KNEE], lm[self.RIGHT_ANKLE]) - - shoulder_width = distance(lm[self.LEFT_SHOULDER], lm[self.RIGHT_SHOULDER]) - hip_width = distance(lm[self.LEFT_HIP], lm[self.RIGHT_HIP]) - scale = max(shoulder_width, hip_width, 0.08) - - left_arm_extended = ( - left_elbow >= 145 - and distance(lm[self.LEFT_SHOULDER], lm[self.LEFT_WRIST]) >= scale * 1.15 - and lm[self.LEFT_WRIST].y <= lm[self.LEFT_SHOULDER].y + scale * 0.35 - ) - right_arm_extended = ( - right_elbow >= 145 - and distance(lm[self.RIGHT_SHOULDER], lm[self.RIGHT_WRIST]) >= scale * 1.15 - and lm[self.RIGHT_WRIST].y <= lm[self.RIGHT_SHOULDER].y + scale * 0.35 - ) - - left_leg_extended = ( - left_knee >= 150 - and distance(lm[self.LEFT_HIP], lm[self.LEFT_ANKLE]) >= scale * 1.55 - and lm[self.LEFT_ANKLE].y >= lm[self.LEFT_KNEE].y - scale * 0.2 - ) - right_leg_extended = ( - right_knee >= 150 - and distance(lm[self.RIGHT_HIP], lm[self.RIGHT_ANKLE]) >= scale * 1.55 - and lm[self.RIGHT_ANKLE].y >= lm[self.RIGHT_KNEE].y - scale * 0.2 - ) - - feedback: list[str] = [] - if left_arm_extended and left_elbow < 160: - feedback.append("Straighten left arm") - if right_arm_extended and right_elbow < 160: - feedback.append("Straighten right arm") - if left_leg_extended and left_knee < 165: - feedback.append("Straighten left leg") - if right_leg_extended and right_knee < 165: - feedback.append("Straighten right leg") - - return DeadBugMetrics( - left_arm_extended=left_arm_extended, - right_arm_extended=right_arm_extended, - left_leg_extended=left_leg_extended, - right_leg_extended=right_leg_extended, - left_elbow_angle=left_elbow, - right_elbow_angle=right_elbow, - left_knee_angle=left_knee, - right_knee_angle=right_knee, - feedback=feedback, - ) - - def _update_state(self, metrics: DeadBugMetrics) -> DeadBugResult: - side = self._detect_diagonal_extension(metrics) - ready = self._is_ready_position(metrics) - - if side is None: - self._candidate_side = None - self._candidate_frames = 0 - elif side == self._candidate_side: - self._candidate_frames += 1 - else: - self._candidate_side = side - self._candidate_frames = 1 - - if self.phase in (DeadBugPhase.READY, DeadBugPhase.NO_POSE): - if self._candidate_frames >= self.extension_confirm_frames and side is not None: - self.phase = DeadBugPhase.EXTENDING - self.active_side = side - self._reset_frames = 0 - elif self.phase == DeadBugPhase.EXTENDING: - if side == self.active_side: - self.phase = DeadBugPhase.NEED_RESET - elif self.phase == DeadBugPhase.NEED_RESET: - if ready: - self._reset_frames += 1 - if self._reset_frames >= self.reset_confirm_frames: - self.rep_count += 1 - self.phase = DeadBugPhase.READY - self.active_side = None - self._candidate_side = None - self._candidate_frames = 0 - self._reset_frames = 0 - else: - self._reset_frames = 0 - - feedback = list(metrics.feedback) - if side is None and not ready: - feedback.append("Extend opposite arm and leg only") - if ready: - feedback.append("Ready position") - elif side == "left_arm_right_leg": - feedback.append("Left arm + right leg") - elif side == "right_arm_left_leg": - feedback.append("Right arm + left leg") - - is_standard = side is not None and not metrics.feedback - return DeadBugResult( - rep_count=self.rep_count, - phase=self.phase, - side=side, - is_standard=is_standard, - feedback=feedback[:3], - metrics=metrics, - ) - - def _detect_diagonal_extension(self, metrics: DeadBugMetrics) -> str | None: - if metrics.left_leg_extended and metrics.right_leg_extended: - return None - - # Dead bug starts with both arms raised, so the non-moving arm may also - # look "extended" in 2D. Infer the rep from the single extended leg and - # require the opposite arm to be extended, instead of rejecting both-arm - # frames as same-side noise. - if metrics.right_leg_extended and metrics.left_arm_extended: - return "left_arm_right_leg" - if metrics.left_leg_extended and metrics.right_arm_extended: - return "right_arm_left_leg" - return None - - def _is_ready_position(self, metrics: DeadBugMetrics) -> bool: - knees_bent = metrics.left_knee_angle <= 140 and metrics.right_knee_angle <= 140 - legs_not_extended = not metrics.left_leg_extended and not metrics.right_leg_extended - return knees_bent and legs_not_extended and self._detect_diagonal_extension(metrics) is None - - def _draw_landmarks(self, image: np.ndarray, landmarks: list[Point]) -> None: - h, w = image.shape[:2] - connections = getattr(getattr(mp, "solutions", None), "pose", None) - pose_connections = getattr(connections, "POSE_CONNECTIONS", _POSE_CONNECTIONS) - for start, end in pose_connections: - if start >= len(landmarks) or end >= len(landmarks): - continue - p1 = landmarks[start] - p2 = landmarks[end] - if p1.visibility < self.visibility_threshold or p2.visibility < self.visibility_threshold: - continue - cv2.line( - image, - (int(p1.x * w), int(p1.y * h)), - (int(p2.x * w), int(p2.y * h)), - (65, 180, 255), - 2, - ) - for idx in self.REQUIRED_LANDMARKS: - p = landmarks[idx] - if p.visibility >= self.visibility_threshold: - cv2.circle(image, (int(p.x * w), int(p.y * h)), 4, (80, 255, 120), -1) - - def _draw_status(self, image: np.ndarray, result: DeadBugResult) -> None: - color = (60, 220, 90) if result.is_standard else (50, 180, 255) - cv2.rectangle(image, (12, 12), (520, 142), (20, 20, 20), -1) - cv2.putText(image, f"Dead bug reps: {result.rep_count}", (28, 48), cv2.FONT_HERSHEY_SIMPLEX, 0.9, color, 2) - cv2.putText(image, f"phase: {result.phase.value}", (28, 82), cv2.FONT_HERSHEY_SIMPLEX, 0.68, (230, 230, 230), 2) - status = "standard" if result.is_standard else "adjust" - cv2.putText(image, f"status: {status}", (28, 116), cv2.FONT_HERSHEY_SIMPLEX, 0.68, color, 2) - - y = 170 - for text in result.feedback: - cv2.putText(image, text, (28, y), cv2.FONT_HERSHEY_SIMPLEX, 0.68, (255, 255, 255), 2) - y += 30 - - -def angle(a: Point, b: Point, c: Point) -> float: - ba = np.array([a.x - b.x, a.y - b.y], dtype=np.float32) - bc = np.array([c.x - b.x, c.y - b.y], dtype=np.float32) - denom = float(np.linalg.norm(ba) * np.linalg.norm(bc)) - if denom == 0: - return 0.0 - cos_value = float(np.dot(ba, bc) / denom) - return float(np.degrees(np.arccos(np.clip(cos_value, -1.0, 1.0)))) - - -def distance(a: Point, b: Point) -> float: - return float(np.hypot(a.x - b.x, a.y - b.y)) - - -_POSE_CONNECTIONS = ( - (11, 12), - (11, 13), - (13, 15), - (12, 14), - (14, 16), - (11, 23), - (12, 24), - (23, 24), - (23, 25), - (25, 27), - (24, 26), - (26, 28), -) diff --git a/handle_client.py b/handle_client.py deleted file mode 100644 index 69f316e..0000000 --- a/handle_client.py +++ /dev/null @@ -1,188 +0,0 @@ -import asyncio -import json -import os -import re -import websockets -import cv2 -from loguru import logger -from aiortc import RTCPeerConnection, RTCSessionDescription, RTCIceCandidate -from aiortc.mediastreams import MediaStreamError - -from dead_bug_detector import DeadBugDetector -from rep_announcer import RepAnnouncer - - -PROCESS_EVERY_N_FRAMES = max(1, int(os.getenv("POSEFIT_PROCESS_EVERY_N_FRAMES", "1"))) -TARGET_FRAME_WIDTH = max(1, int(os.getenv("POSEFIT_FRAME_WIDTH", "1080"))) -TARGET_FRAME_HEIGHT = max(1, int(os.getenv("POSEFIT_FRAME_HEIGHT", "720"))) - - -def format_pose_debug(pose_result): - metrics = pose_result.metrics - if metrics is None: - return "metrics=None" - - return ( - f"side={pose_result.side}, standard={pose_result.is_standard}, " - f"angles(le={metrics.left_elbow_angle:.1f}, re={metrics.right_elbow_angle:.1f}, " - f"lk={metrics.left_knee_angle:.1f}, rk={metrics.right_knee_angle:.1f}), " - f"extended(la={metrics.left_arm_extended}, ra={metrics.right_arm_extended}, " - f"ll={metrics.left_leg_extended}, rl={metrics.right_leg_extended})" - ) - - -async def handle_client(websocket): - client = websocket.remote_address - logger.info(f"Client connected: {client}") - - pc = RTCPeerConnection() - video_task = None - - def parse_ice(data): - match = re.match( - r'candidate:(\S+) (\d) (\S+) (\d+) (\S+) (\d+) typ (\S+)(?: raddr (\S+) rport (\d+))?', - data["candidate"] - ) - if not match: - return None - g = match.groups() - cand = RTCIceCandidate( - foundation=g[0], - component=int(g[1]), - protocol=g[2].lower(), - priority=int(g[3]), - ip=g[4], - port=int(g[5]), - type=g[6], - relatedAddress=g[7], - relatedPort=int(g[8]) if g[8] else None, - ) - cand.sdpMid = data.get("sdpMid") - cand.sdpMLineIndex = data.get("sdpMLineIndex", 0) - return cand - - async def receive_video(track): - logger.info( - "Start receiving video frames, process_every_n_frames={}, target_frame={}x{}", - PROCESS_EVERY_N_FRAMES, - TARGET_FRAME_WIDTH, - TARGET_FRAME_HEIGHT, - ) - frame_count = 0 - processed_count = 0 - detector = DeadBugDetector() - announcer = RepAnnouncer() - last_announced_rep = 0 - last_pose_result = None - last_annotated = None - try: - while True: - frame = await track.recv() - frame_count += 1 - raw_img = frame.to_ndarray(format="bgr24") - img = normalize_frame(raw_img) - timestamp_ms = int(frame.time * 1000) if frame.time is not None else frame_count * 33 - - if frame_count % PROCESS_EVERY_N_FRAMES == 0 or last_pose_result is None: - processed_count += 1 - last_annotated, last_pose_result = detector.process_frame(img, timestamp_ms) - if last_pose_result.rep_count > last_announced_rep: - last_announced_rep = last_pose_result.rep_count - announcer.announce_count(last_announced_rep) - - cv2.imshow("Android Camera (WebRTC)", last_annotated if last_annotated is not None else img) - - if frame_count % 100 == 0: - logger.info( - "Received {} frames, processed={}, raw_shape={}, shape={}, reps={}, phase={}, feedback={}, {}", - frame_count, - processed_count, - raw_img.shape, - img.shape, - last_pose_result.rep_count if last_pose_result is not None else 0, - last_pose_result.phase.value if last_pose_result is not None else "none", - " | ".join(last_pose_result.feedback) if last_pose_result is not None else "", - format_pose_debug(last_pose_result) if last_pose_result is not None else "metrics=None", - ) - - if cv2.waitKey(1) & 0xFF == 27: - logger.info("ESC pressed, closing display") - break - except asyncio.CancelledError: - logger.info("Video receive task cancelled") - except MediaStreamError: - logger.info("Video track ended") - except Exception as e: - logger.exception(f"Video receive error: {e!r}") - finally: - announcer.close() - detector.close() - - @pc.on("track") - async def on_track(track): - logger.info(f"Track received: kind={track.kind}") - if track.kind == "video": - nonlocal video_task - video_task = asyncio.ensure_future(receive_video(track)) - - @pc.on("iceconnectionstatechange") - async def on_iceconnectionstatechange(): - logger.info(f"ICE state: {pc.iceConnectionState}") - if pc.iceConnectionState in ("failed", "closed", "disconnected"): - await pc.close() - - try: - async for message in websocket: - data = json.loads(message) - msg_type = data.get("type") - - if msg_type == "offer": - offer = RTCSessionDescription(sdp=data["sdp"], type="offer") - await pc.setRemoteDescription(offer) - - answer = await pc.createAnswer() - await pc.setLocalDescription(answer) - - await websocket.send(json.dumps({ - "type": "answer", - "sdp": pc.localDescription.sdp, - })) - - elif msg_type == "candidate": - cand = parse_ice(data) - if cand: - await pc.addIceCandidate(cand) - - except websockets.ConnectionClosed: - logger.info(f"Client disconnected: {client}") - except Exception as e: - logger.exception(f"Error: {e}") - finally: - if video_task: - video_task.cancel() - try: - await video_task - except asyncio.CancelledError: - pass - await pc.close() - cv2.destroyAllWindows() - logger.info(f"Connection closed: {client}") - - -async def main(): - host = "0.0.0.0" - port = 8765 - logger.info(f"WebRTC signaling server: ws://{host}:{port}") - async with websockets.serve(handle_client, host, port, max_size=10 * 1024 * 1024): - await asyncio.Future() - - -def normalize_frame(image): - height, width = image.shape[:2] - if width == TARGET_FRAME_WIDTH and height == TARGET_FRAME_HEIGHT: - return image - return cv2.resize(image, (TARGET_FRAME_WIDTH, TARGET_FRAME_HEIGHT), interpolation=cv2.INTER_AREA) - - -if __name__ == "__main__": - asyncio.run(main()) diff --git a/main.py b/main.py deleted file mode 100644 index b6c2300..0000000 --- a/main.py +++ /dev/null @@ -1,29 +0,0 @@ -import os -import faulthandler -from pathlib import Path - -from loguru import logger - -os.environ["MEDIAPIPE_DISABLE_LOGGING"] = "1" -os.environ["GLOG_minloglevel"] = "3" - -import asyncio - -from handle_client import main - - -if __name__ == "__main__": - log_dir = Path(__file__).resolve().parent / "logs" - log_dir.mkdir(exist_ok=True) - crash_log = open(log_dir / "posefit-crash.log", "a", buffering=1) - faulthandler.enable(file=crash_log, all_threads=True) - logger.add( - log_dir / "posefit-server_{time:YYYY-MM-DD}.log", - rotation="20 MB", - retention="14 days", - enqueue=True, - backtrace=True, - diagnose=True, - ) - logger.info("Starting server...") - asyncio.run(main()) diff --git a/run.py b/run.py new file mode 100644 index 0000000..ee20995 --- /dev/null +++ b/run.py @@ -0,0 +1,4 @@ +from app.main import main + +if __name__ == "__main__": + main() diff --git a/tests/test_dead_bug_rules.py b/tests/test_dead_bug_rules.py new file mode 100644 index 0000000..8923e1e --- /dev/null +++ b/tests/test_dead_bug_rules.py @@ -0,0 +1,65 @@ +from __future__ import annotations + +from app.exercises.dead_bug.metrics import calculate_metrics +from app.exercises.dead_bug.rules import detect_diagonal_extension, has_required_visibility, is_ready_position +from app.exercises.dead_bug.state_machine import DeadBugStateMachine +from app.exercises.dead_bug.types import DeadBugMetrics, DeadBugPhase, Point + + +class TestDeadBugRules: + def _make_landmark(self, x=0.5, y=0.5, z=0.0, visibility=1.0): + return Point(x, y, z, visibility) + + def _make_visible_landmarks(self): + return [self._make_landmark() for _ in range(33)] + + def test_has_required_visibility_all_visible(self): + lm = self._make_visible_landmarks() + indices = (11, 12, 13, 14, 15, 16, 23, 24, 25, 26, 27, 28) + assert has_required_visibility(lm, indices, 0.45) + + def test_has_required_visibility_low(self): + lm = self._make_visible_landmarks() + lm[11] = self._make_landmark(visibility=0.1) + indices = (11, 12, 13, 14, 15, 16, 23, 24, 25, 26, 27, 28) + assert not has_required_visibility(lm, indices, 0.45) + + def test_detect_diagonal_extension_none(self): + metrics = DeadBugMetrics( + left_arm_extended=False, right_arm_extended=False, + left_leg_extended=False, right_leg_extended=False, + left_elbow_angle=90, right_elbow_angle=90, + left_knee_angle=90, right_knee_angle=90, + feedback=[], + ) + assert detect_diagonal_extension(metrics) is None + + def test_detect_diagonal_extension_left_arm_right_leg(self): + metrics = DeadBugMetrics( + left_arm_extended=True, right_arm_extended=False, + left_leg_extended=False, right_leg_extended=True, + left_elbow_angle=160, right_elbow_angle=90, + left_knee_angle=90, right_knee_angle=160, + feedback=[], + ) + assert detect_diagonal_extension(metrics) == "left_arm_right_leg" + + def test_is_ready_position(self): + metrics = DeadBugMetrics( + left_arm_extended=False, right_arm_extended=False, + left_leg_extended=False, right_leg_extended=False, + left_elbow_angle=90, right_elbow_angle=90, + left_knee_angle=100, right_knee_angle=100, + feedback=[], + ) + assert is_ready_position(metrics) + + def test_is_not_ready_legs_extended(self): + metrics = DeadBugMetrics( + left_arm_extended=False, right_arm_extended=False, + left_leg_extended=True, right_leg_extended=False, + left_elbow_angle=90, right_elbow_angle=90, + left_knee_angle=100, right_knee_angle=100, + feedback=[], + ) + assert not is_ready_position(metrics) diff --git a/tests/test_dead_bug_state_machine.py b/tests/test_dead_bug_state_machine.py new file mode 100644 index 0000000..a7a10de --- /dev/null +++ b/tests/test_dead_bug_state_machine.py @@ -0,0 +1,42 @@ +from __future__ import annotations + +from app.exercises.dead_bug.state_machine import DeadBugStateMachine +from app.exercises.dead_bug.types import DeadBugMetrics, DeadBugPhase + + +class TestDeadBugStateMachine: + def _ready_metrics(self) -> DeadBugMetrics: + return DeadBugMetrics( + left_arm_extended=False, right_arm_extended=False, + left_leg_extended=False, right_leg_extended=False, + left_elbow_angle=90, right_elbow_angle=90, + left_knee_angle=100, right_knee_angle=100, + feedback=[], + ) + + def _extended_left(self) -> DeadBugMetrics: + return DeadBugMetrics( + left_arm_extended=True, right_arm_extended=False, + left_leg_extended=False, right_leg_extended=True, + left_elbow_angle=160, right_elbow_angle=90, + left_knee_angle=90, right_knee_angle=160, + feedback=[], + ) + + def test_initial_state(self): + sm = DeadBugStateMachine() + assert sm.phase == DeadBugPhase.READY + assert sm.rep_count == 0 + + def test_no_transition_in_ready(self): + sm = DeadBugStateMachine() + result = sm.update(self._ready_metrics()) + assert sm.phase == DeadBugPhase.READY + assert result.rep_count == 0 + + def test_confirm_extension(self): + sm = DeadBugStateMachine(extension_confirm_frames=2, reset_confirm_frames=2) + sm.update(self._extended_left()) + assert sm.phase == DeadBugPhase.READY + sm.update(self._extended_left()) + assert sm.phase == DeadBugPhase.EXTENDING diff --git a/tests/test_ice_parser.py b/tests/test_ice_parser.py new file mode 100644 index 0000000..e69e3ba --- /dev/null +++ b/tests/test_ice_parser.py @@ -0,0 +1,35 @@ +from __future__ import annotations + +from app.signaling.ice_parser import parse_ice + + +class TestIceParser: + def test_parse_valid_ice(self): + data = { + "candidate": "1234567890 1 UDP 2130706431 192.168.1.1 12345 typ host", + "sdpMid": "0", + "sdpMLineIndex": 0, + } + cand = parse_ice(data) + assert cand is not None + assert cand.foundation == "1234567890" + assert cand.component == 1 + assert cand.protocol == "udp" + assert cand.ip == "192.168.1.1" + assert cand.port == 12345 + assert cand.type == "host" + + def test_parse_invalid_ice(self): + assert parse_ice({"candidate": "invalid"}) is None + + def test_parse_srflx(self): + data = { + "candidate": "abcdef 1 UDP 1686052607 203.0.113.1 50000 typ srflx raddr 192.168.1.1 rport 12345", + "sdpMid": "0", + "sdpMLineIndex": 0, + } + cand = parse_ice(data) + assert cand is not None + assert cand.type == "srflx" + assert cand.relatedAddress == "192.168.1.1" + assert cand.relatedPort == 12345