Commit 338e02a8 by David Hotham Committed by GitHub

Simplify extras handling (#6372)

Another random bit of code-tidying:

- remove some dead code that constructs an unused dictionary
- simplify (considerably) the code that walks the dependency tree looking for packages introduced by extras
parent e803c3df
...@@ -524,21 +524,12 @@ class Installer: ...@@ -524,21 +524,12 @@ class Installer:
op.skip("Not needed for the current environment") op.skip("Not needed for the current environment")
continue continue
if self._update:
extras = {}
for extra, dependencies in self._package.extras.items():
extras[extra] = [dependency.name for dependency in dependencies]
else:
extras = {}
for extra, deps in self._locker.lock_data.get("extras", {}).items():
extras[extra] = [dep.lower() for dep in deps]
# If a package is optional and not requested # If a package is optional and not requested
# in any extra we skip it # in any extra we skip it
if package.optional and package.name not in extra_packages: if package.optional and package.name not in extra_packages:
op.skip("Not required") op.skip("Not required")
def _get_extra_packages(self, repo: Repository) -> list[str]: def _get_extra_packages(self, repo: Repository) -> set[NormalizedName]:
""" """
Returns all package names required by extras. Returns all package names required by extras.
...@@ -550,7 +541,7 @@ class Installer: ...@@ -550,7 +541,7 @@ class Installer:
else: else:
extras = self._locker.lock_data.get("extras", {}) extras = self._locker.lock_data.get("extras", {})
return list(get_extra_package_names(repo.packages, extras, self._extras)) return get_extra_package_names(repo.packages, extras, self._extras)
def _get_installer(self) -> BaseInstaller: def _get_installer(self) -> BaseInstaller:
return PipInstaller(self._env, self._io, self._pool) return PipInstaller(self._env, self._io, self._pool)
......
...@@ -6,7 +6,6 @@ from typing import TYPE_CHECKING ...@@ -6,7 +6,6 @@ from typing import TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
from collections.abc import Collection from collections.abc import Collection
from collections.abc import Iterable from collections.abc import Iterable
from collections.abc import Iterator
from typing import Mapping from typing import Mapping
from packaging.utils import NormalizedName from packaging.utils import NormalizedName
...@@ -17,7 +16,7 @@ def get_extra_package_names( ...@@ -17,7 +16,7 @@ def get_extra_package_names(
packages: Iterable[Package], packages: Iterable[Package],
extras: Mapping[str, list[str]], extras: Mapping[str, list[str]],
extra_names: Collection[str], extra_names: Collection[str],
) -> Iterable[NormalizedName]: ) -> set[NormalizedName]:
""" """
Returns all package names required by the given extras. Returns all package names required by the given extras.
...@@ -29,41 +28,30 @@ def get_extra_package_names( ...@@ -29,41 +28,30 @@ def get_extra_package_names(
from packaging.utils import canonicalize_name from packaging.utils import canonicalize_name
if not extra_names: if not extra_names:
return [] return set()
# lookup for packages by name, faster than looping over packages repeatedly # lookup for packages by name, faster than looping over packages repeatedly
packages_by_name = {package.name: package for package in packages} packages_by_name = {package.name: package for package in packages}
# get and flatten names of packages we've opted into as extras # Depth-first search, with our entry points being the packages directly required by
extra_package_names = [ # extras.
seen_package_names = set()
stack = [
canonicalize_name(extra_package_name) canonicalize_name(extra_package_name)
for extra_name in extra_names for extra_name in extra_names
for extra_package_name in extras.get(extra_name, ()) for extra_package_name in extras.get(extra_name, ())
] ]
# keep record of packages seen during recursion in order to avoid recursion error while stack:
seen_package_names = set() package_name = stack.pop()
# We expect to find all packages, but can just carry on if we don't.
package = packages_by_name.get(package_name)
if package is None or package.name in seen_package_names:
continue
seen_package_names.add(package.name)
def _extra_packages( stack += [dependency.name for dependency in package.requires]
package_names: Iterable[NormalizedName],
) -> Iterator[NormalizedName]:
"""Recursively find dependencies for packages names"""
# for each extra package name
for package_name in package_names:
# Find the actual Package object. A missing key indicates an implicit
# dependency (like setuptools), which should be ignored
package = packages_by_name.get(package_name)
if package:
if package.name not in seen_package_names:
seen_package_names.add(package.name)
yield package.name
# Recurse for dependencies
for dependency_package_name in _extra_packages(
dependency.name
for dependency in package.requires
if dependency.name not in seen_package_names
):
seen_package_names.add(dependency_package_name)
yield dependency_package_name
return _extra_packages(extra_package_names) return seen_package_names
...@@ -24,37 +24,37 @@ _PACKAGE_QUIX.add_dependency(Factory.create_dependency("baz", "*")) ...@@ -24,37 +24,37 @@ _PACKAGE_QUIX.add_dependency(Factory.create_dependency("baz", "*"))
["packages", "extras", "extra_names", "expected_extra_package_names"], ["packages", "extras", "extra_names", "expected_extra_package_names"],
[ [
# Empty edge case # Empty edge case
([], {}, [], []), ([], {}, [], set()),
# Selecting no extras is fine # Selecting no extras is fine
([_PACKAGE_FOO], {}, [], []), ([_PACKAGE_FOO], {}, [], set()),
# An empty extras group should return an empty list # An empty extras group should return an empty list
([_PACKAGE_FOO], {"group0": []}, ["group0"], []), ([_PACKAGE_FOO], {"group0": []}, ["group0"], set()),
# Selecting an extras group should return the contained packages # Selecting an extras group should return the contained packages
( (
[_PACKAGE_FOO, _PACKAGE_SPAM, _PACKAGE_BAR], [_PACKAGE_FOO, _PACKAGE_SPAM, _PACKAGE_BAR],
{"group0": ["foo"]}, {"group0": ["foo"]},
["group0"], ["group0"],
["foo"], {"foo"},
), ),
# If a package has dependencies, we should also get their names # If a package has dependencies, we should also get their names
( (
[_PACKAGE_FOO, _PACKAGE_SPAM, _PACKAGE_BAR], [_PACKAGE_FOO, _PACKAGE_SPAM, _PACKAGE_BAR],
{"group0": ["bar"], "group1": ["spam"]}, {"group0": ["bar"], "group1": ["spam"]},
["group0"], ["group0"],
["bar", "foo"], {"bar", "foo"},
), ),
# Selecting multiple extras should get us the union of all package names # Selecting multiple extras should get us the union of all package names
( (
[_PACKAGE_FOO, _PACKAGE_SPAM, _PACKAGE_BAR], [_PACKAGE_FOO, _PACKAGE_SPAM, _PACKAGE_BAR],
{"group0": ["bar"], "group1": ["spam"]}, {"group0": ["bar"], "group1": ["spam"]},
["group0", "group1"], ["group0", "group1"],
["bar", "foo", "spam"], {"bar", "foo", "spam"},
), ),
( (
[_PACKAGE_BAZ, _PACKAGE_QUIX], [_PACKAGE_BAZ, _PACKAGE_QUIX],
{"group0": ["baz"], "group1": ["quix"]}, {"group0": ["baz"], "group1": ["quix"]},
["group0", "group1"], ["group0", "group1"],
["baz", "quix"], {"baz", "quix"},
), ),
], ],
) )
...@@ -62,9 +62,9 @@ def test_get_extra_package_names( ...@@ -62,9 +62,9 @@ def test_get_extra_package_names(
packages: list[Package], packages: list[Package],
extras: dict[str, list[str]], extras: dict[str, list[str]],
extra_names: list[str], extra_names: list[str],
expected_extra_package_names: list[str], expected_extra_package_names: set[str],
) -> None: ) -> None:
assert ( assert (
list(get_extra_package_names(packages, extras, extra_names)) get_extra_package_names(packages, extras, extra_names)
== expected_extra_package_names == expected_extra_package_names
) )
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment