Source code for yt_framework.operations.command_ops.vanilla

"""Driver helpers to package `src/vanilla.py` and submit YT vanilla operations."""

from __future__ import annotations

from dataclasses import dataclass
from typing import TYPE_CHECKING, Any

from omegaconf import DictConfig, OmegaConf

from yt_framework.operations._internal.dependency_strategy import (
    DependencyBuildContext,
    TarArchiveDependencyBuilder,
)
from yt_framework.operations.common import (
    build_operation_environment,
    collect_passthrough_kwargs,
    docker_auth_from_op_config,
    extract_max_failed_jobs,
    extract_operation_resources,
    extract_secure_env_client_kwargs,
)
from yt_framework.utils.logging import log_header, log_success
from yt_framework.yt.clients.operation_specs import (
    VanillaSubmitSpec,
    docker_auth_tuple,
    env_pairs_tuple,
    extras_tuple,
    file_pairs_tuple,
)

if TYPE_CHECKING:
    import logging
    from pathlib import Path

    from yt_framework.contracts import StageContext


[docs] @dataclass class VanillaOperationData: """Data container for vanilla operation configuration. Attributes: script_path: Path to vanilla.py script in YT (or placeholder if tar mode). dependencies: List of (yt_path, local_path) tuples for files to upload. environment: Environment variables dictionary (secrets only). docker_auth: Optional Docker authentication dictionary for private registries. command: Optional command to execute (used in tar archive mode). """ script_path: str dependencies: list[tuple[str, str]] environment: dict[str, str] docker_auth: dict[str, str] | None command: str | None = None
def _prepare_vanilla_operation( pipeline_config: DictConfig, operation_config: DictConfig, stage_config: DictConfig, stage_dir: Path, logger: logging.Logger, ) -> VanillaOperationData: """Build tar-archive dependencies for a vanilla operation. Environment and docker_auth are intentionally left empty here; the caller builds them via ``build_operation_environment`` and sets them on the returned object after construction. Args: pipeline_config: Pipeline-level config (build_folder, etc.) operation_config: Operation-specific config (from client.operations.vanilla) stage_config: Full stage config (for accessing job section) stage_dir: Path to stage directory logger: Logger instance Returns: VanillaOperationData with dependencies and command populated. """ builder = TarArchiveDependencyBuilder() dep = builder.build_dependencies( DependencyBuildContext( operation_type="vanilla", stage_dir=stage_dir, archive_name="source.tar.gz", build_folder=pipeline_config.pipeline.build_folder, operation_config=operation_config, stage_config=stage_config, logger=logger, ), ) return VanillaOperationData( script_path=dep.script_path, dependencies=dep.dependencies, environment={}, docker_auth=None, command=dep.command, ) def _vanilla_operation_description_kwargs( operation_config: DictConfig, logger: logging.Logger, ) -> dict[str, Any]: out: dict[str, Any] = {} od = operation_config.get("operation_description") if not od: return out if isinstance(od, str): logger.info("Operation label: %s", od) out["title"] = od return out out["operation_description"] = OmegaConf.to_container(od, resolve=True) return out
[docs] def run_vanilla( context: StageContext, operation_config: DictConfig, job: str | None = None, ) -> bool: """Run YT vanilla operation and wait for completion. All job parameters (pool, memory, CPU, Docker image, etc.) are automatically extracted from operation_config. Operation config should be passed from stage.config.operations.vanilla. The task name is automatically set to the stage name. Args: context: Stage context (provides deps, logger, stage_dir, name) operation_config: Operation-specific config (from client.operations.vanilla) job: Preferred command alias. When omitted, framework wrapper command is used. Returns: True if successful, False otherwise """ logger = context.logger # Use stage name as task name task_name = context.name env = build_operation_environment( context=context, operation_config=operation_config, logger=logger, include_stage_name=True, include_tokenizer_artifact=False, ) # Prepare operation data automatically vanilla_operation_data = _prepare_vanilla_operation( pipeline_config=context.deps.pipeline_config, operation_config=operation_config, stage_config=context.config, stage_dir=context.stage_dir, logger=logger, ) vanilla_operation_data.environment = env vanilla_operation_data.docker_auth = docker_auth_from_op_config( operation_config, env, ) log_header( logger, "Vanilla Operation", f"Task: {task_name} | Script: {vanilla_operation_data.script_path}", ) logger.debug("Dependencies: %s files", len(vanilla_operation_data.dependencies)) # Command is always provided by the dependency builder (tar archive mode) if not vanilla_operation_data.command: msg = "Command not provided by dependency builder" raise ValueError(msg) command = job if job is not None else vanilla_operation_data.command logger.debug("Extracting operation resources from config") resources = extract_operation_resources(operation_config, logger) max_failed_jobs = extract_max_failed_jobs(operation_config, logger) vanilla_kwargs = _vanilla_operation_description_kwargs(operation_config, logger) vanilla_kwargs.update( collect_passthrough_kwargs( operation_config, reserved_keys={ "resources", "env", "max_failed_job_count", "file_paths", "checkpoint", "tokenizer_artifact", "tar_command_bootstrap", "operation_description", "environment_public_keys", "use_plain_environment_for_secrets", }, ), ) merged_v: dict[str, object] = { **extract_secure_env_client_kwargs(operation_config), **vanilla_kwargs, } operation = context.deps.yt_client.run_vanilla_submit( VanillaSubmitSpec( command=command, files=file_pairs_tuple(vanilla_operation_data.dependencies), env=env_pairs_tuple(vanilla_operation_data.environment), task_name=task_name, resources=resources, docker_auth=docker_auth_tuple(vanilla_operation_data.docker_auth), max_failed_jobs=max_failed_jobs, job=job, extras=extras_tuple(merged_v), ), ) if operation is None: logger.error("Failed to submit vanilla operation: returned None") return False logger.debug("Operation submitted: %s", operation.id) # Wait for completion success = context.deps.yt_client.wait_for_operation(operation) if success: log_success(logger, "Vanilla operation completed successfully") else: logger.error("Vanilla operation failed") return success