189 lines
6.7 KiB
Python
189 lines
6.7 KiB
Python
import asyncio
|
|
import json
|
|
import os
|
|
import re
|
|
import websockets
|
|
import cv2
|
|
from loguru import logger
|
|
from aiortc import RTCPeerConnection, RTCSessionDescription, RTCIceCandidate
|
|
from aiortc.mediastreams import MediaStreamError
|
|
|
|
from dead_bug_detector import DeadBugDetector
|
|
from rep_announcer import RepAnnouncer
|
|
|
|
|
|
PROCESS_EVERY_N_FRAMES = max(1, int(os.getenv("POSEFIT_PROCESS_EVERY_N_FRAMES", "1")))
|
|
TARGET_FRAME_WIDTH = max(1, int(os.getenv("POSEFIT_FRAME_WIDTH", "1080")))
|
|
TARGET_FRAME_HEIGHT = max(1, int(os.getenv("POSEFIT_FRAME_HEIGHT", "720")))
|
|
|
|
|
|
def format_pose_debug(pose_result):
|
|
metrics = pose_result.metrics
|
|
if metrics is None:
|
|
return "metrics=None"
|
|
|
|
return (
|
|
f"side={pose_result.side}, standard={pose_result.is_standard}, "
|
|
f"angles(le={metrics.left_elbow_angle:.1f}, re={metrics.right_elbow_angle:.1f}, "
|
|
f"lk={metrics.left_knee_angle:.1f}, rk={metrics.right_knee_angle:.1f}), "
|
|
f"extended(la={metrics.left_arm_extended}, ra={metrics.right_arm_extended}, "
|
|
f"ll={metrics.left_leg_extended}, rl={metrics.right_leg_extended})"
|
|
)
|
|
|
|
|
|
async def handle_client(websocket):
|
|
client = websocket.remote_address
|
|
logger.info(f"Client connected: {client}")
|
|
|
|
pc = RTCPeerConnection()
|
|
video_task = None
|
|
|
|
def parse_ice(data):
|
|
match = re.match(
|
|
r'candidate:(\S+) (\d) (\S+) (\d+) (\S+) (\d+) typ (\S+)(?: raddr (\S+) rport (\d+))?',
|
|
data["candidate"]
|
|
)
|
|
if not match:
|
|
return None
|
|
g = match.groups()
|
|
cand = RTCIceCandidate(
|
|
foundation=g[0],
|
|
component=int(g[1]),
|
|
protocol=g[2].lower(),
|
|
priority=int(g[3]),
|
|
ip=g[4],
|
|
port=int(g[5]),
|
|
type=g[6],
|
|
relatedAddress=g[7],
|
|
relatedPort=int(g[8]) if g[8] else None,
|
|
)
|
|
cand.sdpMid = data.get("sdpMid")
|
|
cand.sdpMLineIndex = data.get("sdpMLineIndex", 0)
|
|
return cand
|
|
|
|
async def receive_video(track):
|
|
logger.info(
|
|
"Start receiving video frames, process_every_n_frames={}, target_frame={}x{}",
|
|
PROCESS_EVERY_N_FRAMES,
|
|
TARGET_FRAME_WIDTH,
|
|
TARGET_FRAME_HEIGHT,
|
|
)
|
|
frame_count = 0
|
|
processed_count = 0
|
|
detector = DeadBugDetector()
|
|
announcer = RepAnnouncer()
|
|
last_announced_rep = 0
|
|
last_pose_result = None
|
|
last_annotated = None
|
|
try:
|
|
while True:
|
|
frame = await track.recv()
|
|
frame_count += 1
|
|
raw_img = frame.to_ndarray(format="bgr24")
|
|
img = normalize_frame(raw_img)
|
|
timestamp_ms = int(frame.time * 1000) if frame.time is not None else frame_count * 33
|
|
|
|
if frame_count % PROCESS_EVERY_N_FRAMES == 0 or last_pose_result is None:
|
|
processed_count += 1
|
|
last_annotated, last_pose_result = detector.process_frame(img, timestamp_ms)
|
|
if last_pose_result.rep_count > last_announced_rep:
|
|
last_announced_rep = last_pose_result.rep_count
|
|
announcer.announce_count(last_announced_rep)
|
|
|
|
cv2.imshow("Android Camera (WebRTC)", last_annotated if last_annotated is not None else img)
|
|
|
|
if frame_count % 100 == 0:
|
|
logger.info(
|
|
"Received {} frames, processed={}, raw_shape={}, shape={}, reps={}, phase={}, feedback={}, {}",
|
|
frame_count,
|
|
processed_count,
|
|
raw_img.shape,
|
|
img.shape,
|
|
last_pose_result.rep_count if last_pose_result is not None else 0,
|
|
last_pose_result.phase.value if last_pose_result is not None else "none",
|
|
" | ".join(last_pose_result.feedback) if last_pose_result is not None else "",
|
|
format_pose_debug(last_pose_result) if last_pose_result is not None else "metrics=None",
|
|
)
|
|
|
|
if cv2.waitKey(1) & 0xFF == 27:
|
|
logger.info("ESC pressed, closing display")
|
|
break
|
|
except asyncio.CancelledError:
|
|
logger.info("Video receive task cancelled")
|
|
except MediaStreamError:
|
|
logger.info("Video track ended")
|
|
except Exception as e:
|
|
logger.exception(f"Video receive error: {e!r}")
|
|
finally:
|
|
announcer.close()
|
|
detector.close()
|
|
|
|
@pc.on("track")
|
|
async def on_track(track):
|
|
logger.info(f"Track received: kind={track.kind}")
|
|
if track.kind == "video":
|
|
nonlocal video_task
|
|
video_task = asyncio.ensure_future(receive_video(track))
|
|
|
|
@pc.on("iceconnectionstatechange")
|
|
async def on_iceconnectionstatechange():
|
|
logger.info(f"ICE state: {pc.iceConnectionState}")
|
|
if pc.iceConnectionState in ("failed", "closed", "disconnected"):
|
|
await pc.close()
|
|
|
|
try:
|
|
async for message in websocket:
|
|
data = json.loads(message)
|
|
msg_type = data.get("type")
|
|
|
|
if msg_type == "offer":
|
|
offer = RTCSessionDescription(sdp=data["sdp"], type="offer")
|
|
await pc.setRemoteDescription(offer)
|
|
|
|
answer = await pc.createAnswer()
|
|
await pc.setLocalDescription(answer)
|
|
|
|
await websocket.send(json.dumps({
|
|
"type": "answer",
|
|
"sdp": pc.localDescription.sdp,
|
|
}))
|
|
|
|
elif msg_type == "candidate":
|
|
cand = parse_ice(data)
|
|
if cand:
|
|
await pc.addIceCandidate(cand)
|
|
|
|
except websockets.ConnectionClosed:
|
|
logger.info(f"Client disconnected: {client}")
|
|
except Exception as e:
|
|
logger.exception(f"Error: {e}")
|
|
finally:
|
|
if video_task:
|
|
video_task.cancel()
|
|
try:
|
|
await video_task
|
|
except asyncio.CancelledError:
|
|
pass
|
|
await pc.close()
|
|
cv2.destroyAllWindows()
|
|
logger.info(f"Connection closed: {client}")
|
|
|
|
|
|
async def main():
|
|
host = "0.0.0.0"
|
|
port = 8765
|
|
logger.info(f"WebRTC signaling server: ws://{host}:{port}")
|
|
async with websockets.serve(handle_client, host, port, max_size=10 * 1024 * 1024):
|
|
await asyncio.Future()
|
|
|
|
|
|
def normalize_frame(image):
|
|
height, width = image.shape[:2]
|
|
if width == TARGET_FRAME_WIDTH and height == TARGET_FRAME_HEIGHT:
|
|
return image
|
|
return cv2.resize(image, (TARGET_FRAME_WIDTH, TARGET_FRAME_HEIGHT), interpolation=cv2.INTER_AREA)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
asyncio.run(main())
|