"""Driver helpers to package `src/vanilla.py` and submit YT vanilla operations."""
from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any
from omegaconf import DictConfig, OmegaConf
from yt_framework.operations._internal.dependency_strategy import (
DependencyBuildContext,
TarArchiveDependencyBuilder,
)
from yt_framework.operations.common import (
build_operation_environment,
collect_passthrough_kwargs,
docker_auth_from_op_config,
extract_max_failed_jobs,
extract_operation_resources,
extract_secure_env_client_kwargs,
)
from yt_framework.utils.logging import log_header, log_success
from yt_framework.yt.clients.operation_specs import (
VanillaSubmitSpec,
docker_auth_tuple,
env_pairs_tuple,
extras_tuple,
file_pairs_tuple,
)
if TYPE_CHECKING:
import logging
from pathlib import Path
from yt_framework.contracts import StageContext
[docs]
@dataclass
class VanillaOperationData:
"""Data container for vanilla operation configuration.
Attributes:
script_path: Path to vanilla.py script in YT (or placeholder if tar mode).
dependencies: List of (yt_path, local_path) tuples for files to upload.
environment: Environment variables dictionary (secrets only).
docker_auth: Optional Docker authentication dictionary for private registries.
command: Optional command to execute (used in tar archive mode).
"""
script_path: str
dependencies: list[tuple[str, str]]
environment: dict[str, str]
docker_auth: dict[str, str] | None
command: str | None = None
def _prepare_vanilla_operation(
pipeline_config: DictConfig,
operation_config: DictConfig,
stage_config: DictConfig,
stage_dir: Path,
logger: logging.Logger,
) -> VanillaOperationData:
"""Build tar-archive dependencies for a vanilla operation.
Environment and docker_auth are intentionally left empty here; the caller
builds them via ``build_operation_environment`` and sets them on the returned
object after construction.
Args:
pipeline_config: Pipeline-level config (build_folder, etc.)
operation_config: Operation-specific config (from client.operations.vanilla)
stage_config: Full stage config (for accessing job section)
stage_dir: Path to stage directory
logger: Logger instance
Returns:
VanillaOperationData with dependencies and command populated.
"""
builder = TarArchiveDependencyBuilder()
dep = builder.build_dependencies(
DependencyBuildContext(
operation_type="vanilla",
stage_dir=stage_dir,
archive_name="source.tar.gz",
build_folder=pipeline_config.pipeline.build_folder,
operation_config=operation_config,
stage_config=stage_config,
logger=logger,
),
)
return VanillaOperationData(
script_path=dep.script_path,
dependencies=dep.dependencies,
environment={},
docker_auth=None,
command=dep.command,
)
def _vanilla_operation_description_kwargs(
operation_config: DictConfig,
logger: logging.Logger,
) -> dict[str, Any]:
out: dict[str, Any] = {}
od = operation_config.get("operation_description")
if not od:
return out
if isinstance(od, str):
logger.info("Operation label: %s", od)
out["title"] = od
return out
out["operation_description"] = OmegaConf.to_container(od, resolve=True)
return out
[docs]
def run_vanilla(
context: StageContext,
operation_config: DictConfig,
job: str | None = None,
) -> bool:
"""Run YT vanilla operation and wait for completion.
All job parameters (pool, memory, CPU, Docker image, etc.) are automatically
extracted from operation_config. Operation config should be passed from
stage.config.operations.vanilla. The task name is automatically set to
the stage name.
Args:
context: Stage context (provides deps, logger, stage_dir, name)
operation_config: Operation-specific config (from client.operations.vanilla)
job: Preferred command alias. When omitted, framework wrapper command is used.
Returns:
True if successful, False otherwise
"""
logger = context.logger
# Use stage name as task name
task_name = context.name
env = build_operation_environment(
context=context,
operation_config=operation_config,
logger=logger,
include_stage_name=True,
include_tokenizer_artifact=False,
)
# Prepare operation data automatically
vanilla_operation_data = _prepare_vanilla_operation(
pipeline_config=context.deps.pipeline_config,
operation_config=operation_config,
stage_config=context.config,
stage_dir=context.stage_dir,
logger=logger,
)
vanilla_operation_data.environment = env
vanilla_operation_data.docker_auth = docker_auth_from_op_config(
operation_config,
env,
)
log_header(
logger,
"Vanilla Operation",
f"Task: {task_name} | Script: {vanilla_operation_data.script_path}",
)
logger.debug("Dependencies: %s files", len(vanilla_operation_data.dependencies))
# Command is always provided by the dependency builder (tar archive mode)
if not vanilla_operation_data.command:
msg = "Command not provided by dependency builder"
raise ValueError(msg)
command = job if job is not None else vanilla_operation_data.command
logger.debug("Extracting operation resources from config")
resources = extract_operation_resources(operation_config, logger)
max_failed_jobs = extract_max_failed_jobs(operation_config, logger)
vanilla_kwargs = _vanilla_operation_description_kwargs(operation_config, logger)
vanilla_kwargs.update(
collect_passthrough_kwargs(
operation_config,
reserved_keys={
"resources",
"env",
"max_failed_job_count",
"file_paths",
"checkpoint",
"tokenizer_artifact",
"tar_command_bootstrap",
"operation_description",
"environment_public_keys",
"use_plain_environment_for_secrets",
},
),
)
merged_v: dict[str, object] = {
**extract_secure_env_client_kwargs(operation_config),
**vanilla_kwargs,
}
operation = context.deps.yt_client.run_vanilla_submit(
VanillaSubmitSpec(
command=command,
files=file_pairs_tuple(vanilla_operation_data.dependencies),
env=env_pairs_tuple(vanilla_operation_data.environment),
task_name=task_name,
resources=resources,
docker_auth=docker_auth_tuple(vanilla_operation_data.docker_auth),
max_failed_jobs=max_failed_jobs,
job=job,
extras=extras_tuple(merged_v),
),
)
if operation is None:
logger.error("Failed to submit vanilla operation: returned None")
return False
logger.debug("Operation submitted: %s", operation.id)
# Wait for completion
success = context.deps.yt_client.wait_for_operation(operation)
if success:
log_success(logger, "Vanilla operation completed successfully")
else:
logger.error("Vanilla operation failed")
return success