refactor: type-narrow app access and email lookup

This commit is contained in:
2026-01-04 17:52:05 +01:00
parent b6c483623d
commit 0505086e11

View File

@@ -2,13 +2,25 @@
import json import json
from datetime import date, datetime from datetime import date, datetime
from typing import Any, Protocol, TYPE_CHECKING, cast
from textual.app import ComposeResult from textual.app import ComposeResult
from textual.containers import Container, Horizontal, Vertical from textual.containers import Container, Horizontal, Vertical
from textual.screen import ModalScreen from textual.screen import ModalScreen
from textual.widgets import Label, ListItem, ListView, Static from textual.widgets import Label, ListItem, ListView, Static
from .constants import CONFIG_PATH from .constants import AUTH_PATH, CONFIG_PATH
if TYPE_CHECKING:
from textual.binding import Binding
class _AppContext(Protocol):
BINDINGS: list[tuple[str, str, str]]
client: Any
auth: Any
library_client: Any
all_items: list[dict]
KEY_DISPLAY_MAP = { KEY_DISPLAY_MAP = {
@@ -25,7 +37,14 @@ KEY_COLOR = "#f9e2af"
DESC_COLOR = "#cdd6f4" DESC_COLOR = "#cdd6f4"
class HelpScreen(ModalScreen): class AppContextMixin:
"""Mixin to provide a typed app accessor."""
def _app(self) -> _AppContext:
return cast(_AppContext, self.app)
class HelpScreen(AppContextMixin, ModalScreen):
"""Help screen displaying all available keybindings.""" """Help screen displaying all available keybindings."""
BINDINGS = [("escape", "dismiss", "Close"), ("?", "dismiss", "Close")] BINDINGS = [("escape", "dismiss", "Close"), ("?", "dismiss", "Close")]
@@ -39,13 +58,13 @@ class HelpScreen(ModalScreen):
return result return result
@staticmethod @staticmethod
def _parse_binding(binding: tuple | object) -> tuple[str, str]: def _parse_binding(binding: "Binding | tuple[str, str, str]") -> tuple[str, str]:
"""Extract key and description from a binding.""" """Extract key and description from a binding."""
if isinstance(binding, tuple): if isinstance(binding, tuple):
return binding[0], binding[2] return binding[0], binding[2]
return binding.key, binding.description return binding.key, binding.description
def _make_item(self, binding: tuple | object) -> ListItem: def _make_item(self, binding: "Binding | tuple[str, str, str]") -> ListItem:
"""Create a ListItem for a single binding.""" """Create a ListItem for a single binding."""
key, description = self._parse_binding(binding) key, description = self._parse_binding(binding)
key_display = self._format_key_display(key) key_display = self._format_key_display(key)
@@ -53,7 +72,8 @@ class HelpScreen(ModalScreen):
return ListItem(Label(text)) return ListItem(Label(text))
def compose(self) -> ComposeResult: def compose(self) -> ComposeResult:
bindings = list(self.app.BINDINGS) app = self._app()
bindings = list(app.BINDINGS)
mid = (len(bindings) + 1) // 2 mid = (len(bindings) + 1) // 2
with Container(id="help_container"): with Container(id="help_container"):
@@ -72,11 +92,11 @@ class HelpScreen(ModalScreen):
id="help_footer", id="help_footer",
) )
def action_dismiss(self) -> None: async def action_dismiss(self, result: Any | None = None) -> None:
self.dismiss() await self.dismiss(result)
class StatsScreen(ModalScreen): class StatsScreen(AppContextMixin, ModalScreen):
"""Stats screen displaying listening statistics.""" """Stats screen displaying listening statistics."""
BINDINGS = [("escape", "dismiss", "Close"), ("s", "dismiss", "Close")] BINDINGS = [("escape", "dismiss", "Close"), ("s", "dismiss", "Close")]
@@ -102,13 +122,14 @@ class StatsScreen(ModalScreen):
def _get_signup_year(self) -> int: def _get_signup_year(self) -> int:
"""Get signup year using binary search on listening activity.""" """Get signup year using binary search on listening activity."""
if not self.app.client: app = self._app()
if not app.client:
return 0 return 0
current_year = date.today().year current_year = date.today().year
try: try:
stats = self.app.client.get( stats = app.client.get(
"1.0/stats/aggregates", "1.0/stats/aggregates",
monthly_listening_interval_duration="12", monthly_listening_interval_duration="12",
monthly_listening_interval_start_date=f"{current_year}-01", monthly_listening_interval_start_date=f"{current_year}-01",
@@ -125,7 +146,7 @@ class StatsScreen(ModalScreen):
while left <= right: while left <= right:
middle = (left + right) // 2 middle = (left + right) // 2
try: try:
stats = self.app.client.get( stats = app.client.get(
"1.0/stats/aggregates", "1.0/stats/aggregates",
monthly_listening_interval_duration="12", monthly_listening_interval_duration="12",
monthly_listening_interval_start_date=f"{middle}-01", monthly_listening_interval_start_date=f"{middle}-01",
@@ -154,11 +175,12 @@ class StatsScreen(ModalScreen):
def _get_listening_time(self, duration: int, start_date: str) -> int: def _get_listening_time(self, duration: int, start_date: str) -> int:
"""Get listening time in milliseconds for a given period.""" """Get listening time in milliseconds for a given period."""
if not self.app.client: app = self._app()
if not app.client:
return 0 return 0
try: try:
stats = self.app.client.get( stats = app.client.get(
"1.0/stats/aggregates", "1.0/stats/aggregates",
monthly_listening_interval_duration=str(duration), monthly_listening_interval_duration=str(duration),
monthly_listening_interval_start_date=start_date, monthly_listening_interval_start_date=start_date,
@@ -171,15 +193,17 @@ class StatsScreen(ModalScreen):
def _get_finished_books_count(self) -> int: def _get_finished_books_count(self) -> int:
"""Get count of finished books from library.""" """Get count of finished books from library."""
if not self.app.library_client or not self.app.all_items: app = self._app()
if not app.library_client or not app.all_items:
return 0 return 0
return sum( return sum(
1 for item in self.app.all_items if self.app.library_client.is_finished(item) 1 for item in app.all_items if app.library_client.is_finished(item)
) )
def _get_account_info(self) -> dict: def _get_account_info(self) -> dict:
"""Get account information including subscription details.""" """Get account information including subscription details."""
if not self.app.client: app = self._app()
if not app.client:
return {} return {}
account_info = {} account_info = {}
@@ -200,7 +224,7 @@ class StatsScreen(ModalScreen):
for endpoint, response_groups in endpoints: for endpoint, response_groups in endpoints:
try: try:
response = self.app.client.get( response = app.client.get(
endpoint, response_groups=response_groups) endpoint, response_groups=response_groups)
account_info.update(response) account_info.update(response)
except Exception: except Exception:
@@ -209,15 +233,151 @@ class StatsScreen(ModalScreen):
return account_info return account_info
def _get_email(self) -> str: def _get_email(self) -> str:
"""Get email from config file.""" """Get email from auth, config, or API."""
app = self._app()
for getter in (
self._get_email_from_auth,
self._get_email_from_config,
self._get_email_from_auth_file,
self._get_email_from_account_info,
):
email = getter(app)
if email:
return email
auth_data: dict[str, Any] | None = None
if app.auth:
try:
auth_data = getattr(app.auth, "data", None)
except Exception:
auth_data = None
account_info = self._get_account_info() if app.client else None
for candidate in (auth_data, account_info):
email = self._find_email_in_data(candidate)
if email:
return email
return "Unknown"
def _get_email_from_auth(self, app: _AppContext) -> str | None:
"""Extract email from the authenticator if available."""
if not app.auth:
return None
try:
email = self._first_email(
getattr(app.auth, "username", None),
getattr(app.auth, "login", None),
getattr(app.auth, "email", None),
)
if email:
return email
except Exception:
return None
try:
customer_info = getattr(app.auth, "customer_info", None)
if isinstance(customer_info, dict):
email = self._first_email(
customer_info.get("email"),
customer_info.get("email_address"),
customer_info.get("primary_email"),
)
if email:
return email
except Exception:
return None
try:
data = getattr(app.auth, "data", None)
if isinstance(data, dict):
return self._first_email(
data.get("username"),
data.get("email"),
data.get("login"),
data.get("user_email"),
)
except Exception:
return None
def _get_email_from_config(self, app: _AppContext) -> str | None:
"""Extract email from the config file."""
try: try:
if CONFIG_PATH.exists(): if CONFIG_PATH.exists():
with open(CONFIG_PATH, "r", encoding="utf-8") as f: with open(CONFIG_PATH, "r", encoding="utf-8") as f:
config = json.load(f) config = json.load(f)
return config.get("email", "Unknown") return self._first_email(
config.get("email"),
config.get("username"),
config.get("login"),
)
except Exception: except Exception:
pass return None
return "Unknown"
def _get_email_from_auth_file(self, app: _AppContext) -> str | None:
"""Extract email from the auth file."""
try:
if AUTH_PATH.exists():
with open(AUTH_PATH, "r", encoding="utf-8") as f:
auth_file_data = json.load(f)
return self._first_email(
auth_file_data.get("username"),
auth_file_data.get("email"),
auth_file_data.get("login"),
auth_file_data.get("user_email"),
)
except Exception:
return None
def _get_email_from_account_info(self, app: _AppContext) -> str | None:
"""Extract email from the account info API."""
if not app.client:
return None
try:
account_info = self._get_account_info()
if account_info:
email = self._first_email(
account_info.get("email"),
account_info.get("customer_email"),
account_info.get("username"),
)
if email:
return email
customer_info = account_info.get("customer_info", {})
if isinstance(customer_info, dict):
return self._first_email(
customer_info.get("email"),
customer_info.get("email_address"),
customer_info.get("primary_email"),
)
except Exception:
return None
def _first_email(self, *values: str | None) -> str | None:
"""Return the first non-empty, non-Unknown email value."""
for value in values:
if value and value != "Unknown":
return value
return None
def _find_email_in_data(self, data: Any) -> str | None:
"""Search nested data for an email-like value."""
if data is None:
return None
stack: list[Any] = [data]
while stack:
current = stack.pop()
if isinstance(current, dict):
stack.extend(current.values())
elif isinstance(current, list):
stack.extend(current)
elif isinstance(current, str):
if "@" in current:
local, _, domain = current.partition("@")
if local and "." in domain:
return current
return None
def _get_subscription_details(self, account_info: dict) -> dict: def _get_subscription_details(self, account_info: dict) -> dict:
"""Extract subscription details from nested API response.""" """Extract subscription details from nested API response."""
@@ -228,7 +388,7 @@ class StatsScreen(ModalScreen):
["subscription", "subscription_details"], ["subscription", "subscription_details"],
] ]
for path in paths: for path in paths:
data = account_info data: Any = account_info
for key in path: for key in path:
if isinstance(data, dict): if isinstance(data, dict):
data = data.get(key) data = data.get(key)
@@ -240,11 +400,12 @@ class StatsScreen(ModalScreen):
def _get_country(self) -> str: def _get_country(self) -> str:
"""Get country from authenticator locale.""" """Get country from authenticator locale."""
if not self.app.auth: app = self._app()
if not app.auth:
return "Unknown" return "Unknown"
try: try:
locale_obj = getattr(self.app.auth, "locale", None) locale_obj = getattr(app.auth, "locale", None)
if not locale_obj: if not locale_obj:
return "Unknown" return "Unknown"
@@ -264,7 +425,8 @@ class StatsScreen(ModalScreen):
return ListItem(Label(text)) return ListItem(Label(text))
def compose(self) -> ComposeResult: def compose(self) -> ComposeResult:
if not self.app.client: app = self._app()
if not app.client:
with Container(id="help_container"): with Container(id="help_container"):
yield Static("Statistics", id="help_title") yield Static("Statistics", id="help_title")
yield Static( yield Static(
@@ -299,7 +461,8 @@ class StatsScreen(ModalScreen):
month_time = self._get_listening_time(1, today.strftime("%Y-%m")) month_time = self._get_listening_time(1, today.strftime("%Y-%m"))
year_time = self._get_listening_time(12, today.strftime("%Y-01")) year_time = self._get_listening_time(12, today.strftime("%Y-01"))
finished_count = self._get_finished_books_count() finished_count = self._get_finished_books_count()
total_books = len(self.app.all_items) if self.app.all_items else 0 app = self._app()
total_books = len(app.all_items) if app.all_items else 0
email = self._get_email() email = self._get_email()
country = self._get_country() country = self._get_country()
@@ -344,5 +507,5 @@ class StatsScreen(ModalScreen):
return stats_items return stats_items
def action_dismiss(self) -> None: async def action_dismiss(self, result: Any | None = None) -> None:
self.dismiss() await self.dismiss(result)