Parse config.yaml into typed AppConfig dataclass
- ServerConfig, VideoConfig, ModelConfig, DeadBugConfig, AudioConfig, LoggingConfig as nested dataclasses - Consumers use config.server.host, config.model.resolved_path etc. - env var overrides preserved via _apply_env_overrides()
This commit is contained in:
@@ -1,10 +1,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from app.diagnostics.crash_handler import enable_crash_handler
|
||||
from configs.load import LOG_DIR
|
||||
from configs.load import config
|
||||
|
||||
|
||||
def startup() -> None:
|
||||
enable_crash_handler(LOG_DIR)
|
||||
enable_crash_handler(config.logging.dir_path)
|
||||
from app.core.logging import setup_logging
|
||||
setup_logging()
|
||||
|
||||
+6
-5
@@ -4,16 +4,17 @@ from pathlib import Path
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from configs.load import LOG_DIR, LOG_RETENTION, LOG_ROTATION
|
||||
from configs.load import config
|
||||
|
||||
|
||||
def setup_logging() -> None:
|
||||
LOG_DIR.mkdir(parents=True, exist_ok=True)
|
||||
log_dir = config.logging.dir_path
|
||||
log_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
logger.add(
|
||||
LOG_DIR / "posefit-server_{time:YYYY-MM-DD}.log",
|
||||
rotation=LOG_ROTATION,
|
||||
retention=LOG_RETENTION,
|
||||
log_dir / "posefit-server_{time:YYYY-MM-DD}.log",
|
||||
rotation=config.logging.rotation,
|
||||
retention=config.logging.retention,
|
||||
enqueue=True,
|
||||
backtrace=True,
|
||||
diagnose=True,
|
||||
|
||||
+6
-5
@@ -1,10 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
|
||||
os.environ["MEDIAPIPE_DISABLE_LOGGING"] = "1"
|
||||
os.environ["GLOG_minloglevel"] = "3"
|
||||
|
||||
import asyncio
|
||||
|
||||
from loguru import logger
|
||||
@@ -16,7 +11,13 @@ from app.signaling.websocket_server import main as serve
|
||||
def main():
|
||||
startup()
|
||||
logger.info("Starting server...")
|
||||
try:
|
||||
asyncio.run(serve())
|
||||
except (KeyboardInterrupt, SystemExit):
|
||||
logger.info("Server stopped by user")
|
||||
except Exception as e:
|
||||
logger.error(f"Server error: {e}")
|
||||
raise
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -7,7 +7,7 @@ import websockets
|
||||
from loguru import logger
|
||||
|
||||
from app.webrtc.peer_session import PeerSession
|
||||
from configs.load import WS_HOST, WS_MAX_SIZE, WS_PORT
|
||||
from configs.load import config
|
||||
|
||||
|
||||
async def handle_client(websocket):
|
||||
@@ -21,6 +21,7 @@ async def handle_client(websocket):
|
||||
|
||||
|
||||
async def main():
|
||||
logger.info(f"WebRTC signaling server: ws://{WS_HOST}:{WS_PORT}")
|
||||
async with websockets.serve(handle_client, WS_HOST, WS_PORT, max_size=WS_MAX_SIZE):
|
||||
cfg = config.server
|
||||
logger.info(f"WebRTC signaling server: ws://{cfg.host}:{cfg.port}")
|
||||
async with websockets.serve(handle_client, cfg.host, cfg.port, max_size=cfg.max_ws_size):
|
||||
await asyncio.Future()
|
||||
|
||||
@@ -10,17 +10,7 @@ from loguru import logger
|
||||
from app.audio.rep_announcer import RepAnnouncer
|
||||
from app.exercises.dead_bug.detector import DeadBugDetector
|
||||
from app.rendering.window_display import close_window, is_esc_pressed, show_frame
|
||||
from configs.load import (
|
||||
EXTENSION_CONFIRM_FRAMES,
|
||||
MODEL_PATH,
|
||||
PREFER_GPU,
|
||||
PROCESS_EVERY_N_FRAMES,
|
||||
REP_ANNOUNCER_ENABLED,
|
||||
REP_ANNOUNCER_RATE,
|
||||
REP_ANNOUNCER_VOLUME,
|
||||
RESET_CONFIRM_FRAMES,
|
||||
VISIBILITY_THRESHOLD,
|
||||
)
|
||||
from configs.load import config
|
||||
|
||||
|
||||
def _format_pose_debug(pose_result) -> str:
|
||||
@@ -41,21 +31,21 @@ class VideoReceiver:
|
||||
self._track = track
|
||||
|
||||
async def run(self) -> None:
|
||||
logger.info("Start receiving video frames, process_every_n={}", PROCESS_EVERY_N_FRAMES)
|
||||
logger.info("Start receiving video frames, process_every_n={}", config.video.process_every_n_frames)
|
||||
|
||||
frame_count = 0
|
||||
processed_count = 0
|
||||
detector = DeadBugDetector(
|
||||
model_path=MODEL_PATH,
|
||||
visibility_threshold=VISIBILITY_THRESHOLD,
|
||||
extension_confirm_frames=EXTENSION_CONFIRM_FRAMES,
|
||||
reset_confirm_frames=RESET_CONFIRM_FRAMES,
|
||||
prefer_gpu=PREFER_GPU,
|
||||
model_path=config.model.resolved_path,
|
||||
visibility_threshold=config.dead_bug.visibility_threshold,
|
||||
extension_confirm_frames=config.dead_bug.extension_confirm_frames,
|
||||
reset_confirm_frames=config.dead_bug.reset_confirm_frames,
|
||||
prefer_gpu=config.model.prefer_gpu,
|
||||
)
|
||||
announcer = RepAnnouncer(
|
||||
enabled=REP_ANNOUNCER_ENABLED,
|
||||
rate=REP_ANNOUNCER_RATE,
|
||||
volume=REP_ANNOUNCER_VOLUME,
|
||||
enabled=config.audio.rep_announcer_enabled,
|
||||
rate=config.audio.rep_announcer_rate,
|
||||
volume=config.audio.rep_announcer_volume,
|
||||
)
|
||||
last_announced_rep = 0
|
||||
last_pose_result = None
|
||||
@@ -68,7 +58,7 @@ class VideoReceiver:
|
||||
raw_img = frame.to_ndarray(format="bgr24")
|
||||
timestamp_ms = int(frame.time * 1000) if frame.time is not None else frame_count * 33
|
||||
|
||||
if frame_count % PROCESS_EVERY_N_FRAMES == 0 or last_pose_result is None:
|
||||
if frame_count % config.video.process_every_n_frames == 0 or last_pose_result is None:
|
||||
processed_count += 1
|
||||
last_annotated, last_pose_result = detector.process_frame(raw_img, timestamp_ms)
|
||||
if last_pose_result.rep_count > last_announced_rep:
|
||||
|
||||
+1
-1
@@ -10,7 +10,7 @@ video:
|
||||
process_every_n_frames: 1
|
||||
|
||||
model:
|
||||
path: "" # empty = auto-detect pose_models/pose_landmarker_full.task
|
||||
path: "./pose_models/pose_landmarker_full.task" # empty = auto-detect pose_models/pose_landmarker_full.task
|
||||
prefer_gpu: true
|
||||
|
||||
dead_bug:
|
||||
|
||||
+82
-68
@@ -1,83 +1,97 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import yaml
|
||||
|
||||
_PROJECT_ROOT = Path(__file__).resolve().parent.parent
|
||||
|
||||
_ENV_MAP = {
|
||||
"POSEFIT_WS_HOST": ("server", "host"),
|
||||
"POSEFIT_WS_PORT": ("server", "port", int),
|
||||
"POSEFIT_WS_MAX_SIZE": ("server", "max_ws_size", int),
|
||||
"POSEFIT_PROCESS_EVERY_N_FRAMES": ("video", "process_every_n_frames", int),
|
||||
"POSEFIT_MODEL_PATH": ("model", "path"),
|
||||
"POSEFIT_PREFER_GPU": ("model", "prefer_gpu", lambda v: v not in ("0", "false", "False")),
|
||||
"POSEFIT_VISIBILITY_THRESHOLD": ("dead_bug", "visibility_threshold", float),
|
||||
"POSEFIT_EXTENSION_CONFIRM_FRAMES": ("dead_bug", "extension_confirm_frames", int),
|
||||
"POSEFIT_RESET_CONFIRM_FRAMES": ("dead_bug", "reset_confirm_frames", int),
|
||||
"POSEFIT_REP_ANNOUNCER_ENABLED": ("audio", "rep_announcer_enabled", lambda v: v not in ("0", "false", "False")),
|
||||
"POSEFIT_REP_ANNOUNCER_RATE": ("audio", "rep_announcer_rate", int),
|
||||
"POSEFIT_REP_ANNOUNCER_VOLUME": ("audio", "rep_announcer_volume", float),
|
||||
"POSEFIT_LOG_ROTATION": ("logging", "rotation"),
|
||||
"POSEFIT_LOG_RETENTION": ("logging", "retention"),
|
||||
"POSEFIT_LOG_DIR": ("logging", "dir"),
|
||||
}
|
||||
@dataclass
|
||||
class ServerConfig:
|
||||
host: str = "0.0.0.0"
|
||||
port: int = 8765
|
||||
max_ws_size: int = 10_485_760
|
||||
|
||||
|
||||
def _load_yaml() -> dict[str, Any]:
|
||||
config_path = _PROJECT_ROOT / "config.yaml"
|
||||
if config_path.exists():
|
||||
with open(config_path, encoding="utf-8") as f:
|
||||
return yaml.safe_load(f) or {}
|
||||
@dataclass
|
||||
class VideoConfig:
|
||||
process_every_n_frames: int = 1
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelConfig:
|
||||
path: str = ""
|
||||
prefer_gpu: bool = True
|
||||
|
||||
@property
|
||||
def resolved_path(self) -> str:
|
||||
if self.path:
|
||||
return self.path
|
||||
return str(Path(__file__).resolve().parent.parent / "pose_models" / "pose_landmarker_full.task")
|
||||
|
||||
|
||||
@dataclass
|
||||
class DeadBugConfig:
|
||||
visibility_threshold: float = 0.45
|
||||
extension_confirm_frames: int = 4
|
||||
reset_confirm_frames: int = 3
|
||||
|
||||
|
||||
@dataclass
|
||||
class AudioConfig:
|
||||
rep_announcer_enabled: bool = True
|
||||
rep_announcer_rate: int = 185
|
||||
rep_announcer_volume: float = 1.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoggingConfig:
|
||||
dir: str = "logs"
|
||||
rotation: str = "20 MB"
|
||||
retention: str = "14 days"
|
||||
|
||||
@property
|
||||
def dir_path(self) -> Path:
|
||||
return Path(__file__).resolve().parent.parent / self.dir
|
||||
|
||||
|
||||
@dataclass
|
||||
class AppConfig:
|
||||
server: ServerConfig = field(default_factory=ServerConfig)
|
||||
video: VideoConfig = field(default_factory=VideoConfig)
|
||||
model: ModelConfig = field(default_factory=ModelConfig)
|
||||
dead_bug: DeadBugConfig = field(default_factory=DeadBugConfig)
|
||||
audio: AudioConfig = field(default_factory=AudioConfig)
|
||||
logging: LoggingConfig = field(default_factory=LoggingConfig)
|
||||
|
||||
|
||||
def _dict_to_dataclass(cls: type, data: dict[str, Any] | None) -> dict[str, Any]:
|
||||
"""Convert a dict to dataclass constructor kwargs, using only known fields."""
|
||||
import dataclasses
|
||||
if data is None:
|
||||
return {}
|
||||
fields = {f.name for f in dataclasses.fields(cls)}
|
||||
return {k: v for k, v in data.items() if k in fields}
|
||||
|
||||
|
||||
def _apply_env_overrides(config: dict) -> None:
|
||||
for env_var, (section, key, *rest) in _ENV_MAP.items():
|
||||
value = os.getenv(env_var)
|
||||
if value is None:
|
||||
continue
|
||||
if rest:
|
||||
value = rest[0](value)
|
||||
config.setdefault(section, {})[key] = value
|
||||
def load_config(config_path: str | Path | None = None) -> AppConfig:
|
||||
if config_path is None:
|
||||
config_path = Path(__file__).resolve().parent.parent / "config.yaml"
|
||||
|
||||
raw: dict[str, Any] = {}
|
||||
if Path(config_path).exists():
|
||||
with open(config_path, encoding="utf-8") as f:
|
||||
raw = yaml.safe_load(f) or {}
|
||||
|
||||
return AppConfig(
|
||||
server=ServerConfig(**_dict_to_dataclass(ServerConfig, raw.get("server"))),
|
||||
video=VideoConfig(**_dict_to_dataclass(VideoConfig, raw.get("video"))),
|
||||
model=ModelConfig(**_dict_to_dataclass(ModelConfig, raw.get("model"))),
|
||||
dead_bug=DeadBugConfig(**_dict_to_dataclass(DeadBugConfig, raw.get("dead_bug"))),
|
||||
audio=AudioConfig(**_dict_to_dataclass(AudioConfig, raw.get("audio"))),
|
||||
logging=LoggingConfig(**_dict_to_dataclass(LoggingConfig, raw.get("logging"))),
|
||||
)
|
||||
|
||||
|
||||
_cfg = _load_yaml()
|
||||
_apply_env_overrides(_cfg)
|
||||
|
||||
|
||||
def _get(section: str, key: str, default: Any = None) -> Any:
|
||||
return _cfg.get(section, {}).get(key, default)
|
||||
|
||||
|
||||
# ── Server ──────────────────────────────────────────────────────────────────
|
||||
WS_HOST = _get("server", "host", "0.0.0.0")
|
||||
WS_PORT = _get("server", "port", 8765)
|
||||
WS_MAX_SIZE = _get("server", "max_ws_size", 10485760)
|
||||
|
||||
# ── Video processing ────────────────────────────────────────────────────────
|
||||
PROCESS_EVERY_N_FRAMES = max(1, _get("video", "process_every_n_frames", 1))
|
||||
|
||||
# ── Model ───────────────────────────────────────────────────────────────────
|
||||
MODEL_DIR: Path = _PROJECT_ROOT / "pose_models"
|
||||
_model_path = _get("model", "path", "")
|
||||
MODEL_PATH = _model_path if _model_path else str(MODEL_DIR / "pose_landmarker_full.task")
|
||||
PREFER_GPU = bool(_get("model", "prefer_gpu", True))
|
||||
|
||||
# ── Dead bug exercise ───────────────────────────────────────────────────────
|
||||
VISIBILITY_THRESHOLD = float(_get("dead_bug", "visibility_threshold", 0.45))
|
||||
EXTENSION_CONFIRM_FRAMES = int(_get("dead_bug", "extension_confirm_frames", 4))
|
||||
RESET_CONFIRM_FRAMES = int(_get("dead_bug", "reset_confirm_frames", 3))
|
||||
|
||||
# ── Audio ───────────────────────────────────────────────────────────────────
|
||||
REP_ANNOUNCER_ENABLED = bool(_get("audio", "rep_announcer_enabled", True))
|
||||
REP_ANNOUNCER_RATE = int(_get("audio", "rep_announcer_rate", 185))
|
||||
REP_ANNOUNCER_VOLUME = float(_get("audio", "rep_announcer_volume", 1.0))
|
||||
|
||||
# ── Logging ─────────────────────────────────────────────────────────────────
|
||||
LOG_DIR: Path = _PROJECT_ROOT / _get("logging", "dir", "logs")
|
||||
LOG_ROTATION = _get("logging", "rotation", "20 MB")
|
||||
LOG_RETENTION = _get("logging", "retention", "14 days")
|
||||
config = load_config()
|
||||
|
||||
Reference in New Issue
Block a user