Compare commits

...

2 Commits

Author SHA1 Message Date
e6d68dd37d refactor: type strategy interface 2026-01-05 20:38:20 +01:00
809b7823ea refactor: type command metadata 2026-01-05 20:38:09 +01:00
2 changed files with 32 additions and 13 deletions

View File

@@ -1,6 +1,6 @@
"""Command implementations for Skywipe""" """Command implementations for Skywipe"""
from typing import Callable, Any from typing import Callable, Any, TypedDict
from .configure import Configuration from .configure import Configuration
from .operations import Operation from .operations import Operation
from .post_analysis import PostAnalyzer from .post_analysis import PostAnalyzer
@@ -11,7 +11,16 @@ from .safeguard import require_confirmation
CommandHandler = Callable[..., None] CommandHandler = Callable[..., None]
COMMAND_METADATA = { class CommandMetadata(TypedDict):
confirmation: str
help_text: str
operation_name: str
strategy_type: str
collection: str | None
filter_fn: Callable[[Any], bool] | None
COMMAND_METADATA: dict[str, CommandMetadata] = {
"posts": { "posts": {
"confirmation": "delete all posts", "confirmation": "delete all posts",
"help_text": "only posts", "help_text": "only posts",

View File

@@ -52,7 +52,16 @@ class OperationContext:
class BaseStrategy: class BaseStrategy:
def get_cursor(self, response): def fetch(self, context: OperationContext, cursor: str | None = None) -> Any:
raise NotImplementedError
def extract_items(self, response: Any) -> list[Any]:
raise NotImplementedError
def process_item(self, item: Any, context: OperationContext) -> None:
raise NotImplementedError
def get_cursor(self, response: Any) -> str | None:
return response.cursor return response.cursor
@@ -60,7 +69,7 @@ class RecordDeletionStrategy(BaseStrategy):
def __init__(self, collection: str): def __init__(self, collection: str):
self.collection = collection self.collection = collection
def fetch(self, context: OperationContext, cursor: str | None = None): def fetch(self, context: OperationContext, cursor: str | None = None) -> Any:
list_params = models.ComAtprotoRepoListRecords.Params( list_params = models.ComAtprotoRepoListRecords.Params(
repo=context.did, repo=context.did,
collection=self.collection, collection=self.collection,
@@ -69,10 +78,10 @@ class RecordDeletionStrategy(BaseStrategy):
) )
return context.client.com.atproto.repo.list_records(params=list_params) return context.client.com.atproto.repo.list_records(params=list_params)
def extract_items(self, response): def extract_items(self, response: Any) -> list[Any]:
return response.records return response.records
def process_item(self, record, context: OperationContext): def process_item(self, record: Any, context: OperationContext) -> None:
record_uri = record.uri record_uri = record.uri
rkey = record_uri.rsplit("/", 1)[-1] rkey = record_uri.rsplit("/", 1)[-1]
delete_data = { delete_data = {
@@ -85,34 +94,34 @@ class RecordDeletionStrategy(BaseStrategy):
class FeedStrategy(BaseStrategy): class FeedStrategy(BaseStrategy):
def fetch(self, context: OperationContext, cursor: str | None = None): def fetch(self, context: OperationContext, cursor: str | None = None) -> Any:
if cursor: if cursor:
return context.client.get_author_feed( return context.client.get_author_feed(
actor=context.did, limit=context.batch_size, cursor=cursor actor=context.did, limit=context.batch_size, cursor=cursor
) )
return context.client.get_author_feed(actor=context.did, limit=context.batch_size) return context.client.get_author_feed(actor=context.did, limit=context.batch_size)
def extract_items(self, response): def extract_items(self, response: Any) -> list[Any]:
return response.feed return response.feed
def process_item(self, post, context: OperationContext): def process_item(self, post: Any, context: OperationContext) -> None:
uri = post.post.uri uri = post.post.uri
context.client.delete_post(uri) context.client.delete_post(uri)
context.logger.debug(f"Deleted post: {uri}") context.logger.debug(f"Deleted post: {uri}")
class BookmarkStrategy(BaseStrategy): class BookmarkStrategy(BaseStrategy):
def fetch(self, context: OperationContext, cursor: str | None = None): def fetch(self, context: OperationContext, cursor: str | None = None) -> Any:
get_params = models.AppBskyBookmarkGetBookmarks.Params( get_params = models.AppBskyBookmarkGetBookmarks.Params(
limit=context.batch_size, limit=context.batch_size,
cursor=cursor cursor=cursor
) )
return context.client.app.bsky.bookmark.get_bookmarks(params=get_params) return context.client.app.bsky.bookmark.get_bookmarks(params=get_params)
def extract_items(self, response): def extract_items(self, response: Any) -> list[Any]:
return response.bookmarks return response.bookmarks
def process_item(self, bookmark, context: OperationContext): def process_item(self, bookmark: Any, context: OperationContext) -> None:
bookmark_uri = self._extract_bookmark_uri(bookmark) bookmark_uri = self._extract_bookmark_uri(bookmark)
if not bookmark_uri: if not bookmark_uri:
raise ValueError("Unable to find bookmark URI") raise ValueError("Unable to find bookmark URI")
@@ -122,7 +131,7 @@ class BookmarkStrategy(BaseStrategy):
context.client.app.bsky.bookmark.delete_bookmark(data=delete_data) context.client.app.bsky.bookmark.delete_bookmark(data=delete_data)
context.logger.debug(f"Deleted bookmark: {bookmark_uri}") context.logger.debug(f"Deleted bookmark: {bookmark_uri}")
def _extract_bookmark_uri(self, bookmark): def _extract_bookmark_uri(self, bookmark: Any) -> str | None:
if hasattr(bookmark, "uri"): if hasattr(bookmark, "uri"):
return bookmark.uri return bookmark.uri
@@ -148,6 +157,7 @@ class Operation:
self.filter_fn = filter_fn self.filter_fn = filter_fn
self._client = client self._client = client
self._config_data = config_data self._config_data = config_data
self.strategy: BaseStrategy
if strategy_type == "record": if strategy_type == "record":
if not collection: if not collection: