Source code for yt_framework.core.discovery
"""Filesystem scan that imports `stages/*/stage.py` and collects `BaseStage` types."""
import importlib
import sys
from pathlib import Path
from typing import List, Type
import logging
from yt_framework.core.stage import BaseStage
[docs]
def discover_stages(
pipeline_dir: Path,
logger: logging.Logger,
) -> List[Type[BaseStage]]:
"""
Automatically discover all stage classes from the ``stages`` directory tree.
Searches for ``stage.py`` under each ``stages`` child directory and imports
all ``BaseStage`` subclasses found.
Expected layout: ``pipeline_dir/stages/<stage_name>/stage.py`` with a
``BaseStage`` subclass in each module.
Args:
pipeline_dir: Path to pipeline directory
logger: Logger instance
Returns:
List of discovered stage classes
"""
stages_dir = pipeline_dir / "stages"
if not stages_dir.exists():
logger.warning(f"Stages directory not found: {stages_dir}")
return []
discovered_stages: List[Type[BaseStage]] = []
# Iterate through each subdirectory in stages/
for stage_dir in sorted(stages_dir.iterdir()): # Sort for consistent order
if not stage_dir.is_dir():
continue
stage_file = stage_dir / "stage.py"
if not stage_file.exists():
logger.debug(f"Skipping {stage_dir.name}: no stage.py file")
continue
# Import the stage module dynamically
stage_name = stage_dir.name
module_name = f"stages.{stage_name}.stage"
try:
# Add pipeline_dir to sys.path temporarily if needed
if str(pipeline_dir) not in sys.path:
sys.path.insert(0, str(pipeline_dir))
# Import the module
module = importlib.import_module(module_name)
# Find all BaseStage subclasses in the module
for attr_name in dir(module):
attr = getattr(module, attr_name)
# Check if it's a class, inherits from BaseStage, and isn't BaseStage itself
if (
isinstance(attr, type)
and issubclass(attr, BaseStage)
and attr is not BaseStage
):
discovered_stages.append(attr)
logger.debug(f"Discovered stage: {stage_name} -> {attr.__name__}")
break # Only take first BaseStage subclass per module
except Exception as e:
logger.warning(f"Failed to import stage from {stage_file}: {e}")
continue
if discovered_stages:
stage_names = [sc.__name__ for sc in discovered_stages]
logger.info(
f"[Discovery] Found {len(discovered_stages)} stage{'s' if len(discovered_stages) != 1 else ''}: {', '.join(stage_names)}"
)
return discovered_stages