Commit f54864e4 by Chris Kuehl Committed by Randy Döring

perf: don't clear the entire dependency cache when backtracking (#7950)

parent fba13093
...@@ -39,15 +39,20 @@ class DependencyCache: ...@@ -39,15 +39,20 @@ class DependencyCache:
""" """
def __init__(self, provider: Provider) -> None: def __init__(self, provider: Provider) -> None:
self.provider = provider self._provider = provider
self.cache: dict[ self._cache: dict[
int,
dict[
tuple[str, str | None, str | None, str | None, str | None], tuple[str, str | None, str | None, str | None, str | None],
list[DependencyPackage], list[DependencyPackage],
] = {} ],
] = collections.defaultdict(dict)
self.search_for = functools.lru_cache(maxsize=128)(self._search_for) self.search_for = functools.lru_cache(maxsize=128)(self._search_for)
def _search_for(self, dependency: Dependency) -> list[DependencyPackage]: def _search_for(
self, dependency: Dependency, level: int
) -> list[DependencyPackage]:
key = ( key = (
dependency.complete_name, dependency.complete_name,
dependency.source_type, dependency.source_type,
...@@ -56,12 +61,17 @@ class DependencyCache: ...@@ -56,12 +61,17 @@ class DependencyCache:
dependency.source_subdirectory, dependency.source_subdirectory,
) )
packages = self.cache.get(key) for check_level in range(level, -1, -1):
packages = self._cache[check_level].get(key)
if packages: if packages is not None:
packages = [ packages = [
p for p in packages if dependency.constraint.allows(p.package.version) p
for p in packages
if dependency.constraint.allows(p.package.version)
] ]
break
else:
packages = None
# provider.search_for() normally does not include pre-release packages # provider.search_for() normally does not include pre-release packages
# (unless requested), but will include them if there are no other # (unless requested), but will include them if there are no other
...@@ -71,14 +81,14 @@ class DependencyCache: ...@@ -71,14 +81,14 @@ class DependencyCache:
# nothing, we need to call provider.search_for() again as it may return # nothing, we need to call provider.search_for() again as it may return
# additional results this time. # additional results this time.
if not packages: if not packages:
packages = self.provider.search_for(dependency) packages = self._provider.search_for(dependency)
self.cache[key] = packages
self._cache[level][key] = packages
return packages return packages
def clear(self) -> None: def clear_level(self, level: int) -> None:
self.cache.clear() self.search_for.cache_clear()
self._cache.pop(level, None)
class VersionSolver: class VersionSolver:
...@@ -318,9 +328,9 @@ class VersionSolver: ...@@ -318,9 +328,9 @@ class VersionSolver:
self._solution.decision_level, previous_satisfier_level, -1 self._solution.decision_level, previous_satisfier_level, -1
): ):
self._contradicted_incompatibilities.pop(level, None) self._contradicted_incompatibilities.pop(level, None)
self._dependency_cache.clear_level(level)
self._solution.backtrack(previous_satisfier_level) self._solution.backtrack(previous_satisfier_level)
self._dependency_cache.clear()
if new_incompatibility: if new_incompatibility:
self._add_incompatibility(incompatibility) self._add_incompatibility(incompatibility)
...@@ -418,7 +428,11 @@ class VersionSolver: ...@@ -418,7 +428,11 @@ class VersionSolver:
if locked: if locked:
return is_specific_marker, Preference.LOCKED, 1 return is_specific_marker, Preference.LOCKED, 1
num_packages = len(self._dependency_cache.search_for(dependency)) num_packages = len(
self._dependency_cache.search_for(
dependency, self._solution.decision_level
)
)
if num_packages < 2: if num_packages < 2:
preference = Preference.NO_CHOICE preference = Preference.NO_CHOICE
...@@ -435,7 +449,9 @@ class VersionSolver: ...@@ -435,7 +449,9 @@ class VersionSolver:
locked = self._provider.get_locked(dependency) locked = self._provider.get_locked(dependency)
if locked is None: if locked is None:
packages = self._dependency_cache.search_for(dependency) packages = self._dependency_cache.search_for(
dependency, self._solution.decision_level
)
package = next(iter(packages), None) package = next(iter(packages), None)
if package is None: if package is None:
......
...@@ -2,6 +2,7 @@ from __future__ import annotations ...@@ -2,6 +2,7 @@ from __future__ import annotations
from copy import deepcopy from copy import deepcopy
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from unittest import mock
from poetry.factory import Factory from poetry.factory import Factory
from poetry.mixology.version_solver import DependencyCache from poetry.mixology.version_solver import DependencyCache
...@@ -32,14 +33,14 @@ def test_solver_dependency_cache_respects_source_type( ...@@ -32,14 +33,14 @@ def test_solver_dependency_cache_respects_source_type(
cache.search_for.cache_clear() cache.search_for.cache_clear()
# ensure cache was never hit for both calls # ensure cache was never hit for both calls
cache.search_for(dependency_pypi) cache.search_for(dependency_pypi, 0)
cache.search_for(dependency_git) cache.search_for(dependency_git, 0)
assert not cache.search_for.cache_info().hits assert not cache.search_for.cache_info().hits
# increase test coverage by searching for copies # increase test coverage by searching for copies
# (when searching for the exact same object, __eq__ is never called) # (when searching for the exact same object, __eq__ is never called)
packages_pypi = cache.search_for(deepcopy(dependency_pypi)) packages_pypi = cache.search_for(deepcopy(dependency_pypi), 0)
packages_git = cache.search_for(deepcopy(dependency_git)) packages_git = cache.search_for(deepcopy(dependency_git), 0)
assert cache.search_for.cache_info().hits == 2 assert cache.search_for.cache_info().hits == 2
assert cache.search_for.cache_info().currsize == 2 assert cache.search_for.cache_info().currsize == 2
...@@ -60,6 +61,44 @@ def test_solver_dependency_cache_respects_source_type( ...@@ -60,6 +61,44 @@ def test_solver_dependency_cache_respects_source_type(
assert package_git.package.source_resolved_reference == MOCK_DEFAULT_GIT_REVISION assert package_git.package.source_resolved_reference == MOCK_DEFAULT_GIT_REVISION
def test_solver_dependency_cache_pulls_from_prior_level_cache(
root: ProjectPackage, provider: Provider, repo: Repository
) -> None:
dependency_pypi = Factory.create_dependency("demo", ">=0.1.0")
root.add_dependency(dependency_pypi)
add_to_repo(repo, "demo", "1.0.0")
wrapped_provider = mock.Mock(wraps=provider)
cache = DependencyCache(wrapped_provider)
cache.search_for.cache_clear()
# On first call, provider.search_for() should be called and the level-0
# cache populated.
cache.search_for(dependency_pypi, 0)
assert len(wrapped_provider.search_for.mock_calls) == 1
assert ("demo", None, None, None, None) in cache._cache[0]
assert cache.search_for.cache_info().hits == 0
assert cache.search_for.cache_info().misses == 1
# On second call at level 1, provider.search_for() should not be called
# again and the level-1 cache should be populated from the level-0 cache.
cache.search_for(dependency_pypi, 1)
assert len(wrapped_provider.search_for.mock_calls) == 1
assert ("demo", None, None, None, None) in cache._cache[1]
assert cache._cache[0] == cache._cache[1]
assert cache.search_for.cache_info().hits == 0
assert cache.search_for.cache_info().misses == 2
# Clearing the level 1 cache should invalidate the lru_cache on
# cache.search_for and wipe out the level 1 cache while preserving the
# level 0 cache.
cache.clear_level(1)
assert set(cache._cache.keys()) == {0}
assert ("demo", None, None, None, None) in cache._cache[0]
assert cache.search_for.cache_info().hits == 0
assert cache.search_for.cache_info().misses == 0
def test_solver_dependency_cache_respects_subdirectories( def test_solver_dependency_cache_respects_subdirectories(
root: ProjectPackage, provider: Provider, repo: Repository root: ProjectPackage, provider: Provider, repo: Repository
) -> None: ) -> None:
...@@ -87,14 +126,14 @@ def test_solver_dependency_cache_respects_subdirectories( ...@@ -87,14 +126,14 @@ def test_solver_dependency_cache_respects_subdirectories(
cache.search_for.cache_clear() cache.search_for.cache_clear()
# ensure cache was never hit for both calls # ensure cache was never hit for both calls
cache.search_for(dependency_one) cache.search_for(dependency_one, 0)
cache.search_for(dependency_one_copy) cache.search_for(dependency_one_copy, 0)
assert not cache.search_for.cache_info().hits assert not cache.search_for.cache_info().hits
# increase test coverage by searching for copies # increase test coverage by searching for copies
# (when searching for the exact same object, __eq__ is never called) # (when searching for the exact same object, __eq__ is never called)
packages_one = cache.search_for(deepcopy(dependency_one)) packages_one = cache.search_for(deepcopy(dependency_one), 0)
packages_one_copy = cache.search_for(deepcopy(dependency_one_copy)) packages_one_copy = cache.search_for(deepcopy(dependency_one_copy), 0)
assert cache.search_for.cache_info().hits == 2 assert cache.search_for.cache_info().hits == 2
assert cache.search_for.cache_info().currsize == 2 assert cache.search_for.cache_info().currsize == 2
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment