Add local db

This commit is contained in:
taoky 2024-08-02 15:48:55 +08:00
parent 577d021d70
commit 18a3ebbe03
2 changed files with 135 additions and 35 deletions

1
.gitignore vendored
View File

@ -1,6 +1,7 @@
.mypy_cache .mypy_cache
simple/ simple/
local.json local.json
local.db
plan.json plan.json
remote.json remote.json
venv/ venv/

View File

@ -13,12 +13,71 @@ import requests
import argparse import argparse
import os import os
from contextlib import contextmanager from contextlib import contextmanager
import sqlite3
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
USER_AGENT = "Shadowmire (https://github.com/taoky/shadowmire)" USER_AGENT = "Shadowmire (https://github.com/taoky/shadowmire)"
class LocalVersionKV:
"""
A key-value database wrapper over sqlite3.
As it would have consistency issue if it's writing while downstream is downloading the database.
An extra "jsonpath" is used, to store kv results when necessary.
"""
def __init__(self, dbpath: Path, jsonpath: Path) -> None:
self.conn = sqlite3.connect(dbpath)
self.jsonpath = jsonpath
cur = self.conn.cursor()
cur.execute(
"CREATE TABLE IF NOT EXISTS local(key TEXT PRIMARY KEY, value INT NOT NULL)"
)
self.conn.commit()
def get(self, key: str) -> Optional[int]:
cur = self.conn.cursor()
res = cur.execute("SELECT key, value FROM local WHERE key = ?", (key,))
row = res.fetchone()
return row[0] if row else None
INSERT_SQL = "INSERT INTO local (key, value) VALUES (?, ?) ON CONFLICT(key) DO UPDATE SET value=excluded.value"
def set(self, key: str, value: int) -> None:
cur = self.conn.cursor()
cur.execute(self.INSERT_SQL, (key, value))
self.conn.commit()
def batch_set(self, d: dict[str, int]) -> None:
cur = self.conn.cursor()
kvs = [(k, v) for k, v in d.items()]
cur.executemany(self.INSERT_SQL, kvs)
self.conn.commit()
def remove(self, key: str) -> None:
cur = self.conn.cursor()
cur.execute("DELETE FROM local WHERE key = ?", (key,))
self.conn.commit()
def nuke(self, commit: bool = True) -> None:
cur = self.conn.cursor()
cur.execute("DELETE FROM local")
if commit:
self.conn.commit()
def dump(self) -> dict[str, int]:
cur = self.conn.cursor()
res = cur.execute("SELECT key, value FROM local")
rows = res.fetchall()
return {row[0]: row[1] for row in rows}
def dump_json(self) -> None:
res = self.dump()
with overwrite(self.jsonpath) as f:
json.dump(res, f)
@contextmanager @contextmanager
def overwrite(file_path: Path, mode: str = "w", tmp_suffix: str = ".tmp"): def overwrite(file_path: Path, mode: str = "w", tmp_suffix: str = ".tmp"):
tmp_path = file_path.parent / (file_path.name + tmp_suffix) tmp_path = file_path.parent / (file_path.name + tmp_suffix)
@ -235,8 +294,9 @@ class Plan:
class SyncBase: class SyncBase:
def __init__(self, basedir: Path, sync_packages: bool = False) -> None: def __init__(self, basedir: Path, local_db: LocalVersionKV, sync_packages: bool = False) -> None:
self.basedir = basedir self.basedir = basedir
self.local_db = local_db
self.simple_dir = basedir / "simple" self.simple_dir = basedir / "simple"
self.packages_dir = basedir / "packages" self.packages_dir = basedir / "packages"
# create the dirs, if not exist # create the dirs, if not exist
@ -275,9 +335,9 @@ class SyncBase:
to_remove = plan.remove to_remove = plan.remove
to_update = plan.update to_update = plan.update
for package in to_remove: for package_name in to_remove:
logger.info("Removing %s", package) logger.info("Removing %s", package_name)
meta_dir = self.simple_dir / package meta_dir = self.simple_dir / package_name
index_html = meta_dir / "index.html" index_html = meta_dir / "index.html"
try: try:
with open(index_html) as f: with open(index_html) as f:
@ -292,12 +352,15 @@ class SyncBase:
except FileNotFoundError: except FileNotFoundError:
pass pass
# remove all files inside meta_dir # remove all files inside meta_dir
self.local_db.remove(package_name)
remove_dir_with_files(meta_dir) remove_dir_with_files(meta_dir)
for package in to_update: for idx, package_name in enumerate(to_update):
logger.info("Updating %s", package) logger.info("Updating %s", package_name)
self.do_update((package, self.remote[package])) self.do_update(package_name)
if idx % 1000 == 0:
self.local_db.dump_json()
def do_update(self, package: ShadowmirePackageItem) -> bool: def do_update(self, package_name: str) -> bool:
raise NotImplementedError raise NotImplementedError
def finalize(self) -> None: def finalize(self) -> None:
@ -325,10 +388,10 @@ class SyncBase:
class SyncPyPI(SyncBase): class SyncPyPI(SyncBase):
def __init__(self, basedir: Path, sync_packages: bool = False) -> None: def __init__(self, basedir: Path, local_db: LocalVersionKV, sync_packages: bool = False) -> None:
self.pypi = PyPI() self.pypi = PyPI()
self.session = create_requests_session() self.session = create_requests_session()
super().__init__(basedir, sync_packages) super().__init__(basedir, local_db, sync_packages)
def fetch_remote_versions(self) -> dict[str, int]: def fetch_remote_versions(self) -> dict[str, int]:
remote_serials = self.pypi.list_packages_with_serial() remote_serials = self.pypi.list_packages_with_serial()
@ -337,13 +400,9 @@ class SyncPyPI(SyncBase):
ret[normalize(key)] = remote_serials[key] ret[normalize(key)] = remote_serials[key]
return ret return ret
def do_update(self, package: ShadowmirePackageItem) -> bool: def do_update(self, package_name: str) -> bool:
package_name = package[0] package_simple_path = self.simple_dir / package_name
# The serial get from metadata now might be newer than package_serial... package_simple_path.mkdir(exist_ok=True)
# package_serial = package[1]
package_simple_dir = self.simple_dir / package_name
package_simple_dir.mkdir(exist_ok=True)
try: try:
meta = self.pypi.get_package_metadata(package_name) meta = self.pypi.get_package_metadata(package_name)
logger.debug("%s meta: %s", package_name, meta) logger.debug("%s meta: %s", package_name, meta)
@ -351,33 +410,36 @@ class SyncPyPI(SyncBase):
logger.warning("%s missing from upstream, skip.", package_name) logger.warning("%s missing from upstream, skip.", package_name)
return False return False
last_serial = meta['last_serial']
# OK, here we don't bother store raw name # OK, here we don't bother store raw name
# Considering that JSON API even does not give package raw name, why bother we use it? # Considering that JSON API even does not give package raw name, why bother we use it?
simple_html_contents = self.pypi.generate_html_simple_page(meta, package_name) simple_html_contents = self.pypi.generate_html_simple_page(meta, package_name)
simple_json_contents = self.pypi.generate_json_simple_page(meta) simple_json_contents = self.pypi.generate_json_simple_page(meta)
for html_filename in ("index.html", "index.v1_html"): for html_filename in ("index.html", "index.v1_html"):
html_path = package_simple_dir / html_filename html_path = package_simple_path / html_filename
with overwrite(html_path) as f: with overwrite(html_path) as f:
f.write(simple_html_contents) f.write(simple_html_contents)
for json_filename in ("index.v1_json",): for json_filename in ("index.v1_json",):
json_path = package_simple_dir / json_filename json_path = package_simple_path / json_filename
with overwrite(json_path) as f: with overwrite(json_path) as f:
f.write(simple_json_contents) f.write(simple_json_contents)
if self.sync_packages: if self.sync_packages:
raise NotImplementedError raise NotImplementedError
self.local_db.set(package_name, last_serial)
return True return True
class SyncPlainHTTP(SyncBase): class SyncPlainHTTP(SyncBase):
def __init__( def __init__(
self, upstream: str, basedir: Path, sync_packages: bool = False self, upstream: str, basedir: Path, local_db: LocalVersionKV, sync_packages: bool = False
) -> None: ) -> None:
self.upstream = upstream self.upstream = upstream
self.session = create_requests_session() self.session = create_requests_session()
super().__init__(basedir, sync_packages) super().__init__(basedir, local_db, sync_packages)
def fetch_remote_versions(self) -> dict[str, int]: def fetch_remote_versions(self) -> dict[str, int]:
remote_url = urljoin(self.upstream, "local.json") remote_url = urljoin(self.upstream, "local.json")
@ -386,10 +448,9 @@ class SyncPlainHTTP(SyncBase):
remote: dict[str, int] = resp.json() remote: dict[str, int] = resp.json()
return remote return remote
def do_update(self, package: tuple[str, int]) -> bool: def do_update(self, package_name) -> bool:
package_name = package[0] package_simple_path = self.simple_dir / package_name
package_simple_dir = self.simple_dir / package_name package_simple_path.mkdir(exist_ok=True)
package_simple_dir.mkdir(exist_ok=True)
# directly fetch remote files # directly fetch remote files
for filename in ("index.html", "index.v1_html", "index.v1_json"): for filename in ("index.html", "index.v1_html", "index.v1_json"):
file_url = urljoin(self.upstream, f"/simple/{package_name}/{filename}") file_url = urljoin(self.upstream, f"/simple/{package_name}/{filename}")
@ -405,32 +466,48 @@ class SyncPlainHTTP(SyncBase):
else: else:
resp.raise_for_status() resp.raise_for_status()
content = resp.content content = resp.content
with open(package_simple_dir / filename, "wb") as f: with open(package_simple_path / filename, "wb") as f:
f.write(content) f.write(content)
if self.sync_packages: if self.sync_packages:
raise NotImplementedError raise NotImplementedError
last_serial = get_local_serial(package_simple_path)
if not last_serial:
logger.warning("cannot get valid package serial from %s", package_name)
else:
self.local_db.set(package_name, last_serial)
return True return True
def load_local(basedir: Path) -> dict[str, int]: def get_local_serial(package_simple_path: Path) -> Optional[int]:
package_name = package_simple_path.name
package_index_path = package_simple_path / "index.html"
try: try:
with open(basedir / "local.json") as f: with open(package_index_path) as f:
r = json.load(f) contents = f.read()
return r
except FileNotFoundError: except FileNotFoundError:
return {} logger.warning("%s does not have index.html, skipping", package_name)
return None
try:
serial_comment = contents.splitlines()[-1].strip()
serial = int(serial_comment.removeprefix("<!--SERIAL ").removesuffix("-->"))
return serial
except Exception:
logger.warning("cannot parse %s index.html", package_name, exc_info=True)
return None
def main(args: argparse.Namespace) -> None: def main(args: argparse.Namespace) -> None:
log_level = logging.DEBUG if os.environ.get("DEBUG") else logging.INFO log_level = logging.DEBUG if os.environ.get("DEBUG") else logging.INFO
logging.basicConfig(level=log_level) logging.basicConfig(level=log_level)
basedir = Path(".") basedir = Path(".")
local_db = LocalVersionKV(basedir / "local.db", basedir / "local.json")
if args.command == "sync": if args.command == "sync":
sync = SyncPyPI(basedir=basedir) sync = SyncPyPI(basedir=basedir, local_db=local_db)
local = load_local(basedir) local = local_db.dump()
plan = sync.determine_sync_plan(local) plan = sync.determine_sync_plan(local)
# save plan for debugging # save plan for debugging
with overwrite(basedir / "plan.json") as f: with overwrite(basedir / "plan.json") as f:
@ -438,7 +515,26 @@ def main(args: argparse.Namespace) -> None:
sync.do_sync_plan(plan) sync.do_sync_plan(plan)
sync.finalize() sync.finalize()
elif args.command == "genlocal": elif args.command == "genlocal":
pass local = {}
for package_path in (basedir / "simple").iterdir():
package_name = package_path.name
serial = get_local_serial(package_path)
if serial:
local[package_name] = serial
local_db.nuke(commit=False)
local_db.batch_set(local)
local_db.dump_json()
elif args.command == "verify":
sync = SyncPyPI(basedir=basedir, local_db=local_db)
remote = sync.fetch_remote_versions()
for package_name in remote:
local_serial = get_local_serial(basedir / "simple" / package_name)
if local_serial == remote[package_name]:
logger.info("%s serial same as remote", package_name)
continue
logger.info("updating %s", package_name)
sync.do_update(package_name)
sync.finalize()
if __name__ == "__main__": if __name__ == "__main__":
@ -447,7 +543,10 @@ if __name__ == "__main__":
parser_sync = subparsers.add_parser("sync", help="Sync from upstream") parser_sync = subparsers.add_parser("sync", help="Sync from upstream")
parser_genlocal = subparsers.add_parser( parser_genlocal = subparsers.add_parser(
"genlocal", help="(Re)generate local.json file" "genlocal", help="(Re)generate local db and json from simple/"
)
parser_verify = subparsers.add_parser(
"verify", help="Verify existing sync and download missing things"
) )
args = parser.parse_args() args = parser.parse_args()