Add some type hints for mypy strict mode

This commit is contained in:
taoky 2024-08-06 18:39:23 +08:00
parent 6751bc93cb
commit aa4ae7e477

View File

@ -1,6 +1,7 @@
#!/usr/bin/env python #!/usr/bin/env python
import sys import sys
from typing import Any, Optional from types import FrameType
from typing import IO, Any, Callable, Generator, Optional
import xmlrpc.client import xmlrpc.client
from dataclasses import dataclass from dataclasses import dataclass
import re import re
@ -47,7 +48,7 @@ class ExitProgramException(Exception):
pass pass
def exit_handler(signum, frame): def exit_handler(signum: int, frame: Optional[FrameType]) -> None:
raise ExitProgramException raise ExitProgramException
@ -126,7 +127,9 @@ class LocalVersionKV:
@contextmanager @contextmanager
def overwrite(file_path: Path, mode: str = "w", tmp_suffix: str = ".tmp"): def overwrite(
file_path: Path, mode: str = "w", tmp_suffix: str = ".tmp"
) -> Generator[IO[Any], None, None]:
tmp_path = file_path.parent / (file_path.name + tmp_suffix) tmp_path = file_path.parent / (file_path.name + tmp_suffix)
try: try:
with open(tmp_path, mode) as tmp_file: with open(tmp_path, mode) as tmp_file:
@ -249,14 +252,16 @@ class PyPI:
self.session = create_requests_session() self.session = create_requests_session()
def list_packages_with_serial(self) -> dict[str, int]: def list_packages_with_serial(self) -> dict[str, int]:
logger.info("Calling list_packages_with_serial() RPC, this requires some time...") logger.info(
"Calling list_packages_with_serial() RPC, this requires some time..."
)
return self.xmlrpc_client.list_packages_with_serial() # type: ignore return self.xmlrpc_client.list_packages_with_serial() # type: ignore
def get_package_metadata(self, package_name: str) -> dict: def get_package_metadata(self, package_name: str) -> dict:
req = self.session.get(urljoin(self.host, f"pypi/{package_name}/json")) req = self.session.get(urljoin(self.host, f"pypi/{package_name}/json"))
if req.status_code == 404: if req.status_code == 404:
raise PackageNotFoundError raise PackageNotFoundError
return req.json() return req.json() # type: ignore
@staticmethod @staticmethod
def get_release_files_from_meta(package_meta: dict) -> list[dict]: def get_release_files_from_meta(package_meta: dict) -> list[dict]:
@ -509,7 +514,9 @@ class SyncBase:
if serial: if serial:
self.local_db.set(package_name, serial) self.local_db.set(package_name, serial)
except Exception as e: except Exception as e:
if e is ExitProgramException or e is KeyboardInterrupt: if isinstance(e, ExitProgramException) or isinstance(
e, KeyboardInterrupt
):
raise raise
logger.warning( logger.warning(
"%s generated an exception", package_name, exc_info=True "%s generated an exception", package_name, exc_info=True
@ -810,13 +817,13 @@ def get_local_serial(package_meta_path: Path) -> Optional[int]:
return None return None
try: try:
meta = json.loads(contents) meta = json.loads(contents)
return meta["last_serial"] return meta["last_serial"] # type: ignore
except Exception: except Exception:
logger.warning("cannot parse %s's JSON metadata", package_name, exc_info=True) logger.warning("cannot parse %s's JSON metadata", package_name, exc_info=True)
return None return None
def sync_shared_args(func): def sync_shared_args(func: Callable[..., Any]) -> Callable[..., Any]:
shared_options = [ shared_options = [
click.option( click.option(
"--sync-packages/--no-sync-packages", "--sync-packages/--no-sync-packages",