diff --git a/conda_lock/models/lock_spec.py b/conda_lock/models/lock_spec.py index 6448800b0..e67ca7871 100644 --- a/conda_lock/models/lock_spec.py +++ b/conda_lock/models/lock_spec.py @@ -1,8 +1,12 @@ +from __future__ import annotations + +import copy import hashlib import json import pathlib import typing +from fnmatch import fnmatchcase from typing import Dict, List, Optional, Union from pydantic import BaseModel, Field, validator @@ -24,23 +28,126 @@ class _BaseDependency(StrictModel): def sorted_extras(cls, v: List[str]) -> List[str]: return sorted(v) + def _merge_base(self, other: _BaseDependency) -> _BaseDependency: + if other is None: + return self + if ( + self.name != other.name + or self.manager != other.manager + or self.category != other.category + ): + raise ValueError( + "Cannot merge incompatible dependencies: {self} != {other}" + ) + return _BaseDependency( + name=self.name, + manager=self.manager, + category=self.category, + extras=list(set(self.extras + other.extras)), + ) + class VersionedDependency(_BaseDependency): version: str build: Optional[str] = None conda_channel: Optional[str] = None + @staticmethod + def _merge_matchspec( + matchspec1: Optional[str], matchspec2: Optional[str] + ) -> Optional[str]: + if matchspec1 == matchspec2: + return copy.copy(matchspec1) + if matchspec1 is None or matchspec1 == "": + return copy.copy(matchspec2) + if matchspec2 is None or matchspec2 == "": + return copy.copy(matchspec1) + if fnmatchcase(matchspec1, matchspec2): + return copy.copy(matchspec1) + if fnmatchcase(matchspec2, matchspec1): + return copy.copy(matchspec2) + return f"{matchspec1},{matchspec2}" + + def merge(self, other: Optional[VersionedDependency]) -> VersionedDependency: + if other is None: + return self + + if ( + self.conda_channel is not None + and other.conda_channel is not None + and self.conda_channel != other.conda_channel + ): + raise ValueError( + f"VersionedDependency has two different conda_channels:\n{self}\n{other}" + ) + merged_base = self._merge_base(other) + return VersionedDependency( + name=merged_base.name, + manager=merged_base.manager, + category=merged_base.category, + extras=merged_base.extras, + version=self._merge_matchspec(self.version, other.version), # type: ignore + build=self._merge_matchspec(self.build, other.build), + conda_channel=self.conda_channel or other.conda_channel, + ) + class URLDependency(_BaseDependency): url: str hashes: List[str] + def merge(self, other: Optional[URLDependency]) -> URLDependency: + if other is None: + return self + if self.url != other.url: + raise ValueError(f"URLDependency has two different urls:\n{self}\n{other}") + + if self.hashes != other.hashes: + raise ValueError( + f"URLDependency has two different hashess:\n{self}\n{other}" + ) + merged_base = self._merge_base(other) + + return URLDependency( + name=merged_base.name, + manager=merged_base.manager, + category=merged_base.category, + extras=merged_base.extras, + url=self.url, + hashes=self.hashes, + ) + class VCSDependency(_BaseDependency): source: str vcs: str rev: Optional[str] = None + def merge(self, other: Optional[VCSDependency]) -> VCSDependency: + if other is None: + return self + if self.source != other.source: + raise ValueError( + f"VCSDependency has two different sources:\n{self}\n{other}" + ) + + if self.vcs != other.vcs: + raise ValueError(f"VCSDependency has two different vcss:\n{self}\n{other}") + + if self.rev is not None and other.rev is not None and self.rev != other.rev: + raise ValueError(f"VCSDependency has two different revs:\n{self}\n{other}") + merged_base = self._merge_base(other) + + return VCSDependency( + name=merged_base.name, + manager=merged_base.manager, + category=merged_base.category, + extras=merged_base.extras, + source=self.source, + vcs=self.vcs, + rev=self.rev or other.rev, + ) + Dependency = Union[VersionedDependency, URLDependency, VCSDependency] diff --git a/conda_lock/src_parser/aggregation.py b/conda_lock/src_parser/aggregation.py index d2b5349b3..51a033ddd 100644 --- a/conda_lock/src_parser/aggregation.py +++ b/conda_lock/src_parser/aggregation.py @@ -34,7 +34,11 @@ def aggregate_lock_specs( lock_spec.dependencies.get(platform, []) for lock_spec in lock_specs ): key = (dep.manager, dep.name) - unique_deps[key] = dep + if unique_deps.get(key) is not None and type(unique_deps[key]) != type(dep): + raise ValueError( + f"Unsupported use of different dependency types for same package:\n{dep}\n{unique_deps[key]}" + ) + unique_deps[key] = dep.merge(unique_deps.get(key)) # type: ignore dependencies[platform] = list(unique_deps.values()) diff --git a/tests/test_conda_lock.py b/tests/test_conda_lock.py index afbc5f402..5d49b4c3d 100644 --- a/tests/test_conda_lock.py +++ b/tests/test_conda_lock.py @@ -1622,22 +1622,28 @@ def test_aggregate_lock_specs(): assert actual.content_hash() == expected.content_hash() -def test_aggregate_lock_specs_override_version(): - base_spec = LockSpecification( - dependencies={"linux-64": [_make_spec("package", "=1.0")]}, +def test_aggregate_lock_specs_combine_version(): + first_spec = LockSpecification( + dependencies={"linux-64": [_make_spec("package", ">1.0")]}, channels=[Channel.from_string("conda-forge")], sources=[Path("base.yml")], ) - override_spec = LockSpecification( - dependencies={"linux-64": [_make_spec("package", "=2.0")]}, + second_spec = LockSpecification( + dependencies={"linux-64": [_make_spec("package", "<2.0")]}, + channels=[Channel.from_string("internal"), Channel.from_string("conda-forge")], + sources=[Path("additional.yml")], + ) + + result_spec = LockSpecification( + dependencies={"linux-64": [_make_spec("package", "<2.0,>1.0")]}, channels=[Channel.from_string("internal"), Channel.from_string("conda-forge")], - sources=[Path("override.yml")], + sources=[Path("additional.yml")], ) - agg_spec = aggregate_lock_specs([base_spec, override_spec], platforms=["linux-64"]) + agg_spec = aggregate_lock_specs([first_spec, second_spec], platforms=["linux-64"]) - assert agg_spec.dependencies == override_spec.dependencies + assert agg_spec.dependencies == result_spec.dependencies def test_aggregate_lock_specs_invalid_channels():