diff --git a/configs/load.py b/configs/load.py index 6bac897..783d0be 100644 --- a/configs/load.py +++ b/configs/load.py @@ -1,97 +1,87 @@ from __future__ import annotations -from dataclasses import dataclass, field +import dataclasses +import os from pathlib import Path from typing import Any import yaml +from configs.models import ( + AppConfig, + AudioConfig, + DeadBugConfig, + LoggingConfig, + ModelConfig, + ServerConfig, + VideoConfig, +) -@dataclass -class ServerConfig: - host: str = "0.0.0.0" - port: int = 8765 - max_ws_size: int = 10_485_760 +_PROJECT_ROOT = Path(__file__).resolve().parent.parent +_ENV_MAP = { + "POSEFIT_WS_HOST": ("server", "host"), + "POSEFIT_WS_PORT": ("server", "port", int), + "POSEFIT_WS_MAX_SIZE": ("server", "max_ws_size", int), + "POSEFIT_PROCESS_EVERY_N_FRAMES": ("video", "process_every_n_frames", int), + "POSEFIT_MODEL_PATH": ("model", "path"), + "POSEFIT_PREFER_GPU": ("model", "prefer_gpu", lambda v: v not in ("0", "false", "False")), + "POSEFIT_VISIBILITY_THRESHOLD": ("dead_bug", "visibility_threshold", float), + "POSEFIT_EXTENSION_CONFIRM_FRAMES": ("dead_bug", "extension_confirm_frames", int), + "POSEFIT_RESET_CONFIRM_FRAMES": ("dead_bug", "reset_confirm_frames", int), + "POSEFIT_REP_ANNOUNCER_ENABLED": ("audio", "rep_announcer_enabled", lambda v: v not in ("0", "false", "False")), + "POSEFIT_REP_ANNOUNCER_RATE": ("audio", "rep_announcer_rate", int), + "POSEFIT_REP_ANNOUNCER_VOLUME": ("audio", "rep_announcer_volume", float), + "POSEFIT_LOG_ROTATION": ("logging", "rotation"), + "POSEFIT_LOG_RETENTION": ("logging", "retention"), + "POSEFIT_LOG_DIR": ("logging", "dir"), +} -@dataclass -class VideoConfig: - process_every_n_frames: int = 1 - - -@dataclass -class ModelConfig: - path: str = "" - prefer_gpu: bool = True - - @property - def resolved_path(self) -> str: - if self.path: - return self.path - return str(Path(__file__).resolve().parent.parent / "pose_models" / "pose_landmarker_full.task") - - -@dataclass -class DeadBugConfig: - visibility_threshold: float = 0.45 - extension_confirm_frames: int = 4 - reset_confirm_frames: int = 3 - - -@dataclass -class AudioConfig: - rep_announcer_enabled: bool = True - rep_announcer_rate: int = 185 - rep_announcer_volume: float = 1.0 - - -@dataclass -class LoggingConfig: - dir: str = "logs" - rotation: str = "20 MB" - retention: str = "14 days" - - @property - def dir_path(self) -> Path: - return Path(__file__).resolve().parent.parent / self.dir - - -@dataclass -class AppConfig: - server: ServerConfig = field(default_factory=ServerConfig) - video: VideoConfig = field(default_factory=VideoConfig) - model: ModelConfig = field(default_factory=ModelConfig) - dead_bug: DeadBugConfig = field(default_factory=DeadBugConfig) - audio: AudioConfig = field(default_factory=AudioConfig) - logging: LoggingConfig = field(default_factory=LoggingConfig) +_SECTION_CLASS = { + "server": ServerConfig, + "video": VideoConfig, + "model": ModelConfig, + "dead_bug": DeadBugConfig, + "audio": AudioConfig, + "logging": LoggingConfig, +} def _dict_to_dataclass(cls: type, data: dict[str, Any] | None) -> dict[str, Any]: - """Convert a dict to dataclass constructor kwargs, using only known fields.""" - import dataclasses if data is None: return {} - fields = {f.name for f in dataclasses.fields(cls)} - return {k: v for k, v in data.items() if k in fields} + field_names = {f.name for f in dataclasses.fields(cls)} + return {k: v for k, v in data.items() if k in field_names} + + +def _read_yaml(path: Path) -> dict[str, Any]: + if not path.exists(): + return {} + with open(path, encoding="utf-8") as f: + return yaml.safe_load(f) or {} + + +def _apply_env_overrides(raw: dict[str, Any]) -> None: + for env_var, (section, key, *rest) in _ENV_MAP.items(): + value = os.getenv(env_var) + if value is None: + continue + if rest: + value = rest[0](value) + raw.setdefault(section, {})[key] = value def load_config(config_path: str | Path | None = None) -> AppConfig: if config_path is None: - config_path = Path(__file__).resolve().parent.parent / "config.yaml" + config_path = _PROJECT_ROOT / "config.yaml" - raw: dict[str, Any] = {} - if Path(config_path).exists(): - with open(config_path, encoding="utf-8") as f: - raw = yaml.safe_load(f) or {} + raw = _read_yaml(Path(config_path)) + _apply_env_overrides(raw) - return AppConfig( - server=ServerConfig(**_dict_to_dataclass(ServerConfig, raw.get("server"))), - video=VideoConfig(**_dict_to_dataclass(VideoConfig, raw.get("video"))), - model=ModelConfig(**_dict_to_dataclass(ModelConfig, raw.get("model"))), - dead_bug=DeadBugConfig(**_dict_to_dataclass(DeadBugConfig, raw.get("dead_bug"))), - audio=AudioConfig(**_dict_to_dataclass(AudioConfig, raw.get("audio"))), - logging=LoggingConfig(**_dict_to_dataclass(LoggingConfig, raw.get("logging"))), - ) + return AppConfig(**{ + section: cls(**_dict_to_dataclass(cls, raw.get(section))) + for section, cls in _SECTION_CLASS.items() + }) config = load_config() diff --git a/configs/models.py b/configs/models.py new file mode 100644 index 0000000..ae31778 --- /dev/null +++ b/configs/models.py @@ -0,0 +1,63 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from pathlib import Path + + +@dataclass +class ServerConfig: + host: str = "0.0.0.0" + port: int = 8765 + max_ws_size: int = 10_485_760 + + +@dataclass +class VideoConfig: + process_every_n_frames: int = 1 + + +@dataclass +class ModelConfig: + path: str = "" + prefer_gpu: bool = True + + @property + def resolved_path(self) -> str: + if self.path: + return self.path + return str(Path(__file__).resolve().parent.parent / "pose_models" / "pose_landmarker_full.task") + + +@dataclass +class DeadBugConfig: + visibility_threshold: float = 0.45 + extension_confirm_frames: int = 4 + reset_confirm_frames: int = 3 + + +@dataclass +class AudioConfig: + rep_announcer_enabled: bool = True + rep_announcer_rate: int = 185 + rep_announcer_volume: float = 1.0 + + +@dataclass +class LoggingConfig: + dir: str = "logs" + rotation: str = "20 MB" + retention: str = "14 days" + + @property + def dir_path(self) -> Path: + return Path(__file__).resolve().parent.parent / self.dir + + +@dataclass +class AppConfig: + server: ServerConfig = field(default_factory=ServerConfig) + video: VideoConfig = field(default_factory=VideoConfig) + model: ModelConfig = field(default_factory=ModelConfig) + dead_bug: DeadBugConfig = field(default_factory=DeadBugConfig) + audio: AudioConfig = field(default_factory=AudioConfig) + logging: LoggingConfig = field(default_factory=LoggingConfig)