Commit e975b130 by Sébastien Eustace Committed by GitHub

Fix handling of extra markers (#1555)

parent 409756dc
...@@ -219,8 +219,12 @@ class Solver: ...@@ -219,8 +219,12 @@ class Solver:
else: else:
category = dep.category category = dep.category
optional = dep.is_optional() and not dep.is_activated() optional = dep.is_optional() and not dep.is_activated()
intersection = previous["marker"].intersect(previous_dep.marker) intersection = (
intersection = intersection.intersect(package.marker) previous["marker"]
.without_extras()
.intersect(previous_dep.marker.without_extras())
)
intersection = intersection.intersect(package.marker.without_extras())
marker = intersection marker = intersection
......
...@@ -204,6 +204,15 @@ class AnyMarker(BaseMarker): ...@@ -204,6 +204,15 @@ class AnyMarker(BaseMarker):
def __repr__(self): def __repr__(self):
return "<AnyMarker>" return "<AnyMarker>"
def __hash__(self):
return hash(("<any>", "<any>"))
def __eq__(self, other):
if not isinstance(other, BaseMarker):
return NotImplemented
return isinstance(other, AnyMarker)
class EmptyMarker(BaseMarker): class EmptyMarker(BaseMarker):
def intersect(self, other): def intersect(self, other):
...@@ -230,6 +239,15 @@ class EmptyMarker(BaseMarker): ...@@ -230,6 +239,15 @@ class EmptyMarker(BaseMarker):
def __repr__(self): def __repr__(self):
return "<EmptyMarker>" return "<EmptyMarker>"
def __hash__(self):
return hash(("<empty>", "<empty>"))
def __eq__(self, other):
if not isinstance(other, BaseMarker):
return NotImplemented
return isinstance(other, EmptyMarker)
class SingleMarker(BaseMarker): class SingleMarker(BaseMarker):
...@@ -329,7 +347,7 @@ class SingleMarker(BaseMarker): ...@@ -329,7 +347,7 @@ class SingleMarker(BaseMarker):
if self == other: if self == other:
return self return self
return MarkerUnion(self, other) return MarkerUnion.of(self, other)
return other.union(self) return other.union(self)
...@@ -344,7 +362,7 @@ class SingleMarker(BaseMarker): ...@@ -344,7 +362,7 @@ class SingleMarker(BaseMarker):
def without_extras(self): def without_extras(self):
if self.name == "extra": if self.name == "extra":
return EmptyMarker() return AnyMarker()
return self return self
...@@ -443,7 +461,7 @@ class MultiMarker(BaseMarker): ...@@ -443,7 +461,7 @@ class MultiMarker(BaseMarker):
def union(self, other): def union(self, other):
if isinstance(other, (SingleMarker, MultiMarker)): if isinstance(other, (SingleMarker, MultiMarker)):
return MarkerUnion(self, other) return MarkerUnion.of(self, other)
return other.union(self) return other.union(self)
...@@ -493,17 +511,24 @@ class MultiMarker(BaseMarker): ...@@ -493,17 +511,24 @@ class MultiMarker(BaseMarker):
class MarkerUnion(BaseMarker): class MarkerUnion(BaseMarker):
def __init__(self, *markers): def __init__(self, *markers):
self._markers = [] self._markers = list(markers)
markers = _flatten_markers(markers, MarkerUnion) @property
def markers(self):
return self._markers
for marker in markers: @classmethod
if marker in self._markers: def of(cls, *markers): # type: (tuple) -> MarkerUnion
flattened_markers = _flatten_markers(markers, MarkerUnion)
markers = []
for marker in flattened_markers:
if marker in markers:
continue continue
if isinstance(marker, SingleMarker) and marker.name == "python_version": if isinstance(marker, SingleMarker) and marker.name == "python_version":
intersected = False intersected = False
for i, mark in enumerate(self._markers): for i, mark in enumerate(markers):
if ( if (
not isinstance(mark, SingleMarker) not isinstance(mark, SingleMarker)
or isinstance(mark, SingleMarker) or isinstance(mark, SingleMarker)
...@@ -516,18 +541,19 @@ class MarkerUnion(BaseMarker): ...@@ -516,18 +541,19 @@ class MarkerUnion(BaseMarker):
intersected = True intersected = True
break break
elif intersection == marker.constraint: elif intersection == marker.constraint:
self._markers[i] = marker markers[i] = marker
intersected = True intersected = True
break break
if intersected: if intersected:
continue continue
self._markers.append(marker) markers.append(marker)
@property if len(markers) == 1 and markers[0].is_any():
def markers(self): return AnyMarker()
return self._markers
return MarkerUnion(*markers)
def append(self, marker): def append(self, marker):
if marker in self._markers: if marker in self._markers:
...@@ -557,7 +583,7 @@ class MarkerUnion(BaseMarker): ...@@ -557,7 +583,7 @@ class MarkerUnion(BaseMarker):
if not intersection.is_empty(): if not intersection.is_empty():
new_markers.append(intersection) new_markers.append(intersection)
return MarkerUnion(*new_markers) return MarkerUnion.of(*new_markers)
def union(self, other): def union(self, other):
if other.is_any(): if other.is_any():
...@@ -568,7 +594,7 @@ class MarkerUnion(BaseMarker): ...@@ -568,7 +594,7 @@ class MarkerUnion(BaseMarker):
new_markers = self._markers + [other] new_markers = self._markers + [other]
return MarkerUnion(*new_markers) return MarkerUnion.of(*new_markers)
def validate(self, environment): def validate(self, environment):
for m in self._markers: for m in self._markers:
...@@ -654,4 +680,4 @@ def _compact_markers(markers): ...@@ -654,4 +680,4 @@ def _compact_markers(markers):
if len(groups) == 1: if len(groups) == 1:
return groups[0] return groups[0]
return MarkerUnion(*groups) return MarkerUnion.of(*groups)
...@@ -407,8 +407,10 @@ def test_solver_returns_extras_if_requested(solver, repo, package): ...@@ -407,8 +407,10 @@ def test_solver_returns_extras_if_requested(solver, repo, package):
package_b = get_package("B", "1.0") package_b = get_package("B", "1.0")
package_c = get_package("C", "1.0") package_c = get_package("C", "1.0")
package_b.extras = {"foo": [get_dependency("C", "^1.0")]} dep = get_dependency("C", "^1.0", optional=True)
package_b.add_dependency("C", {"version": "^1.0", "optional": True}) dep.marker = parse_marker("extra == 'foo'")
package_b.extras = {"foo": [dep]}
package_b.requires.append(dep)
repo.add_package(package_a) repo.add_package(package_a)
repo.add_package(package_b) repo.add_package(package_b)
...@@ -425,6 +427,9 @@ def test_solver_returns_extras_if_requested(solver, repo, package): ...@@ -425,6 +427,9 @@ def test_solver_returns_extras_if_requested(solver, repo, package):
], ],
) )
assert ops[-1].package.marker.is_any()
assert ops[0].package.marker.is_any()
def test_solver_returns_prereleases_if_requested(solver, repo, package): def test_solver_returns_prereleases_if_requested(solver, repo, package):
package.add_dependency("A") package.add_dependency("A")
...@@ -1148,10 +1153,7 @@ def test_solver_does_not_trigger_new_resolution_on_duplicate_dependencies_if_onl ...@@ -1148,10 +1153,7 @@ def test_solver_does_not_trigger_new_resolution_on_duplicate_dependencies_if_onl
], ],
) )
assert str(ops[0].package.marker) in [ assert str(ops[0].package.marker) == ""
'extra == "foo" or extra == "bar"',
'extra == "bar" or extra == "foo"',
]
assert str(ops[1].package.marker) == "" assert str(ops[1].package.marker) == ""
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment