Parse config.yaml into typed AppConfig dataclass

- ServerConfig, VideoConfig, ModelConfig, DeadBugConfig,
  AudioConfig, LoggingConfig as nested dataclasses
- Consumers use config.server.host, config.model.resolved_path etc.
- env var overrides preserved via _apply_env_overrides()
This commit is contained in:
2026-06-10 10:23:51 +08:00
parent c8fd057129
commit f9384f7bc1
7 changed files with 116 additions and 109 deletions
+84 -70
View File
@@ -1,83 +1,97 @@
from __future__ import annotations
import os
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
import yaml
_PROJECT_ROOT = Path(__file__).resolve().parent.parent
_ENV_MAP = {
"POSEFIT_WS_HOST": ("server", "host"),
"POSEFIT_WS_PORT": ("server", "port", int),
"POSEFIT_WS_MAX_SIZE": ("server", "max_ws_size", int),
"POSEFIT_PROCESS_EVERY_N_FRAMES": ("video", "process_every_n_frames", int),
"POSEFIT_MODEL_PATH": ("model", "path"),
"POSEFIT_PREFER_GPU": ("model", "prefer_gpu", lambda v: v not in ("0", "false", "False")),
"POSEFIT_VISIBILITY_THRESHOLD": ("dead_bug", "visibility_threshold", float),
"POSEFIT_EXTENSION_CONFIRM_FRAMES": ("dead_bug", "extension_confirm_frames", int),
"POSEFIT_RESET_CONFIRM_FRAMES": ("dead_bug", "reset_confirm_frames", int),
"POSEFIT_REP_ANNOUNCER_ENABLED": ("audio", "rep_announcer_enabled", lambda v: v not in ("0", "false", "False")),
"POSEFIT_REP_ANNOUNCER_RATE": ("audio", "rep_announcer_rate", int),
"POSEFIT_REP_ANNOUNCER_VOLUME": ("audio", "rep_announcer_volume", float),
"POSEFIT_LOG_ROTATION": ("logging", "rotation"),
"POSEFIT_LOG_RETENTION": ("logging", "retention"),
"POSEFIT_LOG_DIR": ("logging", "dir"),
}
@dataclass
class ServerConfig:
host: str = "0.0.0.0"
port: int = 8765
max_ws_size: int = 10_485_760
def _load_yaml() -> dict[str, Any]:
config_path = _PROJECT_ROOT / "config.yaml"
if config_path.exists():
@dataclass
class VideoConfig:
process_every_n_frames: int = 1
@dataclass
class ModelConfig:
path: str = ""
prefer_gpu: bool = True
@property
def resolved_path(self) -> str:
if self.path:
return self.path
return str(Path(__file__).resolve().parent.parent / "pose_models" / "pose_landmarker_full.task")
@dataclass
class DeadBugConfig:
visibility_threshold: float = 0.45
extension_confirm_frames: int = 4
reset_confirm_frames: int = 3
@dataclass
class AudioConfig:
rep_announcer_enabled: bool = True
rep_announcer_rate: int = 185
rep_announcer_volume: float = 1.0
@dataclass
class LoggingConfig:
dir: str = "logs"
rotation: str = "20 MB"
retention: str = "14 days"
@property
def dir_path(self) -> Path:
return Path(__file__).resolve().parent.parent / self.dir
@dataclass
class AppConfig:
server: ServerConfig = field(default_factory=ServerConfig)
video: VideoConfig = field(default_factory=VideoConfig)
model: ModelConfig = field(default_factory=ModelConfig)
dead_bug: DeadBugConfig = field(default_factory=DeadBugConfig)
audio: AudioConfig = field(default_factory=AudioConfig)
logging: LoggingConfig = field(default_factory=LoggingConfig)
def _dict_to_dataclass(cls: type, data: dict[str, Any] | None) -> dict[str, Any]:
"""Convert a dict to dataclass constructor kwargs, using only known fields."""
import dataclasses
if data is None:
return {}
fields = {f.name for f in dataclasses.fields(cls)}
return {k: v for k, v in data.items() if k in fields}
def load_config(config_path: str | Path | None = None) -> AppConfig:
if config_path is None:
config_path = Path(__file__).resolve().parent.parent / "config.yaml"
raw: dict[str, Any] = {}
if Path(config_path).exists():
with open(config_path, encoding="utf-8") as f:
return yaml.safe_load(f) or {}
return {}
raw = yaml.safe_load(f) or {}
return AppConfig(
server=ServerConfig(**_dict_to_dataclass(ServerConfig, raw.get("server"))),
video=VideoConfig(**_dict_to_dataclass(VideoConfig, raw.get("video"))),
model=ModelConfig(**_dict_to_dataclass(ModelConfig, raw.get("model"))),
dead_bug=DeadBugConfig(**_dict_to_dataclass(DeadBugConfig, raw.get("dead_bug"))),
audio=AudioConfig(**_dict_to_dataclass(AudioConfig, raw.get("audio"))),
logging=LoggingConfig(**_dict_to_dataclass(LoggingConfig, raw.get("logging"))),
)
def _apply_env_overrides(config: dict) -> None:
for env_var, (section, key, *rest) in _ENV_MAP.items():
value = os.getenv(env_var)
if value is None:
continue
if rest:
value = rest[0](value)
config.setdefault(section, {})[key] = value
_cfg = _load_yaml()
_apply_env_overrides(_cfg)
def _get(section: str, key: str, default: Any = None) -> Any:
return _cfg.get(section, {}).get(key, default)
# ── Server ──────────────────────────────────────────────────────────────────
WS_HOST = _get("server", "host", "0.0.0.0")
WS_PORT = _get("server", "port", 8765)
WS_MAX_SIZE = _get("server", "max_ws_size", 10485760)
# ── Video processing ────────────────────────────────────────────────────────
PROCESS_EVERY_N_FRAMES = max(1, _get("video", "process_every_n_frames", 1))
# ── Model ───────────────────────────────────────────────────────────────────
MODEL_DIR: Path = _PROJECT_ROOT / "pose_models"
_model_path = _get("model", "path", "")
MODEL_PATH = _model_path if _model_path else str(MODEL_DIR / "pose_landmarker_full.task")
PREFER_GPU = bool(_get("model", "prefer_gpu", True))
# ── Dead bug exercise ───────────────────────────────────────────────────────
VISIBILITY_THRESHOLD = float(_get("dead_bug", "visibility_threshold", 0.45))
EXTENSION_CONFIRM_FRAMES = int(_get("dead_bug", "extension_confirm_frames", 4))
RESET_CONFIRM_FRAMES = int(_get("dead_bug", "reset_confirm_frames", 3))
# ── Audio ───────────────────────────────────────────────────────────────────
REP_ANNOUNCER_ENABLED = bool(_get("audio", "rep_announcer_enabled", True))
REP_ANNOUNCER_RATE = int(_get("audio", "rep_announcer_rate", 185))
REP_ANNOUNCER_VOLUME = float(_get("audio", "rep_announcer_volume", 1.0))
# ── Logging ─────────────────────────────────────────────────────────────────
LOG_DIR: Path = _PROJECT_ROOT / _get("logging", "dir", "logs")
LOG_ROTATION = _get("logging", "rotation", "20 MB")
LOG_RETENTION = _get("logging", "retention", "14 days")
config = load_config()