diff --git a/shadowmire.py b/shadowmire.py index 68e23d2..1952f1c 100644 --- a/shadowmire.py +++ b/shadowmire.py @@ -215,9 +215,8 @@ class PyPI: return prefix + parsed.path # Func modified from bandersnatch - def generate_html_simple_page( - self, package_meta: dict, package_rawname: str - ) -> str: + def generate_html_simple_page(self, package_meta: dict) -> str: + package_rawname = package_meta["info"]["name"] simple_page_content = ( "\n" "\n" @@ -355,6 +354,36 @@ class SyncBase: def fetch_remote_versions(self) -> dict[str, int]: raise NotImplementedError + def parallel_update(self, package_names: list) -> None: + with ThreadPoolExecutor(max_workers=self.workers) as executor: + futures = { + executor.submit(self.do_update, package_name, False): ( + idx, + package_name, + ) + for idx, package_name in enumerate(package_names) + } + try: + for future in tqdm(as_completed(futures), total=len(package_names)): + idx, package_name = futures[future] + try: + serial = future.result() + if serial: + self.local_db.set(package_name, serial) + except Exception as e: + if e is ExitProgramException or e is KeyboardInterrupt: + raise + logger.warning( + "%s generated an exception", package_name, exc_info=True + ) + if idx % 1000 == 0: + self.local_db.dump_json() + except (ExitProgramException, KeyboardInterrupt): + logger.info("Get ExitProgramException or KeyboardInterrupt, exiting...") + for future in futures: + future.cancel() + sys.exit(1) + def do_sync_plan(self, plan: Plan) -> None: assert self.remote to_remove = plan.remove @@ -363,35 +392,7 @@ class SyncBase: for package_name in to_remove: self.do_remove(package_name) - with ThreadPoolExecutor(max_workers=self.workers) as executor: - futures = { - executor.submit(self.do_update, package_name, False): ( - idx, - package_name, - ) - for idx, package_name in enumerate(to_update) - } - - try: - for future in tqdm(as_completed(futures), total=len(to_update)): - idx, package_name = futures[future] - try: - serial = future.result() - if serial: - self.local_db.set(package_name, serial) - except Exception as e: - if e is ExitProgramException: - raise - logger.warning( - "%s generated an exception", package_name, exc_info=True - ) - if idx % 1000 == 0: - self.local_db.dump_json() - except ExitProgramException: - logger.info("Get ExitProgramException, exiting...") - for future in futures: - future.cancel() - sys.exit(1) + self.parallel_update(to_update) def do_remove(self, package_name: str) -> None: logger.info("removing %s", package_name) @@ -462,11 +463,40 @@ class SyncPyPI(SyncBase): except PackageNotFoundError: logger.warning("%s missing from upstream, skip.", package_name) return None + + if self.sync_packages: + # sync packages first, then sync index + existing_hrefs = [] + try: + with open(package_simple_path / "index.html") as f: + contents = f.read() + existing_hrefs = get_packages_from_index_html(contents) + except FileNotFoundError: + pass + release_files = self.pypi.get_release_files_from_meta(meta) + # remove packages that no longer exist remotely + remote_hrefs = [self.pypi._file_url_to_local_url(i["url"]) for i in release_files] + should_remove = list(set(existing_hrefs) - set(remote_hrefs)) + for p in should_remove: + logger.info("removing file %s (if exists)", p) + package_path = (package_simple_path / p).resolve() + package_path.unlink(missing_ok=True) + for i in release_files: + url = i["url"] + dest = (package_simple_path / self.pypi._file_url_to_local_url(i["url"])).resolve() + logger.info("downloading file %s -> %s", url, dest) + if dest.exists(): + continue + dest.parent.mkdir(parents=True, exist_ok=True) + resp = self.session.get(url) + if resp.status_code >= 400: + logger.warning("download %s failed, skipping this package", url) + return None + with overwrite(dest, "wb") as f: + f.write(resp.content) last_serial: int = meta["last_serial"] - # OK, here we don't bother store raw name - # Considering that JSON API even does not give package raw name, why bother we use it? - simple_html_contents = self.pypi.generate_html_simple_page(meta, package_name) + simple_html_contents = self.pypi.generate_html_simple_page(meta) simple_json_contents = self.pypi.generate_json_simple_page(meta) for html_filename in ("index.html", "index.v1_html"): @@ -478,9 +508,6 @@ class SyncPyPI(SyncBase): with overwrite(json_path) as f: f.write(simple_json_contents) - if self.sync_packages: - raise NotImplementedError - if write_db: self.local_db.set(package_name, last_serial) @@ -565,7 +592,7 @@ def main(args: argparse.Namespace) -> None: local_db = LocalVersionKV(basedir / "local.db", basedir / "local.json") if args.command == "sync": - sync = SyncPyPI(basedir=basedir, local_db=local_db) + sync = SyncPyPI(basedir=basedir, local_db=local_db, sync_packages=args.sync_packages) local = local_db.dump() plan = sync.determine_sync_plan(local) # save plan for debugging @@ -589,8 +616,7 @@ def main(args: argparse.Namespace) -> None: simple_dirs = set([i.name for i in (basedir / "simple").iterdir()]) for package_name in simple_dirs - local_names: sync.do_remove(package_name) - for package_name in local_names: - sync.do_update(package_name) + sync.parallel_update(list(local_names)) sync.finalize() @@ -599,6 +625,7 @@ if __name__ == "__main__": subparsers = parser.add_subparsers(dest="command") parser_sync = subparsers.add_parser("sync", help="Sync from upstream") + parser_sync.add_argument("--sync-packages", help="Sync packages instead of just indexes", action='store_true') parser_genlocal = subparsers.add_parser( "genlocal", help="(Re)generate local db and json from simple/" )