Add threading for updating packages

This commit is contained in:
taoky 2024-08-02 16:56:23 +08:00
parent 889ed86497
commit ad57590bf0
2 changed files with 68 additions and 24 deletions

View File

@ -1,6 +1,8 @@
requests==2.32.3
tqdm==4.66.4
# dev
black==24.4.2
mypy==1.11.0
types-requests~=2.32.0
types-tqdm~=4.66.0

View File

@ -13,13 +13,31 @@ import argparse
import os
from contextlib import contextmanager
import sqlite3
from concurrent.futures import ThreadPoolExecutor, as_completed
import signal
import requests
from tqdm import tqdm
logger = logging.getLogger(__name__)
USER_AGENT = "Shadowmire (https://github.com/taoky/shadowmire)"
class PackageNotFoundError(Exception):
pass
class ExitProgramException(Exception):
pass
def exit_handler(signum, frame):
raise ExitProgramException
signal.signal(signal.SIGTERM, exit_handler)
class LocalVersionKV:
"""
A key-value database wrapper over sqlite3.
@ -159,10 +177,6 @@ def create_requests_session() -> requests.Session:
return s
class PackageNotFoundError(Exception):
pass
class PyPI:
"""
Upstream which implements full PyPI APIs
@ -313,6 +327,8 @@ class SyncBase:
self.packages_dir.mkdir(parents=True, exist_ok=True)
self.sync_packages = sync_packages
self.remote: Optional[dict[str, int]] = None
# Note that it's suggested to use only 3 workers for PyPI.
self.workers = 3
def determine_sync_plan(self, local: dict[str, int]) -> Plan:
remote = self.fetch_remote_versions()
@ -345,16 +361,40 @@ class SyncBase:
to_update = plan.update
for package_name in to_remove:
logger.info("removing %s", package_name)
self.do_remove(package_name)
for idx, package_name in enumerate(to_update):
logger.info("updating %s", package_name)
self.do_update(package_name)
if idx % 1000 == 0:
self.local_db.dump_json()
with ThreadPoolExecutor(max_workers=self.workers) as executor:
futures = {
executor.submit(self.do_update, package_name, False): (
idx,
package_name,
)
for idx, package_name in enumerate(to_update)
}
def do_remove(self, package_name: str) -> bool:
try:
for future in tqdm(as_completed(futures), total=len(to_update)):
idx, package_name = futures[future]
try:
serial = future.result()
if serial:
self.local_db.set(package_name, serial)
except Exception as e:
if e is ExitProgramException:
raise
logger.warning(
"%s generated an exception", package_name, exc_info=True
)
if idx % 1000 == 0:
self.local_db.dump_json()
except ExitProgramException:
logger.info("Get ExitProgramException, exiting...")
for future in futures:
future.cancel()
sys.exit(1)
def do_remove(self, package_name: str) -> None:
logger.info("removing %s", package_name)
meta_dir = self.simple_dir / package_name
index_html = meta_dir / "index.html"
try:
@ -373,7 +413,7 @@ class SyncBase:
self.local_db.remove(package_name)
remove_dir_with_files(meta_dir)
def do_update(self, package_name: str) -> bool:
def do_update(self, package_name: str, write_db: bool = True) -> Optional[int]:
raise NotImplementedError
def finalize(self) -> None:
@ -412,7 +452,8 @@ class SyncPyPI(SyncBase):
ret[normalize(key)] = remote_serials[key]
return ret
def do_update(self, package_name: str) -> bool:
def do_update(self, package_name: str, write_db: bool = True) -> Optional[int]:
logger.info("updating %s", package_name)
package_simple_path = self.simple_dir / package_name
package_simple_path.mkdir(exist_ok=True)
try:
@ -420,9 +461,9 @@ class SyncPyPI(SyncBase):
logger.debug("%s meta: %s", package_name, meta)
except PackageNotFoundError:
logger.warning("%s missing from upstream, skip.", package_name)
return False
return None
last_serial = meta["last_serial"]
last_serial: int = meta["last_serial"]
# OK, here we don't bother store raw name
# Considering that JSON API even does not give package raw name, why bother we use it?
simple_html_contents = self.pypi.generate_html_simple_page(meta, package_name)
@ -440,9 +481,10 @@ class SyncPyPI(SyncBase):
if self.sync_packages:
raise NotImplementedError
self.local_db.set(package_name, last_serial)
if write_db:
self.local_db.set(package_name, last_serial)
return True
return last_serial
class SyncPlainHTTP(SyncBase):
@ -464,7 +506,8 @@ class SyncPlainHTTP(SyncBase):
remote: dict[str, int] = resp.json()
return remote
def do_update(self, package_name: str) -> bool:
def do_update(self, package_name: str, write_db: bool = True) -> Optional[int]:
logger.info("updating %s", package_name)
package_simple_path = self.simple_dir / package_name
package_simple_path.mkdir(exist_ok=True)
# directly fetch remote files
@ -477,7 +520,7 @@ class SyncPlainHTTP(SyncBase):
continue
else:
logger.error("%s does not exist. Stop with this.", file_url)
return False
return None
else:
resp.raise_for_status()
content = resp.content
@ -491,9 +534,10 @@ class SyncPlainHTTP(SyncBase):
if not last_serial:
logger.warning("cannot get valid package serial from %s", package_name)
else:
self.local_db.set(package_name, last_serial)
if write_db:
self.local_db.set(package_name, last_serial)
return True
return last_serial
def get_local_serial(package_simple_path: Path) -> Optional[int]:
@ -542,12 +586,10 @@ def main(args: argparse.Namespace) -> None:
elif args.command == "verify":
sync = SyncPyPI(basedir=basedir, local_db=local_db)
local_names = set(local_db.keys())
simple_dirs = set(list((basedir / "simple").iterdir()))
simple_dirs = set([i.name for i in (basedir / "simple").iterdir()])
for package_name in simple_dirs - local_names:
logger.info("removing %s", package_name)
sync.do_remove(package_name)
for package_name in local_names:
logger.info("updating %s", package_name)
sync.do_update(package_name)
sync.finalize()