From 18a3ebbe03e1f16540787d49aaacfd89a4262fe0 Mon Sep 17 00:00:00 2001 From: taoky Date: Fri, 2 Aug 2024 15:48:55 +0800 Subject: [PATCH] Add local db --- .gitignore | 1 + shadowmire.py | 169 +++++++++++++++++++++++++++++++++++++++----------- 2 files changed, 135 insertions(+), 35 deletions(-) diff --git a/.gitignore b/.gitignore index 5237b83..5a6fef1 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,7 @@ .mypy_cache simple/ local.json +local.db plan.json remote.json venv/ diff --git a/shadowmire.py b/shadowmire.py index c0c7215..27442fe 100644 --- a/shadowmire.py +++ b/shadowmire.py @@ -13,12 +13,71 @@ import requests import argparse import os from contextlib import contextmanager +import sqlite3 logger = logging.getLogger(__name__) USER_AGENT = "Shadowmire (https://github.com/taoky/shadowmire)" +class LocalVersionKV: + """ + A key-value database wrapper over sqlite3. + + As it would have consistency issue if it's writing while downstream is downloading the database. + An extra "jsonpath" is used, to store kv results when necessary. + """ + def __init__(self, dbpath: Path, jsonpath: Path) -> None: + self.conn = sqlite3.connect(dbpath) + self.jsonpath = jsonpath + cur = self.conn.cursor() + cur.execute( + "CREATE TABLE IF NOT EXISTS local(key TEXT PRIMARY KEY, value INT NOT NULL)" + ) + self.conn.commit() + + def get(self, key: str) -> Optional[int]: + cur = self.conn.cursor() + res = cur.execute("SELECT key, value FROM local WHERE key = ?", (key,)) + row = res.fetchone() + return row[0] if row else None + + INSERT_SQL = "INSERT INTO local (key, value) VALUES (?, ?) ON CONFLICT(key) DO UPDATE SET value=excluded.value" + + def set(self, key: str, value: int) -> None: + cur = self.conn.cursor() + cur.execute(self.INSERT_SQL, (key, value)) + self.conn.commit() + + def batch_set(self, d: dict[str, int]) -> None: + cur = self.conn.cursor() + kvs = [(k, v) for k, v in d.items()] + cur.executemany(self.INSERT_SQL, kvs) + self.conn.commit() + + def remove(self, key: str) -> None: + cur = self.conn.cursor() + cur.execute("DELETE FROM local WHERE key = ?", (key,)) + self.conn.commit() + + def nuke(self, commit: bool = True) -> None: + cur = self.conn.cursor() + cur.execute("DELETE FROM local") + if commit: + self.conn.commit() + + def dump(self) -> dict[str, int]: + cur = self.conn.cursor() + res = cur.execute("SELECT key, value FROM local") + rows = res.fetchall() + return {row[0]: row[1] for row in rows} + + def dump_json(self) -> None: + res = self.dump() + with overwrite(self.jsonpath) as f: + json.dump(res, f) + + @contextmanager def overwrite(file_path: Path, mode: str = "w", tmp_suffix: str = ".tmp"): tmp_path = file_path.parent / (file_path.name + tmp_suffix) @@ -235,8 +294,9 @@ class Plan: class SyncBase: - def __init__(self, basedir: Path, sync_packages: bool = False) -> None: + def __init__(self, basedir: Path, local_db: LocalVersionKV, sync_packages: bool = False) -> None: self.basedir = basedir + self.local_db = local_db self.simple_dir = basedir / "simple" self.packages_dir = basedir / "packages" # create the dirs, if not exist @@ -275,9 +335,9 @@ class SyncBase: to_remove = plan.remove to_update = plan.update - for package in to_remove: - logger.info("Removing %s", package) - meta_dir = self.simple_dir / package + for package_name in to_remove: + logger.info("Removing %s", package_name) + meta_dir = self.simple_dir / package_name index_html = meta_dir / "index.html" try: with open(index_html) as f: @@ -292,12 +352,15 @@ class SyncBase: except FileNotFoundError: pass # remove all files inside meta_dir + self.local_db.remove(package_name) remove_dir_with_files(meta_dir) - for package in to_update: - logger.info("Updating %s", package) - self.do_update((package, self.remote[package])) + for idx, package_name in enumerate(to_update): + logger.info("Updating %s", package_name) + self.do_update(package_name) + if idx % 1000 == 0: + self.local_db.dump_json() - def do_update(self, package: ShadowmirePackageItem) -> bool: + def do_update(self, package_name: str) -> bool: raise NotImplementedError def finalize(self) -> None: @@ -325,10 +388,10 @@ class SyncBase: class SyncPyPI(SyncBase): - def __init__(self, basedir: Path, sync_packages: bool = False) -> None: + def __init__(self, basedir: Path, local_db: LocalVersionKV, sync_packages: bool = False) -> None: self.pypi = PyPI() self.session = create_requests_session() - super().__init__(basedir, sync_packages) + super().__init__(basedir, local_db, sync_packages) def fetch_remote_versions(self) -> dict[str, int]: remote_serials = self.pypi.list_packages_with_serial() @@ -337,13 +400,9 @@ class SyncPyPI(SyncBase): ret[normalize(key)] = remote_serials[key] return ret - def do_update(self, package: ShadowmirePackageItem) -> bool: - package_name = package[0] - # The serial get from metadata now might be newer than package_serial... - # package_serial = package[1] - - package_simple_dir = self.simple_dir / package_name - package_simple_dir.mkdir(exist_ok=True) + def do_update(self, package_name: str) -> bool: + package_simple_path = self.simple_dir / package_name + package_simple_path.mkdir(exist_ok=True) try: meta = self.pypi.get_package_metadata(package_name) logger.debug("%s meta: %s", package_name, meta) @@ -351,33 +410,36 @@ class SyncPyPI(SyncBase): logger.warning("%s missing from upstream, skip.", package_name) return False + last_serial = meta['last_serial'] # OK, here we don't bother store raw name # Considering that JSON API even does not give package raw name, why bother we use it? simple_html_contents = self.pypi.generate_html_simple_page(meta, package_name) simple_json_contents = self.pypi.generate_json_simple_page(meta) for html_filename in ("index.html", "index.v1_html"): - html_path = package_simple_dir / html_filename + html_path = package_simple_path / html_filename with overwrite(html_path) as f: f.write(simple_html_contents) for json_filename in ("index.v1_json",): - json_path = package_simple_dir / json_filename + json_path = package_simple_path / json_filename with overwrite(json_path) as f: f.write(simple_json_contents) if self.sync_packages: raise NotImplementedError + + self.local_db.set(package_name, last_serial) return True class SyncPlainHTTP(SyncBase): def __init__( - self, upstream: str, basedir: Path, sync_packages: bool = False + self, upstream: str, basedir: Path, local_db: LocalVersionKV, sync_packages: bool = False ) -> None: self.upstream = upstream self.session = create_requests_session() - super().__init__(basedir, sync_packages) + super().__init__(basedir, local_db, sync_packages) def fetch_remote_versions(self) -> dict[str, int]: remote_url = urljoin(self.upstream, "local.json") @@ -386,10 +448,9 @@ class SyncPlainHTTP(SyncBase): remote: dict[str, int] = resp.json() return remote - def do_update(self, package: tuple[str, int]) -> bool: - package_name = package[0] - package_simple_dir = self.simple_dir / package_name - package_simple_dir.mkdir(exist_ok=True) + def do_update(self, package_name) -> bool: + package_simple_path = self.simple_dir / package_name + package_simple_path.mkdir(exist_ok=True) # directly fetch remote files for filename in ("index.html", "index.v1_html", "index.v1_json"): file_url = urljoin(self.upstream, f"/simple/{package_name}/{filename}") @@ -405,32 +466,48 @@ class SyncPlainHTTP(SyncBase): else: resp.raise_for_status() content = resp.content - with open(package_simple_dir / filename, "wb") as f: + with open(package_simple_path / filename, "wb") as f: f.write(content) if self.sync_packages: raise NotImplementedError + last_serial = get_local_serial(package_simple_path) + if not last_serial: + logger.warning("cannot get valid package serial from %s", package_name) + else: + self.local_db.set(package_name, last_serial) + return True -def load_local(basedir: Path) -> dict[str, int]: +def get_local_serial(package_simple_path: Path) -> Optional[int]: + package_name = package_simple_path.name + package_index_path = package_simple_path / "index.html" try: - with open(basedir / "local.json") as f: - r = json.load(f) - return r + with open(package_index_path) as f: + contents = f.read() except FileNotFoundError: - return {} + logger.warning("%s does not have index.html, skipping", package_name) + return None + try: + serial_comment = contents.splitlines()[-1].strip() + serial = int(serial_comment.removeprefix("")) + return serial + except Exception: + logger.warning("cannot parse %s index.html", package_name, exc_info=True) + return None def main(args: argparse.Namespace) -> None: log_level = logging.DEBUG if os.environ.get("DEBUG") else logging.INFO logging.basicConfig(level=log_level) basedir = Path(".") + local_db = LocalVersionKV(basedir / "local.db", basedir / "local.json") if args.command == "sync": - sync = SyncPyPI(basedir=basedir) - local = load_local(basedir) + sync = SyncPyPI(basedir=basedir, local_db=local_db) + local = local_db.dump() plan = sync.determine_sync_plan(local) # save plan for debugging with overwrite(basedir / "plan.json") as f: @@ -438,7 +515,26 @@ def main(args: argparse.Namespace) -> None: sync.do_sync_plan(plan) sync.finalize() elif args.command == "genlocal": - pass + local = {} + for package_path in (basedir / "simple").iterdir(): + package_name = package_path.name + serial = get_local_serial(package_path) + if serial: + local[package_name] = serial + local_db.nuke(commit=False) + local_db.batch_set(local) + local_db.dump_json() + elif args.command == "verify": + sync = SyncPyPI(basedir=basedir, local_db=local_db) + remote = sync.fetch_remote_versions() + for package_name in remote: + local_serial = get_local_serial(basedir / "simple" / package_name) + if local_serial == remote[package_name]: + logger.info("%s serial same as remote", package_name) + continue + logger.info("updating %s", package_name) + sync.do_update(package_name) + sync.finalize() if __name__ == "__main__": @@ -447,7 +543,10 @@ if __name__ == "__main__": parser_sync = subparsers.add_parser("sync", help="Sync from upstream") parser_genlocal = subparsers.add_parser( - "genlocal", help="(Re)generate local.json file" + "genlocal", help="(Re)generate local db and json from simple/" + ) + parser_verify = subparsers.add_parser( + "verify", help="Verify existing sync and download missing things" ) args = parser.parse_args()