diff --git a/shadowmire.py b/shadowmire.py index 27442fe..0f5fa42 100644 --- a/shadowmire.py +++ b/shadowmire.py @@ -9,11 +9,11 @@ from pathlib import Path from html.parser import HTMLParser import logging import html -import requests import argparse import os from contextlib import contextmanager import sqlite3 +import requests logger = logging.getLogger(__name__) @@ -27,6 +27,7 @@ class LocalVersionKV: As it would have consistency issue if it's writing while downstream is downloading the database. An extra "jsonpath" is used, to store kv results when necessary. """ + def __init__(self, dbpath: Path, jsonpath: Path) -> None: self.conn = sqlite3.connect(dbpath) self.jsonpath = jsonpath @@ -54,18 +55,24 @@ class LocalVersionKV: kvs = [(k, v) for k, v in d.items()] cur.executemany(self.INSERT_SQL, kvs) self.conn.commit() - + def remove(self, key: str) -> None: cur = self.conn.cursor() cur.execute("DELETE FROM local WHERE key = ?", (key,)) self.conn.commit() - + def nuke(self, commit: bool = True) -> None: cur = self.conn.cursor() cur.execute("DELETE FROM local") if commit: self.conn.commit() + def keys(self) -> list[str]: + cur = self.conn.cursor() + res = cur.execute("SELECT key FROM local") + rows = res.fetchall() + return [row[0] for row in rows] + def dump(self) -> dict[str, int]: cur = self.conn.cursor() res = cur.execute("SELECT key, value FROM local") @@ -294,7 +301,9 @@ class Plan: class SyncBase: - def __init__(self, basedir: Path, local_db: LocalVersionKV, sync_packages: bool = False) -> None: + def __init__( + self, basedir: Path, local_db: LocalVersionKV, sync_packages: bool = False + ) -> None: self.basedir = basedir self.local_db = local_db self.simple_dir = basedir / "simple" @@ -364,7 +373,7 @@ class SyncBase: raise NotImplementedError def finalize(self) -> None: - assert self.remote + local_names = self.local_db.keys() # generate index.html at basedir index_path = self.basedir / "simple" / "index.html" # modified from bandersnatch @@ -378,17 +387,16 @@ class SyncBase: f.write(" \n") # This will either be the simple dir, or if we are using index # directory hashing, a list of subdirs to process. - for pkg in self.remote: + for pkg in local_names: # We're really trusty that this is all encoded in UTF-8. :/ f.write(f' {pkg}
\n') f.write(" \n") - remote_json_path = self.basedir / "remote.json" - local_json_path = self.basedir / "local.json" - remote_json_path.rename(local_json_path) class SyncPyPI(SyncBase): - def __init__(self, basedir: Path, local_db: LocalVersionKV, sync_packages: bool = False) -> None: + def __init__( + self, basedir: Path, local_db: LocalVersionKV, sync_packages: bool = False + ) -> None: self.pypi = PyPI() self.session = create_requests_session() super().__init__(basedir, local_db, sync_packages) @@ -410,7 +418,7 @@ class SyncPyPI(SyncBase): logger.warning("%s missing from upstream, skip.", package_name) return False - last_serial = meta['last_serial'] + last_serial = meta["last_serial"] # OK, here we don't bother store raw name # Considering that JSON API even does not give package raw name, why bother we use it? simple_html_contents = self.pypi.generate_html_simple_page(meta, package_name) @@ -427,7 +435,7 @@ class SyncPyPI(SyncBase): if self.sync_packages: raise NotImplementedError - + self.local_db.set(package_name, last_serial) return True @@ -435,7 +443,11 @@ class SyncPyPI(SyncBase): class SyncPlainHTTP(SyncBase): def __init__( - self, upstream: str, basedir: Path, local_db: LocalVersionKV, sync_packages: bool = False + self, + upstream: str, + basedir: Path, + local_db: LocalVersionKV, + sync_packages: bool = False, ) -> None: self.upstream = upstream self.session = create_requests_session() @@ -448,7 +460,7 @@ class SyncPlainHTTP(SyncBase): remote: dict[str, int] = resp.json() return remote - def do_update(self, package_name) -> bool: + def do_update(self, package_name: str) -> bool: package_simple_path = self.simple_dir / package_name package_simple_path.mkdir(exist_ok=True) # directly fetch remote files @@ -461,8 +473,7 @@ class SyncPlainHTTP(SyncBase): continue else: logger.error("%s does not exist. Stop with this.", file_url) - # TODO: error handling - break + return False else: resp.raise_for_status() content = resp.content