Compare commits
19 Commits
5e60374937
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
| 2cdc4c6c42 | |||
| 07862f0ea2 | |||
| e095c68f72 | |||
| c718e8c6f5 | |||
| a07cc02fb0 | |||
| ecc33054af | |||
| e6d68dd37d | |||
| 809b7823ea | |||
| 6a88ab8560 | |||
| dd3b220a4a | |||
| 82b99da50d | |||
| ce2a1ad594 | |||
| 313b6fc453 | |||
| e364598414 | |||
| cd1cf1f170 | |||
| c3761d1d08 | |||
| 85f1ea4efb | |||
| df22b3dd3d | |||
| 8c0bbceeac |
@@ -61,14 +61,14 @@ def main():
|
|||||||
setup_logger(verbose=False, log_file=LOG_FILE)
|
setup_logger(verbose=False, log_file=LOG_FILE)
|
||||||
logger = get_logger()
|
logger = get_logger()
|
||||||
|
|
||||||
if registry.requires_config(args.command):
|
|
||||||
require_config(logger)
|
|
||||||
config = Configuration()
|
|
||||||
config_data = config.load()
|
|
||||||
verbose = config_data.get("verbose", False)
|
|
||||||
setup_logger(verbose=verbose, log_file=LOG_FILE)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
if registry.requires_config(args.command):
|
||||||
|
require_config(logger)
|
||||||
|
config = Configuration()
|
||||||
|
config_data = config.load()
|
||||||
|
verbose = config_data.get("verbose", False)
|
||||||
|
setup_logger(verbose=verbose, log_file=LOG_FILE)
|
||||||
|
|
||||||
registry.execute(
|
registry.execute(
|
||||||
args.command, skip_confirmation=getattr(args, "yes", False))
|
args.command, skip_confirmation=getattr(args, "yes", False))
|
||||||
except (ValueError, Exception) as e:
|
except (ValueError, Exception) as e:
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
"""Command implementations for Skywipe"""
|
"""Command implementations for Skywipe"""
|
||||||
|
|
||||||
from typing import Callable, Any
|
from typing import Callable, Any, TypedDict
|
||||||
from .configure import Configuration
|
from .configure import Configuration
|
||||||
from .operations import Operation
|
from .operations import Operation
|
||||||
from .post_analysis import PostAnalyzer
|
from .post_analysis import PostAnalyzer
|
||||||
@@ -11,7 +11,16 @@ from .safeguard import require_confirmation
|
|||||||
CommandHandler = Callable[..., None]
|
CommandHandler = Callable[..., None]
|
||||||
|
|
||||||
|
|
||||||
COMMAND_METADATA = {
|
class CommandMetadata(TypedDict):
|
||||||
|
confirmation: str
|
||||||
|
help_text: str
|
||||||
|
operation_name: str
|
||||||
|
strategy_type: str
|
||||||
|
collection: str | None
|
||||||
|
filter_fn: Callable[[Any], bool] | None
|
||||||
|
|
||||||
|
|
||||||
|
COMMAND_METADATA: dict[str, CommandMetadata] = {
|
||||||
"posts": {
|
"posts": {
|
||||||
"confirmation": "delete all posts",
|
"confirmation": "delete all posts",
|
||||||
"help_text": "only posts",
|
"help_text": "only posts",
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
"""Core configuration module for Skywipe"""
|
"""Core configuration module for Skywipe"""
|
||||||
|
|
||||||
import getpass
|
from getpass import getpass
|
||||||
import re
|
import re
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import NamedTuple
|
from typing import NamedTuple
|
||||||
@@ -56,65 +56,100 @@ class Configuration:
|
|||||||
def exists(self) -> bool:
|
def exists(self) -> bool:
|
||||||
return self.config_file.exists()
|
return self.config_file.exists()
|
||||||
|
|
||||||
def create(self):
|
def _confirm_overwrite(self, logger) -> bool:
|
||||||
logger = setup_logger(verbose=False)
|
if not self.exists():
|
||||||
if self.exists():
|
return True
|
||||||
overwrite = input(
|
overwrite = input(
|
||||||
"Configuration already exists. Overwrite? (y/N): ").strip().lower()
|
"Configuration already exists. Overwrite? (y/N): ").strip().lower()
|
||||||
if overwrite not in ("y", "yes"):
|
if overwrite in ("y", "yes"):
|
||||||
logger.info("Configuration creation cancelled.")
|
return True
|
||||||
return
|
logger.info("Configuration creation cancelled.")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _ensure_config_dir(self) -> None:
|
||||||
config_dir = self.config_file.parent
|
config_dir = self.config_file.parent
|
||||||
config_dir.mkdir(parents=True, exist_ok=True)
|
config_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
print("Skywipe Configuration")
|
def _prompt_handle(self, logger) -> str:
|
||||||
print("=" * 50)
|
|
||||||
print("Note: You should use an app password from Bluesky settings.")
|
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
handle = input("Bluesky handle: ").strip()
|
handle = input("Bluesky handle: ").strip()
|
||||||
is_valid, error_msg = _validate_handle(handle)
|
is_valid, error_msg = _validate_handle(handle)
|
||||||
if is_valid:
|
if is_valid:
|
||||||
break
|
return handle
|
||||||
logger.error(error_msg)
|
logger.error(error_msg)
|
||||||
logger.info("Please enter a valid handle and try again.")
|
logger.info("Please enter a valid handle and try again.")
|
||||||
|
|
||||||
|
def _prompt_password(self, logger) -> str:
|
||||||
while True:
|
while True:
|
||||||
password = getpass.getpass(
|
password = getpass(
|
||||||
"Bluesky (hopefully app) password: ").strip()
|
"Bluesky (hopefully app) password: ").strip()
|
||||||
is_valid, error_msg = _validate_password(password)
|
is_valid, error_msg = _validate_password(password)
|
||||||
if is_valid:
|
if is_valid:
|
||||||
break
|
return password
|
||||||
logger.error(error_msg)
|
logger.error(error_msg)
|
||||||
logger.info("Please check your password and try again.")
|
logger.info("Please check your password and try again.")
|
||||||
logger.info(
|
logger.info(
|
||||||
"Generate an app password at: https://bsky.app/settings/app-passwords")
|
"Generate an app password at: https://bsky.app/settings/app-passwords")
|
||||||
|
|
||||||
|
def _parse_batch_size(self, logger) -> int | None:
|
||||||
batch_size = input("Batch size (default: 10): ").strip() or "10"
|
batch_size = input("Batch size (default: 10): ").strip() or "10"
|
||||||
delay = input(
|
|
||||||
"Delay between batches in seconds (default: 1): ").strip() or "1"
|
|
||||||
verbose_input = input(
|
|
||||||
"Verbose mode (y/n, default: y): ").strip().lower() or "y"
|
|
||||||
verbose = verbose_input in ("y", "yes", "true", "1")
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
batch_size = int(batch_size)
|
batch_size_int = int(batch_size)
|
||||||
if batch_size < 1 or batch_size > 100:
|
|
||||||
logger.error("batch_size must be between 1 and 100")
|
|
||||||
return
|
|
||||||
except ValueError:
|
except ValueError:
|
||||||
logger.error("batch_size must be an integer")
|
logger.error("batch_size must be an integer")
|
||||||
return
|
return None
|
||||||
|
if batch_size_int < 1 or batch_size_int > 100:
|
||||||
|
logger.error("batch_size must be between 1 and 100")
|
||||||
|
return None
|
||||||
|
return batch_size_int
|
||||||
|
|
||||||
|
def _parse_delay(self, logger) -> int | None:
|
||||||
|
delay = input(
|
||||||
|
"Delay between batches in seconds (default: 1): ").strip() or "1"
|
||||||
try:
|
try:
|
||||||
delay = int(delay)
|
delay_int = int(delay)
|
||||||
if delay < 0 or delay > 60:
|
|
||||||
logger.error("delay must be between 0 and 60 seconds")
|
|
||||||
return
|
|
||||||
except ValueError:
|
except ValueError:
|
||||||
logger.error("delay must be an integer")
|
logger.error("delay must be an integer")
|
||||||
|
return None
|
||||||
|
if delay_int < 0 or delay_int > 60:
|
||||||
|
logger.error("delay must be between 0 and 60 seconds")
|
||||||
|
return None
|
||||||
|
return delay_int
|
||||||
|
|
||||||
|
def _parse_verbose(self) -> bool:
|
||||||
|
verbose_input = input(
|
||||||
|
"Verbose mode (y/n, default: y): ").strip().lower() or "y"
|
||||||
|
return verbose_input in ("y", "yes", "true", "1")
|
||||||
|
|
||||||
|
def _write_config(self, logger, config_data: dict) -> None:
|
||||||
|
try:
|
||||||
|
with open(self.config_file, "w") as f:
|
||||||
|
yaml.dump(config_data, f, default_flow_style=False)
|
||||||
|
except (IOError, OSError) as e:
|
||||||
|
logger.error(f"Failed to save configuration: {e}")
|
||||||
return
|
return
|
||||||
|
logger.info(f"Configuration saved to {self.config_file}")
|
||||||
|
|
||||||
|
def create(self):
|
||||||
|
logger = setup_logger(verbose=False)
|
||||||
|
if not self._confirm_overwrite(logger):
|
||||||
|
return
|
||||||
|
|
||||||
|
self._ensure_config_dir()
|
||||||
|
|
||||||
|
print("Skywipe Configuration")
|
||||||
|
print("=" * 50)
|
||||||
|
print("Note: You should use an app password from Bluesky settings.")
|
||||||
|
|
||||||
|
handle = self._prompt_handle(logger)
|
||||||
|
password = self._prompt_password(logger)
|
||||||
|
batch_size = self._parse_batch_size(logger)
|
||||||
|
if batch_size is None:
|
||||||
|
return
|
||||||
|
delay = self._parse_delay(logger)
|
||||||
|
if delay is None:
|
||||||
|
return
|
||||||
|
verbose = self._parse_verbose()
|
||||||
|
|
||||||
config_data = {
|
config_data = {
|
||||||
"handle": handle,
|
"handle": handle,
|
||||||
@@ -124,14 +159,7 @@ class Configuration:
|
|||||||
"verbose": verbose
|
"verbose": verbose
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
self._write_config(logger, config_data)
|
||||||
with open(self.config_file, "w") as f:
|
|
||||||
yaml.dump(config_data, f, default_flow_style=False)
|
|
||||||
except (IOError, OSError) as e:
|
|
||||||
logger.error(f"Failed to save configuration: {e}")
|
|
||||||
return
|
|
||||||
|
|
||||||
logger.info(f"Configuration saved to {self.config_file}")
|
|
||||||
|
|
||||||
def load(self) -> dict:
|
def load(self) -> dict:
|
||||||
if not self.exists():
|
if not self.exists():
|
||||||
|
|||||||
@@ -36,9 +36,14 @@ class LevelFilter(logging.Filter):
|
|||||||
|
|
||||||
def setup_logger(verbose: bool = False, log_file: Path | None = None) -> logging.Logger:
|
def setup_logger(verbose: bool = False, log_file: Path | None = None) -> logging.Logger:
|
||||||
logger = logging.getLogger("skywipe")
|
logger = logging.getLogger("skywipe")
|
||||||
|
logger.propagate = False
|
||||||
target_level = logging.DEBUG if verbose else logging.INFO
|
target_level = logging.DEBUG if verbose else logging.INFO
|
||||||
logger.setLevel(target_level)
|
logger.setLevel(target_level)
|
||||||
|
|
||||||
|
progress_logger = logging.getLogger("skywipe.progress")
|
||||||
|
if not progress_logger.handlers:
|
||||||
|
progress_logger.propagate = True
|
||||||
|
|
||||||
info_handler = None
|
info_handler = None
|
||||||
error_handler = None
|
error_handler = None
|
||||||
file_handlers = []
|
file_handlers = []
|
||||||
@@ -56,16 +61,19 @@ def setup_logger(verbose: bool = False, log_file: Path | None = None) -> logging
|
|||||||
|
|
||||||
if info_handler is None:
|
if info_handler is None:
|
||||||
info_handler = logging.StreamHandler(sys.stdout)
|
info_handler = logging.StreamHandler(sys.stdout)
|
||||||
info_handler.addFilter(LevelFilter(logging.DEBUG, logging.INFO))
|
|
||||||
info_handler.setFormatter(formatter)
|
|
||||||
logger.addHandler(info_handler)
|
logger.addHandler(info_handler)
|
||||||
|
for existing in list(info_handler.filters):
|
||||||
|
if isinstance(existing, LevelFilter):
|
||||||
|
info_handler.removeFilter(existing)
|
||||||
|
info_handler.addFilter(LevelFilter(logging.DEBUG, logging.INFO))
|
||||||
|
info_handler.setFormatter(formatter)
|
||||||
info_handler.setLevel(target_level)
|
info_handler.setLevel(target_level)
|
||||||
|
|
||||||
if error_handler is None:
|
if error_handler is None:
|
||||||
error_handler = logging.StreamHandler(sys.stderr)
|
error_handler = logging.StreamHandler(sys.stderr)
|
||||||
error_handler.setLevel(logging.WARNING)
|
error_handler.setLevel(logging.WARNING)
|
||||||
error_handler.setFormatter(formatter)
|
|
||||||
logger.addHandler(error_handler)
|
logger.addHandler(error_handler)
|
||||||
|
error_handler.setFormatter(formatter)
|
||||||
|
|
||||||
for handler in file_handlers:
|
for handler in file_handlers:
|
||||||
handler.close()
|
handler.close()
|
||||||
@@ -89,11 +97,18 @@ def get_logger() -> logging.Logger:
|
|||||||
return logging.getLogger("skywipe")
|
return logging.getLogger("skywipe")
|
||||||
|
|
||||||
|
|
||||||
|
def _format_error_message(error: Exception) -> str:
|
||||||
|
if isinstance(error, KeyError):
|
||||||
|
return str(error.args[0]) if error.args else str(error)
|
||||||
|
return str(error)
|
||||||
|
|
||||||
|
|
||||||
def handle_error(error: Exception, logger: logging.Logger, exit_on_error: bool = False) -> None:
|
def handle_error(error: Exception, logger: logging.Logger, exit_on_error: bool = False) -> None:
|
||||||
if isinstance(error, ValueError):
|
if isinstance(error, (KeyError, ValueError)):
|
||||||
logger.error(f"{error}")
|
logger.error(_format_error_message(error))
|
||||||
else:
|
else:
|
||||||
logger.error(f"Unexpected error: {error}", exc_info=True)
|
logger.error(
|
||||||
|
f"Unexpected error: {_format_error_message(error)}", exc_info=True)
|
||||||
|
|
||||||
if exit_on_error:
|
if exit_on_error:
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|||||||
@@ -52,7 +52,16 @@ class OperationContext:
|
|||||||
|
|
||||||
|
|
||||||
class BaseStrategy:
|
class BaseStrategy:
|
||||||
def get_cursor(self, response):
|
def fetch(self, context: OperationContext, cursor: str | None = None) -> Any:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def extract_items(self, response: Any) -> list[Any]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def process_item(self, item: Any, context: OperationContext) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def get_cursor(self, response: Any) -> str | None:
|
||||||
return response.cursor
|
return response.cursor
|
||||||
|
|
||||||
|
|
||||||
@@ -60,7 +69,7 @@ class RecordDeletionStrategy(BaseStrategy):
|
|||||||
def __init__(self, collection: str):
|
def __init__(self, collection: str):
|
||||||
self.collection = collection
|
self.collection = collection
|
||||||
|
|
||||||
def fetch(self, context: OperationContext, cursor: str | None = None):
|
def fetch(self, context: OperationContext, cursor: str | None = None) -> Any:
|
||||||
list_params = models.ComAtprotoRepoListRecords.Params(
|
list_params = models.ComAtprotoRepoListRecords.Params(
|
||||||
repo=context.did,
|
repo=context.did,
|
||||||
collection=self.collection,
|
collection=self.collection,
|
||||||
@@ -69,10 +78,10 @@ class RecordDeletionStrategy(BaseStrategy):
|
|||||||
)
|
)
|
||||||
return context.client.com.atproto.repo.list_records(params=list_params)
|
return context.client.com.atproto.repo.list_records(params=list_params)
|
||||||
|
|
||||||
def extract_items(self, response):
|
def extract_items(self, response: Any) -> list[Any]:
|
||||||
return response.records
|
return response.records
|
||||||
|
|
||||||
def process_item(self, record, context: OperationContext):
|
def process_item(self, record: Any, context: OperationContext) -> None:
|
||||||
record_uri = record.uri
|
record_uri = record.uri
|
||||||
rkey = record_uri.rsplit("/", 1)[-1]
|
rkey = record_uri.rsplit("/", 1)[-1]
|
||||||
delete_data = {
|
delete_data = {
|
||||||
@@ -85,34 +94,34 @@ class RecordDeletionStrategy(BaseStrategy):
|
|||||||
|
|
||||||
|
|
||||||
class FeedStrategy(BaseStrategy):
|
class FeedStrategy(BaseStrategy):
|
||||||
def fetch(self, context: OperationContext, cursor: str | None = None):
|
def fetch(self, context: OperationContext, cursor: str | None = None) -> Any:
|
||||||
if cursor:
|
if cursor:
|
||||||
return context.client.get_author_feed(
|
return context.client.get_author_feed(
|
||||||
actor=context.did, limit=context.batch_size, cursor=cursor
|
actor=context.did, limit=context.batch_size, cursor=cursor
|
||||||
)
|
)
|
||||||
return context.client.get_author_feed(actor=context.did, limit=context.batch_size)
|
return context.client.get_author_feed(actor=context.did, limit=context.batch_size)
|
||||||
|
|
||||||
def extract_items(self, response):
|
def extract_items(self, response: Any) -> list[Any]:
|
||||||
return response.feed
|
return response.feed
|
||||||
|
|
||||||
def process_item(self, post, context: OperationContext):
|
def process_item(self, post: Any, context: OperationContext) -> None:
|
||||||
uri = post.post.uri
|
uri = post.post.uri
|
||||||
context.client.delete_post(uri)
|
context.client.delete_post(uri)
|
||||||
context.logger.debug(f"Deleted post: {uri}")
|
context.logger.debug(f"Deleted post: {uri}")
|
||||||
|
|
||||||
|
|
||||||
class BookmarkStrategy(BaseStrategy):
|
class BookmarkStrategy(BaseStrategy):
|
||||||
def fetch(self, context: OperationContext, cursor: str | None = None):
|
def fetch(self, context: OperationContext, cursor: str | None = None) -> Any:
|
||||||
get_params = models.AppBskyBookmarkGetBookmarks.Params(
|
get_params = models.AppBskyBookmarkGetBookmarks.Params(
|
||||||
limit=context.batch_size,
|
limit=context.batch_size,
|
||||||
cursor=cursor
|
cursor=cursor
|
||||||
)
|
)
|
||||||
return context.client.app.bsky.bookmark.get_bookmarks(params=get_params)
|
return context.client.app.bsky.bookmark.get_bookmarks(params=get_params)
|
||||||
|
|
||||||
def extract_items(self, response):
|
def extract_items(self, response: Any) -> list[Any]:
|
||||||
return response.bookmarks
|
return response.bookmarks
|
||||||
|
|
||||||
def process_item(self, bookmark, context: OperationContext):
|
def process_item(self, bookmark: Any, context: OperationContext) -> None:
|
||||||
bookmark_uri = self._extract_bookmark_uri(bookmark)
|
bookmark_uri = self._extract_bookmark_uri(bookmark)
|
||||||
if not bookmark_uri:
|
if not bookmark_uri:
|
||||||
raise ValueError("Unable to find bookmark URI")
|
raise ValueError("Unable to find bookmark URI")
|
||||||
@@ -122,7 +131,7 @@ class BookmarkStrategy(BaseStrategy):
|
|||||||
context.client.app.bsky.bookmark.delete_bookmark(data=delete_data)
|
context.client.app.bsky.bookmark.delete_bookmark(data=delete_data)
|
||||||
context.logger.debug(f"Deleted bookmark: {bookmark_uri}")
|
context.logger.debug(f"Deleted bookmark: {bookmark_uri}")
|
||||||
|
|
||||||
def _extract_bookmark_uri(self, bookmark):
|
def _extract_bookmark_uri(self, bookmark: Any) -> str | None:
|
||||||
if hasattr(bookmark, "uri"):
|
if hasattr(bookmark, "uri"):
|
||||||
return bookmark.uri
|
return bookmark.uri
|
||||||
|
|
||||||
@@ -148,6 +157,7 @@ class Operation:
|
|||||||
self.filter_fn = filter_fn
|
self.filter_fn = filter_fn
|
||||||
self._client = client
|
self._client = client
|
||||||
self._config_data = config_data
|
self._config_data = config_data
|
||||||
|
self.strategy: BaseStrategy
|
||||||
|
|
||||||
if strategy_type == "record":
|
if strategy_type == "record":
|
||||||
if not collection:
|
if not collection:
|
||||||
|
|||||||
@@ -1,7 +1,10 @@
|
|||||||
|
from pathlib import Path
|
||||||
from typing import Iterable, Callable
|
from typing import Iterable, Callable
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from skywipe.configure import Configuration
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def user_input(monkeypatch) -> Callable[[Iterable[str], Iterable[str]], None]:
|
def user_input(monkeypatch) -> Callable[[Iterable[str], Iterable[str]], None]:
|
||||||
@@ -10,7 +13,13 @@ def user_input(monkeypatch) -> Callable[[Iterable[str], Iterable[str]], None]:
|
|||||||
password_iter = iter(passwords)
|
password_iter = iter(passwords)
|
||||||
|
|
||||||
monkeypatch.setattr("builtins.input", lambda _prompt: next(input_iter))
|
monkeypatch.setattr("builtins.input", lambda _prompt: next(input_iter))
|
||||||
monkeypatch.setattr("getpass.getpass",
|
monkeypatch.setattr("skywipe.configure.getpass",
|
||||||
lambda _prompt: next(password_iter))
|
lambda _prompt: next(password_iter))
|
||||||
|
|
||||||
return _set
|
return _set
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def config_with_tmp_path(monkeypatch, tmp_path):
|
||||||
|
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
||||||
|
return Configuration()
|
||||||
|
|||||||
@@ -6,91 +6,234 @@ import pytest
|
|||||||
import skywipe.cli as cli
|
import skywipe.cli as cli
|
||||||
|
|
||||||
|
|
||||||
|
TEST_COMMAND = "posts"
|
||||||
|
TEST_ERROR_MESSAGE = "boom"
|
||||||
|
TEST_LOGGER_NAME = "test.cli"
|
||||||
|
|
||||||
|
|
||||||
|
def _setup_parser_mocks(monkeypatch, commands=None):
|
||||||
|
if commands is None:
|
||||||
|
commands = {TEST_COMMAND: "only posts"}
|
||||||
|
monkeypatch.setattr(cli.registry, "get_all_commands", lambda: commands)
|
||||||
|
|
||||||
|
|
||||||
|
def _setup_main_mocks(monkeypatch, calls, requires_config=False, config_data=None):
|
||||||
|
_setup_parser_mocks(monkeypatch)
|
||||||
|
monkeypatch.setattr(cli.registry, "requires_config",
|
||||||
|
lambda name: requires_config)
|
||||||
|
|
||||||
|
def mock_execute(name, skip_confirmation=False):
|
||||||
|
calls["execute"] = (name, skip_confirmation)
|
||||||
|
|
||||||
|
def mock_setup_logger(verbose, log_file):
|
||||||
|
calls["setup"].append((verbose, log_file))
|
||||||
|
|
||||||
|
def mock_require_config(logger):
|
||||||
|
calls["require_config"].append(logger)
|
||||||
|
|
||||||
|
monkeypatch.setattr(cli.registry, "execute", mock_execute)
|
||||||
|
monkeypatch.setattr(cli, "setup_logger", mock_setup_logger)
|
||||||
|
monkeypatch.setattr(cli, "require_config", mock_require_config)
|
||||||
|
monkeypatch.setattr(cli, "get_logger",
|
||||||
|
lambda: logging.getLogger(TEST_LOGGER_NAME))
|
||||||
|
|
||||||
|
if config_data is not None:
|
||||||
|
monkeypatch.setattr(cli.Configuration, "load",
|
||||||
|
lambda self: config_data)
|
||||||
|
|
||||||
|
|
||||||
|
def _setup_error_mocks(monkeypatch, calls, error_factory):
|
||||||
|
_setup_parser_mocks(monkeypatch)
|
||||||
|
monkeypatch.setattr(cli.registry, "requires_config", lambda name: False)
|
||||||
|
monkeypatch.setattr(cli.registry, "execute", error_factory)
|
||||||
|
monkeypatch.setattr(cli, "setup_logger", lambda verbose, log_file: None)
|
||||||
|
monkeypatch.setattr(cli, "get_logger",
|
||||||
|
lambda: logging.getLogger(TEST_LOGGER_NAME))
|
||||||
|
|
||||||
|
def _format_error_message(error):
|
||||||
|
if isinstance(error, KeyError):
|
||||||
|
return error.args[0] if error.args else str(error)
|
||||||
|
return str(error)
|
||||||
|
|
||||||
|
def mock_handle_error(error, logger, exit_on_error=False):
|
||||||
|
calls["handle_error"] = (type(error).__name__,
|
||||||
|
_format_error_message(error), exit_on_error)
|
||||||
|
|
||||||
|
monkeypatch.setattr(cli, "handle_error", mock_handle_error)
|
||||||
|
|
||||||
|
|
||||||
def test_create_parser_includes_commands(monkeypatch):
|
def test_create_parser_includes_commands(monkeypatch):
|
||||||
monkeypatch.setattr(cli.registry, "get_all_commands",
|
_setup_parser_mocks(monkeypatch)
|
||||||
lambda: {"posts": "only posts"})
|
|
||||||
parser = cli.create_parser()
|
parser = cli.create_parser()
|
||||||
args = parser.parse_args(["posts"])
|
args = parser.parse_args([TEST_COMMAND])
|
||||||
assert args.command == "posts"
|
assert args.command == TEST_COMMAND
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_parser_handles_multiple_commands(monkeypatch):
|
||||||
|
commands = {
|
||||||
|
"posts": "only posts",
|
||||||
|
"likes": "only likes",
|
||||||
|
"reposts": "only reposts"
|
||||||
|
}
|
||||||
|
_setup_parser_mocks(monkeypatch, commands)
|
||||||
|
parser = cli.create_parser()
|
||||||
|
|
||||||
|
args1 = parser.parse_args(["posts"])
|
||||||
|
args2 = parser.parse_args(["likes"])
|
||||||
|
args3 = parser.parse_args(["reposts"])
|
||||||
|
|
||||||
|
assert args1.command == "posts"
|
||||||
|
assert args2.command == "likes"
|
||||||
|
assert args3.command == "reposts"
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_parser_parses_yes_flag(monkeypatch):
|
||||||
|
_setup_parser_mocks(monkeypatch)
|
||||||
|
parser = cli.create_parser()
|
||||||
|
args = parser.parse_args(["--yes", TEST_COMMAND])
|
||||||
|
assert args.command == TEST_COMMAND
|
||||||
|
assert args.yes is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_parser_parses_without_yes_flag(monkeypatch):
|
||||||
|
_setup_parser_mocks(monkeypatch)
|
||||||
|
parser = cli.create_parser()
|
||||||
|
args = parser.parse_args([TEST_COMMAND])
|
||||||
|
assert args.command == TEST_COMMAND
|
||||||
|
assert getattr(args, "yes", False) is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_parser_version_flag_exits(monkeypatch):
|
||||||
|
_setup_parser_mocks(monkeypatch)
|
||||||
|
parser = cli.create_parser()
|
||||||
|
with pytest.raises(SystemExit) as excinfo:
|
||||||
|
parser.parse_args(["--version"])
|
||||||
|
assert excinfo.value.code == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_parser_requires_command(monkeypatch):
|
||||||
|
_setup_parser_mocks(monkeypatch)
|
||||||
|
parser = cli.create_parser()
|
||||||
|
with pytest.raises(SystemExit):
|
||||||
|
parser.parse_args([])
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_parser_rejects_invalid_command(monkeypatch):
|
||||||
|
_setup_parser_mocks(monkeypatch)
|
||||||
|
parser = cli.create_parser()
|
||||||
|
with pytest.raises(SystemExit):
|
||||||
|
parser.parse_args(["invalid_command"])
|
||||||
|
|
||||||
|
|
||||||
def test_require_config_exits_when_missing(monkeypatch):
|
def test_require_config_exits_when_missing(monkeypatch):
|
||||||
monkeypatch.setattr(cli.Configuration, "exists", lambda self: False)
|
monkeypatch.setattr(cli.Configuration, "exists", lambda self: False)
|
||||||
logger = logging.getLogger("test.cli")
|
logger = logging.getLogger(TEST_LOGGER_NAME)
|
||||||
with pytest.raises(SystemExit) as excinfo:
|
with pytest.raises(SystemExit) as excinfo:
|
||||||
cli.require_config(logger)
|
cli.require_config(logger)
|
||||||
assert excinfo.value.code == 1
|
assert excinfo.value.code == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_require_config_does_not_exit_when_exists(monkeypatch):
|
||||||
|
monkeypatch.setattr(cli.Configuration, "exists", lambda self: True)
|
||||||
|
logger = logging.getLogger(TEST_LOGGER_NAME)
|
||||||
|
cli.require_config(logger)
|
||||||
|
|
||||||
|
|
||||||
def test_main_executes_without_config(monkeypatch):
|
def test_main_executes_without_config(monkeypatch):
|
||||||
calls = {"execute": None, "setup": []}
|
calls = {"execute": None, "setup": [], "require_config": []}
|
||||||
|
_setup_main_mocks(monkeypatch, calls, requires_config=False)
|
||||||
|
|
||||||
monkeypatch.setattr(cli.registry, "get_all_commands",
|
monkeypatch.setattr(sys, "argv", ["skywipe", "--yes", TEST_COMMAND])
|
||||||
lambda: {"posts": "only posts"})
|
|
||||||
monkeypatch.setattr(cli.registry, "requires_config", lambda name: False)
|
|
||||||
monkeypatch.setattr(cli.registry, "execute", lambda name, skip_confirmation=False: calls.update(
|
|
||||||
{"execute": (name, skip_confirmation)}
|
|
||||||
))
|
|
||||||
monkeypatch.setattr(cli, "setup_logger", lambda verbose,
|
|
||||||
log_file: calls["setup"].append((verbose, log_file)))
|
|
||||||
monkeypatch.setattr(cli, "get_logger",
|
|
||||||
lambda: logging.getLogger("test.cli"))
|
|
||||||
|
|
||||||
monkeypatch.setattr(sys, "argv", ["skywipe", "--yes", "posts"])
|
|
||||||
cli.main()
|
cli.main()
|
||||||
|
|
||||||
|
assert len(calls["require_config"]) == 0
|
||||||
assert calls["setup"] == [(False, cli.LOG_FILE)]
|
assert calls["setup"] == [(False, cli.LOG_FILE)]
|
||||||
assert calls["execute"] == ("posts", True)
|
assert calls["execute"] == (TEST_COMMAND, True)
|
||||||
|
|
||||||
|
|
||||||
def test_main_loads_config_and_sets_verbose(monkeypatch):
|
def test_main_loads_config_and_sets_verbose(monkeypatch):
|
||||||
calls = {"setup": [], "execute": None, "require_config": 0}
|
calls = {"setup": [], "execute": None, "require_config": []}
|
||||||
|
_setup_main_mocks(monkeypatch, calls, requires_config=True,
|
||||||
|
config_data={"verbose": True})
|
||||||
|
|
||||||
monkeypatch.setattr(cli.registry, "get_all_commands",
|
monkeypatch.setattr(sys, "argv", ["skywipe", TEST_COMMAND])
|
||||||
lambda: {"posts": "only posts"})
|
|
||||||
monkeypatch.setattr(cli.registry, "requires_config", lambda name: True)
|
|
||||||
monkeypatch.setattr(cli.registry, "execute", lambda name, skip_confirmation=False: calls.update(
|
|
||||||
{"execute": (name, skip_confirmation)}
|
|
||||||
))
|
|
||||||
monkeypatch.setattr(cli, "require_config", lambda logger: calls.update(
|
|
||||||
{"require_config": calls["require_config"] + 1}
|
|
||||||
))
|
|
||||||
monkeypatch.setattr(cli.Configuration, "load",
|
|
||||||
lambda self: {"verbose": True})
|
|
||||||
monkeypatch.setattr(cli, "setup_logger", lambda verbose,
|
|
||||||
log_file: calls["setup"].append((verbose, log_file)))
|
|
||||||
monkeypatch.setattr(cli, "get_logger",
|
|
||||||
lambda: logging.getLogger("test.cli"))
|
|
||||||
|
|
||||||
monkeypatch.setattr(sys, "argv", ["skywipe", "posts"])
|
|
||||||
cli.main()
|
cli.main()
|
||||||
|
|
||||||
assert calls["require_config"] == 1
|
assert len(calls["require_config"]) == 1
|
||||||
assert calls["setup"] == [(False, cli.LOG_FILE), (True, cli.LOG_FILE)]
|
assert calls["setup"] == [(False, cli.LOG_FILE), (True, cli.LOG_FILE)]
|
||||||
assert calls["execute"] == ("posts", False)
|
assert calls["execute"] == (TEST_COMMAND, False)
|
||||||
|
|
||||||
|
|
||||||
def test_main_handles_execute_error(monkeypatch):
|
@pytest.mark.parametrize("config_data,expected_verbose", [
|
||||||
calls = {"handle_error": None}
|
({}, False),
|
||||||
|
({"verbose": False}, False),
|
||||||
|
])
|
||||||
|
def test_main_config_verbose_defaults(monkeypatch, config_data, expected_verbose):
|
||||||
|
calls = {"setup": [], "execute": None, "require_config": []}
|
||||||
|
_setup_main_mocks(monkeypatch, calls, requires_config=True,
|
||||||
|
config_data=config_data)
|
||||||
|
|
||||||
monkeypatch.setattr(cli.registry, "get_all_commands",
|
monkeypatch.setattr(sys, "argv", ["skywipe", TEST_COMMAND])
|
||||||
lambda: {"posts": "only posts"})
|
cli.main()
|
||||||
monkeypatch.setattr(cli.registry, "requires_config", lambda name: False)
|
|
||||||
|
|
||||||
def raise_error(*_args, **_kwargs):
|
assert len(calls["require_config"]) == 1
|
||||||
raise ValueError("boom")
|
assert calls["setup"] == [(False, cli.LOG_FILE),
|
||||||
|
(expected_verbose, cli.LOG_FILE)]
|
||||||
|
assert calls["execute"] == (TEST_COMMAND, False)
|
||||||
|
|
||||||
monkeypatch.setattr(cli.registry, "execute", raise_error)
|
|
||||||
|
def test_main_handles_config_load_error(monkeypatch):
|
||||||
|
calls = {"handle_error": None, "require_config": []}
|
||||||
|
|
||||||
|
def mock_require_config(logger):
|
||||||
|
calls["require_config"].append(logger)
|
||||||
|
|
||||||
|
def raise_config_error(self):
|
||||||
|
raise RuntimeError("config error")
|
||||||
|
|
||||||
|
def _format_error_message(error):
|
||||||
|
if isinstance(error, KeyError):
|
||||||
|
return error.args[0] if error.args else str(error)
|
||||||
|
return str(error)
|
||||||
|
|
||||||
|
def mock_handle_error(error, logger, exit_on_error=False):
|
||||||
|
calls["handle_error"] = (type(error).__name__,
|
||||||
|
_format_error_message(error), exit_on_error)
|
||||||
|
|
||||||
|
_setup_parser_mocks(monkeypatch)
|
||||||
|
monkeypatch.setattr(cli.registry, "requires_config", lambda name: True)
|
||||||
|
monkeypatch.setattr(cli, "require_config", mock_require_config)
|
||||||
|
monkeypatch.setattr(cli.Configuration, "load", raise_config_error)
|
||||||
monkeypatch.setattr(cli, "setup_logger", lambda verbose, log_file: None)
|
monkeypatch.setattr(cli, "setup_logger", lambda verbose, log_file: None)
|
||||||
monkeypatch.setattr(cli, "get_logger",
|
monkeypatch.setattr(cli, "get_logger",
|
||||||
lambda: logging.getLogger("test.cli"))
|
lambda: logging.getLogger(TEST_LOGGER_NAME))
|
||||||
|
monkeypatch.setattr(cli, "handle_error", mock_handle_error)
|
||||||
def fake_handle_error(error, logger, exit_on_error=False):
|
|
||||||
calls["handle_error"] = (str(error), exit_on_error)
|
|
||||||
|
|
||||||
monkeypatch.setattr(cli, "handle_error", fake_handle_error)
|
|
||||||
monkeypatch.setattr(sys, "argv", ["skywipe", "posts"])
|
|
||||||
|
|
||||||
|
monkeypatch.setattr(sys, "argv", ["skywipe", TEST_COMMAND])
|
||||||
cli.main()
|
cli.main()
|
||||||
|
|
||||||
assert calls["handle_error"] == ("boom", True)
|
assert len(calls["require_config"]) == 1
|
||||||
|
assert calls["handle_error"] is not None
|
||||||
|
assert calls["handle_error"][0] == "RuntimeError"
|
||||||
|
assert calls["handle_error"][2] is True
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("error_class,error_message", [
|
||||||
|
(ValueError, TEST_ERROR_MESSAGE),
|
||||||
|
(RuntimeError, "runtime error"),
|
||||||
|
(KeyError, "missing key"),
|
||||||
|
])
|
||||||
|
def test_main_handles_execute_errors(monkeypatch, error_class, error_message):
|
||||||
|
calls = {"handle_error": None}
|
||||||
|
|
||||||
|
def raise_error(*_args, **_kwargs):
|
||||||
|
raise error_class(error_message)
|
||||||
|
|
||||||
|
_setup_error_mocks(monkeypatch, calls, raise_error)
|
||||||
|
monkeypatch.setattr(sys, "argv", ["skywipe", TEST_COMMAND])
|
||||||
|
cli.main()
|
||||||
|
|
||||||
|
assert calls["handle_error"] is not None
|
||||||
|
assert calls["handle_error"][0] == error_class.__name__
|
||||||
|
assert calls["handle_error"][1] == error_message
|
||||||
|
assert calls["handle_error"][2] is True
|
||||||
|
|||||||
@@ -1,11 +1,8 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
from skywipe.configure import Configuration
|
|
||||||
|
|
||||||
|
def test_configuration_create_reprompts_and_writes_file(config_with_tmp_path, user_input):
|
||||||
def test_configuration_create_reprompts_and_writes_file(monkeypatch, tmp_path, user_input):
|
|
||||||
inputs = iter([
|
inputs = iter([
|
||||||
"bad handle",
|
"bad handle",
|
||||||
"alice.bsky.social",
|
"alice.bsky.social",
|
||||||
@@ -18,10 +15,9 @@ def test_configuration_create_reprompts_and_writes_file(monkeypatch, tmp_path, u
|
|||||||
"longenough",
|
"longenough",
|
||||||
])
|
])
|
||||||
|
|
||||||
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
|
||||||
user_input(inputs, passwords)
|
user_input(inputs, passwords)
|
||||||
|
|
||||||
config = Configuration()
|
config = config_with_tmp_path
|
||||||
config.create()
|
config.create()
|
||||||
|
|
||||||
assert config.config_file.exists() is True
|
assert config.config_file.exists() is True
|
||||||
@@ -33,7 +29,7 @@ def test_configuration_create_reprompts_and_writes_file(monkeypatch, tmp_path, u
|
|||||||
assert data["verbose"] is False
|
assert data["verbose"] is False
|
||||||
|
|
||||||
|
|
||||||
def test_configuration_create_invalid_batch_size(monkeypatch, tmp_path, user_input):
|
def test_configuration_create_invalid_batch_size(config_with_tmp_path, user_input):
|
||||||
inputs = iter([
|
inputs = iter([
|
||||||
"alice.bsky.social",
|
"alice.bsky.social",
|
||||||
"0",
|
"0",
|
||||||
@@ -42,16 +38,15 @@ def test_configuration_create_invalid_batch_size(monkeypatch, tmp_path, user_inp
|
|||||||
])
|
])
|
||||||
passwords = iter(["longenough"])
|
passwords = iter(["longenough"])
|
||||||
|
|
||||||
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
|
||||||
user_input(inputs, passwords)
|
user_input(inputs, passwords)
|
||||||
|
|
||||||
config = Configuration()
|
config = config_with_tmp_path
|
||||||
config.create()
|
config.create()
|
||||||
|
|
||||||
assert config.config_file.exists() is False
|
assert config.config_file.exists() is False
|
||||||
|
|
||||||
|
|
||||||
def test_configuration_create_invalid_delay(monkeypatch, tmp_path, user_input):
|
def test_configuration_create_invalid_delay(config_with_tmp_path, user_input):
|
||||||
inputs = iter([
|
inputs = iter([
|
||||||
"alice.bsky.social",
|
"alice.bsky.social",
|
||||||
"10",
|
"10",
|
||||||
@@ -60,18 +55,16 @@ def test_configuration_create_invalid_delay(monkeypatch, tmp_path, user_input):
|
|||||||
])
|
])
|
||||||
passwords = iter(["longenough"])
|
passwords = iter(["longenough"])
|
||||||
|
|
||||||
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
|
||||||
user_input(inputs, passwords)
|
user_input(inputs, passwords)
|
||||||
|
|
||||||
config = Configuration()
|
config = config_with_tmp_path
|
||||||
config.create()
|
config.create()
|
||||||
|
|
||||||
assert config.config_file.exists() is False
|
assert config.config_file.exists() is False
|
||||||
|
|
||||||
|
|
||||||
def test_configuration_create_overwrite_cancel(monkeypatch, tmp_path, user_input):
|
def test_configuration_create_overwrite_cancel(config_with_tmp_path, user_input):
|
||||||
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
config = config_with_tmp_path
|
||||||
config = Configuration()
|
|
||||||
config.config_file.parent.mkdir(parents=True, exist_ok=True)
|
config.config_file.parent.mkdir(parents=True, exist_ok=True)
|
||||||
config.config_file.write_text("existing")
|
config.config_file.write_text("existing")
|
||||||
|
|
||||||
@@ -82,14 +75,13 @@ def test_configuration_create_overwrite_cancel(monkeypatch, tmp_path, user_input
|
|||||||
assert config.config_file.read_text() == "existing"
|
assert config.config_file.read_text() == "existing"
|
||||||
|
|
||||||
|
|
||||||
def test_configuration_create_write_failure(monkeypatch, tmp_path, user_input):
|
def test_configuration_create_write_failure(config_with_tmp_path, user_input, monkeypatch):
|
||||||
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
|
||||||
user_input(
|
user_input(
|
||||||
["alice.bsky.social", "5", "0", "y"],
|
["alice.bsky.social", "5", "0", "y"],
|
||||||
["longenough"],
|
["longenough"],
|
||||||
)
|
)
|
||||||
|
|
||||||
config = Configuration()
|
config = config_with_tmp_path
|
||||||
|
|
||||||
original_open = open
|
original_open = open
|
||||||
|
|
||||||
|
|||||||
@@ -1,32 +1,59 @@
|
|||||||
from pathlib import Path
|
from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
import yaml
|
||||||
from skywipe.configure import Configuration
|
|
||||||
|
|
||||||
|
|
||||||
def test_configuration_load_missing_file(monkeypatch, tmp_path):
|
def test_configuration_load_missing_file(config_with_tmp_path):
|
||||||
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
with pytest.raises(FileNotFoundError, match="Configuration file not found"):
|
||||||
config = Configuration()
|
config_with_tmp_path.load()
|
||||||
with pytest.raises(FileNotFoundError):
|
|
||||||
config.load()
|
|
||||||
|
|
||||||
|
|
||||||
def test_configuration_load_empty_file(monkeypatch, tmp_path):
|
def test_configuration_load_empty_file(config_with_tmp_path):
|
||||||
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
config_with_tmp_path.config_file.parent.mkdir(parents=True, exist_ok=True)
|
||||||
config = Configuration()
|
config_with_tmp_path.config_file.write_text("")
|
||||||
config.config_file.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
config.config_file.write_text("")
|
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="empty or invalid"):
|
with pytest.raises(ValueError, match="empty or invalid"):
|
||||||
config.load()
|
config_with_tmp_path.load()
|
||||||
|
|
||||||
|
|
||||||
def test_configuration_load_invalid_yaml(monkeypatch, tmp_path):
|
def test_configuration_load_invalid_yaml(config_with_tmp_path):
|
||||||
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
config_with_tmp_path.config_file.parent.mkdir(parents=True, exist_ok=True)
|
||||||
config = Configuration()
|
config_with_tmp_path.config_file.write_text(": bad")
|
||||||
config.config_file.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
config.config_file.write_text(": bad")
|
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="Invalid YAML"):
|
with pytest.raises(ValueError, match="Invalid YAML"):
|
||||||
config.load()
|
config_with_tmp_path.load()
|
||||||
|
|
||||||
|
|
||||||
|
def test_configuration_load_valid_file(config_with_tmp_path):
|
||||||
|
config_with_tmp_path.config_file.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
config_data = {
|
||||||
|
"handle": "alice.bsky.social",
|
||||||
|
"password": "password123",
|
||||||
|
"batch_size": 10,
|
||||||
|
"delay": 1,
|
||||||
|
"verbose": True
|
||||||
|
}
|
||||||
|
with open(config_with_tmp_path.config_file, "w") as f:
|
||||||
|
yaml.dump(config_data, f)
|
||||||
|
|
||||||
|
result = config_with_tmp_path.load()
|
||||||
|
assert result == config_data
|
||||||
|
|
||||||
|
|
||||||
|
def test_configuration_load_file_read_error(config_with_tmp_path):
|
||||||
|
config_with_tmp_path.config_file.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
config_with_tmp_path.config_file.write_text("handle: alice")
|
||||||
|
|
||||||
|
with patch("builtins.open", side_effect=IOError("Permission denied")):
|
||||||
|
with pytest.raises(ValueError, match="Failed to read configuration file"):
|
||||||
|
config_with_tmp_path.load()
|
||||||
|
|
||||||
|
|
||||||
|
def test_configuration_load_file_os_error(config_with_tmp_path):
|
||||||
|
config_with_tmp_path.config_file.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
config_with_tmp_path.config_file.write_text("handle: alice")
|
||||||
|
|
||||||
|
with patch("builtins.open", side_effect=OSError("File is locked")):
|
||||||
|
with pytest.raises(ValueError, match="Failed to read configuration file"):
|
||||||
|
config_with_tmp_path.load()
|
||||||
|
|||||||
@@ -124,6 +124,74 @@ def test_setup_logger_does_not_duplicate_handlers():
|
|||||||
assert first_count == second_count
|
assert first_count == second_count
|
||||||
|
|
||||||
|
|
||||||
|
def test_setup_logger_resets_stream_formatters():
|
||||||
|
logger = logging.getLogger("skywipe")
|
||||||
|
original_handlers = list(logger.handlers)
|
||||||
|
for handler in original_handlers:
|
||||||
|
logger.removeHandler(handler)
|
||||||
|
|
||||||
|
try:
|
||||||
|
setup_logger(verbose=False)
|
||||||
|
alt_formatter = logging.Formatter(fmt="%(message)s")
|
||||||
|
for handler in logger.handlers:
|
||||||
|
if isinstance(handler, logging.StreamHandler):
|
||||||
|
handler.setFormatter(alt_formatter)
|
||||||
|
|
||||||
|
setup_logger(verbose=False)
|
||||||
|
|
||||||
|
for handler in logger.handlers:
|
||||||
|
if isinstance(handler, logging.StreamHandler):
|
||||||
|
assert handler.formatter is not None
|
||||||
|
assert handler.formatter._fmt == "%(levelname)s: %(message)s"
|
||||||
|
finally:
|
||||||
|
for handler in list(logger.handlers):
|
||||||
|
handler.close()
|
||||||
|
logger.removeHandler(handler)
|
||||||
|
for handler in original_handlers:
|
||||||
|
logger.addHandler(handler)
|
||||||
|
|
||||||
|
|
||||||
|
def test_setup_logger_disables_propagation():
|
||||||
|
root_logger = logging.getLogger()
|
||||||
|
root_messages = []
|
||||||
|
original_root_level = root_logger.level
|
||||||
|
|
||||||
|
class RootHandler(logging.Handler):
|
||||||
|
def emit(self, record):
|
||||||
|
root_messages.append(self.format(record))
|
||||||
|
|
||||||
|
root_handler = RootHandler()
|
||||||
|
root_handler.setLevel(logging.INFO)
|
||||||
|
root_logger.addHandler(root_handler)
|
||||||
|
root_logger.setLevel(logging.INFO)
|
||||||
|
|
||||||
|
logger = logging.getLogger("skywipe")
|
||||||
|
original_handlers = list(logger.handlers)
|
||||||
|
for handler in original_handlers:
|
||||||
|
logger.removeHandler(handler)
|
||||||
|
|
||||||
|
try:
|
||||||
|
setup_logger(verbose=False)
|
||||||
|
assert logger.propagate is False
|
||||||
|
|
||||||
|
progress_logger = logging.getLogger("skywipe.progress")
|
||||||
|
assert progress_logger.propagate is True
|
||||||
|
|
||||||
|
logger.info("Test message")
|
||||||
|
progress_logger.info("Progress message")
|
||||||
|
|
||||||
|
assert len(root_messages) == 0
|
||||||
|
finally:
|
||||||
|
root_handler.close()
|
||||||
|
root_logger.removeHandler(root_handler)
|
||||||
|
root_logger.setLevel(original_root_level)
|
||||||
|
for handler in list(logger.handlers):
|
||||||
|
handler.close()
|
||||||
|
logger.removeHandler(handler)
|
||||||
|
for handler in original_handlers:
|
||||||
|
logger.addHandler(handler)
|
||||||
|
|
||||||
|
|
||||||
def test_setup_logger_file_handler_lifecycle(tmp_path):
|
def test_setup_logger_file_handler_lifecycle(tmp_path):
|
||||||
logger = logging.getLogger("skywipe")
|
logger = logging.getLogger("skywipe")
|
||||||
original_handlers = list(logger.handlers)
|
original_handlers = list(logger.handlers)
|
||||||
|
|||||||
Reference in New Issue
Block a user