From ad57590bf0a33b48e1be45ef0584bc14a91211c6 Mon Sep 17 00:00:00 2001 From: taoky Date: Fri, 2 Aug 2024 16:56:23 +0800 Subject: [PATCH] Add threading for updating packages --- requirements.txt | 2 ++ shadowmire.py | 90 +++++++++++++++++++++++++++++++++++------------- 2 files changed, 68 insertions(+), 24 deletions(-) diff --git a/requirements.txt b/requirements.txt index 4a84c9b..e62b54d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,8 @@ requests==2.32.3 +tqdm==4.66.4 # dev black==24.4.2 mypy==1.11.0 types-requests~=2.32.0 +types-tqdm~=4.66.0 diff --git a/shadowmire.py b/shadowmire.py index 79f3de5..68e23d2 100644 --- a/shadowmire.py +++ b/shadowmire.py @@ -13,13 +13,31 @@ import argparse import os from contextlib import contextmanager import sqlite3 +from concurrent.futures import ThreadPoolExecutor, as_completed +import signal import requests +from tqdm import tqdm logger = logging.getLogger(__name__) USER_AGENT = "Shadowmire (https://github.com/taoky/shadowmire)" +class PackageNotFoundError(Exception): + pass + + +class ExitProgramException(Exception): + pass + + +def exit_handler(signum, frame): + raise ExitProgramException + + +signal.signal(signal.SIGTERM, exit_handler) + + class LocalVersionKV: """ A key-value database wrapper over sqlite3. @@ -159,10 +177,6 @@ def create_requests_session() -> requests.Session: return s -class PackageNotFoundError(Exception): - pass - - class PyPI: """ Upstream which implements full PyPI APIs @@ -313,6 +327,8 @@ class SyncBase: self.packages_dir.mkdir(parents=True, exist_ok=True) self.sync_packages = sync_packages self.remote: Optional[dict[str, int]] = None + # Note that it's suggested to use only 3 workers for PyPI. + self.workers = 3 def determine_sync_plan(self, local: dict[str, int]) -> Plan: remote = self.fetch_remote_versions() @@ -345,16 +361,40 @@ class SyncBase: to_update = plan.update for package_name in to_remove: - logger.info("removing %s", package_name) self.do_remove(package_name) - for idx, package_name in enumerate(to_update): - logger.info("updating %s", package_name) - self.do_update(package_name) - if idx % 1000 == 0: - self.local_db.dump_json() + with ThreadPoolExecutor(max_workers=self.workers) as executor: + futures = { + executor.submit(self.do_update, package_name, False): ( + idx, + package_name, + ) + for idx, package_name in enumerate(to_update) + } - def do_remove(self, package_name: str) -> bool: + try: + for future in tqdm(as_completed(futures), total=len(to_update)): + idx, package_name = futures[future] + try: + serial = future.result() + if serial: + self.local_db.set(package_name, serial) + except Exception as e: + if e is ExitProgramException: + raise + logger.warning( + "%s generated an exception", package_name, exc_info=True + ) + if idx % 1000 == 0: + self.local_db.dump_json() + except ExitProgramException: + logger.info("Get ExitProgramException, exiting...") + for future in futures: + future.cancel() + sys.exit(1) + + def do_remove(self, package_name: str) -> None: + logger.info("removing %s", package_name) meta_dir = self.simple_dir / package_name index_html = meta_dir / "index.html" try: @@ -373,7 +413,7 @@ class SyncBase: self.local_db.remove(package_name) remove_dir_with_files(meta_dir) - def do_update(self, package_name: str) -> bool: + def do_update(self, package_name: str, write_db: bool = True) -> Optional[int]: raise NotImplementedError def finalize(self) -> None: @@ -412,7 +452,8 @@ class SyncPyPI(SyncBase): ret[normalize(key)] = remote_serials[key] return ret - def do_update(self, package_name: str) -> bool: + def do_update(self, package_name: str, write_db: bool = True) -> Optional[int]: + logger.info("updating %s", package_name) package_simple_path = self.simple_dir / package_name package_simple_path.mkdir(exist_ok=True) try: @@ -420,9 +461,9 @@ class SyncPyPI(SyncBase): logger.debug("%s meta: %s", package_name, meta) except PackageNotFoundError: logger.warning("%s missing from upstream, skip.", package_name) - return False + return None - last_serial = meta["last_serial"] + last_serial: int = meta["last_serial"] # OK, here we don't bother store raw name # Considering that JSON API even does not give package raw name, why bother we use it? simple_html_contents = self.pypi.generate_html_simple_page(meta, package_name) @@ -440,9 +481,10 @@ class SyncPyPI(SyncBase): if self.sync_packages: raise NotImplementedError - self.local_db.set(package_name, last_serial) + if write_db: + self.local_db.set(package_name, last_serial) - return True + return last_serial class SyncPlainHTTP(SyncBase): @@ -464,7 +506,8 @@ class SyncPlainHTTP(SyncBase): remote: dict[str, int] = resp.json() return remote - def do_update(self, package_name: str) -> bool: + def do_update(self, package_name: str, write_db: bool = True) -> Optional[int]: + logger.info("updating %s", package_name) package_simple_path = self.simple_dir / package_name package_simple_path.mkdir(exist_ok=True) # directly fetch remote files @@ -477,7 +520,7 @@ class SyncPlainHTTP(SyncBase): continue else: logger.error("%s does not exist. Stop with this.", file_url) - return False + return None else: resp.raise_for_status() content = resp.content @@ -491,9 +534,10 @@ class SyncPlainHTTP(SyncBase): if not last_serial: logger.warning("cannot get valid package serial from %s", package_name) else: - self.local_db.set(package_name, last_serial) + if write_db: + self.local_db.set(package_name, last_serial) - return True + return last_serial def get_local_serial(package_simple_path: Path) -> Optional[int]: @@ -542,12 +586,10 @@ def main(args: argparse.Namespace) -> None: elif args.command == "verify": sync = SyncPyPI(basedir=basedir, local_db=local_db) local_names = set(local_db.keys()) - simple_dirs = set(list((basedir / "simple").iterdir())) + simple_dirs = set([i.name for i in (basedir / "simple").iterdir()]) for package_name in simple_dirs - local_names: - logger.info("removing %s", package_name) sync.do_remove(package_name) for package_name in local_names: - logger.info("updating %s", package_name) sync.do_update(package_name) sync.finalize()