fix(cli): report broken command/service loads once and cleanly

Loaders raised at import time, and since Python does not cache a failed import, every command importing services re-ran and re-reported the same errors. Build the registry once, collect failures into LOAD_ERRORS, and surface them via a single click.ClickException at the list/get chokepoints so the message renders once without a traceback or cascade.
This commit is contained in:
imSp4rky
2026-05-25 19:50:11 -06:00
parent 5e3ffdeaa1
commit b3a6db915b
2 changed files with 107 additions and 3 deletions

View File

@@ -1,3 +1,5 @@
import logging
from pathlib import Path
from typing import Optional
import click
@@ -5,11 +7,56 @@ import click
from unshackle.core.config import config
from unshackle.core.utilities import import_module_by_path
log = logging.getLogger("commands")
_COMMANDS = sorted(
(path for path in config.directories.commands.glob("*.py") if path.stem.lower() != "__init__"), key=lambda x: x.stem
)
_MODULES = {path.stem: getattr(import_module_by_path(path), path.stem) for path in _COMMANDS}
def load_command(path: Path) -> object:
"""Load one command module, returning its stem-named attribute.
Raises a concise, single-line error naming the command and the real cause so
a broken command never surfaces as a raw traceback pointing at the loader.
"""
try:
module = import_module_by_path(path)
except Exception as e:
raise RuntimeError(f"{path.stem}: failed to import — {type(e).__name__}: {e} ({path})") from e
try:
return getattr(module, path.stem)
except AttributeError as e:
raise RuntimeError(
f"{path.stem}: no object named '{path.stem}' found in {path} — it must match the filename"
) from e
def load_commands(paths: list[Path]) -> tuple[dict[str, object], list[str]]:
"""Load every command, returning the good ones plus a list of load errors.
Importing this module must never raise (it runs at CLI startup, before Rich
is installed, so a raise here prints an ugly pre-setup traceback). Instead we
collect failures and surface them once, cleanly, when the CLI is used.
"""
modules: dict[str, object] = {}
errors: list[str] = []
for path in paths:
try:
modules[path.stem] = load_command(path)
except Exception as e:
errors.append(str(e))
return modules, errors
_MODULES, LOAD_ERRORS = load_commands(_COMMANDS)
def check_load_errors() -> None:
"""Raise a single clean error if any command failed to load."""
if LOAD_ERRORS:
joined = "\n".join(f" - {err}" for err in LOAD_ERRORS)
raise click.ClickException(f"Failed to load {len(LOAD_ERRORS)} command(s):\n{joined}")
class Commands(click.MultiCommand):
@@ -17,10 +64,12 @@ class Commands(click.MultiCommand):
def list_commands(self, ctx: click.Context) -> list[str]:
"""Returns a list of command names from the command filenames."""
check_load_errors()
return [x.stem.replace("_", "-") for x in _COMMANDS]
def get_command(self, ctx: click.Context, name: str) -> Optional[click.Command]:
"""Load the command code and return the main click command function."""
check_load_errors()
module = _MODULES.get(name) or _MODULES.get(name.replace("-", "_"))
if not module:
raise click.ClickException(f"Unable to find command by the name '{name}'")

View File

@@ -1,5 +1,6 @@
from __future__ import annotations
import logging
import re
from pathlib import Path
@@ -9,6 +10,8 @@ from unshackle.core.config import config
from unshackle.core.service import Service
from unshackle.core.utilities import import_module_by_path
log = logging.getLogger("services")
_service_dirs = config.directories.services
if not isinstance(_service_dirs, list):
_service_dirs = [_service_dirs]
@@ -18,9 +21,59 @@ _SERVICES = sorted(
key=lambda x: x.parent.stem,
)
_MODULES = {path.parent.stem: getattr(import_module_by_path(path), path.parent.stem) for path in _SERVICES}
_ALIASES = {tag: getattr(module, "ALIASES") for tag, module in _MODULES.items()}
def load_service(path: Path) -> object:
"""Load one Service module, returning its tag-named class.
Raises a concise, single-line error naming the Service and the real cause so
a broken Service never surfaces as a raw traceback pointing at the loader.
"""
tag = path.parent.stem
try:
module = import_module_by_path(path)
except Exception as e:
raise RuntimeError(f"{tag}: failed to import — {type(e).__name__}: {e} ({path})") from e
try:
return getattr(module, tag)
except AttributeError as e:
raise RuntimeError(
f"{tag}: no class named '{tag}' found in {path} — the class name must match the directory name"
) from e
def load_services(paths: list[Path]) -> tuple[dict[str, object], list[str]]:
"""Load every Service, returning the good ones plus a list of load errors.
Importing this module must never raise: it is imported by several commands,
and a failed import is not cached by Python, so raising here would re-run and
re-report for every command. Instead we collect failures and let the caller
surface them once, cleanly, at the point services are actually used.
"""
modules: dict[str, object] = {}
errors: list[str] = []
for path in paths:
try:
modules[path.parent.stem] = load_service(path)
except Exception as e:
errors.append(str(e))
return modules, errors
_MODULES, LOAD_ERRORS = load_services(_SERVICES)
_ALIASES = {tag: getattr(module, "ALIASES", ()) for tag, module in _MODULES.items()}
def check_load_errors() -> None:
"""Raise a single clean error if any Service failed to load.
Called when services are actually needed (listing/resolving) so the message
is rendered once by Click, without a traceback and without cascading through
every command that imports this module.
"""
if LOAD_ERRORS:
joined = "\n".join(f" - {err}" for err in LOAD_ERRORS)
raise click.ClickException(f"Failed to load {len(LOAD_ERRORS)} service(s):\n{joined}")
class Services(click.MultiCommand):
@@ -64,10 +117,12 @@ class Services(click.MultiCommand):
if remote_tag not in tags:
tags.append(remote_tag)
return tags
check_load_errors()
return Services.get_tags()
def get_command(self, ctx: click.Context, name: str) -> click.Command:
"""Load the Service and return the Click CLI method."""
check_load_errors()
tag = Services.get_tag(name)
import_file = ctx.params.get("import_file") or (ctx.parent and ctx.parent.params.get("import_file"))