Skip to content

Commit

Permalink
BUG fixed issue with cache in mamba
Browse files Browse the repository at this point in the history
  • Loading branch information
beckermr committed Jun 5, 2024
1 parent 8d6a97e commit d7b061b
Show file tree
Hide file tree
Showing 4 changed files with 208 additions and 23 deletions.
74 changes: 52 additions & 22 deletions conda_forge_feedstock_check_solvable/mamba_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
get_run_exports,
print_debug,
print_warning,
suppress_output,
)

pkgs_dirs = context.pkgs_dirs
Expand All @@ -41,6 +42,42 @@
api.Context().channel_priority = api.ChannelPriority.kStrict


def _get_pool(channels, platform, constraints):
with suppress_output():
pool = api.Pool()

repos = []
load_channels(
pool,
channels,
repos,
platform=platform,
has_priority=True,
)
for repo in repos:
# need set_installed for add_pin, not sure why
repo.set_installed()

return pool


def _get_solver(channels, platform, constraints):
pool = _get_pool(channels, platform, constraints)

solver_options = [(api.SOLVER_FLAG_ALLOW_DOWNGRADE, 1)]
solver = api.Solver(pool, solver_options)

for constraint in constraints:
solver.add_pin(constraint)

return solver, pool


@lru_cache(maxsize=128)
def _get_solver_cached(channels, platform, constraints):
return _get_solver(channels, platform, constraints)


class MambaSolver:
"""Run the mamba solver.
Expand All @@ -57,22 +94,10 @@ class MambaSolver:
>>> solver.solve(["xtensor 0.18"])
"""

def __init__(self, channels, platform):
def __init__(self, channels, platform, _use_cache=False):
self.channels = channels
self.platform = platform
self.pool = api.Pool()

self.repos = []
self.index = load_channels(
self.pool,
self.channels,
self.repos,
platform=platform,
has_priority=True,
)
for repo in self.repos:
# need set_installed for add_pin, not sure why
repo.set_installed()
self._use_cache = _use_cache

def solve(
self,
Expand Down Expand Up @@ -121,19 +146,23 @@ def solve(
ignore_run_exports_from = ignore_run_exports_from or []
ignore_run_exports = ignore_run_exports or []

solver_options = [(api.SOLVER_FLAG_ALLOW_DOWNGRADE, 1)]
solver = api.Solver(self.pool, solver_options)

_specs = [convert_spec_to_conda_build(s) for s in specs]
_constraints = [convert_spec_to_conda_build(s) for s in constraints or []]

if self._use_cache:
solver, pool = _get_solver_cached(
self.channels, self.platform, tuple(_constraints)
)
else:
solver, pool = _get_solver(
self.channels, self.platform, tuple(_constraints)
)

print_debug(
"MAMBA running solver for specs \n\n%s\nconstraints: %s\n",
pprint.pformat(_specs),
pprint.pformat(_constraints),
)
for constraint in _constraints:
solver.add_pin(constraint)

solver.add_jobs(_specs, api.SOLVER_INSTALL)
success = solver.solve()
Expand All @@ -143,18 +172,20 @@ def solve(
print_warning(
"MAMBA failed to solve specs \n\n%s\n\nwith "
"constraints \n\n%s\n\nfor channels "
"\n\n%s\n\non platform "
"\n\n%s\n\nThe reported errors are:\n\n%s\n",
textwrap.indent(pprint.pformat(_specs), " "),
textwrap.indent(pprint.pformat(_constraints), " "),
textwrap.indent(pprint.pformat(self.channels), " "),
textwrap.indent(pprint.pformat(self.platform), " "),
textwrap.indent(solver.explain_problems(), " "),
)
err = solver.explain_problems()
solution = None
run_exports = copy.deepcopy(DEFAULT_RUN_EXPORTS)
else:
t = api.Transaction(
self.pool,
pool,
solver,
PACKAGE_CACHE,
)
Expand Down Expand Up @@ -215,6 +246,5 @@ def _get_run_exports(
return run_exports


@lru_cache(maxsize=128)
def mamba_solver_factory(channels, platform):
return MambaSolver(list(channels), platform)
return MambaSolver(tuple(channels), platform, _use_cache=True)
2 changes: 2 additions & 0 deletions conda_forge_feedstock_check_solvable/rattler_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,10 +133,12 @@ def solve(
print_warning(
"MAMBA failed to solve specs \n\n%s\n\nwith "
"constraints \n\n%s\n\nfor channels "
"\n\n%s\n\non platform "
"\n\n%s\n\nThe reported errors are:\n\n%s\n",
textwrap.indent(pprint.pformat(specs), " "),
textwrap.indent(pprint.pformat(constraints), " "),
textwrap.indent(pprint.pformat(self.channels), " "),
textwrap.indent(pprint.pformat(self.platform_arch), " "),
textwrap.indent(err, " "),
)
success = False
Expand Down
2 changes: 1 addition & 1 deletion conda_forge_feedstock_check_solvable/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def override_env_var(name, value):

@contextlib.contextmanager
def suppress_output():
if "CONDA_FORGE_FEEDSTOCK_CHECK_SOLVABLE_DEBUG" in os.environ:
if "CONDA_FORGE_FEEDSTOCK_CHECK_SOLVABLE_DEBUG" in os.environ or VERBOSITY > 2:
suppress = False
else:
suppress = True
Expand Down
153 changes: 153 additions & 0 deletions tests/test_solvers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,18 @@
import inspect
import pprint

import pytest
from flaky import flaky

from conda_forge_feedstock_check_solvable.mamba_solver import (
MambaSolver,
_get_solver_cached,
mamba_solver_factory,
)
from conda_forge_feedstock_check_solvable.rattler_solver import (
RattlerSolver,
rattler_solver_factory,
)
from conda_forge_feedstock_check_solvable.utils import apply_pins, suppress_output
from conda_forge_feedstock_check_solvable.virtual_packages import (
virtual_package_repodata,
Expand Down Expand Up @@ -215,3 +226,145 @@ def test_solvers_hang(solver_factory):
],
)
assert res[0]


@pytest.mark.parametrize("mamba_factory", [MambaSolver, mamba_solver_factory])
@pytest.mark.parametrize("rattler_factory", [RattlerSolver, rattler_solver_factory])
def test_solvers_compare_output(mamba_factory, rattler_factory):
specs_linux = (
"libutf8proc >=2.8.0,<3.0a0",
"orc >=2.0.1,<2.0.2.0a0",
"glog >=0.7.0,<0.8.0a0",
"libabseil * cxx17*",
"libgcc-ng >=12",
"libbrotlidec >=1.1.0,<1.2.0a0",
"bzip2 >=1.0.8,<2.0a0",
"libbrotlienc >=1.1.0,<1.2.0a0",
"libgoogle-cloud-storage >=2.24.0,<2.25.0a0",
"libstdcxx-ng >=12",
"re2",
"gflags >=2.2.2,<2.3.0a0",
"libabseil >=20240116.2,<20240117.0a0",
"libre2-11 >=2023.9.1,<2024.0a0",
"libgoogle-cloud >=2.24.0,<2.25.0a0",
"lz4-c >=1.9.3,<1.10.0a0",
"libbrotlicommon >=1.1.0,<1.2.0a0",
"aws-sdk-cpp >=1.11.329,<1.11.330.0a0",
"snappy >=1.2.0,<1.3.0a0",
"zstd >=1.5.6,<1.6.0a0",
"aws-crt-cpp >=0.26.9,<0.26.10.0a0",
"libzlib >=1.2.13,<2.0a0",
)
constraints_linux = ("apache-arrow-proc * cpu", "arrow-cpp <0.0a0")

specs_linux_again = (
"glog >=0.7.0,<0.8.0a0",
"bzip2 >=1.0.8,<2.0a0",
"lz4-c >=1.9.3,<1.10.0a0",
"libbrotlidec >=1.1.0,<1.2.0a0",
"zstd >=1.5.6,<1.6.0a0",
"gflags >=2.2.2,<2.3.0a0",
"libzlib >=1.2.13,<2.0a0",
"libbrotlienc >=1.1.0,<1.2.0a0",
"re2",
"aws-sdk-cpp >=1.11.329,<1.11.330.0a0",
"libgoogle-cloud-storage >=2.24.0,<2.25.0a0",
"libgoogle-cloud >=2.24.0,<2.25.0a0",
"libstdcxx-ng >=12",
"libutf8proc >=2.8.0,<3.0a0",
"libabseil * cxx17*",
"snappy >=1.2.0,<1.3.0a0",
"__glibc >=2.17,<3.0.a0",
"orc >=2.0.1,<2.0.2.0a0",
"libgcc-ng >=12",
"libabseil >=20240116.2,<20240117.0a0",
"libbrotlicommon >=1.1.0,<1.2.0a0",
"libre2-11 >=2023.9.1,<2024.0a0",
"aws-crt-cpp >=0.26.9,<0.26.10.0a0",
)
constraints_linux_again = ("arrow-cpp <0.0a0", "apache-arrow-proc * cuda")

specs_win = (
"re2",
"libabseil * cxx17*",
"vc >=14.2,<15",
"libbrotlidec >=1.1.0,<1.2.0a0",
"lz4-c >=1.9.3,<1.10.0a0",
"aws-sdk-cpp >=1.11.329,<1.11.330.0a0",
"libbrotlicommon >=1.1.0,<1.2.0a0",
"snappy >=1.2.0,<1.3.0a0",
"ucrt >=10.0.20348.0",
"orc >=2.0.1,<2.0.2.0a0",
"zstd >=1.5.6,<1.6.0a0",
"libcrc32c >=1.1.2,<1.2.0a0",
"libre2-11 >=2023.9.1,<2024.0a0",
"libbrotlienc >=1.1.0,<1.2.0a0",
"libcurl >=8.8.0,<9.0a0",
"libabseil >=20240116.2,<20240117.0a0",
"bzip2 >=1.0.8,<2.0a0",
"libgoogle-cloud >=2.24.0,<2.25.0a0",
"vc14_runtime >=14.29.30139",
"libzlib >=1.2.13,<2.0a0",
"libgoogle-cloud-storage >=2.24.0,<2.25.0a0",
"libutf8proc >=2.8.0,<3.0a0",
"aws-crt-cpp >=0.26.9,<0.26.10.0a0",
)
constraints_win = ("arrow-cpp <0.0a0", "apache-arrow-proc * cuda")

channels = (virtual_package_repodata(), "conda-forge", "msys2")

platform = "linux-64"
mamba_solver = mamba_factory(channels, platform)
rattler_solver = rattler_factory(channels, platform)
mamba_solvable, mamba_err, mamba_solution = mamba_solver.solve(
specs_linux, constraints=constraints_linux
)
rattler_solvable, rattler_err, rattler_solution = rattler_solver.solve(
specs_linux, constraints=constraints_linux
)
assert set(mamba_solution or []) == set(rattler_solution or [])
assert mamba_solvable == rattler_solvable

platform = "linux-64"
mamba_solver = mamba_factory(channels, platform)
rattler_solver = rattler_factory(channels, platform)
mamba_solvable, mamba_err, mamba_solution = mamba_solver.solve(
specs_linux_again, constraints=constraints_linux_again
)
rattler_solvable, rattler_err, rattler_solution = rattler_solver.solve(
specs_linux_again, constraints=constraints_linux_again
)
assert set(mamba_solution or []) == set(rattler_solution or [])
assert mamba_solvable == rattler_solvable

platform = "linux-64"
mamba_solver = mamba_factory(channels, platform)
rattler_solver = rattler_factory(channels, platform)
mamba_solvable, mamba_err, mamba_solution = mamba_solver.solve(
specs_linux, constraints=constraints_linux
)
rattler_solvable, rattler_err, rattler_solution = rattler_solver.solve(
specs_linux, constraints=constraints_linux
)
assert set(mamba_solution or []) == set(rattler_solution or [])
assert mamba_solvable == rattler_solvable

platform = "win-64"
mamba_solver = mamba_factory(channels, platform)
rattler_solver = rattler_factory(channels, platform)
mamba_solvable, mamba_err, mamba_solution = mamba_solver.solve(
specs_win, constraints=constraints_win
)
rattler_solvable, rattler_err, rattler_solution = rattler_solver.solve(
specs_win, constraints=constraints_win
)
assert set(mamba_solution or []) == set(rattler_solution or [])
assert mamba_solvable == rattler_solvable

if inspect.isfunction(mamba_factory):
assert (
_get_solver_cached.cache_info().misses == 3
), _get_solver_cached.cache_info()

if hasattr(rattler_factory, "cache_info"):
assert rattler_factory.cache_info().misses == 2, rattler_factory.cache_info()

0 comments on commit d7b061b

Please sign in to comment.