Compare commits

..

25 Commits

Author SHA1 Message Date
2cdc4c6c42 refactor: tidy logger handler setup 2026-01-15 16:06:32 +01:00
07862f0ea2 refactor: simplify configuration create flow 2026-01-15 16:02:21 +01:00
e095c68f72 test: share error message formatting helper 2026-01-08 06:51:56 +01:00
c718e8c6f5 test: patch configure getpass correctly 2026-01-08 06:51:46 +01:00
a07cc02fb0 fix: format KeyError messages cleanly 2026-01-08 06:51:37 +01:00
ecc33054af fix: handle config load errors via handler 2026-01-08 06:51:23 +01:00
e6d68dd37d refactor: type strategy interface 2026-01-05 20:38:20 +01:00
809b7823ea refactor: type command metadata 2026-01-05 20:38:09 +01:00
6a88ab8560 refactor: import getpass directly 2025-12-31 08:22:27 +01:00
dd3b220a4a test: use shared config fixture 2025-12-30 23:09:01 +01:00
82b99da50d test: add config_with_tmp_path fixture 2025-12-30 23:08:51 +01:00
ce2a1ad594 test: expand configure load tests with fixture and error cases 2025-12-30 23:04:45 +01:00
313b6fc453 test: refactor CLI tests and expand coverage 2025-12-30 18:29:34 +01:00
e364598414 test: expect progress propagation 2025-12-30 18:16:51 +01:00
cd1cf1f170 fix: restore progress logger propagation 2025-12-30 18:16:44 +01:00
c3761d1d08 test: cover formatter reset on reuse 2025-12-30 18:13:56 +01:00
85f1ea4efb fix: reset stream formatters in setup_logger 2025-12-30 18:13:51 +01:00
df22b3dd3d test: cover logger propagation behavior 2025-12-30 18:11:48 +01:00
8c0bbceeac fix: disable skywipe log propagation 2025-12-30 18:11:38 +01:00
5e60374937 test: add tests for ProgressTracker.batch with total_batches=0 2025-12-30 17:46:29 +01:00
fd62bb5ea2 fix: use "is not None" check for total_batches in ProgressTracker.batch 2025-12-30 17:46:22 +01:00
6785ecd45a test: strengthen logger handler duplication test 2025-12-30 08:53:00 +01:00
7828989150 refactor: dedupe file handler cleanup in setup_logger 2025-12-30 08:45:39 +01:00
9eb2ed0097 test: verify FileHandler replacement when log_file path changes 2025-12-30 08:31:14 +01:00
5c8932599c fix: replace existing FileHandler when log_file changes 2025-12-30 08:30:59 +01:00
10 changed files with 570 additions and 175 deletions

View File

@@ -61,6 +61,7 @@ def main():
setup_logger(verbose=False, log_file=LOG_FILE)
logger = get_logger()
try:
if registry.requires_config(args.command):
require_config(logger)
config = Configuration()
@@ -68,7 +69,6 @@ def main():
verbose = config_data.get("verbose", False)
setup_logger(verbose=verbose, log_file=LOG_FILE)
try:
registry.execute(
args.command, skip_confirmation=getattr(args, "yes", False))
except (ValueError, Exception) as e:

View File

@@ -1,6 +1,6 @@
"""Command implementations for Skywipe"""
from typing import Callable, Any
from typing import Callable, Any, TypedDict
from .configure import Configuration
from .operations import Operation
from .post_analysis import PostAnalyzer
@@ -11,7 +11,16 @@ from .safeguard import require_confirmation
CommandHandler = Callable[..., None]
COMMAND_METADATA = {
class CommandMetadata(TypedDict):
confirmation: str
help_text: str
operation_name: str
strategy_type: str
collection: str | None
filter_fn: Callable[[Any], bool] | None
COMMAND_METADATA: dict[str, CommandMetadata] = {
"posts": {
"confirmation": "delete all posts",
"help_text": "only posts",

View File

@@ -1,6 +1,6 @@
"""Core configuration module for Skywipe"""
import getpass
from getpass import getpass
import re
from pathlib import Path
from typing import NamedTuple
@@ -56,65 +56,100 @@ class Configuration:
def exists(self) -> bool:
return self.config_file.exists()
def create(self):
logger = setup_logger(verbose=False)
if self.exists():
def _confirm_overwrite(self, logger) -> bool:
if not self.exists():
return True
overwrite = input(
"Configuration already exists. Overwrite? (y/N): ").strip().lower()
if overwrite not in ("y", "yes"):
if overwrite in ("y", "yes"):
return True
logger.info("Configuration creation cancelled.")
return
return False
def _ensure_config_dir(self) -> None:
config_dir = self.config_file.parent
config_dir.mkdir(parents=True, exist_ok=True)
print("Skywipe Configuration")
print("=" * 50)
print("Note: You should use an app password from Bluesky settings.")
def _prompt_handle(self, logger) -> str:
while True:
handle = input("Bluesky handle: ").strip()
is_valid, error_msg = _validate_handle(handle)
if is_valid:
break
return handle
logger.error(error_msg)
logger.info("Please enter a valid handle and try again.")
def _prompt_password(self, logger) -> str:
while True:
password = getpass.getpass(
password = getpass(
"Bluesky (hopefully app) password: ").strip()
is_valid, error_msg = _validate_password(password)
if is_valid:
break
return password
logger.error(error_msg)
logger.info("Please check your password and try again.")
logger.info(
"Generate an app password at: https://bsky.app/settings/app-passwords")
def _parse_batch_size(self, logger) -> int | None:
batch_size = input("Batch size (default: 10): ").strip() or "10"
delay = input(
"Delay between batches in seconds (default: 1): ").strip() or "1"
verbose_input = input(
"Verbose mode (y/n, default: y): ").strip().lower() or "y"
verbose = verbose_input in ("y", "yes", "true", "1")
try:
batch_size = int(batch_size)
if batch_size < 1 or batch_size > 100:
logger.error("batch_size must be between 1 and 100")
return
batch_size_int = int(batch_size)
except ValueError:
logger.error("batch_size must be an integer")
return
return None
if batch_size_int < 1 or batch_size_int > 100:
logger.error("batch_size must be between 1 and 100")
return None
return batch_size_int
def _parse_delay(self, logger) -> int | None:
delay = input(
"Delay between batches in seconds (default: 1): ").strip() or "1"
try:
delay = int(delay)
if delay < 0 or delay > 60:
logger.error("delay must be between 0 and 60 seconds")
return
delay_int = int(delay)
except ValueError:
logger.error("delay must be an integer")
return None
if delay_int < 0 or delay_int > 60:
logger.error("delay must be between 0 and 60 seconds")
return None
return delay_int
def _parse_verbose(self) -> bool:
verbose_input = input(
"Verbose mode (y/n, default: y): ").strip().lower() or "y"
return verbose_input in ("y", "yes", "true", "1")
def _write_config(self, logger, config_data: dict) -> None:
try:
with open(self.config_file, "w") as f:
yaml.dump(config_data, f, default_flow_style=False)
except (IOError, OSError) as e:
logger.error(f"Failed to save configuration: {e}")
return
logger.info(f"Configuration saved to {self.config_file}")
def create(self):
logger = setup_logger(verbose=False)
if not self._confirm_overwrite(logger):
return
self._ensure_config_dir()
print("Skywipe Configuration")
print("=" * 50)
print("Note: You should use an app password from Bluesky settings.")
handle = self._prompt_handle(logger)
password = self._prompt_password(logger)
batch_size = self._parse_batch_size(logger)
if batch_size is None:
return
delay = self._parse_delay(logger)
if delay is None:
return
verbose = self._parse_verbose()
config_data = {
"handle": handle,
@@ -124,14 +159,7 @@ class Configuration:
"verbose": verbose
}
try:
with open(self.config_file, "w") as f:
yaml.dump(config_data, f, default_flow_style=False)
except (IOError, OSError) as e:
logger.error(f"Failed to save configuration: {e}")
return
logger.info(f"Configuration saved to {self.config_file}")
self._write_config(logger, config_data)
def load(self) -> dict:
if not self.exists():

View File

@@ -15,7 +15,7 @@ class ProgressTracker:
def batch(self, batch_num: int, batch_size: int, total_batches: int | None = None):
logger = logging.getLogger("skywipe.progress")
if total_batches:
if total_batches is not None:
logger.info(
f"{self.operation} - batch {batch_num}/{total_batches} ({batch_size} items)"
)
@@ -36,9 +36,14 @@ class LevelFilter(logging.Filter):
def setup_logger(verbose: bool = False, log_file: Path | None = None) -> logging.Logger:
logger = logging.getLogger("skywipe")
logger.propagate = False
target_level = logging.DEBUG if verbose else logging.INFO
logger.setLevel(target_level)
progress_logger = logging.getLogger("skywipe.progress")
if not progress_logger.handlers:
progress_logger.propagate = True
info_handler = None
error_handler = None
file_handlers = []
@@ -56,19 +61,25 @@ def setup_logger(verbose: bool = False, log_file: Path | None = None) -> logging
if info_handler is None:
info_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(info_handler)
for existing in list(info_handler.filters):
if isinstance(existing, LevelFilter):
info_handler.removeFilter(existing)
info_handler.addFilter(LevelFilter(logging.DEBUG, logging.INFO))
info_handler.setFormatter(formatter)
logger.addHandler(info_handler)
info_handler.setLevel(target_level)
if error_handler is None:
error_handler = logging.StreamHandler(sys.stderr)
error_handler.setLevel(logging.WARNING)
error_handler.setFormatter(formatter)
logger.addHandler(error_handler)
error_handler.setFormatter(formatter)
for handler in file_handlers:
handler.close()
logger.removeHandler(handler)
if log_file:
if not file_handlers:
log_file.parent.mkdir(parents=True, exist_ok=True)
file_handler = logging.FileHandler(log_file, encoding="utf-8")
file_handler.setLevel(logging.DEBUG)
@@ -78,10 +89,6 @@ def setup_logger(verbose: bool = False, log_file: Path | None = None) -> logging
)
file_handler.setFormatter(file_formatter)
logger.addHandler(file_handler)
else:
for handler in file_handlers:
handler.close()
logger.removeHandler(handler)
return logger
@@ -90,11 +97,18 @@ def get_logger() -> logging.Logger:
return logging.getLogger("skywipe")
def _format_error_message(error: Exception) -> str:
if isinstance(error, KeyError):
return str(error.args[0]) if error.args else str(error)
return str(error)
def handle_error(error: Exception, logger: logging.Logger, exit_on_error: bool = False) -> None:
if isinstance(error, ValueError):
logger.error(f"{error}")
if isinstance(error, (KeyError, ValueError)):
logger.error(_format_error_message(error))
else:
logger.error(f"Unexpected error: {error}", exc_info=True)
logger.error(
f"Unexpected error: {_format_error_message(error)}", exc_info=True)
if exit_on_error:
sys.exit(1)

View File

@@ -52,7 +52,16 @@ class OperationContext:
class BaseStrategy:
def get_cursor(self, response):
def fetch(self, context: OperationContext, cursor: str | None = None) -> Any:
raise NotImplementedError
def extract_items(self, response: Any) -> list[Any]:
raise NotImplementedError
def process_item(self, item: Any, context: OperationContext) -> None:
raise NotImplementedError
def get_cursor(self, response: Any) -> str | None:
return response.cursor
@@ -60,7 +69,7 @@ class RecordDeletionStrategy(BaseStrategy):
def __init__(self, collection: str):
self.collection = collection
def fetch(self, context: OperationContext, cursor: str | None = None):
def fetch(self, context: OperationContext, cursor: str | None = None) -> Any:
list_params = models.ComAtprotoRepoListRecords.Params(
repo=context.did,
collection=self.collection,
@@ -69,10 +78,10 @@ class RecordDeletionStrategy(BaseStrategy):
)
return context.client.com.atproto.repo.list_records(params=list_params)
def extract_items(self, response):
def extract_items(self, response: Any) -> list[Any]:
return response.records
def process_item(self, record, context: OperationContext):
def process_item(self, record: Any, context: OperationContext) -> None:
record_uri = record.uri
rkey = record_uri.rsplit("/", 1)[-1]
delete_data = {
@@ -85,34 +94,34 @@ class RecordDeletionStrategy(BaseStrategy):
class FeedStrategy(BaseStrategy):
def fetch(self, context: OperationContext, cursor: str | None = None):
def fetch(self, context: OperationContext, cursor: str | None = None) -> Any:
if cursor:
return context.client.get_author_feed(
actor=context.did, limit=context.batch_size, cursor=cursor
)
return context.client.get_author_feed(actor=context.did, limit=context.batch_size)
def extract_items(self, response):
def extract_items(self, response: Any) -> list[Any]:
return response.feed
def process_item(self, post, context: OperationContext):
def process_item(self, post: Any, context: OperationContext) -> None:
uri = post.post.uri
context.client.delete_post(uri)
context.logger.debug(f"Deleted post: {uri}")
class BookmarkStrategy(BaseStrategy):
def fetch(self, context: OperationContext, cursor: str | None = None):
def fetch(self, context: OperationContext, cursor: str | None = None) -> Any:
get_params = models.AppBskyBookmarkGetBookmarks.Params(
limit=context.batch_size,
cursor=cursor
)
return context.client.app.bsky.bookmark.get_bookmarks(params=get_params)
def extract_items(self, response):
def extract_items(self, response: Any) -> list[Any]:
return response.bookmarks
def process_item(self, bookmark, context: OperationContext):
def process_item(self, bookmark: Any, context: OperationContext) -> None:
bookmark_uri = self._extract_bookmark_uri(bookmark)
if not bookmark_uri:
raise ValueError("Unable to find bookmark URI")
@@ -122,7 +131,7 @@ class BookmarkStrategy(BaseStrategy):
context.client.app.bsky.bookmark.delete_bookmark(data=delete_data)
context.logger.debug(f"Deleted bookmark: {bookmark_uri}")
def _extract_bookmark_uri(self, bookmark):
def _extract_bookmark_uri(self, bookmark: Any) -> str | None:
if hasattr(bookmark, "uri"):
return bookmark.uri
@@ -148,6 +157,7 @@ class Operation:
self.filter_fn = filter_fn
self._client = client
self._config_data = config_data
self.strategy: BaseStrategy
if strategy_type == "record":
if not collection:

View File

@@ -1,7 +1,10 @@
from pathlib import Path
from typing import Iterable, Callable
import pytest
from skywipe.configure import Configuration
@pytest.fixture
def user_input(monkeypatch) -> Callable[[Iterable[str], Iterable[str]], None]:
@@ -10,7 +13,13 @@ def user_input(monkeypatch) -> Callable[[Iterable[str], Iterable[str]], None]:
password_iter = iter(passwords)
monkeypatch.setattr("builtins.input", lambda _prompt: next(input_iter))
monkeypatch.setattr("getpass.getpass",
monkeypatch.setattr("skywipe.configure.getpass",
lambda _prompt: next(password_iter))
return _set
@pytest.fixture
def config_with_tmp_path(monkeypatch, tmp_path):
monkeypatch.setattr(Path, "home", lambda: tmp_path)
return Configuration()

View File

@@ -6,91 +6,234 @@ import pytest
import skywipe.cli as cli
TEST_COMMAND = "posts"
TEST_ERROR_MESSAGE = "boom"
TEST_LOGGER_NAME = "test.cli"
def _setup_parser_mocks(monkeypatch, commands=None):
if commands is None:
commands = {TEST_COMMAND: "only posts"}
monkeypatch.setattr(cli.registry, "get_all_commands", lambda: commands)
def _setup_main_mocks(monkeypatch, calls, requires_config=False, config_data=None):
_setup_parser_mocks(monkeypatch)
monkeypatch.setattr(cli.registry, "requires_config",
lambda name: requires_config)
def mock_execute(name, skip_confirmation=False):
calls["execute"] = (name, skip_confirmation)
def mock_setup_logger(verbose, log_file):
calls["setup"].append((verbose, log_file))
def mock_require_config(logger):
calls["require_config"].append(logger)
monkeypatch.setattr(cli.registry, "execute", mock_execute)
monkeypatch.setattr(cli, "setup_logger", mock_setup_logger)
monkeypatch.setattr(cli, "require_config", mock_require_config)
monkeypatch.setattr(cli, "get_logger",
lambda: logging.getLogger(TEST_LOGGER_NAME))
if config_data is not None:
monkeypatch.setattr(cli.Configuration, "load",
lambda self: config_data)
def _setup_error_mocks(monkeypatch, calls, error_factory):
_setup_parser_mocks(monkeypatch)
monkeypatch.setattr(cli.registry, "requires_config", lambda name: False)
monkeypatch.setattr(cli.registry, "execute", error_factory)
monkeypatch.setattr(cli, "setup_logger", lambda verbose, log_file: None)
monkeypatch.setattr(cli, "get_logger",
lambda: logging.getLogger(TEST_LOGGER_NAME))
def _format_error_message(error):
if isinstance(error, KeyError):
return error.args[0] if error.args else str(error)
return str(error)
def mock_handle_error(error, logger, exit_on_error=False):
calls["handle_error"] = (type(error).__name__,
_format_error_message(error), exit_on_error)
monkeypatch.setattr(cli, "handle_error", mock_handle_error)
def test_create_parser_includes_commands(monkeypatch):
monkeypatch.setattr(cli.registry, "get_all_commands",
lambda: {"posts": "only posts"})
_setup_parser_mocks(monkeypatch)
parser = cli.create_parser()
args = parser.parse_args(["posts"])
assert args.command == "posts"
args = parser.parse_args([TEST_COMMAND])
assert args.command == TEST_COMMAND
def test_create_parser_handles_multiple_commands(monkeypatch):
commands = {
"posts": "only posts",
"likes": "only likes",
"reposts": "only reposts"
}
_setup_parser_mocks(monkeypatch, commands)
parser = cli.create_parser()
args1 = parser.parse_args(["posts"])
args2 = parser.parse_args(["likes"])
args3 = parser.parse_args(["reposts"])
assert args1.command == "posts"
assert args2.command == "likes"
assert args3.command == "reposts"
def test_create_parser_parses_yes_flag(monkeypatch):
_setup_parser_mocks(monkeypatch)
parser = cli.create_parser()
args = parser.parse_args(["--yes", TEST_COMMAND])
assert args.command == TEST_COMMAND
assert args.yes is True
def test_create_parser_parses_without_yes_flag(monkeypatch):
_setup_parser_mocks(monkeypatch)
parser = cli.create_parser()
args = parser.parse_args([TEST_COMMAND])
assert args.command == TEST_COMMAND
assert getattr(args, "yes", False) is False
def test_create_parser_version_flag_exits(monkeypatch):
_setup_parser_mocks(monkeypatch)
parser = cli.create_parser()
with pytest.raises(SystemExit) as excinfo:
parser.parse_args(["--version"])
assert excinfo.value.code == 0
def test_create_parser_requires_command(monkeypatch):
_setup_parser_mocks(monkeypatch)
parser = cli.create_parser()
with pytest.raises(SystemExit):
parser.parse_args([])
def test_create_parser_rejects_invalid_command(monkeypatch):
_setup_parser_mocks(monkeypatch)
parser = cli.create_parser()
with pytest.raises(SystemExit):
parser.parse_args(["invalid_command"])
def test_require_config_exits_when_missing(monkeypatch):
monkeypatch.setattr(cli.Configuration, "exists", lambda self: False)
logger = logging.getLogger("test.cli")
logger = logging.getLogger(TEST_LOGGER_NAME)
with pytest.raises(SystemExit) as excinfo:
cli.require_config(logger)
assert excinfo.value.code == 1
def test_require_config_does_not_exit_when_exists(monkeypatch):
monkeypatch.setattr(cli.Configuration, "exists", lambda self: True)
logger = logging.getLogger(TEST_LOGGER_NAME)
cli.require_config(logger)
def test_main_executes_without_config(monkeypatch):
calls = {"execute": None, "setup": []}
calls = {"execute": None, "setup": [], "require_config": []}
_setup_main_mocks(monkeypatch, calls, requires_config=False)
monkeypatch.setattr(cli.registry, "get_all_commands",
lambda: {"posts": "only posts"})
monkeypatch.setattr(cli.registry, "requires_config", lambda name: False)
monkeypatch.setattr(cli.registry, "execute", lambda name, skip_confirmation=False: calls.update(
{"execute": (name, skip_confirmation)}
))
monkeypatch.setattr(cli, "setup_logger", lambda verbose,
log_file: calls["setup"].append((verbose, log_file)))
monkeypatch.setattr(cli, "get_logger",
lambda: logging.getLogger("test.cli"))
monkeypatch.setattr(sys, "argv", ["skywipe", "--yes", "posts"])
monkeypatch.setattr(sys, "argv", ["skywipe", "--yes", TEST_COMMAND])
cli.main()
assert len(calls["require_config"]) == 0
assert calls["setup"] == [(False, cli.LOG_FILE)]
assert calls["execute"] == ("posts", True)
assert calls["execute"] == (TEST_COMMAND, True)
def test_main_loads_config_and_sets_verbose(monkeypatch):
calls = {"setup": [], "execute": None, "require_config": 0}
calls = {"setup": [], "execute": None, "require_config": []}
_setup_main_mocks(monkeypatch, calls, requires_config=True,
config_data={"verbose": True})
monkeypatch.setattr(cli.registry, "get_all_commands",
lambda: {"posts": "only posts"})
monkeypatch.setattr(cli.registry, "requires_config", lambda name: True)
monkeypatch.setattr(cli.registry, "execute", lambda name, skip_confirmation=False: calls.update(
{"execute": (name, skip_confirmation)}
))
monkeypatch.setattr(cli, "require_config", lambda logger: calls.update(
{"require_config": calls["require_config"] + 1}
))
monkeypatch.setattr(cli.Configuration, "load",
lambda self: {"verbose": True})
monkeypatch.setattr(cli, "setup_logger", lambda verbose,
log_file: calls["setup"].append((verbose, log_file)))
monkeypatch.setattr(cli, "get_logger",
lambda: logging.getLogger("test.cli"))
monkeypatch.setattr(sys, "argv", ["skywipe", "posts"])
monkeypatch.setattr(sys, "argv", ["skywipe", TEST_COMMAND])
cli.main()
assert calls["require_config"] == 1
assert len(calls["require_config"]) == 1
assert calls["setup"] == [(False, cli.LOG_FILE), (True, cli.LOG_FILE)]
assert calls["execute"] == ("posts", False)
assert calls["execute"] == (TEST_COMMAND, False)
def test_main_handles_execute_error(monkeypatch):
calls = {"handle_error": None}
@pytest.mark.parametrize("config_data,expected_verbose", [
({}, False),
({"verbose": False}, False),
])
def test_main_config_verbose_defaults(monkeypatch, config_data, expected_verbose):
calls = {"setup": [], "execute": None, "require_config": []}
_setup_main_mocks(monkeypatch, calls, requires_config=True,
config_data=config_data)
monkeypatch.setattr(cli.registry, "get_all_commands",
lambda: {"posts": "only posts"})
monkeypatch.setattr(cli.registry, "requires_config", lambda name: False)
monkeypatch.setattr(sys, "argv", ["skywipe", TEST_COMMAND])
cli.main()
def raise_error(*_args, **_kwargs):
raise ValueError("boom")
assert len(calls["require_config"]) == 1
assert calls["setup"] == [(False, cli.LOG_FILE),
(expected_verbose, cli.LOG_FILE)]
assert calls["execute"] == (TEST_COMMAND, False)
monkeypatch.setattr(cli.registry, "execute", raise_error)
def test_main_handles_config_load_error(monkeypatch):
calls = {"handle_error": None, "require_config": []}
def mock_require_config(logger):
calls["require_config"].append(logger)
def raise_config_error(self):
raise RuntimeError("config error")
def _format_error_message(error):
if isinstance(error, KeyError):
return error.args[0] if error.args else str(error)
return str(error)
def mock_handle_error(error, logger, exit_on_error=False):
calls["handle_error"] = (type(error).__name__,
_format_error_message(error), exit_on_error)
_setup_parser_mocks(monkeypatch)
monkeypatch.setattr(cli.registry, "requires_config", lambda name: True)
monkeypatch.setattr(cli, "require_config", mock_require_config)
monkeypatch.setattr(cli.Configuration, "load", raise_config_error)
monkeypatch.setattr(cli, "setup_logger", lambda verbose, log_file: None)
monkeypatch.setattr(cli, "get_logger",
lambda: logging.getLogger("test.cli"))
def fake_handle_error(error, logger, exit_on_error=False):
calls["handle_error"] = (str(error), exit_on_error)
monkeypatch.setattr(cli, "handle_error", fake_handle_error)
monkeypatch.setattr(sys, "argv", ["skywipe", "posts"])
lambda: logging.getLogger(TEST_LOGGER_NAME))
monkeypatch.setattr(cli, "handle_error", mock_handle_error)
monkeypatch.setattr(sys, "argv", ["skywipe", TEST_COMMAND])
cli.main()
assert calls["handle_error"] == ("boom", True)
assert len(calls["require_config"]) == 1
assert calls["handle_error"] is not None
assert calls["handle_error"][0] == "RuntimeError"
assert calls["handle_error"][2] is True
@pytest.mark.parametrize("error_class,error_message", [
(ValueError, TEST_ERROR_MESSAGE),
(RuntimeError, "runtime error"),
(KeyError, "missing key"),
])
def test_main_handles_execute_errors(monkeypatch, error_class, error_message):
calls = {"handle_error": None}
def raise_error(*_args, **_kwargs):
raise error_class(error_message)
_setup_error_mocks(monkeypatch, calls, raise_error)
monkeypatch.setattr(sys, "argv", ["skywipe", TEST_COMMAND])
cli.main()
assert calls["handle_error"] is not None
assert calls["handle_error"][0] == error_class.__name__
assert calls["handle_error"][1] == error_message
assert calls["handle_error"][2] is True

View File

@@ -1,11 +1,8 @@
from pathlib import Path
import yaml
from skywipe.configure import Configuration
def test_configuration_create_reprompts_and_writes_file(monkeypatch, tmp_path, user_input):
def test_configuration_create_reprompts_and_writes_file(config_with_tmp_path, user_input):
inputs = iter([
"bad handle",
"alice.bsky.social",
@@ -18,10 +15,9 @@ def test_configuration_create_reprompts_and_writes_file(monkeypatch, tmp_path, u
"longenough",
])
monkeypatch.setattr(Path, "home", lambda: tmp_path)
user_input(inputs, passwords)
config = Configuration()
config = config_with_tmp_path
config.create()
assert config.config_file.exists() is True
@@ -33,7 +29,7 @@ def test_configuration_create_reprompts_and_writes_file(monkeypatch, tmp_path, u
assert data["verbose"] is False
def test_configuration_create_invalid_batch_size(monkeypatch, tmp_path, user_input):
def test_configuration_create_invalid_batch_size(config_with_tmp_path, user_input):
inputs = iter([
"alice.bsky.social",
"0",
@@ -42,16 +38,15 @@ def test_configuration_create_invalid_batch_size(monkeypatch, tmp_path, user_inp
])
passwords = iter(["longenough"])
monkeypatch.setattr(Path, "home", lambda: tmp_path)
user_input(inputs, passwords)
config = Configuration()
config = config_with_tmp_path
config.create()
assert config.config_file.exists() is False
def test_configuration_create_invalid_delay(monkeypatch, tmp_path, user_input):
def test_configuration_create_invalid_delay(config_with_tmp_path, user_input):
inputs = iter([
"alice.bsky.social",
"10",
@@ -60,18 +55,16 @@ def test_configuration_create_invalid_delay(monkeypatch, tmp_path, user_input):
])
passwords = iter(["longenough"])
monkeypatch.setattr(Path, "home", lambda: tmp_path)
user_input(inputs, passwords)
config = Configuration()
config = config_with_tmp_path
config.create()
assert config.config_file.exists() is False
def test_configuration_create_overwrite_cancel(monkeypatch, tmp_path, user_input):
monkeypatch.setattr(Path, "home", lambda: tmp_path)
config = Configuration()
def test_configuration_create_overwrite_cancel(config_with_tmp_path, user_input):
config = config_with_tmp_path
config.config_file.parent.mkdir(parents=True, exist_ok=True)
config.config_file.write_text("existing")
@@ -82,14 +75,13 @@ def test_configuration_create_overwrite_cancel(monkeypatch, tmp_path, user_input
assert config.config_file.read_text() == "existing"
def test_configuration_create_write_failure(monkeypatch, tmp_path, user_input):
monkeypatch.setattr(Path, "home", lambda: tmp_path)
def test_configuration_create_write_failure(config_with_tmp_path, user_input, monkeypatch):
user_input(
["alice.bsky.social", "5", "0", "y"],
["longenough"],
)
config = Configuration()
config = config_with_tmp_path
original_open = open

View File

@@ -1,32 +1,59 @@
from pathlib import Path
from unittest.mock import patch
import pytest
from skywipe.configure import Configuration
import yaml
def test_configuration_load_missing_file(monkeypatch, tmp_path):
monkeypatch.setattr(Path, "home", lambda: tmp_path)
config = Configuration()
with pytest.raises(FileNotFoundError):
config.load()
def test_configuration_load_missing_file(config_with_tmp_path):
with pytest.raises(FileNotFoundError, match="Configuration file not found"):
config_with_tmp_path.load()
def test_configuration_load_empty_file(monkeypatch, tmp_path):
monkeypatch.setattr(Path, "home", lambda: tmp_path)
config = Configuration()
config.config_file.parent.mkdir(parents=True, exist_ok=True)
config.config_file.write_text("")
def test_configuration_load_empty_file(config_with_tmp_path):
config_with_tmp_path.config_file.parent.mkdir(parents=True, exist_ok=True)
config_with_tmp_path.config_file.write_text("")
with pytest.raises(ValueError, match="empty or invalid"):
config.load()
config_with_tmp_path.load()
def test_configuration_load_invalid_yaml(monkeypatch, tmp_path):
monkeypatch.setattr(Path, "home", lambda: tmp_path)
config = Configuration()
config.config_file.parent.mkdir(parents=True, exist_ok=True)
config.config_file.write_text(": bad")
def test_configuration_load_invalid_yaml(config_with_tmp_path):
config_with_tmp_path.config_file.parent.mkdir(parents=True, exist_ok=True)
config_with_tmp_path.config_file.write_text(": bad")
with pytest.raises(ValueError, match="Invalid YAML"):
config.load()
config_with_tmp_path.load()
def test_configuration_load_valid_file(config_with_tmp_path):
config_with_tmp_path.config_file.parent.mkdir(parents=True, exist_ok=True)
config_data = {
"handle": "alice.bsky.social",
"password": "password123",
"batch_size": 10,
"delay": 1,
"verbose": True
}
with open(config_with_tmp_path.config_file, "w") as f:
yaml.dump(config_data, f)
result = config_with_tmp_path.load()
assert result == config_data
def test_configuration_load_file_read_error(config_with_tmp_path):
config_with_tmp_path.config_file.parent.mkdir(parents=True, exist_ok=True)
config_with_tmp_path.config_file.write_text("handle: alice")
with patch("builtins.open", side_effect=IOError("Permission denied")):
with pytest.raises(ValueError, match="Failed to read configuration file"):
config_with_tmp_path.load()
def test_configuration_load_file_os_error(config_with_tmp_path):
config_with_tmp_path.config_file.parent.mkdir(parents=True, exist_ok=True)
config_with_tmp_path.config_file.write_text("handle: alice")
with patch("builtins.open", side_effect=OSError("File is locked")):
with pytest.raises(ValueError, match="Failed to read configuration file"):
config_with_tmp_path.load()

View File

@@ -1,4 +1,5 @@
import logging
import sys
from skywipe.logger import LevelFilter, ProgressTracker, setup_logger
@@ -40,6 +41,53 @@ def test_progress_tracker_updates_counts():
assert tracker.current == 3
def test_progress_tracker_batch_with_total():
tracker = ProgressTracker(operation="Testing")
logger = logging.getLogger("skywipe.progress")
messages = []
class MessageHandler(logging.Handler):
def emit(self, record):
messages.append(self.format(record))
handler = MessageHandler()
handler.setLevel(logging.INFO)
logger.addHandler(handler)
logger.setLevel(logging.INFO)
try:
tracker.batch(1, 10, total_batches=5)
assert messages[-1] == "Testing - batch 1/5 (10 items)"
tracker.batch(0, 5, total_batches=0)
assert messages[-1] == "Testing - batch 0/0 (5 items)"
finally:
handler.close()
logger.removeHandler(handler)
def test_progress_tracker_batch_without_total():
tracker = ProgressTracker(operation="Testing")
logger = logging.getLogger("skywipe.progress")
messages = []
class MessageHandler(logging.Handler):
def emit(self, record):
messages.append(self.format(record))
handler = MessageHandler()
handler.setLevel(logging.INFO)
logger.addHandler(handler)
logger.setLevel(logging.INFO)
try:
tracker.batch(1, 10, total_batches=None)
assert messages[-1] == "Testing - batch 1 (10 items)"
finally:
handler.close()
logger.removeHandler(handler)
def test_setup_logger_does_not_duplicate_handlers():
logger = logging.getLogger("skywipe")
original_handlers = list(logger.handlers)
@@ -48,11 +96,27 @@ def test_setup_logger_does_not_duplicate_handlers():
try:
setup_logger(verbose=False)
first_count = len(logger.handlers)
first_handlers = list(logger.handlers)
stream_handlers = [
handler for handler in first_handlers
if isinstance(handler, logging.StreamHandler)
]
assert len(stream_handlers) == 2
assert {handler.stream for handler in stream_handlers} == {
sys.stdout,
sys.stderr,
}
assert not any(
isinstance(handler, logging.FileHandler)
for handler in first_handlers
)
first_count = len(first_handlers)
setup_logger(verbose=False)
second_count = len(logger.handlers)
finally:
for handler in list(logger.handlers):
handler.close()
logger.removeHandler(handler)
for handler in original_handlers:
logger.addHandler(handler)
@@ -60,6 +124,74 @@ def test_setup_logger_does_not_duplicate_handlers():
assert first_count == second_count
def test_setup_logger_resets_stream_formatters():
logger = logging.getLogger("skywipe")
original_handlers = list(logger.handlers)
for handler in original_handlers:
logger.removeHandler(handler)
try:
setup_logger(verbose=False)
alt_formatter = logging.Formatter(fmt="%(message)s")
for handler in logger.handlers:
if isinstance(handler, logging.StreamHandler):
handler.setFormatter(alt_formatter)
setup_logger(verbose=False)
for handler in logger.handlers:
if isinstance(handler, logging.StreamHandler):
assert handler.formatter is not None
assert handler.formatter._fmt == "%(levelname)s: %(message)s"
finally:
for handler in list(logger.handlers):
handler.close()
logger.removeHandler(handler)
for handler in original_handlers:
logger.addHandler(handler)
def test_setup_logger_disables_propagation():
root_logger = logging.getLogger()
root_messages = []
original_root_level = root_logger.level
class RootHandler(logging.Handler):
def emit(self, record):
root_messages.append(self.format(record))
root_handler = RootHandler()
root_handler.setLevel(logging.INFO)
root_logger.addHandler(root_handler)
root_logger.setLevel(logging.INFO)
logger = logging.getLogger("skywipe")
original_handlers = list(logger.handlers)
for handler in original_handlers:
logger.removeHandler(handler)
try:
setup_logger(verbose=False)
assert logger.propagate is False
progress_logger = logging.getLogger("skywipe.progress")
assert progress_logger.propagate is True
logger.info("Test message")
progress_logger.info("Progress message")
assert len(root_messages) == 0
finally:
root_handler.close()
root_logger.removeHandler(root_handler)
root_logger.setLevel(original_root_level)
for handler in list(logger.handlers):
handler.close()
logger.removeHandler(handler)
for handler in original_handlers:
logger.addHandler(handler)
def test_setup_logger_file_handler_lifecycle(tmp_path):
logger = logging.getLogger("skywipe")
original_handlers = list(logger.handlers)
@@ -87,3 +219,34 @@ def test_setup_logger_file_handler_lifecycle(tmp_path):
logger.removeHandler(handler)
for handler in original_handlers:
logger.addHandler(handler)
def test_setup_logger_replaces_file_handler_when_path_changes(tmp_path):
logger = logging.getLogger("skywipe")
original_handlers = list(logger.handlers)
for handler in original_handlers:
logger.removeHandler(handler)
log_file1 = tmp_path / "skywipe1.log"
log_file2 = tmp_path / "skywipe2.log"
try:
setup_logger(verbose=False, log_file=log_file1)
file_handlers = [
handler for handler in logger.handlers
if isinstance(handler, logging.FileHandler)
]
assert len(file_handlers) == 1
assert file_handlers[0].baseFilename == str(log_file1)
setup_logger(verbose=False, log_file=log_file2)
file_handlers = [
handler for handler in logger.handlers
if isinstance(handler, logging.FileHandler)
]
assert len(file_handlers) == 1
assert file_handlers[0].baseFilename == str(log_file2)
finally:
for handler in list(logger.handlers):
logger.removeHandler(handler)
for handler in original_handlers:
logger.addHandler(handler)