refactor: type strategy interface

This commit is contained in:
2026-01-05 20:38:20 +01:00
parent 809b7823ea
commit e6d68dd37d

View File

@@ -52,7 +52,16 @@ class OperationContext:
class BaseStrategy: class BaseStrategy:
def get_cursor(self, response): def fetch(self, context: OperationContext, cursor: str | None = None) -> Any:
raise NotImplementedError
def extract_items(self, response: Any) -> list[Any]:
raise NotImplementedError
def process_item(self, item: Any, context: OperationContext) -> None:
raise NotImplementedError
def get_cursor(self, response: Any) -> str | None:
return response.cursor return response.cursor
@@ -60,7 +69,7 @@ class RecordDeletionStrategy(BaseStrategy):
def __init__(self, collection: str): def __init__(self, collection: str):
self.collection = collection self.collection = collection
def fetch(self, context: OperationContext, cursor: str | None = None): def fetch(self, context: OperationContext, cursor: str | None = None) -> Any:
list_params = models.ComAtprotoRepoListRecords.Params( list_params = models.ComAtprotoRepoListRecords.Params(
repo=context.did, repo=context.did,
collection=self.collection, collection=self.collection,
@@ -69,10 +78,10 @@ class RecordDeletionStrategy(BaseStrategy):
) )
return context.client.com.atproto.repo.list_records(params=list_params) return context.client.com.atproto.repo.list_records(params=list_params)
def extract_items(self, response): def extract_items(self, response: Any) -> list[Any]:
return response.records return response.records
def process_item(self, record, context: OperationContext): def process_item(self, record: Any, context: OperationContext) -> None:
record_uri = record.uri record_uri = record.uri
rkey = record_uri.rsplit("/", 1)[-1] rkey = record_uri.rsplit("/", 1)[-1]
delete_data = { delete_data = {
@@ -85,34 +94,34 @@ class RecordDeletionStrategy(BaseStrategy):
class FeedStrategy(BaseStrategy): class FeedStrategy(BaseStrategy):
def fetch(self, context: OperationContext, cursor: str | None = None): def fetch(self, context: OperationContext, cursor: str | None = None) -> Any:
if cursor: if cursor:
return context.client.get_author_feed( return context.client.get_author_feed(
actor=context.did, limit=context.batch_size, cursor=cursor actor=context.did, limit=context.batch_size, cursor=cursor
) )
return context.client.get_author_feed(actor=context.did, limit=context.batch_size) return context.client.get_author_feed(actor=context.did, limit=context.batch_size)
def extract_items(self, response): def extract_items(self, response: Any) -> list[Any]:
return response.feed return response.feed
def process_item(self, post, context: OperationContext): def process_item(self, post: Any, context: OperationContext) -> None:
uri = post.post.uri uri = post.post.uri
context.client.delete_post(uri) context.client.delete_post(uri)
context.logger.debug(f"Deleted post: {uri}") context.logger.debug(f"Deleted post: {uri}")
class BookmarkStrategy(BaseStrategy): class BookmarkStrategy(BaseStrategy):
def fetch(self, context: OperationContext, cursor: str | None = None): def fetch(self, context: OperationContext, cursor: str | None = None) -> Any:
get_params = models.AppBskyBookmarkGetBookmarks.Params( get_params = models.AppBskyBookmarkGetBookmarks.Params(
limit=context.batch_size, limit=context.batch_size,
cursor=cursor cursor=cursor
) )
return context.client.app.bsky.bookmark.get_bookmarks(params=get_params) return context.client.app.bsky.bookmark.get_bookmarks(params=get_params)
def extract_items(self, response): def extract_items(self, response: Any) -> list[Any]:
return response.bookmarks return response.bookmarks
def process_item(self, bookmark, context: OperationContext): def process_item(self, bookmark: Any, context: OperationContext) -> None:
bookmark_uri = self._extract_bookmark_uri(bookmark) bookmark_uri = self._extract_bookmark_uri(bookmark)
if not bookmark_uri: if not bookmark_uri:
raise ValueError("Unable to find bookmark URI") raise ValueError("Unable to find bookmark URI")
@@ -122,7 +131,7 @@ class BookmarkStrategy(BaseStrategy):
context.client.app.bsky.bookmark.delete_bookmark(data=delete_data) context.client.app.bsky.bookmark.delete_bookmark(data=delete_data)
context.logger.debug(f"Deleted bookmark: {bookmark_uri}") context.logger.debug(f"Deleted bookmark: {bookmark_uri}")
def _extract_bookmark_uri(self, bookmark): def _extract_bookmark_uri(self, bookmark: Any) -> str | None:
if hasattr(bookmark, "uri"): if hasattr(bookmark, "uri"):
return bookmark.uri return bookmark.uri
@@ -148,6 +157,7 @@ class Operation:
self.filter_fn = filter_fn self.filter_fn = filter_fn
self._client = client self._client = client
self._config_data = config_data self._config_data = config_data
self.strategy: BaseStrategy
if strategy_type == "record": if strategy_type == "record":
if not collection: if not collection: