Signed-off-by: Shengqi Chen <harry-chen@outlook.com>
This commit is contained in:
Shengqi Chen 2024-08-08 16:46:25 +08:00
parent b94fc106ba
commit 957d9f0efd
No known key found for this signature in database
GPG Key ID: 6EE389C0F18AF774
2 changed files with 116 additions and 53 deletions

View File

@ -1,7 +1,8 @@
# tunasync-scripts # tunasync-scripts
Custom scripts for mirror jobs Custom scripts for mirror jobs
# LICENCE ## LICENCE
``` ```
This program is free software: you can redistribute it and/or modify This program is free software: you can redistribute it and/or modify
@ -17,3 +18,7 @@ GNU General Public License for more details.
You should have received a copy of the GNU General Public License You should have received a copy of the GNU General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>. along with this program. If not, see <http://www.gnu.org/licenses/>.
``` ```
## Note
* `shadowmire.py` is taken as-is from [taoky/shadowmire](https://github.com/taoky/shadowmire). It is licensed under Apache License.

View File

@ -1,7 +1,5 @@
# Copied from https://github.com/taoky/shadowmire
# NOTE: THIS FILE IS NOT LICENSED WITH THIS REPO.
#!/usr/bin/env python #!/usr/bin/env python
import sys import sys
from types import FrameType from types import FrameType
from typing import IO, Any, Callable, Generator, Optional from typing import IO, Any, Callable, Generator, Optional
@ -24,12 +22,15 @@ from concurrent.futures import ThreadPoolExecutor, as_completed
import signal import signal
import tomllib import tomllib
from copy import deepcopy from copy import deepcopy
import requests import requests
import click import click
from tqdm import tqdm from tqdm import tqdm
from requests.adapters import HTTPAdapter, Retry from requests.adapters import HTTPAdapter, Retry
logger = logging.getLogger(__name__) LOG_FORMAT = "%(asctime)s %(levelname)s: %(message)s (%(filename)s:%(lineno)d)"
logging.basicConfig(format=LOG_FORMAT)
logger = logging.getLogger("shadowmire")
USER_AGENT = "Shadowmire (https://github.com/taoky/shadowmire)" USER_AGENT = "Shadowmire (https://github.com/taoky/shadowmire)"
@ -93,7 +94,7 @@ class LocalVersionKV:
def batch_set(self, d: dict[str, int]) -> None: def batch_set(self, d: dict[str, int]) -> None:
cur = self.conn.cursor() cur = self.conn.cursor()
kvs = [(k, v) for k, v in d.items()] kvs = list(d.items())
cur.executemany(self.INSERT_SQL, kvs) cur.executemany(self.INSERT_SQL, kvs)
self.conn.commit() self.conn.commit()
@ -269,11 +270,19 @@ class PyPI:
) )
self.session = create_requests_session() self.session = create_requests_session()
def list_packages_with_serial(self) -> dict[str, int]: def list_packages_with_serial(self, do_normalize: bool = True) -> dict[str, int]:
logger.info( logger.info(
"Calling list_packages_with_serial() RPC, this requires some time..." "Calling list_packages_with_serial() RPC, this requires some time..."
) )
return self.xmlrpc_client.list_packages_with_serial() # type: ignore ret: dict[str, int] = self.xmlrpc_client.list_packages_with_serial() # type: ignore
if do_normalize:
for key in list(ret.keys()):
normalized_key = normalize(key)
if normalized_key == key:
continue
ret[normalized_key] = ret[key]
del ret[key]
return ret
def get_package_metadata(self, package_name: str) -> dict: def get_package_metadata(self, package_name: str) -> dict:
req = self.session.get(urljoin(self.host, f"pypi/{package_name}/json")) req = self.session.get(urljoin(self.host, f"pypi/{package_name}/json"))
@ -473,7 +482,7 @@ class SyncBase:
package_names: list[str], package_names: list[str],
prerelease_excludes: list[re.Pattern[str]], prerelease_excludes: list[re.Pattern[str]],
compare_size: bool, compare_size: bool,
) -> None: ) -> bool:
to_update = [] to_update = []
for package_name in tqdm(package_names, desc="Checking consistency"): for package_name in tqdm(package_names, desc="Checking consistency"):
package_jsonmeta_path = self.jsonmeta_dir / package_name package_jsonmeta_path = self.jsonmeta_dir / package_name
@ -526,11 +535,12 @@ class SyncBase:
if should_update: if should_update:
to_update.append(package_name) to_update.append(package_name)
logger.info("%s packages to update in check_and_update()", len(to_update)) logger.info("%s packages to update in check_and_update()", len(to_update))
self.parallel_update(to_update, prerelease_excludes) return self.parallel_update(to_update, prerelease_excludes)
def parallel_update( def parallel_update(
self, package_names: list, prerelease_excludes: list[re.Pattern[str]] self, package_names: list, prerelease_excludes: list[re.Pattern[str]]
) -> None: ) -> bool:
success = True
with ThreadPoolExecutor(max_workers=WORKERS) as executor: with ThreadPoolExecutor(max_workers=WORKERS) as executor:
futures = { futures = {
executor.submit( executor.submit(
@ -551,13 +561,12 @@ class SyncBase:
if serial: if serial:
self.local_db.set(package_name, serial) self.local_db.set(package_name, serial)
except Exception as e: except Exception as e:
if isinstance(e, ExitProgramException) or isinstance( if isinstance(e, (ExitProgramException, KeyboardInterrupt)):
e, KeyboardInterrupt
):
raise raise
logger.warning( logger.warning(
"%s generated an exception", package_name, exc_info=True "%s generated an exception", package_name, exc_info=True
) )
success = False
if idx % 100 == 0: if idx % 100 == 0:
logger.info("dumping local db...") logger.info("dumping local db...")
self.local_db.dump_json() self.local_db.dump_json()
@ -566,17 +575,18 @@ class SyncBase:
for future in futures: for future in futures:
future.cancel() future.cancel()
sys.exit(1) sys.exit(1)
return success
def do_sync_plan( def do_sync_plan(
self, plan: Plan, prerelease_excludes: list[re.Pattern[str]] self, plan: Plan, prerelease_excludes: list[re.Pattern[str]]
) -> None: ) -> bool:
to_remove = plan.remove to_remove = plan.remove
to_update = plan.update to_update = plan.update
for package_name in to_remove: for package_name in to_remove:
self.do_remove(package_name) self.do_remove(package_name)
self.parallel_update(to_update, prerelease_excludes) return self.parallel_update(to_update, prerelease_excludes)
def do_remove( def do_remove(
self, package_name: str, use_db: bool = True, remove_packages: bool = True self, package_name: str, use_db: bool = True, remove_packages: bool = True
@ -585,7 +595,7 @@ class SyncBase:
package_simple_dir = self.simple_dir / package_name package_simple_dir = self.simple_dir / package_name
if metajson_path.exists() or package_simple_dir.exists(): if metajson_path.exists() or package_simple_dir.exists():
# To make this less noisy... # To make this less noisy...
logger.info("removing %s", package_name) logger.info("Removing package %s", package_name)
packages_to_remove = get_existing_hrefs(package_simple_dir) packages_to_remove = get_existing_hrefs(package_simple_dir)
if remove_packages and packages_to_remove: if remove_packages and packages_to_remove:
paths_to_remove = [package_simple_dir / p for p in packages_to_remove] paths_to_remove = [package_simple_dir / p for p in packages_to_remove]
@ -680,10 +690,7 @@ class SyncPyPI(SyncBase):
super().__init__(basedir, local_db, sync_packages) super().__init__(basedir, local_db, sync_packages)
def fetch_remote_versions(self) -> dict[str, int]: def fetch_remote_versions(self) -> dict[str, int]:
remote_serials = self.pypi.list_packages_with_serial() ret = self.pypi.list_packages_with_serial()
ret = {}
for key in remote_serials:
ret[normalize(key)] = remote_serials[key]
logger.info("Remote has %s packages", len(ret)) logger.info("Remote has %s packages", len(ret))
with overwrite(self.basedir / "remote.json") as f: with overwrite(self.basedir / "remote.json") as f:
json.dump(ret, f) json.dump(ret, f)
@ -744,7 +751,7 @@ class SyncPyPI(SyncBase):
if dest.exists(): if dest.exists():
continue continue
dest.parent.mkdir(parents=True, exist_ok=True) dest.parent.mkdir(parents=True, exist_ok=True)
success, resp = download(self.session, url, dest) success, _resp = download(self.session, url, dest)
if not success: if not success:
logger.warning("skipping %s as it fails downloading", package_name) logger.warning("skipping %s as it fails downloading", package_name)
return None return None
@ -770,16 +777,26 @@ class SyncPlainHTTP(SyncBase):
basedir: Path, basedir: Path,
local_db: LocalVersionKV, local_db: LocalVersionKV,
sync_packages: bool = False, sync_packages: bool = False,
use_pypi_index: bool = False,
) -> None: ) -> None:
self.upstream = upstream self.upstream = upstream
self.session = create_requests_session() self.session = create_requests_session()
self.pypi: Optional[PyPI]
if use_pypi_index:
self.pypi = PyPI()
else:
self.pypi = None
super().__init__(basedir, local_db, sync_packages) super().__init__(basedir, local_db, sync_packages)
def fetch_remote_versions(self) -> dict[str, int]: def fetch_remote_versions(self) -> dict[str, int]:
remote_url = urljoin(self.upstream, "local.json") remote: dict[str, int]
resp = self.session.get(remote_url) if not self.pypi:
resp.raise_for_status() remote_url = urljoin(self.upstream, "local.json")
remote: dict[str, int] = resp.json() resp = self.session.get(remote_url)
resp.raise_for_status()
remote = resp.json()
else:
remote = self.pypi.list_packages_with_serial()
logger.info("Remote has %s packages", len(remote)) logger.info("Remote has %s packages", len(remote))
with overwrite(self.basedir / "remote.json") as f: with overwrite(self.basedir / "remote.json") as f:
json.dump(remote, f) json.dump(remote, f)
@ -863,7 +880,7 @@ def get_local_serial(package_meta_path: Path) -> Optional[int]:
with open(package_meta_path) as f: with open(package_meta_path) as f:
contents = f.read() contents = f.read()
except FileNotFoundError: except FileNotFoundError:
logger.warning("%s does not have index.html, skipping", package_name) logger.warning("%s does not have JSON metadata, skipping", package_name)
return None return None
try: try:
meta = json.loads(contents) meta = json.loads(contents)
@ -886,13 +903,18 @@ def sync_shared_args(func: Callable[..., Any]) -> Callable[..., Any]:
type=str, type=str,
help="Use another upstream using shadowmire instead of PyPI", help="Use another upstream using shadowmire instead of PyPI",
), ),
click.option(
"--use-pypi-index/--no-use-pypi-index",
default=False,
help="Always use PyPI index metadata (via XMLRPC). It's no-op without --shadowmire-upstream. Some packages might not be downloaded successfully.",
),
click.option( click.option(
"--exclude", multiple=True, help="Remote package names to exclude. Regex." "--exclude", multiple=True, help="Remote package names to exclude. Regex."
), ),
click.option( click.option(
"--prerelease-exclude", "--prerelease-exclude",
multiple=True, multiple=True,
help="Package names that shall exclude prerelease. Regex.", help="Package names of which prereleases will be excluded. Regex.",
), ),
] ]
for option in shared_options[::-1]: for option in shared_options[::-1]:
@ -903,6 +925,10 @@ def sync_shared_args(func: Callable[..., Any]) -> Callable[..., Any]:
def read_config( def read_config(
ctx: click.Context, param: click.Option, filename: Optional[str] ctx: click.Context, param: click.Option, filename: Optional[str]
) -> None: ) -> None:
# Set default repo as cwd
ctx.default_map = {}
ctx.default_map["repo"] = "."
if filename is None: if filename is None:
return return
with open(filename, "rb") as f: with open(filename, "rb") as f:
@ -911,12 +937,16 @@ def read_config(
options = dict(data["options"]) options = dict(data["options"])
except KeyError: except KeyError:
options = {} options = {}
ctx.default_map = { if options.get("repo"):
"sync": options, ctx.default_map["repo"] = options["repo"]
"verify": options, del options["repo"]
"do-update": options,
"do-remove": options, logger.info("Read options from %s: %s", filename, options)
}
ctx.default_map["sync"] = options
ctx.default_map["verify"] = options
ctx.default_map["do-update"] = options
ctx.default_map["do-remove"] = options
@click.group() @click.group()
@ -927,8 +957,9 @@ def read_config(
callback=read_config, callback=read_config,
expose_value=False, expose_value=False,
) )
@click.option("--repo", type=click.Path(file_okay=False), help="Repo (basedir) path")
@click.pass_context @click.pass_context
def cli(ctx: click.Context) -> None: def cli(ctx: click.Context, repo: str) -> None:
log_level = logging.DEBUG if os.environ.get("DEBUG") else logging.INFO log_level = logging.DEBUG if os.environ.get("DEBUG") else logging.INFO
logging.basicConfig(level=log_level) logging.basicConfig(level=log_level)
ctx.ensure_object(dict) ctx.ensure_object(dict)
@ -940,7 +971,7 @@ def cli(ctx: click.Context) -> None:
logger.warning("Don't blame me if you were banned!") logger.warning("Don't blame me if you were banned!")
# Make sure basedir is absolute # Make sure basedir is absolute
basedir = Path(os.environ.get("REPO", ".")).resolve() basedir = Path(repo).resolve()
local_db = LocalVersionKV(basedir / "local.db", basedir / "local.json") local_db = LocalVersionKV(basedir / "local.db", basedir / "local.json")
ctx.obj["basedir"] = basedir ctx.obj["basedir"] = basedir
@ -956,6 +987,7 @@ def get_syncer(
local_db: LocalVersionKV, local_db: LocalVersionKV,
sync_packages: bool, sync_packages: bool,
shadowmire_upstream: Optional[str], shadowmire_upstream: Optional[str],
use_pypi_index: bool,
) -> SyncBase: ) -> SyncBase:
syncer: SyncBase syncer: SyncBase
if shadowmire_upstream: if shadowmire_upstream:
@ -964,6 +996,7 @@ def get_syncer(
basedir=basedir, basedir=basedir,
local_db=local_db, local_db=local_db,
sync_packages=sync_packages, sync_packages=sync_packages,
use_pypi_index=use_pypi_index,
) )
else: else:
syncer = SyncPyPI( syncer = SyncPyPI(
@ -981,20 +1014,28 @@ def sync(
shadowmire_upstream: Optional[str], shadowmire_upstream: Optional[str],
exclude: tuple[str], exclude: tuple[str],
prerelease_exclude: tuple[str], prerelease_exclude: tuple[str],
use_pypi_index: bool,
) -> None: ) -> None:
basedir: Path = ctx.obj["basedir"] basedir: Path = ctx.obj["basedir"]
local_db: LocalVersionKV = ctx.obj["local_db"] local_db: LocalVersionKV = ctx.obj["local_db"]
excludes = exclude_to_excludes(exclude) excludes = exclude_to_excludes(exclude)
prerelease_excludes = exclude_to_excludes(prerelease_exclude) prerelease_excludes = exclude_to_excludes(prerelease_exclude)
syncer = get_syncer(basedir, local_db, sync_packages, shadowmire_upstream) syncer = get_syncer(
basedir, local_db, sync_packages, shadowmire_upstream, use_pypi_index
)
local = local_db.dump(skip_invalid=False) local = local_db.dump(skip_invalid=False)
plan = syncer.determine_sync_plan(local, excludes) plan = syncer.determine_sync_plan(local, excludes)
# save plan for debugging # save plan for debugging
with overwrite(basedir / "plan.json") as f: with overwrite(basedir / "plan.json") as f:
json.dump(plan, f, default=vars, indent=2) json.dump(plan, f, default=vars, indent=2)
syncer.do_sync_plan(plan, prerelease_excludes) success = syncer.do_sync_plan(plan, prerelease_excludes)
syncer.finalize() syncer.finalize()
logger.info("Synchronization finished. Success: %s", success)
if not success:
sys.exit(1)
@cli.command(help="(Re)generate local db and json from json/") @cli.command(help="(Re)generate local db and json from json/")
@click.pass_context @click.pass_context
@ -1011,6 +1052,7 @@ def genlocal(ctx: click.Context) -> None:
serial = get_local_serial(package_metapath) serial = get_local_serial(package_metapath)
if serial: if serial:
local[package_name] = serial local[package_name] = serial
logger.info("%d out of {} packages have valid serial number", len(local), len(dir_items))
local_db.nuke(commit=False) local_db.nuke(commit=False)
local_db.batch_set(local) local_db.batch_set(local)
local_db.dump_json() local_db.dump_json()
@ -1037,21 +1079,24 @@ def verify(
prerelease_exclude: tuple[str], prerelease_exclude: tuple[str],
remove_not_in_local: bool, remove_not_in_local: bool,
compare_size: bool, compare_size: bool,
use_pypi_index: bool,
) -> None: ) -> None:
basedir: Path = ctx.obj["basedir"] basedir: Path = ctx.obj["basedir"]
local_db: LocalVersionKV = ctx.obj["local_db"] local_db: LocalVersionKV = ctx.obj["local_db"]
excludes = exclude_to_excludes(exclude) excludes = exclude_to_excludes(exclude)
prerelease_excludes = exclude_to_excludes(prerelease_exclude) prerelease_excludes = exclude_to_excludes(prerelease_exclude)
syncer = get_syncer(basedir, local_db, sync_packages, shadowmire_upstream) syncer = get_syncer(
basedir, local_db, sync_packages, shadowmire_upstream, use_pypi_index
)
logger.info("remove packages NOT in local db") logger.info("====== Step 1. Remove packages NOT in local db ======")
local_names = set(local_db.keys()) local_names = set(local_db.keys())
simple_dirs = set([i.name for i in (basedir / "simple").iterdir() if i.is_dir()]) simple_dirs = {i.name for i in (basedir / "simple").iterdir() if i.is_dir()}
json_files = set([i.name for i in (basedir / "json").iterdir() if i.is_file()]) json_files = {i.name for i in (basedir / "json").iterdir() if i.is_file()}
not_in_local = (simple_dirs | json_files) - local_names not_in_local = (simple_dirs | json_files) - local_names
logger.info("%s packages NOT in local db", len(not_in_local)) logger.info("%d out of %d local packages NOT in local db", len(not_in_local), len(local_names))
for package_name in not_in_local: for package_name in not_in_local:
logger.debug("package %s not in local db", package_name) logger.info("package %s not in local db", package_name)
if remove_not_in_local: if remove_not_in_local:
# Old bandersnatch would download packages without normalization, # Old bandersnatch would download packages without normalization,
# in which case one package file could have multiple "packages" # in which case one package file could have multiple "packages"
@ -1060,28 +1105,30 @@ def verify(
# In step 4 unreferenced files would be removed, anyway. # In step 4 unreferenced files would be removed, anyway.
syncer.do_remove(package_name, remove_packages=False) syncer.do_remove(package_name, remove_packages=False)
logger.info("remove packages NOT in remote") logger.info("====== Step 2. Remove packages NOT in remote index ======")
local = local_db.dump(skip_invalid=False) local = local_db.dump(skip_invalid=False)
plan = syncer.determine_sync_plan(local, excludes) plan = syncer.determine_sync_plan(local, excludes)
logger.info( logger.info(
"%s packages NOT in remote -- this might contain packages that also do not exist locally", "%s packages NOT in remote index -- this might contain packages that also do not exist locally",
len(plan.remove), len(plan.remove),
) )
for package_name in plan.remove: for package_name in plan.remove:
# We only take the plan.remove part here # We only take the plan.remove part here
logger.debug("package %s not in remote index", package_name) logger.info("package %s not in remote index", package_name)
syncer.do_remove(package_name, remove_packages=False) syncer.do_remove(package_name, remove_packages=False)
# After some removal, local_names is changed. # After some removal, local_names is changed.
local_names = set(local_db.keys()) local_names = set(local_db.keys())
logger.info( logger.info(
"make sure all local indexes are valid, and (if --sync-packages) have valid local package files" "====== Step 3. Make sure all local indexes are valid, and (if --sync-packages) have valid local package files ======"
)
success = syncer.check_and_update(
list(local_names), prerelease_excludes, compare_size
) )
syncer.check_and_update(list(local_names), prerelease_excludes, compare_size)
syncer.finalize() syncer.finalize()
logger.info("delete unreferenced files in `packages` folder") logger.info("====== Step 4. Remove any unreferenced files in `packages` folder ======")
ref_set = set() ref_set = set()
for sname in tqdm(simple_dirs, desc="Iterating simple/ directory"): for sname in tqdm(simple_dirs, desc="Iterating simple/ directory"):
sd = basedir / "simple" / sname sd = basedir / "simple" / sname
@ -1100,8 +1147,13 @@ def verify(
# just convert to str to match normpath result # just convert to str to match normpath result
logger.debug("find file %s", file) logger.debug("find file %s", file)
if str(file) not in ref_set: if str(file) not in ref_set:
logger.info("removing unreferenced %s", file) logger.info("removing unreferenced file %s", file)
file.unlink() file.unlink()
logger.info("Verification finished. Success: %s", success)
if not success:
sys.exit(1)
@cli.command(help="Manual update given package for debugging purpose") @cli.command(help="Manual update given package for debugging purpose")
@ -1114,6 +1166,7 @@ def do_update(
shadowmire_upstream: Optional[str], shadowmire_upstream: Optional[str],
exclude: tuple[str], exclude: tuple[str],
prerelease_exclude: tuple[str], prerelease_exclude: tuple[str],
use_pypi_index: bool,
package_name: str, package_name: str,
) -> None: ) -> None:
basedir: Path = ctx.obj["basedir"] basedir: Path = ctx.obj["basedir"]
@ -1122,7 +1175,9 @@ def do_update(
if excludes: if excludes:
logger.warning("--exclude is ignored in do_update()") logger.warning("--exclude is ignored in do_update()")
prerelease_excludes = exclude_to_excludes(prerelease_exclude) prerelease_excludes = exclude_to_excludes(prerelease_exclude)
syncer = get_syncer(basedir, local_db, sync_packages, shadowmire_upstream) syncer = get_syncer(
basedir, local_db, sync_packages, shadowmire_upstream, use_pypi_index
)
syncer.do_update(package_name, prerelease_excludes) syncer.do_update(package_name, prerelease_excludes)
@ -1136,13 +1191,16 @@ def do_remove(
shadowmire_upstream: Optional[str], shadowmire_upstream: Optional[str],
exclude: tuple[str], exclude: tuple[str],
prerelease_exclude: tuple[str], prerelease_exclude: tuple[str],
use_pypi_index: bool,
package_name: str, package_name: str,
) -> None: ) -> None:
basedir = ctx.obj["basedir"] basedir = ctx.obj["basedir"]
local_db = ctx.obj["local_db"] local_db = ctx.obj["local_db"]
if exclude or prerelease_exclude: if exclude or prerelease_exclude:
logger.warning("exclusion rules are ignored in do_remove()") logger.warning("exclusion rules are ignored in do_remove()")
syncer = get_syncer(basedir, local_db, sync_packages, shadowmire_upstream) syncer = get_syncer(
basedir, local_db, sync_packages, shadowmire_upstream, use_pypi_index
)
syncer.do_remove(package_name) syncer.do_remove(package_name)