diff --git a/tests/test_logger.py b/tests/test_logger.py index a2f9c8d..43bb824 100644 --- a/tests/test_logger.py +++ b/tests/test_logger.py @@ -41,6 +41,53 @@ def test_progress_tracker_updates_counts(): assert tracker.current == 3 +def test_progress_tracker_batch_with_total(): + tracker = ProgressTracker(operation="Testing") + logger = logging.getLogger("skywipe.progress") + messages = [] + + class MessageHandler(logging.Handler): + def emit(self, record): + messages.append(self.format(record)) + + handler = MessageHandler() + handler.setLevel(logging.INFO) + logger.addHandler(handler) + logger.setLevel(logging.INFO) + + try: + tracker.batch(1, 10, total_batches=5) + assert messages[-1] == "Testing - batch 1/5 (10 items)" + + tracker.batch(0, 5, total_batches=0) + assert messages[-1] == "Testing - batch 0/0 (5 items)" + finally: + handler.close() + logger.removeHandler(handler) + + +def test_progress_tracker_batch_without_total(): + tracker = ProgressTracker(operation="Testing") + logger = logging.getLogger("skywipe.progress") + messages = [] + + class MessageHandler(logging.Handler): + def emit(self, record): + messages.append(self.format(record)) + + handler = MessageHandler() + handler.setLevel(logging.INFO) + logger.addHandler(handler) + logger.setLevel(logging.INFO) + + try: + tracker.batch(1, 10, total_batches=None) + assert messages[-1] == "Testing - batch 1 (10 items)" + finally: + handler.close() + logger.removeHandler(handler) + + def test_setup_logger_does_not_duplicate_handlers(): logger = logging.getLogger("skywipe") original_handlers = list(logger.handlers)