Compare commits

..

5 Commits

5 changed files with 196 additions and 54 deletions

View File

@@ -2,7 +2,7 @@
from typing import Callable, Any, TypedDict from typing import Callable, Any, TypedDict
from .configure import Configuration from .configure import Configuration
from .operations import Operation from .operations import Operation, OperationFetchError
from .post_analysis import PostAnalyzer from .post_analysis import PostAnalyzer
from .logger import get_logger, handle_error from .logger import get_logger, handle_error
from .safeguard import require_confirmation from .safeguard import require_confirmation
@@ -35,7 +35,7 @@ COMMAND_METADATA: dict[str, CommandMetadata] = {
"operation_name": "Deleting posts with media", "operation_name": "Deleting posts with media",
"strategy_type": "feed", "strategy_type": "feed",
"collection": None, "collection": None,
"filter_fn": lambda post: PostAnalyzer.has_media(post.post), "filter_fn": lambda post: bool(PostAnalyzer.has_media(post.post)),
}, },
"likes": { "likes": {
"confirmation": "undo all likes", "confirmation": "undo all likes",
@@ -59,7 +59,7 @@ COMMAND_METADATA: dict[str, CommandMetadata] = {
"operation_name": "Deleting quote posts", "operation_name": "Deleting quote posts",
"strategy_type": "feed", "strategy_type": "feed",
"collection": None, "collection": None,
"filter_fn": lambda post: PostAnalyzer.has_quote(post.post), "filter_fn": lambda post: bool(PostAnalyzer.has_quote(post.post)),
}, },
"follows": { "follows": {
"confirmation": "unfollow all accounts", "confirmation": "unfollow all accounts",
@@ -101,7 +101,7 @@ class CommandRegistry:
name: str, name: str,
handler: CommandHandler, handler: CommandHandler,
help_text: str, help_text: str,
requires_config: bool = True requires_config: bool = True,
): ):
self._commands[name] = handler self._commands[name] = handler
self._help_texts[name] = help_text self._help_texts[name] = help_text
@@ -138,7 +138,7 @@ def _create_operation_handler(
operation_name: str, operation_name: str,
strategy_type: str = "feed", strategy_type: str = "feed",
collection: str | None = None, collection: str | None = None,
filter_fn: Callable[[Any], bool] | None = None filter_fn: Callable[[Any], bool] | None = None,
) -> CommandHandler: ) -> CommandHandler:
logger = get_logger() logger = get_logger()
@@ -149,10 +149,11 @@ def _create_operation_handler(
operation_name, operation_name,
strategy_type=strategy_type, strategy_type=strategy_type,
collection=collection, collection=collection,
filter_fn=filter_fn filter_fn=filter_fn,
).run() ).run()
except (ValueError, Exception) as e: except Exception as e:
handle_error(e, logger) handle_error(e, logger)
return handler return handler
@@ -169,7 +170,7 @@ def _create_command_handlers():
metadata["operation_name"], metadata["operation_name"],
strategy_type=metadata["strategy_type"], strategy_type=metadata["strategy_type"],
collection=metadata["collection"], collection=metadata["collection"],
filter_fn=metadata["filter_fn"] filter_fn=metadata["filter_fn"],
) )
return handlers return handlers
@@ -179,16 +180,19 @@ _command_handlers = _create_command_handlers()
def run_all(skip_confirmation: bool = False): def run_all(skip_confirmation: bool = False):
logger = get_logger() logger = get_logger()
fetch_failures: list[str] = []
all_commands = registry.get_all_commands() all_commands = registry.get_all_commands()
available_commands = [cmd for cmd in all_commands.keys() available_commands = [
if cmd not in ("configure", "all")] cmd for cmd in all_commands.keys() if cmd not in ("configure", "all")
]
commands = [cmd for cmd in COMMAND_EXECUTION_ORDER commands = [
if cmd in available_commands] cmd for cmd in COMMAND_EXECUTION_ORDER if cmd in available_commands]
commands.extend([cmd for cmd in available_commands commands.extend(
if cmd not in COMMAND_EXECUTION_ORDER]) [cmd for cmd in available_commands if cmd not in COMMAND_EXECUTION_ORDER]
)
commands_str = ", ".join(commands) commands_str = ", ".join(commands)
all_confirmation = f"run all cleanup commands ({commands_str})" all_confirmation = f"run all cleanup commands ({commands_str})"
@@ -197,6 +201,7 @@ def run_all(skip_confirmation: bool = False):
logger.info("Running all cleanup commands...") logger.info("Running all cleanup commands...")
from .operations import OperationContext from .operations import OperationContext
try: try:
context = OperationContext() context = OperationContext()
shared_client = context.client shared_client = context.client
@@ -217,19 +222,32 @@ def run_all(skip_confirmation: bool = False):
collection=metadata["collection"], collection=metadata["collection"],
filter_fn=metadata["filter_fn"], filter_fn=metadata["filter_fn"],
client=shared_client, client=shared_client,
config_data=shared_config_data config_data=shared_config_data,
).run() ).run()
else: else:
registry.execute(cmd, skip_confirmation=True) registry.execute(cmd, skip_confirmation=True)
logger.info(f"Completed command: {cmd}") logger.info(f"Completed command: {cmd}")
except OperationFetchError as e:
fetch_failures.append(cmd)
logger.error(
f"Fetch error while running '{cmd}': {e}", exc_info=True)
continue
except Exception as e: except Exception as e:
logger.error(f"Error running '{cmd}': {e}", exc_info=True) logger.error(f"Error running '{cmd}': {e}", exc_info=True)
continue continue
if fetch_failures:
failed_commands = ", ".join(fetch_failures)
raise RuntimeError(
f"One or more commands failed while fetching items: {failed_commands}"
)
logger.info("All commands completed.") logger.info("All commands completed.")
registry.register("configure", run_configure, registry.register(
"create configuration", requires_config=False) "configure", run_configure, "create configuration", requires_config=False
)
for cmd, metadata in COMMAND_METADATA.items(): for cmd, metadata in COMMAND_METADATA.items():
registry.register(cmd, _command_handlers[cmd], metadata["help_text"]) registry.register(cmd, _command_handlers[cmd], metadata["help_text"])
registry.register("all", run_all, "target everything") registry.register("all", run_all, "target everything")

View File

@@ -1,6 +1,7 @@
"""Shared operation utilities and strategies for Skywipe""" """Shared operation utilities and strategies for Skywipe"""
import time import time
from dataclasses import dataclass
from typing import Callable, Any from typing import Callable, Any
from atproto import models from atproto import models
@@ -9,6 +10,29 @@ from .configure import Configuration
from .logger import get_logger, ProgressTracker from .logger import get_logger, ProgressTracker
@dataclass(slots=True)
class OperationResult:
processed: int = 0
failed: int = 0
skipped: int = 0
class OperationFetchError(RuntimeError):
def __init__(
self,
operation_name: str,
batch_num: int,
error: Exception,
result: OperationResult,
):
self.operation_name = operation_name
self.batch_num = batch_num
self.result = result
super().__init__(
f"Failed to fetch items for '{operation_name}' at batch {batch_num}: {error}"
)
class OperationContext: class OperationContext:
def __init__(self, client=None, config_data=None): def __init__(self, client=None, config_data=None):
self.logger = get_logger() self.logger = get_logger()
@@ -19,21 +43,31 @@ class OperationContext:
def _initialize_client(self, client): def _initialize_client(self, client):
if client is not None: if client is not None:
return client, client.me.did did = self._extract_did(client)
return client, did
try: try:
auth = Auth() auth = Auth()
client = auth.login() client = auth.login()
return client, client.me.did did = self._extract_did(client)
return client, did
except (ValueError, FileNotFoundError) as e: except (ValueError, FileNotFoundError) as e:
self.logger.error(f"Configuration error: {e}") self.logger.error(f"Configuration error: {e}")
raise raise
except Exception as e: except Exception as e:
self.logger.error( self.logger.error(
f"Unexpected error during initialization: {e}", exc_info=True) f"Unexpected error during initialization: {e}", exc_info=True
)
raise ValueError( raise ValueError(
f"Failed to initialize operation context: {e}") from e f"Failed to initialize operation context: {e}") from e
def _extract_did(self, client) -> str:
me = getattr(client, "me", None)
did = getattr(me, "did", None)
if not did:
raise ValueError("Authenticated client does not expose a DID")
return did
def _initialize_config(self, config_data): def _initialize_config(self, config_data):
if config_data is not None: if config_data is not None:
return config_data return config_data
@@ -46,9 +80,9 @@ class OperationContext:
raise raise
except Exception as e: except Exception as e:
self.logger.error( self.logger.error(
f"Unexpected error loading configuration: {e}", exc_info=True) f"Unexpected error loading configuration: {e}", exc_info=True
raise ValueError( )
f"Failed to load configuration: {e}") from e raise ValueError(f"Failed to load configuration: {e}") from e
class BaseStrategy: class BaseStrategy:
@@ -74,21 +108,21 @@ class RecordDeletionStrategy(BaseStrategy):
repo=context.did, repo=context.did,
collection=self.collection, collection=self.collection,
limit=context.batch_size, limit=context.batch_size,
cursor=cursor cursor=cursor,
) )
return context.client.com.atproto.repo.list_records(params=list_params) return context.client.com.atproto.repo.list_records(params=list_params)
def extract_items(self, response: Any) -> list[Any]: def extract_items(self, response: Any) -> list[Any]:
return response.records return response.records
def process_item(self, record: Any, context: OperationContext) -> None: def process_item(self, item: Any, context: OperationContext) -> None:
record_uri = record.uri record_uri = item.uri
rkey = record_uri.rsplit("/", 1)[-1] rkey = record_uri.rsplit("/", 1)[-1]
delete_data = { delete_data = models.ComAtprotoRepoDeleteRecord.Data(
"repo": context.did, repo=context.did,
"collection": self.collection, collection=self.collection,
"rkey": rkey rkey=rkey,
} )
context.client.com.atproto.repo.delete_record(data=delete_data) context.client.com.atproto.repo.delete_record(data=delete_data)
context.logger.debug(f"Deleted: {record_uri}") context.logger.debug(f"Deleted: {record_uri}")
@@ -99,13 +133,15 @@ class FeedStrategy(BaseStrategy):
return context.client.get_author_feed( return context.client.get_author_feed(
actor=context.did, limit=context.batch_size, cursor=cursor actor=context.did, limit=context.batch_size, cursor=cursor
) )
return context.client.get_author_feed(actor=context.did, limit=context.batch_size) return context.client.get_author_feed(
actor=context.did, limit=context.batch_size
)
def extract_items(self, response: Any) -> list[Any]: def extract_items(self, response: Any) -> list[Any]:
return response.feed return response.feed
def process_item(self, post: Any, context: OperationContext) -> None: def process_item(self, item: Any, context: OperationContext) -> None:
uri = post.post.uri uri = item.post.uri
context.client.delete_post(uri) context.client.delete_post(uri)
context.logger.debug(f"Deleted post: {uri}") context.logger.debug(f"Deleted post: {uri}")
@@ -113,16 +149,15 @@ class FeedStrategy(BaseStrategy):
class BookmarkStrategy(BaseStrategy): class BookmarkStrategy(BaseStrategy):
def fetch(self, context: OperationContext, cursor: str | None = None) -> Any: def fetch(self, context: OperationContext, cursor: str | None = None) -> Any:
get_params = models.AppBskyBookmarkGetBookmarks.Params( get_params = models.AppBskyBookmarkGetBookmarks.Params(
limit=context.batch_size, limit=context.batch_size, cursor=cursor
cursor=cursor
) )
return context.client.app.bsky.bookmark.get_bookmarks(params=get_params) return context.client.app.bsky.bookmark.get_bookmarks(params=get_params)
def extract_items(self, response: Any) -> list[Any]: def extract_items(self, response: Any) -> list[Any]:
return response.bookmarks return response.bookmarks
def process_item(self, bookmark: Any, context: OperationContext) -> None: def process_item(self, item: Any, context: OperationContext) -> None:
bookmark_uri = self._extract_bookmark_uri(bookmark) bookmark_uri = self._extract_bookmark_uri(item)
if not bookmark_uri: if not bookmark_uri:
raise ValueError("Unable to find bookmark URI") raise ValueError("Unable to find bookmark URI")
@@ -151,7 +186,7 @@ class Operation:
collection: str | None = None, collection: str | None = None,
filter_fn: Callable[[Any], bool] | None = None, filter_fn: Callable[[Any], bool] | None = None,
client=None, client=None,
config_data=None config_data=None,
): ):
self.operation_name = operation_name self.operation_name = operation_name
self.filter_fn = filter_fn self.filter_fn = filter_fn
@@ -168,7 +203,7 @@ class Operation:
else: else:
self.strategy = FeedStrategy() self.strategy = FeedStrategy()
def run(self) -> int: def run(self) -> OperationResult:
context = OperationContext( context = OperationContext(
client=self._client, config_data=self._config_data) client=self._client, config_data=self._config_data)
progress = ProgressTracker(operation=self.operation_name) progress = ProgressTracker(operation=self.operation_name)
@@ -178,7 +213,7 @@ class Operation:
) )
cursor = None cursor = None
total_processed = 0 result = OperationResult()
batch_num = 0 batch_num = 0
while True: while True:
@@ -187,9 +222,16 @@ class Operation:
response = self.strategy.fetch(context, cursor) response = self.strategy.fetch(context, cursor)
items = self.strategy.extract_items(response) items = self.strategy.extract_items(response)
except Exception as e: except Exception as e:
context.logger.error( raise OperationFetchError(
f"Error fetching items for batch {batch_num}: {e}", exc_info=True) self.operation_name,
break batch_num,
e,
OperationResult(
processed=result.processed,
failed=result.failed,
skipped=result.skipped,
),
) from e
if not items: if not items:
break break
@@ -198,13 +240,15 @@ class Operation:
for item in items: for item in items:
if self.filter_fn and not self.filter_fn(item): if self.filter_fn and not self.filter_fn(item):
result.skipped += 1
continue continue
try: try:
self.strategy.process_item(item, context) self.strategy.process_item(item, context)
total_processed += 1 result.processed += 1
progress.update(1) progress.update(1)
except Exception as e: except Exception as e:
result.failed += 1
context.logger.error(f"Error processing item: {e}") context.logger.error(f"Error processing item: {e}")
cursor = self.strategy.get_cursor(response) cursor = self.strategy.get_cursor(response)
@@ -215,5 +259,6 @@ class Operation:
time.sleep(context.delay) time.sleep(context.delay)
context.logger.info( context.logger.info(
f"{self.operation_name}: {total_processed} items processed.") f"{self.operation_name}: processed={result.processed}, failed={result.failed}, skipped={result.skipped}."
return total_processed )
return result

View File

@@ -1,4 +1,5 @@
import logging import logging
from types import SimpleNamespace
import pytest import pytest
@@ -21,6 +22,7 @@ def test_create_operation_handler_calls_confirmation_and_run(monkeypatch):
def run(self): def run(self):
calls["run"] += 1 calls["run"] += 1
return SimpleNamespace(processed=1, failed=0, skipped=0)
monkeypatch.setattr(commands, "require_confirmation", fake_confirm) monkeypatch.setattr(commands, "require_confirmation", fake_confirm)
monkeypatch.setattr(commands, "Operation", FakeOperation) monkeypatch.setattr(commands, "Operation", FakeOperation)
@@ -77,6 +79,7 @@ def test_run_all_runs_in_order(monkeypatch):
def run(self): def run(self):
ran.append(self.operation_name) ran.append(self.operation_name)
return SimpleNamespace(processed=1, failed=0, skipped=0)
class FakeOperationContext: class FakeOperationContext:
def __init__(self): def __init__(self):
@@ -115,6 +118,7 @@ def test_run_all_continues_on_error(monkeypatch):
ran.append(self.operation_name) ran.append(self.operation_name)
if self.operation_name == "Deleting posts": if self.operation_name == "Deleting posts":
raise RuntimeError("fail") raise RuntimeError("fail")
return SimpleNamespace(processed=1, failed=0, skipped=0)
class FakeOperationContext: class FakeOperationContext:
def __init__(self): def __init__(self):
@@ -130,3 +134,48 @@ def test_run_all_continues_on_error(monkeypatch):
assert "Deleting posts" in ran assert "Deleting posts" in ran
assert len(ran) >= 2 assert len(ran) >= 2
def test_run_all_raises_on_fetch_failure(monkeypatch):
class FakeOperation:
def __init__(self, operation_name, **kwargs):
self.operation_name = operation_name
def run(self):
if self.operation_name == "Deleting posts":
raise operations.OperationFetchError(
"Deleting posts",
1,
RuntimeError("api unavailable"),
operations.OperationResult(),
)
return SimpleNamespace(processed=1, failed=0, skipped=0)
class FakeOperationContext:
def __init__(self):
self.client = object()
self.config_data = {"batch_size": 1, "delay": 0}
monkeypatch.setattr(commands, "require_confirmation",
lambda *args, **kwargs: None)
monkeypatch.setattr(commands, "Operation", FakeOperation)
monkeypatch.setattr(operations, "OperationContext", FakeOperationContext)
monkeypatch.setattr(
commands,
"COMMAND_METADATA",
{
"posts": {
"operation_name": "Deleting posts",
"strategy_type": "feed",
"collection": None,
"filter_fn": None,
}
},
)
monkeypatch.setattr(
commands.registry, "get_all_commands", lambda: {"posts": "only posts"}
)
monkeypatch.setattr(commands, "COMMAND_EXECUTION_ORDER", ["posts"])
with pytest.raises(RuntimeError, match="failed while fetching items"):
commands.run_all(skip_confirmation=True)

View File

@@ -2,7 +2,7 @@ import time
import pytest import pytest
from skywipe.operations import Operation, BookmarkStrategy from skywipe.operations import Operation, BookmarkStrategy, OperationFetchError
class FakeClient: class FakeClient:
@@ -60,10 +60,42 @@ def test_operation_run_batches_filters_and_sleeps(monkeypatch):
total = operation.run() total = operation.run()
assert total == 2 assert total.processed == 2
assert total.failed == 1
assert total.skipped == 1
assert slept == [1] assert slept == [1]
def test_operation_run_raises_on_fetch_error(monkeypatch):
class FetchErrorStrategy:
def fetch(self, context, cursor=None):
raise RuntimeError("api down")
def extract_items(self, response):
return []
def process_item(self, item, context):
return None
def get_cursor(self, response):
return None
operation = Operation(
"Testing",
strategy_type="feed",
client=FakeClient(),
config_data={"batch_size": 2, "delay": 0},
)
operation.strategy = FetchErrorStrategy()
with pytest.raises(OperationFetchError, match="Failed to fetch items") as exc:
operation.run()
assert exc.value.result.processed == 0
assert exc.value.result.failed == 0
assert exc.value.result.skipped == 0
def test_bookmark_strategy_extracts_uri_from_shapes(): def test_bookmark_strategy_extracts_uri_from_shapes():
strategy = BookmarkStrategy() strategy = BookmarkStrategy()

View File

@@ -55,11 +55,9 @@ def test_record_deletion_strategy_fetch_and_process(monkeypatch):
record = SimpleNamespace(uri="at://did:plc:fake/app.bsky.feed.like/abc123") record = SimpleNamespace(uri="at://did:plc:fake/app.bsky.feed.like/abc123")
strategy.process_item(record, context) strategy.process_item(record, context)
assert captured["delete"] == { assert captured["delete"].repo == "did:plc:fake"
"repo": "did:plc:fake", assert captured["delete"].collection == "app.bsky.feed.like"
"collection": "app.bsky.feed.like", assert captured["delete"].rkey == "abc123"
"rkey": "abc123",
}
def test_feed_strategy_fetch_and_process(): def test_feed_strategy_fetch_and_process():