diff --git a/ptpython/python_input.py b/ptpython/python_input.py index d66b5ae..975d3d9 100644 --- a/ptpython/python_input.py +++ b/ptpython/python_input.py @@ -347,14 +347,6 @@ def __init__( "classic": ClassicPrompt(), } - self.get_input_prompt = lambda: self.all_prompt_styles[ - self.prompt_style - ].in_prompt() - - self.get_output_prompt = lambda: self.all_prompt_styles[ - self.prompt_style - ].out_prompt() - #: Load styles. self.code_styles: dict[str, BaseStyle] = get_all_code_styles() self.ui_styles = get_all_ui_styles() @@ -425,6 +417,12 @@ def __init__( else: self._app = None + def get_input_prompt(self) -> AnyFormattedText: + return self.all_prompt_styles[self.prompt_style].in_prompt() + + def get_output_prompt(self) -> AnyFormattedText: + return self.all_prompt_styles[self.prompt_style].out_prompt() + def _accept_handler(self, buff: Buffer) -> bool: app = get_app() app.exit(result=buff.text) diff --git a/ptpython/repl.py b/ptpython/repl.py index bbbd852..ea2d84f 100644 --- a/ptpython/repl.py +++ b/ptpython/repl.py @@ -19,7 +19,8 @@ import types import warnings from dis import COMPILER_FLAG_NAMES -from typing import Any, Callable, ContextManager, Iterable +from pathlib import Path +from typing import Any, Callable, ContextManager, Iterable, Sequence from prompt_toolkit.formatted_text import OneStyleAndTextTuple from prompt_toolkit.patch_stdout import patch_stdout as patch_stdout_context @@ -64,7 +65,7 @@ def _has_coroutine_flag(code: types.CodeType) -> bool: class PythonRepl(PythonInput): def __init__(self, *a, **kw) -> None: - self._startup_paths = kw.pop("startup_paths", None) + self._startup_paths: Sequence[str | Path] | None = kw.pop("startup_paths", None) super().__init__(*a, **kw) self._load_start_paths() @@ -348,7 +349,7 @@ def _store_eval_result(self, result: object) -> None: def get_compiler_flags(self) -> int: return super().get_compiler_flags() | PyCF_ALLOW_TOP_LEVEL_AWAIT - def _compile_with_flags(self, code: str, mode: str): + def _compile_with_flags(self, code: str, mode: str) -> Any: "Compile code with the right compiler flags." return compile( code, @@ -459,13 +460,13 @@ def enter_to_continue() -> None: def embed( - globals=None, - locals=None, + globals: dict[str, Any] | None = None, + locals: dict[str, Any] | None = None, configure: Callable[[PythonRepl], None] | None = None, vi_mode: bool = False, history_filename: str | None = None, title: str | None = None, - startup_paths=None, + startup_paths: Sequence[str | Path] | None = None, patch_stdout: bool = False, return_asyncio_coroutine: bool = False, ) -> None: @@ -494,10 +495,10 @@ def embed( locals = locals or globals - def get_globals(): + def get_globals() -> dict[str, Any]: return globals - def get_locals(): + def get_locals() -> dict[str, Any]: return locals # Create REPL.