mirror of
https://github.com/taoky/shadowmire.git
synced 2025-07-07 16:52:43 +00:00
Add threading for updating packages
This commit is contained in:
parent
889ed86497
commit
ad57590bf0
@ -1,6 +1,8 @@
|
||||
requests==2.32.3
|
||||
tqdm==4.66.4
|
||||
|
||||
# dev
|
||||
black==24.4.2
|
||||
mypy==1.11.0
|
||||
types-requests~=2.32.0
|
||||
types-tqdm~=4.66.0
|
||||
|
@ -13,13 +13,31 @@ import argparse
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
import sqlite3
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
import signal
|
||||
import requests
|
||||
from tqdm import tqdm
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
USER_AGENT = "Shadowmire (https://github.com/taoky/shadowmire)"
|
||||
|
||||
|
||||
class PackageNotFoundError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class ExitProgramException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def exit_handler(signum, frame):
|
||||
raise ExitProgramException
|
||||
|
||||
|
||||
signal.signal(signal.SIGTERM, exit_handler)
|
||||
|
||||
|
||||
class LocalVersionKV:
|
||||
"""
|
||||
A key-value database wrapper over sqlite3.
|
||||
@ -159,10 +177,6 @@ def create_requests_session() -> requests.Session:
|
||||
return s
|
||||
|
||||
|
||||
class PackageNotFoundError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class PyPI:
|
||||
"""
|
||||
Upstream which implements full PyPI APIs
|
||||
@ -313,6 +327,8 @@ class SyncBase:
|
||||
self.packages_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.sync_packages = sync_packages
|
||||
self.remote: Optional[dict[str, int]] = None
|
||||
# Note that it's suggested to use only 3 workers for PyPI.
|
||||
self.workers = 3
|
||||
|
||||
def determine_sync_plan(self, local: dict[str, int]) -> Plan:
|
||||
remote = self.fetch_remote_versions()
|
||||
@ -345,16 +361,40 @@ class SyncBase:
|
||||
to_update = plan.update
|
||||
|
||||
for package_name in to_remove:
|
||||
logger.info("removing %s", package_name)
|
||||
self.do_remove(package_name)
|
||||
|
||||
for idx, package_name in enumerate(to_update):
|
||||
logger.info("updating %s", package_name)
|
||||
self.do_update(package_name)
|
||||
if idx % 1000 == 0:
|
||||
self.local_db.dump_json()
|
||||
with ThreadPoolExecutor(max_workers=self.workers) as executor:
|
||||
futures = {
|
||||
executor.submit(self.do_update, package_name, False): (
|
||||
idx,
|
||||
package_name,
|
||||
)
|
||||
for idx, package_name in enumerate(to_update)
|
||||
}
|
||||
|
||||
def do_remove(self, package_name: str) -> bool:
|
||||
try:
|
||||
for future in tqdm(as_completed(futures), total=len(to_update)):
|
||||
idx, package_name = futures[future]
|
||||
try:
|
||||
serial = future.result()
|
||||
if serial:
|
||||
self.local_db.set(package_name, serial)
|
||||
except Exception as e:
|
||||
if e is ExitProgramException:
|
||||
raise
|
||||
logger.warning(
|
||||
"%s generated an exception", package_name, exc_info=True
|
||||
)
|
||||
if idx % 1000 == 0:
|
||||
self.local_db.dump_json()
|
||||
except ExitProgramException:
|
||||
logger.info("Get ExitProgramException, exiting...")
|
||||
for future in futures:
|
||||
future.cancel()
|
||||
sys.exit(1)
|
||||
|
||||
def do_remove(self, package_name: str) -> None:
|
||||
logger.info("removing %s", package_name)
|
||||
meta_dir = self.simple_dir / package_name
|
||||
index_html = meta_dir / "index.html"
|
||||
try:
|
||||
@ -373,7 +413,7 @@ class SyncBase:
|
||||
self.local_db.remove(package_name)
|
||||
remove_dir_with_files(meta_dir)
|
||||
|
||||
def do_update(self, package_name: str) -> bool:
|
||||
def do_update(self, package_name: str, write_db: bool = True) -> Optional[int]:
|
||||
raise NotImplementedError
|
||||
|
||||
def finalize(self) -> None:
|
||||
@ -412,7 +452,8 @@ class SyncPyPI(SyncBase):
|
||||
ret[normalize(key)] = remote_serials[key]
|
||||
return ret
|
||||
|
||||
def do_update(self, package_name: str) -> bool:
|
||||
def do_update(self, package_name: str, write_db: bool = True) -> Optional[int]:
|
||||
logger.info("updating %s", package_name)
|
||||
package_simple_path = self.simple_dir / package_name
|
||||
package_simple_path.mkdir(exist_ok=True)
|
||||
try:
|
||||
@ -420,9 +461,9 @@ class SyncPyPI(SyncBase):
|
||||
logger.debug("%s meta: %s", package_name, meta)
|
||||
except PackageNotFoundError:
|
||||
logger.warning("%s missing from upstream, skip.", package_name)
|
||||
return False
|
||||
return None
|
||||
|
||||
last_serial = meta["last_serial"]
|
||||
last_serial: int = meta["last_serial"]
|
||||
# OK, here we don't bother store raw name
|
||||
# Considering that JSON API even does not give package raw name, why bother we use it?
|
||||
simple_html_contents = self.pypi.generate_html_simple_page(meta, package_name)
|
||||
@ -440,9 +481,10 @@ class SyncPyPI(SyncBase):
|
||||
if self.sync_packages:
|
||||
raise NotImplementedError
|
||||
|
||||
self.local_db.set(package_name, last_serial)
|
||||
if write_db:
|
||||
self.local_db.set(package_name, last_serial)
|
||||
|
||||
return True
|
||||
return last_serial
|
||||
|
||||
|
||||
class SyncPlainHTTP(SyncBase):
|
||||
@ -464,7 +506,8 @@ class SyncPlainHTTP(SyncBase):
|
||||
remote: dict[str, int] = resp.json()
|
||||
return remote
|
||||
|
||||
def do_update(self, package_name: str) -> bool:
|
||||
def do_update(self, package_name: str, write_db: bool = True) -> Optional[int]:
|
||||
logger.info("updating %s", package_name)
|
||||
package_simple_path = self.simple_dir / package_name
|
||||
package_simple_path.mkdir(exist_ok=True)
|
||||
# directly fetch remote files
|
||||
@ -477,7 +520,7 @@ class SyncPlainHTTP(SyncBase):
|
||||
continue
|
||||
else:
|
||||
logger.error("%s does not exist. Stop with this.", file_url)
|
||||
return False
|
||||
return None
|
||||
else:
|
||||
resp.raise_for_status()
|
||||
content = resp.content
|
||||
@ -491,9 +534,10 @@ class SyncPlainHTTP(SyncBase):
|
||||
if not last_serial:
|
||||
logger.warning("cannot get valid package serial from %s", package_name)
|
||||
else:
|
||||
self.local_db.set(package_name, last_serial)
|
||||
if write_db:
|
||||
self.local_db.set(package_name, last_serial)
|
||||
|
||||
return True
|
||||
return last_serial
|
||||
|
||||
|
||||
def get_local_serial(package_simple_path: Path) -> Optional[int]:
|
||||
@ -542,12 +586,10 @@ def main(args: argparse.Namespace) -> None:
|
||||
elif args.command == "verify":
|
||||
sync = SyncPyPI(basedir=basedir, local_db=local_db)
|
||||
local_names = set(local_db.keys())
|
||||
simple_dirs = set(list((basedir / "simple").iterdir()))
|
||||
simple_dirs = set([i.name for i in (basedir / "simple").iterdir()])
|
||||
for package_name in simple_dirs - local_names:
|
||||
logger.info("removing %s", package_name)
|
||||
sync.do_remove(package_name)
|
||||
for package_name in local_names:
|
||||
logger.info("updating %s", package_name)
|
||||
sync.do_update(package_name)
|
||||
sync.finalize()
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user