"""Thin wrapper around `yt.wrapper.YtClient` for real cluster operations."""
import contextlib
import logging
import uuid
from typing import Any, Literal, NoReturn, cast
from yt.wrapper import TablePath, YtClient
from yt.wrapper import format as yt_format
from yt_framework.yt.clients._client_split._client_prod_ops_mixin import (
ClientProdOpsMixin,
)
from yt_framework.yt.clients._client_split._client_prod_yql_mixin import (
ClientProdYqlMixin,
)
from yt_framework.yt.clients.client_base import BaseYTClient
from yt_framework.yt.support._client_prod_runtime import (
_raise_runtime_error,
disable_yt_proxy_discovery,
prod_create_table_parent,
prod_write_table_replace_create,
read_required_yt_secret,
)
from yt_framework.yt.support.max_row_weight import (
build_max_row_weight_pragma,
ensure_max_row_weight_pragma,
)
def _raise_value_error(message: str) -> NoReturn:
raise ValueError(message)
def _raise_value_error_from(cause: BaseException, message: str) -> NoReturn:
raise ValueError(message) from cause
[docs]
class YTProdClient(ClientProdYqlMixin, ClientProdOpsMixin, BaseYTClient):
"""Production YT client implementation.
Uses actual YTsaurus client for all operations.
"""
[docs]
def __init__(
self,
logger: logging.Logger,
secrets: dict[str, str],
pickling: dict[str, Any] | None = None,
) -> None:
"""Initialize production YT client.
Args:
logger: Logger instance
secrets: Dictionary containing YT credentials. Expected keys:
- YT_PROXY
- YT_TOKEN
pickling: Optional pickling-related client config (see ``_apply_pickling_config``).
"""
super().__init__(logger)
yt_proxy = read_required_yt_secret(
secrets,
key="YT_PROXY",
missing_message="YT_PROXY is not set (check secrets.env or environment variables)",
)
yt_token = read_required_yt_secret(
secrets,
key="YT_TOKEN",
missing_message="YT_TOKEN is not set (check secrets.env or environment variables)",
)
self.client = YtClient(proxy=yt_proxy, token=yt_token)
self._apply_pickling_config(pickling or {})
disable_yt_proxy_discovery(self.client, self.logger, yt_proxy)
def _apply_pickling_config(self, pickling: dict[str, Any]) -> None:
"""Apply pickling flags from pipeline config to the YT client.
Supported flags:
ignore_system_modules (bool): Skip stdlib/site-packages from auto-upload.
Prevents shadow packages (certifi, importlib, boto3, etc.) from polluting
the worker sandbox. Safe default for Docker-based jobs.
disable_module_upload (bool): Skip ALL automatic module uploads.
Worker relies entirely on the Docker image + source.tar.gz.
"""
if not pickling:
return
cfg = cast("dict[str, Any]", self.client.config.setdefault("pickling", {}))
if pickling.get("ignore_system_modules"):
cfg["ignore_system_modules"] = True
self.logger.debug("Pickling: ignore_system_modules=True")
if pickling.get("disable_module_upload"):
existing_module_filter = cfg.get("module_filter")
def module_filter(module: object) -> bool:
if callable(existing_module_filter):
existing_module_filter(module)
return False
cfg["module_filter"] = module_filter
self.logger.debug("Pickling: module_filter=<upload nothing>")
[docs]
def create_path(
self,
path: str,
node_type: Literal[
"table",
"file",
"map_node",
"list_node",
"document",
] = "map_node",
) -> None:
"""Create a path in YT.
Args:
path: YT path to create.
node_type: Type of node to create (default: "map_node").
Returns:
None
Raises:
Exception: If path creation fails.
"""
try:
self.client.create(node_type, path, recursive=True, ignore_existing=True)
except Exception:
self.logger.exception("Failed to create path")
raise
[docs]
def exists(self, path: str) -> bool:
"""Check if a path exists in YT.
Args:
path: YT path to check.
Returns:
bool: True if path exists, False otherwise.
Raises:
Exception: If check fails.
"""
try:
return self.client.exists(path)
except Exception:
self.logger.exception("Failed to check if path exists")
raise
[docs]
def write_table(
self,
table_path: str,
rows: list[dict[str, Any]],
*,
append: bool = False,
replication_factor: int = 1,
make_parents: bool = True,
) -> None:
"""Write rows to a YT table.
Args:
table_path: YT table path
rows: List of dictionaries representing table rows
append: If True, append to existing table (default: False)
replication_factor: Replication factor for the table (default: 1)
make_parents: If True, create parent directories if they don't exist (default: True)
"""
mode_str = "append" if append else "overwrite"
self.logger.info("Writing %s rows → %s (%s)", len(rows), table_path, mode_str)
try:
prod_create_table_parent(
make_parents=make_parents,
table_path=table_path,
create_path=self.create_path,
logger=self.logger,
)
prod_write_table_replace_create(
self.client,
append=append,
table_path=table_path,
replication_factor=replication_factor,
)
self.client.write_table(
TablePath(table_path, append=append),
rows,
format=yt_format.JsonFormat(),
)
except Exception:
self.logger.exception("Failed to write table")
raise
[docs]
def read_table(self, table_path: str) -> list[dict[str, Any]]:
"""Read rows from a YT table.
Args:
table_path: YT table path to read.
Returns:
List[Dict[str, Any]]: List of dictionaries representing table rows.
Raises:
Exception: If table read fails.
"""
self.logger.info("Reading table: %s", table_path)
try:
# Type ignore needed because YT client's read_table has complex return types
# but when called with JsonFormat(), it returns an iterable of dicts
table_iterator = self.client.read_table(
TablePath(table_path),
format=yt_format.JsonFormat(),
)
results: list[dict[str, Any]] = list(table_iterator) # type: ignore[arg-type]
self.logger.info("✓ Read %s rows", len(results))
except Exception:
self.logger.exception("Failed to read table")
raise
else:
return results
[docs]
def row_count(self, table_path: str) -> int:
"""Get number of rows in a YT table.
Args:
table_path: YT table path.
Returns:
int: Number of rows in the table.
Raises:
Exception: If row count query fails.
"""
try:
count = self.client.row_count(table_path)
self.logger.debug("Row count: %s", count)
except Exception:
self.logger.exception("Failed to get row count")
raise
else:
return count
@staticmethod
def _filter_internal_yql_columns(columns: list[str]) -> list[str]:
return [col for col in columns if not col.startswith("_")]
@staticmethod
def _extract_columns_from_schema_value(schema: object) -> list[str]:
if not isinstance(schema, list):
return []
columns = [
col["name"] for col in schema if isinstance(col, dict) and "name" in col
]
return YTProdClient._filter_internal_yql_columns(columns)
def _get_columns_from_table_attributes(self, table_path: str) -> list[str]:
attrs = self.client.get(table_path, attributes=["schema"])
if not (attrs and isinstance(attrs, dict) and "schema" in attrs):
return []
return self._extract_columns_from_schema_value(attrs["schema"])
def _get_columns_from_first_row(self, table_path: str) -> list[str]:
rows = self.read_table(table_path)
if not rows:
_raise_value_error(f"Table {table_path} is empty, cannot determine columns")
filtered = self._filter_internal_yql_columns(list(rows[0].keys()))
if filtered:
return filtered
return list(rows[0].keys())
@staticmethod
def _is_binary_decode_error(error: Exception) -> bool:
error_str = str(error)
return "Failed to decode string" in error_str or "encoding" in error_str.lower()
def _infer_columns_temp_yql(self, table_path: str) -> list[str]:
temp_output = f"{table_path}.temp_schema_{uuid.uuid4().hex[:8]}"
try:
query = f"""{build_max_row_weight_pragma()}
PRAGMA yt.InferSchema = '1';
INSERT INTO `{temp_output}` WITH TRUNCATE
SELECT * FROM `{table_path}` LIMIT 0;""" # noqa: S608
self.run_yql(query)
temp_attrs = self.client.get(temp_output, attributes=["schema"])
if temp_attrs and isinstance(temp_attrs, dict) and "schema" in temp_attrs:
return self._extract_columns_from_schema_value(temp_attrs["schema"])
return []
finally:
with contextlib.suppress(Exception):
self.client.remove(temp_output)
def _handle_table_column_read_error(
self,
*,
table_path: str,
read_error: Exception,
) -> list[str]:
if not self._is_binary_decode_error(read_error):
_raise_value_error_from(
read_error,
f"Failed to get table columns from {table_path}: {read_error}",
)
self.logger.debug(
"Reading failed due to binary columns, using YQL to infer schema",
)
try:
columns = self._infer_columns_temp_yql(table_path)
if columns:
return columns
except Exception as yql_error: # noqa: BLE001
self.logger.debug("YQL schema inference failed: %s", yql_error)
msg = (
f"Table {table_path} contains binary columns that cannot be decoded. "
"This usually happens when a table was created with SELECT * and contains "
"internal YQL columns like _yql_column_0. Please recreate the table with "
f"explicit column selection, or delete and recreate it. Original error: {read_error}"
)
_raise_value_error_from(read_error, msg)
def _get_table_columns(self, table_path: str) -> list[str]:
"""Get column names from a table.
Tries multiple methods:
1. Get schema from table attributes (handles binary columns)
2. Read one row from table
3. Use YQL query with LIMIT 0 to infer schema (when reading fails due to binary columns)
Args:
table_path: Path to YT table
Returns:
List of column names (filtered to exclude internal YQL columns)
Raises:
ValueError: If table is empty or cannot be read
"""
try:
columns = self._get_columns_from_table_attributes(table_path)
if columns:
return columns
except Exception as e: # noqa: BLE001
# YT attribute/schema access is version-dependent; fall through to row read.
self.logger.debug(
"Could not get schema from attributes: %s, trying to read table",
e,
)
try:
return self._get_columns_from_first_row(table_path)
except Exception as read_error: # noqa: BLE001
return self._handle_table_column_read_error(
table_path=table_path,
read_error=read_error,
)
[docs]
def run_yql(
self,
query: str,
pool: str = "default",
max_row_weight: str | None = None,
) -> None:
"""Execute a YQL query on YT cluster.
Args:
query: YQL query string to execute
pool: YT pool name (default: 'default')
max_row_weight: Optional max row weight override
Raises:
Exception: If query execution fails
"""
self.logger.info("Executing YQL query on YT cluster")
self.logger.debug("Pool: %s", pool)
query_with_max_row_weight = ensure_max_row_weight_pragma(
query=query,
max_row_weight=max_row_weight,
)
self.logger.debug("Query:\n%s", query_with_max_row_weight)
try:
# Execute YQL query on YT cluster using Python API
query_obj = self.client.run_query(
engine="yql",
query=query_with_max_row_weight,
settings={"pool": pool},
)
# Wait for query to complete
self.logger.info("Query started: %s", query_obj.id)
# Check result
state = query_obj.get_state()
if state == "completed":
self.logger.info("✓ YQL query completed successfully")
else:
error = query_obj.get_error()
msg = f"Query failed with state {state}: {error}"
_raise_runtime_error(msg)
except Exception:
self.logger.exception("Failed to execute YQL query")
raise