Add local db

This commit is contained in:
taoky 2024-08-02 15:48:55 +08:00
parent 577d021d70
commit 18a3ebbe03
2 changed files with 135 additions and 35 deletions

1
.gitignore vendored
View File

@ -1,6 +1,7 @@
.mypy_cache
simple/
local.json
local.db
plan.json
remote.json
venv/

View File

@ -13,12 +13,71 @@ import requests
import argparse
import os
from contextlib import contextmanager
import sqlite3
logger = logging.getLogger(__name__)
USER_AGENT = "Shadowmire (https://github.com/taoky/shadowmire)"
class LocalVersionKV:
"""
A key-value database wrapper over sqlite3.
As it would have consistency issue if it's writing while downstream is downloading the database.
An extra "jsonpath" is used, to store kv results when necessary.
"""
def __init__(self, dbpath: Path, jsonpath: Path) -> None:
self.conn = sqlite3.connect(dbpath)
self.jsonpath = jsonpath
cur = self.conn.cursor()
cur.execute(
"CREATE TABLE IF NOT EXISTS local(key TEXT PRIMARY KEY, value INT NOT NULL)"
)
self.conn.commit()
def get(self, key: str) -> Optional[int]:
cur = self.conn.cursor()
res = cur.execute("SELECT key, value FROM local WHERE key = ?", (key,))
row = res.fetchone()
return row[0] if row else None
INSERT_SQL = "INSERT INTO local (key, value) VALUES (?, ?) ON CONFLICT(key) DO UPDATE SET value=excluded.value"
def set(self, key: str, value: int) -> None:
cur = self.conn.cursor()
cur.execute(self.INSERT_SQL, (key, value))
self.conn.commit()
def batch_set(self, d: dict[str, int]) -> None:
cur = self.conn.cursor()
kvs = [(k, v) for k, v in d.items()]
cur.executemany(self.INSERT_SQL, kvs)
self.conn.commit()
def remove(self, key: str) -> None:
cur = self.conn.cursor()
cur.execute("DELETE FROM local WHERE key = ?", (key,))
self.conn.commit()
def nuke(self, commit: bool = True) -> None:
cur = self.conn.cursor()
cur.execute("DELETE FROM local")
if commit:
self.conn.commit()
def dump(self) -> dict[str, int]:
cur = self.conn.cursor()
res = cur.execute("SELECT key, value FROM local")
rows = res.fetchall()
return {row[0]: row[1] for row in rows}
def dump_json(self) -> None:
res = self.dump()
with overwrite(self.jsonpath) as f:
json.dump(res, f)
@contextmanager
def overwrite(file_path: Path, mode: str = "w", tmp_suffix: str = ".tmp"):
tmp_path = file_path.parent / (file_path.name + tmp_suffix)
@ -235,8 +294,9 @@ class Plan:
class SyncBase:
def __init__(self, basedir: Path, sync_packages: bool = False) -> None:
def __init__(self, basedir: Path, local_db: LocalVersionKV, sync_packages: bool = False) -> None:
self.basedir = basedir
self.local_db = local_db
self.simple_dir = basedir / "simple"
self.packages_dir = basedir / "packages"
# create the dirs, if not exist
@ -275,9 +335,9 @@ class SyncBase:
to_remove = plan.remove
to_update = plan.update
for package in to_remove:
logger.info("Removing %s", package)
meta_dir = self.simple_dir / package
for package_name in to_remove:
logger.info("Removing %s", package_name)
meta_dir = self.simple_dir / package_name
index_html = meta_dir / "index.html"
try:
with open(index_html) as f:
@ -292,12 +352,15 @@ class SyncBase:
except FileNotFoundError:
pass
# remove all files inside meta_dir
self.local_db.remove(package_name)
remove_dir_with_files(meta_dir)
for package in to_update:
logger.info("Updating %s", package)
self.do_update((package, self.remote[package]))
for idx, package_name in enumerate(to_update):
logger.info("Updating %s", package_name)
self.do_update(package_name)
if idx % 1000 == 0:
self.local_db.dump_json()
def do_update(self, package: ShadowmirePackageItem) -> bool:
def do_update(self, package_name: str) -> bool:
raise NotImplementedError
def finalize(self) -> None:
@ -325,10 +388,10 @@ class SyncBase:
class SyncPyPI(SyncBase):
def __init__(self, basedir: Path, sync_packages: bool = False) -> None:
def __init__(self, basedir: Path, local_db: LocalVersionKV, sync_packages: bool = False) -> None:
self.pypi = PyPI()
self.session = create_requests_session()
super().__init__(basedir, sync_packages)
super().__init__(basedir, local_db, sync_packages)
def fetch_remote_versions(self) -> dict[str, int]:
remote_serials = self.pypi.list_packages_with_serial()
@ -337,13 +400,9 @@ class SyncPyPI(SyncBase):
ret[normalize(key)] = remote_serials[key]
return ret
def do_update(self, package: ShadowmirePackageItem) -> bool:
package_name = package[0]
# The serial get from metadata now might be newer than package_serial...
# package_serial = package[1]
package_simple_dir = self.simple_dir / package_name
package_simple_dir.mkdir(exist_ok=True)
def do_update(self, package_name: str) -> bool:
package_simple_path = self.simple_dir / package_name
package_simple_path.mkdir(exist_ok=True)
try:
meta = self.pypi.get_package_metadata(package_name)
logger.debug("%s meta: %s", package_name, meta)
@ -351,33 +410,36 @@ class SyncPyPI(SyncBase):
logger.warning("%s missing from upstream, skip.", package_name)
return False
last_serial = meta['last_serial']
# OK, here we don't bother store raw name
# Considering that JSON API even does not give package raw name, why bother we use it?
simple_html_contents = self.pypi.generate_html_simple_page(meta, package_name)
simple_json_contents = self.pypi.generate_json_simple_page(meta)
for html_filename in ("index.html", "index.v1_html"):
html_path = package_simple_dir / html_filename
html_path = package_simple_path / html_filename
with overwrite(html_path) as f:
f.write(simple_html_contents)
for json_filename in ("index.v1_json",):
json_path = package_simple_dir / json_filename
json_path = package_simple_path / json_filename
with overwrite(json_path) as f:
f.write(simple_json_contents)
if self.sync_packages:
raise NotImplementedError
self.local_db.set(package_name, last_serial)
return True
class SyncPlainHTTP(SyncBase):
def __init__(
self, upstream: str, basedir: Path, sync_packages: bool = False
self, upstream: str, basedir: Path, local_db: LocalVersionKV, sync_packages: bool = False
) -> None:
self.upstream = upstream
self.session = create_requests_session()
super().__init__(basedir, sync_packages)
super().__init__(basedir, local_db, sync_packages)
def fetch_remote_versions(self) -> dict[str, int]:
remote_url = urljoin(self.upstream, "local.json")
@ -386,10 +448,9 @@ class SyncPlainHTTP(SyncBase):
remote: dict[str, int] = resp.json()
return remote
def do_update(self, package: tuple[str, int]) -> bool:
package_name = package[0]
package_simple_dir = self.simple_dir / package_name
package_simple_dir.mkdir(exist_ok=True)
def do_update(self, package_name) -> bool:
package_simple_path = self.simple_dir / package_name
package_simple_path.mkdir(exist_ok=True)
# directly fetch remote files
for filename in ("index.html", "index.v1_html", "index.v1_json"):
file_url = urljoin(self.upstream, f"/simple/{package_name}/{filename}")
@ -405,32 +466,48 @@ class SyncPlainHTTP(SyncBase):
else:
resp.raise_for_status()
content = resp.content
with open(package_simple_dir / filename, "wb") as f:
with open(package_simple_path / filename, "wb") as f:
f.write(content)
if self.sync_packages:
raise NotImplementedError
last_serial = get_local_serial(package_simple_path)
if not last_serial:
logger.warning("cannot get valid package serial from %s", package_name)
else:
self.local_db.set(package_name, last_serial)
return True
def load_local(basedir: Path) -> dict[str, int]:
def get_local_serial(package_simple_path: Path) -> Optional[int]:
package_name = package_simple_path.name
package_index_path = package_simple_path / "index.html"
try:
with open(basedir / "local.json") as f:
r = json.load(f)
return r
with open(package_index_path) as f:
contents = f.read()
except FileNotFoundError:
return {}
logger.warning("%s does not have index.html, skipping", package_name)
return None
try:
serial_comment = contents.splitlines()[-1].strip()
serial = int(serial_comment.removeprefix("<!--SERIAL ").removesuffix("-->"))
return serial
except Exception:
logger.warning("cannot parse %s index.html", package_name, exc_info=True)
return None
def main(args: argparse.Namespace) -> None:
log_level = logging.DEBUG if os.environ.get("DEBUG") else logging.INFO
logging.basicConfig(level=log_level)
basedir = Path(".")
local_db = LocalVersionKV(basedir / "local.db", basedir / "local.json")
if args.command == "sync":
sync = SyncPyPI(basedir=basedir)
local = load_local(basedir)
sync = SyncPyPI(basedir=basedir, local_db=local_db)
local = local_db.dump()
plan = sync.determine_sync_plan(local)
# save plan for debugging
with overwrite(basedir / "plan.json") as f:
@ -438,7 +515,26 @@ def main(args: argparse.Namespace) -> None:
sync.do_sync_plan(plan)
sync.finalize()
elif args.command == "genlocal":
pass
local = {}
for package_path in (basedir / "simple").iterdir():
package_name = package_path.name
serial = get_local_serial(package_path)
if serial:
local[package_name] = serial
local_db.nuke(commit=False)
local_db.batch_set(local)
local_db.dump_json()
elif args.command == "verify":
sync = SyncPyPI(basedir=basedir, local_db=local_db)
remote = sync.fetch_remote_versions()
for package_name in remote:
local_serial = get_local_serial(basedir / "simple" / package_name)
if local_serial == remote[package_name]:
logger.info("%s serial same as remote", package_name)
continue
logger.info("updating %s", package_name)
sync.do_update(package_name)
sync.finalize()
if __name__ == "__main__":
@ -447,7 +543,10 @@ if __name__ == "__main__":
parser_sync = subparsers.add_parser("sync", help="Sync from upstream")
parser_genlocal = subparsers.add_parser(
"genlocal", help="(Re)generate local.json file"
"genlocal", help="(Re)generate local db and json from simple/"
)
parser_verify = subparsers.add_parser(
"verify", help="Verify existing sync and download missing things"
)
args = parser.parse_args()