diff --git a/README.md b/README.md index 6e66450..9f5c563 100644 --- a/README.md +++ b/README.md @@ -35,10 +35,10 @@ Obviously, `list_packages_with_serial()`'s alternative is the `local.json`, whic If you just need to fetch all indexes (and then use a cache solution for packages): ```shell -REPO=/path/to/pypi ./shadowmire.py sync +./shadowmire.py --repo /path/to/pypi sync ``` -If `REPO` env is not set, it defaults to current working directory. +If `--repo` argument is not set, it defaults to current working directory. If you need to download all packages, add `--sync-packages`. diff --git a/config.example.toml b/config.example.toml index 00adf0e..412f4b7 100644 --- a/config.example.toml +++ b/config.example.toml @@ -1,6 +1,7 @@ [options] +# repo = "/example/path" sync_packages = true -# shadowmire_upstream = http://example.com/pypi/web/ +# shadowmire_upstream = "http://example.com/pypi/web/" exclude = [ "[a-z]" ] diff --git a/shadowmire.py b/shadowmire.py index 2c9b4f6..3855e9e 100755 --- a/shadowmire.py +++ b/shadowmire.py @@ -900,6 +900,10 @@ def sync_shared_args(func: Callable[..., Any]) -> Callable[..., Any]: def read_config( ctx: click.Context, param: click.Option, filename: Optional[str] ) -> None: + # Set default repo as cwd + ctx.default_map = {} + ctx.default_map["repo"] = "." + if filename is None: return with open(filename, "rb") as f: @@ -908,12 +912,14 @@ def read_config( options = dict(data["options"]) except KeyError: options = {} - ctx.default_map = { - "sync": options, - "verify": options, - "do-update": options, - "do-remove": options, - } + if options.get("repo"): + ctx.default_map["repo"] = options["repo"] + del options["repo"] + + ctx.default_map["sync"] = options + ctx.default_map["verify"] = options + ctx.default_map["do-update"] = options + ctx.default_map["do-remove"] = options @click.group() @@ -924,8 +930,9 @@ def read_config( callback=read_config, expose_value=False, ) +@click.option("--repo", type=click.Path(file_okay=False), help="Repo (basedir) path") @click.pass_context -def cli(ctx: click.Context) -> None: +def cli(ctx: click.Context, repo: str) -> None: log_level = logging.DEBUG if os.environ.get("DEBUG") else logging.INFO logging.basicConfig(level=log_level) ctx.ensure_object(dict) @@ -937,7 +944,7 @@ def cli(ctx: click.Context) -> None: logger.warning("Don't blame me if you were banned!") # Make sure basedir is absolute - basedir = Path(os.environ.get("REPO", ".")).resolve() + basedir = Path(repo).resolve() local_db = LocalVersionKV(basedir / "local.db", basedir / "local.json") ctx.obj["basedir"] = basedir