mirror of
https://github.com/taoky/shadowmire.git
synced 2025-07-08 17:32:43 +00:00
Add local db
This commit is contained in:
parent
577d021d70
commit
18a3ebbe03
1
.gitignore
vendored
1
.gitignore
vendored
@ -1,6 +1,7 @@
|
|||||||
.mypy_cache
|
.mypy_cache
|
||||||
simple/
|
simple/
|
||||||
local.json
|
local.json
|
||||||
|
local.db
|
||||||
plan.json
|
plan.json
|
||||||
remote.json
|
remote.json
|
||||||
venv/
|
venv/
|
||||||
|
169
shadowmire.py
169
shadowmire.py
@ -13,12 +13,71 @@ import requests
|
|||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
import sqlite3
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
USER_AGENT = "Shadowmire (https://github.com/taoky/shadowmire)"
|
USER_AGENT = "Shadowmire (https://github.com/taoky/shadowmire)"
|
||||||
|
|
||||||
|
|
||||||
|
class LocalVersionKV:
|
||||||
|
"""
|
||||||
|
A key-value database wrapper over sqlite3.
|
||||||
|
|
||||||
|
As it would have consistency issue if it's writing while downstream is downloading the database.
|
||||||
|
An extra "jsonpath" is used, to store kv results when necessary.
|
||||||
|
"""
|
||||||
|
def __init__(self, dbpath: Path, jsonpath: Path) -> None:
|
||||||
|
self.conn = sqlite3.connect(dbpath)
|
||||||
|
self.jsonpath = jsonpath
|
||||||
|
cur = self.conn.cursor()
|
||||||
|
cur.execute(
|
||||||
|
"CREATE TABLE IF NOT EXISTS local(key TEXT PRIMARY KEY, value INT NOT NULL)"
|
||||||
|
)
|
||||||
|
self.conn.commit()
|
||||||
|
|
||||||
|
def get(self, key: str) -> Optional[int]:
|
||||||
|
cur = self.conn.cursor()
|
||||||
|
res = cur.execute("SELECT key, value FROM local WHERE key = ?", (key,))
|
||||||
|
row = res.fetchone()
|
||||||
|
return row[0] if row else None
|
||||||
|
|
||||||
|
INSERT_SQL = "INSERT INTO local (key, value) VALUES (?, ?) ON CONFLICT(key) DO UPDATE SET value=excluded.value"
|
||||||
|
|
||||||
|
def set(self, key: str, value: int) -> None:
|
||||||
|
cur = self.conn.cursor()
|
||||||
|
cur.execute(self.INSERT_SQL, (key, value))
|
||||||
|
self.conn.commit()
|
||||||
|
|
||||||
|
def batch_set(self, d: dict[str, int]) -> None:
|
||||||
|
cur = self.conn.cursor()
|
||||||
|
kvs = [(k, v) for k, v in d.items()]
|
||||||
|
cur.executemany(self.INSERT_SQL, kvs)
|
||||||
|
self.conn.commit()
|
||||||
|
|
||||||
|
def remove(self, key: str) -> None:
|
||||||
|
cur = self.conn.cursor()
|
||||||
|
cur.execute("DELETE FROM local WHERE key = ?", (key,))
|
||||||
|
self.conn.commit()
|
||||||
|
|
||||||
|
def nuke(self, commit: bool = True) -> None:
|
||||||
|
cur = self.conn.cursor()
|
||||||
|
cur.execute("DELETE FROM local")
|
||||||
|
if commit:
|
||||||
|
self.conn.commit()
|
||||||
|
|
||||||
|
def dump(self) -> dict[str, int]:
|
||||||
|
cur = self.conn.cursor()
|
||||||
|
res = cur.execute("SELECT key, value FROM local")
|
||||||
|
rows = res.fetchall()
|
||||||
|
return {row[0]: row[1] for row in rows}
|
||||||
|
|
||||||
|
def dump_json(self) -> None:
|
||||||
|
res = self.dump()
|
||||||
|
with overwrite(self.jsonpath) as f:
|
||||||
|
json.dump(res, f)
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def overwrite(file_path: Path, mode: str = "w", tmp_suffix: str = ".tmp"):
|
def overwrite(file_path: Path, mode: str = "w", tmp_suffix: str = ".tmp"):
|
||||||
tmp_path = file_path.parent / (file_path.name + tmp_suffix)
|
tmp_path = file_path.parent / (file_path.name + tmp_suffix)
|
||||||
@ -235,8 +294,9 @@ class Plan:
|
|||||||
|
|
||||||
|
|
||||||
class SyncBase:
|
class SyncBase:
|
||||||
def __init__(self, basedir: Path, sync_packages: bool = False) -> None:
|
def __init__(self, basedir: Path, local_db: LocalVersionKV, sync_packages: bool = False) -> None:
|
||||||
self.basedir = basedir
|
self.basedir = basedir
|
||||||
|
self.local_db = local_db
|
||||||
self.simple_dir = basedir / "simple"
|
self.simple_dir = basedir / "simple"
|
||||||
self.packages_dir = basedir / "packages"
|
self.packages_dir = basedir / "packages"
|
||||||
# create the dirs, if not exist
|
# create the dirs, if not exist
|
||||||
@ -275,9 +335,9 @@ class SyncBase:
|
|||||||
to_remove = plan.remove
|
to_remove = plan.remove
|
||||||
to_update = plan.update
|
to_update = plan.update
|
||||||
|
|
||||||
for package in to_remove:
|
for package_name in to_remove:
|
||||||
logger.info("Removing %s", package)
|
logger.info("Removing %s", package_name)
|
||||||
meta_dir = self.simple_dir / package
|
meta_dir = self.simple_dir / package_name
|
||||||
index_html = meta_dir / "index.html"
|
index_html = meta_dir / "index.html"
|
||||||
try:
|
try:
|
||||||
with open(index_html) as f:
|
with open(index_html) as f:
|
||||||
@ -292,12 +352,15 @@ class SyncBase:
|
|||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
pass
|
pass
|
||||||
# remove all files inside meta_dir
|
# remove all files inside meta_dir
|
||||||
|
self.local_db.remove(package_name)
|
||||||
remove_dir_with_files(meta_dir)
|
remove_dir_with_files(meta_dir)
|
||||||
for package in to_update:
|
for idx, package_name in enumerate(to_update):
|
||||||
logger.info("Updating %s", package)
|
logger.info("Updating %s", package_name)
|
||||||
self.do_update((package, self.remote[package]))
|
self.do_update(package_name)
|
||||||
|
if idx % 1000 == 0:
|
||||||
|
self.local_db.dump_json()
|
||||||
|
|
||||||
def do_update(self, package: ShadowmirePackageItem) -> bool:
|
def do_update(self, package_name: str) -> bool:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def finalize(self) -> None:
|
def finalize(self) -> None:
|
||||||
@ -325,10 +388,10 @@ class SyncBase:
|
|||||||
|
|
||||||
|
|
||||||
class SyncPyPI(SyncBase):
|
class SyncPyPI(SyncBase):
|
||||||
def __init__(self, basedir: Path, sync_packages: bool = False) -> None:
|
def __init__(self, basedir: Path, local_db: LocalVersionKV, sync_packages: bool = False) -> None:
|
||||||
self.pypi = PyPI()
|
self.pypi = PyPI()
|
||||||
self.session = create_requests_session()
|
self.session = create_requests_session()
|
||||||
super().__init__(basedir, sync_packages)
|
super().__init__(basedir, local_db, sync_packages)
|
||||||
|
|
||||||
def fetch_remote_versions(self) -> dict[str, int]:
|
def fetch_remote_versions(self) -> dict[str, int]:
|
||||||
remote_serials = self.pypi.list_packages_with_serial()
|
remote_serials = self.pypi.list_packages_with_serial()
|
||||||
@ -337,13 +400,9 @@ class SyncPyPI(SyncBase):
|
|||||||
ret[normalize(key)] = remote_serials[key]
|
ret[normalize(key)] = remote_serials[key]
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
def do_update(self, package: ShadowmirePackageItem) -> bool:
|
def do_update(self, package_name: str) -> bool:
|
||||||
package_name = package[0]
|
package_simple_path = self.simple_dir / package_name
|
||||||
# The serial get from metadata now might be newer than package_serial...
|
package_simple_path.mkdir(exist_ok=True)
|
||||||
# package_serial = package[1]
|
|
||||||
|
|
||||||
package_simple_dir = self.simple_dir / package_name
|
|
||||||
package_simple_dir.mkdir(exist_ok=True)
|
|
||||||
try:
|
try:
|
||||||
meta = self.pypi.get_package_metadata(package_name)
|
meta = self.pypi.get_package_metadata(package_name)
|
||||||
logger.debug("%s meta: %s", package_name, meta)
|
logger.debug("%s meta: %s", package_name, meta)
|
||||||
@ -351,33 +410,36 @@ class SyncPyPI(SyncBase):
|
|||||||
logger.warning("%s missing from upstream, skip.", package_name)
|
logger.warning("%s missing from upstream, skip.", package_name)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
last_serial = meta['last_serial']
|
||||||
# OK, here we don't bother store raw name
|
# OK, here we don't bother store raw name
|
||||||
# Considering that JSON API even does not give package raw name, why bother we use it?
|
# Considering that JSON API even does not give package raw name, why bother we use it?
|
||||||
simple_html_contents = self.pypi.generate_html_simple_page(meta, package_name)
|
simple_html_contents = self.pypi.generate_html_simple_page(meta, package_name)
|
||||||
simple_json_contents = self.pypi.generate_json_simple_page(meta)
|
simple_json_contents = self.pypi.generate_json_simple_page(meta)
|
||||||
|
|
||||||
for html_filename in ("index.html", "index.v1_html"):
|
for html_filename in ("index.html", "index.v1_html"):
|
||||||
html_path = package_simple_dir / html_filename
|
html_path = package_simple_path / html_filename
|
||||||
with overwrite(html_path) as f:
|
with overwrite(html_path) as f:
|
||||||
f.write(simple_html_contents)
|
f.write(simple_html_contents)
|
||||||
for json_filename in ("index.v1_json",):
|
for json_filename in ("index.v1_json",):
|
||||||
json_path = package_simple_dir / json_filename
|
json_path = package_simple_path / json_filename
|
||||||
with overwrite(json_path) as f:
|
with overwrite(json_path) as f:
|
||||||
f.write(simple_json_contents)
|
f.write(simple_json_contents)
|
||||||
|
|
||||||
if self.sync_packages:
|
if self.sync_packages:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
self.local_db.set(package_name, last_serial)
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
class SyncPlainHTTP(SyncBase):
|
class SyncPlainHTTP(SyncBase):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, upstream: str, basedir: Path, sync_packages: bool = False
|
self, upstream: str, basedir: Path, local_db: LocalVersionKV, sync_packages: bool = False
|
||||||
) -> None:
|
) -> None:
|
||||||
self.upstream = upstream
|
self.upstream = upstream
|
||||||
self.session = create_requests_session()
|
self.session = create_requests_session()
|
||||||
super().__init__(basedir, sync_packages)
|
super().__init__(basedir, local_db, sync_packages)
|
||||||
|
|
||||||
def fetch_remote_versions(self) -> dict[str, int]:
|
def fetch_remote_versions(self) -> dict[str, int]:
|
||||||
remote_url = urljoin(self.upstream, "local.json")
|
remote_url = urljoin(self.upstream, "local.json")
|
||||||
@ -386,10 +448,9 @@ class SyncPlainHTTP(SyncBase):
|
|||||||
remote: dict[str, int] = resp.json()
|
remote: dict[str, int] = resp.json()
|
||||||
return remote
|
return remote
|
||||||
|
|
||||||
def do_update(self, package: tuple[str, int]) -> bool:
|
def do_update(self, package_name) -> bool:
|
||||||
package_name = package[0]
|
package_simple_path = self.simple_dir / package_name
|
||||||
package_simple_dir = self.simple_dir / package_name
|
package_simple_path.mkdir(exist_ok=True)
|
||||||
package_simple_dir.mkdir(exist_ok=True)
|
|
||||||
# directly fetch remote files
|
# directly fetch remote files
|
||||||
for filename in ("index.html", "index.v1_html", "index.v1_json"):
|
for filename in ("index.html", "index.v1_html", "index.v1_json"):
|
||||||
file_url = urljoin(self.upstream, f"/simple/{package_name}/{filename}")
|
file_url = urljoin(self.upstream, f"/simple/{package_name}/{filename}")
|
||||||
@ -405,32 +466,48 @@ class SyncPlainHTTP(SyncBase):
|
|||||||
else:
|
else:
|
||||||
resp.raise_for_status()
|
resp.raise_for_status()
|
||||||
content = resp.content
|
content = resp.content
|
||||||
with open(package_simple_dir / filename, "wb") as f:
|
with open(package_simple_path / filename, "wb") as f:
|
||||||
f.write(content)
|
f.write(content)
|
||||||
|
|
||||||
if self.sync_packages:
|
if self.sync_packages:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
last_serial = get_local_serial(package_simple_path)
|
||||||
|
if not last_serial:
|
||||||
|
logger.warning("cannot get valid package serial from %s", package_name)
|
||||||
|
else:
|
||||||
|
self.local_db.set(package_name, last_serial)
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def load_local(basedir: Path) -> dict[str, int]:
|
def get_local_serial(package_simple_path: Path) -> Optional[int]:
|
||||||
|
package_name = package_simple_path.name
|
||||||
|
package_index_path = package_simple_path / "index.html"
|
||||||
try:
|
try:
|
||||||
with open(basedir / "local.json") as f:
|
with open(package_index_path) as f:
|
||||||
r = json.load(f)
|
contents = f.read()
|
||||||
return r
|
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
return {}
|
logger.warning("%s does not have index.html, skipping", package_name)
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
serial_comment = contents.splitlines()[-1].strip()
|
||||||
|
serial = int(serial_comment.removeprefix("<!--SERIAL ").removesuffix("-->"))
|
||||||
|
return serial
|
||||||
|
except Exception:
|
||||||
|
logger.warning("cannot parse %s index.html", package_name, exc_info=True)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def main(args: argparse.Namespace) -> None:
|
def main(args: argparse.Namespace) -> None:
|
||||||
log_level = logging.DEBUG if os.environ.get("DEBUG") else logging.INFO
|
log_level = logging.DEBUG if os.environ.get("DEBUG") else logging.INFO
|
||||||
logging.basicConfig(level=log_level)
|
logging.basicConfig(level=log_level)
|
||||||
basedir = Path(".")
|
basedir = Path(".")
|
||||||
|
local_db = LocalVersionKV(basedir / "local.db", basedir / "local.json")
|
||||||
|
|
||||||
if args.command == "sync":
|
if args.command == "sync":
|
||||||
sync = SyncPyPI(basedir=basedir)
|
sync = SyncPyPI(basedir=basedir, local_db=local_db)
|
||||||
local = load_local(basedir)
|
local = local_db.dump()
|
||||||
plan = sync.determine_sync_plan(local)
|
plan = sync.determine_sync_plan(local)
|
||||||
# save plan for debugging
|
# save plan for debugging
|
||||||
with overwrite(basedir / "plan.json") as f:
|
with overwrite(basedir / "plan.json") as f:
|
||||||
@ -438,7 +515,26 @@ def main(args: argparse.Namespace) -> None:
|
|||||||
sync.do_sync_plan(plan)
|
sync.do_sync_plan(plan)
|
||||||
sync.finalize()
|
sync.finalize()
|
||||||
elif args.command == "genlocal":
|
elif args.command == "genlocal":
|
||||||
pass
|
local = {}
|
||||||
|
for package_path in (basedir / "simple").iterdir():
|
||||||
|
package_name = package_path.name
|
||||||
|
serial = get_local_serial(package_path)
|
||||||
|
if serial:
|
||||||
|
local[package_name] = serial
|
||||||
|
local_db.nuke(commit=False)
|
||||||
|
local_db.batch_set(local)
|
||||||
|
local_db.dump_json()
|
||||||
|
elif args.command == "verify":
|
||||||
|
sync = SyncPyPI(basedir=basedir, local_db=local_db)
|
||||||
|
remote = sync.fetch_remote_versions()
|
||||||
|
for package_name in remote:
|
||||||
|
local_serial = get_local_serial(basedir / "simple" / package_name)
|
||||||
|
if local_serial == remote[package_name]:
|
||||||
|
logger.info("%s serial same as remote", package_name)
|
||||||
|
continue
|
||||||
|
logger.info("updating %s", package_name)
|
||||||
|
sync.do_update(package_name)
|
||||||
|
sync.finalize()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
@ -447,7 +543,10 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
parser_sync = subparsers.add_parser("sync", help="Sync from upstream")
|
parser_sync = subparsers.add_parser("sync", help="Sync from upstream")
|
||||||
parser_genlocal = subparsers.add_parser(
|
parser_genlocal = subparsers.add_parser(
|
||||||
"genlocal", help="(Re)generate local.json file"
|
"genlocal", help="(Re)generate local db and json from simple/"
|
||||||
|
)
|
||||||
|
parser_verify = subparsers.add_parser(
|
||||||
|
"verify", help="Verify existing sync and download missing things"
|
||||||
)
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user