Commit 21228d11 by Randy Döring Committed by Bjorn Neergaard

refactor(solver/provider/version_solver): move get_locked() from VersionSolver to Provider

parent cf213245
......@@ -22,6 +22,7 @@ if TYPE_CHECKING:
from collections.abc import Iterable
from cleo.io.io import IO
from packaging.utils import NormalizedName
from poetry.core.packages.project_package import ProjectPackage
from poetry.config.config import Config
......@@ -59,7 +60,7 @@ class Installer:
self._execute_operations = True
self._lock = False
self._whitelist: list[str] = []
self._whitelist: list[NormalizedName] = []
self._extras: list[str] = []
......
......@@ -9,16 +9,10 @@ if TYPE_CHECKING:
from poetry.core.packages.project_package import ProjectPackage
from poetry.mixology.result import SolverResult
from poetry.packages import DependencyPackage
from poetry.puzzle.provider import Provider
def resolve_version(
root: ProjectPackage,
provider: Provider,
locked: dict[str, list[DependencyPackage]] | None = None,
use_latest: list[str] | None = None,
) -> SolverResult:
solver = VersionSolver(root, provider, locked=locked, use_latest=use_latest)
def resolve_version(root: ProjectPackage, provider: Provider) -> SolverResult:
solver = VersionSolver(root, provider)
return solver.solve()
......@@ -18,12 +18,12 @@ from poetry.mixology.partial_solution import PartialSolution
from poetry.mixology.result import SolverResult
from poetry.mixology.set_relation import SetRelation
from poetry.mixology.term import Term
from poetry.packages import DependencyPackage
if TYPE_CHECKING:
from poetry.core.packages.project_package import ProjectPackage
from poetry.packages import DependencyPackage
from poetry.puzzle.provider import Provider
......@@ -82,23 +82,10 @@ class VersionSolver:
on how this solver works.
"""
def __init__(
self,
root: ProjectPackage,
provider: Provider,
locked: dict[str, list[DependencyPackage]] | None = None,
use_latest: list[str] | None = None,
) -> None:
def __init__(self, root: ProjectPackage, provider: Provider) -> None:
self._root = root
self._provider = provider
self._dependency_cache = DependencyCache(provider)
self._locked = locked or {}
if use_latest is None:
use_latest = []
self._use_latest = use_latest
self._incompatibilities: dict[str, list[Incompatibility]] = {}
self._contradicted_incompatibilities: set[Incompatibility] = set()
self._solution = PartialSolution()
......@@ -384,12 +371,12 @@ class VersionSolver:
if dependency.is_direct_origin():
return False, -1
if dependency.name in self._use_latest:
if dependency.name in self._provider.use_latest:
# If we're forced to use the latest version of a package, it effectively
# only has one version to choose from.
return not dependency.marker.is_any(), 1
locked = self._get_locked(dependency)
locked = self._provider.get_locked(dependency)
if locked:
return not dependency.marker.is_any(), 1
......@@ -406,7 +393,7 @@ class VersionSolver:
else:
dependency = min(*unsatisfied, key=_get_min)
locked = self._get_locked(dependency)
locked = self._provider.get_locked(dependency)
if locked is None:
try:
packages = self._dependency_cache.search_for(dependency)
......@@ -499,23 +486,5 @@ class VersionSolver:
incompatibility
)
def _get_locked(self, dependency: Dependency) -> DependencyPackage | None:
if dependency.name in self._use_latest:
return None
locked = self._locked.get(dependency.name, [])
for dependency_package in locked:
package = dependency_package.package
if (
# Locked dependencies are always without features.
# Thus, we can't use is_same_package_as() here because it compares
# the complete_name (including features).
dependency.name == package.name
and dependency.is_same_source_as(package)
and dependency.constraint.allows(package.version)
):
return DependencyPackage(dependency, package)
return None
def _log(self, text: str) -> None:
self._provider.debug(text, self._solution.attempted_solutions)
......@@ -12,6 +12,7 @@ from collections import defaultdict
from contextlib import contextmanager
from pathlib import Path
from typing import TYPE_CHECKING
from typing import Collection
from typing import cast
from cleo.ui.progress_indicator import ProgressIndicator
......@@ -42,6 +43,7 @@ if TYPE_CHECKING:
from collections.abc import Iterator
from cleo.io.io import IO
from packaging.utils import NormalizedName
from poetry.core.packages.dependency import Dependency
from poetry.core.packages.directory_dependency import DirectoryDependency
from poetry.core.packages.file_dependency import FileDependency
......@@ -127,6 +129,7 @@ class Provider:
io: IO,
*,
installed: list[Package] | None = None,
locked: list[Package] | None = None,
) -> None:
self._package = package
self._pool = pool
......@@ -140,11 +143,27 @@ class Provider:
self._source_root: Path | None = None
self._installed_packages = installed if installed is not None else []
self._direct_origin_packages: dict[str, Package] = {}
self._locked: dict[NormalizedName, list[DependencyPackage]] = defaultdict(list)
self._use_latest: Collection[NormalizedName] = []
for package in locked or []:
self._locked[package.name].append(
DependencyPackage(package.to_dependency(), package)
)
for dependency_packages in self._locked.values():
dependency_packages.sort(
key=lambda p: p.package.version,
reverse=True,
)
@property
def pool(self) -> Pool:
return self._pool
@property
def use_latest(self) -> Collection[NormalizedName]:
return self._use_latest
def is_debugging(self) -> bool:
return self._is_debugging
......@@ -161,9 +180,10 @@ class Provider:
original_source_root = self._source_root
self._source_root = source_root
yield self
self._source_root = original_source_root
try:
yield self
finally:
self._source_root = original_source_root
@contextmanager
def use_environment(self, env: Env) -> Iterator[Provider]:
......@@ -172,10 +192,20 @@ class Provider:
self._env = env
self._python_constraint = Version.parse(env.marker_env["python_full_version"])
yield self
try:
yield self
finally:
self._env = None
self._python_constraint = original_python_constraint
self._env = None
self._python_constraint = original_python_constraint
@contextmanager
def use_latest_for(self, names: Collection[NormalizedName]) -> Iterator[Provider]:
self._use_latest = names
try:
yield self
finally:
self._use_latest = []
@staticmethod
def validate_package_for_dependency(
......@@ -801,6 +831,24 @@ class Provider:
return dependency_package
def get_locked(self, dependency: Dependency) -> DependencyPackage | None:
if dependency.name in self._use_latest:
return None
locked = self._locked.get(dependency.name, [])
for dependency_package in locked:
package = dependency_package.package
if (
# Locked dependencies are always without features.
# Thus, we can't use is_same_package_as() here because it compares
# the complete_name (including features).
dependency.name == package.name
and dependency.is_same_source_as(package)
and dependency.constraint.allows(package.version)
):
return DependencyPackage(dependency, package)
return None
def debug(self, message: str, depth: int = 0) -> None:
if not (self._io.is_very_verbose() or self._io.is_debug()):
return
......
......@@ -5,6 +5,7 @@ import time
from collections import defaultdict
from contextlib import contextmanager
from typing import TYPE_CHECKING
from typing import Collection
from typing import FrozenSet
from typing import Tuple
from typing import TypeVar
......@@ -13,7 +14,6 @@ from poetry.core.packages.dependency_group import MAIN_GROUP
from poetry.mixology import resolve_version
from poetry.mixology.failure import SolveFailure
from poetry.packages import DependencyPackage
from poetry.puzzle.exceptions import OverrideNeeded
from poetry.puzzle.exceptions import SolverProblemError
from poetry.puzzle.provider import Indicator
......@@ -24,10 +24,12 @@ if TYPE_CHECKING:
from collections.abc import Iterator
from cleo.io.io import IO
from packaging.utils import NormalizedName
from poetry.core.packages.dependency import Dependency
from poetry.core.packages.package import Package
from poetry.core.packages.project_package import ProjectPackage
from poetry.packages import DependencyPackage
from poetry.puzzle.transaction import Transaction
from poetry.repositories import Pool
from poetry.utils.env import Env
......@@ -49,7 +51,7 @@ class Solver:
self._io = io
self._provider = Provider(
self._package, self._pool, self._io, installed=installed
self._package, self._pool, self._io, installed=installed, locked=locked
)
self._overrides: list[dict[DependencyPackage, dict[str, Dependency]]] = []
......@@ -62,12 +64,14 @@ class Solver:
with self.provider.use_environment(env):
yield
def solve(self, use_latest: list[str] | None = None) -> Transaction:
def solve(
self, use_latest: Collection[NormalizedName] | None = None
) -> Transaction:
from poetry.puzzle.transaction import Transaction
with self._progress():
with self._progress(), self._provider.use_latest_for(use_latest or []):
start = time.time()
packages, depths = self._solve(use_latest=use_latest)
packages, depths = self._solve()
end = time.time()
if len(self._overrides) > 1:
......@@ -116,7 +120,6 @@ class Solver:
def _solve_in_compatibility_mode(
self,
overrides: tuple[dict[DependencyPackage, dict[str, Dependency]], ...],
use_latest: list[str] | None = None,
) -> tuple[list[Package], list[int]]:
packages = []
depths = []
......@@ -126,7 +129,7 @@ class Solver:
f"with the following overrides ({override}).</comment>"
)
self._provider.set_overrides(override)
_packages, _depths = self._solve(use_latest=use_latest)
_packages, _depths = self._solve()
for index, package in enumerate(_packages):
if package not in packages:
packages.append(package)
......@@ -143,31 +146,16 @@ class Solver:
return packages, depths
def _solve(
self, use_latest: list[str] | None = None
) -> tuple[list[Package], list[int]]:
def _solve(self) -> tuple[list[Package], list[int]]:
if self._provider._overrides:
self._overrides.append(self._provider._overrides)
locked: dict[str, list[DependencyPackage]] = defaultdict(list)
for package in self._locked_packages:
locked[package.name].append(
DependencyPackage(package.to_dependency(), package)
)
for dependency_packages in locked.values():
dependency_packages.sort(
key=lambda p: p.package.version,
reverse=True,
)
try:
result = resolve_version(
self._package, self._provider, locked=locked, use_latest=use_latest
)
result = resolve_version(self._package, self._provider)
packages = result.packages
except OverrideNeeded as e:
return self._solve_in_compatibility_mode(e.overrides, use_latest=use_latest)
return self._solve_in_compatibility_mode(e.overrides)
except SolveFailure as e:
raise SolverProblemError(e)
......
......@@ -7,10 +7,11 @@ from poetry.core.packages.package import Package
from poetry.factory import Factory
from poetry.mixology.failure import SolveFailure
from poetry.mixology.version_solver import VersionSolver
from poetry.packages import DependencyPackage
if TYPE_CHECKING:
from packaging.utils import NormalizedName
from poetry.core.factory import DependencyConstraint
from poetry.core.packages.project_package import ProjectPackage
from poetry.mixology import SolverResult
......@@ -22,7 +23,7 @@ def add_to_repo(
repository: Repository,
name: str,
version: str,
deps: dict[str, str] | None = None,
deps: dict[str, DependencyConstraint] | None = None,
python: str | None = None,
yanked: bool = False,
) -> None:
......@@ -43,32 +44,26 @@ def check_solver_result(
result: dict[str, str] | None = None,
error: str | None = None,
tries: int | None = None,
locked: dict[str, Package] | None = None,
use_latest: list[str] | None = None,
use_latest: list[NormalizedName] | None = None,
) -> SolverResult | None:
if locked is not None:
locked = {
k: [DependencyPackage(l.to_dependency(), l)] for k, l in locked.items()
}
solver = VersionSolver(root, provider)
with provider.use_latest_for(use_latest or []):
try:
solution = solver.solve()
except SolveFailure as e:
if error:
assert str(e) == error
solver = VersionSolver(root, provider, locked=locked, use_latest=use_latest)
try:
solution = solver.solve()
except SolveFailure as e:
if error:
assert str(e) == error
if tries is not None:
assert solver.solution.attempted_solutions == tries
if tries is not None:
assert solver.solution.attempted_solutions == tries
return None
raise
except AssertionError as e:
if error:
assert str(e) == error
return None
raise
except AssertionError as e:
if error:
assert str(e) == error
return
raise
packages = {}
for package in solution.packages:
......
......@@ -2,21 +2,25 @@ from __future__ import annotations
from typing import TYPE_CHECKING
from cleo.io.null_io import NullIO
from packaging.utils import canonicalize_name
from poetry.factory import Factory
from tests.helpers import get_package
from tests.mixology.helpers import add_to_repo
from tests.mixology.helpers import check_solver_result
from tests.mixology.version_solver.conftest import Provider
if TYPE_CHECKING:
from poetry.core.packages.project_package import ProjectPackage
from poetry.repositories import Pool
from poetry.repositories import Repository
from tests.mixology.version_solver.conftest import Provider
def test_with_compatible_locked_dependencies(
root: ProjectPackage, provider: Provider, repo: Repository
root: ProjectPackage, repo: Repository, pool: Pool
):
root.add_dependency(Factory.create_dependency("foo", "*"))
......@@ -27,16 +31,18 @@ def test_with_compatible_locked_dependencies(
add_to_repo(repo, "bar", "1.0.1")
add_to_repo(repo, "bar", "1.0.2")
locked = [get_package("foo", "1.0.1"), get_package("bar", "1.0.1")]
provider = Provider(root, pool, NullIO(), locked=locked)
check_solver_result(
root,
provider,
result={"foo": "1.0.1", "bar": "1.0.1"},
locked={"foo": get_package("foo", "1.0.1"), "bar": get_package("bar", "1.0.1")},
)
def test_with_incompatible_locked_dependencies(
root: ProjectPackage, provider: Provider, repo: Repository
root: ProjectPackage, repo: Repository, pool: Pool
):
root.add_dependency(Factory.create_dependency("foo", ">1.0.1"))
......@@ -47,16 +53,18 @@ def test_with_incompatible_locked_dependencies(
add_to_repo(repo, "bar", "1.0.1")
add_to_repo(repo, "bar", "1.0.2")
locked = [get_package("foo", "1.0.1"), get_package("bar", "1.0.1")]
provider = Provider(root, pool, NullIO(), locked=locked)
check_solver_result(
root,
provider,
result={"foo": "1.0.2", "bar": "1.0.2"},
locked={"foo": get_package("foo", "1.0.1"), "bar": get_package("bar", "1.0.1")},
)
def test_with_unrelated_locked_dependencies(
root: ProjectPackage, provider: Provider, repo: Repository
root: ProjectPackage, repo: Repository, pool: Pool
):
root.add_dependency(Factory.create_dependency("foo", "*"))
......@@ -68,16 +76,18 @@ def test_with_unrelated_locked_dependencies(
add_to_repo(repo, "bar", "1.0.2")
add_to_repo(repo, "baz", "1.0.0")
locked = [get_package("baz", "1.0.1")]
provider = Provider(root, pool, NullIO(), locked=locked)
check_solver_result(
root,
provider,
result={"foo": "1.0.2", "bar": "1.0.2"},
locked={"baz": get_package("baz", "1.0.1")},
)
def test_unlocks_dependencies_if_necessary_to_ensure_that_a_new_dependency_is_satisfied(
root: ProjectPackage, provider: Provider, repo: Repository
root: ProjectPackage, repo: Repository, pool: Pool
):
root.add_dependency(Factory.create_dependency("foo", "*"))
root.add_dependency(Factory.create_dependency("newdep", "2.0.0"))
......@@ -92,6 +102,14 @@ def test_unlocks_dependencies_if_necessary_to_ensure_that_a_new_dependency_is_sa
add_to_repo(repo, "qux", "2.0.0")
add_to_repo(repo, "newdep", "2.0.0", deps={"baz": ">=1.5.0"})
locked = [
get_package("foo", "2.0.0"),
get_package("bar", "1.0.0"),
get_package("baz", "1.0.0"),
get_package("qux", "1.0.0"),
]
provider = Provider(root, pool, NullIO(), locked=locked)
check_solver_result(
root,
provider,
......@@ -102,17 +120,11 @@ def test_unlocks_dependencies_if_necessary_to_ensure_that_a_new_dependency_is_sa
"qux": "1.0.0",
"newdep": "2.0.0",
},
locked={
"foo": get_package("foo", "2.0.0"),
"bar": get_package("bar", "1.0.0"),
"baz": get_package("baz", "1.0.0"),
"qux": get_package("qux", "1.0.0"),
},
)
def test_with_compatible_locked_dependencies_use_latest(
root: ProjectPackage, provider: Provider, repo: Repository
root: ProjectPackage, repo: Repository, pool: Pool
):
root.add_dependency(Factory.create_dependency("foo", "*"))
root.add_dependency(Factory.create_dependency("baz", "*"))
......@@ -126,21 +138,23 @@ def test_with_compatible_locked_dependencies_use_latest(
add_to_repo(repo, "baz", "1.0.0")
add_to_repo(repo, "baz", "1.0.1")
locked = [
get_package("foo", "1.0.1"),
get_package("bar", "1.0.1"),
get_package("baz", "1.0.0"),
]
provider = Provider(root, pool, NullIO(), locked=locked)
check_solver_result(
root,
provider,
result={"foo": "1.0.2", "bar": "1.0.2", "baz": "1.0.0"},
locked={
"foo": get_package("foo", "1.0.1"),
"bar": get_package("bar", "1.0.1"),
"baz": get_package("baz", "1.0.0"),
},
use_latest=["foo"],
use_latest=[canonicalize_name("foo")],
)
def test_with_compatible_locked_dependencies_with_extras(
root: ProjectPackage, provider: Provider, repo: Repository
root: ProjectPackage, repo: Repository, pool: Pool
):
root.add_dependency(Factory.create_dependency("foo", "^1.0"))
......@@ -159,20 +173,22 @@ def test_with_compatible_locked_dependencies_with_extras(
add_to_repo(repo, "baz", "1.0.0")
add_to_repo(repo, "baz", "1.0.1")
locked = [
get_package("foo", "1.0.0"),
get_package("bar", "1.0.0"),
get_package("baz", "1.0.0"),
]
provider = Provider(root, pool, NullIO(), locked=locked)
check_solver_result(
root,
provider,
result={"foo": "1.0.0", "bar": "1.0.0", "baz": "1.0.0"},
locked={
"foo": get_package("foo", "1.0.0"),
"bar": get_package("bar", "1.0.0"),
"baz": get_package("baz", "1.0.0"),
},
)
def test_with_yanked_package_in_lock(
root: ProjectPackage, provider: Provider, repo: Repository
root: ProjectPackage, repo: Repository, pool: Pool
):
root.add_dependency(Factory.create_dependency("foo", "*"))
......@@ -182,16 +198,17 @@ def test_with_yanked_package_in_lock(
# yanked version is kept in lock file
locked_foo = get_package("foo", "2")
assert not locked_foo.yanked
provider = Provider(root, pool, NullIO(), locked=[locked_foo])
result = check_solver_result(
root,
provider,
result={"foo": "2"},
locked={"foo": locked_foo},
)
foo = result.packages[0]
assert foo.yanked
# without considering the lock file, the other version is chosen
provider = Provider(root, pool, NullIO())
check_solver_result(
root,
provider,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment