feat: dependency injection to allow reusing an existing auth client
This commit is contained in:
@@ -10,13 +10,17 @@ from .logger import get_logger, ProgressTracker
|
|||||||
|
|
||||||
|
|
||||||
class OperationContext:
|
class OperationContext:
|
||||||
def __init__(self):
|
def __init__(self, client=None, config_data=None):
|
||||||
self.logger = get_logger()
|
self.logger = get_logger()
|
||||||
|
|
||||||
|
if client is not None:
|
||||||
|
self.client = client
|
||||||
|
self.did = client.me.did
|
||||||
|
else:
|
||||||
try:
|
try:
|
||||||
self.auth = Auth()
|
self.auth = Auth()
|
||||||
self.client = self.auth.login()
|
self.client = self.auth.login()
|
||||||
self.config = Configuration()
|
self.did = self.client.me.did
|
||||||
self.config_data = self.config.load()
|
|
||||||
except (ValueError, FileNotFoundError) as e:
|
except (ValueError, FileNotFoundError) as e:
|
||||||
self.logger.error(f"Configuration error: {e}")
|
self.logger.error(f"Configuration error: {e}")
|
||||||
raise
|
raise
|
||||||
@@ -26,9 +30,23 @@ class OperationContext:
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Failed to initialize operation context: {e}") from e
|
f"Failed to initialize operation context: {e}") from e
|
||||||
|
|
||||||
|
if config_data is not None:
|
||||||
|
self.config_data = config_data
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
self.config = Configuration()
|
||||||
|
self.config_data = self.config.load()
|
||||||
|
except (ValueError, FileNotFoundError) as e:
|
||||||
|
self.logger.error(f"Configuration error: {e}")
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(
|
||||||
|
f"Unexpected error loading configuration: {e}", exc_info=True)
|
||||||
|
raise ValueError(
|
||||||
|
f"Failed to load configuration: {e}") from e
|
||||||
|
|
||||||
self.batch_size = self.config_data.get("batch_size", 10)
|
self.batch_size = self.config_data.get("batch_size", 10)
|
||||||
self.delay = self.config_data.get("delay", 1)
|
self.delay = self.config_data.get("delay", 1)
|
||||||
self.did = self.client.me.did
|
|
||||||
|
|
||||||
|
|
||||||
class BaseStrategy:
|
class BaseStrategy:
|
||||||
@@ -120,10 +138,14 @@ class Operation:
|
|||||||
operation_name: str,
|
operation_name: str,
|
||||||
strategy_type: str = "feed",
|
strategy_type: str = "feed",
|
||||||
collection: Optional[str] = None,
|
collection: Optional[str] = None,
|
||||||
filter_fn: Optional[Callable[[Any], bool]] = None
|
filter_fn: Optional[Callable[[Any], bool]] = None,
|
||||||
|
client=None,
|
||||||
|
config_data=None
|
||||||
):
|
):
|
||||||
self.operation_name = operation_name
|
self.operation_name = operation_name
|
||||||
self.filter_fn = filter_fn
|
self.filter_fn = filter_fn
|
||||||
|
self._client = client
|
||||||
|
self._config_data = config_data
|
||||||
|
|
||||||
if strategy_type == "record":
|
if strategy_type == "record":
|
||||||
if not collection:
|
if not collection:
|
||||||
@@ -135,7 +157,8 @@ class Operation:
|
|||||||
self.strategy = FeedStrategy()
|
self.strategy = FeedStrategy()
|
||||||
|
|
||||||
def run(self) -> int:
|
def run(self) -> int:
|
||||||
context = OperationContext()
|
context = OperationContext(
|
||||||
|
client=self._client, config_data=self._config_data)
|
||||||
progress = ProgressTracker(operation=self.operation_name)
|
progress = ProgressTracker(operation=self.operation_name)
|
||||||
|
|
||||||
context.logger.info(
|
context.logger.info(
|
||||||
|
|||||||
Reference in New Issue
Block a user