mirror of
https://github.com/taoky/shadowmire.git
synced 2025-07-08 17:32:43 +00:00
Add exclude func (for testing), and fix unreferenced files cleanup
This commit is contained in:
parent
103252ad14
commit
af49fb183f
@ -23,9 +23,11 @@ logger = logging.getLogger(__name__)
|
||||
USER_AGENT = "Shadowmire (https://github.com/taoky/shadowmire)"
|
||||
|
||||
# Note that it's suggested to use only 3 workers for PyPI.
|
||||
WORKERS = int(os.environ("SHADOWMIRE_WORKERS", "3"))
|
||||
WORKERS = int(os.environ.get("SHADOWMIRE_WORKERS", "3"))
|
||||
if WORKERS > 10:
|
||||
logger.warning("You have set a worker value larger than 10, which is forbidden by PyPI maintainers.")
|
||||
logger.warning(
|
||||
"You have set a worker value larger than 10, which is forbidden by PyPI maintainers."
|
||||
)
|
||||
logger.warning("Don't blame me if you were banned!")
|
||||
|
||||
|
||||
@ -342,11 +344,24 @@ class SyncBase:
|
||||
self.simple_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.packages_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.sync_packages = sync_packages
|
||||
self.remote: Optional[dict[str, int]] = None
|
||||
|
||||
def filter_remote_with_excludes(self, remote: dict[str, int], excludes: list[re.Pattern]) -> dict[str, int]:
|
||||
if not excludes:
|
||||
return remote
|
||||
res = {}
|
||||
for k, v in remote.items():
|
||||
matched = False
|
||||
for exclude in excludes:
|
||||
if exclude.match(k):
|
||||
matched = True
|
||||
break
|
||||
if not matched:
|
||||
res[k] = v
|
||||
return res
|
||||
|
||||
def determine_sync_plan(self, local: dict[str, int]) -> Plan:
|
||||
def determine_sync_plan(self, local: dict[str, int], excludes: list[re.Pattern]) -> Plan:
|
||||
remote = self.fetch_remote_versions()
|
||||
self.remote = remote
|
||||
remote = self.filter_remote_with_excludes(remote, excludes)
|
||||
# store remote to remote.json
|
||||
with overwrite(self.basedir / "remote.json") as f:
|
||||
json.dump(remote, f)
|
||||
@ -356,6 +371,7 @@ class SyncBase:
|
||||
remote_keys = set(remote.keys())
|
||||
for i in local_keys - remote_keys:
|
||||
to_remove.append(i)
|
||||
local_keys.remove(i)
|
||||
for i in remote_keys - local_keys:
|
||||
to_update.append(i)
|
||||
for i in local_keys:
|
||||
@ -400,7 +416,6 @@ class SyncBase:
|
||||
sys.exit(1)
|
||||
|
||||
def do_sync_plan(self, plan: Plan) -> None:
|
||||
assert self.remote
|
||||
to_remove = plan.remove
|
||||
to_update = plan.update
|
||||
|
||||
@ -580,9 +595,7 @@ class SyncPlainHTTP(SyncBase):
|
||||
package_simple_url = urljoin(self.upstream, f"/simple/{package_name}/")
|
||||
for href in current_hrefs:
|
||||
url = urljoin(package_simple_url, href)
|
||||
dest = (
|
||||
package_simple_path / href
|
||||
).resolve()
|
||||
dest = (package_simple_path / href).resolve()
|
||||
logger.info("downloading file %s -> %s", url, dest)
|
||||
if dest.exists():
|
||||
continue
|
||||
@ -620,6 +633,8 @@ def get_local_serial(package_simple_path: Path) -> Optional[int]:
|
||||
def main(args: argparse.Namespace) -> None:
|
||||
log_level = logging.DEBUG if os.environ.get("DEBUG") else logging.INFO
|
||||
logging.basicConfig(level=log_level)
|
||||
logger.debug(args)
|
||||
|
||||
basedir = Path(".")
|
||||
local_db = LocalVersionKV(basedir / "local.db", basedir / "local.json")
|
||||
|
||||
@ -628,7 +643,7 @@ def main(args: argparse.Namespace) -> None:
|
||||
basedir=basedir, local_db=local_db, sync_packages=args.sync_packages
|
||||
)
|
||||
local = local_db.dump()
|
||||
plan = sync.determine_sync_plan(local)
|
||||
plan = sync.determine_sync_plan(local, args.excludes)
|
||||
# save plan for debugging
|
||||
with overwrite(basedir / "plan.json") as f:
|
||||
json.dump(plan, f, default=vars)
|
||||
@ -649,7 +664,7 @@ def main(args: argparse.Namespace) -> None:
|
||||
basedir=basedir, local_db=local_db, sync_packages=args.sync_packages
|
||||
)
|
||||
local_names = set(local_db.keys())
|
||||
simple_dirs = set([i.name for i in (basedir / "simple").iterdir()])
|
||||
simple_dirs = set([i.name for i in (basedir / "simple").iterdir() if i.is_dir()])
|
||||
for package_name in simple_dirs - local_names:
|
||||
sync.do_remove(package_name)
|
||||
sync.parallel_update(list(local_names))
|
||||
@ -657,14 +672,13 @@ def main(args: argparse.Namespace) -> None:
|
||||
# clean up unreferenced package files
|
||||
ref_set = set()
|
||||
for sname in simple_dirs:
|
||||
sd = basedir / sname
|
||||
index_html = sd / "index.html"
|
||||
hrefs = get_existing_hrefs(index_html)
|
||||
sd = basedir / "simple" / sname
|
||||
hrefs = get_existing_hrefs(sd)
|
||||
for i in hrefs:
|
||||
ref_set.add(str((sd / i).resolve()))
|
||||
for file in (basedir / "packages").glob("*/*/*/*"):
|
||||
file = file.resolve()
|
||||
if file not in ref_set:
|
||||
if str(file) not in ref_set:
|
||||
logger.info("removing unreferenced %s", file)
|
||||
file.unlink()
|
||||
|
||||
@ -679,6 +693,9 @@ if __name__ == "__main__":
|
||||
help="Sync packages instead of just indexes",
|
||||
action="store_true",
|
||||
)
|
||||
parser_sync.add_argument(
|
||||
"--exclude", help="Remote package names to exclude. Regex.", nargs="*"
|
||||
)
|
||||
parser_genlocal = subparsers.add_parser(
|
||||
"genlocal", help="(Re)generate local db and json from simple/"
|
||||
)
|
||||
@ -696,4 +713,6 @@ if __name__ == "__main__":
|
||||
if args.command is None:
|
||||
parser.print_help()
|
||||
sys.exit(1)
|
||||
if args.command == "sync" and args.exclude:
|
||||
args.excludes = [re.compile(i) for i in args.exclude]
|
||||
main(args)
|
||||
|
Loading…
x
Reference in New Issue
Block a user