diff --git a/conda_forge_feedstock_check_solvable/mamba_solver.py b/conda_forge_feedstock_check_solvable/mamba_solver.py index af4b139..9130e11 100644 --- a/conda_forge_feedstock_check_solvable/mamba_solver.py +++ b/conda_forge_feedstock_check_solvable/mamba_solver.py @@ -28,6 +28,7 @@ get_run_exports, print_debug, print_warning, + suppress_output, ) pkgs_dirs = context.pkgs_dirs @@ -41,6 +42,42 @@ api.Context().channel_priority = api.ChannelPriority.kStrict +def _get_pool(channels, platform, constraints): + with suppress_output(): + pool = api.Pool() + + repos = [] + load_channels( + pool, + channels, + repos, + platform=platform, + has_priority=True, + ) + for repo in repos: + # need set_installed for add_pin, not sure why + repo.set_installed() + + return pool + + +def _get_solver(channels, platform, constraints): + pool = _get_pool(channels, platform, constraints) + + solver_options = [(api.SOLVER_FLAG_ALLOW_DOWNGRADE, 1)] + solver = api.Solver(pool, solver_options) + + for constraint in constraints: + solver.add_pin(constraint) + + return solver, pool + + +@lru_cache(maxsize=128) +def _get_solver_cached(channels, platform, constraints): + return _get_solver(channels, platform, constraints) + + class MambaSolver: """Run the mamba solver. @@ -57,22 +94,10 @@ class MambaSolver: >>> solver.solve(["xtensor 0.18"]) """ - def __init__(self, channels, platform): + def __init__(self, channels, platform, _use_cache=False): self.channels = channels self.platform = platform - self.pool = api.Pool() - - self.repos = [] - self.index = load_channels( - self.pool, - self.channels, - self.repos, - platform=platform, - has_priority=True, - ) - for repo in self.repos: - # need set_installed for add_pin, not sure why - repo.set_installed() + self._use_cache = _use_cache def solve( self, @@ -121,19 +146,23 @@ def solve( ignore_run_exports_from = ignore_run_exports_from or [] ignore_run_exports = ignore_run_exports or [] - solver_options = [(api.SOLVER_FLAG_ALLOW_DOWNGRADE, 1)] - solver = api.Solver(self.pool, solver_options) - _specs = [convert_spec_to_conda_build(s) for s in specs] _constraints = [convert_spec_to_conda_build(s) for s in constraints or []] + if self._use_cache: + solver, pool = _get_solver_cached( + self.channels, self.platform, tuple(_constraints) + ) + else: + solver, pool = _get_solver( + self.channels, self.platform, tuple(_constraints) + ) + print_debug( "MAMBA running solver for specs \n\n%s\nconstraints: %s\n", pprint.pformat(_specs), pprint.pformat(_constraints), ) - for constraint in _constraints: - solver.add_pin(constraint) solver.add_jobs(_specs, api.SOLVER_INSTALL) success = solver.solve() @@ -143,10 +172,12 @@ def solve( print_warning( "MAMBA failed to solve specs \n\n%s\n\nwith " "constraints \n\n%s\n\nfor channels " + "\n\n%s\n\non platform " "\n\n%s\n\nThe reported errors are:\n\n%s\n", textwrap.indent(pprint.pformat(_specs), " "), textwrap.indent(pprint.pformat(_constraints), " "), textwrap.indent(pprint.pformat(self.channels), " "), + textwrap.indent(pprint.pformat(self.platform), " "), textwrap.indent(solver.explain_problems(), " "), ) err = solver.explain_problems() @@ -154,7 +185,7 @@ def solve( run_exports = copy.deepcopy(DEFAULT_RUN_EXPORTS) else: t = api.Transaction( - self.pool, + pool, solver, PACKAGE_CACHE, ) @@ -215,6 +246,5 @@ def _get_run_exports( return run_exports -@lru_cache(maxsize=128) def mamba_solver_factory(channels, platform): - return MambaSolver(list(channels), platform) + return MambaSolver(tuple(channels), platform, _use_cache=True) diff --git a/conda_forge_feedstock_check_solvable/rattler_solver.py b/conda_forge_feedstock_check_solvable/rattler_solver.py index 8d97cd7..aaa45d2 100644 --- a/conda_forge_feedstock_check_solvable/rattler_solver.py +++ b/conda_forge_feedstock_check_solvable/rattler_solver.py @@ -133,10 +133,12 @@ def solve( print_warning( "MAMBA failed to solve specs \n\n%s\n\nwith " "constraints \n\n%s\n\nfor channels " + "\n\n%s\n\non platform " "\n\n%s\n\nThe reported errors are:\n\n%s\n", textwrap.indent(pprint.pformat(specs), " "), textwrap.indent(pprint.pformat(constraints), " "), textwrap.indent(pprint.pformat(self.channels), " "), + textwrap.indent(pprint.pformat(self.platform_arch), " "), textwrap.indent(err, " "), ) success = False diff --git a/conda_forge_feedstock_check_solvable/utils.py b/conda_forge_feedstock_check_solvable/utils.py index f07a75c..f1803a0 100644 --- a/conda_forge_feedstock_check_solvable/utils.py +++ b/conda_forge_feedstock_check_solvable/utils.py @@ -149,7 +149,7 @@ def override_env_var(name, value): @contextlib.contextmanager def suppress_output(): - if "CONDA_FORGE_FEEDSTOCK_CHECK_SOLVABLE_DEBUG" in os.environ: + if "CONDA_FORGE_FEEDSTOCK_CHECK_SOLVABLE_DEBUG" in os.environ or VERBOSITY > 2: suppress = False else: suppress = True diff --git a/tests/test_solvers.py b/tests/test_solvers.py index 5609e77..a1b1962 100644 --- a/tests/test_solvers.py +++ b/tests/test_solvers.py @@ -1,7 +1,18 @@ +import inspect import pprint +import pytest from flaky import flaky +from conda_forge_feedstock_check_solvable.mamba_solver import ( + MambaSolver, + _get_solver_cached, + mamba_solver_factory, +) +from conda_forge_feedstock_check_solvable.rattler_solver import ( + RattlerSolver, + rattler_solver_factory, +) from conda_forge_feedstock_check_solvable.utils import apply_pins, suppress_output from conda_forge_feedstock_check_solvable.virtual_packages import ( virtual_package_repodata, @@ -215,3 +226,145 @@ def test_solvers_hang(solver_factory): ], ) assert res[0] + + +@pytest.mark.parametrize("mamba_factory", [MambaSolver, mamba_solver_factory]) +@pytest.mark.parametrize("rattler_factory", [RattlerSolver, rattler_solver_factory]) +def test_solvers_compare_output(mamba_factory, rattler_factory): + specs_linux = ( + "libutf8proc >=2.8.0,<3.0a0", + "orc >=2.0.1,<2.0.2.0a0", + "glog >=0.7.0,<0.8.0a0", + "libabseil * cxx17*", + "libgcc-ng >=12", + "libbrotlidec >=1.1.0,<1.2.0a0", + "bzip2 >=1.0.8,<2.0a0", + "libbrotlienc >=1.1.0,<1.2.0a0", + "libgoogle-cloud-storage >=2.24.0,<2.25.0a0", + "libstdcxx-ng >=12", + "re2", + "gflags >=2.2.2,<2.3.0a0", + "libabseil >=20240116.2,<20240117.0a0", + "libre2-11 >=2023.9.1,<2024.0a0", + "libgoogle-cloud >=2.24.0,<2.25.0a0", + "lz4-c >=1.9.3,<1.10.0a0", + "libbrotlicommon >=1.1.0,<1.2.0a0", + "aws-sdk-cpp >=1.11.329,<1.11.330.0a0", + "snappy >=1.2.0,<1.3.0a0", + "zstd >=1.5.6,<1.6.0a0", + "aws-crt-cpp >=0.26.9,<0.26.10.0a0", + "libzlib >=1.2.13,<2.0a0", + ) + constraints_linux = ("apache-arrow-proc * cpu", "arrow-cpp <0.0a0") + + specs_linux_again = ( + "glog >=0.7.0,<0.8.0a0", + "bzip2 >=1.0.8,<2.0a0", + "lz4-c >=1.9.3,<1.10.0a0", + "libbrotlidec >=1.1.0,<1.2.0a0", + "zstd >=1.5.6,<1.6.0a0", + "gflags >=2.2.2,<2.3.0a0", + "libzlib >=1.2.13,<2.0a0", + "libbrotlienc >=1.1.0,<1.2.0a0", + "re2", + "aws-sdk-cpp >=1.11.329,<1.11.330.0a0", + "libgoogle-cloud-storage >=2.24.0,<2.25.0a0", + "libgoogle-cloud >=2.24.0,<2.25.0a0", + "libstdcxx-ng >=12", + "libutf8proc >=2.8.0,<3.0a0", + "libabseil * cxx17*", + "snappy >=1.2.0,<1.3.0a0", + "__glibc >=2.17,<3.0.a0", + "orc >=2.0.1,<2.0.2.0a0", + "libgcc-ng >=12", + "libabseil >=20240116.2,<20240117.0a0", + "libbrotlicommon >=1.1.0,<1.2.0a0", + "libre2-11 >=2023.9.1,<2024.0a0", + "aws-crt-cpp >=0.26.9,<0.26.10.0a0", + ) + constraints_linux_again = ("arrow-cpp <0.0a0", "apache-arrow-proc * cuda") + + specs_win = ( + "re2", + "libabseil * cxx17*", + "vc >=14.2,<15", + "libbrotlidec >=1.1.0,<1.2.0a0", + "lz4-c >=1.9.3,<1.10.0a0", + "aws-sdk-cpp >=1.11.329,<1.11.330.0a0", + "libbrotlicommon >=1.1.0,<1.2.0a0", + "snappy >=1.2.0,<1.3.0a0", + "ucrt >=10.0.20348.0", + "orc >=2.0.1,<2.0.2.0a0", + "zstd >=1.5.6,<1.6.0a0", + "libcrc32c >=1.1.2,<1.2.0a0", + "libre2-11 >=2023.9.1,<2024.0a0", + "libbrotlienc >=1.1.0,<1.2.0a0", + "libcurl >=8.8.0,<9.0a0", + "libabseil >=20240116.2,<20240117.0a0", + "bzip2 >=1.0.8,<2.0a0", + "libgoogle-cloud >=2.24.0,<2.25.0a0", + "vc14_runtime >=14.29.30139", + "libzlib >=1.2.13,<2.0a0", + "libgoogle-cloud-storage >=2.24.0,<2.25.0a0", + "libutf8proc >=2.8.0,<3.0a0", + "aws-crt-cpp >=0.26.9,<0.26.10.0a0", + ) + constraints_win = ("arrow-cpp <0.0a0", "apache-arrow-proc * cuda") + + channels = (virtual_package_repodata(), "conda-forge", "msys2") + + platform = "linux-64" + mamba_solver = mamba_factory(channels, platform) + rattler_solver = rattler_factory(channels, platform) + mamba_solvable, mamba_err, mamba_solution = mamba_solver.solve( + specs_linux, constraints=constraints_linux + ) + rattler_solvable, rattler_err, rattler_solution = rattler_solver.solve( + specs_linux, constraints=constraints_linux + ) + assert set(mamba_solution or []) == set(rattler_solution or []) + assert mamba_solvable == rattler_solvable + + platform = "linux-64" + mamba_solver = mamba_factory(channels, platform) + rattler_solver = rattler_factory(channels, platform) + mamba_solvable, mamba_err, mamba_solution = mamba_solver.solve( + specs_linux_again, constraints=constraints_linux_again + ) + rattler_solvable, rattler_err, rattler_solution = rattler_solver.solve( + specs_linux_again, constraints=constraints_linux_again + ) + assert set(mamba_solution or []) == set(rattler_solution or []) + assert mamba_solvable == rattler_solvable + + platform = "linux-64" + mamba_solver = mamba_factory(channels, platform) + rattler_solver = rattler_factory(channels, platform) + mamba_solvable, mamba_err, mamba_solution = mamba_solver.solve( + specs_linux, constraints=constraints_linux + ) + rattler_solvable, rattler_err, rattler_solution = rattler_solver.solve( + specs_linux, constraints=constraints_linux + ) + assert set(mamba_solution or []) == set(rattler_solution or []) + assert mamba_solvable == rattler_solvable + + platform = "win-64" + mamba_solver = mamba_factory(channels, platform) + rattler_solver = rattler_factory(channels, platform) + mamba_solvable, mamba_err, mamba_solution = mamba_solver.solve( + specs_win, constraints=constraints_win + ) + rattler_solvable, rattler_err, rattler_solution = rattler_solver.solve( + specs_win, constraints=constraints_win + ) + assert set(mamba_solution or []) == set(rattler_solution or []) + assert mamba_solvable == rattler_solvable + + if inspect.isfunction(mamba_factory): + assert ( + _get_solver_cached.cache_info().misses == 3 + ), _get_solver_cached.cache_info() + + if hasattr(rattler_factory, "cache_info"): + assert rattler_factory.cache_info().misses == 2, rattler_factory.cache_info()