mirror of
https://github.com/taoky/shadowmire.git
synced 2025-07-08 09:12:43 +00:00
Add local db
This commit is contained in:
parent
577d021d70
commit
18a3ebbe03
1
.gitignore
vendored
1
.gitignore
vendored
@ -1,6 +1,7 @@
|
||||
.mypy_cache
|
||||
simple/
|
||||
local.json
|
||||
local.db
|
||||
plan.json
|
||||
remote.json
|
||||
venv/
|
||||
|
169
shadowmire.py
169
shadowmire.py
@ -13,12 +13,71 @@ import requests
|
||||
import argparse
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
import sqlite3
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
USER_AGENT = "Shadowmire (https://github.com/taoky/shadowmire)"
|
||||
|
||||
|
||||
class LocalVersionKV:
|
||||
"""
|
||||
A key-value database wrapper over sqlite3.
|
||||
|
||||
As it would have consistency issue if it's writing while downstream is downloading the database.
|
||||
An extra "jsonpath" is used, to store kv results when necessary.
|
||||
"""
|
||||
def __init__(self, dbpath: Path, jsonpath: Path) -> None:
|
||||
self.conn = sqlite3.connect(dbpath)
|
||||
self.jsonpath = jsonpath
|
||||
cur = self.conn.cursor()
|
||||
cur.execute(
|
||||
"CREATE TABLE IF NOT EXISTS local(key TEXT PRIMARY KEY, value INT NOT NULL)"
|
||||
)
|
||||
self.conn.commit()
|
||||
|
||||
def get(self, key: str) -> Optional[int]:
|
||||
cur = self.conn.cursor()
|
||||
res = cur.execute("SELECT key, value FROM local WHERE key = ?", (key,))
|
||||
row = res.fetchone()
|
||||
return row[0] if row else None
|
||||
|
||||
INSERT_SQL = "INSERT INTO local (key, value) VALUES (?, ?) ON CONFLICT(key) DO UPDATE SET value=excluded.value"
|
||||
|
||||
def set(self, key: str, value: int) -> None:
|
||||
cur = self.conn.cursor()
|
||||
cur.execute(self.INSERT_SQL, (key, value))
|
||||
self.conn.commit()
|
||||
|
||||
def batch_set(self, d: dict[str, int]) -> None:
|
||||
cur = self.conn.cursor()
|
||||
kvs = [(k, v) for k, v in d.items()]
|
||||
cur.executemany(self.INSERT_SQL, kvs)
|
||||
self.conn.commit()
|
||||
|
||||
def remove(self, key: str) -> None:
|
||||
cur = self.conn.cursor()
|
||||
cur.execute("DELETE FROM local WHERE key = ?", (key,))
|
||||
self.conn.commit()
|
||||
|
||||
def nuke(self, commit: bool = True) -> None:
|
||||
cur = self.conn.cursor()
|
||||
cur.execute("DELETE FROM local")
|
||||
if commit:
|
||||
self.conn.commit()
|
||||
|
||||
def dump(self) -> dict[str, int]:
|
||||
cur = self.conn.cursor()
|
||||
res = cur.execute("SELECT key, value FROM local")
|
||||
rows = res.fetchall()
|
||||
return {row[0]: row[1] for row in rows}
|
||||
|
||||
def dump_json(self) -> None:
|
||||
res = self.dump()
|
||||
with overwrite(self.jsonpath) as f:
|
||||
json.dump(res, f)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def overwrite(file_path: Path, mode: str = "w", tmp_suffix: str = ".tmp"):
|
||||
tmp_path = file_path.parent / (file_path.name + tmp_suffix)
|
||||
@ -235,8 +294,9 @@ class Plan:
|
||||
|
||||
|
||||
class SyncBase:
|
||||
def __init__(self, basedir: Path, sync_packages: bool = False) -> None:
|
||||
def __init__(self, basedir: Path, local_db: LocalVersionKV, sync_packages: bool = False) -> None:
|
||||
self.basedir = basedir
|
||||
self.local_db = local_db
|
||||
self.simple_dir = basedir / "simple"
|
||||
self.packages_dir = basedir / "packages"
|
||||
# create the dirs, if not exist
|
||||
@ -275,9 +335,9 @@ class SyncBase:
|
||||
to_remove = plan.remove
|
||||
to_update = plan.update
|
||||
|
||||
for package in to_remove:
|
||||
logger.info("Removing %s", package)
|
||||
meta_dir = self.simple_dir / package
|
||||
for package_name in to_remove:
|
||||
logger.info("Removing %s", package_name)
|
||||
meta_dir = self.simple_dir / package_name
|
||||
index_html = meta_dir / "index.html"
|
||||
try:
|
||||
with open(index_html) as f:
|
||||
@ -292,12 +352,15 @@ class SyncBase:
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
# remove all files inside meta_dir
|
||||
self.local_db.remove(package_name)
|
||||
remove_dir_with_files(meta_dir)
|
||||
for package in to_update:
|
||||
logger.info("Updating %s", package)
|
||||
self.do_update((package, self.remote[package]))
|
||||
for idx, package_name in enumerate(to_update):
|
||||
logger.info("Updating %s", package_name)
|
||||
self.do_update(package_name)
|
||||
if idx % 1000 == 0:
|
||||
self.local_db.dump_json()
|
||||
|
||||
def do_update(self, package: ShadowmirePackageItem) -> bool:
|
||||
def do_update(self, package_name: str) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
def finalize(self) -> None:
|
||||
@ -325,10 +388,10 @@ class SyncBase:
|
||||
|
||||
|
||||
class SyncPyPI(SyncBase):
|
||||
def __init__(self, basedir: Path, sync_packages: bool = False) -> None:
|
||||
def __init__(self, basedir: Path, local_db: LocalVersionKV, sync_packages: bool = False) -> None:
|
||||
self.pypi = PyPI()
|
||||
self.session = create_requests_session()
|
||||
super().__init__(basedir, sync_packages)
|
||||
super().__init__(basedir, local_db, sync_packages)
|
||||
|
||||
def fetch_remote_versions(self) -> dict[str, int]:
|
||||
remote_serials = self.pypi.list_packages_with_serial()
|
||||
@ -337,13 +400,9 @@ class SyncPyPI(SyncBase):
|
||||
ret[normalize(key)] = remote_serials[key]
|
||||
return ret
|
||||
|
||||
def do_update(self, package: ShadowmirePackageItem) -> bool:
|
||||
package_name = package[0]
|
||||
# The serial get from metadata now might be newer than package_serial...
|
||||
# package_serial = package[1]
|
||||
|
||||
package_simple_dir = self.simple_dir / package_name
|
||||
package_simple_dir.mkdir(exist_ok=True)
|
||||
def do_update(self, package_name: str) -> bool:
|
||||
package_simple_path = self.simple_dir / package_name
|
||||
package_simple_path.mkdir(exist_ok=True)
|
||||
try:
|
||||
meta = self.pypi.get_package_metadata(package_name)
|
||||
logger.debug("%s meta: %s", package_name, meta)
|
||||
@ -351,33 +410,36 @@ class SyncPyPI(SyncBase):
|
||||
logger.warning("%s missing from upstream, skip.", package_name)
|
||||
return False
|
||||
|
||||
last_serial = meta['last_serial']
|
||||
# OK, here we don't bother store raw name
|
||||
# Considering that JSON API even does not give package raw name, why bother we use it?
|
||||
simple_html_contents = self.pypi.generate_html_simple_page(meta, package_name)
|
||||
simple_json_contents = self.pypi.generate_json_simple_page(meta)
|
||||
|
||||
for html_filename in ("index.html", "index.v1_html"):
|
||||
html_path = package_simple_dir / html_filename
|
||||
html_path = package_simple_path / html_filename
|
||||
with overwrite(html_path) as f:
|
||||
f.write(simple_html_contents)
|
||||
for json_filename in ("index.v1_json",):
|
||||
json_path = package_simple_dir / json_filename
|
||||
json_path = package_simple_path / json_filename
|
||||
with overwrite(json_path) as f:
|
||||
f.write(simple_json_contents)
|
||||
|
||||
if self.sync_packages:
|
||||
raise NotImplementedError
|
||||
|
||||
self.local_db.set(package_name, last_serial)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
class SyncPlainHTTP(SyncBase):
|
||||
def __init__(
|
||||
self, upstream: str, basedir: Path, sync_packages: bool = False
|
||||
self, upstream: str, basedir: Path, local_db: LocalVersionKV, sync_packages: bool = False
|
||||
) -> None:
|
||||
self.upstream = upstream
|
||||
self.session = create_requests_session()
|
||||
super().__init__(basedir, sync_packages)
|
||||
super().__init__(basedir, local_db, sync_packages)
|
||||
|
||||
def fetch_remote_versions(self) -> dict[str, int]:
|
||||
remote_url = urljoin(self.upstream, "local.json")
|
||||
@ -386,10 +448,9 @@ class SyncPlainHTTP(SyncBase):
|
||||
remote: dict[str, int] = resp.json()
|
||||
return remote
|
||||
|
||||
def do_update(self, package: tuple[str, int]) -> bool:
|
||||
package_name = package[0]
|
||||
package_simple_dir = self.simple_dir / package_name
|
||||
package_simple_dir.mkdir(exist_ok=True)
|
||||
def do_update(self, package_name) -> bool:
|
||||
package_simple_path = self.simple_dir / package_name
|
||||
package_simple_path.mkdir(exist_ok=True)
|
||||
# directly fetch remote files
|
||||
for filename in ("index.html", "index.v1_html", "index.v1_json"):
|
||||
file_url = urljoin(self.upstream, f"/simple/{package_name}/{filename}")
|
||||
@ -405,32 +466,48 @@ class SyncPlainHTTP(SyncBase):
|
||||
else:
|
||||
resp.raise_for_status()
|
||||
content = resp.content
|
||||
with open(package_simple_dir / filename, "wb") as f:
|
||||
with open(package_simple_path / filename, "wb") as f:
|
||||
f.write(content)
|
||||
|
||||
if self.sync_packages:
|
||||
raise NotImplementedError
|
||||
|
||||
last_serial = get_local_serial(package_simple_path)
|
||||
if not last_serial:
|
||||
logger.warning("cannot get valid package serial from %s", package_name)
|
||||
else:
|
||||
self.local_db.set(package_name, last_serial)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def load_local(basedir: Path) -> dict[str, int]:
|
||||
def get_local_serial(package_simple_path: Path) -> Optional[int]:
|
||||
package_name = package_simple_path.name
|
||||
package_index_path = package_simple_path / "index.html"
|
||||
try:
|
||||
with open(basedir / "local.json") as f:
|
||||
r = json.load(f)
|
||||
return r
|
||||
with open(package_index_path) as f:
|
||||
contents = f.read()
|
||||
except FileNotFoundError:
|
||||
return {}
|
||||
logger.warning("%s does not have index.html, skipping", package_name)
|
||||
return None
|
||||
try:
|
||||
serial_comment = contents.splitlines()[-1].strip()
|
||||
serial = int(serial_comment.removeprefix("<!--SERIAL ").removesuffix("-->"))
|
||||
return serial
|
||||
except Exception:
|
||||
logger.warning("cannot parse %s index.html", package_name, exc_info=True)
|
||||
return None
|
||||
|
||||
|
||||
def main(args: argparse.Namespace) -> None:
|
||||
log_level = logging.DEBUG if os.environ.get("DEBUG") else logging.INFO
|
||||
logging.basicConfig(level=log_level)
|
||||
basedir = Path(".")
|
||||
local_db = LocalVersionKV(basedir / "local.db", basedir / "local.json")
|
||||
|
||||
if args.command == "sync":
|
||||
sync = SyncPyPI(basedir=basedir)
|
||||
local = load_local(basedir)
|
||||
sync = SyncPyPI(basedir=basedir, local_db=local_db)
|
||||
local = local_db.dump()
|
||||
plan = sync.determine_sync_plan(local)
|
||||
# save plan for debugging
|
||||
with overwrite(basedir / "plan.json") as f:
|
||||
@ -438,7 +515,26 @@ def main(args: argparse.Namespace) -> None:
|
||||
sync.do_sync_plan(plan)
|
||||
sync.finalize()
|
||||
elif args.command == "genlocal":
|
||||
pass
|
||||
local = {}
|
||||
for package_path in (basedir / "simple").iterdir():
|
||||
package_name = package_path.name
|
||||
serial = get_local_serial(package_path)
|
||||
if serial:
|
||||
local[package_name] = serial
|
||||
local_db.nuke(commit=False)
|
||||
local_db.batch_set(local)
|
||||
local_db.dump_json()
|
||||
elif args.command == "verify":
|
||||
sync = SyncPyPI(basedir=basedir, local_db=local_db)
|
||||
remote = sync.fetch_remote_versions()
|
||||
for package_name in remote:
|
||||
local_serial = get_local_serial(basedir / "simple" / package_name)
|
||||
if local_serial == remote[package_name]:
|
||||
logger.info("%s serial same as remote", package_name)
|
||||
continue
|
||||
logger.info("updating %s", package_name)
|
||||
sync.do_update(package_name)
|
||||
sync.finalize()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@ -447,7 +543,10 @@ if __name__ == "__main__":
|
||||
|
||||
parser_sync = subparsers.add_parser("sync", help="Sync from upstream")
|
||||
parser_genlocal = subparsers.add_parser(
|
||||
"genlocal", help="(Re)generate local.json file"
|
||||
"genlocal", help="(Re)generate local db and json from simple/"
|
||||
)
|
||||
parser_verify = subparsers.add_parser(
|
||||
"verify", help="Verify existing sync and download missing things"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
Loading…
x
Reference in New Issue
Block a user