Source code for yt_framework.core.discovery

"""
Stage Discovery
===============

Automatic discovery of stages from stages/ directory.
"""

import importlib
import sys
from pathlib import Path
from typing import List, Type
import logging

from yt_framework.core.stage import BaseStage


[docs] def discover_stages( pipeline_dir: Path, logger: logging.Logger, ) -> List[Type[BaseStage]]: """ Automatically discover all stage classes from stages/ directory. Searches for stage.py files in stages/*/ subdirectories and imports all BaseStage subclasses found in them. Directory structure expected: pipeline_dir/ stages/ stage_name_1/ stage.py # Contains Stage class inheriting from BaseStage stage_name_2/ stage.py Args: pipeline_dir: Path to pipeline directory logger: Logger instance Returns: List of discovered stage classes """ stages_dir = pipeline_dir / "stages" if not stages_dir.exists(): logger.warning(f"Stages directory not found: {stages_dir}") return [] discovered_stages: List[Type[BaseStage]] = [] # Iterate through each subdirectory in stages/ for stage_dir in sorted(stages_dir.iterdir()): # Sort for consistent order if not stage_dir.is_dir(): continue stage_file = stage_dir / "stage.py" if not stage_file.exists(): logger.debug(f"Skipping {stage_dir.name}: no stage.py file") continue # Import the stage module dynamically stage_name = stage_dir.name module_name = f"stages.{stage_name}.stage" try: # Add pipeline_dir to sys.path temporarily if needed if str(pipeline_dir) not in sys.path: sys.path.insert(0, str(pipeline_dir)) # Import the module module = importlib.import_module(module_name) # Find all BaseStage subclasses in the module for attr_name in dir(module): attr = getattr(module, attr_name) # Check if it's a class, inherits from BaseStage, and isn't BaseStage itself if ( isinstance(attr, type) and issubclass(attr, BaseStage) and attr is not BaseStage ): discovered_stages.append(attr) logger.debug(f"Discovered stage: {stage_name} -> {attr.__name__}") break # Only take first BaseStage subclass per module except Exception as e: logger.warning(f"Failed to import stage from {stage_file}: {e}") continue if discovered_stages: stage_names = [sc.__name__ for sc in discovered_stages] logger.info( f"[Discovery] Found {len(discovered_stages)} stage{'s' if len(discovered_stages) != 1 else ''}: {', '.join(stage_names)}" ) return discovered_stages