Commit 2deaf594 by David Hotham Committed by GitHub

tighten typechecking (#6203)

parent 3bd4c7a3
...@@ -410,7 +410,7 @@ class Provider: ...@@ -410,7 +410,7 @@ class Provider:
file_name = os.path.basename(urllib.parse.urlparse(url).path) file_name = os.path.basename(urllib.parse.urlparse(url).path)
with tempfile.TemporaryDirectory() as temp_dir: with tempfile.TemporaryDirectory() as temp_dir:
dest = Path(temp_dir) / file_name dest = Path(temp_dir) / file_name
download_file(url, str(dest)) download_file(url, dest)
package = cls.get_package_from_file(dest) package = cls.get_package_from_file(dest)
package._source_type = "url" package._source_type = "url"
...@@ -528,7 +528,7 @@ class Provider: ...@@ -528,7 +528,7 @@ class Provider:
dependency, dependency,
self._pool.package( self._pool.package(
package.name, package.name,
package.version.text, package.version,
extras=list(dependency.extras), extras=list(dependency.extras),
repository=dependency.source_name, repository=dependency.source_name,
), ),
......
...@@ -13,7 +13,9 @@ from poetry.repositories.repository import Repository ...@@ -13,7 +13,9 @@ from poetry.repositories.repository import Repository
if TYPE_CHECKING: if TYPE_CHECKING:
from packaging.utils import NormalizedName
from poetry.core.packages.package import Package from poetry.core.packages.package import Package
from poetry.core.semver.version import Version
from poetry.inspection.info import PackageInfo from poetry.inspection.info import PackageInfo
...@@ -40,10 +42,12 @@ class CachedRepository(Repository, ABC): ...@@ -40,10 +42,12 @@ class CachedRepository(Repository, ABC):
) )
@abstractmethod @abstractmethod
def _get_release_info(self, name: str, version: str) -> dict[str, Any]: def _get_release_info(
self, name: NormalizedName, version: Version
) -> dict[str, Any]:
raise NotImplementedError() raise NotImplementedError()
def get_release_info(self, name: str, version: str) -> PackageInfo: def get_release_info(self, name: NormalizedName, version: Version) -> PackageInfo:
""" """
Return the release information given a package name and a version. Return the release information given a package name and a version.
...@@ -74,8 +78,8 @@ class CachedRepository(Repository, ABC): ...@@ -74,8 +78,8 @@ class CachedRepository(Repository, ABC):
def package( def package(
self, self,
name: str, name: NormalizedName,
version: str, version: Version,
extras: list[str] | None = None, extras: list[str] | None = None,
) -> Package: ) -> Package:
return self.get_release_info(name, version).to_package(name=name, extras=extras) return self.get_release_info(name, version).to_package(name=name, extras=extras)
...@@ -68,7 +68,7 @@ class HTTPRepository(CachedRepository, ABC): ...@@ -68,7 +68,7 @@ class HTTPRepository(CachedRepository, ABC):
def authenticated_url(self) -> str: def authenticated_url(self) -> str:
return self._authenticator.authenticated_url(url=self.url) return self._authenticator.authenticated_url(url=self.url)
def _download(self, url: str, dest: str) -> None: def _download(self, url: str, dest: Path) -> None:
return download_file(url, dest, session=self.session) return download_file(url, dest, session=self.session)
def _get_info_from_wheel(self, url: str) -> PackageInfo: def _get_info_from_wheel(self, url: str) -> PackageInfo:
...@@ -81,7 +81,7 @@ class HTTPRepository(CachedRepository, ABC): ...@@ -81,7 +81,7 @@ class HTTPRepository(CachedRepository, ABC):
with temporary_directory() as temp_dir: with temporary_directory() as temp_dir:
filepath = Path(temp_dir) / filename filepath = Path(temp_dir) / filename
self._download(url, str(filepath)) self._download(url, filepath)
return PackageInfo.from_wheel(filepath) return PackageInfo.from_wheel(filepath)
...@@ -97,7 +97,7 @@ class HTTPRepository(CachedRepository, ABC): ...@@ -97,7 +97,7 @@ class HTTPRepository(CachedRepository, ABC):
with temporary_directory() as temp_dir: with temporary_directory() as temp_dir:
filepath = Path(temp_dir) / filename filepath = Path(temp_dir) / filename
self._download(url, str(filepath)) self._download(url, filepath)
return PackageInfo.from_sdist(filepath) return PackageInfo.from_sdist(filepath)
...@@ -226,7 +226,7 @@ class HTTPRepository(CachedRepository, ABC): ...@@ -226,7 +226,7 @@ class HTTPRepository(CachedRepository, ABC):
): ):
with temporary_directory() as temp_dir: with temporary_directory() as temp_dir:
filepath = Path(temp_dir) / link.filename filepath = Path(temp_dir) / link.filename
self._download(link.url, str(filepath)) self._download(link.url, filepath)
known_hash = ( known_hash = (
getattr(hashlib, link.hash_name)() if link.hash_name else None getattr(hashlib, link.hash_name)() if link.hash_name else None
......
...@@ -3,9 +3,7 @@ from __future__ import annotations ...@@ -3,9 +3,7 @@ from __future__ import annotations
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from typing import Any from typing import Any
from packaging.utils import canonicalize_name
from poetry.core.packages.package import Package from poetry.core.packages.package import Package
from poetry.core.semver.version import Version
from poetry.inspection.info import PackageInfo from poetry.inspection.info import PackageInfo
from poetry.repositories.exceptions import PackageNotFound from poetry.repositories.exceptions import PackageNotFound
...@@ -16,6 +14,7 @@ from poetry.repositories.link_sources.html import SimpleRepositoryPage ...@@ -16,6 +14,7 @@ from poetry.repositories.link_sources.html import SimpleRepositoryPage
if TYPE_CHECKING: if TYPE_CHECKING:
from packaging.utils import NormalizedName from packaging.utils import NormalizedName
from poetry.core.packages.utils.link import Link from poetry.core.packages.utils.link import Link
from poetry.core.semver.version import Version
from poetry.core.semver.version_constraint import VersionConstraint from poetry.core.semver.version_constraint import VersionConstraint
from poetry.config.config import Config from poetry.config.config import Config
...@@ -35,7 +34,7 @@ class LegacyRepository(HTTPRepository): ...@@ -35,7 +34,7 @@ class LegacyRepository(HTTPRepository):
super().__init__(name, url.rstrip("/"), config, disable_cache) super().__init__(name, url.rstrip("/"), config, disable_cache)
def package( def package(
self, name: str, version: str, extras: list[str] | None = None self, name: NormalizedName, version: Version, extras: list[str] | None = None
) -> Package: ) -> Package:
""" """
Retrieve the release information. Retrieve the release information.
...@@ -49,7 +48,7 @@ class LegacyRepository(HTTPRepository): ...@@ -49,7 +48,7 @@ class LegacyRepository(HTTPRepository):
should be much faster. should be much faster.
""" """
try: try:
index = self._packages.index(Package(name, version, version)) index = self._packages.index(Package(name, version))
return self._packages[index] return self._packages[index]
except ValueError: except ValueError:
...@@ -106,18 +105,20 @@ class LegacyRepository(HTTPRepository): ...@@ -106,18 +105,20 @@ class LegacyRepository(HTTPRepository):
for version in versions for version in versions
] ]
def _get_release_info(self, name: str, version: str) -> dict[str, Any]: def _get_release_info(
page = self._get_page(f"/{canonicalize_name(name)}/") self, name: NormalizedName, version: Version
) -> dict[str, Any]:
page = self._get_page(f"/{name}/")
if page is None: if page is None:
raise PackageNotFound(f'No package named "{name}"') raise PackageNotFound(f'No package named "{name}"')
links = list(page.links_for_version(name, Version.parse(version))) links = list(page.links_for_version(name, version))
return self._links_to_data( return self._links_to_data(
links, links,
PackageInfo( PackageInfo(
name=name, name=name,
version=version, version=version.text,
summary="", summary="",
platform=None, platform=None,
requires_dist=[], requires_dist=[],
......
...@@ -17,6 +17,7 @@ from poetry.utils.patterns import wheel_file_re ...@@ -17,6 +17,7 @@ from poetry.utils.patterns import wheel_file_re
if TYPE_CHECKING: if TYPE_CHECKING:
from collections.abc import Iterator from collections.abc import Iterator
from packaging.utils import NormalizedName
from poetry.core.packages.utils.link import Link from poetry.core.packages.utils.link import Link
...@@ -98,9 +99,9 @@ class LinkSource: ...@@ -98,9 +99,9 @@ class LinkSource:
pkg = Package(name, version, source_url=link.url) pkg = Package(name, version, source_url=link.url)
return pkg return pkg
def links_for_version(self, name: str, version: Version) -> Iterator[Link]: def links_for_version(
name = canonicalize_name(name) self, name: NormalizedName, version: Version
) -> Iterator[Link]:
for link in self.links: for link in self.links:
pkg = self.link_package_data(link) pkg = self.link_package_data(link)
......
...@@ -7,8 +7,10 @@ from poetry.repositories.repository import Repository ...@@ -7,8 +7,10 @@ from poetry.repositories.repository import Repository
if TYPE_CHECKING: if TYPE_CHECKING:
from packaging.utils import NormalizedName
from poetry.core.packages.dependency import Dependency from poetry.core.packages.dependency import Dependency
from poetry.core.packages.package import Package from poetry.core.packages.package import Package
from poetry.core.semver.version import Version
class Pool(Repository): class Pool(Repository):
...@@ -118,8 +120,8 @@ class Pool(Repository): ...@@ -118,8 +120,8 @@ class Pool(Repository):
def package( def package(
self, self,
name: str, name: NormalizedName,
version: str, version: Version,
extras: list[str] | None = None, extras: list[str] | None = None,
repository: str | None = None, repository: str | None = None,
) -> Package: ) -> Package:
......
...@@ -168,7 +168,7 @@ class PyPiRepository(HTTPRepository): ...@@ -168,7 +168,7 @@ class PyPiRepository(HTTPRepository):
return links return links
def _get_release_info( def _get_release_info(
self, name: str, version: str self, name: NormalizedName, version: Version
) -> dict[str, str | list[str] | None]: ) -> dict[str, str | list[str] | None]:
from poetry.inspection.info import PackageInfo from poetry.inspection.info import PackageInfo
......
...@@ -16,6 +16,7 @@ if TYPE_CHECKING: ...@@ -16,6 +16,7 @@ if TYPE_CHECKING:
from poetry.core.packages.dependency import Dependency from poetry.core.packages.dependency import Dependency
from poetry.core.packages.package import Package from poetry.core.packages.package import Package
from poetry.core.packages.utils.link import Link from poetry.core.packages.utils.link import Link
from poetry.core.semver.version import Version
class Repository: class Repository:
...@@ -133,12 +134,10 @@ class Repository: ...@@ -133,12 +134,10 @@ class Repository:
return [] return []
def package( def package(
self, name: str, version: str, extras: list[str] | None = None self, name: NormalizedName, version: Version, extras: list[str] | None = None
) -> Package: ) -> Package:
name = name.lower()
for package in self.packages: for package in self.packages:
if name == package.name and package.version.text == version: if name == package.name and package.version == version:
return package.clone() return package.clone()
raise PackageNotFound(f"Package {name} ({version}) not found.") raise PackageNotFound(f"Package {name} ({version}) not found.")
...@@ -72,7 +72,7 @@ def merge_dicts(d1: dict[str, Any], d2: dict[str, Any]) -> None: ...@@ -72,7 +72,7 @@ def merge_dicts(d1: dict[str, Any], d2: dict[str, Any]) -> None:
def download_file( def download_file(
url: str, url: str,
dest: str, dest: Path,
session: Authenticator | Session | None = None, session: Authenticator | Session | None = None,
chunk_size: int = 1024, chunk_size: int = 1024,
) -> None: ) -> None:
......
...@@ -125,13 +125,13 @@ def mock_clone( ...@@ -125,13 +125,13 @@ def mock_clone(
return MockDulwichRepo(dest) return MockDulwichRepo(dest)
def mock_download(url: str, dest: str, **__: Any) -> None: def mock_download(url: str, dest: Path) -> None:
parts = urllib.parse.urlparse(url) parts = urllib.parse.urlparse(url)
fixtures = Path(__file__).parent / "fixtures" fixtures = Path(__file__).parent / "fixtures"
fixture = fixtures / parts.path.lstrip("/") fixture = fixtures / parts.path.lstrip("/")
copy_or_symlink(fixture, Path(dest)) copy_or_symlink(fixture, dest)
class TestExecutor(Executor): class TestExecutor(Executor):
......
...@@ -12,6 +12,8 @@ import pytest ...@@ -12,6 +12,8 @@ import pytest
from cleo.io.null_io import NullIO from cleo.io.null_io import NullIO
from deepdiff import DeepDiff from deepdiff import DeepDiff
from packaging.utils import canonicalize_name
from poetry.core.semver.version import Version
from poetry.factory import Factory from poetry.factory import Factory
from poetry.masonry.builders.editable import EditableBuilder from poetry.masonry.builders.editable import EditableBuilder
...@@ -224,7 +226,7 @@ def test_builder_falls_back_on_setup_and_pip_for_packages_with_build_scripts( ...@@ -224,7 +226,7 @@ def test_builder_falls_back_on_setup_and_pip_for_packages_with_build_scripts(
assert [] == env.executed assert [] == env.executed
def test_builder_setup_generation_runs_with_pip_editable(tmp_dir: str): def test_builder_setup_generation_runs_with_pip_editable(tmp_dir: str) -> None:
# create an isolated copy of the project # create an isolated copy of the project
fixture = Path(__file__).parent.parent.parent / "fixtures" / "extended_project" fixture = Path(__file__).parent.parent.parent / "fixtures" / "extended_project"
extended_project = Path(tmp_dir) / "extended_project" extended_project = Path(tmp_dir) / "extended_project"
...@@ -241,7 +243,10 @@ def test_builder_setup_generation_runs_with_pip_editable(tmp_dir: str): ...@@ -241,7 +243,10 @@ def test_builder_setup_generation_runs_with_pip_editable(tmp_dir: str):
# is the package installed? # is the package installed?
repository = InstalledRepository.load(venv) repository = InstalledRepository.load(venv)
assert repository.package("extended-project", "1.2.3") package = repository.package(
canonicalize_name("extended-project"), Version.parse("1.2.3")
)
assert package.name == "extended-project"
# check for the module built by build.py # check for the module built by build.py
try: try:
......
...@@ -2,18 +2,21 @@ from __future__ import annotations ...@@ -2,18 +2,21 @@ from __future__ import annotations
import pytest import pytest
from packaging.utils import canonicalize_name
from poetry.core.semver.version import Version
from poetry.repositories import Pool from poetry.repositories import Pool
from poetry.repositories import Repository from poetry.repositories import Repository
from poetry.repositories.exceptions import PackageNotFound from poetry.repositories.exceptions import PackageNotFound
from poetry.repositories.legacy_repository import LegacyRepository from poetry.repositories.legacy_repository import LegacyRepository
def test_pool_raises_package_not_found_when_no_package_is_found(): def test_pool_raises_package_not_found_when_no_package_is_found() -> None:
pool = Pool() pool = Pool()
pool.add_repository(Repository("repo")) pool.add_repository(Repository("repo"))
with pytest.raises(PackageNotFound): with pytest.raises(PackageNotFound):
pool.package("foo", "1.0.0") pool.package(canonicalize_name("foo"), Version.parse("1.0.0"))
def test_pool(): def test_pool():
......
...@@ -9,7 +9,9 @@ from typing import TYPE_CHECKING ...@@ -9,7 +9,9 @@ from typing import TYPE_CHECKING
import pytest import pytest
from packaging.utils import canonicalize_name
from poetry.core.packages.dependency import Dependency from poetry.core.packages.dependency import Dependency
from poetry.core.semver.version import Version
from requests.exceptions import TooManyRedirects from requests.exceptions import TooManyRedirects
from requests.models import Response from requests.models import Response
...@@ -50,7 +52,7 @@ class MockRepository(PyPiRepository): ...@@ -50,7 +52,7 @@ class MockRepository(PyPiRepository):
fixture = self.JSON_FIXTURES / (name + ".json") fixture = self.JSON_FIXTURES / (name + ".json")
if not fixture.exists(): if not fixture.exists():
return return None
with fixture.open(encoding="utf-8") as f: with fixture.open(encoding="utf-8") as f:
return json.loads(f.read()) return json.loads(f.read())
...@@ -63,21 +65,21 @@ class MockRepository(PyPiRepository): ...@@ -63,21 +65,21 @@ class MockRepository(PyPiRepository):
shutil.copyfile(str(fixture), dest) shutil.copyfile(str(fixture), dest)
def test_find_packages(): def test_find_packages() -> None:
repo = MockRepository() repo = MockRepository()
packages = repo.find_packages(Factory.create_dependency("requests", "^2.18")) packages = repo.find_packages(Factory.create_dependency("requests", "^2.18"))
assert len(packages) == 5 assert len(packages) == 5
def test_find_packages_with_prereleases(): def test_find_packages_with_prereleases() -> None:
repo = MockRepository() repo = MockRepository()
packages = repo.find_packages(Factory.create_dependency("toga", ">=0.3.0.dev2")) packages = repo.find_packages(Factory.create_dependency("toga", ">=0.3.0.dev2"))
assert len(packages) == 7 assert len(packages) == 7
def test_find_packages_does_not_select_prereleases_if_not_allowed(): def test_find_packages_does_not_select_prereleases_if_not_allowed() -> None:
repo = MockRepository() repo = MockRepository()
packages = repo.find_packages(Factory.create_dependency("pyyaml", "*")) packages = repo.find_packages(Factory.create_dependency("pyyaml", "*"))
...@@ -87,17 +89,17 @@ def test_find_packages_does_not_select_prereleases_if_not_allowed(): ...@@ -87,17 +89,17 @@ def test_find_packages_does_not_select_prereleases_if_not_allowed():
@pytest.mark.parametrize( @pytest.mark.parametrize(
["constraint", "count"], [("*", 1), (">=1", 0), (">=19.0.0a0", 1)] ["constraint", "count"], [("*", 1), (">=1", 0), (">=19.0.0a0", 1)]
) )
def test_find_packages_only_prereleases(constraint: str, count: int): def test_find_packages_only_prereleases(constraint: str, count: int) -> None:
repo = MockRepository() repo = MockRepository()
packages = repo.find_packages(Factory.create_dependency("black", constraint)) packages = repo.find_packages(Factory.create_dependency("black", constraint))
assert len(packages) == count assert len(packages) == count
def test_package(): def test_package() -> None:
repo = MockRepository() repo = MockRepository()
package = repo.package("requests", "2.18.4") package = repo.package(canonicalize_name("requests"), Version.parse("2.18.4"))
assert package.name == "requests" assert package.name == "requests"
assert len(package.requires) == 9 assert len(package.requires) == 9
...@@ -126,10 +128,10 @@ def test_package(): ...@@ -126,10 +128,10 @@ def test_package():
) )
def test_fallback_on_downloading_packages(): def test_fallback_on_downloading_packages() -> None:
repo = MockRepository(fallback=True) repo = MockRepository(fallback=True)
package = repo.package("jupyter", "1.0.0") package = repo.package(canonicalize_name("jupyter"), Version.parse("1.0.0"))
assert package.name == "jupyter" assert package.name == "jupyter"
assert len(package.requires) == 6 assert len(package.requires) == 6
...@@ -145,10 +147,10 @@ def test_fallback_on_downloading_packages(): ...@@ -145,10 +147,10 @@ def test_fallback_on_downloading_packages():
] ]
def test_fallback_inspects_sdist_first_if_no_matching_wheels_can_be_found(): def test_fallback_inspects_sdist_first_if_no_matching_wheels_can_be_found() -> None:
repo = MockRepository(fallback=True) repo = MockRepository(fallback=True)
package = repo.package("isort", "4.3.4") package = repo.package(canonicalize_name("isort"), Version.parse("4.3.4"))
assert package.name == "isort" assert package.name == "isort"
assert len(package.requires) == 1 assert len(package.requires) == 1
...@@ -158,10 +160,10 @@ def test_fallback_inspects_sdist_first_if_no_matching_wheels_can_be_found(): ...@@ -158,10 +160,10 @@ def test_fallback_inspects_sdist_first_if_no_matching_wheels_can_be_found():
assert dep.python_versions == "~2.7" assert dep.python_versions == "~2.7"
def test_fallback_can_read_setup_to_get_dependencies(): def test_fallback_can_read_setup_to_get_dependencies() -> None:
repo = MockRepository(fallback=True) repo = MockRepository(fallback=True)
package = repo.package("sqlalchemy", "1.2.12") package = repo.package(canonicalize_name("sqlalchemy"), Version.parse("1.2.12"))
assert package.name == "sqlalchemy" assert package.name == "sqlalchemy"
assert len(package.requires) == 9 assert len(package.requires) == 9
...@@ -180,10 +182,10 @@ def test_fallback_can_read_setup_to_get_dependencies(): ...@@ -180,10 +182,10 @@ def test_fallback_can_read_setup_to_get_dependencies():
} }
def test_pypi_repository_supports_reading_bz2_files(): def test_pypi_repository_supports_reading_bz2_files() -> None:
repo = MockRepository(fallback=True) repo = MockRepository(fallback=True)
package = repo.package("twisted", "18.9.0") package = repo.package(canonicalize_name("twisted"), Version.parse("18.9.0"))
assert package.name == "twisted" assert package.name == "twisted"
assert len(package.requires) == 71 assert len(package.requires) == 71
...@@ -220,7 +222,7 @@ def test_pypi_repository_supports_reading_bz2_files(): ...@@ -220,7 +222,7 @@ def test_pypi_repository_supports_reading_bz2_files():
) )
def test_invalid_versions_ignored(): def test_invalid_versions_ignored() -> None:
repo = MockRepository() repo = MockRepository()
# the json metadata for this package contains one malformed version # the json metadata for this package contains one malformed version
...@@ -229,7 +231,9 @@ def test_invalid_versions_ignored(): ...@@ -229,7 +231,9 @@ def test_invalid_versions_ignored():
assert len(packages) == 1 assert len(packages) == 1
def test_get_should_invalid_cache_on_too_many_redirects_error(mocker: MockerFixture): def test_get_should_invalid_cache_on_too_many_redirects_error(
mocker: MockerFixture,
) -> None:
delete_cache = mocker.patch("cachecontrol.caches.file_cache.FileCache.delete") delete_cache = mocker.patch("cachecontrol.caches.file_cache.FileCache.delete")
response = Response() response = Response()
...@@ -246,14 +250,14 @@ def test_get_should_invalid_cache_on_too_many_redirects_error(mocker: MockerFixt ...@@ -246,14 +250,14 @@ def test_get_should_invalid_cache_on_too_many_redirects_error(mocker: MockerFixt
assert delete_cache.called assert delete_cache.called
def test_urls(): def test_urls() -> None:
repository = PyPiRepository() repository = PyPiRepository()
assert repository.url == "https://pypi.org/simple/" assert repository.url == "https://pypi.org/simple/"
assert repository.authenticated_url == "https://pypi.org/simple/" assert repository.authenticated_url == "https://pypi.org/simple/"
def test_use_pypi_pretty_name(): def test_use_pypi_pretty_name() -> None:
repo = MockRepository(fallback=True) repo = MockRepository(fallback=True)
package = repo.find_packages(Factory.create_dependency("twisted", "*")) package = repo.find_packages(Factory.create_dependency("twisted", "*"))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment