Source code for ytjobs.s3.client

"""Minimal boto3 wrapper for job-side list/get/put helpers."""

import logging
from collections.abc import Mapping, Sequence
from dataclasses import dataclass, replace
from typing import Any, Literal
from urllib.parse import urlparse

import boto3
from botocore.client import Config as BotoConfig
from botocore.exceptions import ClientError


def _parse_chunk_size(line: bytes) -> int | None:
    size_part = line.split(b";")[0].strip() if b";" in line else line.strip()
    try:
        return int(size_part, 16)
    except ValueError:
        return None


def _skip_crlf_after_chunk(data: bytes, i: int) -> int:
    while i < len(data) and data[i : i + 1] in (b"\r", b"\n"):
        i += 1
    return i


def _try_decode_next_chunk(data: bytes, i: int) -> tuple[int, bytes] | None:
    line_end = data.find(b"\n", i)
    if line_end < 0:
        return None
    line = data[i:line_end].strip(b"\r")
    size = _parse_chunk_size(line)
    if size is None:
        return None
    i = line_end + 1
    if size == 0:
        return None
    if i + size > len(data):
        return None
    chunk = bytes(data[i : i + size])
    i += size
    i = _skip_crlf_after_chunk(data, i)
    return i, chunk


def _decode_all_http_chunks(data: bytes) -> bytes:
    out = bytearray()
    i = 0
    while i < len(data):
        step = _try_decode_next_chunk(data, i)
        if step is None:
            break
        i, chunk = step
        out.extend(chunk)
    return bytes(out)


def _decode_http_chunked_if_present(data: bytes, logger: logging.Logger) -> bytes:
    """Decode HTTP chunked-transfer payloads when present, else return ``data``.

    Some S3-compatible backends store chunked-encoded bodies from dev/local uploads;
    strip framing and return the raw payload. If ``data`` is not chunked, return it
    unchanged.
    """
    if not data or data[0:1] not in b"0123456789abcdefABCDEF":
        return data
    decoded = _decode_all_http_chunks(data)
    if not decoded:
        return data
    logger.debug("Decoded HTTP chunked body (%s bytes payload)", len(decoded))
    return decoded


[docs] @dataclass(frozen=True) class S3ClientOptions: """Optional runtime knobs for `S3Client` construction.""" max_retries: int = 30 timeout: int = 360 logger: logging.Logger | None = None region_name: str | None = None boto_config: BotoConfig | None = None
def _coerce_int_option(value: object, option_name: str) -> int: if isinstance(value, bool) or not isinstance(value, int): msg = f"{option_name} must be an integer, got {value!r}" raise TypeError(msg) return value _LEGACY_S3_INIT_KEYS = frozenset( { "max_retries", "timeout", "logger", "region_name", "boto_config", }, ) def _reject_unknown_s3_legacy_kwargs(legacy_kwargs: dict[str, object]) -> None: unknown = set(legacy_kwargs) - _LEGACY_S3_INIT_KEYS if unknown: msg = f"Unknown S3Client init option(s): {sorted(unknown)}" raise TypeError(msg) def _replace_options_with_legacy_kwargs( merged_options: S3ClientOptions, legacy_kwargs: dict[str, object], ) -> S3ClientOptions: return replace( merged_options, max_retries=_coerce_int_option( legacy_kwargs.get("max_retries", merged_options.max_retries), "max_retries", ), timeout=_coerce_int_option( legacy_kwargs.get("timeout", merged_options.timeout), "timeout", ), logger=legacy_kwargs.get("logger", merged_options.logger), region_name=( str(legacy_kwargs["region_name"]) if "region_name" in legacy_kwargs and legacy_kwargs["region_name"] is not None else merged_options.region_name ), boto_config=legacy_kwargs.get("boto_config", merged_options.boto_config), ) def _options_from_legacy_kwargs( options: S3ClientOptions | None, legacy_kwargs: dict[str, object], ) -> S3ClientOptions: """Build options from explicit object plus backward-compatible kwargs.""" merged_options = options or S3ClientOptions() if not legacy_kwargs: return merged_options _reject_unknown_s3_legacy_kwargs(legacy_kwargs) return _replace_options_with_legacy_kwargs(merged_options, legacy_kwargs) def _append_single_listed_key( result: list[str], key: str, *, extension: str | None, max_files: int | None, ) -> bool: """Append key if it passes extension filter; return True if max_files reached.""" if extension and not key.endswith(f".{extension}"): return False result.append(key) return max_files is not None and len(result) >= max_files def _append_keys_until_limit( result: list[str], contents: Sequence[Mapping[str, Any]], *, extension: str | None, max_files: int | None, logger: logging.Logger, ) -> bool: """Append keys from one list_objects_v2 page; return True if max_files reached.""" for obj in contents: if _append_single_listed_key( result, obj["Key"], extension=extension, max_files=max_files, ): logger.info("Reached max_files limit (%s)", max_files) return True return False
[docs] class S3Client: """Thin boto3 S3 wrapper for job code (list, download, upload, head)."""
[docs] def __init__( self, endpoint: str, access_key: str, secret_key: str, *, options: S3ClientOptions | None = None, **legacy_kwargs: object, ) -> None: """Build a boto3 S3 client for the given endpoint and credentials. Args: endpoint: S3 API endpoint URL (e.g. from ``S3_ENDPOINT``). access_key: Access key id for this client. secret_key: Secret access key for this client. options: Optional strongly-typed client options. **legacy_kwargs: Backward-compatible aliases for old init kwargs (``max_retries``, ``timeout``, ``logger``, ``region_name``, ``boto_config``). """ resolved_options = _options_from_legacy_kwargs(options, legacy_kwargs) self.logger = resolved_options.logger or logging.getLogger(__name__) if resolved_options.boto_config is None: config = BotoConfig( s3={"addressing_style": "virtual"}, retries={ "max_attempts": resolved_options.max_retries, "mode": "standard", }, read_timeout=resolved_options.timeout, max_pool_connections=1, ) else: config = resolved_options.boto_config client_kwargs: dict[str, Any] = { "service_name": "s3", "aws_access_key_id": access_key, "aws_secret_access_key": secret_key, "endpoint_url": endpoint, "config": config, } if resolved_options.region_name is not None: client_kwargs["region_name"] = resolved_options.region_name self.client = boto3.client(**client_kwargs) self.logger.debug("S3 client initialized: %s", endpoint)
[docs] @staticmethod def parse_s3_uri(uri: str) -> tuple[str, str]: """Split ``s3://bucket/key/path`` into ``(bucket, key)``. Args: uri: S3 URI with non-empty bucket and key path. Returns: ``(bucket, key)`` where ``key`` has no leading slash. Raises: ValueError: If the URI is not a valid S3 URI. """ u = urlparse(uri) if u.scheme != "s3" or not u.netloc or not u.path: msg = f"Bad s3 uri: {uri}" raise ValueError(msg) return u.netloc, u.path.lstrip("/")
@staticmethod def _access_keys_for_client_type( secrets: dict[str, str], client_type: Literal["download", "upload"], ) -> tuple[str | None, str | None]: if client_type == "upload": return secrets.get("S3_UPLOAD_ACCESS_KEY"), secrets.get( "S3_UPLOAD_SECRET_KEY" ) if client_type == "download": return secrets.get("S3_DOWNLOAD_ACCESS_KEY"), secrets.get( "S3_DOWNLOAD_SECRET_KEY" ) msg = f"Unknown client type: {client_type}" raise ValueError(msg) @staticmethod def _require_s3_connection_secrets( *, endpoint: object, access_key: object, secret_key: object, client_type: Literal["download", "upload"], ) -> tuple[str, str, str]: if not endpoint: msg = "S3_ENDPOINT is not set" raise ValueError(msg) if not access_key: msg = f"S3_{client_type.upper()}_ACCESS_KEY is not set" raise ValueError(msg) if not secret_key: msg = f"S3_{client_type.upper()}_SECRET_KEY is not set" raise ValueError(msg) return str(endpoint), str(access_key), str(secret_key)
[docs] @staticmethod def create( secrets: dict[str, str], client_type: Literal["download", "upload"] = "download", ) -> "S3Client": """Create S3 client from secrets dictionary. Args: secrets: Dictionary containing S3 credentials. Expected keys: - S3_ENDPOINT - S3_DOWNLOAD_ACCESS_KEY - S3_DOWNLOAD_SECRET_KEY - S3_UPLOAD_ACCESS_KEY - S3_UPLOAD_SECRET_KEY client_type: ``download`` or ``upload`` (default: ``download``). Returns: Configured ``S3Client`` instance. Raises: ValueError: If ``client_type`` is unknown or required secrets are missing. """ access_key, secret_key = S3Client._access_keys_for_client_type( secrets, client_type ) endpoint = secrets.get("S3_ENDPOINT") ep, ak, sk = S3Client._require_s3_connection_secrets( endpoint=endpoint, access_key=access_key, secret_key=secret_key, client_type=client_type, ) return S3Client(endpoint=ep, access_key=ak, secret_key=sk)
[docs] def list_files( self, bucket: str, prefix: str = "", extension: str | None = None, max_files: int | None = None, ) -> list[str]: """List object keys under ``prefix`` in ``bucket``. Args: bucket: Bucket name. prefix: Key prefix filter; use ``""`` to list from the bucket root. extension: If set, keep only keys ending with ``.<extension>``. max_files: If set, stop after this many keys (best-effort). Returns: List of S3 object keys (not full ``s3://`` URIs). Raises: Exception: Propagates boto3/client errors after logging. """ self.logger.info("Listing files: s3://%s/%s", bucket, prefix) result = [] truncated = True token = None while truncated: params = {"Bucket": bucket, "Prefix": prefix} if token: params["ContinuationToken"] = token try: response = self.client.list_objects_v2(**params) except Exception: self.logger.exception("Failed to list objects") raise if _append_keys_until_limit( result, response.get("Contents", []), extension=extension, max_files=max_files, logger=self.logger, ): return result truncated = response.get("IsTruncated", False) token = response.get("NextContinuationToken") self.logger.info("Found %s files", len(result)) return result
[docs] def download(self, bucket: str, key: str) -> bytes: """Download one object body as bytes. Args: bucket: Bucket name. key: Object key. Returns: Raw object bytes (HTTP-chunked payloads may be normalized). Raises: Exception: Propagates boto3/client errors after logging. """ self.logger.debug("Downloading: s3://%s/%s", bucket, key) try: response = self.client.get_object(Bucket=bucket, Key=key) data = response["Body"].read() data = _decode_http_chunked_if_present(data, self.logger) self.logger.debug("Downloaded %s bytes", len(data)) except Exception: self.logger.exception("Failed to download") raise else: return data
[docs] def download_by_uri(self, s3_uri: str) -> bytes: """Download object bytes from ``s3://bucket/key``. Args: s3_uri: Valid S3 URI. Returns: Same as ``download``. Raises: ValueError: If ``s3_uri`` is invalid (via ``parse_s3_uri``). Exception: Propagates boto3/client errors from ``download``. """ bucket, key = S3Client.parse_s3_uri(s3_uri) return self.download(bucket, key)
[docs] def upload( self, data: bytes, bucket: str, key: str, content_type: str | None = None, ) -> None: """Upload bytes to ``s3://bucket/key`` via ``put_object``. Args: data: Object body. bucket: Bucket name. key: Object key. content_type: Optional ``ContentType`` header. Raises: Exception: Propagates boto3/client errors after logging. """ self.logger.debug("Uploading %s bytes to s3://%s/%s", len(data), bucket, key) try: params = {"Bucket": bucket, "Key": key, "Body": data} if content_type: params["ContentType"] = content_type self.client.put_object(**params) except Exception: self.logger.exception("Failed to upload") raise else: self.logger.debug("Upload completed")
[docs] def exists(self, bucket: str, key: str) -> bool: """Return whether an object exists (``head_object`` succeeds). Args: bucket: Bucket name. key: Object key. Returns: ``True`` if the object exists; ``False`` on any ``head_object`` error. """ try: self.client.head_object(Bucket=bucket, Key=key) except (ClientError, OSError): return False else: return True