Commit e975b130 by Sébastien Eustace Committed by GitHub

Fix handling of extra markers (#1555)

parent 409756dc
......@@ -219,8 +219,12 @@ class Solver:
else:
category = dep.category
optional = dep.is_optional() and not dep.is_activated()
intersection = previous["marker"].intersect(previous_dep.marker)
intersection = intersection.intersect(package.marker)
intersection = (
previous["marker"]
.without_extras()
.intersect(previous_dep.marker.without_extras())
)
intersection = intersection.intersect(package.marker.without_extras())
marker = intersection
......
......@@ -204,6 +204,15 @@ class AnyMarker(BaseMarker):
def __repr__(self):
return "<AnyMarker>"
def __hash__(self):
return hash(("<any>", "<any>"))
def __eq__(self, other):
if not isinstance(other, BaseMarker):
return NotImplemented
return isinstance(other, AnyMarker)
class EmptyMarker(BaseMarker):
def intersect(self, other):
......@@ -230,6 +239,15 @@ class EmptyMarker(BaseMarker):
def __repr__(self):
return "<EmptyMarker>"
def __hash__(self):
return hash(("<empty>", "<empty>"))
def __eq__(self, other):
if not isinstance(other, BaseMarker):
return NotImplemented
return isinstance(other, EmptyMarker)
class SingleMarker(BaseMarker):
......@@ -329,7 +347,7 @@ class SingleMarker(BaseMarker):
if self == other:
return self
return MarkerUnion(self, other)
return MarkerUnion.of(self, other)
return other.union(self)
......@@ -344,7 +362,7 @@ class SingleMarker(BaseMarker):
def without_extras(self):
if self.name == "extra":
return EmptyMarker()
return AnyMarker()
return self
......@@ -443,7 +461,7 @@ class MultiMarker(BaseMarker):
def union(self, other):
if isinstance(other, (SingleMarker, MultiMarker)):
return MarkerUnion(self, other)
return MarkerUnion.of(self, other)
return other.union(self)
......@@ -493,17 +511,24 @@ class MultiMarker(BaseMarker):
class MarkerUnion(BaseMarker):
def __init__(self, *markers):
self._markers = []
self._markers = list(markers)
@property
def markers(self):
return self._markers
markers = _flatten_markers(markers, MarkerUnion)
@classmethod
def of(cls, *markers): # type: (tuple) -> MarkerUnion
flattened_markers = _flatten_markers(markers, MarkerUnion)
for marker in markers:
if marker in self._markers:
markers = []
for marker in flattened_markers:
if marker in markers:
continue
if isinstance(marker, SingleMarker) and marker.name == "python_version":
intersected = False
for i, mark in enumerate(self._markers):
for i, mark in enumerate(markers):
if (
not isinstance(mark, SingleMarker)
or isinstance(mark, SingleMarker)
......@@ -516,18 +541,19 @@ class MarkerUnion(BaseMarker):
intersected = True
break
elif intersection == marker.constraint:
self._markers[i] = marker
markers[i] = marker
intersected = True
break
if intersected:
continue
self._markers.append(marker)
markers.append(marker)
@property
def markers(self):
return self._markers
if len(markers) == 1 and markers[0].is_any():
return AnyMarker()
return MarkerUnion(*markers)
def append(self, marker):
if marker in self._markers:
......@@ -557,7 +583,7 @@ class MarkerUnion(BaseMarker):
if not intersection.is_empty():
new_markers.append(intersection)
return MarkerUnion(*new_markers)
return MarkerUnion.of(*new_markers)
def union(self, other):
if other.is_any():
......@@ -568,7 +594,7 @@ class MarkerUnion(BaseMarker):
new_markers = self._markers + [other]
return MarkerUnion(*new_markers)
return MarkerUnion.of(*new_markers)
def validate(self, environment):
for m in self._markers:
......@@ -654,4 +680,4 @@ def _compact_markers(markers):
if len(groups) == 1:
return groups[0]
return MarkerUnion(*groups)
return MarkerUnion.of(*groups)
......@@ -407,8 +407,10 @@ def test_solver_returns_extras_if_requested(solver, repo, package):
package_b = get_package("B", "1.0")
package_c = get_package("C", "1.0")
package_b.extras = {"foo": [get_dependency("C", "^1.0")]}
package_b.add_dependency("C", {"version": "^1.0", "optional": True})
dep = get_dependency("C", "^1.0", optional=True)
dep.marker = parse_marker("extra == 'foo'")
package_b.extras = {"foo": [dep]}
package_b.requires.append(dep)
repo.add_package(package_a)
repo.add_package(package_b)
......@@ -425,6 +427,9 @@ def test_solver_returns_extras_if_requested(solver, repo, package):
],
)
assert ops[-1].package.marker.is_any()
assert ops[0].package.marker.is_any()
def test_solver_returns_prereleases_if_requested(solver, repo, package):
package.add_dependency("A")
......@@ -1148,10 +1153,7 @@ def test_solver_does_not_trigger_new_resolution_on_duplicate_dependencies_if_onl
],
)
assert str(ops[0].package.marker) in [
'extra == "foo" or extra == "bar"',
'extra == "bar" or extra == "foo"',
]
assert str(ops[0].package.marker) == ""
assert str(ops[1].package.marker) == ""
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment