From 1b8b32027cb3d37b231b549046947adb339d21b1 Mon Sep 17 00:00:00 2001 From: Kharec Date: Sat, 20 Dec 2025 21:07:50 +0100 Subject: [PATCH] feat: dependency injection to allow reusing an existing auth client --- skywipe/operations.py | 57 ++++++++++++++++++++++++++++++------------- 1 file changed, 40 insertions(+), 17 deletions(-) diff --git a/skywipe/operations.py b/skywipe/operations.py index 601c516..765a3cb 100644 --- a/skywipe/operations.py +++ b/skywipe/operations.py @@ -10,25 +10,43 @@ from .logger import get_logger, ProgressTracker class OperationContext: - def __init__(self): + def __init__(self, client=None, config_data=None): self.logger = get_logger() - try: - self.auth = Auth() - self.client = self.auth.login() - self.config = Configuration() - self.config_data = self.config.load() - except (ValueError, FileNotFoundError) as e: - self.logger.error(f"Configuration error: {e}") - raise - except Exception as e: - self.logger.error( - f"Unexpected error during initialization: {e}", exc_info=True) - raise ValueError( - f"Failed to initialize operation context: {e}") from e + + if client is not None: + self.client = client + self.did = client.me.did + else: + try: + self.auth = Auth() + self.client = self.auth.login() + self.did = self.client.me.did + except (ValueError, FileNotFoundError) as e: + self.logger.error(f"Configuration error: {e}") + raise + except Exception as e: + self.logger.error( + f"Unexpected error during initialization: {e}", exc_info=True) + raise ValueError( + f"Failed to initialize operation context: {e}") from e + + if config_data is not None: + self.config_data = config_data + else: + try: + self.config = Configuration() + self.config_data = self.config.load() + except (ValueError, FileNotFoundError) as e: + self.logger.error(f"Configuration error: {e}") + raise + except Exception as e: + self.logger.error( + f"Unexpected error loading configuration: {e}", exc_info=True) + raise ValueError( + f"Failed to load configuration: {e}") from e self.batch_size = self.config_data.get("batch_size", 10) self.delay = self.config_data.get("delay", 1) - self.did = self.client.me.did class BaseStrategy: @@ -120,10 +138,14 @@ class Operation: operation_name: str, strategy_type: str = "feed", collection: Optional[str] = None, - filter_fn: Optional[Callable[[Any], bool]] = None + filter_fn: Optional[Callable[[Any], bool]] = None, + client=None, + config_data=None ): self.operation_name = operation_name self.filter_fn = filter_fn + self._client = client + self._config_data = config_data if strategy_type == "record": if not collection: @@ -135,7 +157,8 @@ class Operation: self.strategy = FeedStrategy() def run(self) -> int: - context = OperationContext() + context = OperationContext( + client=self._client, config_data=self._config_data) progress = ProgressTracker(operation=self.operation_name) context.logger.info(