Add prerelease exclusion

This commit is contained in:
taoky 2024-08-04 16:20:08 +08:00
parent c695419700
commit 5dc892c0b3

View File

@ -42,13 +42,6 @@ PRERELEASE_PATTERNS = (
) )
def is_version_prerelease(version: str) -> bool:
for p in PRERELEASE_PATTERNS:
if p.match(version):
return True
return False
class PackageNotFoundError(Exception): class PackageNotFoundError(Exception):
pass pass
@ -353,6 +346,15 @@ class Plan:
update: list[str] update: list[str]
def match_patterns(
s: str, ps: list[re.Pattern[str]] | tuple[re.Pattern[str], ...]
) -> bool:
for p in ps:
if p.match(s):
return True
return False
class SyncBase: class SyncBase:
def __init__( def __init__(
self, basedir: Path, local_db: LocalVersionKV, sync_packages: bool = False self, basedir: Path, local_db: LocalVersionKV, sync_packages: bool = False
@ -367,23 +369,19 @@ class SyncBase:
self.sync_packages = sync_packages self.sync_packages = sync_packages
def filter_remote_with_excludes( def filter_remote_with_excludes(
self, remote: dict[str, int], excludes: list[re.Pattern] self, remote: dict[str, int], excludes: list[re.Pattern[str]]
) -> dict[str, int]: ) -> dict[str, int]:
if not excludes: if not excludes:
return remote return remote
res = {} res = {}
for k, v in remote.items(): for k, v in remote.items():
matched = False matched = match_patterns(k, excludes)
for exclude in excludes:
if exclude.match(k):
matched = True
break
if not matched: if not matched:
res[k] = v res[k] = v
return res return res
def determine_sync_plan( def determine_sync_plan(
self, local: dict[str, int], excludes: list[re.Pattern] self, local: dict[str, int], excludes: list[re.Pattern[str]]
) -> Plan: ) -> Plan:
remote = self.fetch_remote_versions() remote = self.fetch_remote_versions()
remote = self.filter_remote_with_excludes(remote, excludes) remote = self.filter_remote_with_excludes(remote, excludes)
@ -410,10 +408,14 @@ class SyncBase:
def fetch_remote_versions(self) -> dict[str, int]: def fetch_remote_versions(self) -> dict[str, int]:
raise NotImplementedError raise NotImplementedError
def parallel_update(self, package_names: list) -> None: def parallel_update(
self, package_names: list, prerelease_excludes: list[re.Pattern[str]]
) -> None:
with ThreadPoolExecutor(max_workers=WORKERS) as executor: with ThreadPoolExecutor(max_workers=WORKERS) as executor:
futures = { futures = {
executor.submit(self.do_update, package_name, False): ( executor.submit(
self.do_update, package_name, prerelease_excludes, False
): (
idx, idx,
package_name, package_name,
) )
@ -440,14 +442,14 @@ class SyncBase:
future.cancel() future.cancel()
sys.exit(1) sys.exit(1)
def do_sync_plan(self, plan: Plan) -> None: def do_sync_plan(self, plan: Plan, prerelease_excludes: list[re.Pattern[str]]) -> None:
to_remove = plan.remove to_remove = plan.remove
to_update = plan.update to_update = plan.update
for package_name in to_remove: for package_name in to_remove:
self.do_remove(package_name) self.do_remove(package_name)
self.parallel_update(to_update) self.parallel_update(to_update, prerelease_excludes)
def do_remove(self, package_name: str) -> None: def do_remove(self, package_name: str) -> None:
logger.info("removing %s", package_name) logger.info("removing %s", package_name)
@ -469,7 +471,12 @@ class SyncBase:
self.local_db.remove(package_name) self.local_db.remove(package_name)
remove_dir_with_files(meta_dir) remove_dir_with_files(meta_dir)
def do_update(self, package_name: str, write_db: bool = True) -> Optional[int]: def do_update(
self,
package_name: str,
prerelease_excludes: list[re.Pattern[str]],
write_db: bool = True,
) -> Optional[int]:
raise NotImplementedError raise NotImplementedError
def finalize(self) -> None: def finalize(self) -> None:
@ -519,7 +526,12 @@ class SyncPyPI(SyncBase):
ret[normalize(key)] = remote_serials[key] ret[normalize(key)] = remote_serials[key]
return ret return ret
def do_update(self, package_name: str, write_db: bool = True) -> Optional[int]: def do_update(
self,
package_name: str,
prerelease_excludes: list[re.Pattern[str]],
write_db: bool = True,
) -> Optional[int]:
logger.info("updating %s", package_name) logger.info("updating %s", package_name)
package_simple_path = self.simple_dir / package_name package_simple_path = self.simple_dir / package_name
package_simple_path.mkdir(exist_ok=True) package_simple_path.mkdir(exist_ok=True)
@ -530,6 +542,12 @@ class SyncPyPI(SyncBase):
logger.warning("%s missing from upstream, skip.", package_name) logger.warning("%s missing from upstream, skip.", package_name)
return None return None
# filter prerelease, if necessary
if prerelease_excludes:
for release in list(meta["releases"].keys()):
if match_patterns(release, PRERELEASE_PATTERNS):
del meta["releases"][release]
if self.sync_packages: if self.sync_packages:
# sync packages first, then sync index # sync packages first, then sync index
existing_hrefs = get_existing_hrefs(package_simple_path) existing_hrefs = get_existing_hrefs(package_simple_path)
@ -592,7 +610,16 @@ class SyncPlainHTTP(SyncBase):
remote: dict[str, int] = resp.json() remote: dict[str, int] = resp.json()
return remote return remote
def do_update(self, package_name: str, write_db: bool = True) -> Optional[int]: def do_update(
self,
package_name: str,
prerelease_excludes: list[re.Pattern[str]],
write_db: bool = True,
) -> Optional[int]:
if prerelease_excludes:
logger.warning(
"prerelease_excludes is currently ignored in SyncPlainHTTP mode."
)
logger.info("updating %s", package_name) logger.info("updating %s", package_name)
package_simple_path = self.simple_dir / package_name package_simple_path = self.simple_dir / package_name
package_simple_path.mkdir(exist_ok=True) package_simple_path.mkdir(exist_ok=True)
@ -671,6 +698,11 @@ def sync_shared_args(func):
click.option( click.option(
"--exclude", multiple=True, help="Remote package names to exclude. Regex." "--exclude", multiple=True, help="Remote package names to exclude. Regex."
), ),
click.option(
"--prerelease-exclude",
multiple=True,
help="Package names that shall exclude prerelease. Regex.",
),
] ]
for option in shared_options[::-1]: for option in shared_options[::-1]:
func = option(func) func = option(func)
@ -691,7 +723,7 @@ def cli(ctx: click.Context) -> None:
ctx.obj["local_db"] = local_db ctx.obj["local_db"] = local_db
def exclude_to_excludes(exclude: tuple[str]) -> list[re.Pattern]: def exclude_to_excludes(exclude: tuple[str]) -> list[re.Pattern[str]]:
return [re.compile(i) for i in exclude] return [re.compile(i) for i in exclude]
@ -724,17 +756,19 @@ def sync(
sync_packages: bool, sync_packages: bool,
shadowmire_upstream: Optional[str], shadowmire_upstream: Optional[str],
exclude: tuple[str], exclude: tuple[str],
prerelease_exclude: tuple[str],
) -> None: ) -> None:
basedir = ctx.obj["basedir"] basedir = ctx.obj["basedir"]
local_db = ctx.obj["local_db"] local_db = ctx.obj["local_db"]
excludes = exclude_to_excludes(exclude) excludes = exclude_to_excludes(exclude)
prerelease_excludes = exclude_to_excludes(prerelease_exclude)
syncer = get_syncer(basedir, local_db, sync_packages, shadowmire_upstream) syncer = get_syncer(basedir, local_db, sync_packages, shadowmire_upstream)
local = local_db.dump() local = local_db.dump()
plan = syncer.determine_sync_plan(local, excludes) plan = syncer.determine_sync_plan(local, excludes)
# save plan for debugging # save plan for debugging
with overwrite(basedir / "plan.json") as f: with overwrite(basedir / "plan.json") as f:
json.dump(plan, f, default=vars) json.dump(plan, f, default=vars)
syncer.do_sync_plan(plan) syncer.do_sync_plan(plan, prerelease_excludes)
syncer.finalize() syncer.finalize()
@ -762,16 +796,18 @@ def verify(
sync_packages: bool, sync_packages: bool,
shadowmire_upstream: Optional[str], shadowmire_upstream: Optional[str],
exclude: tuple[str], exclude: tuple[str],
prerelease_exclude: tuple[str],
) -> None: ) -> None:
basedir = ctx.obj["basedir"] basedir = ctx.obj["basedir"]
local_db = ctx.obj["local_db"] local_db = ctx.obj["local_db"]
excludes = exclude_to_excludes(exclude) excludes = exclude_to_excludes(exclude)
prerelease_excludes = exclude_to_excludes(prerelease_exclude)
syncer = get_syncer(basedir, local_db, sync_packages, shadowmire_upstream) syncer = get_syncer(basedir, local_db, sync_packages, shadowmire_upstream)
local_names = set(local_db.keys()) local_names = set(local_db.keys())
simple_dirs = set([i.name for i in (basedir / "simple").iterdir() if i.is_dir()]) simple_dirs = set([i.name for i in (basedir / "simple").iterdir() if i.is_dir()])
for package_name in simple_dirs - local_names: for package_name in simple_dirs - local_names:
syncer.do_remove(package_name) syncer.do_remove(package_name)
syncer.parallel_update(list(local_names)) syncer.parallel_update(list(local_names), prerelease_excludes)
syncer.finalize() syncer.finalize()
# clean up unreferenced package files # clean up unreferenced package files
ref_set = set() ref_set = set()