Compare commits

..

32 Commits

Author SHA1 Message Date
2cdc4c6c42 refactor: tidy logger handler setup 2026-01-15 16:06:32 +01:00
07862f0ea2 refactor: simplify configuration create flow 2026-01-15 16:02:21 +01:00
e095c68f72 test: share error message formatting helper 2026-01-08 06:51:56 +01:00
c718e8c6f5 test: patch configure getpass correctly 2026-01-08 06:51:46 +01:00
a07cc02fb0 fix: format KeyError messages cleanly 2026-01-08 06:51:37 +01:00
ecc33054af fix: handle config load errors via handler 2026-01-08 06:51:23 +01:00
e6d68dd37d refactor: type strategy interface 2026-01-05 20:38:20 +01:00
809b7823ea refactor: type command metadata 2026-01-05 20:38:09 +01:00
6a88ab8560 refactor: import getpass directly 2025-12-31 08:22:27 +01:00
dd3b220a4a test: use shared config fixture 2025-12-30 23:09:01 +01:00
82b99da50d test: add config_with_tmp_path fixture 2025-12-30 23:08:51 +01:00
ce2a1ad594 test: expand configure load tests with fixture and error cases 2025-12-30 23:04:45 +01:00
313b6fc453 test: refactor CLI tests and expand coverage 2025-12-30 18:29:34 +01:00
e364598414 test: expect progress propagation 2025-12-30 18:16:51 +01:00
cd1cf1f170 fix: restore progress logger propagation 2025-12-30 18:16:44 +01:00
c3761d1d08 test: cover formatter reset on reuse 2025-12-30 18:13:56 +01:00
85f1ea4efb fix: reset stream formatters in setup_logger 2025-12-30 18:13:51 +01:00
df22b3dd3d test: cover logger propagation behavior 2025-12-30 18:11:48 +01:00
8c0bbceeac fix: disable skywipe log propagation 2025-12-30 18:11:38 +01:00
5e60374937 test: add tests for ProgressTracker.batch with total_batches=0 2025-12-30 17:46:29 +01:00
fd62bb5ea2 fix: use "is not None" check for total_batches in ProgressTracker.batch 2025-12-30 17:46:22 +01:00
6785ecd45a test: strengthen logger handler duplication test 2025-12-30 08:53:00 +01:00
7828989150 refactor: dedupe file handler cleanup in setup_logger 2025-12-30 08:45:39 +01:00
9eb2ed0097 test: verify FileHandler replacement when log_file path changes 2025-12-30 08:31:14 +01:00
5c8932599c fix: replace existing FileHandler when log_file changes 2025-12-30 08:30:59 +01:00
b2af41d5fb style: prefer PEP 604/585 type hints 2025-12-23 05:14:38 +01:00
6de91e2bb9 build: package = true 2025-12-23 04:55:49 +01:00
d026c53c0a chore: update uv.lock 2025-12-23 04:55:42 +01:00
d6ce77ab15 test: cover operation run + bookmark parsing 2025-12-23 04:50:21 +01:00
b6e0c55c3e test: cover operation contect error paths 2025-12-23 04:50:10 +01:00
3b84be90b7 test: cover run_all ordering and errors 2025-12-23 04:49:46 +01:00
b8f6953a17 test: cover config create flows 2025-12-23 04:49:32 +01:00
16 changed files with 936 additions and 172 deletions

View File

@@ -12,6 +12,8 @@ dev = ["pytest>=8.0"]
[project.scripts] [project.scripts]
skywipe = "skywipe.cli:main" skywipe = "skywipe.cli:main"
[tool.uv]
package = true
[tool.pytest.ini_options] [tool.pytest.ini_options]
testpaths = ["tests"] testpaths = ["tests"]
pythonpath = ["."]

View File

@@ -61,14 +61,14 @@ def main():
setup_logger(verbose=False, log_file=LOG_FILE) setup_logger(verbose=False, log_file=LOG_FILE)
logger = get_logger() logger = get_logger()
if registry.requires_config(args.command):
require_config(logger)
config = Configuration()
config_data = config.load()
verbose = config_data.get("verbose", False)
setup_logger(verbose=verbose, log_file=LOG_FILE)
try: try:
if registry.requires_config(args.command):
require_config(logger)
config = Configuration()
config_data = config.load()
verbose = config_data.get("verbose", False)
setup_logger(verbose=verbose, log_file=LOG_FILE)
registry.execute( registry.execute(
args.command, skip_confirmation=getattr(args, "yes", False)) args.command, skip_confirmation=getattr(args, "yes", False))
except (ValueError, Exception) as e: except (ValueError, Exception) as e:

View File

@@ -1,6 +1,6 @@
"""Command implementations for Skywipe""" """Command implementations for Skywipe"""
from typing import Callable, Dict, Optional, Any from typing import Callable, Any, TypedDict
from .configure import Configuration from .configure import Configuration
from .operations import Operation from .operations import Operation
from .post_analysis import PostAnalyzer from .post_analysis import PostAnalyzer
@@ -11,7 +11,16 @@ from .safeguard import require_confirmation
CommandHandler = Callable[..., None] CommandHandler = Callable[..., None]
COMMAND_METADATA = { class CommandMetadata(TypedDict):
confirmation: str
help_text: str
operation_name: str
strategy_type: str
collection: str | None
filter_fn: Callable[[Any], bool] | None
COMMAND_METADATA: dict[str, CommandMetadata] = {
"posts": { "posts": {
"confirmation": "delete all posts", "confirmation": "delete all posts",
"help_text": "only posts", "help_text": "only posts",
@@ -98,16 +107,16 @@ class CommandRegistry:
self._help_texts[name] = help_text self._help_texts[name] = help_text
self._requires_config[name] = requires_config self._requires_config[name] = requires_config
def get_handler(self, name: str) -> Optional[CommandHandler]: def get_handler(self, name: str) -> CommandHandler | None:
return self._commands.get(name) return self._commands.get(name)
def get_help_text(self, name: str) -> Optional[str]: def get_help_text(self, name: str) -> str | None:
return self._help_texts.get(name) return self._help_texts.get(name)
def requires_config(self, name: str) -> bool: def requires_config(self, name: str) -> bool:
return self._requires_config.get(name, True) return self._requires_config.get(name, True)
def get_all_commands(self) -> Dict[str, str]: def get_all_commands(self) -> dict[str, str]:
return self._help_texts.copy() return self._help_texts.copy()
def execute(self, name: str, skip_confirmation: bool = False): def execute(self, name: str, skip_confirmation: bool = False):
@@ -128,8 +137,8 @@ def _create_operation_handler(
confirmation_message: str, confirmation_message: str,
operation_name: str, operation_name: str,
strategy_type: str = "feed", strategy_type: str = "feed",
collection: Optional[str] = None, collection: str | None = None,
filter_fn: Optional[Callable[[Any], bool]] = None filter_fn: Callable[[Any], bool] | None = None
) -> CommandHandler: ) -> CommandHandler:
logger = get_logger() logger = get_logger()

View File

@@ -1,6 +1,6 @@
"""Core configuration module for Skywipe""" """Core configuration module for Skywipe"""
import getpass from getpass import getpass
import re import re
from pathlib import Path from pathlib import Path
from typing import NamedTuple from typing import NamedTuple
@@ -56,65 +56,100 @@ class Configuration:
def exists(self) -> bool: def exists(self) -> bool:
return self.config_file.exists() return self.config_file.exists()
def create(self): def _confirm_overwrite(self, logger) -> bool:
logger = setup_logger(verbose=False) if not self.exists():
if self.exists(): return True
overwrite = input( overwrite = input(
"Configuration already exists. Overwrite? (y/N): ").strip().lower() "Configuration already exists. Overwrite? (y/N): ").strip().lower()
if overwrite not in ("y", "yes"): if overwrite in ("y", "yes"):
logger.info("Configuration creation cancelled.") return True
return logger.info("Configuration creation cancelled.")
return False
def _ensure_config_dir(self) -> None:
config_dir = self.config_file.parent config_dir = self.config_file.parent
config_dir.mkdir(parents=True, exist_ok=True) config_dir.mkdir(parents=True, exist_ok=True)
print("Skywipe Configuration") def _prompt_handle(self, logger) -> str:
print("=" * 50)
print("Note: You should use an app password from Bluesky settings.")
while True: while True:
handle = input("Bluesky handle: ").strip() handle = input("Bluesky handle: ").strip()
is_valid, error_msg = _validate_handle(handle) is_valid, error_msg = _validate_handle(handle)
if is_valid: if is_valid:
break return handle
logger.error(error_msg) logger.error(error_msg)
logger.info("Please enter a valid handle and try again.") logger.info("Please enter a valid handle and try again.")
def _prompt_password(self, logger) -> str:
while True: while True:
password = getpass.getpass( password = getpass(
"Bluesky (hopefully app) password: ").strip() "Bluesky (hopefully app) password: ").strip()
is_valid, error_msg = _validate_password(password) is_valid, error_msg = _validate_password(password)
if is_valid: if is_valid:
break return password
logger.error(error_msg) logger.error(error_msg)
logger.info("Please check your password and try again.") logger.info("Please check your password and try again.")
logger.info( logger.info(
"Generate an app password at: https://bsky.app/settings/app-passwords") "Generate an app password at: https://bsky.app/settings/app-passwords")
def _parse_batch_size(self, logger) -> int | None:
batch_size = input("Batch size (default: 10): ").strip() or "10" batch_size = input("Batch size (default: 10): ").strip() or "10"
delay = input(
"Delay between batches in seconds (default: 1): ").strip() or "1"
verbose_input = input(
"Verbose mode (y/n, default: y): ").strip().lower() or "y"
verbose = verbose_input in ("y", "yes", "true", "1")
try: try:
batch_size = int(batch_size) batch_size_int = int(batch_size)
if batch_size < 1 or batch_size > 100:
logger.error("batch_size must be between 1 and 100")
return
except ValueError: except ValueError:
logger.error("batch_size must be an integer") logger.error("batch_size must be an integer")
return return None
if batch_size_int < 1 or batch_size_int > 100:
logger.error("batch_size must be between 1 and 100")
return None
return batch_size_int
def _parse_delay(self, logger) -> int | None:
delay = input(
"Delay between batches in seconds (default: 1): ").strip() or "1"
try: try:
delay = int(delay) delay_int = int(delay)
if delay < 0 or delay > 60:
logger.error("delay must be between 0 and 60 seconds")
return
except ValueError: except ValueError:
logger.error("delay must be an integer") logger.error("delay must be an integer")
return None
if delay_int < 0 or delay_int > 60:
logger.error("delay must be between 0 and 60 seconds")
return None
return delay_int
def _parse_verbose(self) -> bool:
verbose_input = input(
"Verbose mode (y/n, default: y): ").strip().lower() or "y"
return verbose_input in ("y", "yes", "true", "1")
def _write_config(self, logger, config_data: dict) -> None:
try:
with open(self.config_file, "w") as f:
yaml.dump(config_data, f, default_flow_style=False)
except (IOError, OSError) as e:
logger.error(f"Failed to save configuration: {e}")
return return
logger.info(f"Configuration saved to {self.config_file}")
def create(self):
logger = setup_logger(verbose=False)
if not self._confirm_overwrite(logger):
return
self._ensure_config_dir()
print("Skywipe Configuration")
print("=" * 50)
print("Note: You should use an app password from Bluesky settings.")
handle = self._prompt_handle(logger)
password = self._prompt_password(logger)
batch_size = self._parse_batch_size(logger)
if batch_size is None:
return
delay = self._parse_delay(logger)
if delay is None:
return
verbose = self._parse_verbose()
config_data = { config_data = {
"handle": handle, "handle": handle,
@@ -124,14 +159,7 @@ class Configuration:
"verbose": verbose "verbose": verbose
} }
try: self._write_config(logger, config_data)
with open(self.config_file, "w") as f:
yaml.dump(config_data, f, default_flow_style=False)
except (IOError, OSError) as e:
logger.error(f"Failed to save configuration: {e}")
return
logger.info(f"Configuration saved to {self.config_file}")
def load(self) -> dict: def load(self) -> dict:
if not self.exists(): if not self.exists():

View File

@@ -3,7 +3,6 @@
import logging import logging
import sys import sys
from pathlib import Path from pathlib import Path
from typing import Optional
class ProgressTracker: class ProgressTracker:
@@ -14,9 +13,9 @@ class ProgressTracker:
def update(self, count: int = 1): def update(self, count: int = 1):
self.current += count self.current += count
def batch(self, batch_num: int, batch_size: int, total_batches: Optional[int] = None): def batch(self, batch_num: int, batch_size: int, total_batches: int | None = None):
logger = logging.getLogger("skywipe.progress") logger = logging.getLogger("skywipe.progress")
if total_batches: if total_batches is not None:
logger.info( logger.info(
f"{self.operation} - batch {batch_num}/{total_batches} ({batch_size} items)" f"{self.operation} - batch {batch_num}/{total_batches} ({batch_size} items)"
) )
@@ -35,11 +34,16 @@ class LevelFilter(logging.Filter):
return self.min_level <= record.levelno <= self.max_level return self.min_level <= record.levelno <= self.max_level
def setup_logger(verbose: bool = False, log_file: Optional[Path] = None) -> logging.Logger: def setup_logger(verbose: bool = False, log_file: Path | None = None) -> logging.Logger:
logger = logging.getLogger("skywipe") logger = logging.getLogger("skywipe")
logger.propagate = False
target_level = logging.DEBUG if verbose else logging.INFO target_level = logging.DEBUG if verbose else logging.INFO
logger.setLevel(target_level) logger.setLevel(target_level)
progress_logger = logging.getLogger("skywipe.progress")
if not progress_logger.handlers:
progress_logger.propagate = True
info_handler = None info_handler = None
error_handler = None error_handler = None
file_handlers = [] file_handlers = []
@@ -57,32 +61,34 @@ def setup_logger(verbose: bool = False, log_file: Optional[Path] = None) -> logg
if info_handler is None: if info_handler is None:
info_handler = logging.StreamHandler(sys.stdout) info_handler = logging.StreamHandler(sys.stdout)
info_handler.addFilter(LevelFilter(logging.DEBUG, logging.INFO))
info_handler.setFormatter(formatter)
logger.addHandler(info_handler) logger.addHandler(info_handler)
for existing in list(info_handler.filters):
if isinstance(existing, LevelFilter):
info_handler.removeFilter(existing)
info_handler.addFilter(LevelFilter(logging.DEBUG, logging.INFO))
info_handler.setFormatter(formatter)
info_handler.setLevel(target_level) info_handler.setLevel(target_level)
if error_handler is None: if error_handler is None:
error_handler = logging.StreamHandler(sys.stderr) error_handler = logging.StreamHandler(sys.stderr)
error_handler.setLevel(logging.WARNING) error_handler.setLevel(logging.WARNING)
error_handler.setFormatter(formatter)
logger.addHandler(error_handler) logger.addHandler(error_handler)
error_handler.setFormatter(formatter)
for handler in file_handlers:
handler.close()
logger.removeHandler(handler)
if log_file: if log_file:
if not file_handlers: log_file.parent.mkdir(parents=True, exist_ok=True)
log_file.parent.mkdir(parents=True, exist_ok=True) file_handler = logging.FileHandler(log_file, encoding="utf-8")
file_handler = logging.FileHandler(log_file, encoding="utf-8") file_handler.setLevel(logging.DEBUG)
file_handler.setLevel(logging.DEBUG) file_formatter = logging.Formatter(
file_formatter = logging.Formatter( fmt="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
fmt="%(asctime)s - %(name)s - %(levelname)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S"
datefmt="%Y-%m-%d %H:%M:%S" )
) file_handler.setFormatter(file_formatter)
file_handler.setFormatter(file_formatter) logger.addHandler(file_handler)
logger.addHandler(file_handler)
else:
for handler in file_handlers:
handler.close()
logger.removeHandler(handler)
return logger return logger
@@ -91,11 +97,18 @@ def get_logger() -> logging.Logger:
return logging.getLogger("skywipe") return logging.getLogger("skywipe")
def _format_error_message(error: Exception) -> str:
if isinstance(error, KeyError):
return str(error.args[0]) if error.args else str(error)
return str(error)
def handle_error(error: Exception, logger: logging.Logger, exit_on_error: bool = False) -> None: def handle_error(error: Exception, logger: logging.Logger, exit_on_error: bool = False) -> None:
if isinstance(error, ValueError): if isinstance(error, (KeyError, ValueError)):
logger.error(f"{error}") logger.error(_format_error_message(error))
else: else:
logger.error(f"Unexpected error: {error}", exc_info=True) logger.error(
f"Unexpected error: {_format_error_message(error)}", exc_info=True)
if exit_on_error: if exit_on_error:
sys.exit(1) sys.exit(1)

View File

@@ -1,7 +1,7 @@
"""Shared operation utilities and strategies for Skywipe""" """Shared operation utilities and strategies for Skywipe"""
import time import time
from typing import Callable, Optional, Any from typing import Callable, Any
from atproto import models from atproto import models
from .auth import Auth from .auth import Auth
@@ -52,7 +52,16 @@ class OperationContext:
class BaseStrategy: class BaseStrategy:
def get_cursor(self, response): def fetch(self, context: OperationContext, cursor: str | None = None) -> Any:
raise NotImplementedError
def extract_items(self, response: Any) -> list[Any]:
raise NotImplementedError
def process_item(self, item: Any, context: OperationContext) -> None:
raise NotImplementedError
def get_cursor(self, response: Any) -> str | None:
return response.cursor return response.cursor
@@ -60,7 +69,7 @@ class RecordDeletionStrategy(BaseStrategy):
def __init__(self, collection: str): def __init__(self, collection: str):
self.collection = collection self.collection = collection
def fetch(self, context: OperationContext, cursor: Optional[str] = None): def fetch(self, context: OperationContext, cursor: str | None = None) -> Any:
list_params = models.ComAtprotoRepoListRecords.Params( list_params = models.ComAtprotoRepoListRecords.Params(
repo=context.did, repo=context.did,
collection=self.collection, collection=self.collection,
@@ -69,10 +78,10 @@ class RecordDeletionStrategy(BaseStrategy):
) )
return context.client.com.atproto.repo.list_records(params=list_params) return context.client.com.atproto.repo.list_records(params=list_params)
def extract_items(self, response): def extract_items(self, response: Any) -> list[Any]:
return response.records return response.records
def process_item(self, record, context: OperationContext): def process_item(self, record: Any, context: OperationContext) -> None:
record_uri = record.uri record_uri = record.uri
rkey = record_uri.rsplit("/", 1)[-1] rkey = record_uri.rsplit("/", 1)[-1]
delete_data = { delete_data = {
@@ -85,34 +94,34 @@ class RecordDeletionStrategy(BaseStrategy):
class FeedStrategy(BaseStrategy): class FeedStrategy(BaseStrategy):
def fetch(self, context: OperationContext, cursor: Optional[str] = None): def fetch(self, context: OperationContext, cursor: str | None = None) -> Any:
if cursor: if cursor:
return context.client.get_author_feed( return context.client.get_author_feed(
actor=context.did, limit=context.batch_size, cursor=cursor actor=context.did, limit=context.batch_size, cursor=cursor
) )
return context.client.get_author_feed(actor=context.did, limit=context.batch_size) return context.client.get_author_feed(actor=context.did, limit=context.batch_size)
def extract_items(self, response): def extract_items(self, response: Any) -> list[Any]:
return response.feed return response.feed
def process_item(self, post, context: OperationContext): def process_item(self, post: Any, context: OperationContext) -> None:
uri = post.post.uri uri = post.post.uri
context.client.delete_post(uri) context.client.delete_post(uri)
context.logger.debug(f"Deleted post: {uri}") context.logger.debug(f"Deleted post: {uri}")
class BookmarkStrategy(BaseStrategy): class BookmarkStrategy(BaseStrategy):
def fetch(self, context: OperationContext, cursor: Optional[str] = None): def fetch(self, context: OperationContext, cursor: str | None = None) -> Any:
get_params = models.AppBskyBookmarkGetBookmarks.Params( get_params = models.AppBskyBookmarkGetBookmarks.Params(
limit=context.batch_size, limit=context.batch_size,
cursor=cursor cursor=cursor
) )
return context.client.app.bsky.bookmark.get_bookmarks(params=get_params) return context.client.app.bsky.bookmark.get_bookmarks(params=get_params)
def extract_items(self, response): def extract_items(self, response: Any) -> list[Any]:
return response.bookmarks return response.bookmarks
def process_item(self, bookmark, context: OperationContext): def process_item(self, bookmark: Any, context: OperationContext) -> None:
bookmark_uri = self._extract_bookmark_uri(bookmark) bookmark_uri = self._extract_bookmark_uri(bookmark)
if not bookmark_uri: if not bookmark_uri:
raise ValueError("Unable to find bookmark URI") raise ValueError("Unable to find bookmark URI")
@@ -122,7 +131,7 @@ class BookmarkStrategy(BaseStrategy):
context.client.app.bsky.bookmark.delete_bookmark(data=delete_data) context.client.app.bsky.bookmark.delete_bookmark(data=delete_data)
context.logger.debug(f"Deleted bookmark: {bookmark_uri}") context.logger.debug(f"Deleted bookmark: {bookmark_uri}")
def _extract_bookmark_uri(self, bookmark): def _extract_bookmark_uri(self, bookmark: Any) -> str | None:
if hasattr(bookmark, "uri"): if hasattr(bookmark, "uri"):
return bookmark.uri return bookmark.uri
@@ -139,8 +148,8 @@ class Operation:
self, self,
operation_name: str, operation_name: str,
strategy_type: str = "feed", strategy_type: str = "feed",
collection: Optional[str] = None, collection: str | None = None,
filter_fn: Optional[Callable[[Any], bool]] = None, filter_fn: Callable[[Any], bool] | None = None,
client=None, client=None,
config_data=None config_data=None
): ):
@@ -148,6 +157,7 @@ class Operation:
self.filter_fn = filter_fn self.filter_fn = filter_fn
self._client = client self._client = client
self._config_data = config_data self._config_data = config_data
self.strategy: BaseStrategy
if strategy_type == "record": if strategy_type == "record":
if not collection: if not collection:

View File

@@ -2,14 +2,17 @@
import sys import sys
import logging import logging
from typing import Optional
from .logger import get_logger from .logger import get_logger
CONFIRM_RESPONSES = {"yes", "y"} CONFIRM_RESPONSES = {"yes", "y"}
def require_confirmation(operation: str, skip_confirmation: bool = False, logger: Optional[logging.Logger] = None) -> None: def require_confirmation(
operation: str,
skip_confirmation: bool = False,
logger: logging.Logger | None = None
) -> None:
if skip_confirmation: if skip_confirmation:
return return

View File

@@ -1,7 +1,10 @@
from pathlib import Path
from typing import Iterable, Callable from typing import Iterable, Callable
import pytest import pytest
from skywipe.configure import Configuration
@pytest.fixture @pytest.fixture
def user_input(monkeypatch) -> Callable[[Iterable[str], Iterable[str]], None]: def user_input(monkeypatch) -> Callable[[Iterable[str], Iterable[str]], None]:
@@ -10,7 +13,13 @@ def user_input(monkeypatch) -> Callable[[Iterable[str], Iterable[str]], None]:
password_iter = iter(passwords) password_iter = iter(passwords)
monkeypatch.setattr("builtins.input", lambda _prompt: next(input_iter)) monkeypatch.setattr("builtins.input", lambda _prompt: next(input_iter))
monkeypatch.setattr("getpass.getpass", monkeypatch.setattr("skywipe.configure.getpass",
lambda _prompt: next(password_iter)) lambda _prompt: next(password_iter))
return _set return _set
@pytest.fixture
def config_with_tmp_path(monkeypatch, tmp_path):
monkeypatch.setattr(Path, "home", lambda: tmp_path)
return Configuration()

View File

@@ -6,91 +6,234 @@ import pytest
import skywipe.cli as cli import skywipe.cli as cli
TEST_COMMAND = "posts"
TEST_ERROR_MESSAGE = "boom"
TEST_LOGGER_NAME = "test.cli"
def _setup_parser_mocks(monkeypatch, commands=None):
if commands is None:
commands = {TEST_COMMAND: "only posts"}
monkeypatch.setattr(cli.registry, "get_all_commands", lambda: commands)
def _setup_main_mocks(monkeypatch, calls, requires_config=False, config_data=None):
_setup_parser_mocks(monkeypatch)
monkeypatch.setattr(cli.registry, "requires_config",
lambda name: requires_config)
def mock_execute(name, skip_confirmation=False):
calls["execute"] = (name, skip_confirmation)
def mock_setup_logger(verbose, log_file):
calls["setup"].append((verbose, log_file))
def mock_require_config(logger):
calls["require_config"].append(logger)
monkeypatch.setattr(cli.registry, "execute", mock_execute)
monkeypatch.setattr(cli, "setup_logger", mock_setup_logger)
monkeypatch.setattr(cli, "require_config", mock_require_config)
monkeypatch.setattr(cli, "get_logger",
lambda: logging.getLogger(TEST_LOGGER_NAME))
if config_data is not None:
monkeypatch.setattr(cli.Configuration, "load",
lambda self: config_data)
def _setup_error_mocks(monkeypatch, calls, error_factory):
_setup_parser_mocks(monkeypatch)
monkeypatch.setattr(cli.registry, "requires_config", lambda name: False)
monkeypatch.setattr(cli.registry, "execute", error_factory)
monkeypatch.setattr(cli, "setup_logger", lambda verbose, log_file: None)
monkeypatch.setattr(cli, "get_logger",
lambda: logging.getLogger(TEST_LOGGER_NAME))
def _format_error_message(error):
if isinstance(error, KeyError):
return error.args[0] if error.args else str(error)
return str(error)
def mock_handle_error(error, logger, exit_on_error=False):
calls["handle_error"] = (type(error).__name__,
_format_error_message(error), exit_on_error)
monkeypatch.setattr(cli, "handle_error", mock_handle_error)
def test_create_parser_includes_commands(monkeypatch): def test_create_parser_includes_commands(monkeypatch):
monkeypatch.setattr(cli.registry, "get_all_commands", _setup_parser_mocks(monkeypatch)
lambda: {"posts": "only posts"})
parser = cli.create_parser() parser = cli.create_parser()
args = parser.parse_args(["posts"]) args = parser.parse_args([TEST_COMMAND])
assert args.command == "posts" assert args.command == TEST_COMMAND
def test_create_parser_handles_multiple_commands(monkeypatch):
commands = {
"posts": "only posts",
"likes": "only likes",
"reposts": "only reposts"
}
_setup_parser_mocks(monkeypatch, commands)
parser = cli.create_parser()
args1 = parser.parse_args(["posts"])
args2 = parser.parse_args(["likes"])
args3 = parser.parse_args(["reposts"])
assert args1.command == "posts"
assert args2.command == "likes"
assert args3.command == "reposts"
def test_create_parser_parses_yes_flag(monkeypatch):
_setup_parser_mocks(monkeypatch)
parser = cli.create_parser()
args = parser.parse_args(["--yes", TEST_COMMAND])
assert args.command == TEST_COMMAND
assert args.yes is True
def test_create_parser_parses_without_yes_flag(monkeypatch):
_setup_parser_mocks(monkeypatch)
parser = cli.create_parser()
args = parser.parse_args([TEST_COMMAND])
assert args.command == TEST_COMMAND
assert getattr(args, "yes", False) is False
def test_create_parser_version_flag_exits(monkeypatch):
_setup_parser_mocks(monkeypatch)
parser = cli.create_parser()
with pytest.raises(SystemExit) as excinfo:
parser.parse_args(["--version"])
assert excinfo.value.code == 0
def test_create_parser_requires_command(monkeypatch):
_setup_parser_mocks(monkeypatch)
parser = cli.create_parser()
with pytest.raises(SystemExit):
parser.parse_args([])
def test_create_parser_rejects_invalid_command(monkeypatch):
_setup_parser_mocks(monkeypatch)
parser = cli.create_parser()
with pytest.raises(SystemExit):
parser.parse_args(["invalid_command"])
def test_require_config_exits_when_missing(monkeypatch): def test_require_config_exits_when_missing(monkeypatch):
monkeypatch.setattr(cli.Configuration, "exists", lambda self: False) monkeypatch.setattr(cli.Configuration, "exists", lambda self: False)
logger = logging.getLogger("test.cli") logger = logging.getLogger(TEST_LOGGER_NAME)
with pytest.raises(SystemExit) as excinfo: with pytest.raises(SystemExit) as excinfo:
cli.require_config(logger) cli.require_config(logger)
assert excinfo.value.code == 1 assert excinfo.value.code == 1
def test_require_config_does_not_exit_when_exists(monkeypatch):
monkeypatch.setattr(cli.Configuration, "exists", lambda self: True)
logger = logging.getLogger(TEST_LOGGER_NAME)
cli.require_config(logger)
def test_main_executes_without_config(monkeypatch): def test_main_executes_without_config(monkeypatch):
calls = {"execute": None, "setup": []} calls = {"execute": None, "setup": [], "require_config": []}
_setup_main_mocks(monkeypatch, calls, requires_config=False)
monkeypatch.setattr(cli.registry, "get_all_commands", monkeypatch.setattr(sys, "argv", ["skywipe", "--yes", TEST_COMMAND])
lambda: {"posts": "only posts"})
monkeypatch.setattr(cli.registry, "requires_config", lambda name: False)
monkeypatch.setattr(cli.registry, "execute", lambda name, skip_confirmation=False: calls.update(
{"execute": (name, skip_confirmation)}
))
monkeypatch.setattr(cli, "setup_logger", lambda verbose,
log_file: calls["setup"].append((verbose, log_file)))
monkeypatch.setattr(cli, "get_logger",
lambda: logging.getLogger("test.cli"))
monkeypatch.setattr(sys, "argv", ["skywipe", "--yes", "posts"])
cli.main() cli.main()
assert len(calls["require_config"]) == 0
assert calls["setup"] == [(False, cli.LOG_FILE)] assert calls["setup"] == [(False, cli.LOG_FILE)]
assert calls["execute"] == ("posts", True) assert calls["execute"] == (TEST_COMMAND, True)
def test_main_loads_config_and_sets_verbose(monkeypatch): def test_main_loads_config_and_sets_verbose(monkeypatch):
calls = {"setup": [], "execute": None, "require_config": 0} calls = {"setup": [], "execute": None, "require_config": []}
_setup_main_mocks(monkeypatch, calls, requires_config=True,
config_data={"verbose": True})
monkeypatch.setattr(cli.registry, "get_all_commands", monkeypatch.setattr(sys, "argv", ["skywipe", TEST_COMMAND])
lambda: {"posts": "only posts"})
monkeypatch.setattr(cli.registry, "requires_config", lambda name: True)
monkeypatch.setattr(cli.registry, "execute", lambda name, skip_confirmation=False: calls.update(
{"execute": (name, skip_confirmation)}
))
monkeypatch.setattr(cli, "require_config", lambda logger: calls.update(
{"require_config": calls["require_config"] + 1}
))
monkeypatch.setattr(cli.Configuration, "load",
lambda self: {"verbose": True})
monkeypatch.setattr(cli, "setup_logger", lambda verbose,
log_file: calls["setup"].append((verbose, log_file)))
monkeypatch.setattr(cli, "get_logger",
lambda: logging.getLogger("test.cli"))
monkeypatch.setattr(sys, "argv", ["skywipe", "posts"])
cli.main() cli.main()
assert calls["require_config"] == 1 assert len(calls["require_config"]) == 1
assert calls["setup"] == [(False, cli.LOG_FILE), (True, cli.LOG_FILE)] assert calls["setup"] == [(False, cli.LOG_FILE), (True, cli.LOG_FILE)]
assert calls["execute"] == ("posts", False) assert calls["execute"] == (TEST_COMMAND, False)
def test_main_handles_execute_error(monkeypatch): @pytest.mark.parametrize("config_data,expected_verbose", [
calls = {"handle_error": None} ({}, False),
({"verbose": False}, False),
])
def test_main_config_verbose_defaults(monkeypatch, config_data, expected_verbose):
calls = {"setup": [], "execute": None, "require_config": []}
_setup_main_mocks(monkeypatch, calls, requires_config=True,
config_data=config_data)
monkeypatch.setattr(cli.registry, "get_all_commands", monkeypatch.setattr(sys, "argv", ["skywipe", TEST_COMMAND])
lambda: {"posts": "only posts"}) cli.main()
monkeypatch.setattr(cli.registry, "requires_config", lambda name: False)
def raise_error(*_args, **_kwargs): assert len(calls["require_config"]) == 1
raise ValueError("boom") assert calls["setup"] == [(False, cli.LOG_FILE),
(expected_verbose, cli.LOG_FILE)]
assert calls["execute"] == (TEST_COMMAND, False)
monkeypatch.setattr(cli.registry, "execute", raise_error)
def test_main_handles_config_load_error(monkeypatch):
calls = {"handle_error": None, "require_config": []}
def mock_require_config(logger):
calls["require_config"].append(logger)
def raise_config_error(self):
raise RuntimeError("config error")
def _format_error_message(error):
if isinstance(error, KeyError):
return error.args[0] if error.args else str(error)
return str(error)
def mock_handle_error(error, logger, exit_on_error=False):
calls["handle_error"] = (type(error).__name__,
_format_error_message(error), exit_on_error)
_setup_parser_mocks(monkeypatch)
monkeypatch.setattr(cli.registry, "requires_config", lambda name: True)
monkeypatch.setattr(cli, "require_config", mock_require_config)
monkeypatch.setattr(cli.Configuration, "load", raise_config_error)
monkeypatch.setattr(cli, "setup_logger", lambda verbose, log_file: None) monkeypatch.setattr(cli, "setup_logger", lambda verbose, log_file: None)
monkeypatch.setattr(cli, "get_logger", monkeypatch.setattr(cli, "get_logger",
lambda: logging.getLogger("test.cli")) lambda: logging.getLogger(TEST_LOGGER_NAME))
monkeypatch.setattr(cli, "handle_error", mock_handle_error)
def fake_handle_error(error, logger, exit_on_error=False):
calls["handle_error"] = (str(error), exit_on_error)
monkeypatch.setattr(cli, "handle_error", fake_handle_error)
monkeypatch.setattr(sys, "argv", ["skywipe", "posts"])
monkeypatch.setattr(sys, "argv", ["skywipe", TEST_COMMAND])
cli.main() cli.main()
assert calls["handle_error"] == ("boom", True) assert len(calls["require_config"]) == 1
assert calls["handle_error"] is not None
assert calls["handle_error"][0] == "RuntimeError"
assert calls["handle_error"][2] is True
@pytest.mark.parametrize("error_class,error_message", [
(ValueError, TEST_ERROR_MESSAGE),
(RuntimeError, "runtime error"),
(KeyError, "missing key"),
])
def test_main_handles_execute_errors(monkeypatch, error_class, error_message):
calls = {"handle_error": None}
def raise_error(*_args, **_kwargs):
raise error_class(error_message)
_setup_error_mocks(monkeypatch, calls, raise_error)
monkeypatch.setattr(sys, "argv", ["skywipe", TEST_COMMAND])
cli.main()
assert calls["handle_error"] is not None
assert calls["handle_error"][0] == error_class.__name__
assert calls["handle_error"][1] == error_message
assert calls["handle_error"][2] is True

View File

@@ -0,0 +1,132 @@
import logging
import pytest
import skywipe.commands as commands
import skywipe.operations as operations
def test_create_operation_handler_calls_confirmation_and_run(monkeypatch):
calls = {"confirmed": False, "run": 0}
def fake_confirm(message, skip_confirmation, logger):
calls["confirmed"] = True
assert message == "do the thing"
assert skip_confirmation is True
assert isinstance(logger, logging.Logger)
class FakeOperation:
def __init__(self, *args, **kwargs):
pass
def run(self):
calls["run"] += 1
monkeypatch.setattr(commands, "require_confirmation", fake_confirm)
monkeypatch.setattr(commands, "Operation", FakeOperation)
handler = commands._create_operation_handler(
"do the thing", "Test", strategy_type="feed"
)
handler(skip_confirmation=True)
assert calls["confirmed"] is True
assert calls["run"] == 1
def test_run_all_runs_in_order(monkeypatch):
ran = []
fake_commands = {
"posts": "only posts",
"likes": "only likes",
"medias": "only medias",
"extra": "extra",
}
fake_metadata = {
"posts": {
"operation_name": "Deleting posts",
"strategy_type": "feed",
"collection": None,
"filter_fn": None,
},
"likes": {
"operation_name": "Undoing likes",
"strategy_type": "record",
"collection": "app.bsky.feed.like",
"filter_fn": None,
},
"medias": {
"operation_name": "Deleting posts with media",
"strategy_type": "feed",
"collection": None,
"filter_fn": None,
},
"extra": {
"operation_name": "Extra op",
"strategy_type": "feed",
"collection": None,
"filter_fn": None,
},
}
fake_order = ["medias", "posts", "likes"]
class FakeOperation:
def __init__(self, operation_name, **kwargs):
self.operation_name = operation_name
def run(self):
ran.append(self.operation_name)
class FakeOperationContext:
def __init__(self):
self.client = object()
self.config_data = {"batch_size": 1, "delay": 0}
monkeypatch.setattr(commands, "require_confirmation",
lambda *args, **kwargs: None)
monkeypatch.setattr(commands, "Operation", FakeOperation)
monkeypatch.setattr(operations, "OperationContext", FakeOperationContext)
monkeypatch.setattr(commands.registry,
"get_all_commands", lambda: fake_commands)
monkeypatch.setattr(commands, "COMMAND_METADATA", fake_metadata)
monkeypatch.setattr(commands, "COMMAND_EXECUTION_ORDER", fake_order)
commands.run_all(skip_confirmation=True)
expected = [
"Deleting posts with media",
"Deleting posts",
"Undoing likes",
"Extra op",
]
assert ran == expected
def test_run_all_continues_on_error(monkeypatch):
ran = []
class FakeOperation:
def __init__(self, operation_name, **kwargs):
self.operation_name = operation_name
def run(self):
ran.append(self.operation_name)
if self.operation_name == "Deleting posts":
raise RuntimeError("fail")
class FakeOperationContext:
def __init__(self):
self.client = object()
self.config_data = {"batch_size": 1, "delay": 0}
monkeypatch.setattr(commands, "require_confirmation",
lambda *args, **kwargs: None)
monkeypatch.setattr(commands, "Operation", FakeOperation)
monkeypatch.setattr(operations, "OperationContext", FakeOperationContext)
commands.run_all(skip_confirmation=True)
assert "Deleting posts" in ran
assert len(ran) >= 2

View File

@@ -0,0 +1,97 @@
from pathlib import Path
import yaml
def test_configuration_create_reprompts_and_writes_file(config_with_tmp_path, user_input):
inputs = iter([
"bad handle",
"alice.bsky.social",
"5",
"0",
"n",
])
passwords = iter([
"short",
"longenough",
])
user_input(inputs, passwords)
config = config_with_tmp_path
config.create()
assert config.config_file.exists() is True
data = yaml.safe_load(config.config_file.read_text())
assert data["handle"] == "alice.bsky.social"
assert data["password"] == "longenough"
assert data["batch_size"] == 5
assert data["delay"] == 0
assert data["verbose"] is False
def test_configuration_create_invalid_batch_size(config_with_tmp_path, user_input):
inputs = iter([
"alice.bsky.social",
"0",
"1",
"y",
])
passwords = iter(["longenough"])
user_input(inputs, passwords)
config = config_with_tmp_path
config.create()
assert config.config_file.exists() is False
def test_configuration_create_invalid_delay(config_with_tmp_path, user_input):
inputs = iter([
"alice.bsky.social",
"10",
"61",
"y",
])
passwords = iter(["longenough"])
user_input(inputs, passwords)
config = config_with_tmp_path
config.create()
assert config.config_file.exists() is False
def test_configuration_create_overwrite_cancel(config_with_tmp_path, user_input):
config = config_with_tmp_path
config.config_file.parent.mkdir(parents=True, exist_ok=True)
config.config_file.write_text("existing")
user_input(["n"], [])
config.create()
assert config.config_file.read_text() == "existing"
def test_configuration_create_write_failure(config_with_tmp_path, user_input, monkeypatch):
user_input(
["alice.bsky.social", "5", "0", "y"],
["longenough"],
)
config = config_with_tmp_path
original_open = open
def fake_open(path, mode="r", *args, **kwargs):
if Path(path) == config.config_file and "w" in mode:
raise OSError("disk full")
return original_open(path, mode, *args, **kwargs)
monkeypatch.setattr("builtins.open", fake_open)
config.create()
assert config.config_file.exists() is False

View File

@@ -1,32 +1,59 @@
from pathlib import Path from unittest.mock import patch
import pytest import pytest
import yaml
from skywipe.configure import Configuration
def test_configuration_load_missing_file(monkeypatch, tmp_path): def test_configuration_load_missing_file(config_with_tmp_path):
monkeypatch.setattr(Path, "home", lambda: tmp_path) with pytest.raises(FileNotFoundError, match="Configuration file not found"):
config = Configuration() config_with_tmp_path.load()
with pytest.raises(FileNotFoundError):
config.load()
def test_configuration_load_empty_file(monkeypatch, tmp_path): def test_configuration_load_empty_file(config_with_tmp_path):
monkeypatch.setattr(Path, "home", lambda: tmp_path) config_with_tmp_path.config_file.parent.mkdir(parents=True, exist_ok=True)
config = Configuration() config_with_tmp_path.config_file.write_text("")
config.config_file.parent.mkdir(parents=True, exist_ok=True)
config.config_file.write_text("")
with pytest.raises(ValueError, match="empty or invalid"): with pytest.raises(ValueError, match="empty or invalid"):
config.load() config_with_tmp_path.load()
def test_configuration_load_invalid_yaml(monkeypatch, tmp_path): def test_configuration_load_invalid_yaml(config_with_tmp_path):
monkeypatch.setattr(Path, "home", lambda: tmp_path) config_with_tmp_path.config_file.parent.mkdir(parents=True, exist_ok=True)
config = Configuration() config_with_tmp_path.config_file.write_text(": bad")
config.config_file.parent.mkdir(parents=True, exist_ok=True)
config.config_file.write_text(": bad")
with pytest.raises(ValueError, match="Invalid YAML"): with pytest.raises(ValueError, match="Invalid YAML"):
config.load() config_with_tmp_path.load()
def test_configuration_load_valid_file(config_with_tmp_path):
config_with_tmp_path.config_file.parent.mkdir(parents=True, exist_ok=True)
config_data = {
"handle": "alice.bsky.social",
"password": "password123",
"batch_size": 10,
"delay": 1,
"verbose": True
}
with open(config_with_tmp_path.config_file, "w") as f:
yaml.dump(config_data, f)
result = config_with_tmp_path.load()
assert result == config_data
def test_configuration_load_file_read_error(config_with_tmp_path):
config_with_tmp_path.config_file.parent.mkdir(parents=True, exist_ok=True)
config_with_tmp_path.config_file.write_text("handle: alice")
with patch("builtins.open", side_effect=IOError("Permission denied")):
with pytest.raises(ValueError, match="Failed to read configuration file"):
config_with_tmp_path.load()
def test_configuration_load_file_os_error(config_with_tmp_path):
config_with_tmp_path.config_file.parent.mkdir(parents=True, exist_ok=True)
config_with_tmp_path.config_file.write_text("handle: alice")
with patch("builtins.open", side_effect=OSError("File is locked")):
with pytest.raises(ValueError, match="Failed to read configuration file"):
config_with_tmp_path.load()

View File

@@ -1,4 +1,5 @@
import logging import logging
import sys
from skywipe.logger import LevelFilter, ProgressTracker, setup_logger from skywipe.logger import LevelFilter, ProgressTracker, setup_logger
@@ -40,6 +41,53 @@ def test_progress_tracker_updates_counts():
assert tracker.current == 3 assert tracker.current == 3
def test_progress_tracker_batch_with_total():
tracker = ProgressTracker(operation="Testing")
logger = logging.getLogger("skywipe.progress")
messages = []
class MessageHandler(logging.Handler):
def emit(self, record):
messages.append(self.format(record))
handler = MessageHandler()
handler.setLevel(logging.INFO)
logger.addHandler(handler)
logger.setLevel(logging.INFO)
try:
tracker.batch(1, 10, total_batches=5)
assert messages[-1] == "Testing - batch 1/5 (10 items)"
tracker.batch(0, 5, total_batches=0)
assert messages[-1] == "Testing - batch 0/0 (5 items)"
finally:
handler.close()
logger.removeHandler(handler)
def test_progress_tracker_batch_without_total():
tracker = ProgressTracker(operation="Testing")
logger = logging.getLogger("skywipe.progress")
messages = []
class MessageHandler(logging.Handler):
def emit(self, record):
messages.append(self.format(record))
handler = MessageHandler()
handler.setLevel(logging.INFO)
logger.addHandler(handler)
logger.setLevel(logging.INFO)
try:
tracker.batch(1, 10, total_batches=None)
assert messages[-1] == "Testing - batch 1 (10 items)"
finally:
handler.close()
logger.removeHandler(handler)
def test_setup_logger_does_not_duplicate_handlers(): def test_setup_logger_does_not_duplicate_handlers():
logger = logging.getLogger("skywipe") logger = logging.getLogger("skywipe")
original_handlers = list(logger.handlers) original_handlers = list(logger.handlers)
@@ -48,11 +96,27 @@ def test_setup_logger_does_not_duplicate_handlers():
try: try:
setup_logger(verbose=False) setup_logger(verbose=False)
first_count = len(logger.handlers) first_handlers = list(logger.handlers)
stream_handlers = [
handler for handler in first_handlers
if isinstance(handler, logging.StreamHandler)
]
assert len(stream_handlers) == 2
assert {handler.stream for handler in stream_handlers} == {
sys.stdout,
sys.stderr,
}
assert not any(
isinstance(handler, logging.FileHandler)
for handler in first_handlers
)
first_count = len(first_handlers)
setup_logger(verbose=False) setup_logger(verbose=False)
second_count = len(logger.handlers) second_count = len(logger.handlers)
finally: finally:
for handler in list(logger.handlers): for handler in list(logger.handlers):
handler.close()
logger.removeHandler(handler) logger.removeHandler(handler)
for handler in original_handlers: for handler in original_handlers:
logger.addHandler(handler) logger.addHandler(handler)
@@ -60,6 +124,74 @@ def test_setup_logger_does_not_duplicate_handlers():
assert first_count == second_count assert first_count == second_count
def test_setup_logger_resets_stream_formatters():
logger = logging.getLogger("skywipe")
original_handlers = list(logger.handlers)
for handler in original_handlers:
logger.removeHandler(handler)
try:
setup_logger(verbose=False)
alt_formatter = logging.Formatter(fmt="%(message)s")
for handler in logger.handlers:
if isinstance(handler, logging.StreamHandler):
handler.setFormatter(alt_formatter)
setup_logger(verbose=False)
for handler in logger.handlers:
if isinstance(handler, logging.StreamHandler):
assert handler.formatter is not None
assert handler.formatter._fmt == "%(levelname)s: %(message)s"
finally:
for handler in list(logger.handlers):
handler.close()
logger.removeHandler(handler)
for handler in original_handlers:
logger.addHandler(handler)
def test_setup_logger_disables_propagation():
root_logger = logging.getLogger()
root_messages = []
original_root_level = root_logger.level
class RootHandler(logging.Handler):
def emit(self, record):
root_messages.append(self.format(record))
root_handler = RootHandler()
root_handler.setLevel(logging.INFO)
root_logger.addHandler(root_handler)
root_logger.setLevel(logging.INFO)
logger = logging.getLogger("skywipe")
original_handlers = list(logger.handlers)
for handler in original_handlers:
logger.removeHandler(handler)
try:
setup_logger(verbose=False)
assert logger.propagate is False
progress_logger = logging.getLogger("skywipe.progress")
assert progress_logger.propagate is True
logger.info("Test message")
progress_logger.info("Progress message")
assert len(root_messages) == 0
finally:
root_handler.close()
root_logger.removeHandler(root_handler)
root_logger.setLevel(original_root_level)
for handler in list(logger.handlers):
handler.close()
logger.removeHandler(handler)
for handler in original_handlers:
logger.addHandler(handler)
def test_setup_logger_file_handler_lifecycle(tmp_path): def test_setup_logger_file_handler_lifecycle(tmp_path):
logger = logging.getLogger("skywipe") logger = logging.getLogger("skywipe")
original_handlers = list(logger.handlers) original_handlers = list(logger.handlers)
@@ -87,3 +219,34 @@ def test_setup_logger_file_handler_lifecycle(tmp_path):
logger.removeHandler(handler) logger.removeHandler(handler)
for handler in original_handlers: for handler in original_handlers:
logger.addHandler(handler) logger.addHandler(handler)
def test_setup_logger_replaces_file_handler_when_path_changes(tmp_path):
logger = logging.getLogger("skywipe")
original_handlers = list(logger.handlers)
for handler in original_handlers:
logger.removeHandler(handler)
log_file1 = tmp_path / "skywipe1.log"
log_file2 = tmp_path / "skywipe2.log"
try:
setup_logger(verbose=False, log_file=log_file1)
file_handlers = [
handler for handler in logger.handlers
if isinstance(handler, logging.FileHandler)
]
assert len(file_handlers) == 1
assert file_handlers[0].baseFilename == str(log_file1)
setup_logger(verbose=False, log_file=log_file2)
file_handlers = [
handler for handler in logger.handlers
if isinstance(handler, logging.FileHandler)
]
assert len(file_handlers) == 1
assert file_handlers[0].baseFilename == str(log_file2)
finally:
for handler in list(logger.handlers):
logger.removeHandler(handler)
for handler in original_handlers:
logger.addHandler(handler)

View File

@@ -0,0 +1,30 @@
import pytest
import skywipe.operations as operations
def test_operation_context_raises_on_auth_error(monkeypatch):
class FakeAuth:
def login(self):
raise ValueError("bad auth")
monkeypatch.setattr(operations, "Auth", FakeAuth)
with pytest.raises(ValueError, match="bad auth"):
operations.OperationContext()
def test_operation_context_raises_on_config_error(monkeypatch):
class FakeClient:
class Me:
did = "did:plc:fake"
me = Me()
def fake_load(self):
raise ValueError("bad config")
monkeypatch.setattr(operations.Configuration, "load", fake_load)
with pytest.raises(ValueError, match="bad config"):
operations.OperationContext(client=FakeClient())

98
tests/test_operations.py Normal file
View File

@@ -0,0 +1,98 @@
import time
import pytest
from skywipe.operations import Operation, BookmarkStrategy
class FakeClient:
class Me:
did = "did:plc:fake"
me = Me()
class FakeResponse:
def __init__(self, items, cursor=None):
self.items = items
self.cursor = cursor
class FakeStrategy:
def __init__(self, responses, fail_on=None):
self._responses = list(responses)
self._fail_on = fail_on
def fetch(self, context, cursor=None):
return self._responses.pop(0)
def extract_items(self, response):
return response.items
def process_item(self, item, context):
if self._fail_on is not None and item == self._fail_on:
raise ValueError("boom")
def get_cursor(self, response):
return response.cursor
def test_operation_run_batches_filters_and_sleeps(monkeypatch):
responses = [
FakeResponse(items=[1, 2, 3], cursor="next"),
FakeResponse(items=[4], cursor=None),
]
operation = Operation(
"Testing",
strategy_type="feed",
client=FakeClient(),
config_data={"batch_size": 2, "delay": 1},
filter_fn=lambda item: item != 2,
)
operation.strategy = FakeStrategy(responses, fail_on=3)
slept = []
def fake_sleep(seconds):
slept.append(seconds)
monkeypatch.setattr(time, "sleep", fake_sleep)
total = operation.run()
assert total == 2
assert slept == [1]
def test_bookmark_strategy_extracts_uri_from_shapes():
strategy = BookmarkStrategy()
class Obj:
pass
direct = Obj()
direct.uri = "direct"
assert strategy._extract_bookmark_uri(direct) == "direct"
subject = Obj()
subject.subject = Obj()
subject.subject.uri = "subject"
assert strategy._extract_bookmark_uri(subject) == "subject"
record = Obj()
record.record = Obj()
record.record.uri = "record"
assert strategy._extract_bookmark_uri(record) == "record"
post = Obj()
post.post = Obj()
post.post.uri = "post"
assert strategy._extract_bookmark_uri(post) == "post"
item = Obj()
item.item = Obj()
item.item.uri = "item"
assert strategy._extract_bookmark_uri(item) == "item"
missing = Obj()
assert strategy._extract_bookmark_uri(missing) is None

2
uv.lock generated
View File

@@ -430,7 +430,7 @@ wheels = [
[[package]] [[package]]
name = "skywipe" name = "skywipe"
version = "0.1.0" version = "0.1.0"
source = { virtual = "." } source = { editable = "." }
dependencies = [ dependencies = [
{ name = "atproto" }, { name = "atproto" },
{ name = "pyyaml" }, { name = "pyyaml" },