"""
Production YT Client
====================
Production implementation of YT client using actual YTsaurus client.
"""
import logging
from pathlib import Path
from typing import List, Dict, Any, Optional, Union, Tuple, Literal
from yt.wrapper import ( # pyright: ignore[reportMissingImports]
YtClient,
FilePath,
TablePath,
Operation,
MapSpecBuilder,
VanillaSpecBuilder,
format as yt_format,
)
from yt.wrapper.schema import TableSchema # pyright: ignore[reportMissingImports]
from yt_framework.yt.client_base import BaseYTClient, OperationResources
from yt_framework.utils.ignore import YTIgnoreMatcher
[docs]
class YTProdClient(BaseYTClient):
"""
Production YT client implementation.
Uses actual YTsaurus client for all operations.
"""
[docs]
def __init__(
self,
logger: logging.Logger,
secrets: Dict[str, str],
) -> None:
"""
Initialize production YT client.
Args:
logger: Logger instance
secrets: Dictionary containing YT credentials. Expected keys:
- YT_PROXY
- YT_TOKEN
"""
super().__init__(logger)
yt_proxy = secrets.get("YT_PROXY")
if not yt_proxy:
raise ValueError(
"YT_PROXY is not set (check secrets.env or environment variables)"
)
yt_token = secrets.get("YT_TOKEN")
if not yt_token:
raise ValueError(
"YT_TOKEN is not set (check secrets.env or environment variables)"
)
self.client = YtClient(proxy=yt_proxy, token=yt_token)
try:
if "proxy" in self.client.config:
self.client.config["proxy"]["enable_proxy_discovery"] = False # type: ignore[index]
self.logger.debug(
f"YT Client initialized with proxy: {yt_proxy} (proxy discovery disabled)"
)
else:
self.logger.debug(f"YT Client initialized with proxy: {yt_proxy}")
except Exception as e:
self.logger.warning(
f"Could not disable proxy discovery: {e}. Continuing with default settings."
)
self.logger.debug(f"YT Client initialized with proxy: {yt_proxy}")
[docs]
def create_path(
self,
path: str,
node_type: Literal[
"table", "file", "map_node", "list_node", "document"
] = "map_node",
) -> None:
"""Create a path in YT.
Args:
path: YT path to create.
node_type: Type of node to create (default: "map_node").
Returns:
None
Raises:
Exception: If path creation fails.
"""
try:
self.client.create(node_type, path, recursive=True, ignore_existing=True)
except Exception as e:
self.logger.error(f"Failed to create path: {e}")
raise
[docs]
def exists(self, path: str) -> bool:
"""Check if a path exists in YT.
Args:
path: YT path to check.
Returns:
bool: True if path exists, False otherwise.
Raises:
Exception: If check fails.
"""
try:
return self.client.exists(path)
except Exception as e:
self.logger.error(f"Failed to check if path exists: {e}")
raise
[docs]
def write_table(
self,
table_path: str,
rows: List[Dict[str, Any]],
append: bool = False,
replication_factor: int = 1,
make_parents: bool = True,
) -> None:
"""Write rows to a YT table.
Args:
table_path: YT table path
rows: List of dictionaries representing table rows
append: If True, append to existing table (default: False)
replication_factor: Replication factor for the table (default: 1)
make_parents: If True, create parent directories if they don't exist (default: True)
"""
mode_str = "append" if append else "overwrite"
self.logger.info(f"Writing {len(rows)} rows → {table_path} ({mode_str})")
try:
# Create parent directories if they don't exist
if make_parents and "/" in table_path:
parent_dir = "/".join(table_path.rstrip("/").split("/")[:-1])
if parent_dir:
self.logger.debug(f"Ensuring parent directory exists: {parent_dir}")
self.create_path(parent_dir, node_type="map_node")
# Create table with replication factor if it doesn't exist
if not append:
if self.client.exists(table_path):
self.client.remove(table_path, force=True)
self.client.create(
"table",
table_path,
attributes={"replication_factor": replication_factor},
ignore_existing=True,
)
self.client.write_table(
TablePath(table_path, append=append),
rows,
format=yt_format.JsonFormat(),
)
except Exception as e:
self.logger.error(f"Failed to write table: {e}")
raise
[docs]
def read_table(self, table_path: str) -> List[Dict[str, Any]]:
"""Read rows from a YT table.
Args:
table_path: YT table path to read.
Returns:
List[Dict[str, Any]]: List of dictionaries representing table rows.
Raises:
Exception: If table read fails.
"""
self.logger.info(f"Reading table: {table_path}")
try:
# Type ignore needed because YT client's read_table has complex return types
# but when called with JsonFormat(), it returns an iterable of dicts
table_iterator = self.client.read_table(
TablePath(table_path), format=yt_format.JsonFormat()
)
results: List[Dict[str, Any]] = list(table_iterator) # type: ignore[arg-type]
self.logger.info(f"✓ Read {len(results)} rows")
return results
except Exception as e:
self.logger.error(f"Failed to read table: {e}")
raise
[docs]
def row_count(self, table_path: str) -> int:
"""Get number of rows in a YT table.
Args:
table_path: YT table path.
Returns:
int: Number of rows in the table.
Raises:
Exception: If row count query fails.
"""
try:
count = self.client.row_count(table_path)
self.logger.debug(f"Row count: {count}")
return count
except Exception as e:
self.logger.error(f"Failed to get row count: {e}")
raise
def _get_table_columns(self, table_path: str) -> List[str]:
"""
Get column names from a table.
Tries multiple methods:
1. Get schema from table attributes (handles binary columns)
2. Read one row from table
3. Use YQL query with LIMIT 0 to infer schema (when reading fails due to binary columns)
Args:
table_path: Path to YT table
Returns:
List of column names (filtered to exclude internal YQL columns)
Raises:
ValueError: If table is empty or cannot be read
"""
# Method 1: Try to get schema from table attributes first (handles binary columns)
try:
attrs = self.client.get(table_path, attributes=["schema"]) # type: ignore[assignment]
if attrs and isinstance(attrs, dict) and "schema" in attrs: # type: ignore[operator]
schema = attrs["schema"] # type: ignore[index]
if schema and isinstance(schema, list):
columns = [
col["name"]
for col in schema
if isinstance(col, dict) and "name" in col
]
# Filter out internal YQL columns like _other, _yql_column_*
columns = [col for col in columns if not col.startswith("_")]
if columns:
return columns
except Exception as e:
self.logger.debug(
f"Could not get schema from attributes: {e}, trying to read table"
)
# Method 2: Try to read one row (may fail with binary columns)
try:
rows = self.read_table(table_path)
if not rows:
raise ValueError(
f"Table {table_path} is empty, cannot determine columns"
)
# Get column names from first row
columns = list(rows[0].keys())
# Filter out internal YQL columns like _other, _yql_column_*
columns = [col for col in columns if not col.startswith("_")]
if not columns:
# If all columns were filtered out, use all keys (fallback)
columns = list(rows[0].keys())
return columns
except Exception as read_error:
# Method 3: If reading fails (e.g., binary columns), use YQL to infer schema
error_str = str(read_error)
if (
"Failed to decode string" in error_str
or "encoding" in error_str.lower()
):
temp_output = None
try:
self.logger.debug(
"Reading failed due to binary columns, using YQL to infer schema"
)
# Use YQL to create a temporary table with LIMIT 0 to infer schema
# This doesn't read actual data, just infers the schema
import uuid
temp_output = f"{table_path}.temp_schema_{uuid.uuid4().hex[:8]}"
query = f"""PRAGMA yt.InferSchema = '1';
INSERT INTO `{temp_output}` WITH TRUNCATE
SELECT * FROM `{table_path}` LIMIT 0;"""
# Execute query to create temp table with schema
self.run_yql(query)
# Get schema from the temporary table
temp_attrs = self.client.get(temp_output, attributes=["schema"]) # type: ignore[assignment]
if temp_attrs and isinstance(temp_attrs, dict) and "schema" in temp_attrs: # type: ignore[operator]
temp_schema = temp_attrs["schema"] # type: ignore[index]
if temp_schema and isinstance(temp_schema, list):
columns = [
col["name"]
for col in temp_schema
if isinstance(col, dict) and "name" in col
]
# Filter out internal YQL columns
columns = [
col for col in columns if not col.startswith("_")
]
if columns:
# Clean up temp table before returning
if temp_output:
try:
self.client.remove(temp_output)
except Exception:
pass
return columns
# Clean up temp table if we got here
if temp_output:
try:
self.client.remove(temp_output)
except Exception:
pass
except Exception as yql_error:
self.logger.debug("YQL schema inference failed: %s", yql_error)
# Clean up temp table if it was created
if temp_output:
try:
self.client.remove(temp_output)
except Exception:
pass
# If all methods fail, provide helpful error message
raise ValueError(
f"Table {table_path} contains binary columns that cannot be decoded. "
f"This usually happens when a table was created with SELECT * and contains "
f"internal YQL columns like _yql_column_0. Please recreate the table with "
f"explicit column selection, or delete and recreate it. Original error: {read_error}"
) from read_error
raise ValueError(
f"Failed to get table columns from {table_path}: {read_error}"
) from read_error
[docs]
def run_yql(
self,
query: str,
pool: str = "default",
) -> None:
"""
Execute a YQL query on YT cluster.
Args:
query: YQL query string to execute
pool: YT pool name (default: 'default')
Raises:
Exception: If query execution fails
"""
self.logger.info("Executing YQL query on YT cluster")
self.logger.debug(f"Pool: {pool}")
self.logger.debug(f"Query:\n{query}")
try:
# Execute YQL query on YT cluster using Python API
query_obj = self.client.run_query(
engine="yql",
query=query,
settings={"pool": pool},
)
# Wait for query to complete
self.logger.info(f"Query started: {query_obj.id}")
# Check result
state = query_obj.get_state()
if state == "completed":
self.logger.info("✓ YQL query completed successfully")
else:
error = query_obj.get_error()
raise RuntimeError(f"Query failed with state {state}: {error}")
except Exception as e:
self.logger.error(f"Failed to execute YQL query: {e}")
raise
# Convenience methods for common YQL operations
[docs]
def join_tables(
self,
left_table: str,
right_table: str,
output_table: str,
on: Union[str, List[str], Dict[str, str]],
how: Literal["inner", "left", "right", "full"] = "left",
select_columns: Optional[List[str]] = None,
dry_run: bool = False,
) -> Optional[str]:
"""
Join two tables using YQL.
Args:
left_table: Path to left table
right_table: Path to right table
output_table: Path to output table
on: Join key(s) - column name(s) to join on
- str: Same column name on both sides (e.g., "user_id")
- List[str]: Multiple columns with same names (e.g., ["user_id", "region"])
- Dict[str, str]: Different column names (e.g., {"left": "input_s3_path", "right": "path"})
how: Join type - "inner", "left", "right", or "full"
select_columns: Optional list of columns to select (with table aliases)
dry_run: If True, return the YQL query without executing
Returns:
YQL query string if dry_run=True, None otherwise
"""
from yt_framework.yt.yql_builder import build_join_query
query = build_join_query(
left_table=left_table,
right_table=right_table,
output_table=output_table,
on=on,
how=how,
select_columns=select_columns,
)
if dry_run:
return query
self.run_yql(query)
return None
[docs]
def filter_table(
self,
input_table: str,
output_table: str,
condition: str,
dry_run: bool = False,
) -> Optional[str]:
"""
Filter table rows using WHERE condition.
Args:
input_table: Path to input table
output_table: Path to output table
condition: WHERE condition (e.g., "status = 'active' AND total > 100")
dry_run: If True, return the YQL query without executing
Returns:
YQL query string if dry_run=True, None otherwise
"""
from yt_framework.yt.yql_builder import build_filter_query
# Get columns from input table to avoid _other column issues
columns = self._get_table_columns(input_table)
query = build_filter_query(
input_table=input_table,
output_table=output_table,
condition=condition,
columns=columns,
)
if dry_run:
return query
self.run_yql(query)
return None
[docs]
def select_columns(
self,
input_table: str,
output_table: str,
columns: List[str],
dry_run: bool = False,
) -> Optional[str]:
"""
Select specific columns from a table.
Args:
input_table: Path to input table
output_table: Path to output table
columns: List of column names to select
dry_run: If True, return the YQL query without executing
Returns:
YQL query string if dry_run=True, None otherwise
"""
from yt_framework.yt.yql_builder import build_select_query
query = build_select_query(
input_table=input_table,
output_table=output_table,
columns=columns,
)
if dry_run:
return query
self.run_yql(query)
return None
[docs]
def group_by_aggregate(
self,
input_table: str,
output_table: str,
group_by: Union[str, List[str]],
aggregations: Dict[str, Union[str, Tuple[str, str]]],
dry_run: bool = False,
) -> Optional[str]:
"""
Group by columns and compute aggregations.
Args:
input_table: Path to input table
output_table: Path to output table
group_by: Column(s) to group by
aggregations: Dict mapping output column names to aggregation functions
e.g., {"order_count": "count", "total_amount": "sum"}
dry_run: If True, return the YQL query without executing
Returns:
YQL query string if dry_run=True, None otherwise
"""
from yt_framework.yt.yql_builder import build_group_by_query
query = build_group_by_query(
input_table=input_table,
output_table=output_table,
group_by=group_by,
aggregations=aggregations,
)
if dry_run:
return query
self.run_yql(query)
return None
[docs]
def union_tables(
self,
tables: List[str],
output_table: str,
dry_run: bool = False,
) -> Optional[str]:
"""
Union multiple tables.
Args:
tables: List of table paths to union
output_table: Path to output table
dry_run: If True, return the YQL query without executing
Returns:
YQL query string if dry_run=True, None otherwise
"""
from yt_framework.yt.yql_builder import build_union_query
# Get columns from first table to avoid _other column issues
# All tables in union should have the same columns
columns = self._get_table_columns(tables[0])
query = build_union_query(
tables=tables,
output_table=output_table,
columns=columns,
)
if dry_run:
return query
self.run_yql(query)
return None
[docs]
def distinct(
self,
input_table: str,
output_table: str,
columns: Optional[List[str]] = None,
dry_run: bool = False,
) -> Optional[str]:
"""
Get distinct rows from a table.
Args:
input_table: Path to input table
output_table: Path to output table
columns: Optional list of columns to select (if None, selects all)
dry_run: If True, return the YQL query without executing
Returns:
YQL query string if dry_run=True, None otherwise
"""
from yt_framework.yt.yql_builder import build_distinct_query
query = build_distinct_query(
input_table=input_table,
output_table=output_table,
columns=columns,
)
if dry_run:
return query
self.run_yql(query)
return None
[docs]
def sort_table(
self,
input_table: str,
output_table: str,
order_by: Union[str, List[str]],
ascending: bool = True,
dry_run: bool = False,
) -> Optional[str]:
"""
Sort table by columns.
WARNING: Sorting large tables can be expensive. Use with caution.
Args:
input_table: Path to input table
output_table: Path to output table
order_by: Column(s) to sort by
ascending: Sort direction (True for ASC, False for DESC)
dry_run: If True, return the YQL query without executing
Returns:
YQL query string if dry_run=True, None otherwise
"""
from yt_framework.yt.yql_builder import build_sort_query
# Get columns from input table to avoid _other column issues
columns = self._get_table_columns(input_table)
query = build_sort_query(
input_table=input_table,
output_table=output_table,
order_by=order_by,
columns=columns,
ascending=ascending,
)
if dry_run:
return query
self.run_yql(query)
return None
[docs]
def limit_table(
self,
input_table: str,
output_table: str,
limit: int,
dry_run: bool = False,
) -> Optional[str]:
"""
Limit number of rows from a table.
Args:
input_table: Path to input table
output_table: Path to output table
limit: Maximum number of rows to return
dry_run: If True, return the YQL query without executing
Returns:
YQL query string if dry_run=True, None otherwise
"""
from yt_framework.yt.yql_builder import build_limit_query
# Get columns from input table to avoid _other column issues
columns = self._get_table_columns(input_table)
query = build_limit_query(
input_table=input_table,
output_table=output_table,
limit=limit,
columns=columns,
)
if dry_run:
return query
self.run_yql(query)
return None
[docs]
def upload_file(
self, local_path: Path, yt_path: str, create_parent_dir: bool = False
) -> None:
"""
Upload a file to YT.
Args:
local_path: Local file path to upload
yt_path: YT destination path
create_parent_dir: If True, create parent directory if it doesn't exist (default: False)
"""
self.logger.info(f"Uploading {local_path.name} → {yt_path}")
try:
# Ensure parent directory exists before uploading if requested
if create_parent_dir:
# Extract parent directory from yt_path (everything before the last '/')
if "/" in yt_path:
parent_dir = "/".join(yt_path.split("/")[:-1])
if parent_dir:
self.logger.debug(
f"Ensuring parent directory exists: {parent_dir}"
)
self.create_path(parent_dir, node_type="map_node")
with open(local_path, "rb") as f:
self.client.write_file(
yt_path,
f,
force_create=True,
compute_md5=True,
)
self.logger.debug(f"Upload completed: {yt_path}")
except Exception as e:
self.logger.error(f"Failed to upload file: {e}")
raise
[docs]
def upload_directory(
self, local_dir: Path, yt_dir: str, pattern: str = "*"
) -> List[str]:
"""
Upload a directory to YT cluster.
Recursively uploads all files from a local directory to a YT directory,
respecting .ytignore patterns if present.
Args:
local_dir: Local directory path to upload
yt_dir: YT destination directory path
pattern: File pattern to match (default: "*" for all files)
Returns:
List of uploaded YT file paths
Raises:
Exception: If directory upload fails
"""
self.logger.info(f"Uploading directory {local_dir} → {yt_dir}")
# Create YT directory
self.create_path(yt_dir, node_type="map_node")
# Initialize .ytignore matcher
ignore_matcher = YTIgnoreMatcher(local_dir)
uploaded = []
ignored_count = 0
for local_file in local_dir.rglob(pattern):
if local_file.is_file():
# Check if file should be ignored
if ignore_matcher.should_ignore(local_file):
self.logger.debug(
f"Ignoring file (matched .ytignore): {local_file}"
)
ignored_count += 1
continue
# Compute relative path
rel_path = local_file.relative_to(local_dir)
yt_path = f"{yt_dir}/{rel_path}".replace("\\", "/")
# Create parent directories if needed
parent = "/".join(yt_path.split("/")[:-1])
if parent:
self.create_path(parent, node_type="map_node")
# Upload file
self.upload_file(local_file, yt_path)
uploaded.append(yt_path)
self.logger.info(f"Uploaded {len(uploaded)} files")
if ignored_count > 0:
self.logger.info(
f"Ignored {ignored_count} files (matched .ytignore patterns)"
)
return uploaded
[docs]
def run_map(
self,
command: str,
input_table: str,
output_table: str,
files: List[Tuple[str, str]],
resources: OperationResources,
env: Dict[str, str],
output_schema: Optional[TableSchema] = None,
max_failed_jobs: int = 1,
docker_auth: Optional[Dict[str, str]] = None,
) -> Operation:
"""Run a map operation on YT cluster.
Submits a map operation that processes each row of the input table independently
and writes results to the output table. The operation runs on the YT cluster
with the specified resources and dependencies.
Args:
command: Command to execute (typically bash command with script path).
input_table: Input YT table path.
output_table: Output YT table path.
files: List of (yt_path, local_path) tuples for dependencies.
resources: Operation resource configuration (memory, CPU, GPU, etc.).
env: Environment variables dictionary.
output_schema: Optional output table schema for typed output.
max_failed_jobs: Maximum failed jobs allowed before operation fails.
docker_auth: Optional Docker authentication for private registries.
Returns:
Operation: YT operation object that can be monitored and waited on.
Raises:
Exception: If operation submission fails.
"""
self.logger.info("Submitting map operation")
self.logger.info(f" Input: {input_table}")
self.logger.info(f" Output: {output_table}")
self.logger.info(f" Output Schema: {output_schema}")
self.logger.info(f" Command: {command}")
self.logger.info(f" Files: {files}")
self.logger.info(f" Resources: {resources}")
try:
file_paths = [
FilePath(yt_path, file_name=local_path) for yt_path, local_path in files
]
output_path = TablePath(output_table, append=False, schema=output_schema)
spec_builder = (
MapSpecBuilder()
.pool(resources.pool)
.resource_limits({"user_slots": resources.user_slots})
.max_failed_job_count(max_failed_jobs)
.job_count(resources.job_count)
.input_table_paths([input_table])
.output_table_paths([output_path])
)
# Set pool tree if specified
if resources.pool_tree:
spec_builder = spec_builder.pool_trees([resources.pool_tree])
self.logger.debug(f"Set pool tree to {resources.pool_tree}")
mapper_builder = (
spec_builder.begin_mapper()
.command(command)
.format(yt_format.JsonFormat(encode_utf8=False))
.file_paths(file_paths)
.environment(env)
.memory_limit(resources.memory_gb * 1024**3)
.cpu_limit(resources.cpu_limit)
.gpu_limit(resources.gpu_limit)
)
if resources.docker_image:
mapper_builder = mapper_builder.docker_image(resources.docker_image)
spec_builder.secure_vault({"docker_auth": docker_auth})
mapper_builder = mapper_builder.end_mapper()
operation = self.client.run_operation(spec_builder, sync=False)
if operation is None:
raise RuntimeError(
"Failed to submit operation: run_operation returned None"
)
self.logger.info(f"Operation submitted: {operation.id}")
return operation
except Exception as e:
self.logger.error(f"Failed to submit operation: {e}")
raise
[docs]
def run_vanilla(
self,
command: str,
files: List[Tuple[str, str]],
env: Dict[str, str],
task_name: str,
resources: OperationResources,
docker_auth: Optional[Dict[str, str]] = None,
max_failed_jobs: int = 1,
) -> Operation:
"""Run a vanilla operation on YT cluster.
Submits a vanilla operation that runs a standalone job without input/output tables.
The operation runs on the YT cluster with the specified resources and dependencies.
Args:
command: Command to execute (typically bash command with script path).
files: List of (yt_path, local_path) tuples for dependencies.
env: Environment variables dictionary.
task_name: Task name for the operation.
resources: Operation resource configuration (memory, CPU, GPU, etc.).
docker_auth: Optional Docker authentication for private registries.
max_failed_jobs: Maximum failed jobs allowed before operation fails.
Returns:
Operation: YT operation object that can be monitored and waited on.
Raises:
Exception: If operation submission fails.
"""
self.logger.info("Submitting vanilla operation")
self.logger.info(f" Task Name: {task_name}")
self.logger.info(f" Command: {command}")
self.logger.info(f" Files: {files}")
self.logger.info(f" Resources: {resources}")
try:
file_paths = [
FilePath(yt_path, file_name=local_path) for yt_path, local_path in files
]
spec_builder = (
VanillaSpecBuilder()
.pool(resources.pool)
.resource_limits({"user_slots": resources.user_slots})
.max_failed_job_count(max_failed_jobs)
)
# Set pool tree if specified
if resources.pool_tree:
spec_builder = spec_builder.pool_trees([resources.pool_tree])
self.logger.debug(f"Set pool tree to {resources.pool_tree}")
task_builder = (
spec_builder.begin_task(task_name)
.command(command)
.file_paths(file_paths)
.environment(env)
.memory_limit(resources.memory_gb * 1024**3)
.cpu_limit(resources.cpu_limit)
.gpu_limit(resources.gpu_limit)
.job_count(resources.job_count)
)
if resources.docker_image:
task_builder = task_builder.docker_image(resources.docker_image)
spec_builder.secure_vault({"docker_auth": docker_auth})
task_builder.end_task()
operation = self.client.run_operation(spec_builder, sync=False)
if operation is None:
raise RuntimeError(
"Failed to submit operation: run_operation returned None"
)
self.logger.info(f"Operation submitted: {operation.id}")
return operation
except Exception as e:
self.logger.error(f"Failed to submit vanilla operation: {e}")
raise