Commit 63c86bf9 by David Hotham Committed by GitHub

refactor the search for direct-origin dependencies (#5904)

* refactor the search for direct-origin dependencies

* clarify where the interface is

* fix treatment of dependency via cache
parent 8b640886
from __future__ import annotations from __future__ import annotations
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from typing import cast
from cleo.helpers import argument from cleo.helpers import argument
from cleo.helpers import option from cleo.helpers import option
from poetry.core.packages.directory_dependency import DirectoryDependency
from poetry.core.packages.file_dependency import FileDependency
from poetry.core.packages.vcs_dependency import VCSDependency
from poetry.console.commands.group_command import GroupCommand from poetry.console.commands.group_command import GroupCommand
from poetry.utils.helpers import canonicalize_name from poetry.utils.helpers import canonicalize_name
...@@ -499,16 +495,7 @@ lists all packages available.""" ...@@ -499,16 +495,7 @@ lists all packages available."""
for dep in requires: for dep in requires:
if dep.name == package.name: if dep.name == package.name:
provider = Provider(root, self.poetry.pool, NullIO()) provider = Provider(root, self.poetry.pool, NullIO())
return provider.search_for_direct_origin_dependency(dep)
if dep.is_vcs():
dep = cast(VCSDependency, dep)
return provider.search_for_vcs(dep)[0]
if dep.is_file():
dep = cast(FileDependency, dep)
return provider.search_for_file(dep)[0]
if dep.is_directory():
dep = cast(DirectoryDependency, dep)
return provider.search_for_directory(dep)[0]
name = package.name name = package.name
selector = VersionSelector(self.poetry.pool) selector = VersionSelector(self.poetry.pool)
......
...@@ -221,6 +221,44 @@ class Provider: ...@@ -221,6 +221,44 @@ class Provider:
) )
return packages return packages
def search_for_direct_origin_dependency(self, dependency: Dependency) -> Package:
package = self._deferred_cache.get(dependency)
if package is not None:
pass
elif dependency.is_vcs():
dependency = cast(VCSDependency, dependency)
package = self._search_for_vcs(dependency)
elif dependency.is_file():
dependency = cast(FileDependency, dependency)
package = self._search_for_file(dependency)
elif dependency.is_directory():
dependency = cast(DirectoryDependency, dependency)
package = self._search_for_directory(dependency)
elif dependency.is_url():
dependency = cast(URLDependency, dependency)
package = self._search_for_url(dependency)
else:
raise RuntimeError(
f"Unknown direct dependency type {dependency.source_type}"
)
if dependency.is_vcs():
dependency._source_reference = package.source_reference
dependency._source_resolved_reference = package.source_resolved_reference
dependency._source_subdirectory = package.source_subdirectory
dependency._constraint = package.version
dependency._pretty_constraint = package.version.text
self._deferred_cache[dependency] = package
return package
def search_for(self, dependency: Dependency) -> list[DependencyPackage]: def search_for(self, dependency: Dependency) -> list[DependencyPackage]:
""" """
Search for the specifications that match the given dependency. Search for the specifications that match the given dependency.
...@@ -231,18 +269,9 @@ class Provider: ...@@ -231,18 +269,9 @@ class Provider:
if dependency.is_root: if dependency.is_root:
return PackageCollection(dependency, [self._package]) return PackageCollection(dependency, [self._package])
if dependency.is_vcs(): if dependency.is_direct_origin():
dependency = cast(VCSDependency, dependency) packages = [self.search_for_direct_origin_dependency(dependency)]
packages = self.search_for_vcs(dependency)
elif dependency.is_file():
dependency = cast(FileDependency, dependency)
packages = self.search_for_file(dependency)
elif dependency.is_directory():
dependency = cast(DirectoryDependency, dependency)
packages = self.search_for_directory(dependency)
elif dependency.is_url():
dependency = cast(URLDependency, dependency)
packages = self.search_for_url(dependency)
else: else:
packages = self._pool.find_packages(dependency) packages = self._pool.find_packages(dependency)
...@@ -259,7 +288,7 @@ class Provider: ...@@ -259,7 +288,7 @@ class Provider:
return PackageCollection(dependency, packages) return PackageCollection(dependency, packages)
def search_for_vcs(self, dependency: VCSDependency) -> list[Package]: def _search_for_vcs(self, dependency: VCSDependency) -> Package:
""" """
Search for the specifications that match the given VCS dependency. Search for the specifications that match the given VCS dependency.
...@@ -281,16 +310,7 @@ class Provider: ...@@ -281,16 +310,7 @@ class Provider:
package.develop = dependency.develop package.develop = dependency.develop
dependency._constraint = package.version return package
dependency._pretty_constraint = package.version.text
dependency._source_reference = package.source_reference
dependency._source_resolved_reference = package.source_resolved_reference
dependency._source_subdirectory = package.source_subdirectory
self._deferred_cache[dependency] = package
return [package]
@staticmethod @staticmethod
def get_package_from_vcs( def get_package_from_vcs(
...@@ -314,18 +334,8 @@ class Provider: ...@@ -314,18 +334,8 @@ class Provider:
source_root=source_root, source_root=source_root,
) )
def search_for_file(self, dependency: FileDependency) -> list[Package]: def _search_for_file(self, dependency: FileDependency) -> Package:
if dependency in self._deferred_cache: package = self.get_package_from_file(dependency.full_path)
_package = self._deferred_cache[dependency]
package = _package.clone()
else:
package = self.get_package_from_file(dependency.full_path)
dependency._constraint = package.version
dependency._pretty_constraint = package.version.text
self._deferred_cache[dependency] = package
self.validate_package_for_dependency(dependency=dependency, package=package) self.validate_package_for_dependency(dependency=dependency, package=package)
...@@ -336,7 +346,7 @@ class Provider: ...@@ -336,7 +346,7 @@ class Provider:
{"file": dependency.path.name, "hash": "sha256:" + dependency.hash()} {"file": dependency.path.name, "hash": "sha256:" + dependency.hash()}
] ]
return [package] return package
@classmethod @classmethod
def get_package_from_file(cls, file_path: Path) -> Package: def get_package_from_file(cls, file_path: Path) -> Package:
...@@ -351,18 +361,8 @@ class Provider: ...@@ -351,18 +361,8 @@ class Provider:
return package return package
def search_for_directory(self, dependency: DirectoryDependency) -> list[Package]: def _search_for_directory(self, dependency: DirectoryDependency) -> Package:
if dependency in self._deferred_cache: package = self.get_package_from_directory(dependency.full_path)
_package = self._deferred_cache[dependency]
package = _package.clone()
else:
package = self.get_package_from_directory(dependency.full_path)
dependency._constraint = package.version
dependency._pretty_constraint = package.version.text
self._deferred_cache[dependency] = package
self.validate_package_for_dependency(dependency=dependency, package=package) self.validate_package_for_dependency(dependency=dependency, package=package)
...@@ -371,16 +371,13 @@ class Provider: ...@@ -371,16 +371,13 @@ class Provider:
if dependency.base is not None: if dependency.base is not None:
package.root_dir = dependency.base package.root_dir = dependency.base
return [package] return package
@classmethod @classmethod
def get_package_from_directory(cls, directory: Path) -> Package: def get_package_from_directory(cls, directory: Path) -> Package:
return PackageInfo.from_directory(path=directory).to_package(root_dir=directory) return PackageInfo.from_directory(path=directory).to_package(root_dir=directory)
def search_for_url(self, dependency: URLDependency) -> list[Package]: def _search_for_url(self, dependency: URLDependency) -> Package:
if dependency in self._deferred_cache:
return [self._deferred_cache[dependency]]
package = self.get_package_from_url(dependency.url) package = self.get_package_from_url(dependency.url)
self.validate_package_for_dependency(dependency=dependency, package=package) self.validate_package_for_dependency(dependency=dependency, package=package)
...@@ -393,12 +390,7 @@ class Provider: ...@@ -393,12 +390,7 @@ class Provider:
for extra_dep in package.extras[extra]: for extra_dep in package.extras[extra]:
package.add_dependency(extra_dep) package.add_dependency(extra_dep)
dependency._constraint = package.version return package
dependency._pretty_constraint = package.version.text
self._deferred_cache[dependency] = package
return [package]
@classmethod @classmethod
def get_package_from_url(cls, url: str) -> Package: def get_package_from_url(cls, url: str) -> Package:
...@@ -538,14 +530,8 @@ class Provider: ...@@ -538,14 +530,8 @@ class Provider:
if self._load_deferred: if self._load_deferred:
# Retrieving constraints for deferred dependencies # Retrieving constraints for deferred dependencies
for r in requires: for r in requires:
if r.is_directory(): if r.is_direct_origin():
self.search_for_directory(r) self.search_for_direct_origin_dependency(r)
elif r.is_file():
self.search_for_file(r)
elif r.is_vcs():
self.search_for_vcs(r)
elif r.is_url():
self.search_for_url(r)
optional_dependencies = [] optional_dependencies = []
_dependencies = [] _dependencies = []
......
...@@ -60,14 +60,14 @@ def test_search_for_vcs_retains_develop_flag(provider: Provider, value: bool): ...@@ -60,14 +60,14 @@ def test_search_for_vcs_retains_develop_flag(provider: Provider, value: bool):
dependency = VCSDependency( dependency = VCSDependency(
"demo", "git", "https://github.com/demo/demo.git", develop=value "demo", "git", "https://github.com/demo/demo.git", develop=value
) )
package = provider.search_for_vcs(dependency)[0] package = provider.search_for_direct_origin_dependency(dependency)
assert package.develop == value assert package.develop == value
def test_search_for_vcs_setup_egg_info(provider: Provider): def test_search_for_vcs_setup_egg_info(provider: Provider):
dependency = VCSDependency("demo", "git", "https://github.com/demo/demo.git") dependency = VCSDependency("demo", "git", "https://github.com/demo/demo.git")
package = provider.search_for_vcs(dependency)[0] package = provider.search_for_direct_origin_dependency(dependency)
assert package.name == "demo" assert package.name == "demo"
assert package.version.text == "0.1.2" assert package.version.text == "0.1.2"
...@@ -87,7 +87,7 @@ def test_search_for_vcs_setup_egg_info_with_extras(provider: Provider): ...@@ -87,7 +87,7 @@ def test_search_for_vcs_setup_egg_info_with_extras(provider: Provider):
"demo", "git", "https://github.com/demo/demo.git", extras=["foo"] "demo", "git", "https://github.com/demo/demo.git", extras=["foo"]
) )
package = provider.search_for_vcs(dependency)[0] package = provider.search_for_direct_origin_dependency(dependency)
assert package.name == "demo" assert package.name == "demo"
assert package.version.text == "0.1.2" assert package.version.text == "0.1.2"
...@@ -107,7 +107,7 @@ def test_search_for_vcs_read_setup(provider: Provider, mocker: MockerFixture): ...@@ -107,7 +107,7 @@ def test_search_for_vcs_read_setup(provider: Provider, mocker: MockerFixture):
dependency = VCSDependency("demo", "git", "https://github.com/demo/demo.git") dependency = VCSDependency("demo", "git", "https://github.com/demo/demo.git")
package = provider.search_for_vcs(dependency)[0] package = provider.search_for_direct_origin_dependency(dependency)
assert package.name == "demo" assert package.name == "demo"
assert package.version.text == "0.1.2" assert package.version.text == "0.1.2"
...@@ -131,7 +131,7 @@ def test_search_for_vcs_read_setup_with_extras( ...@@ -131,7 +131,7 @@ def test_search_for_vcs_read_setup_with_extras(
"demo", "git", "https://github.com/demo/demo.git", extras=["foo"] "demo", "git", "https://github.com/demo/demo.git", extras=["foo"]
) )
package = provider.search_for_vcs(dependency)[0] package = provider.search_for_direct_origin_dependency(dependency)
assert package.name == "demo" assert package.name == "demo"
assert package.version.text == "0.1.2" assert package.version.text == "0.1.2"
...@@ -153,7 +153,7 @@ def test_search_for_vcs_read_setup_raises_error_if_no_version( ...@@ -153,7 +153,7 @@ def test_search_for_vcs_read_setup_raises_error_if_no_version(
dependency = VCSDependency("demo", "git", "https://github.com/demo/no-version.git") dependency = VCSDependency("demo", "git", "https://github.com/demo/no-version.git")
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):
provider.search_for_vcs(dependency) provider.search_for_direct_origin_dependency(dependency)
@pytest.mark.parametrize("directory", ["demo", "non-canonical-name"]) @pytest.mark.parametrize("directory", ["demo", "non-canonical-name"])
...@@ -168,7 +168,7 @@ def test_search_for_directory_setup_egg_info(provider: Provider, directory: str) ...@@ -168,7 +168,7 @@ def test_search_for_directory_setup_egg_info(provider: Provider, directory: str)
/ directory, / directory,
) )
package = provider.search_for_directory(dependency)[0] package = provider.search_for_direct_origin_dependency(dependency)
assert package.name == "demo" assert package.name == "demo"
assert package.version.text == "0.1.2" assert package.version.text == "0.1.2"
...@@ -195,7 +195,7 @@ def test_search_for_directory_setup_egg_info_with_extras(provider: Provider): ...@@ -195,7 +195,7 @@ def test_search_for_directory_setup_egg_info_with_extras(provider: Provider):
extras=["foo"], extras=["foo"],
) )
package = provider.search_for_directory(dependency)[0] package = provider.search_for_direct_origin_dependency(dependency)
assert package.name == "demo" assert package.name == "demo"
assert package.version.text == "0.1.2" assert package.version.text == "0.1.2"
...@@ -228,7 +228,7 @@ def test_search_for_directory_setup_with_base(provider: Provider, directory: str ...@@ -228,7 +228,7 @@ def test_search_for_directory_setup_with_base(provider: Provider, directory: str
/ directory, / directory,
) )
package = provider.search_for_directory(dependency)[0] package = provider.search_for_direct_origin_dependency(dependency)
assert package.name == "demo" assert package.name == "demo"
assert package.version.text == "0.1.2" assert package.version.text == "0.1.2"
...@@ -266,7 +266,7 @@ def test_search_for_directory_setup_read_setup( ...@@ -266,7 +266,7 @@ def test_search_for_directory_setup_read_setup(
/ "demo", / "demo",
) )
package = provider.search_for_directory(dependency)[0] package = provider.search_for_direct_origin_dependency(dependency)
assert package.name == "demo" assert package.name == "demo"
assert package.version.text == "0.1.2" assert package.version.text == "0.1.2"
...@@ -297,7 +297,7 @@ def test_search_for_directory_setup_read_setup_with_extras( ...@@ -297,7 +297,7 @@ def test_search_for_directory_setup_read_setup_with_extras(
extras=["foo"], extras=["foo"],
) )
package = provider.search_for_directory(dependency)[0] package = provider.search_for_direct_origin_dependency(dependency)
assert package.name == "demo" assert package.name == "demo"
assert package.version.text == "0.1.2" assert package.version.text == "0.1.2"
...@@ -323,7 +323,7 @@ def test_search_for_directory_setup_read_setup_with_no_dependencies(provider: Pr ...@@ -323,7 +323,7 @@ def test_search_for_directory_setup_read_setup_with_no_dependencies(provider: Pr
/ "no-dependencies", / "no-dependencies",
) )
package = provider.search_for_directory(dependency)[0] package = provider.search_for_direct_origin_dependency(dependency)
assert package.name == "demo" assert package.name == "demo"
assert package.version.text == "0.1.2" assert package.version.text == "0.1.2"
...@@ -337,7 +337,7 @@ def test_search_for_directory_poetry(provider: Provider): ...@@ -337,7 +337,7 @@ def test_search_for_directory_poetry(provider: Provider):
Path(__file__).parent.parent / "fixtures" / "project_with_extras", Path(__file__).parent.parent / "fixtures" / "project_with_extras",
) )
package = provider.search_for_directory(dependency)[0] package = provider.search_for_direct_origin_dependency(dependency)
assert package.name == "project-with-extras" assert package.name == "project-with-extras"
assert package.version.text == "1.2.3" assert package.version.text == "1.2.3"
...@@ -366,7 +366,7 @@ def test_search_for_directory_poetry_with_extras(provider: Provider): ...@@ -366,7 +366,7 @@ def test_search_for_directory_poetry_with_extras(provider: Provider):
extras=["extras_a"], extras=["extras_a"],
) )
package = provider.search_for_directory(dependency)[0] package = provider.search_for_direct_origin_dependency(dependency)
assert package.name == "project-with-extras" assert package.name == "project-with-extras"
assert package.version.text == "1.2.3" assert package.version.text == "1.2.3"
...@@ -397,7 +397,7 @@ def test_search_for_file_sdist(provider: Provider): ...@@ -397,7 +397,7 @@ def test_search_for_file_sdist(provider: Provider):
/ "demo-0.1.0.tar.gz", / "demo-0.1.0.tar.gz",
) )
package = provider.search_for_file(dependency)[0] package = provider.search_for_direct_origin_dependency(dependency)
assert package.name == "demo" assert package.name == "demo"
assert package.version.text == "0.1.0" assert package.version.text == "0.1.0"
...@@ -429,7 +429,7 @@ def test_search_for_file_sdist_with_extras(provider: Provider): ...@@ -429,7 +429,7 @@ def test_search_for_file_sdist_with_extras(provider: Provider):
extras=["foo"], extras=["foo"],
) )
package = provider.search_for_file(dependency)[0] package = provider.search_for_direct_origin_dependency(dependency)
assert package.name == "demo" assert package.name == "demo"
assert package.version.text == "0.1.0" assert package.version.text == "0.1.0"
...@@ -460,7 +460,7 @@ def test_search_for_file_wheel(provider: Provider): ...@@ -460,7 +460,7 @@ def test_search_for_file_wheel(provider: Provider):
/ "demo-0.1.0-py2.py3-none-any.whl", / "demo-0.1.0-py2.py3-none-any.whl",
) )
package = provider.search_for_file(dependency)[0] package = provider.search_for_direct_origin_dependency(dependency)
assert package.name == "demo" assert package.name == "demo"
assert package.version.text == "0.1.0" assert package.version.text == "0.1.0"
...@@ -492,7 +492,7 @@ def test_search_for_file_wheel_with_extras(provider: Provider): ...@@ -492,7 +492,7 @@ def test_search_for_file_wheel_with_extras(provider: Provider):
extras=["foo"], extras=["foo"],
) )
package = provider.search_for_file(dependency)[0] package = provider.search_for_direct_origin_dependency(dependency)
assert package.name == "demo" assert package.name == "demo"
assert package.version.text == "0.1.0" assert package.version.text == "0.1.0"
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment