Source code for yt_framework.yt.client_dev

"""Local filesystem stand-in for Cypress tables and subprocess-backed jobs."""

import json
import os
import logging
import subprocess
import shutil
from pathlib import Path
from typing import List, Dict, Any, Optional, Union, Tuple, Literal

from yt.wrapper import Operation  # pyright: ignore[reportMissingImports]
from typing import TYPE_CHECKING

from yt_framework.yt.client_base import BaseYTClient, OperationResources
from yt_framework.yt.operation_secure_env import pop_secure_env_client_kwargs
from yt_framework.operations.job_command import is_typed_job, resolve_aliased_job
from yt_framework.yt.max_row_weight import ensure_max_row_weight_pragma

if TYPE_CHECKING:
    from yt.wrapper.schema import TableSchema  # pyright: ignore[reportMissingImports]


class _DevOperation:
    """Fake operation for dev run_map; implements wait, get_state, get_error."""

    def __init__(self, returncode: int, stderr_message: str = ""):
        self._returncode = returncode
        self._stderr = stderr_message
        self.id = f"dev-operation-{id(self)}"  # Fake operation ID for dev mode

    def wait(self) -> None:
        pass

    def get_state(self) -> str:
        return "completed" if self._returncode == 0 else "failed"

    def get_error(self) -> Optional[str]:
        if self._returncode == 0:
            return None
        return self._stderr or f"Mapper exited with code {self._returncode}"


[docs] class YTDevClient(BaseYTClient): """ Development YT client implementation. Uses local file system for all operations, simulating YT behavior. Tables are stored as .jsonl files in .dev/ directory. """
[docs] def __init__( self, logger: logging.Logger, pipeline_dir: Optional[Path] = None, ) -> None: """ Initialize development YT client. Args: logger: Logger instance pipeline_dir: Pipeline directory (required for dev mode) """ if pipeline_dir is not None: resolved_pipeline_dir = Path(pipeline_dir).resolve() else: pd = os.environ.get("YT_PIPELINE_DIR") if pd: resolved_pipeline_dir = Path(pd).resolve() else: resolved_pipeline_dir = Path.cwd() logger.warning( "mode=dev but pipeline_dir not set and YT_PIPELINE_DIR not set; using cwd as pipeline_dir" ) super().__init__(logger, pipeline_dir=resolved_pipeline_dir)
def _dev_dir(self) -> Path: """Return .dev directory under pipeline_dir. Caller should mkdir when writing.""" assert self.pipeline_dir is not None, "pipeline_dir is required in dev mode" return self.pipeline_dir / ".dev" def _table_basename(self, yt_path: str) -> str: """Last path component of a YT table path, e.g. //home/.../name -> name.""" return yt_path.rstrip("/").split("/")[-1] def _table_local_path(self, yt_path: str) -> Path: """Local .jsonl path for a YT table in dev: {pipeline_dir}/.dev/{basename}.jsonl.""" return self._dev_dir() / f"{self._table_basename(yt_path)}.jsonl"
[docs] def create_path( self, path: str, node_type: Literal[ "table", "file", "map_node", "list_node", "document" ] = "map_node", ) -> None: """Create a path in YT (no-op in dev mode). Args: path: YT path to create (not used in dev mode). node_type: Type of node to create (not used in dev mode). Returns: None """ pass
[docs] def exists(self, path: str) -> bool: """ Check if a path exists in YT. In dev mode, always returns True (assumes files exist locally). Args: path: YT path to check. Returns: bool: Always True in dev mode. """ return True
[docs] def write_table( self, table_path: str, rows: List[Dict[str, Any]], append: bool = False, replication_factor: int = 1, ) -> None: """Write rows to a YT table (saves as local .jsonl file). In dev mode, tables are stored as JSONL files in the .dev directory. Each row is written as a JSON object on a single line. Args: table_path: YT table path (e.g., "//tmp/my_table"). rows: List of dictionaries representing table rows. append: If True, append to existing file; otherwise overwrite. replication_factor: Not used in dev mode (kept for API compatibility). Returns: None Example: >>> client.write_table("//tmp/data", [{"id": 1, "name": "Alice"}]) >>> # Creates .dev/data.jsonl with: {"id":1,"name":"Alice"}\\n """ mode_str = "append" if append else "overwrite" self.logger.info(f"Writing {len(rows)} rows → {table_path} ({mode_str})") p = self._table_local_path(table_path) self._dev_dir().mkdir(parents=True, exist_ok=True) with open(p, "a" if append else "w") as f: for row in rows: f.write(json.dumps(row, ensure_ascii=False) + "\n")
[docs] def read_table(self, table_path: str) -> List[Dict[str, Any]]: """Read rows from a YT table (reads from local .jsonl file). Args: table_path: YT table path (e.g., "//tmp/my_table"). Returns: List[Dict[str, Any]]: List of dictionaries representing table rows. Returns empty list if file doesn't exist. """ self.logger.info(f"Reading table: {table_path}") p = self._table_local_path(table_path) if not p.exists(): self.logger.warning(f"Table file not found: {p}, returning empty list") return [] results = [] with open(p, "r") as f: for line in f: line = line.strip() if line: results.append(json.loads(line)) self.logger.info(f"✓ Read {len(results)} rows") return results
[docs] def row_count(self, table_path: str) -> int: """Get number of rows in a YT table (counts lines in local .jsonl file). Args: table_path: YT table path (e.g., "//tmp/my_table"). Returns: int: Number of non-empty lines in the JSONL file. Returns 0 if file doesn't exist. """ p = self._table_local_path(table_path) if not p.exists(): return 0 with open(p) as f: n = sum(1 for line in f if line.strip()) self.logger.debug(f"Row count: {n}") return n
def _get_table_columns(self, table_path: str) -> List[str]: """ Get column names from a table by reading one row. In dev mode, tables are stored as JSONL files, so binary columns are less likely. This implementation matches the production client structure for consistency. Args: table_path: Path to YT table Returns: List of column names (filtered to exclude internal YQL columns) Raises: ValueError: If table is empty or cannot be read """ try: rows = self.read_table(table_path) if not rows: raise ValueError( f"Table {table_path} is empty, cannot determine columns" ) # Get column names from first row columns = list(rows[0].keys()) # Filter out internal YQL columns like _other, _yql_column_* columns = [col for col in columns if not col.startswith("_")] if not columns: # If all columns were filtered out, use all keys (fallback) columns = list(rows[0].keys()) return columns except Exception as e: self.logger.error(f"Failed to get table columns: {e}") raise
[docs] def run_yql( self, query: str, pool: str = "default", max_row_weight: Optional[str] = None, ) -> None: """ Execute a YQL query locally using DuckDB simulation. Args: query: YQL query string to execute pool: YT pool name (default: 'default') max_row_weight: Optional max row weight override """ self.logger.info("Executing YQL query (dev mode - DuckDB simulation)") self.logger.debug(f"Pool: {pool}") query_with_max_row_weight = ensure_max_row_weight_pragma( query=query, max_row_weight=max_row_weight, ) self.logger.debug(f"Query:\n{query_with_max_row_weight}") from yt_framework.yt.dev_simulator import ( DuckDBSimulator, extract_table_references, extract_output_table, ) # Create DuckDB simulator simulator = DuckDBSimulator(dev_dir=self._dev_dir(), logger=self.logger) try: # Extract table references input_tables = extract_table_references(query_with_max_row_weight) output_table = extract_output_table(query_with_max_row_weight) self.logger.debug(f"Input tables: {input_tables}") self.logger.debug(f"Output table: {output_table}") # Load input tables for table_path in input_tables: local_path = self._table_local_path(table_path) if local_path.exists(): simulator.load_table(table_path, local_path) else: self.logger.warning(f"Input table not found: {local_path}") # Execute query results, _ = simulator.execute_yql(query_with_max_row_weight) # Save results if output table specified if output_table and results is not None: self.write_table(output_table, results, append=False) self.logger.info(f"Wrote {len(results)} rows to {output_table}") self.logger.info("✓ YQL query executed successfully") except Exception as e: self.logger.error(f"Failed to execute YQL query in dev mode: {e}") raise finally: simulator.close()
# Convenience methods for common YQL operations
[docs] def join_tables( self, left_table: str, right_table: str, output_table: str, on: Union[str, List[str], Dict[str, str]], how: Literal["inner", "left", "right", "full"] = "left", select_columns: Optional[List[str]] = None, dry_run: bool = False, max_row_weight: Optional[str] = None, ) -> Optional[str]: """Join two tables using YQL (executed locally with DuckDB in dev mode).""" from yt_framework.yt.yql_builder import build_join_query query = build_join_query( left_table=left_table, right_table=right_table, output_table=output_table, on=on, how=how, select_columns=select_columns, max_row_weight=max_row_weight, ) if dry_run: return query self.run_yql(query, max_row_weight=max_row_weight) return None
[docs] def filter_table( self, input_table: str, output_table: str, condition: str, dry_run: bool = False, max_row_weight: Optional[str] = None, ) -> Optional[str]: """Filter table rows using WHERE condition (executed locally with DuckDB in dev mode).""" from yt_framework.yt.yql_builder import build_filter_query # Get columns from input table to avoid _other column issues columns = self._get_table_columns(input_table) query = build_filter_query( input_table=input_table, output_table=output_table, condition=condition, columns=columns, max_row_weight=max_row_weight, ) if dry_run: return query self.run_yql(query, max_row_weight=max_row_weight) return None
[docs] def select_columns( self, input_table: str, output_table: str, columns: List[str], dry_run: bool = False, max_row_weight: Optional[str] = None, ) -> Optional[str]: """Select specific columns from a table (executed locally with DuckDB in dev mode).""" from yt_framework.yt.yql_builder import build_select_query query = build_select_query( input_table=input_table, output_table=output_table, columns=columns, max_row_weight=max_row_weight, ) if dry_run: return query self.run_yql(query, max_row_weight=max_row_weight) return None
[docs] def group_by_aggregate( self, input_table: str, output_table: str, group_by: Union[str, List[str]], aggregations: Dict[str, Union[str, Tuple[str, str]]], dry_run: bool = False, max_row_weight: Optional[str] = None, ) -> Optional[str]: """Group by columns and compute aggregations (executed locally with DuckDB in dev mode).""" from yt_framework.yt.yql_builder import build_group_by_query query = build_group_by_query( input_table=input_table, output_table=output_table, group_by=group_by, aggregations=aggregations, max_row_weight=max_row_weight, ) if dry_run: return query self.run_yql(query, max_row_weight=max_row_weight) return None
[docs] def union_tables( self, tables: List[str], output_table: str, dry_run: bool = False, max_row_weight: Optional[str] = None, ) -> Optional[str]: """Union multiple tables (executed locally with DuckDB in dev mode).""" from yt_framework.yt.yql_builder import build_union_query # Get columns from first table to avoid _other column issues # All tables in union should have the same columns columns = self._get_table_columns(tables[0]) query = build_union_query( tables=tables, output_table=output_table, columns=columns, max_row_weight=max_row_weight, ) if dry_run: return query self.run_yql(query, max_row_weight=max_row_weight) return None
[docs] def distinct( self, input_table: str, output_table: str, columns: Optional[List[str]] = None, dry_run: bool = False, max_row_weight: Optional[str] = None, ) -> Optional[str]: """Get distinct rows from a table (executed locally with DuckDB in dev mode).""" from yt_framework.yt.yql_builder import build_distinct_query query = build_distinct_query( input_table=input_table, output_table=output_table, columns=columns, max_row_weight=max_row_weight, ) if dry_run: return query self.run_yql(query, max_row_weight=max_row_weight) return None
[docs] def sort_table( self, input_table: str, output_table: str, order_by: Union[str, List[str]], ascending: bool = True, dry_run: bool = False, max_row_weight: Optional[str] = None, ) -> Optional[str]: """Sort table by columns (executed locally with DuckDB in dev mode).""" from yt_framework.yt.yql_builder import build_sort_query # Get columns from input table to avoid _other column issues columns = self._get_table_columns(input_table) query = build_sort_query( input_table=input_table, output_table=output_table, order_by=order_by, columns=columns, ascending=ascending, max_row_weight=max_row_weight, ) if dry_run: return query self.run_yql(query, max_row_weight=max_row_weight) return None
[docs] def limit_table( self, input_table: str, output_table: str, limit: int, dry_run: bool = False, max_row_weight: Optional[str] = None, ) -> Optional[str]: """Limit number of rows from a table (executed locally with DuckDB in dev mode).""" from yt_framework.yt.yql_builder import build_limit_query # Get columns from input table to avoid _other column issues columns = self._get_table_columns(input_table) query = build_limit_query( input_table=input_table, output_table=output_table, limit=limit, columns=columns, max_row_weight=max_row_weight, ) if dry_run: return query self.run_yql(query, max_row_weight=max_row_weight) return None
[docs] def upload_file( self, local_path: Path, yt_path: str, create_parent_dir: bool = False ) -> None: """ Upload a file to YT (no-op in dev mode). Args: local_path: Local file path to upload yt_path: YT destination path create_parent_dir: If True, create parent directory if it doesn't exist (default: False) """ self.logger.debug(f"Dev: upload_file no-op {local_path.name}{yt_path}")
[docs] def upload_directory( self, local_dir: Path, yt_dir: str, pattern: str = "*" ) -> List[str]: """Upload a directory to YT (no-op in dev mode). Args: local_dir: Local directory path to upload. yt_dir: YT destination directory path. pattern: File pattern to match (not used in dev mode). Returns: List[str]: Empty list in dev mode. """ self.logger.debug(f"Dev: upload_directory no-op {local_dir}{yt_dir}") return []
[docs] def run_map( self, command: Any, input_table: str, output_table: str, files: List[Tuple[str, str]], resources: OperationResources, env: Dict[str, str], output_schema: Optional["TableSchema"] = None, max_failed_jobs: int = 1, docker_auth: Optional[Dict[str, str]] = None, job: Any = None, append: bool = False, **kwargs: Any, ) -> Operation: """Run a map operation locally using subprocess. In dev mode, executes the mapper script locally with input/output tables as JSONL files. The command is executed in a temporary sandbox directory with all dependencies available. Args: command: Mapper job (command string in dev mode). input_table: Input YT table path (read from local JSONL). output_table: Output YT table path (written to local JSONL). files: List of (yt_path, local_path) tuples for dependencies. resources: Operation resource configuration (not fully used in dev mode). env: Environment variables dictionary. output_schema: Optional output table schema (not used in dev mode). max_failed_jobs: Maximum failed jobs allowed (not used in dev mode). docker_auth: Optional Docker authentication (not used in dev mode). append: If True and output JSONL exists, append mapper stdout lines to it. **kwargs: Additional arguments (not used in dev mode). Returns: Operation: Mock operation object that simulates YT operation. Example: >>> op = client.run_map( ... command="python3 mapper.py", ... input_table="//tmp/input", ... output_table="//tmp/output", ... files=[], ... resources=OperationResources(), ... env={} ... ) """ assert self.pipeline_dir is not None _kw = dict(kwargs) pop_secure_env_client_kwargs(_kw) self.logger.info("Submitting map operation") self.logger.info(f" Input: {input_table}") self.logger.info(f" Output: {output_table}") mapper_job = job if job is not None else command self.logger.info(f" Command: {mapper_job}") if not isinstance(mapper_job, str): raise NotImplementedError( "Dev mode run_map supports only string commands; " "TypedJob mappers are supported in prod mode." ) # Prepare sandbox and input/output files sandbox_dir, sandbox_input, sandbox_output = self._prepare_map_sandbox( input_table, output_table ) # Copy files to sandbox self._upload_files(files, sandbox_dir) # Setup environment env_merged = self._setup_map_environment(env) logs_path = self._dev_dir() / f"{self._table_basename(output_table)}.log" with ( open(sandbox_input) as fin, open(sandbox_output, "w") as fout, open(logs_path, "w") as ferr, ): proc = subprocess.run( ["bash", "-c", mapper_job], stdin=fin, stdout=fout, stderr=ferr, env=env_merged, cwd=str(sandbox_dir), ) # Copy output back output_path = self._table_local_path(output_table) if proc.returncode == 0 and sandbox_output.exists(): if append and output_path.exists(): with open(output_path, "ab") as out, open(sandbox_output, "rb") as sand: out.write(sand.read()) else: shutil.copy2(sandbox_output, output_path) err_hint = f"Stderr written to {logs_path}" if proc.returncode != 0 else "" return _DevOperation(proc.returncode, err_hint) # type: ignore[return-value]
[docs] def run_vanilla( self, command: str, files: List[Tuple[str, str]], env: Dict[str, str], task_name: str = "main", job: Optional[str] = None, **kwargs, ) -> Operation: """Run a vanilla operation locally using subprocess. In dev mode, executes the vanilla script locally in a temporary sandbox directory with all dependencies available. No input/output tables are involved. Args: command: Command to execute (typically bash command with script path). files: List of (yt_path, local_path) tuples for dependencies. env: Environment variables dictionary. task_name: Task name for logging (default: "main"). **kwargs: Additional arguments (not used in dev mode). Returns: Operation: Mock operation object that simulates YT operation. """ self.logger.info("Submitting vanilla operation") _kw = dict(kwargs) pop_secure_env_client_kwargs(_kw) vanilla_job = job if job is not None else command self.logger.info(f" Command: {vanilla_job}") self.logger.info(f" Task: {task_name}") assert self.pipeline_dir is not None self._dev_dir().mkdir(parents=True, exist_ok=True) sandbox_dir = self._dev_dir() / f"{task_name}_sandbox" sandbox_dir.mkdir(parents=True, exist_ok=True) self._upload_files(files, sandbox_dir) # Copy config.yaml to the correct location in sandbox if it exists # config.yaml dependency has local_name="config.yaml" but should be at stages/{task_name}/config.yaml stage_config_source = self.pipeline_dir / "stages" / task_name / "config.yaml" if stage_config_source.exists(): stage_config_dest = sandbox_dir / "stages" / task_name / "config.yaml" stage_config_dest.parent.mkdir(parents=True, exist_ok=True) shutil.copy2(stage_config_source, stage_config_dest) self.logger.debug(f" Dev: copied config.yaml to {stage_config_dest}") # Convert YT paths in command to local sandbox paths # YT path format: //tmp/.../build/stages/.../vanilla.py # Local path format: stages/.../vanilla.py (relative to sandbox) local_command = vanilla_job if "/build/" in vanilla_job: # Extract the path after /build/ and use it as local path import re # Split command into parts and find the /build/ part # Command format: "python3 //tmp/examples/05_vanilla_operation/build/stages/run_vanilla/src/vanilla.py" # We want to extract: "stages/run_vanilla/src/vanilla.py" parts = vanilla_job.split("/build/", 1) if len(parts) == 2: # parts[1] contains "stages/run_vanilla/src/vanilla.py" (may have leading/trailing spaces) local_path = parts[1].strip() # Replace the entire YT path with the local path # Match pattern: //anything/build/local_path yt_path_pattern = r"//[^/\s]+(?:/[^/\s]+)*/build/" + re.escape( local_path.split()[0] ) local_command = re.sub( yt_path_pattern, local_path.split()[0], vanilla_job ) if local_command != vanilla_job: self.logger.debug( f" Dev: converted command: {vanilla_job} -> {local_command}" ) else: # Fallback: simple string replacement yt_full_path = "/build/".join(parts) if yt_full_path in vanilla_job: local_command = vanilla_job.replace( yt_full_path, local_path.split()[0] ) self.logger.debug( f" Dev: converted command (fallback): {vanilla_job} -> {local_command}" ) logs_path = self._dev_dir() / f"{task_name}.log" # Set up environment with JOB_CONFIG_PATH pointing to the config file in sandbox env_merged = self._build_env(env) config_path_in_sandbox = sandbox_dir / "stages" / task_name / "config.yaml" if config_path_in_sandbox.exists(): env_merged["JOB_CONFIG_PATH"] = str(config_path_in_sandbox) self.logger.debug(f" Dev: JOB_CONFIG_PATH={config_path_in_sandbox}") else: self.logger.warning( f" Dev: config file not found at {config_path_in_sandbox}" ) self.logger.info(f" Dev: sandbox={sandbox_dir}") self.logger.info(f" Dev: stderr={logs_path}") with open(logs_path, "w") as ferr: proc = subprocess.run( ["bash", "-c", local_command], stderr=ferr, env=env_merged, cwd=str(sandbox_dir), ) err_hint = f"Output written to {logs_path}" if proc.returncode != 0 else "" return _DevOperation(proc.returncode, err_hint) # type: ignore[return-value]
[docs] def run_map_reduce( self, mapper: Any, reducer: Any, input_table: str, output_table: str, reduce_by: List[str], files: List[Tuple[str, str]], resources: OperationResources, env: Dict[str, str], sort_by: Optional[List[str]] = None, output_schema: Optional["TableSchema"] = None, max_failed_jobs: int = 1, docker_auth: Optional[Dict[str, str]] = None, map_job: Any = None, reduce_job: Any = None, **kwargs: Any, ) -> Operation: """Dev: no-op; copy input table to output table.""" _kw = dict(kwargs) pop_secure_env_client_kwargs(_kw) mapper_leg = resolve_aliased_job( legacy_name="mapper", legacy_value=mapper, preferred_name="map_job", preferred_value=map_job, ) reducer_leg = resolve_aliased_job( legacy_name="reducer", legacy_value=reducer, preferred_name="reduce_job", preferred_value=reduce_job, ) def _leg_desc(obj: Any) -> str: if is_typed_job(obj): return "TypedJob" if isinstance(obj, str): return "command (prod uses JsonFormat on this leg)" return f"invalid leg type {type(obj).__name__} (expected TypedJob or str)" self.logger.info( "Dev: map-reduce mapper leg: %s; reducer leg: %s", _leg_desc(mapper_leg), _leg_desc(reducer_leg), ) self.logger.info("Dev: map-reduce no-op (copying input -> output)") assert self.pipeline_dir is not None self._dev_dir().mkdir(parents=True, exist_ok=True) input_path = self._table_local_path(input_table) output_path = self._table_local_path(output_table) if input_path.exists(): shutil.copy2(input_path, output_path) else: output_path.parent.mkdir(parents=True, exist_ok=True) output_path.write_text("") return _DevOperation(0) # type: ignore[return-value]
[docs] def run_reduce( self, reducer: Any, input_table: str, output_table: str, reduce_by: List[str], files: List[Tuple[str, str]], resources: OperationResources, env: Dict[str, str], output_schema: Optional["TableSchema"] = None, max_failed_jobs: int = 1, docker_auth: Optional[Dict[str, str]] = None, job: Any = None, **kwargs: Any, ) -> Operation: """Dev: no-op; copy input table to output table.""" _kw = dict(kwargs) pop_secure_env_client_kwargs(_kw) reducer_leg = resolve_aliased_job( legacy_name="reducer", legacy_value=reducer, preferred_name="job", preferred_value=job, ) if is_typed_job(reducer_leg): rdesc = "TypedJob" elif isinstance(reducer_leg, str): rdesc = "command (prod uses JsonFormat on this leg)" else: rdesc = f"invalid leg type {type(reducer_leg).__name__} (expected TypedJob or str)" self.logger.info("Dev: reduce leg: %s", rdesc) self.logger.info("Dev: reduce no-op (copying input -> output)") assert self.pipeline_dir is not None self._dev_dir().mkdir(parents=True, exist_ok=True) input_path = self._table_local_path(input_table) output_path = self._table_local_path(output_table) if input_path.exists(): shutil.copy2(input_path, output_path) else: output_path.parent.mkdir(parents=True, exist_ok=True) output_path.write_text("") return _DevOperation(0) # type: ignore[return-value]
[docs] def run_sort( self, table_path: str, sort_by: List[str], pool: Optional[str] = None, pool_tree: Optional[str] = None, **kwargs: Any, ) -> None: """Dev: no-op (table unchanged).""" self.logger.info(f"Dev: run_sort no-op for {table_path} by {sort_by}")
def _build_env(self, env: Dict[str, str]) -> Dict[str, str]: """Build environment variables for subprocess.""" # Set PYTHONPATH to include pipeline dir env_merged = {**os.environ, **(env or {})} pp_parts = [str(self.pipeline_dir)] # Add yt_framework to PYTHONPATH import yt_framework yt_framework_dir = Path(yt_framework.__file__).parent if yt_framework_dir.parent not in [Path(p) for p in pp_parts]: pp_parts.append(str(yt_framework_dir.parent)) # Add ytjobs to PYTHONPATH import ytjobs ytjobs_dir = Path(ytjobs.__file__).parent if ytjobs_dir.parent not in [Path(p) for p in pp_parts]: pp_parts.append(str(ytjobs_dir.parent)) if env_merged.get("PYTHONPATH"): pp_parts.append(env_merged["PYTHONPATH"]) env_merged["PYTHONPATH"] = os.pathsep.join(pp_parts) return env_merged def _upload_files(self, files: List[Tuple[str, str]], sandbox_dir: Path) -> None: """Upload files to sandbox directory.""" assert self.pipeline_dir is not None # Try to get local checkpoint path from stage config for checkpoint files local_checkpoint_path = self._get_local_checkpoint_path() if local_checkpoint_path: self.logger.debug( f" Dev: local_checkpoint_path available: {local_checkpoint_path}" ) for file_info in files: yt_path, local_name = file_info copied = False # Handle checkpoint files - copy from local_checkpoint_path if available # Match if either the yt_path filename or local_name matches the checkpoint filename if local_checkpoint_path: checkpoint_filename = Path(local_checkpoint_path).name yt_filename = Path(yt_path).name # Check if this is a checkpoint file (matches by filename) if ( checkpoint_filename == yt_filename or checkpoint_filename == local_name ): checkpoint_path = Path(local_checkpoint_path) if checkpoint_path.exists(): # Use the expected local_name in sandbox (from dependency) dest_file = sandbox_dir / local_name dest_file.parent.mkdir(parents=True, exist_ok=True) self.logger.info( f" Dev: copying checkpoint {checkpoint_path} -> {dest_file}" ) shutil.copy2(checkpoint_path, dest_file) copied = True else: self.logger.warning( f" Dev: checkpoint path does not exist: {checkpoint_path}" ) # Handle build files (source.tar.gz, etc.) if not copied and yt_path.endswith(".tar.gz"): # Try to find the file in .build directory local_build = self.pipeline_dir / ".build" if local_build.exists(): # Extract just the filename from yt_path filename = Path(yt_path).name source_file = local_build / filename if source_file.exists(): dest_file = sandbox_dir / local_name dest_file.parent.mkdir(parents=True, exist_ok=True) self.logger.debug( f" Dev: copying {source_file} -> {dest_file}" ) shutil.copy2(source_file, dest_file) copied = True # Handle regular stage files and ytjobs files # local_name is like "stages/run_vanilla/src/vanilla.py" or "ytjobs/..." if not copied: # Try to find the file relative to pipeline_dir source_file = self.pipeline_dir / local_name if source_file.exists(): dest_file = sandbox_dir / local_name dest_file.parent.mkdir(parents=True, exist_ok=True) self.logger.debug(f" Dev: copying {source_file} -> {dest_file}") shutil.copy2(source_file, dest_file) copied = True else: # Also try ytjobs files - they might be in the installed package if local_name.startswith("ytjobs/"): try: import ytjobs ytjobs_dir = Path(ytjobs.__file__).parent ytjobs_rel_path = local_name.replace("ytjobs/", "") source_file = ytjobs_dir / ytjobs_rel_path if source_file.exists(): dest_file = sandbox_dir / local_name dest_file.parent.mkdir(parents=True, exist_ok=True) self.logger.debug( f" Dev: copying ytjobs {source_file} -> {dest_file}" ) shutil.copy2(source_file, dest_file) copied = True except ImportError: pass if not copied: self.logger.debug( f" Dev: skipping file {yt_path} -> {local_name} (not found locally)" ) def _prepare_map_sandbox( self, input_table: str, output_table: str ) -> Tuple[Path, Path, Path]: """Prepare sandbox directory and input/output file paths.""" assert self.pipeline_dir is not None input_path = self._table_local_path(input_table) if not input_path.exists(): raise FileNotFoundError( f"Dev: input table file not found: {input_path}. " "Create it (e.g. run a previous stage or add .jsonl manually)." ) self._dev_dir().mkdir(parents=True, exist_ok=True) # Create sandbox directory sandbox_dir = ( self._dev_dir() / f"sandbox_{self._table_basename(input_table)}->{self._table_basename(output_table)}" ) sandbox_dir.mkdir(parents=True, exist_ok=True) # Setup input/output files in sandbox sandbox_input = sandbox_dir / "input.jsonl" sandbox_output = sandbox_dir / "output.jsonl" shutil.copy2(input_path, sandbox_input) self.logger.info(f" Dev: sandbox={sandbox_dir}") self.logger.info(f" Dev: stdin={sandbox_input}, stdout={sandbox_output}") return sandbox_dir, sandbox_input, sandbox_output def _setup_map_environment(self, env: Dict[str, str]) -> Dict[str, str]: """Setup environment variables for map operation.""" env_merged = self._build_env(env) # Try to setup checkpoint config from stage config # This attempts to find stage config by looking for stages directory # If found, sets JOB_CONFIG_PATH and CHECKPOINT_FILE env vars self._setup_checkpoint_config(env_merged) return env_merged def _find_checkpoint_in_config(self, stage_config) -> Optional[str]: """ Find checkpoint local_checkpoint_path in stage config. Searches through all operations in client.operations dynamically, then falls back to client.local_checkpoint_path (legacy). Args: stage_config: OmegaConf DictConfig for the stage Returns: Local checkpoint path string if found, None otherwise """ from omegaconf import OmegaConf # First, try legacy path local_checkpoint = OmegaConf.select( stage_config, "client.local_checkpoint_path" ) if local_checkpoint: return str(local_checkpoint) # Then, iterate over all operations dynamically operations = OmegaConf.select(stage_config, "client.operations") if operations: for op_name in operations.keys(): checkpoint_path = ( f"client.operations.{op_name}.checkpoint.local_checkpoint_path" ) local_checkpoint = OmegaConf.select(stage_config, checkpoint_path) if local_checkpoint: return str(local_checkpoint) return None def _get_local_checkpoint_path(self) -> Optional[str]: """Get local checkpoint path from stage config if available.""" assert self.pipeline_dir is not None # Try to find stage config by scanning stages directory stages_dir = self.pipeline_dir / "stages" if not stages_dir.exists(): return None try: from omegaconf import OmegaConf # Try to find a stage config with checkpoint configuration # Check all stage configs, not just the first one for stage_dir in stages_dir.iterdir(): if stage_dir.is_dir(): stage_config_path = stage_dir / "config.yaml" if stage_config_path.exists(): try: stage_config = OmegaConf.load(stage_config_path) local_checkpoint = self._find_checkpoint_in_config( stage_config ) if local_checkpoint: checkpoint_path = Path(local_checkpoint).resolve() if checkpoint_path.exists(): self.logger.debug( f" Dev: found local_checkpoint_path: {checkpoint_path}" ) return str(checkpoint_path) except Exception as e: # Continue to next stage config self.logger.debug( f" Dev: error reading {stage_config_path}: {e}" ) continue except Exception as e: self.logger.debug(f" Dev: error scanning stages directory: {e}") return None def _setup_checkpoint_config(self, env_merged: Dict[str, str]) -> None: """Setup checkpoint config from stage config if available.""" assert self.pipeline_dir is not None # Try to find stage config by scanning stages directory # This is a best-effort approach since we no longer have mapper_script path stages_dir = self.pipeline_dir / "stages" if not stages_dir.exists(): return # Look for any stage config that might be relevant # In practice, the calling code should set JOB_CONFIG_PATH if needed # This is a fallback for backward compatibility try: from omegaconf import OmegaConf # Try to find a stage config (use first one found as fallback) for stage_dir in stages_dir.iterdir(): if stage_dir.is_dir(): stage_config_path = stage_dir / "config.yaml" if stage_config_path.exists(): env_merged["JOB_CONFIG_PATH"] = str(stage_config_path) try: stage_config = OmegaConf.load(stage_config_path) # Find checkpoint path dynamically (searches all operations) local_checkpoint = self._find_checkpoint_in_config( stage_config ) # Get model_name from job.model_name (used as checkpoint filename) model_name = OmegaConf.select( stage_config, "job.model_name" ) if local_checkpoint: checkpoint_path = Path(local_checkpoint).resolve() if checkpoint_path.exists(): # Set CHECKPOINT_FILE to the filename (not full path) since the file # will be copied to the sandbox directory and processor expects it there checkpoint_filename = ( model_name or checkpoint_path.name ) env_merged["CHECKPOINT_FILE"] = checkpoint_filename self.logger.info( f" Dev: checkpoint file set to: {checkpoint_filename} (from {checkpoint_path})" ) else: self.logger.warning( f" Dev: local_checkpoint_path not found: {checkpoint_path}" ) except Exception as e: self.logger.warning( f" Dev: failed to load checkpoint config: {e}" ) # Only use first found config as fallback break except Exception as e: self.logger.debug(f" Dev: could not setup checkpoint config: {e}")