Files
posefit-server/handle_client.py
T

189 lines
6.7 KiB
Python

import asyncio
import json
import os
import re
import websockets
import cv2
from loguru import logger
from aiortc import RTCPeerConnection, RTCSessionDescription, RTCIceCandidate
from aiortc.mediastreams import MediaStreamError
from dead_bug_detector import DeadBugDetector
from rep_announcer import RepAnnouncer
PROCESS_EVERY_N_FRAMES = max(1, int(os.getenv("POSEFIT_PROCESS_EVERY_N_FRAMES", "1")))
TARGET_FRAME_WIDTH = max(1, int(os.getenv("POSEFIT_FRAME_WIDTH", "1080")))
TARGET_FRAME_HEIGHT = max(1, int(os.getenv("POSEFIT_FRAME_HEIGHT", "720")))
def format_pose_debug(pose_result):
metrics = pose_result.metrics
if metrics is None:
return "metrics=None"
return (
f"side={pose_result.side}, standard={pose_result.is_standard}, "
f"angles(le={metrics.left_elbow_angle:.1f}, re={metrics.right_elbow_angle:.1f}, "
f"lk={metrics.left_knee_angle:.1f}, rk={metrics.right_knee_angle:.1f}), "
f"extended(la={metrics.left_arm_extended}, ra={metrics.right_arm_extended}, "
f"ll={metrics.left_leg_extended}, rl={metrics.right_leg_extended})"
)
async def handle_client(websocket):
client = websocket.remote_address
logger.info(f"Client connected: {client}")
pc = RTCPeerConnection()
video_task = None
def parse_ice(data):
match = re.match(
r'candidate:(\S+) (\d) (\S+) (\d+) (\S+) (\d+) typ (\S+)(?: raddr (\S+) rport (\d+))?',
data["candidate"]
)
if not match:
return None
g = match.groups()
cand = RTCIceCandidate(
foundation=g[0],
component=int(g[1]),
protocol=g[2].lower(),
priority=int(g[3]),
ip=g[4],
port=int(g[5]),
type=g[6],
relatedAddress=g[7],
relatedPort=int(g[8]) if g[8] else None,
)
cand.sdpMid = data.get("sdpMid")
cand.sdpMLineIndex = data.get("sdpMLineIndex", 0)
return cand
async def receive_video(track):
logger.info(
"Start receiving video frames, process_every_n_frames={}, target_frame={}x{}",
PROCESS_EVERY_N_FRAMES,
TARGET_FRAME_WIDTH,
TARGET_FRAME_HEIGHT,
)
frame_count = 0
processed_count = 0
detector = DeadBugDetector()
announcer = RepAnnouncer()
last_announced_rep = 0
last_pose_result = None
last_annotated = None
try:
while True:
frame = await track.recv()
frame_count += 1
raw_img = frame.to_ndarray(format="bgr24")
img = normalize_frame(raw_img)
timestamp_ms = int(frame.time * 1000) if frame.time is not None else frame_count * 33
if frame_count % PROCESS_EVERY_N_FRAMES == 0 or last_pose_result is None:
processed_count += 1
last_annotated, last_pose_result = detector.process_frame(img, timestamp_ms)
if last_pose_result.rep_count > last_announced_rep:
last_announced_rep = last_pose_result.rep_count
announcer.announce_count(last_announced_rep)
cv2.imshow("Android Camera (WebRTC)", last_annotated if last_annotated is not None else img)
if frame_count % 100 == 0:
logger.info(
"Received {} frames, processed={}, raw_shape={}, shape={}, reps={}, phase={}, feedback={}, {}",
frame_count,
processed_count,
raw_img.shape,
img.shape,
last_pose_result.rep_count if last_pose_result is not None else 0,
last_pose_result.phase.value if last_pose_result is not None else "none",
" | ".join(last_pose_result.feedback) if last_pose_result is not None else "",
format_pose_debug(last_pose_result) if last_pose_result is not None else "metrics=None",
)
if cv2.waitKey(1) & 0xFF == 27:
logger.info("ESC pressed, closing display")
break
except asyncio.CancelledError:
logger.info("Video receive task cancelled")
except MediaStreamError:
logger.info("Video track ended")
except Exception as e:
logger.exception(f"Video receive error: {e!r}")
finally:
announcer.close()
detector.close()
@pc.on("track")
async def on_track(track):
logger.info(f"Track received: kind={track.kind}")
if track.kind == "video":
nonlocal video_task
video_task = asyncio.ensure_future(receive_video(track))
@pc.on("iceconnectionstatechange")
async def on_iceconnectionstatechange():
logger.info(f"ICE state: {pc.iceConnectionState}")
if pc.iceConnectionState in ("failed", "closed", "disconnected"):
await pc.close()
try:
async for message in websocket:
data = json.loads(message)
msg_type = data.get("type")
if msg_type == "offer":
offer = RTCSessionDescription(sdp=data["sdp"], type="offer")
await pc.setRemoteDescription(offer)
answer = await pc.createAnswer()
await pc.setLocalDescription(answer)
await websocket.send(json.dumps({
"type": "answer",
"sdp": pc.localDescription.sdp,
}))
elif msg_type == "candidate":
cand = parse_ice(data)
if cand:
await pc.addIceCandidate(cand)
except websockets.ConnectionClosed:
logger.info(f"Client disconnected: {client}")
except Exception as e:
logger.exception(f"Error: {e}")
finally:
if video_task:
video_task.cancel()
try:
await video_task
except asyncio.CancelledError:
pass
await pc.close()
cv2.destroyAllWindows()
logger.info(f"Connection closed: {client}")
async def main():
host = "0.0.0.0"
port = 8765
logger.info(f"WebRTC signaling server: ws://{host}:{port}")
async with websockets.serve(handle_client, host, port, max_size=10 * 1024 * 1024):
await asyncio.Future()
def normalize_frame(image):
height, width = image.shape[:2]
if width == TARGET_FRAME_WIDTH and height == TARGET_FRAME_HEIGHT:
return image
return cv2.resize(image, (TARGET_FRAME_WIDTH, TARGET_FRAME_HEIGHT), interpolation=cv2.INTER_AREA)
if __name__ == "__main__":
asyncio.run(main())