diff --git a/shadowmire.py b/shadowmire.py index 209fa50..c0df66d 100644 --- a/shadowmire.py +++ b/shadowmire.py @@ -23,9 +23,11 @@ logger = logging.getLogger(__name__) USER_AGENT = "Shadowmire (https://github.com/taoky/shadowmire)" # Note that it's suggested to use only 3 workers for PyPI. -WORKERS = int(os.environ("SHADOWMIRE_WORKERS", "3")) +WORKERS = int(os.environ.get("SHADOWMIRE_WORKERS", "3")) if WORKERS > 10: - logger.warning("You have set a worker value larger than 10, which is forbidden by PyPI maintainers.") + logger.warning( + "You have set a worker value larger than 10, which is forbidden by PyPI maintainers." + ) logger.warning("Don't blame me if you were banned!") @@ -342,11 +344,24 @@ class SyncBase: self.simple_dir.mkdir(parents=True, exist_ok=True) self.packages_dir.mkdir(parents=True, exist_ok=True) self.sync_packages = sync_packages - self.remote: Optional[dict[str, int]] = None + + def filter_remote_with_excludes(self, remote: dict[str, int], excludes: list[re.Pattern]) -> dict[str, int]: + if not excludes: + return remote + res = {} + for k, v in remote.items(): + matched = False + for exclude in excludes: + if exclude.match(k): + matched = True + break + if not matched: + res[k] = v + return res - def determine_sync_plan(self, local: dict[str, int]) -> Plan: + def determine_sync_plan(self, local: dict[str, int], excludes: list[re.Pattern]) -> Plan: remote = self.fetch_remote_versions() - self.remote = remote + remote = self.filter_remote_with_excludes(remote, excludes) # store remote to remote.json with overwrite(self.basedir / "remote.json") as f: json.dump(remote, f) @@ -356,6 +371,7 @@ class SyncBase: remote_keys = set(remote.keys()) for i in local_keys - remote_keys: to_remove.append(i) + local_keys.remove(i) for i in remote_keys - local_keys: to_update.append(i) for i in local_keys: @@ -400,7 +416,6 @@ class SyncBase: sys.exit(1) def do_sync_plan(self, plan: Plan) -> None: - assert self.remote to_remove = plan.remove to_update = plan.update @@ -580,9 +595,7 @@ class SyncPlainHTTP(SyncBase): package_simple_url = urljoin(self.upstream, f"/simple/{package_name}/") for href in current_hrefs: url = urljoin(package_simple_url, href) - dest = ( - package_simple_path / href - ).resolve() + dest = (package_simple_path / href).resolve() logger.info("downloading file %s -> %s", url, dest) if dest.exists(): continue @@ -620,6 +633,8 @@ def get_local_serial(package_simple_path: Path) -> Optional[int]: def main(args: argparse.Namespace) -> None: log_level = logging.DEBUG if os.environ.get("DEBUG") else logging.INFO logging.basicConfig(level=log_level) + logger.debug(args) + basedir = Path(".") local_db = LocalVersionKV(basedir / "local.db", basedir / "local.json") @@ -628,7 +643,7 @@ def main(args: argparse.Namespace) -> None: basedir=basedir, local_db=local_db, sync_packages=args.sync_packages ) local = local_db.dump() - plan = sync.determine_sync_plan(local) + plan = sync.determine_sync_plan(local, args.excludes) # save plan for debugging with overwrite(basedir / "plan.json") as f: json.dump(plan, f, default=vars) @@ -649,7 +664,7 @@ def main(args: argparse.Namespace) -> None: basedir=basedir, local_db=local_db, sync_packages=args.sync_packages ) local_names = set(local_db.keys()) - simple_dirs = set([i.name for i in (basedir / "simple").iterdir()]) + simple_dirs = set([i.name for i in (basedir / "simple").iterdir() if i.is_dir()]) for package_name in simple_dirs - local_names: sync.do_remove(package_name) sync.parallel_update(list(local_names)) @@ -657,14 +672,13 @@ def main(args: argparse.Namespace) -> None: # clean up unreferenced package files ref_set = set() for sname in simple_dirs: - sd = basedir / sname - index_html = sd / "index.html" - hrefs = get_existing_hrefs(index_html) + sd = basedir / "simple" / sname + hrefs = get_existing_hrefs(sd) for i in hrefs: ref_set.add(str((sd / i).resolve())) for file in (basedir / "packages").glob("*/*/*/*"): file = file.resolve() - if file not in ref_set: + if str(file) not in ref_set: logger.info("removing unreferenced %s", file) file.unlink() @@ -679,6 +693,9 @@ if __name__ == "__main__": help="Sync packages instead of just indexes", action="store_true", ) + parser_sync.add_argument( + "--exclude", help="Remote package names to exclude. Regex.", nargs="*" + ) parser_genlocal = subparsers.add_parser( "genlocal", help="(Re)generate local db and json from simple/" ) @@ -696,4 +713,6 @@ if __name__ == "__main__": if args.command is None: parser.print_help() sys.exit(1) + if args.command == "sync" and args.exclude: + args.excludes = [re.compile(i) for i in args.exclude] main(args)