diff --git a/shadowmire.py b/shadowmire.py index 1b9c520..520ad66 100755 --- a/shadowmire.py +++ b/shadowmire.py @@ -42,13 +42,6 @@ PRERELEASE_PATTERNS = ( ) -def is_version_prerelease(version: str) -> bool: - for p in PRERELEASE_PATTERNS: - if p.match(version): - return True - return False - - class PackageNotFoundError(Exception): pass @@ -353,6 +346,15 @@ class Plan: update: list[str] +def match_patterns( + s: str, ps: list[re.Pattern[str]] | tuple[re.Pattern[str], ...] +) -> bool: + for p in ps: + if p.match(s): + return True + return False + + class SyncBase: def __init__( self, basedir: Path, local_db: LocalVersionKV, sync_packages: bool = False @@ -367,23 +369,19 @@ class SyncBase: self.sync_packages = sync_packages def filter_remote_with_excludes( - self, remote: dict[str, int], excludes: list[re.Pattern] + self, remote: dict[str, int], excludes: list[re.Pattern[str]] ) -> dict[str, int]: if not excludes: return remote res = {} for k, v in remote.items(): - matched = False - for exclude in excludes: - if exclude.match(k): - matched = True - break + matched = match_patterns(k, excludes) if not matched: res[k] = v return res def determine_sync_plan( - self, local: dict[str, int], excludes: list[re.Pattern] + self, local: dict[str, int], excludes: list[re.Pattern[str]] ) -> Plan: remote = self.fetch_remote_versions() remote = self.filter_remote_with_excludes(remote, excludes) @@ -410,10 +408,14 @@ class SyncBase: def fetch_remote_versions(self) -> dict[str, int]: raise NotImplementedError - def parallel_update(self, package_names: list) -> None: + def parallel_update( + self, package_names: list, prerelease_excludes: list[re.Pattern[str]] + ) -> None: with ThreadPoolExecutor(max_workers=WORKERS) as executor: futures = { - executor.submit(self.do_update, package_name, False): ( + executor.submit( + self.do_update, package_name, prerelease_excludes, False + ): ( idx, package_name, ) @@ -440,14 +442,14 @@ class SyncBase: future.cancel() sys.exit(1) - def do_sync_plan(self, plan: Plan) -> None: + def do_sync_plan(self, plan: Plan, prerelease_excludes: list[re.Pattern[str]]) -> None: to_remove = plan.remove to_update = plan.update for package_name in to_remove: self.do_remove(package_name) - self.parallel_update(to_update) + self.parallel_update(to_update, prerelease_excludes) def do_remove(self, package_name: str) -> None: logger.info("removing %s", package_name) @@ -469,7 +471,12 @@ class SyncBase: self.local_db.remove(package_name) remove_dir_with_files(meta_dir) - def do_update(self, package_name: str, write_db: bool = True) -> Optional[int]: + def do_update( + self, + package_name: str, + prerelease_excludes: list[re.Pattern[str]], + write_db: bool = True, + ) -> Optional[int]: raise NotImplementedError def finalize(self) -> None: @@ -519,7 +526,12 @@ class SyncPyPI(SyncBase): ret[normalize(key)] = remote_serials[key] return ret - def do_update(self, package_name: str, write_db: bool = True) -> Optional[int]: + def do_update( + self, + package_name: str, + prerelease_excludes: list[re.Pattern[str]], + write_db: bool = True, + ) -> Optional[int]: logger.info("updating %s", package_name) package_simple_path = self.simple_dir / package_name package_simple_path.mkdir(exist_ok=True) @@ -530,6 +542,12 @@ class SyncPyPI(SyncBase): logger.warning("%s missing from upstream, skip.", package_name) return None + # filter prerelease, if necessary + if prerelease_excludes: + for release in list(meta["releases"].keys()): + if match_patterns(release, PRERELEASE_PATTERNS): + del meta["releases"][release] + if self.sync_packages: # sync packages first, then sync index existing_hrefs = get_existing_hrefs(package_simple_path) @@ -592,7 +610,16 @@ class SyncPlainHTTP(SyncBase): remote: dict[str, int] = resp.json() return remote - def do_update(self, package_name: str, write_db: bool = True) -> Optional[int]: + def do_update( + self, + package_name: str, + prerelease_excludes: list[re.Pattern[str]], + write_db: bool = True, + ) -> Optional[int]: + if prerelease_excludes: + logger.warning( + "prerelease_excludes is currently ignored in SyncPlainHTTP mode." + ) logger.info("updating %s", package_name) package_simple_path = self.simple_dir / package_name package_simple_path.mkdir(exist_ok=True) @@ -671,6 +698,11 @@ def sync_shared_args(func): click.option( "--exclude", multiple=True, help="Remote package names to exclude. Regex." ), + click.option( + "--prerelease-exclude", + multiple=True, + help="Package names that shall exclude prerelease. Regex.", + ), ] for option in shared_options[::-1]: func = option(func) @@ -691,7 +723,7 @@ def cli(ctx: click.Context) -> None: ctx.obj["local_db"] = local_db -def exclude_to_excludes(exclude: tuple[str]) -> list[re.Pattern]: +def exclude_to_excludes(exclude: tuple[str]) -> list[re.Pattern[str]]: return [re.compile(i) for i in exclude] @@ -724,17 +756,19 @@ def sync( sync_packages: bool, shadowmire_upstream: Optional[str], exclude: tuple[str], + prerelease_exclude: tuple[str], ) -> None: basedir = ctx.obj["basedir"] local_db = ctx.obj["local_db"] excludes = exclude_to_excludes(exclude) + prerelease_excludes = exclude_to_excludes(prerelease_exclude) syncer = get_syncer(basedir, local_db, sync_packages, shadowmire_upstream) local = local_db.dump() plan = syncer.determine_sync_plan(local, excludes) # save plan for debugging with overwrite(basedir / "plan.json") as f: json.dump(plan, f, default=vars) - syncer.do_sync_plan(plan) + syncer.do_sync_plan(plan, prerelease_excludes) syncer.finalize() @@ -762,16 +796,18 @@ def verify( sync_packages: bool, shadowmire_upstream: Optional[str], exclude: tuple[str], + prerelease_exclude: tuple[str], ) -> None: basedir = ctx.obj["basedir"] local_db = ctx.obj["local_db"] excludes = exclude_to_excludes(exclude) + prerelease_excludes = exclude_to_excludes(prerelease_exclude) syncer = get_syncer(basedir, local_db, sync_packages, shadowmire_upstream) local_names = set(local_db.keys()) simple_dirs = set([i.name for i in (basedir / "simple").iterdir() if i.is_dir()]) for package_name in simple_dirs - local_names: syncer.do_remove(package_name) - syncer.parallel_update(list(local_names)) + syncer.parallel_update(list(local_names), prerelease_excludes) syncer.finalize() # clean up unreferenced package files ref_set = set()