Commit cd4247f4 by David Hotham Committed by Bjorn Neergaard

typechecking poetry.utils.env

parent 59a9fd3f
...@@ -115,18 +115,6 @@ enable_error_code = [ ...@@ -115,18 +115,6 @@ enable_error_code = [
"truthy-bool", "truthy-bool",
] ]
# The following whitelist is used to allow for incremental adoption
# of Mypy. Modules should be removed from this whitelist as and when
# their respective type errors have been addressed. No new modules
# should be added to this whitelist.
# see https://github.com/python-poetry/poetry/pull/4510.
[[tool.mypy.overrides]]
module = [
'poetry.utils.env',
]
ignore_errors = true
# use of importlib-metadata backport at python3.7 makes it impossible to # use of importlib-metadata backport at python3.7 makes it impossible to
# satisfy mypy without some ignores: but we get a different set of ignores at # satisfy mypy without some ignores: but we get a different set of ignores at
# different python versions. # different python versions.
...@@ -152,6 +140,7 @@ module = [ ...@@ -152,6 +140,7 @@ module = [
'poetry.core.*', 'poetry.core.*',
'requests_toolbelt.*', 'requests_toolbelt.*',
'shellingham.*', 'shellingham.*',
'virtualenv.*',
] ]
ignore_missing_imports = true ignore_missing_imports = true
......
...@@ -18,7 +18,6 @@ from pathlib import Path ...@@ -18,7 +18,6 @@ from pathlib import Path
from subprocess import CalledProcessError from subprocess import CalledProcessError
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from typing import Any from typing import Any
from typing import TypeVar
import packaging.tags import packaging.tags
import tomlkit import tomlkit
...@@ -29,7 +28,6 @@ from packaging.tags import Tag ...@@ -29,7 +28,6 @@ from packaging.tags import Tag
from packaging.tags import interpreter_name from packaging.tags import interpreter_name
from packaging.tags import interpreter_version from packaging.tags import interpreter_version
from packaging.tags import sys_tags from packaging.tags import sys_tags
from poetry.core.poetry import Poetry
from poetry.core.semver.helpers import parse_constraint from poetry.core.semver.helpers import parse_constraint
from poetry.core.semver.version import Version from poetry.core.semver.version import Version
from poetry.core.toml.file import TOMLFile from poetry.core.toml.file import TOMLFile
...@@ -51,10 +49,11 @@ if TYPE_CHECKING: ...@@ -51,10 +49,11 @@ if TYPE_CHECKING:
from collections.abc import Iterator from collections.abc import Iterator
from cleo.io.io import IO from cleo.io.io import IO
from poetry.core.poetry import Poetry as CorePoetry
from poetry.core.version.markers import BaseMarker from poetry.core.version.markers import BaseMarker
from virtualenv.seed.wheels.util import Wheel
from poetry.poetry import Poetry
P = TypeVar("P", bound=Poetry)
GET_SYS_TAGS = f""" GET_SYS_TAGS = f"""
...@@ -226,7 +225,7 @@ class SitePackages: ...@@ -226,7 +225,7 @@ class SitePackages:
self, self,
purelib: Path, purelib: Path,
platlib: Path | None = None, platlib: Path | None = None,
fallbacks: list[Path] = None, fallbacks: list[Path] | None = None,
skip_write_checks: bool = False, skip_write_checks: bool = False,
) -> None: ) -> None:
self._purelib = purelib self._purelib = purelib
...@@ -302,7 +301,7 @@ class SitePackages: ...@@ -302,7 +301,7 @@ class SitePackages:
def distributions( def distributions(
self, name: str | None = None, writable_only: bool = False self, name: str | None = None, writable_only: bool = False
) -> Iterable[metadata.PathDistribution]: ) -> Iterable[metadata.Distribution]:
path = list( path = list(
map( map(
str, self._candidates if not writable_only else self.writable_candidates str, self._candidates if not writable_only else self.writable_candidates
...@@ -313,7 +312,7 @@ class SitePackages: ...@@ -313,7 +312,7 @@ class SitePackages:
def find_distribution( def find_distribution(
self, name: str, writable_only: bool = False self, name: str, writable_only: bool = False
) -> metadata.PathDistribution | None: ) -> metadata.Distribution | None:
for distribution in self.distributions(name=name, writable_only=writable_only): for distribution in self.distributions(name=name, writable_only=writable_only):
return distribution return distribution
return None return None
...@@ -324,6 +323,7 @@ class SitePackages: ...@@ -324,6 +323,7 @@ class SitePackages:
for distribution in self.distributions( for distribution in self.distributions(
name=distribution_name, writable_only=writable_only name=distribution_name, writable_only=writable_only
): ):
assert distribution.files is not None
for file in distribution.files: for file in distribution.files:
if file.name.endswith(suffix): if file.name.endswith(suffix):
yield Path(distribution.locate_file(file)) yield Path(distribution.locate_file(file))
...@@ -334,6 +334,7 @@ class SitePackages: ...@@ -334,6 +334,7 @@ class SitePackages:
for distribution in self.distributions( for distribution in self.distributions(
name=distribution_name, writable_only=writable_only name=distribution_name, writable_only=writable_only
): ):
assert distribution.files is not None
for file in distribution.files: for file in distribution.files:
if file.name == name: if file.name == name:
yield Path(distribution.locate_file(file)) yield Path(distribution.locate_file(file))
...@@ -362,16 +363,18 @@ class SitePackages: ...@@ -362,16 +363,18 @@ class SitePackages:
for distribution in self.distributions( for distribution in self.distributions(
name=distribution_name, writable_only=True name=distribution_name, writable_only=True
): ):
assert distribution.files is not None
for file in distribution.files: for file in distribution.files:
file = Path(distribution.locate_file(file)) path = Path(distribution.locate_file(file))
# We can't use unlink(missing_ok=True) because it's not always available # We can't use unlink(missing_ok=True) because it's not always available
if file.exists(): if path.exists():
file.unlink() path.unlink()
if distribution._path.exists(): distribution_path: Path = distribution._path # type: ignore[attr-defined]
remove_directory(str(distribution._path), force=True) if distribution_path.exists():
remove_directory(str(distribution_path), force=True)
paths.append(distribution._path) paths.append(distribution_path)
return paths return paths
...@@ -409,10 +412,14 @@ class SitePackages: ...@@ -409,10 +412,14 @@ class SitePackages:
raise OSError(f"Unable to access any of {paths_csv(candidates)}") raise OSError(f"Unable to access any of {paths_csv(candidates)}")
def write_text(self, path: str | Path, *args: Any, **kwargs: Any) -> Path: def write_text(self, path: str | Path, *args: Any, **kwargs: Any) -> Path:
return self._path_method_wrapper(path, "write_text", *args, **kwargs)[0] paths = self._path_method_wrapper(path, "write_text", *args, **kwargs)
assert isinstance(paths, tuple)
return paths[0]
def mkdir(self, path: str | Path, *args: Any, **kwargs: Any) -> Path: def mkdir(self, path: str | Path, *args: Any, **kwargs: Any) -> Path:
return self._path_method_wrapper(path, "mkdir", *args, **kwargs)[0] paths = self._path_method_wrapper(path, "mkdir", *args, **kwargs)
assert isinstance(paths, tuple)
return paths[0]
def exists(self, path: str | Path) -> bool: def exists(self, path: str | Path) -> bool:
return any( return any(
...@@ -498,7 +505,7 @@ class EnvManager: ...@@ -498,7 +505,7 @@ class EnvManager:
ENVS_FILE = "envs.toml" ENVS_FILE = "envs.toml"
def __init__(self, poetry: P) -> None: def __init__(self, poetry: Poetry) -> None:
self._poetry = poetry self._poetry = poetry
def _full_python_path(self, python: str) -> str: def _full_python_path(self, python: str) -> str:
...@@ -516,7 +523,7 @@ class EnvManager: ...@@ -516,7 +523,7 @@ class EnvManager:
return executable return executable
def _detect_active_python(self, io: IO) -> str: def _detect_active_python(self, io: IO) -> str | None:
executable = None executable = None
try: try:
...@@ -558,7 +565,7 @@ class EnvManager: ...@@ -558,7 +565,7 @@ class EnvManager:
python = self._full_python_path(python) python = self._full_python_path(python)
try: try:
python_version = decode( python_version_string = decode(
subprocess.check_output( subprocess.check_output(
list_to_shell_command([python, "-c", GET_PYTHON_VERSION_ONELINER]), list_to_shell_command([python, "-c", GET_PYTHON_VERSION_ONELINER]),
shell=True, shell=True,
...@@ -567,7 +574,7 @@ class EnvManager: ...@@ -567,7 +574,7 @@ class EnvManager:
except CalledProcessError as e: except CalledProcessError as e:
raise EnvCommandError(e) raise EnvCommandError(e)
python_version = Version.parse(python_version.strip()) python_version = Version.parse(python_version_string.strip())
minor = f"{python_version.major}.{python_version.minor}" minor = f"{python_version.major}.{python_version.minor}"
patch = python_version.text patch = python_version.text
...@@ -649,7 +656,7 @@ class EnvManager: ...@@ -649,7 +656,7 @@ class EnvManager:
envs_file.write(envs) envs_file.write(envs)
def get(self, reload: bool = False) -> VirtualEnv | SystemEnv: def get(self, reload: bool = False) -> Env:
if self._env is not None and not reload: if self._env is not None and not reload:
return self._env return self._env
...@@ -798,7 +805,7 @@ class EnvManager: ...@@ -798,7 +805,7 @@ class EnvManager:
pass pass
try: try:
python_version = decode( python_version_string = decode(
subprocess.check_output( subprocess.check_output(
list_to_shell_command([python, "-c", GET_PYTHON_VERSION_ONELINER]), list_to_shell_command([python, "-c", GET_PYTHON_VERSION_ONELINER]),
shell=True, shell=True,
...@@ -807,13 +814,13 @@ class EnvManager: ...@@ -807,13 +814,13 @@ class EnvManager:
except CalledProcessError as e: except CalledProcessError as e:
raise EnvCommandError(e) raise EnvCommandError(e)
python_version = Version.parse(python_version.strip()) python_version = Version.parse(python_version_string.strip())
minor = f"{python_version.major}.{python_version.minor}" minor = f"{python_version.major}.{python_version.minor}"
name = f"{base_env_name}-py{minor}" name = f"{base_env_name}-py{minor}"
venv = venv_path / name venv_path = venv_path / name
if not venv.exists(): if not venv_path.exists():
raise ValueError(f'<warning>Environment "{name}" does not exist.</warning>') raise ValueError(f'<warning>Environment "{name}" does not exist.</warning>')
if envs_file.exists(): if envs_file.exists():
...@@ -826,9 +833,9 @@ class EnvManager: ...@@ -826,9 +833,9 @@ class EnvManager:
del envs[base_env_name] del envs[base_env_name]
envs_file.write(envs) envs_file.write(envs)
self.remove_venv(venv) self.remove_venv(venv_path)
return VirtualEnv(venv, venv) return VirtualEnv(venv_path, venv_path)
def create_venv( def create_venv(
self, self,
...@@ -836,7 +843,7 @@ class EnvManager: ...@@ -836,7 +843,7 @@ class EnvManager:
name: str | None = None, name: str | None = None,
executable: str | None = None, executable: str | None = None,
force: bool = False, force: bool = False,
) -> SystemEnv | VirtualEnv: ) -> Env:
if self._env is not None and not force: if self._env is not None and not force:
return self._env return self._env
...@@ -1019,7 +1026,7 @@ class EnvManager: ...@@ -1019,7 +1026,7 @@ class EnvManager:
cls, cls,
path: Path | str, path: Path | str,
executable: str | Path | None = None, executable: str | Path | None = None,
flags: dict[str, bool] = None, flags: dict[str, bool] | None = None,
with_pip: bool | None = None, with_pip: bool | None = None,
with_wheel: bool | None = None, with_wheel: bool | None = None,
with_setuptools: bool | None = None, with_setuptools: bool | None = None,
...@@ -1086,7 +1093,7 @@ class EnvManager: ...@@ -1086,7 +1093,7 @@ class EnvManager:
remove_directory(file_path, force=True) remove_directory(file_path, force=True)
@classmethod @classmethod
def get_system_env(cls, naive: bool = False) -> SystemEnv | GenericEnv: def get_system_env(cls, naive: bool = False) -> Env:
""" """
Retrieve the current Python environment. Retrieve the current Python environment.
...@@ -1100,7 +1107,7 @@ class EnvManager: ...@@ -1100,7 +1107,7 @@ class EnvManager:
(e.g. plugin installation or self update). (e.g. plugin installation or self update).
""" """
prefix, base_prefix = Path(sys.prefix), Path(cls.get_base_prefix()) prefix, base_prefix = Path(sys.prefix), Path(cls.get_base_prefix())
env = SystemEnv(prefix) env: Env = SystemEnv(prefix)
if not naive: if not naive:
if prefix.joinpath("poetry_env").exists(): if prefix.joinpath("poetry_env").exists():
env = GenericEnv(base_prefix, child_env=env) env = GenericEnv(base_prefix, child_env=env)
...@@ -1118,11 +1125,13 @@ class EnvManager: ...@@ -1118,11 +1125,13 @@ class EnvManager:
@classmethod @classmethod
def get_base_prefix(cls) -> Path: def get_base_prefix(cls) -> Path:
if hasattr(sys, "real_prefix"): real_prefix = getattr(sys, "real_prefix", None)
return Path(sys.real_prefix) if real_prefix is not None:
return Path(real_prefix)
if hasattr(sys, "base_prefix"): base_prefix = getattr(sys, "base_prefix", None)
return Path(sys.base_prefix) if base_prefix is not None:
return Path(base_prefix)
return Path(sys.prefix) return Path(sys.prefix)
...@@ -1131,10 +1140,10 @@ class EnvManager: ...@@ -1131,10 +1140,10 @@ class EnvManager:
name = name.lower() name = name.lower()
sanitized_name = re.sub(r'[ $`!*@"\\\r\n\t]', "_", name)[:42] sanitized_name = re.sub(r'[ $`!*@"\\\r\n\t]', "_", name)[:42]
normalized_cwd = os.path.normcase(cwd) normalized_cwd = os.path.normcase(cwd)
h = hashlib.sha256(encode(normalized_cwd)).digest() h_bytes = hashlib.sha256(encode(normalized_cwd)).digest()
h = base64.urlsafe_b64encode(h).decode()[:8] h_str = base64.urlsafe_b64encode(h_bytes).decode()[:8]
return f"{sanitized_name}-{h}" return f"{sanitized_name}-{h_str}"
class Env: class Env:
...@@ -1161,16 +1170,16 @@ class Env: ...@@ -1161,16 +1170,16 @@ class Env:
self._base = base or path self._base = base or path
self._marker_env = None self._marker_env: dict[str, Any] | None = None
self._pip_version = None self._pip_version: Version | None = None
self._site_packages = None self._site_packages: SitePackages | None = None
self._paths = None self._paths: dict[str, str] | None = None
self._supported_tags = None self._supported_tags: list[Tag] | None = None
self._purelib = None self._purelib: Path | None = None
self._platlib = None self._platlib: Path | None = None
self._script_dirs = None self._script_dirs: list[Path] | None = None
self._embedded_pip_path = None self._embedded_pip_path: str | None = None
@property @property
def path(self) -> Path: def path(self) -> Path:
...@@ -1181,12 +1190,13 @@ class Env: ...@@ -1181,12 +1190,13 @@ class Env:
return self._base return self._base
@property @property
def version_info(self) -> tuple[int]: def version_info(self) -> tuple[Any, ...]:
return tuple(self.marker_env["version_info"]) return tuple(self.marker_env["version_info"])
@property @property
def python_implementation(self) -> str: def python_implementation(self) -> str:
return self.marker_env["platform_python_implementation"] implementation: str = self.marker_env["platform_python_implementation"]
return implementation
@property @property
def python(self) -> str: def python(self) -> str:
...@@ -1242,9 +1252,11 @@ class Env: ...@@ -1242,9 +1252,11 @@ class Env:
self._find_pip_executable() self._find_pip_executable()
def get_embedded_wheel(self, distribution: str) -> Path: def get_embedded_wheel(self, distribution: str) -> Path:
return get_embed_wheel( wheel: Wheel = get_embed_wheel(
distribution, f"{self.version_info[0]}.{self.version_info[1]}" distribution, f"{self.version_info[0]}.{self.version_info[1]}"
).path )
path: Path = wheel.path
return path
@property @property
def pip_embedded(self) -> str: def pip_embedded(self) -> str:
...@@ -1351,15 +1363,17 @@ class Env: ...@@ -1351,15 +1363,17 @@ class Env:
@classmethod @classmethod
def get_base_prefix(cls) -> Path: def get_base_prefix(cls) -> Path:
if hasattr(sys, "real_prefix"): real_prefix = getattr(sys, "real_prefix", None)
return Path(sys.real_prefix) if real_prefix is not None:
return Path(real_prefix)
if hasattr(sys, "base_prefix"): base_prefix = getattr(sys, "base_prefix", None)
return Path(sys.base_prefix) if base_prefix is not None:
return Path(base_prefix)
return Path(sys.prefix) return Path(sys.prefix)
def get_version_info(self) -> tuple[int]: def get_version_info(self) -> tuple[Any, ...]:
raise NotImplementedError() raise NotImplementedError()
def get_python_implementation(self) -> str: def get_python_implementation(self) -> str:
...@@ -1421,12 +1435,15 @@ class Env: ...@@ -1421,12 +1435,15 @@ class Env:
if self._is_windows: if self._is_windows:
kwargs["shell"] = True kwargs["shell"] = True
command: str | list[str]
if kwargs.get("shell", False): if kwargs.get("shell", False):
cmd = list_to_shell_command(cmd) command = list_to_shell_command(cmd)
else:
command = cmd
if input_: if input_:
output = subprocess.run( output = subprocess.run(
cmd, command,
stdout=subprocess.PIPE, stdout=subprocess.PIPE,
stderr=subprocess.STDOUT, stderr=subprocess.STDOUT,
input=encode(input_), input=encode(input_),
...@@ -1434,10 +1451,12 @@ class Env: ...@@ -1434,10 +1451,12 @@ class Env:
**kwargs, **kwargs,
).stdout ).stdout
elif call: elif call:
return subprocess.call(cmd, stderr=subprocess.STDOUT, env=env, **kwargs) return subprocess.call(
command, stderr=subprocess.STDOUT, env=env, **kwargs
)
else: else:
output = subprocess.check_output( output = subprocess.check_output(
cmd, stderr=subprocess.STDOUT, env=env, **kwargs command, stderr=subprocess.STDOUT, env=env, **kwargs
) )
except CalledProcessError as e: except CalledProcessError as e:
raise EnvCommandError(e, input=input_) raise EnvCommandError(e, input=input_)
...@@ -1461,11 +1480,10 @@ class Env: ...@@ -1461,11 +1480,10 @@ class Env:
@property @property
def script_dirs(self) -> list[Path]: def script_dirs(self) -> list[Path]:
if self._script_dirs is None: if self._script_dirs is None:
self._script_dirs = ( scripts = self.paths.get("scripts")
[Path(self.paths["scripts"])] self._script_dirs = [
if "scripts" in self.paths Path(scripts) if scripts is not None else self._bin_dir
else self._bin_dir ]
)
if self.userbase: if self.userbase:
self._script_dirs.append(self.userbase / self._script_dirs[0].name) self._script_dirs.append(self.userbase / self._script_dirs[0].name)
return self._script_dirs return self._script_dirs
...@@ -1498,6 +1516,9 @@ class Env: ...@@ -1498,6 +1516,9 @@ class Env:
return str(bin_path) return str(bin_path)
def __eq__(self, other: object) -> bool: def __eq__(self, other: object) -> bool:
if not isinstance(other, Env):
return False
return other.__class__ == self.__class__ and other.path == self.path return other.__class__ == self.__class__ and other.path == self.path
def __repr__(self) -> str: def __repr__(self) -> str:
...@@ -1517,8 +1538,8 @@ class SystemEnv(Env): ...@@ -1517,8 +1538,8 @@ class SystemEnv(Env):
def sys_path(self) -> list[str]: def sys_path(self) -> list[str]:
return sys.path return sys.path
def get_version_info(self) -> tuple[int]: def get_version_info(self) -> tuple[Any, ...]:
return sys.version_info return tuple(sys.version_info)
def get_python_implementation(self) -> str: def get_python_implementation(self) -> str:
return platform.python_implementation() return platform.python_implementation()
...@@ -1543,6 +1564,7 @@ class SystemEnv(Env): ...@@ -1543,6 +1564,7 @@ class SystemEnv(Env):
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.filterwarnings("ignore", "setup.py install is deprecated") warnings.filterwarnings("ignore", "setup.py install is deprecated")
obj = d.get_command_obj("install", create=True) obj = d.get_command_obj("install", create=True)
assert obj is not None
obj.finalize_options() obj.finalize_options()
paths = sysconfig.get_paths().copy() paths = sysconfig.get_paths().copy()
...@@ -1553,9 +1575,12 @@ class SystemEnv(Env): ...@@ -1553,9 +1575,12 @@ class SystemEnv(Env):
paths[key] = getattr(obj, f"install_{key}") paths[key] = getattr(obj, f"install_{key}")
if site.check_enableusersite() and hasattr(obj, "install_usersite"): if site.check_enableusersite():
paths["usersite"] = obj.install_usersite usersite = getattr(obj, "install_usersite", None)
paths["userbase"] = obj.install_userbase userbase = getattr(obj, "install_userbase", None)
if usersite is not None and userbase is not None:
paths["usersite"] = usersite
paths["userbase"] = userbase
return paths return paths
...@@ -1614,20 +1639,26 @@ class VirtualEnv(Env): ...@@ -1614,20 +1639,26 @@ class VirtualEnv(Env):
# In this case we need to get sys.base_prefix # In this case we need to get sys.base_prefix
# from inside the virtualenv. # from inside the virtualenv.
if base is None: if base is None:
self._base = Path(self.run_python_script(GET_BASE_PREFIX).strip()) output = self.run_python_script(GET_BASE_PREFIX)
assert isinstance(output, str)
self._base = Path(output.strip())
@property @property
def sys_path(self) -> list[str]: def sys_path(self) -> list[str]:
output = self.run_python_script(GET_SYS_PATH) output = self.run_python_script(GET_SYS_PATH)
return json.loads(output) assert isinstance(output, str)
paths: list[str] = json.loads(output)
return paths
def get_version_info(self) -> tuple[int]: def get_version_info(self) -> tuple[Any, ...]:
output = self.run_python_script(GET_PYTHON_VERSION) output = self.run_python_script(GET_PYTHON_VERSION)
assert isinstance(output, str)
return tuple(int(s) for s in output.strip().split(".")) return tuple(int(s) for s in output.strip().split("."))
def get_python_implementation(self) -> str: def get_python_implementation(self) -> str:
return self.marker_env["platform_python_implementation"] implementation: str = self.marker_env["platform_python_implementation"]
return implementation
def get_pip_command(self, embedded: bool = False) -> list[str]: def get_pip_command(self, embedded: bool = False) -> list[str]:
# We're in a virtualenv that is known to be sane, # We're in a virtualenv that is known to be sane,
...@@ -1639,16 +1670,22 @@ class VirtualEnv(Env): ...@@ -1639,16 +1670,22 @@ class VirtualEnv(Env):
def get_supported_tags(self) -> list[Tag]: def get_supported_tags(self) -> list[Tag]:
output = self.run_python_script(GET_SYS_TAGS) output = self.run_python_script(GET_SYS_TAGS)
assert isinstance(output, str)
return [Tag(*t) for t in json.loads(output)] return [Tag(*t) for t in json.loads(output)]
def get_marker_env(self) -> dict[str, Any]: def get_marker_env(self) -> dict[str, Any]:
output = self.run_python_script(GET_ENVIRONMENT_INFO) output = self.run_python_script(GET_ENVIRONMENT_INFO)
assert isinstance(output, str)
return json.loads(output) env: dict[str, Any] = json.loads(output)
return env
def get_pip_version(self) -> Version: def get_pip_version(self) -> Version:
output = self.run_pip("--version").strip() output = self.run_pip("--version")
assert isinstance(output, str)
output = output.strip()
m = re.match("pip (.+?)(?: from .+)?$", output) m = re.match("pip (.+?)(?: from .+)?$", output)
if not m: if not m:
return Version.parse("0.0") return Version.parse("0.0")
...@@ -1657,7 +1694,9 @@ class VirtualEnv(Env): ...@@ -1657,7 +1694,9 @@ class VirtualEnv(Env):
def get_paths(self) -> dict[str, str]: def get_paths(self) -> dict[str, str]:
output = self.run_python_script(GET_PATHS) output = self.run_python_script(GET_PATHS)
return json.loads(output) assert isinstance(output, str)
paths: dict[str, str] = json.loads(output)
return paths
def is_venv(self) -> bool: def is_venv(self) -> bool:
return True return True
...@@ -1666,7 +1705,7 @@ class VirtualEnv(Env): ...@@ -1666,7 +1705,7 @@ class VirtualEnv(Env):
# A virtualenv is considered sane if "python" exists. # A virtualenv is considered sane if "python" exists.
return os.path.exists(self.python) return os.path.exists(self.python)
def _run(self, cmd: list[str], **kwargs: Any) -> int | None: def _run(self, cmd: list[str], **kwargs: Any) -> int | str:
kwargs["env"] = self.get_temp_environ(environ=kwargs.get("env")) kwargs["env"] = self.get_temp_environ(environ=kwargs.get("env"))
return super()._run(cmd, **kwargs) return super()._run(cmd, **kwargs)
...@@ -1771,8 +1810,10 @@ class GenericEnv(VirtualEnv): ...@@ -1771,8 +1810,10 @@ class GenericEnv(VirtualEnv):
def get_paths(self) -> dict[str, str]: def get_paths(self) -> dict[str, str]:
output = self.run_python_script(GET_PATHS_FOR_GENERIC_ENVS) output = self.run_python_script(GET_PATHS_FOR_GENERIC_ENVS)
assert isinstance(output, str)
return json.loads(output) paths: dict[str, str] = json.loads(output)
return paths
def execute(self, bin: str, *args: str, **kwargs: Any) -> int | None: def execute(self, bin: str, *args: str, **kwargs: Any) -> int | None:
command = self.get_command_from_bin(bin) + list(args) command = self.get_command_from_bin(bin) + list(args)
...@@ -1786,7 +1827,7 @@ class GenericEnv(VirtualEnv): ...@@ -1786,7 +1827,7 @@ class GenericEnv(VirtualEnv):
return exe.returncode return exe.returncode
def _run(self, cmd: list[str], **kwargs: Any) -> int | None: def _run(self, cmd: list[str], **kwargs: Any) -> int | str:
return super(VirtualEnv, self)._run(cmd, **kwargs) return super(VirtualEnv, self)._run(cmd, **kwargs)
def is_venv(self) -> bool: def is_venv(self) -> bool:
...@@ -1795,7 +1836,7 @@ class GenericEnv(VirtualEnv): ...@@ -1795,7 +1836,7 @@ class GenericEnv(VirtualEnv):
class NullEnv(SystemEnv): class NullEnv(SystemEnv):
def __init__( def __init__(
self, path: Path = None, base: Path | None = None, execute: bool = False self, path: Path | None = None, base: Path | None = None, execute: bool = False
) -> None: ) -> None:
if path is None: if path is None:
path = Path(sys.prefix) path = Path(sys.prefix)
...@@ -1803,7 +1844,7 @@ class NullEnv(SystemEnv): ...@@ -1803,7 +1844,7 @@ class NullEnv(SystemEnv):
super().__init__(path, base=base) super().__init__(path, base=base)
self._execute = execute self._execute = execute
self.executed = [] self.executed: list[list[str]] = []
def get_pip_command(self, embedded: bool = False) -> list[str]: def get_pip_command(self, embedded: bool = False) -> list[str]:
return [ return [
...@@ -1811,12 +1852,12 @@ class NullEnv(SystemEnv): ...@@ -1811,12 +1852,12 @@ class NullEnv(SystemEnv):
self.pip_embedded if embedded else self.pip, self.pip_embedded if embedded else self.pip,
] ]
def _run(self, cmd: list[str], **kwargs: Any) -> int | None: def _run(self, cmd: list[str], **kwargs: Any) -> int | str:
self.executed.append(cmd) self.executed.append(cmd)
if self._execute: if self._execute:
return super()._run(cmd, **kwargs) return super()._run(cmd, **kwargs)
return None return 0
def execute(self, bin: str, *args: str, **kwargs: Any) -> int | None: def execute(self, bin: str, *args: str, **kwargs: Any) -> int | None:
self.executed.append([bin] + list(args)) self.executed.append([bin] + list(args))
...@@ -1832,7 +1873,7 @@ class NullEnv(SystemEnv): ...@@ -1832,7 +1873,7 @@ class NullEnv(SystemEnv):
@contextmanager @contextmanager
def ephemeral_environment( def ephemeral_environment(
executable: str | Path | None = None, executable: str | Path | None = None,
flags: dict[str, bool] = None, flags: dict[str, bool] | None = None,
) -> Iterator[VirtualEnv]: ) -> Iterator[VirtualEnv]:
with temporary_directory() as tmp_dir: with temporary_directory() as tmp_dir:
# TODO: cache PEP 517 build environment corresponding to each project venv # TODO: cache PEP 517 build environment corresponding to each project venv
...@@ -1847,7 +1888,7 @@ def ephemeral_environment( ...@@ -1847,7 +1888,7 @@ def ephemeral_environment(
@contextmanager @contextmanager
def build_environment( def build_environment(
poetry: P, env: Env | None = None, io: IO | None = None poetry: CorePoetry, env: Env | None = None, io: IO | None = None
) -> Iterator[Env]: ) -> Iterator[Env]:
""" """
If a build script is specified for the project, there could be additional build If a build script is specified for the project, there could be additional build
...@@ -1897,8 +1938,8 @@ class MockEnv(NullEnv): ...@@ -1897,8 +1938,8 @@ class MockEnv(NullEnv):
is_venv: bool = False, is_venv: bool = False,
pip_version: str = "19.1", pip_version: str = "19.1",
sys_path: list[str] | None = None, sys_path: list[str] | None = None,
marker_env: dict[str, Any] = None, marker_env: dict[str, Any] | None = None,
supported_tags: list[Tag] = None, supported_tags: list[Tag] | None = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
super().__init__(**kwargs) super().__init__(**kwargs)
...@@ -1908,7 +1949,7 @@ class MockEnv(NullEnv): ...@@ -1908,7 +1949,7 @@ class MockEnv(NullEnv):
self._platform = platform self._platform = platform
self._os_name = os_name self._os_name = os_name
self._is_venv = is_venv self._is_venv = is_venv
self._pip_version = Version.parse(pip_version) self._pip_version: Version = Version.parse(pip_version)
self._sys_path = sys_path self._sys_path = sys_path
self._mock_marker_env = marker_env self._mock_marker_env = marker_env
self._supported_tags = supported_tags self._supported_tags = supported_tags
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment