Commit f172b602 by David Hotham Committed by GitHub

typechecking for poetry.utils (#5464)

parent 0d38a394
...@@ -129,11 +129,7 @@ module = [ ...@@ -129,11 +129,7 @@ module = [
'poetry.mixology.term', 'poetry.mixology.term',
'poetry.mixology.version_solver', 'poetry.mixology.version_solver',
'poetry.repositories.installed_repository', 'poetry.repositories.installed_repository',
'poetry.utils.appdirs',
'poetry.utils.authenticator',
'poetry.utils.env', 'poetry.utils.env',
'poetry.utils.exporter',
'poetry.utils.setup_reader',
] ]
ignore_errors = true ignore_errors = true
......
...@@ -7,17 +7,11 @@ from __future__ import annotations ...@@ -7,17 +7,11 @@ from __future__ import annotations
import os import os
import sys import sys
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from pathlib import Path
WINDOWS = sys.platform.startswith("win") or (sys.platform == "cli" and os.name == "nt") WINDOWS = sys.platform.startswith("win") or (sys.platform == "cli" and os.name == "nt")
def expanduser(path: str | Path) -> str: def expanduser(path: str) -> str:
""" """
Expand ~ and ~user constructions. Expand ~ and ~user constructions.
...@@ -214,14 +208,15 @@ def _get_win_folder_with_ctypes(csidl_name: str) -> str: ...@@ -214,14 +208,15 @@ def _get_win_folder_with_ctypes(csidl_name: str) -> str:
}[csidl_name] }[csidl_name]
buf = ctypes.create_unicode_buffer(1024) buf = ctypes.create_unicode_buffer(1024)
ctypes.windll.shell32.SHGetFolderPathW(None, csidl_const, None, 0, buf) windll = ctypes.windll # type: ignore[attr-defined]
windll.shell32.SHGetFolderPathW(None, csidl_const, None, 0, buf)
# Downgrade to short path name if have highbit chars. See # Downgrade to short path name if have highbit chars. See
# <http://bugs.activestate.com/show_bug.cgi?id=85099>. # <http://bugs.activestate.com/show_bug.cgi?id=85099>.
has_high_char = any(ord(c) > 255 for c in buf) has_high_char = any(ord(c) > 255 for c in buf)
if has_high_char: if has_high_char:
buf2 = ctypes.create_unicode_buffer(1024) buf2 = ctypes.create_unicode_buffer(1024)
if ctypes.windll.kernel32.GetShortPathNameW(buf.value, buf2, 1024): if windll.kernel32.GetShortPathNameW(buf.value, buf2, 1024):
buf = buf2 buf = buf2
return buf.value return buf.value
......
...@@ -33,9 +33,9 @@ class Authenticator: ...@@ -33,9 +33,9 @@ class Authenticator:
def __init__(self, config: Config, io: IO | None = None) -> None: def __init__(self, config: Config, io: IO | None = None) -> None:
self._config = config self._config = config
self._io = io self._io = io
self._session = None self._session: requests.Session | None = None
self._credentials = {} self._credentials: dict[str, tuple[str, str]] = {}
self._certs = {} self._certs: dict[str, dict[str, Path | None]] = {}
self._password_manager = PasswordManager(self._config) self._password_manager = PasswordManager(self._config)
def _log(self, message: str, level: str = "debug") -> None: def _log(self, message: str, level: str = "debug") -> None:
...@@ -118,7 +118,9 @@ class Authenticator: ...@@ -118,7 +118,9 @@ class Authenticator:
netloc = parsed_url.netloc netloc = parsed_url.netloc
credentials = self._credentials.get(netloc, (None, None)) credentials: tuple[str | None, str | None] = self._credentials.get(
netloc, (None, None)
)
if credentials == (None, None): if credentials == (None, None):
if "@" not in netloc: if "@" not in netloc:
...@@ -131,25 +133,27 @@ class Authenticator: ...@@ -131,25 +133,27 @@ class Authenticator:
# Split from the left because that's how urllib.parse.urlsplit() # Split from the left because that's how urllib.parse.urlsplit()
# behaves if more than one : is present (which again can be checked # behaves if more than one : is present (which again can be checked
# using the password attribute of the return value) # using the password attribute of the return value)
credentials = auth.split(":", 1) if ":" in auth else (auth, None) user, password = auth.split(":", 1) if ":" in auth else (auth, "")
credentials = tuple( credentials = (
None if x is None else urllib.parse.unquote(x) for x in credentials urllib.parse.unquote(user),
urllib.parse.unquote(password),
) )
if credentials[0] is not None or credentials[1] is not None: if any(credential is not None for credential in credentials):
credentials = (credentials[0] or "", credentials[1] or "") credentials = (credentials[0] or "", credentials[1] or "")
self._credentials[netloc] = credentials self._credentials[netloc] = credentials
return credentials[0], credentials[1] return credentials
def get_pypi_token(self, name: str) -> str: def get_pypi_token(self, name: str) -> str | None:
return self._password_manager.get_pypi_token(name) return self._password_manager.get_pypi_token(name)
def get_http_auth(self, name: str) -> dict[str, str] | None: def get_http_auth(self, name: str) -> dict[str, str | None] | None:
return self._get_http_auth(name, None) return self._get_http_auth(name, None)
def _get_http_auth(self, name: str, netloc: str | None) -> dict[str, str] | None: def _get_http_auth(
self, name: str, netloc: str | None
) -> dict[str, str | None] | None:
if name == "pypi": if name == "pypi":
url = "https://upload.pypi.org/legacy/" url = "https://upload.pypi.org/legacy/"
else: else:
...@@ -161,15 +165,18 @@ class Authenticator: ...@@ -161,15 +165,18 @@ class Authenticator:
if netloc is None or netloc == parsed_url.netloc: if netloc is None or netloc == parsed_url.netloc:
auth = self._password_manager.get_http_auth(name) auth = self._password_manager.get_http_auth(name)
auth = auth or {}
if auth is None or auth["password"] is None: if auth.get("password") is None:
username = auth["username"] if auth else None username = auth.get("username")
auth = self._get_credentials_for_netloc_from_keyring( auth = self._get_credentials_for_netloc_from_keyring(
url, parsed_url.netloc, username url, parsed_url.netloc, username
) )
return auth return auth
return None
def _get_credentials_for_netloc(self, netloc: str) -> tuple[str | None, str | None]: def _get_credentials_for_netloc(self, netloc: str) -> tuple[str | None, str | None]:
for repository_name, _ in self._get_repository_netlocs(): for repository_name, _ in self._get_repository_netlocs():
auth = self._get_http_auth(repository_name, netloc) auth = self._get_http_auth(repository_name, netloc)
...@@ -177,7 +184,7 @@ class Authenticator: ...@@ -177,7 +184,7 @@ class Authenticator:
if auth is None: if auth is None:
continue continue
return auth["username"], auth["password"] return auth.get("username"), auth.get("password")
return None, None return None, None
...@@ -199,7 +206,7 @@ class Authenticator: ...@@ -199,7 +206,7 @@ class Authenticator:
def _get_credentials_for_netloc_from_keyring( def _get_credentials_for_netloc_from_keyring(
self, url: str, netloc: str, username: str | None self, url: str, netloc: str, username: str | None
) -> dict[str, str] | None: ) -> dict[str, str | None] | None:
import keyring import keyring
cred = keyring.get_credential(url, username) cred = keyring.get_credential(url, username)
...@@ -225,7 +232,7 @@ class Authenticator: ...@@ -225,7 +232,7 @@ class Authenticator:
return None return None
def _get_certs_for_netloc_from_config(self, netloc: str) -> dict[str, Path | None]: def _get_certs_for_netloc_from_config(self, netloc: str) -> dict[str, Path | None]:
certs = {"cert": None, "verify": None} certs: dict[str, Path | None] = {"cert": None, "verify": None}
for repository_name, repository_netloc in self._get_repository_netlocs(): for repository_name, repository_netloc in self._get_repository_netlocs():
if netloc == repository_netloc: if netloc == repository_netloc:
......
...@@ -152,7 +152,7 @@ class PasswordManager: ...@@ -152,7 +152,7 @@ class PasswordManager:
self.keyring.delete_password(name, "__token__") self.keyring.delete_password(name, "__token__")
def get_http_auth(self, name: str) -> dict[str, str] | None: def get_http_auth(self, name: str) -> dict[str, str | None] | None:
auth = self._config.get(f"http-basic.{name}") auth = self._config.get(f"http-basic.{name}")
if not auth: if not auth:
username = self._config.get(f"http-basic.{name}.username") username = self._config.get(f"http-basic.{name}.username")
...@@ -181,10 +181,14 @@ class PasswordManager: ...@@ -181,10 +181,14 @@ class PasswordManager:
def delete_http_password(self, name: str) -> None: def delete_http_password(self, name: str) -> None:
auth = self.get_http_auth(name) auth = self.get_http_auth(name)
if not auth or "username" not in auth: if not auth:
return
username = auth.get("username")
if username is None:
return return
with suppress(KeyRingError): with suppress(KeyRingError):
self.keyring.delete_password(name, auth["username"]) self.keyring.delete_password(name, username)
self._config.auth_config_source.remove_property(f"http-basic.{name}") self._config.auth_config_source.remove_property(f"http-basic.{name}")
...@@ -5,7 +5,6 @@ import ast ...@@ -5,7 +5,6 @@ import ast
from configparser import ConfigParser from configparser import ConfigParser
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
from typing import Iterable
from poetry.core.semver.version import Version from poetry.core.semver.version import Version
...@@ -15,7 +14,7 @@ class SetupReader: ...@@ -15,7 +14,7 @@ class SetupReader:
Class that reads a setup.py file without executing it. Class that reads a setup.py file without executing it.
""" """
DEFAULT = { DEFAULT: dict[str, Any] = {
"name": None, "name": None,
"version": None, "version": None,
"install_requires": [], "install_requires": [],
...@@ -52,26 +51,27 @@ class SetupReader: ...@@ -52,26 +51,27 @@ class SetupReader:
with filepath.open(encoding="utf-8") as f: with filepath.open(encoding="utf-8") as f:
content = f.read() content = f.read()
result = {} result: dict[str, Any] = {}
body = ast.parse(content).body body = ast.parse(content).body
setup_call, body = self._find_setup_call(body) setup_call = self._find_setup_call(body)
if not setup_call: if setup_call is None:
return self.DEFAULT return self.DEFAULT
# Inspecting keyword arguments # Inspecting keyword arguments
result["name"] = self._find_single_string(setup_call, body, "name") call, body = setup_call
result["version"] = self._find_single_string(setup_call, body, "version") result["name"] = self._find_single_string(call, body, "name")
result["install_requires"] = self._find_install_requires(setup_call, body) result["version"] = self._find_single_string(call, body, "version")
result["extras_require"] = self._find_extras_require(setup_call, body) result["install_requires"] = self._find_install_requires(call, body)
result["extras_require"] = self._find_extras_require(call, body)
result["python_requires"] = self._find_single_string( result["python_requires"] = self._find_single_string(
setup_call, body, "python_requires" call, body, "python_requires"
) )
return result return result
def read_setup_cfg(self, filepath: str | Path) -> dict[str, list | dict]: def read_setup_cfg(self, filepath: str | Path) -> dict[str, Any]:
parser = ConfigParser() parser = ConfigParser()
parser.read(str(filepath)) parser.read(str(filepath))
...@@ -85,7 +85,7 @@ class SetupReader: ...@@ -85,7 +85,7 @@ class SetupReader:
version = Version.parse(parser.get("metadata", "version")).text version = Version.parse(parser.get("metadata", "version")).text
install_requires = [] install_requires = []
extras_require = {} extras_require: dict[str, list[str]] = {}
python_requires = None python_requires = None
if parser.has_section("options"): if parser.has_section("options"):
if parser.has_option("options", "install_requires"): if parser.has_option("options", "install_requires"):
...@@ -119,9 +119,9 @@ class SetupReader: ...@@ -119,9 +119,9 @@ class SetupReader:
} }
def _find_setup_call( def _find_setup_call(
self, elements: list[Any] self, elements: list[ast.stmt]
) -> tuple[ast.Call | None, list[Any] | None]: ) -> tuple[ast.Call, list[ast.stmt]] | None:
funcdefs = [] funcdefs: list[ast.stmt] = []
for i, element in enumerate(elements): for i, element in enumerate(elements):
if isinstance(element, ast.If) and i == len(elements) - 1: if isinstance(element, ast.If) and i == len(elements) - 1:
# Checking if the last element is an if statement # Checking if the last element is an if statement
...@@ -138,11 +138,13 @@ class SetupReader: ...@@ -138,11 +138,13 @@ class SetupReader:
if left.id != "__name__": if left.id != "__name__":
continue continue
setup_call, body = self._find_sub_setup_call([element]) setup_call = self._find_sub_setup_call([element])
if not setup_call: if setup_call is None:
continue continue
return setup_call, body + elements call, body = setup_call
return call, body + elements
if not isinstance(element, ast.Expr): if not isinstance(element, ast.Expr):
if isinstance(element, ast.FunctionDef): if isinstance(element, ast.FunctionDef):
funcdefs.append(element) funcdefs.append(element)
...@@ -156,8 +158,7 @@ class SetupReader: ...@@ -156,8 +158,7 @@ class SetupReader:
func = value.func func = value.func
if not (isinstance(func, ast.Name) and func.id == "setup") and not ( if not (isinstance(func, ast.Name) and func.id == "setup") and not (
isinstance(func, ast.Attribute) isinstance(func, ast.Attribute)
and hasattr(func.value, "id") and getattr(func.value, "id", None) == "setuptools"
and func.value.id == "setuptools"
and func.attr == "setup" and func.attr == "setup"
): ):
continue continue
...@@ -168,24 +169,24 @@ class SetupReader: ...@@ -168,24 +169,24 @@ class SetupReader:
return self._find_sub_setup_call(funcdefs) return self._find_sub_setup_call(funcdefs)
def _find_sub_setup_call( def _find_sub_setup_call(
self, elements: list[Any] self, elements: list[ast.stmt]
) -> tuple[ast.Call | None, list[Any] | None]: ) -> tuple[ast.Call, list[ast.stmt]] | None:
for element in elements: for element in elements:
if not isinstance(element, (ast.FunctionDef, ast.If)): if not isinstance(element, (ast.FunctionDef, ast.If)):
continue continue
setup_call = self._find_setup_call(element.body) setup_call = self._find_setup_call(element.body)
if setup_call != (None, None): if setup_call is not None:
setup_call, body = setup_call sub_call, body = setup_call
body = elements + body body = elements + body
return setup_call, body return sub_call, body
return None, None return None
def _find_install_requires(self, call: ast.Call, body: Iterable[Any]) -> list[str]: def _find_install_requires(self, call: ast.Call, body: list[ast.stmt]) -> list[str]:
install_requires = [] install_requires: list[str] = []
value = self._find_in_call(call, "install_requires") value = self._find_in_call(call, "install_requires")
if value is None: if value is None:
# Trying to find in kwargs # Trying to find in kwargs
...@@ -214,20 +215,22 @@ class SetupReader: ...@@ -214,20 +215,22 @@ class SetupReader:
if isinstance(value, ast.List): if isinstance(value, ast.List):
for el in value.elts: for el in value.elts:
install_requires.append(el.s) if isinstance(el, ast.Str):
install_requires.append(el.s)
elif isinstance(value, ast.Name): elif isinstance(value, ast.Name):
variable = self._find_variable_in_body(body, value.id) variable = self._find_variable_in_body(body, value.id)
if variable is not None and isinstance(variable, ast.List): if variable is not None and isinstance(variable, ast.List):
for el in variable.elts: for el in variable.elts:
install_requires.append(el.s) if isinstance(el, ast.Str):
install_requires.append(el.s)
return install_requires return install_requires
def _find_extras_require( def _find_extras_require(
self, call: ast.Call, body: Iterable[Any] self, call: ast.Call, body: list[ast.stmt]
) -> dict[str, list]: ) -> dict[str, list[str]]:
extras_require = {} extras_require: dict[str, list[str]] = {}
value = self._find_in_call(call, "extras_require") value = self._find_in_call(call, "extras_require")
if value is None: if value is None:
# Trying to find in kwargs # Trying to find in kwargs
...@@ -255,12 +258,18 @@ class SetupReader: ...@@ -255,12 +258,18 @@ class SetupReader:
return extras_require return extras_require
if isinstance(value, ast.Dict): if isinstance(value, ast.Dict):
val: ast.expr | None
for key, val in zip(value.keys, value.values): for key, val in zip(value.keys, value.values):
if not isinstance(key, ast.Str):
continue
if isinstance(val, ast.Name): if isinstance(val, ast.Name):
val = self._find_variable_in_body(body, val.id) val = self._find_variable_in_body(body, val.id)
if isinstance(val, ast.List): if isinstance(val, ast.List):
extras_require[key.s] = [e.s for e in val.elts] extras_require[key.s] = [
e.s for e in val.elts if isinstance(e, ast.Str)
]
elif isinstance(value, ast.Name): elif isinstance(value, ast.Name):
variable = self._find_variable_in_body(body, value.id) variable = self._find_variable_in_body(body, value.id)
...@@ -268,16 +277,21 @@ class SetupReader: ...@@ -268,16 +277,21 @@ class SetupReader:
return extras_require return extras_require
for key, val in zip(variable.keys, variable.values): for key, val in zip(variable.keys, variable.values):
if not isinstance(key, ast.Str):
continue
if isinstance(val, ast.Name): if isinstance(val, ast.Name):
val = self._find_variable_in_body(body, val.id) val = self._find_variable_in_body(body, val.id)
if isinstance(val, ast.List): if isinstance(val, ast.List):
extras_require[key.s] = [e.s for e in val.elts] extras_require[key.s] = [
e.s for e in val.elts if isinstance(e, ast.Str)
]
return extras_require return extras_require
def _find_single_string( def _find_single_string(
self, call: ast.Call, body: list[Any], name: str self, call: ast.Call, body: list[ast.stmt], name: str
) -> str | None: ) -> str | None:
value = self._find_in_call(call, name) value = self._find_in_call(call, name)
if value is None: if value is None:
...@@ -313,6 +327,8 @@ class SetupReader: ...@@ -313,6 +327,8 @@ class SetupReader:
if variable is not None and isinstance(variable, ast.Str): if variable is not None and isinstance(variable, ast.Str):
return variable.s return variable.s
return None
def _find_in_call(self, call: ast.Call, name: str) -> Any | None: def _find_in_call(self, call: ast.Call, name: str) -> Any | None:
for keyword in call.keywords: for keyword in call.keywords:
if keyword.arg == name: if keyword.arg == name:
...@@ -327,12 +343,10 @@ class SetupReader: ...@@ -327,12 +343,10 @@ class SetupReader:
return kwargs return kwargs
def _find_variable_in_body(self, body: Iterable[Any], name: str) -> Any | None: def _find_variable_in_body(
found = None self, body: list[ast.stmt], name: str
) -> ast.expr | None:
for elem in body: for elem in body:
if found:
break
if not isinstance(elem, ast.Assign): if not isinstance(elem, ast.Assign):
continue continue
...@@ -343,8 +357,11 @@ class SetupReader: ...@@ -343,8 +357,11 @@ class SetupReader:
if target.id == name: if target.id == name:
return elem.value return elem.value
def _find_in_dict(self, dict_: ast.Dict | ast.Call, name: str) -> Any | None: return None
def _find_in_dict(self, dict_: ast.Dict, name: str) -> ast.expr | None:
for key, val in zip(dict_.keys, dict_.values): for key, val in zip(dict_.keys, dict_.values):
if isinstance(key, ast.Str) and key.s == name: if isinstance(key, ast.Str) and key.s == name:
return val return val
return None return None
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment