mirror of
https://github.com/taoky/shadowmire.git
synced 2025-07-08 09:12:43 +00:00
Add threading for updating packages
This commit is contained in:
parent
889ed86497
commit
ad57590bf0
@ -1,6 +1,8 @@
|
|||||||
requests==2.32.3
|
requests==2.32.3
|
||||||
|
tqdm==4.66.4
|
||||||
|
|
||||||
# dev
|
# dev
|
||||||
black==24.4.2
|
black==24.4.2
|
||||||
mypy==1.11.0
|
mypy==1.11.0
|
||||||
types-requests~=2.32.0
|
types-requests~=2.32.0
|
||||||
|
types-tqdm~=4.66.0
|
||||||
|
@ -13,13 +13,31 @@ import argparse
|
|||||||
import os
|
import os
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
import sqlite3
|
import sqlite3
|
||||||
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
|
import signal
|
||||||
import requests
|
import requests
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
USER_AGENT = "Shadowmire (https://github.com/taoky/shadowmire)"
|
USER_AGENT = "Shadowmire (https://github.com/taoky/shadowmire)"
|
||||||
|
|
||||||
|
|
||||||
|
class PackageNotFoundError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ExitProgramException(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def exit_handler(signum, frame):
|
||||||
|
raise ExitProgramException
|
||||||
|
|
||||||
|
|
||||||
|
signal.signal(signal.SIGTERM, exit_handler)
|
||||||
|
|
||||||
|
|
||||||
class LocalVersionKV:
|
class LocalVersionKV:
|
||||||
"""
|
"""
|
||||||
A key-value database wrapper over sqlite3.
|
A key-value database wrapper over sqlite3.
|
||||||
@ -159,10 +177,6 @@ def create_requests_session() -> requests.Session:
|
|||||||
return s
|
return s
|
||||||
|
|
||||||
|
|
||||||
class PackageNotFoundError(Exception):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class PyPI:
|
class PyPI:
|
||||||
"""
|
"""
|
||||||
Upstream which implements full PyPI APIs
|
Upstream which implements full PyPI APIs
|
||||||
@ -313,6 +327,8 @@ class SyncBase:
|
|||||||
self.packages_dir.mkdir(parents=True, exist_ok=True)
|
self.packages_dir.mkdir(parents=True, exist_ok=True)
|
||||||
self.sync_packages = sync_packages
|
self.sync_packages = sync_packages
|
||||||
self.remote: Optional[dict[str, int]] = None
|
self.remote: Optional[dict[str, int]] = None
|
||||||
|
# Note that it's suggested to use only 3 workers for PyPI.
|
||||||
|
self.workers = 3
|
||||||
|
|
||||||
def determine_sync_plan(self, local: dict[str, int]) -> Plan:
|
def determine_sync_plan(self, local: dict[str, int]) -> Plan:
|
||||||
remote = self.fetch_remote_versions()
|
remote = self.fetch_remote_versions()
|
||||||
@ -345,16 +361,40 @@ class SyncBase:
|
|||||||
to_update = plan.update
|
to_update = plan.update
|
||||||
|
|
||||||
for package_name in to_remove:
|
for package_name in to_remove:
|
||||||
logger.info("removing %s", package_name)
|
|
||||||
self.do_remove(package_name)
|
self.do_remove(package_name)
|
||||||
|
|
||||||
for idx, package_name in enumerate(to_update):
|
with ThreadPoolExecutor(max_workers=self.workers) as executor:
|
||||||
logger.info("updating %s", package_name)
|
futures = {
|
||||||
self.do_update(package_name)
|
executor.submit(self.do_update, package_name, False): (
|
||||||
if idx % 1000 == 0:
|
idx,
|
||||||
self.local_db.dump_json()
|
package_name,
|
||||||
|
)
|
||||||
|
for idx, package_name in enumerate(to_update)
|
||||||
|
}
|
||||||
|
|
||||||
def do_remove(self, package_name: str) -> bool:
|
try:
|
||||||
|
for future in tqdm(as_completed(futures), total=len(to_update)):
|
||||||
|
idx, package_name = futures[future]
|
||||||
|
try:
|
||||||
|
serial = future.result()
|
||||||
|
if serial:
|
||||||
|
self.local_db.set(package_name, serial)
|
||||||
|
except Exception as e:
|
||||||
|
if e is ExitProgramException:
|
||||||
|
raise
|
||||||
|
logger.warning(
|
||||||
|
"%s generated an exception", package_name, exc_info=True
|
||||||
|
)
|
||||||
|
if idx % 1000 == 0:
|
||||||
|
self.local_db.dump_json()
|
||||||
|
except ExitProgramException:
|
||||||
|
logger.info("Get ExitProgramException, exiting...")
|
||||||
|
for future in futures:
|
||||||
|
future.cancel()
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
def do_remove(self, package_name: str) -> None:
|
||||||
|
logger.info("removing %s", package_name)
|
||||||
meta_dir = self.simple_dir / package_name
|
meta_dir = self.simple_dir / package_name
|
||||||
index_html = meta_dir / "index.html"
|
index_html = meta_dir / "index.html"
|
||||||
try:
|
try:
|
||||||
@ -373,7 +413,7 @@ class SyncBase:
|
|||||||
self.local_db.remove(package_name)
|
self.local_db.remove(package_name)
|
||||||
remove_dir_with_files(meta_dir)
|
remove_dir_with_files(meta_dir)
|
||||||
|
|
||||||
def do_update(self, package_name: str) -> bool:
|
def do_update(self, package_name: str, write_db: bool = True) -> Optional[int]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def finalize(self) -> None:
|
def finalize(self) -> None:
|
||||||
@ -412,7 +452,8 @@ class SyncPyPI(SyncBase):
|
|||||||
ret[normalize(key)] = remote_serials[key]
|
ret[normalize(key)] = remote_serials[key]
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
def do_update(self, package_name: str) -> bool:
|
def do_update(self, package_name: str, write_db: bool = True) -> Optional[int]:
|
||||||
|
logger.info("updating %s", package_name)
|
||||||
package_simple_path = self.simple_dir / package_name
|
package_simple_path = self.simple_dir / package_name
|
||||||
package_simple_path.mkdir(exist_ok=True)
|
package_simple_path.mkdir(exist_ok=True)
|
||||||
try:
|
try:
|
||||||
@ -420,9 +461,9 @@ class SyncPyPI(SyncBase):
|
|||||||
logger.debug("%s meta: %s", package_name, meta)
|
logger.debug("%s meta: %s", package_name, meta)
|
||||||
except PackageNotFoundError:
|
except PackageNotFoundError:
|
||||||
logger.warning("%s missing from upstream, skip.", package_name)
|
logger.warning("%s missing from upstream, skip.", package_name)
|
||||||
return False
|
return None
|
||||||
|
|
||||||
last_serial = meta["last_serial"]
|
last_serial: int = meta["last_serial"]
|
||||||
# OK, here we don't bother store raw name
|
# OK, here we don't bother store raw name
|
||||||
# Considering that JSON API even does not give package raw name, why bother we use it?
|
# Considering that JSON API even does not give package raw name, why bother we use it?
|
||||||
simple_html_contents = self.pypi.generate_html_simple_page(meta, package_name)
|
simple_html_contents = self.pypi.generate_html_simple_page(meta, package_name)
|
||||||
@ -440,9 +481,10 @@ class SyncPyPI(SyncBase):
|
|||||||
if self.sync_packages:
|
if self.sync_packages:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
self.local_db.set(package_name, last_serial)
|
if write_db:
|
||||||
|
self.local_db.set(package_name, last_serial)
|
||||||
|
|
||||||
return True
|
return last_serial
|
||||||
|
|
||||||
|
|
||||||
class SyncPlainHTTP(SyncBase):
|
class SyncPlainHTTP(SyncBase):
|
||||||
@ -464,7 +506,8 @@ class SyncPlainHTTP(SyncBase):
|
|||||||
remote: dict[str, int] = resp.json()
|
remote: dict[str, int] = resp.json()
|
||||||
return remote
|
return remote
|
||||||
|
|
||||||
def do_update(self, package_name: str) -> bool:
|
def do_update(self, package_name: str, write_db: bool = True) -> Optional[int]:
|
||||||
|
logger.info("updating %s", package_name)
|
||||||
package_simple_path = self.simple_dir / package_name
|
package_simple_path = self.simple_dir / package_name
|
||||||
package_simple_path.mkdir(exist_ok=True)
|
package_simple_path.mkdir(exist_ok=True)
|
||||||
# directly fetch remote files
|
# directly fetch remote files
|
||||||
@ -477,7 +520,7 @@ class SyncPlainHTTP(SyncBase):
|
|||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
logger.error("%s does not exist. Stop with this.", file_url)
|
logger.error("%s does not exist. Stop with this.", file_url)
|
||||||
return False
|
return None
|
||||||
else:
|
else:
|
||||||
resp.raise_for_status()
|
resp.raise_for_status()
|
||||||
content = resp.content
|
content = resp.content
|
||||||
@ -491,9 +534,10 @@ class SyncPlainHTTP(SyncBase):
|
|||||||
if not last_serial:
|
if not last_serial:
|
||||||
logger.warning("cannot get valid package serial from %s", package_name)
|
logger.warning("cannot get valid package serial from %s", package_name)
|
||||||
else:
|
else:
|
||||||
self.local_db.set(package_name, last_serial)
|
if write_db:
|
||||||
|
self.local_db.set(package_name, last_serial)
|
||||||
|
|
||||||
return True
|
return last_serial
|
||||||
|
|
||||||
|
|
||||||
def get_local_serial(package_simple_path: Path) -> Optional[int]:
|
def get_local_serial(package_simple_path: Path) -> Optional[int]:
|
||||||
@ -542,12 +586,10 @@ def main(args: argparse.Namespace) -> None:
|
|||||||
elif args.command == "verify":
|
elif args.command == "verify":
|
||||||
sync = SyncPyPI(basedir=basedir, local_db=local_db)
|
sync = SyncPyPI(basedir=basedir, local_db=local_db)
|
||||||
local_names = set(local_db.keys())
|
local_names = set(local_db.keys())
|
||||||
simple_dirs = set(list((basedir / "simple").iterdir()))
|
simple_dirs = set([i.name for i in (basedir / "simple").iterdir()])
|
||||||
for package_name in simple_dirs - local_names:
|
for package_name in simple_dirs - local_names:
|
||||||
logger.info("removing %s", package_name)
|
|
||||||
sync.do_remove(package_name)
|
sync.do_remove(package_name)
|
||||||
for package_name in local_names:
|
for package_name in local_names:
|
||||||
logger.info("updating %s", package_name)
|
|
||||||
sync.do_update(package_name)
|
sync.do_update(package_name)
|
||||||
sync.finalize()
|
sync.finalize()
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user