Source code for yt_framework.operations.tokenizer_artifact

"""Pack tokenizer/processor tarballs, upload to Cypress, and expose sandbox env vars."""

from __future__ import annotations

import tarfile
import tempfile
from pathlib import Path
from typing import Optional, TYPE_CHECKING

from omegaconf import DictConfig

if TYPE_CHECKING:
    from yt_framework.core.stage import StageContext


[docs] def resolve_tokenizer_artifact_name( stage_config: DictConfig, tokenizer_artifact_config: DictConfig, ) -> Optional[str]: """Resolve logical tokenizer artifact name from config.""" explicit = tokenizer_artifact_config.get("artifact_name") if explicit and str(explicit).strip(): return str(explicit).strip() if "job" in stage_config: tokenizer_name = stage_config.job.get("tokenizer_name") if tokenizer_name and str(tokenizer_name).strip(): return str(tokenizer_name).strip() model_name = stage_config.job.get("model_name") if model_name and str(model_name).strip(): return str(model_name).strip().split("/")[-1] local_path = tokenizer_artifact_config.get("local_artifact_path") if local_path and str(local_path).strip(): return Path(str(local_path)).name.replace(".tar.gz", "") return None
[docs] def resolve_tokenizer_archive_name(artifact_name: str) -> str: """Convert logical artifact name to mounted tar filename.""" if artifact_name.endswith(".tar.gz"): return artifact_name return f"{artifact_name}.tar.gz"
[docs] def resolve_tokenizer_artifact_yt_path( stage_config: DictConfig, tokenizer_artifact_config: DictConfig, ) -> Optional[str]: """Resolve full YT file path for tokenizer artifact tarball.""" artifact_base = tokenizer_artifact_config.get("artifact_base") if not artifact_base: return None artifact_name = resolve_tokenizer_artifact_name( stage_config=stage_config, tokenizer_artifact_config=tokenizer_artifact_config, ) if not artifact_name: return None archive_name = resolve_tokenizer_archive_name(artifact_name) return f"{artifact_base}/{archive_name}"
def _tar_directory(source_dir: Path, target_tar_gz: Path) -> None: """Create tar.gz from directory contents (without parent dir wrapper).""" with tarfile.open(target_tar_gz, "w:gz") as tar: for path in source_dir.rglob("*"): if path.is_file(): tar.add(path, arcname=path.relative_to(source_dir), recursive=False) def _prepare_local_archive(local_artifact_path: Path, artifact_name: str) -> Path: """ Prepare local tar.gz path from `local_artifact_path`. - If source is a directory, pack it to a temporary `.tar.gz`. - If source is `.tar.gz`, use it directly. """ if local_artifact_path.is_dir(): fd, tmp_name = tempfile.mkstemp(prefix=f"{artifact_name}_", suffix=".tar.gz") Path(tmp_name).unlink(missing_ok=True) tmp_archive = Path(tmp_name) _tar_directory(local_artifact_path, tmp_archive) return tmp_archive if local_artifact_path.is_file() and local_artifact_path.name.endswith(".tar.gz"): return local_artifact_path raise ValueError( "local_artifact_path must point to a directory or to a .tar.gz file, " f"got: {local_artifact_path}" )
[docs] def init_tokenizer_artifact_directory( context: "StageContext", tokenizer_artifact_config: DictConfig, ) -> None: """ Initialize tokenizer artifact in YT (if configured). Behavior: - creates `artifact_base` if needed; - uploads local artifact from `local_artifact_path` if provided and missing in YT; - validates artifact presence in YT. """ artifact_base = tokenizer_artifact_config.get("artifact_base") if not artifact_base: return artifact_name = resolve_tokenizer_artifact_name( stage_config=context.config, tokenizer_artifact_config=tokenizer_artifact_config, ) if not artifact_name: raise ValueError( "tokenizer_artifact is configured but artifact_name cannot be resolved. " "Set tokenizer_artifact.artifact_name or job.tokenizer_name/model_name." ) archive_name = resolve_tokenizer_archive_name(artifact_name) yt_artifact_path = f"{artifact_base}/{archive_name}" local_artifact_path = tokenizer_artifact_config.get("local_artifact_path") context.deps.yt_client.create_path(artifact_base, node_type="map_node") context.logger.info(f"Tokenizer artifact directory ready: {artifact_base}") temp_archive: Optional[Path] = None try: if local_artifact_path: source = Path(str(local_artifact_path)) if not source.exists(): context.logger.warning( f"tokenizer_artifact.local_artifact_path does not exist: {source}" ) elif context.deps.yt_client.exists(yt_artifact_path): context.logger.info( f"Tokenizer artifact already exists in YT: {yt_artifact_path} (skipping upload)" ) else: archive_local_path = _prepare_local_archive(source, artifact_name) if archive_local_path != source: temp_archive = archive_local_path context.logger.info( f"Uploading tokenizer artifact: {archive_local_path} -> {yt_artifact_path}" ) context.deps.yt_client.upload_file( archive_local_path, yt_artifact_path, create_parent_dir=True ) if not context.deps.yt_client.exists(yt_artifact_path): raise FileNotFoundError( f"Tokenizer artifact not found in YT: {yt_artifact_path}. " "Provide tokenizer_artifact.local_artifact_path or upload manually." ) context.logger.info(f"Tokenizer artifact verified: {yt_artifact_path}") finally: if temp_archive and temp_archive.exists(): temp_archive.unlink(missing_ok=True)