Commit 63c86bf9 by David Hotham Committed by GitHub

refactor the search for direct-origin dependencies (#5904)

* refactor the search for direct-origin dependencies

* clarify where the interface is

* fix treatment of dependency via cache
parent 8b640886
from __future__ import annotations
from typing import TYPE_CHECKING
from typing import cast
from cleo.helpers import argument
from cleo.helpers import option
from poetry.core.packages.directory_dependency import DirectoryDependency
from poetry.core.packages.file_dependency import FileDependency
from poetry.core.packages.vcs_dependency import VCSDependency
from poetry.console.commands.group_command import GroupCommand
from poetry.utils.helpers import canonicalize_name
......@@ -499,16 +495,7 @@ lists all packages available."""
for dep in requires:
if dep.name == package.name:
provider = Provider(root, self.poetry.pool, NullIO())
if dep.is_vcs():
dep = cast(VCSDependency, dep)
return provider.search_for_vcs(dep)[0]
if dep.is_file():
dep = cast(FileDependency, dep)
return provider.search_for_file(dep)[0]
if dep.is_directory():
dep = cast(DirectoryDependency, dep)
return provider.search_for_directory(dep)[0]
return provider.search_for_direct_origin_dependency(dep)
name = package.name
selector = VersionSelector(self.poetry.pool)
......
......@@ -221,6 +221,44 @@ class Provider:
)
return packages
def search_for_direct_origin_dependency(self, dependency: Dependency) -> Package:
package = self._deferred_cache.get(dependency)
if package is not None:
pass
elif dependency.is_vcs():
dependency = cast(VCSDependency, dependency)
package = self._search_for_vcs(dependency)
elif dependency.is_file():
dependency = cast(FileDependency, dependency)
package = self._search_for_file(dependency)
elif dependency.is_directory():
dependency = cast(DirectoryDependency, dependency)
package = self._search_for_directory(dependency)
elif dependency.is_url():
dependency = cast(URLDependency, dependency)
package = self._search_for_url(dependency)
else:
raise RuntimeError(
f"Unknown direct dependency type {dependency.source_type}"
)
if dependency.is_vcs():
dependency._source_reference = package.source_reference
dependency._source_resolved_reference = package.source_resolved_reference
dependency._source_subdirectory = package.source_subdirectory
dependency._constraint = package.version
dependency._pretty_constraint = package.version.text
self._deferred_cache[dependency] = package
return package
def search_for(self, dependency: Dependency) -> list[DependencyPackage]:
"""
Search for the specifications that match the given dependency.
......@@ -231,18 +269,9 @@ class Provider:
if dependency.is_root:
return PackageCollection(dependency, [self._package])
if dependency.is_vcs():
dependency = cast(VCSDependency, dependency)
packages = self.search_for_vcs(dependency)
elif dependency.is_file():
dependency = cast(FileDependency, dependency)
packages = self.search_for_file(dependency)
elif dependency.is_directory():
dependency = cast(DirectoryDependency, dependency)
packages = self.search_for_directory(dependency)
elif dependency.is_url():
dependency = cast(URLDependency, dependency)
packages = self.search_for_url(dependency)
if dependency.is_direct_origin():
packages = [self.search_for_direct_origin_dependency(dependency)]
else:
packages = self._pool.find_packages(dependency)
......@@ -259,7 +288,7 @@ class Provider:
return PackageCollection(dependency, packages)
def search_for_vcs(self, dependency: VCSDependency) -> list[Package]:
def _search_for_vcs(self, dependency: VCSDependency) -> Package:
"""
Search for the specifications that match the given VCS dependency.
......@@ -281,16 +310,7 @@ class Provider:
package.develop = dependency.develop
dependency._constraint = package.version
dependency._pretty_constraint = package.version.text
dependency._source_reference = package.source_reference
dependency._source_resolved_reference = package.source_resolved_reference
dependency._source_subdirectory = package.source_subdirectory
self._deferred_cache[dependency] = package
return [package]
return package
@staticmethod
def get_package_from_vcs(
......@@ -314,18 +334,8 @@ class Provider:
source_root=source_root,
)
def search_for_file(self, dependency: FileDependency) -> list[Package]:
if dependency in self._deferred_cache:
_package = self._deferred_cache[dependency]
package = _package.clone()
else:
package = self.get_package_from_file(dependency.full_path)
dependency._constraint = package.version
dependency._pretty_constraint = package.version.text
self._deferred_cache[dependency] = package
def _search_for_file(self, dependency: FileDependency) -> Package:
package = self.get_package_from_file(dependency.full_path)
self.validate_package_for_dependency(dependency=dependency, package=package)
......@@ -336,7 +346,7 @@ class Provider:
{"file": dependency.path.name, "hash": "sha256:" + dependency.hash()}
]
return [package]
return package
@classmethod
def get_package_from_file(cls, file_path: Path) -> Package:
......@@ -351,18 +361,8 @@ class Provider:
return package
def search_for_directory(self, dependency: DirectoryDependency) -> list[Package]:
if dependency in self._deferred_cache:
_package = self._deferred_cache[dependency]
package = _package.clone()
else:
package = self.get_package_from_directory(dependency.full_path)
dependency._constraint = package.version
dependency._pretty_constraint = package.version.text
self._deferred_cache[dependency] = package
def _search_for_directory(self, dependency: DirectoryDependency) -> Package:
package = self.get_package_from_directory(dependency.full_path)
self.validate_package_for_dependency(dependency=dependency, package=package)
......@@ -371,16 +371,13 @@ class Provider:
if dependency.base is not None:
package.root_dir = dependency.base
return [package]
return package
@classmethod
def get_package_from_directory(cls, directory: Path) -> Package:
return PackageInfo.from_directory(path=directory).to_package(root_dir=directory)
def search_for_url(self, dependency: URLDependency) -> list[Package]:
if dependency in self._deferred_cache:
return [self._deferred_cache[dependency]]
def _search_for_url(self, dependency: URLDependency) -> Package:
package = self.get_package_from_url(dependency.url)
self.validate_package_for_dependency(dependency=dependency, package=package)
......@@ -393,12 +390,7 @@ class Provider:
for extra_dep in package.extras[extra]:
package.add_dependency(extra_dep)
dependency._constraint = package.version
dependency._pretty_constraint = package.version.text
self._deferred_cache[dependency] = package
return [package]
return package
@classmethod
def get_package_from_url(cls, url: str) -> Package:
......@@ -538,14 +530,8 @@ class Provider:
if self._load_deferred:
# Retrieving constraints for deferred dependencies
for r in requires:
if r.is_directory():
self.search_for_directory(r)
elif r.is_file():
self.search_for_file(r)
elif r.is_vcs():
self.search_for_vcs(r)
elif r.is_url():
self.search_for_url(r)
if r.is_direct_origin():
self.search_for_direct_origin_dependency(r)
optional_dependencies = []
_dependencies = []
......
......@@ -60,14 +60,14 @@ def test_search_for_vcs_retains_develop_flag(provider: Provider, value: bool):
dependency = VCSDependency(
"demo", "git", "https://github.com/demo/demo.git", develop=value
)
package = provider.search_for_vcs(dependency)[0]
package = provider.search_for_direct_origin_dependency(dependency)
assert package.develop == value
def test_search_for_vcs_setup_egg_info(provider: Provider):
dependency = VCSDependency("demo", "git", "https://github.com/demo/demo.git")
package = provider.search_for_vcs(dependency)[0]
package = provider.search_for_direct_origin_dependency(dependency)
assert package.name == "demo"
assert package.version.text == "0.1.2"
......@@ -87,7 +87,7 @@ def test_search_for_vcs_setup_egg_info_with_extras(provider: Provider):
"demo", "git", "https://github.com/demo/demo.git", extras=["foo"]
)
package = provider.search_for_vcs(dependency)[0]
package = provider.search_for_direct_origin_dependency(dependency)
assert package.name == "demo"
assert package.version.text == "0.1.2"
......@@ -107,7 +107,7 @@ def test_search_for_vcs_read_setup(provider: Provider, mocker: MockerFixture):
dependency = VCSDependency("demo", "git", "https://github.com/demo/demo.git")
package = provider.search_for_vcs(dependency)[0]
package = provider.search_for_direct_origin_dependency(dependency)
assert package.name == "demo"
assert package.version.text == "0.1.2"
......@@ -131,7 +131,7 @@ def test_search_for_vcs_read_setup_with_extras(
"demo", "git", "https://github.com/demo/demo.git", extras=["foo"]
)
package = provider.search_for_vcs(dependency)[0]
package = provider.search_for_direct_origin_dependency(dependency)
assert package.name == "demo"
assert package.version.text == "0.1.2"
......@@ -153,7 +153,7 @@ def test_search_for_vcs_read_setup_raises_error_if_no_version(
dependency = VCSDependency("demo", "git", "https://github.com/demo/no-version.git")
with pytest.raises(RuntimeError):
provider.search_for_vcs(dependency)
provider.search_for_direct_origin_dependency(dependency)
@pytest.mark.parametrize("directory", ["demo", "non-canonical-name"])
......@@ -168,7 +168,7 @@ def test_search_for_directory_setup_egg_info(provider: Provider, directory: str)
/ directory,
)
package = provider.search_for_directory(dependency)[0]
package = provider.search_for_direct_origin_dependency(dependency)
assert package.name == "demo"
assert package.version.text == "0.1.2"
......@@ -195,7 +195,7 @@ def test_search_for_directory_setup_egg_info_with_extras(provider: Provider):
extras=["foo"],
)
package = provider.search_for_directory(dependency)[0]
package = provider.search_for_direct_origin_dependency(dependency)
assert package.name == "demo"
assert package.version.text == "0.1.2"
......@@ -228,7 +228,7 @@ def test_search_for_directory_setup_with_base(provider: Provider, directory: str
/ directory,
)
package = provider.search_for_directory(dependency)[0]
package = provider.search_for_direct_origin_dependency(dependency)
assert package.name == "demo"
assert package.version.text == "0.1.2"
......@@ -266,7 +266,7 @@ def test_search_for_directory_setup_read_setup(
/ "demo",
)
package = provider.search_for_directory(dependency)[0]
package = provider.search_for_direct_origin_dependency(dependency)
assert package.name == "demo"
assert package.version.text == "0.1.2"
......@@ -297,7 +297,7 @@ def test_search_for_directory_setup_read_setup_with_extras(
extras=["foo"],
)
package = provider.search_for_directory(dependency)[0]
package = provider.search_for_direct_origin_dependency(dependency)
assert package.name == "demo"
assert package.version.text == "0.1.2"
......@@ -323,7 +323,7 @@ def test_search_for_directory_setup_read_setup_with_no_dependencies(provider: Pr
/ "no-dependencies",
)
package = provider.search_for_directory(dependency)[0]
package = provider.search_for_direct_origin_dependency(dependency)
assert package.name == "demo"
assert package.version.text == "0.1.2"
......@@ -337,7 +337,7 @@ def test_search_for_directory_poetry(provider: Provider):
Path(__file__).parent.parent / "fixtures" / "project_with_extras",
)
package = provider.search_for_directory(dependency)[0]
package = provider.search_for_direct_origin_dependency(dependency)
assert package.name == "project-with-extras"
assert package.version.text == "1.2.3"
......@@ -366,7 +366,7 @@ def test_search_for_directory_poetry_with_extras(provider: Provider):
extras=["extras_a"],
)
package = provider.search_for_directory(dependency)[0]
package = provider.search_for_direct_origin_dependency(dependency)
assert package.name == "project-with-extras"
assert package.version.text == "1.2.3"
......@@ -397,7 +397,7 @@ def test_search_for_file_sdist(provider: Provider):
/ "demo-0.1.0.tar.gz",
)
package = provider.search_for_file(dependency)[0]
package = provider.search_for_direct_origin_dependency(dependency)
assert package.name == "demo"
assert package.version.text == "0.1.0"
......@@ -429,7 +429,7 @@ def test_search_for_file_sdist_with_extras(provider: Provider):
extras=["foo"],
)
package = provider.search_for_file(dependency)[0]
package = provider.search_for_direct_origin_dependency(dependency)
assert package.name == "demo"
assert package.version.text == "0.1.0"
......@@ -460,7 +460,7 @@ def test_search_for_file_wheel(provider: Provider):
/ "demo-0.1.0-py2.py3-none-any.whl",
)
package = provider.search_for_file(dependency)[0]
package = provider.search_for_direct_origin_dependency(dependency)
assert package.name == "demo"
assert package.version.text == "0.1.0"
......@@ -492,7 +492,7 @@ def test_search_for_file_wheel_with_extras(provider: Provider):
extras=["foo"],
)
package = provider.search_for_file(dependency)[0]
package = provider.search_for_direct_origin_dependency(dependency)
assert package.name == "demo"
assert package.version.text == "0.1.0"
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment