Skip to content

Commit

Permalink
fix: cli arg parsing
Browse files Browse the repository at this point in the history
  • Loading branch information
dPys committed Nov 6, 2024
1 parent dbf42c1 commit 8131f4f
Showing 1 changed file with 53 additions and 59 deletions.
112 changes: 53 additions & 59 deletions nxbench/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,26 +46,21 @@ def validate_executable(path: str | Path) -> Path:


def safe_run(
cmd: Sequence[str | Path], check: bool = True, **kwargs
cmd: Sequence[str | Path],
check: bool = True,
capture_output: bool = False,
**kwargs,
) -> subprocess.CompletedProcess:
"""Safely run a subprocess command.
This function implements several security measures:
- Validates executable path exists and is executable
- Ensures all arguments are strings or Path objects
- Never uses shell=True
- Converts all arguments to strings
- Runs in a subprocess with captured output
While this function aims to be secure, callers must ensure input validation
at their level for any user-provided values.
"""Safely run a subprocess command with optional output capture.
Parameters
----------
cmd : sequence of str or Path
Command and arguments to run. First item must be path to executable.
check : bool, default=True
Whether to check return code
capture_output : bool, default=False
Whether to capture stdout and stderr
**kwargs : dict
Additional arguments to subprocess.run
Expand All @@ -82,12 +77,6 @@ def safe_run(
If command contains invalid argument types
subprocess.SubprocessError
If command fails and check=True
Notes
-----
This function is designed for running trusted executables with validated
arguments. It should not be used directly with unvalidated user input.
# noqa: S603
"""
if not cmd:
raise ValueError("Empty command")
Expand All @@ -101,7 +90,12 @@ def safe_run(
safe_cmd.append(str(arg))

return subprocess.run( # noqa: S603
safe_cmd, capture_output=True, text=True, shell=False, check=check, **kwargs
safe_cmd,
capture_output=capture_output,
text=True,
shell=False,
check=check,
**kwargs,
)


Expand Down Expand Up @@ -139,12 +133,50 @@ def get_git_hash() -> str:
return "unknown"

try:
proc = safe_run([git_path, "rev-parse", "HEAD"])
proc = safe_run([git_path, "rev-parse", "HEAD"], capture_output=True)
return proc.stdout.strip()
except (subprocess.SubprocessError, ValueError):
return "unknown"


def run_asv_command(
args: Sequence[str], check: bool = True
) -> subprocess.CompletedProcess:
"""Run ASV command with security checks.
Parameters
----------
args : sequence of str
Command arguments
check : bool, default=True
Whether to check return code
Returns
-------
subprocess.CompletedProcess
Completed process info
Raises
------
click.ClickException
If command fails
"""
asv_path = get_asv_executable()
if asv_path is None:
raise click.ClickException("ASV executable not found")

safe_args = []
for arg in args:
if not isinstance(arg, str):
raise click.ClickException(f"Invalid argument type: {type(arg)}")
safe_args.append(arg)

try:
return safe_run([asv_path, *safe_args], check=check)
except (subprocess.SubprocessError, ValueError) as e:
raise click.ClickException(str(e))


@click.group()
@click.option("-v", "--verbose", count=True, help="Increase verbosity.")
@click.option(
Expand Down Expand Up @@ -241,44 +273,6 @@ def benchmark(ctx):
"""Benchmark management commands."""


def run_asv_command(
args: Sequence[str], check: bool = True
) -> subprocess.CompletedProcess:
"""Run ASV command with security checks.
Parameters
----------
args : sequence of str
Command arguments
check : bool, default=True
Whether to check return code
Returns
-------
subprocess.CompletedProcess
Completed process info
Raises
------
click.ClickException
If command fails
"""
asv_path = get_asv_executable()
if asv_path is None:
raise click.ClickException("ASV executable not found")

safe_args = []
for arg in args:
if not isinstance(arg, str):
raise click.ClickException(f"Invalid argument type: {type(arg)}")
safe_args.append(arg)

try:
return safe_run([asv_path, *safe_args], check=check)
except (subprocess.SubprocessError, ValueError) as e:
raise click.ClickException(str(e))


@benchmark.command(name="run")
@click.option(
"--backend",
Expand All @@ -301,7 +295,7 @@ def run_benchmark(ctx, backend: tuple[str], collection: str):
logger.exception("Failed to get git hash")
raise click.ClickException("Could not determine git commit hash")

cmd_args = ["asv", "run", "--quick", f"--set-commit-hash={git_hash}"]
cmd_args = ["run", "--quick", f"--set-commit-hash={git_hash}"]

if package_config.verbosity_level >= 1:
cmd_args.append("--verbose")
Expand Down

0 comments on commit 8131f4f

Please sign in to comment.