Source code for yt_framework.operations.checkpoint
"""Upload or reuse single-file model checkpoints and wire them into operation specs."""
from __future__ import annotations
from pathlib import Path
from typing import TYPE_CHECKING
from omegaconf import DictConfig
if TYPE_CHECKING:
from yt_framework.contracts import StageContext
def _raise_file_not_found(message: str) -> None:
raise FileNotFoundError(message)
def _checkpoint_base_from_config(checkpoint_config: DictConfig) -> str | None:
raw = checkpoint_config.get("checkpoint_base")
if isinstance(raw, str) and raw.strip():
return str(raw)
return None
def _local_checkpoint_path_from_config(checkpoint_config: DictConfig) -> str | None:
raw = checkpoint_config.get("local_checkpoint_path")
if isinstance(raw, str):
return str(raw)
return None
def _resolve_model_name(context: StageContext) -> str | None:
job_cfg = context.config.get("job")
if not isinstance(job_cfg, DictConfig):
return None
model_name = job_cfg.get("model_name")
if isinstance(model_name, str) and model_name.strip():
return model_name
return None
def _upload_local_checkpoint_if_needed(
context: StageContext,
checkpoint_base: str,
local_checkpoint_path: str | None,
) -> None:
if not local_checkpoint_path:
return
local_path = Path(local_checkpoint_path)
if not local_path.exists():
context.logger.warning(
"Local checkpoint path does not exist: %s",
local_path,
)
return
checkpoint_name = local_path.name
yt_checkpoint_path = f"{checkpoint_base}/{checkpoint_name}"
if context.deps.yt_client.exists(yt_checkpoint_path):
context.logger.info(
"Checkpoint already exists in YT: %s (skipping upload)",
yt_checkpoint_path,
)
return
context.logger.info(
"Uploading local checkpoint: %s → %s",
local_path,
yt_checkpoint_path,
)
context.deps.yt_client.upload_file(
local_path,
yt_checkpoint_path,
create_parent_dir=True,
)
context.logger.debug("Checkpoint uploaded: %s", yt_checkpoint_path)
def _validate_required_checkpoint(
context: StageContext,
checkpoint_base: str,
model_name: str | None,
) -> None:
if not model_name:
context.logger.debug("No model_name specified, skipping checkpoint validation")
return
yt_checkpoint_path = f"{checkpoint_base}/{model_name}"
if context.deps.yt_client.exists(yt_checkpoint_path):
context.logger.debug("Required checkpoint verified: %s", yt_checkpoint_path)
return
error_msg = (
f"Required checkpoint not found in YT: {yt_checkpoint_path}\n"
"Please upload the checkpoint using local_checkpoint_path in config, "
f"or manually upload {model_name} to {checkpoint_base}"
)
context.logger.error(error_msg)
_raise_file_not_found(error_msg)
[docs]
def init_checkpoint_directory(
context: StageContext,
checkpoint_config: DictConfig,
) -> None:
"""Initialize checkpoint directory in YTsaurus if it doesn't exist.
Uses checkpoint_base from checkpoint_config. Also uploads local checkpoint if specified.
Validates that required checkpoint exists in YT before proceeding.
Args:
context: Stage context (provides deps, logger)
checkpoint_config: Checkpoint-specific config (from client.operations.<op>.checkpoint)
Returns:
None
Raises:
FileNotFoundError: If required checkpoint not found in YT
Exception: If checkpoint initialization fails
"""
checkpoint_base = _checkpoint_base_from_config(checkpoint_config)
local_checkpoint_path = _local_checkpoint_path_from_config(checkpoint_config)
model_name = _resolve_model_name(context)
if not checkpoint_base:
context.logger.warning(
"No checkpoint_base specified in checkpoint config, skipping checkpoint initialization",
)
return
try:
# Create checkpoint directory in YT
context.deps.yt_client.create_path(checkpoint_base, node_type="map_node")
context.logger.info("Checkpoint directory ready: %s", checkpoint_base)
_upload_local_checkpoint_if_needed(
context=context,
checkpoint_base=checkpoint_base,
local_checkpoint_path=local_checkpoint_path,
)
_validate_required_checkpoint(
context=context,
checkpoint_base=checkpoint_base,
model_name=model_name,
)
except FileNotFoundError:
raise # Re-raise checkpoint validation errors
except Exception:
context.logger.exception(
"Could not initialize checkpoint directory %s",
checkpoint_base,
)
raise