Source code for yt_framework.core.pipeline_config
"""OmegaConf normalization helpers for pipeline and upload configuration."""
from __future__ import annotations
from collections.abc import Mapping
from typing import Any, Literal
from omegaconf import DictConfig, ListConfig, OmegaConf
def _single_nonempty_module_name(raw: str) -> list[str]:
s = raw.strip()
return [s] if s else []
def _upload_modules_from_sequence(
raw: list[object] | tuple[object, ...] | ListConfig,
) -> list[str]:
return [str(m).strip() for m in raw if str(m).strip()]
[docs]
def normalize_upload_modules(raw: object) -> list[str]:
"""Normalize upload_modules config: accept list, tuple, or single string."""
if raw is None:
return []
if isinstance(raw, (list, tuple, ListConfig)):
return _upload_modules_from_sequence(raw)
if isinstance(raw, str):
return _single_nonempty_module_name(raw)
msg = "upload_modules must be a list of module names or a single string."
raise ValueError(msg)
def _coerce_upload_path_mapping(idx: int, element: object) -> dict[str, str]:
item: object = element
if isinstance(item, DictConfig):
item = OmegaConf.to_container(item, resolve=True)
if not isinstance(item, Mapping):
msg = (
f"upload_paths[{idx}] must be a mapping with at least a 'source' key, "
f"got {type(item).__name__!r}."
)
raise TypeError(msg)
if "source" not in item:
msg = f"upload_paths[{idx}] is missing required 'source' key."
raise ValueError(msg)
return {k: str(v) for k, v in item.items()}
[docs]
def normalize_upload_paths(raw: object) -> list[dict[str, str]]:
"""Normalize upload_paths config: must be a list of {source, target?} mappings."""
if raw is None:
return []
if not isinstance(raw, (list, tuple, ListConfig)):
msg = "upload_paths must be a list of {source, target?} dicts."
raise TypeError(msg)
return [_coerce_upload_path_mapping(i, el) for i, el in enumerate(raw)]
[docs]
def yt_mode_from_pipeline_config(raw: object) -> Literal["prod", "dev"] | None:
"""Coerce ``pipeline.mode`` to a literal prod/dev or None (caller may default)."""
if raw is None:
return None
s = str(raw).strip().lower()
if s == "prod":
return "prod"
if s == "dev":
return "dev"
msg = f"pipeline.mode must be 'prod' or 'dev', got {raw!r}"
raise ValueError(msg)
[docs]
def pickling_dict_from_config(pickling_cfg: object) -> dict[str, Any]:
"""Return a plain dict for ``create_yt_client(..., pickling=...)``."""
if not pickling_cfg:
return {}
raw = OmegaConf.to_container(pickling_cfg, resolve=True)
if raw is None:
return {}
if isinstance(raw, Mapping):
return dict(raw)
msg = (
"pipeline.pickling must be a mapping-compatible config, "
f"got {type(raw).__name__}"
)
raise TypeError(msg)
def _enabled_from_sequence(
enabled: list[object] | tuple[object, ...] | ListConfig,
) -> list[str]:
return [str(x) for x in enabled]
[docs]
def enabled_stage_names(enabled: object) -> list[str]:
"""Normalize ``stages.enabled_stages`` to a list of directory names."""
if enabled is None:
return []
if isinstance(enabled, (list, tuple, ListConfig)):
return _enabled_from_sequence(enabled)
if isinstance(enabled, str):
s = enabled.strip()
return [s] if s else []
return [str(enabled)]