Signed-off-by: Shengqi Chen <harry-chen@outlook.com>
This commit is contained in:
Shengqi Chen 2024-10-08 21:48:00 +08:00
parent 1a6b3f6add
commit 69b7d8d993
No known key found for this signature in database

View File

@ -2,12 +2,12 @@
import sys
from types import FrameType
from typing import IO, Any, Callable, Generator, Optional
from typing import IO, Any, Callable, Generator, Literal, NoReturn, Optional
import xmlrpc.client
from dataclasses import dataclass
import re
import json
from urllib.parse import urljoin, urlparse, urlunparse
from urllib.parse import urljoin, urlparse, urlunparse, unquote
from pathlib import Path
from html.parser import HTMLParser
import logging
@ -18,10 +18,12 @@ from os.path import (
) # fast path computation, instead of accessing real files like pathlib
from contextlib import contextmanager
import sqlite3
from concurrent.futures import ThreadPoolExecutor, as_completed
from concurrent.futures import Future, ThreadPoolExecutor, as_completed
import signal
import tomllib
from copy import deepcopy
import functools
from http.client import HTTPConnection
import requests
import click
@ -36,6 +38,10 @@ USER_AGENT = "Shadowmire (https://github.com/taoky/shadowmire)"
# Note that it's suggested to use only 3 workers for PyPI.
WORKERS = int(os.environ.get("SHADOWMIRE_WORKERS", "3"))
# Use threads to parallelize verification local IO
IOWORKERS = int(os.environ.get("SHADOWMIRE_IOWORKERS", "2"))
# A safety net -- to avoid upstream issues casuing too many packages removed when determinating sync plan.
MAX_DELETION = int(os.environ.get("SHADOWMIRE_MAX_DELETION", "50000"))
# https://github.com/pypa/bandersnatch/blob/a05af547f8d1958217ef0dc0028890b1839e6116/src/bandersnatch_filter_plugins/prerelease_name.py#L18C1-L23C6
PRERELEASE_PATTERNS = (
@ -61,6 +67,13 @@ def exit_handler(signum: int, frame: Optional[FrameType]) -> None:
signal.signal(signal.SIGTERM, exit_handler)
def exit_with_futures(futures: dict[Future[Any], Any]) -> NoReturn:
logger.info("Exiting...")
for future in futures:
future.cancel()
sys.exit(1)
class LocalVersionKV:
"""
A key-value database wrapper over sqlite3.
@ -146,6 +159,20 @@ def overwrite(
raise
def fast_readall(file_path: Path) -> bytes:
"""
Save some extra read(), lseek() and ioctl().
"""
fd = os.open(file_path, os.O_RDONLY)
if fd < 0:
raise FileNotFoundError(file_path)
try:
contents = os.read(fd, file_path.stat().st_size)
return contents
finally:
os.close(fd)
def normalize(name: str) -> str:
"""
See https://peps.python.org/pep-0503/#normalized-names
@ -166,6 +193,22 @@ def remove_dir_with_files(directory: Path) -> None:
logger.info("Removed dir %s", directory)
def fast_iterdir(
directory: Path | str, filter_type: Literal["dir", "file"]
) -> Generator[os.DirEntry[str], Any, None]:
"""
iterdir() in pathlib would ignore file type information from getdents64(),
which is not acceptable when you have millions of files in one directory,
and you need to filter out all files/directories.
"""
assert filter_type in ["dir", "file"]
for item in os.scandir(directory):
if filter_type == "dir" and item.is_dir():
yield item
elif filter_type == "file" and item.is_file():
yield item
def get_package_urls_from_index_html(html_path: Path) -> list[str]:
"""
Get all <a> href (fragments removed) from given simple/<package>/index.html contents
@ -185,8 +228,8 @@ def get_package_urls_from_index_html(html_path: Path) -> list[str]:
self.hrefs.append(attr[1])
p = ATagHTMLParser()
with open(html_path) as f:
p.feed(f.read())
contents = fast_readall(html_path).decode()
p.feed(contents)
ret = []
for href in p.hrefs:
@ -201,8 +244,8 @@ def get_package_urls_from_index_json(json_path: Path) -> list[str]:
"""
Get all urls from given simple/<package>/index.v1_json contents
"""
with open(json_path) as f:
contents_dict = json.load(f)
contents = fast_readall(json_path)
contents_dict = json.loads(contents)
urls = [i["url"] for i in contents_dict["files"]]
return urls
@ -213,8 +256,8 @@ def get_package_urls_size_from_index_json(json_path: Path) -> list[tuple[str, in
If size is not available, returns size as -1
"""
with open(json_path) as f:
contents_dict = json.load(f)
contents = fast_readall(json_path)
contents_dict = json.loads(contents)
ret = [(i["url"], i.get("size", -1)) for i in contents_dict["files"]]
return ret
@ -226,15 +269,15 @@ def get_existing_hrefs(package_simple_path: Path) -> Optional[list[str]]:
Priority: index.v1_json -> index.html
"""
if not package_simple_path.exists():
return None
json_file = package_simple_path / "index.v1_json"
html_file = package_simple_path / "index.html"
if json_file.exists():
try:
return get_package_urls_from_index_json(json_file)
if html_file.exists():
return get_package_urls_from_index_html(html_file)
return None
except FileNotFoundError:
try:
return get_package_urls_from_index_html(html_file)
except FileNotFoundError:
return None
class CustomXMLRPCTransport(xmlrpc.client.Transport):
@ -244,9 +287,20 @@ class CustomXMLRPCTransport(xmlrpc.client.Transport):
user_agent = USER_AGENT
def make_connection(self, host: tuple[str, dict[str, str]] | str) -> HTTPConnection:
conn = super().make_connection(host)
if conn.timeout is None:
# 2 min timeout
conn.timeout = 120
return conn
def create_requests_session() -> requests.Session:
s = requests.Session()
# hardcode 1min timeout for connect & read for now
# https://requests.readthedocs.io/en/latest/user/advanced/#timeouts
# A hack to overwrite get() method
s.get_orig, s.get = s.get, functools.partial(s.get, timeout=(60, 60)) # type: ignore
retries = Retry(total=3, backoff_factor=0.1)
s.mount("http://", HTTPAdapter(max_retries=retries))
s.mount("https://", HTTPAdapter(max_retries=retries))
@ -299,11 +353,25 @@ class PyPI:
@staticmethod
def file_url_to_local_url(url: str) -> str:
"""
This function should NOT be used to construct a local Path!
"""
parsed = urlparse(url)
assert parsed.path.startswith("/packages")
prefix = "../.."
return prefix + parsed.path
@staticmethod
def file_url_to_local_path(url: str) -> Path:
"""
Unquote() and returns a Path
"""
path = urlparse(url).path
path = unquote(path)
assert path.startswith("/packages")
path = path[1:]
return Path("../..") / path
# Func modified from bandersnatch
@classmethod
def generate_html_simple_page(cls, package_meta: dict) -> str:
@ -459,6 +527,21 @@ class SyncBase:
for i in local_keys - remote_keys:
to_remove.append(i)
local_keys.remove(i)
# There are always some packages in PyPI's list_packages_with_serial() but actually not there
# Don't count them when comparing len(to_remove) with MAX_DELETION
if len(to_remove) > MAX_DELETION:
logger.error(
"Too many packages to remove (%d > %d)", len(to_remove), MAX_DELETION
)
logger.info("Some packages that would be removed:")
for p in to_remove[:100]:
logger.info("- %s", p)
for p in to_remove[100:]:
logger.debug("- %s", p)
logger.error(
"Use SHADOWMIRE_MAX_DELETION env to adjust the threshold if you really want to proceed"
)
sys.exit(2)
for i in remote_keys - local_keys:
to_update.append(i)
for i in local_keys:
@ -480,33 +563,32 @@ class SyncBase:
self,
package_names: list[str],
prerelease_excludes: list[re.Pattern[str]],
json_files: set[str],
packages_pathcache: set[str],
compare_size: bool,
) -> bool:
to_update = []
for package_name in tqdm(package_names, desc="Checking consistency"):
package_jsonmeta_path = self.jsonmeta_dir / package_name
if not package_jsonmeta_path.exists():
def is_consistent(package_name: str) -> bool:
if package_name not in json_files:
# save a newfstatat() when name already in json_files
logger.info("add %s as it does not have json API file", package_name)
to_update.append(package_name)
continue
return False
package_simple_path = self.simple_dir / package_name
html_simple = package_simple_path / "index.html"
htmlv1_simple = package_simple_path / "index.v1_html"
json_simple = package_simple_path / "index.v1_json"
if not (
html_simple.exists() and json_simple.exists() and htmlv1_simple.exists()
):
try:
# always create index.html symlink, if not exists or not a symlink
if not html_simple.is_symlink():
html_simple.unlink(missing_ok=True)
html_simple.symlink_to("index.v1_html")
hrefs_html = get_package_urls_from_index_html(htmlv1_simple)
hrefsize_json = get_package_urls_size_from_index_json(json_simple)
except FileNotFoundError:
logger.info(
"add %s as it does not have index.html, index.v1_html or index.v1_json",
"add %s as it does not have index.v1_html or index.v1_json",
package_name,
)
to_update.append(package_name)
continue
if not html_simple.is_symlink():
html_simple.unlink()
html_simple.symlink_to("index.v1_html")
hrefs_html = get_package_urls_from_index_html(html_simple)
hrefsize_json = get_package_urls_size_from_index_json(json_simple)
return False
if (
hrefs_html is None
or hrefsize_json is None
@ -514,36 +596,67 @@ class SyncBase:
):
# something unexpected happens...
logger.info("add %s as its indexes are not consistent", package_name)
to_update.append(package_name)
continue
return False
# OK, check if all hrefs have corresponding files
if self.sync_packages:
should_update = False
for href, size in hrefsize_json:
dest = Path(normpath(package_simple_path / href))
if not dest.exists():
relative_path = unquote(href)
dest_pathstr = normpath(package_simple_path / relative_path)
try:
# Fast shortcut to avoid stat() it
if dest_pathstr not in packages_pathcache:
raise FileNotFoundError
if compare_size and size != -1:
dest = Path(dest_pathstr)
# So, do stat() for real only when we need to do so,
# have a size, and it really exists in pathcache.
dest_stat = dest.stat()
dest_size = dest_stat.st_size
if dest_size != size:
logger.info(
"add %s as its local size %s != %s",
package_name,
dest_size,
size,
)
return False
except FileNotFoundError:
logger.info("add %s as it's missing packages", package_name)
should_update = True
break
if compare_size and size != -1:
dest_size = dest.stat().st_size
if dest_size != size:
logger.info(
"add %s as its local size %s != %s",
package_name,
dest_size,
size,
)
should_update = True
break
if should_update:
to_update.append(package_name)
return False
return True
to_update = []
with ThreadPoolExecutor(max_workers=IOWORKERS) as executor:
futures = {
executor.submit(is_consistent, package_name): package_name
for package_name in package_names
}
try:
for future in tqdm(
as_completed(futures),
total=len(package_names),
desc="Checking consistency",
):
package_name = futures[future]
try:
consistent = future.result()
if not consistent:
to_update.append(package_name)
except Exception:
logger.warning(
"%s generated an exception", package_name, exc_info=True
)
raise
except:
exit_with_futures(futures)
logger.info("%s packages to update in check_and_update()", len(to_update))
return self.parallel_update(to_update, prerelease_excludes)
def parallel_update(
self, package_names: list, prerelease_excludes: list[re.Pattern[str]]
self, package_names: list[str], prerelease_excludes: list[re.Pattern[str]]
) -> bool:
success = True
with ThreadPoolExecutor(max_workers=WORKERS) as executor:
@ -566,7 +679,7 @@ class SyncBase:
if serial:
self.local_db.set(package_name, serial)
except Exception as e:
if isinstance(e, (ExitProgramException, KeyboardInterrupt)):
if isinstance(e, (KeyboardInterrupt)):
raise
logger.warning(
"%s generated an exception", package_name, exc_info=True
@ -576,10 +689,7 @@ class SyncBase:
logger.info("dumping local db...")
self.local_db.dump_json()
except (ExitProgramException, KeyboardInterrupt):
logger.info("Get ExitProgramException or KeyboardInterrupt, exiting...")
for future in futures:
future.cancel()
sys.exit(1)
exit_with_futures(futures)
return success
def do_sync_plan(
@ -662,6 +772,27 @@ class SyncBase:
f.write(" </body>\n</html>")
self.local_db.dump_json()
def skip_this_package(self, i: dict, dest: Path) -> bool:
"""
A helper function for subclasses implementing do_update().
As existence check is also done with stat(), this would not bring extra I/O overhead.
Returns if skip this package or not.
"""
try:
dest_size = dest.stat().st_size
i_size = i.get("size", -1)
if i_size == -1:
return True
if dest_size == i_size:
return True
logger.warning(
"file %s exists locally, but size does not match with upstream, so it would still be downloaded.",
dest,
)
return False
except FileNotFoundError:
return False
def download(
session: requests.Session, url: str, dest: Path
@ -746,7 +877,8 @@ class SyncPyPI(SyncBase):
self.pypi.file_url_to_local_url(i["url"]) for i in release_files
]
should_remove = list(set(existing_hrefs) - set(remote_hrefs))
for p in should_remove:
for href in should_remove:
p = unquote(href)
logger.info("removing file %s (if exists)", p)
package_path = Path(normpath(package_simple_path / p))
package_path.unlink(missing_ok=True)
@ -754,12 +886,13 @@ class SyncPyPI(SyncBase):
url = i["url"]
dest = Path(
normpath(
package_simple_path / self.pypi.file_url_to_local_url(i["url"])
package_simple_path / self.pypi.file_url_to_local_path(i["url"])
)
)
logger.info("downloading file %s -> %s", url, dest)
if dest.exists():
if self.skip_this_package(i, dest):
continue
dest.parent.mkdir(parents=True, exist_ok=True)
success, _resp = download(self.session, url, dest)
if not success:
@ -847,16 +980,19 @@ class SyncPlainHTTP(SyncBase):
release_files = PyPI.get_release_files_from_meta(meta)
remote_hrefs = [PyPI.file_url_to_local_url(i["url"]) for i in release_files]
should_remove = list(set(existing_hrefs) - set(remote_hrefs))
for p in should_remove:
for href in should_remove:
p = unquote(href)
logger.info("removing file %s (if exists)", p)
package_path = Path(normpath(package_simple_path / p))
package_path.unlink(missing_ok=True)
package_simple_url = urljoin(self.upstream, f"simple/{package_name}/")
for href in remote_hrefs:
for i in release_files:
href = PyPI.file_url_to_local_url(i["url"])
path = PyPI.file_url_to_local_path(i["url"])
url = urljoin(package_simple_url, href)
dest = Path(normpath(package_simple_path / href))
dest = Path(normpath(package_simple_path / path))
logger.info("downloading file %s -> %s", url, dest)
if dest.exists():
if self.skip_this_package(i, dest):
continue
dest.parent.mkdir(parents=True, exist_ok=True)
success, resp = download(self.session, url, dest)
@ -878,14 +1014,13 @@ class SyncPlainHTTP(SyncBase):
return last_serial
def get_local_serial(package_meta_path: Path) -> Optional[int]:
def get_local_serial(package_meta_direntry: os.DirEntry[str]) -> Optional[int]:
"""
Accepts /json/<package_name> as package_meta_path
"""
package_name = package_meta_path.name
package_name = package_meta_direntry.name
try:
with open(package_meta_path) as f:
contents = f.read()
contents = fast_readall(Path(package_meta_direntry.path))
except FileNotFoundError:
logger.warning("%s does not have JSON metadata, skipping", package_name)
return None
@ -1052,13 +1187,32 @@ def genlocal(ctx: click.Context) -> None:
local = {}
json_dir = basedir / "json"
logger.info("Iterating all items under %s", json_dir)
dir_items = [d for d in json_dir.iterdir() if d.is_file()]
dir_items = [d for d in fast_iterdir(json_dir, "file")]
logger.info("Detected %s packages in %s in total", len(dir_items), json_dir)
for package_metapath in tqdm(dir_items, desc="Reading packages from json/"):
package_name = package_metapath.name
serial = get_local_serial(package_metapath)
if serial:
local[package_name] = serial
with ThreadPoolExecutor(max_workers=IOWORKERS) as executor:
futures = {
executor.submit(get_local_serial, package_metapath): package_metapath
for package_metapath in dir_items
}
try:
for future in tqdm(
as_completed(futures),
total=len(dir_items),
desc="Reading packages from json/",
):
package_name = futures[future].name
try:
serial = future.result()
if serial:
local[package_name] = serial
except Exception as e:
if isinstance(e, (KeyboardInterrupt)):
raise
logger.warning(
"%s generated an exception", package_name, exc_info=True
)
except (ExitProgramException, KeyboardInterrupt):
exit_with_futures(futures)
logger.info(
"%d out of %d packages have valid serial number", len(local), len(dir_items)
)
@ -1100,8 +1254,8 @@ def verify(
logger.info("====== Step 1. Remove packages NOT in local db ======")
local_names = set(local_db.keys())
simple_dirs = {i.name for i in (basedir / "simple").iterdir() if i.is_dir()}
json_files = {i.name for i in (basedir / "json").iterdir() if i.is_file()}
simple_dirs = {i.name for i in fast_iterdir((basedir / "simple"), "dir")}
json_files = {i.name for i in fast_iterdir((basedir / "json"), "file")}
not_in_local = (simple_dirs | json_files) - local_names
logger.info(
"%d out of %d local packages NOT in local db",
@ -1133,37 +1287,101 @@ def verify(
# After some removal, local_names is changed.
local_names = set(local_db.keys())
logger.info("====== Step 3. Caching packages/ dirtree in memory for Step 4 & 5.")
packages_pathcache: set[str] = set()
with ThreadPoolExecutor(max_workers=IOWORKERS) as executor:
def packages_iterate(first_dirname: str, position: int) -> list[str]:
with tqdm(
desc=f"Iterating packages/{first_dirname}/*/*/*", position=position
) as pb:
res = []
for d1 in fast_iterdir(basedir / "packages" / first_dirname, "dir"):
for d2 in fast_iterdir(d1.path, "dir"):
for file in fast_iterdir(d2.path, "file"):
pb.update(1)
res.append(file.path)
return res
futures = {
executor.submit(packages_iterate, first_dir.name, idx % IOWORKERS): first_dir.name # type: ignore
for idx, first_dir in enumerate(fast_iterdir((basedir / "packages"), "dir"))
}
try:
for future in as_completed(futures):
sname = futures[future]
try:
for p in future.result():
packages_pathcache.add(p)
except Exception as e:
if isinstance(e, (KeyboardInterrupt)):
raise
logger.warning("%s generated an exception", sname, exc_info=True)
success = False
except (ExitProgramException, KeyboardInterrupt):
exit_with_futures(futures)
logger.info(
"====== Step 3. Make sure all local indexes are valid, and (if --sync-packages) have valid local package files ======"
"====== Step 4. Make sure all local indexes are valid, and (if --sync-packages) have valid local package files ======"
)
success = syncer.check_and_update(
list(local_names), prerelease_excludes, compare_size
list(local_names),
prerelease_excludes,
json_files,
packages_pathcache,
compare_size,
)
syncer.finalize()
logger.info(
"====== Step 4. Remove any unreferenced files in `packages` folder ======"
"====== Step 5. Remove any unreferenced files in `packages` folder ======"
)
ref_set = set()
for sname in tqdm(simple_dirs, desc="Iterating simple/ directory"):
sd = basedir / "simple" / sname
hrefs = get_existing_hrefs(sd)
hrefs = [] if hrefs is None else hrefs
for i in hrefs:
# use normpath, which is much faster than pathlib resolve(), as it does not need to access fs
# we could make sure no symlinks could affect this here
np = normpath(sd / i)
logger.debug("add to ref_set: %s", np)
ref_set.add(np)
for file in tqdm(
(basedir / "packages").glob("*/*/*/*"), desc="Iterating packages/*/*/*/*"
):
# basedir is absolute, so file is also absolute
# just convert to str to match normpath result
logger.debug("find file %s", file)
if str(file) not in ref_set:
logger.info("removing unreferenced file %s", file)
file.unlink()
ref_set: set[str] = set()
with ThreadPoolExecutor(max_workers=IOWORKERS) as executor:
# Part 1: iterate simple/
def iterate_simple(sname: str) -> list[str]:
sd = basedir / "simple" / sname
hrefs = get_existing_hrefs(sd)
hrefs = [] if hrefs is None else hrefs
nps = []
for href in hrefs:
i = unquote(href)
# use normpath, which is much faster than pathlib resolve(), as it does not need to access fs
# we could make sure no symlinks could affect this here
np = normpath(sd / i)
logger.debug("add to ref_set: %s", np)
nps.append(np)
return nps
# MyPy does not enjoy same variable name with different types, even when --allow-redefinition
# Ignore here to make mypy happy
futures = {
executor.submit(iterate_simple, sname): sname for sname in simple_dirs # type: ignore
}
try:
for future in tqdm(
as_completed(futures),
total=len(simple_dirs),
desc="Iterating simple/ directory",
):
sname = futures[future]
try:
nps = future.result()
for np in nps:
ref_set.add(np)
except Exception as e:
if isinstance(e, (KeyboardInterrupt)):
raise
logger.warning("%s generated an exception", sname, exc_info=True)
success = False
except (ExitProgramException, KeyboardInterrupt):
exit_with_futures(futures)
# Part 2: handling packages
for path in tqdm(packages_pathcache, desc="Iterating path cache"):
if path not in ref_set:
logger.info("removing unreferenced file %s", path)
Path(path).unlink()
logger.info("Verification finished. Success: %s", success)