Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Type functions part 1 #812

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .idea/watcherTasks.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 3 additions & 1 deletion mypy/binder.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import DefaultDict, Generator, Iterator, List, NamedTuple, Optional, Tuple, Union
from typing_extensions import TypeAlias as _TypeAlias

import mypy.options
from mypy.erasetype import remove_instance_last_known_values
from mypy.join import join_simple
from mypy.literals import Key, literal, literal_hash, subkeys
Expand Down Expand Up @@ -331,7 +332,8 @@ def assign_type(
) -> None:
# We should erase last known value in binder, because if we are using it,
# it means that the target is not final, and therefore can't hold a literal.
type = remove_instance_last_known_values(type)
# HUUHHH?????
# type = remove_instance_last_known_values(type)

if self.type_assignments is not None:
# We are in a multiassign from union, defer the actual binding,
Expand Down
19 changes: 16 additions & 3 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3589,19 +3589,30 @@ def check_assignment(
):
lvalue.node.type = remove_instance_last_known_values(lvalue_type)

elif lvalue.node and lvalue.node.is_inferred and rvalue_type:
# for literal values
# Don't use type binder for definitions of special forms, like named tuples.
if not (isinstance(lvalue, NameExpr) and lvalue.is_special_form):
self.binder.assign_type(lvalue, rvalue_type, lvalue_type, False)

elif index_lvalue:
self.check_indexed_assignment(index_lvalue, rvalue, lvalue)

if inferred:
type_context = self.get_variable_type_context(inferred)
rvalue_type = self.expr_checker.accept(rvalue, type_context=type_context)
original_rvalue_type = rvalue_type
if not (
inferred.is_final
or inferred.is_index_var
or (isinstance(lvalue, NameExpr) and lvalue.name == "__match_args__")
):
rvalue_type = remove_instance_last_known_values(rvalue_type)
self.infer_variable_type(inferred, lvalue, rvalue_type, rvalue)
if self.infer_variable_type(inferred, lvalue, rvalue_type, rvalue):
self.binder.assign_type(
lvalue, original_rvalue_type, original_rvalue_type, False
)

self.check_assignment_to_slots(lvalue)

# (type, operator) tuples for augmented assignments supported with partial types
Expand Down Expand Up @@ -4553,12 +4564,13 @@ def is_definition(self, s: Lvalue) -> bool:

def infer_variable_type(
self, name: Var, lvalue: Lvalue, init_type: Type, context: Context
) -> None:
) -> bool:
"""Infer the type of initialized variables from initializer type."""
valid = True
if isinstance(init_type, DeletedType):
self.msg.deleted_as_rvalue(init_type, context)
elif (
not is_valid_inferred_type(init_type, is_lvalue_final=name.is_final)
not (valid := is_valid_inferred_type(init_type, is_lvalue_final=name.is_final))
and not self.no_partial_types
):
# We cannot use the type of the initialization expression for full type
Expand All @@ -4585,6 +4597,7 @@ def infer_variable_type(
init_type = strip_type(init_type)

self.set_inferred_type(name, lvalue, init_type)
return valid

def infer_partial_type(self, name: Var, lvalue: Lvalue, init_type: Type) -> bool:
init_type = get_proper_type(init_type)
Expand Down
135 changes: 131 additions & 4 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,21 @@

from __future__ import annotations

import builtins
import contextlib
import enum
import importlib
import io
import itertools
import time
from collections import defaultdict
from contextlib import contextmanager
from types import GetSetDescriptorType
from typing import Callable, ClassVar, Final, Iterable, Iterator, List, Optional, Sequence, cast
from typing_extensions import TypeAlias as _TypeAlias, assert_never, overload

from basedtyping import TypeFunctionError

import mypy.checker
import mypy.errorcodes as codes
from mypy import applytype, erasetype, errorcodes, join, message_registry, nodes, operators, types
Expand All @@ -23,6 +30,7 @@
)
from mypy.checkstrformat import StringFormatterChecker
from mypy.erasetype import erase_type, remove_instance_last_known_values, replace_meta_vars
from mypy.errorcodes import ErrorCode
from mypy.errors import ErrorInfo, ErrorWatcher, report_internal_error
from mypy.expandtype import (
expand_type,
Expand Down Expand Up @@ -205,12 +213,14 @@
from mypy.typestate import type_state
from mypy.typevars import fill_typevars
from mypy.util import split_module_names
from mypy.valuetotype import type_to_value, value_to_type
from mypy.visitor import ExpressionVisitor

# Type of callback user for checking individual function arguments. See
# check_args() below for details.
ArgChecker: _TypeAlias = Callable[
[Type, Type, ArgKind, Type, int, int, CallableType, Optional[Type], Context, Context], None
[Type, Type, ArgKind, Type, int, int, CallableType, Optional[Type], Context, Context, bool],
None,
]

# Maximum nesting level for math union in overloads, setting this to large values
Expand Down Expand Up @@ -1846,12 +1856,13 @@ def check_callable_call(
fresh_ret_type = freshen_all_functions_type_vars(callee.ret_type)
freeze_all_type_vars(fresh_ret_type)
callee = callee.copy_modified(ret_type=fresh_ret_type)

if callee.is_generic():
need_refresh = any(
isinstance(v, (ParamSpecType, TypeVarTupleType)) for v in callee.variables
)
# IT"S HERE!
callee = freshen_function_type_vars(callee)
# IT"S HERE!
callee = self.infer_function_type_arguments_using_context(callee, context)
if need_refresh:
# Argument kinds etc. may have changed due to
Expand Down Expand Up @@ -1909,7 +1920,6 @@ def check_callable_call(
self.check_argument_types(
arg_types, arg_kinds, args, callee, formal_to_actual, context, object_type=object_type
)

if (
callee.is_type_obj()
and (len(arg_types) == 1)
Expand All @@ -1921,6 +1931,38 @@ def check_callable_call(
# Store the inferred callable type.
self.chk.store_type(callable_node, callee)

if callee.is_type_function:
with self.msg.filter_errors(filter_errors=True) as error_watcher:
if object_type:
self.check_arg(
caller_type=object_type,
original_caller_type=object_type,
caller_kind=ArgKind.ARG_POS,
callee_type=callee.bound_args[0],
n=0,
m=0,
callee=callee,
object_type=object_type,
context=context,
outer_context=context,
type_function=True,
)

self.check_argument_types(
arg_types,
arg_kinds,
args,
callee,
formal_to_actual,
context,
object_type=object_type,
type_function=True,
)
if not error_watcher.has_new_errors() and "." in callable_name:
ret_type = self.call_type_function(callable_name, object_type, arg_types, context)
if ret_type:
callee = callee.copy_modified(ret_type=ret_type)

if callable_name and (
(object_type is None and self.plugin.get_function_hook(callable_name))
or (object_type is not None and self.plugin.get_method_hook(callable_name))
Expand All @@ -1939,6 +1981,71 @@ def check_callable_call(
callee = callee.copy_modified(ret_type=new_ret_type)
return callee.ret_type, callee

def call_type_function(
self,
callable_name: str,
object_type: ProperType | None,
arg_types: list[ProperType],
context: Context,
) -> Type | None:
container_name, fn_name = callable_name.rsplit(".", maxsplit=1)
resolved = None
for part in container_name.split("."):
if resolved:
m = resolved.names.get(part)
else:
m = self.chk.modules.get(part)
if m:
resolved = m
is_method = not isinstance(resolved, MypyFile)
if is_method:
container = resolved.node
module_name = container.module_name
else:
container = resolved
module_name = container.fullname

all_sigs = []
object_type = object_type and [object_type] or []
for arg in object_type + arg_types:
if isinstance(arg, UnionType):
if not all_sigs:
all_sigs = [[x] for x in arg.items]
else:
from itertools import product

all_sigs = product(all_sigs, arg.items)
all_sigs = all_sigs or [object_type + arg_types]
all_rets = []
with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr(io.StringIO()):
mod = importlib.import_module(module_name)
container = getattr(mod, container.name) if is_method else mod
fn = getattr(container, fn_name)
for sig in all_sigs:
if isinstance(fn, (GetSetDescriptorType, property)):
fn = fn.__get__
args = [type_to_value(arg, self.chk) for arg in sig]
try:
return_value = fn(*args)
except RecursionError:
self.chk.fail(
"maximum recursion depth exceeded while evaluating type function",
context=context,
)
except TypeFunctionError as type_function_error:
code = type_function_error.code and ErrorCode(type_function_error.code, "", "")
self.chk.fail(type_function_error.message, code=code, context=context)
except Exception as exception:
self.chk.fail(
f"Invocation raises {type(exception).__name__}: {exception}",
context,
code=errorcodes.CALL_RAISES,
)
else:
all_rets.append(value_to_type(return_value, chk=self.chk))

return make_simplified_union(all_rets)

def can_return_none(self, type: TypeInfo, attr_name: str) -> bool:
"""Is the given attribute a method with a None-compatible return type?

Expand Down Expand Up @@ -2175,6 +2282,13 @@ def infer_function_type_arguments(
Return a derived callable type that has the arguments applied.
"""
if self.chk.in_checked_function():
if isinstance(callee_type.ret_type, TypeVarType):
# if the return type is constant, infer as literal
rvalue_type = [
remove_instance_last_known_values(arg) if isinstance(arg, Instance) else arg
for arg in args
]

# Disable type errors during type inference. There may be errors
# due to partial available context information at this time, but
# these errors can be safely ignored as the arguments will be
Expand Down Expand Up @@ -2581,6 +2695,8 @@ def check_argument_types(
context: Context,
check_arg: ArgChecker | None = None,
object_type: Type | None = None,
*,
type_function=False,
) -> None:
"""Check argument types against a callable type.

Expand Down Expand Up @@ -2712,6 +2828,7 @@ def check_argument_types(
object_type,
args[actual],
context,
type_function,
)

def check_arg(
Expand All @@ -2726,12 +2843,16 @@ def check_arg(
object_type: Type | None,
context: Context,
outer_context: Context,
type_function=False,
) -> None:
"""Check the type of a single argument in a call."""
caller_type = get_proper_type(caller_type)
original_caller_type = get_proper_type(original_caller_type)
callee_type = get_proper_type(callee_type)

if type_function:
# TODO: make this work at all
if not isinstance(caller_type, Instance) or not caller_type.last_known_value:
caller_type = self.named_type("builtins.object")
if isinstance(caller_type, DeletedType):
self.msg.deleted_as_rvalue(caller_type, context)
# Only non-abstract non-protocol class can be given where Type[...] is expected...
Expand Down Expand Up @@ -3348,6 +3469,7 @@ def check_arg(
object_type: Type | None,
context: Context,
outer_context: Context,
type_function: bool,
) -> None:
if not arg_approximate_similarity(caller_type, callee_type):
# No match -- exit early since none of the remaining work can change
Expand Down Expand Up @@ -3580,10 +3702,14 @@ def visit_bytes_expr(self, e: BytesExpr) -> Type:

def visit_float_expr(self, e: FloatExpr) -> Type:
"""Type check a float literal (trivial)."""
if mypy.options._based:
return self.infer_literal_expr_type(e.value, "builtins.float")
return self.named_type("builtins.float")

def visit_complex_expr(self, e: ComplexExpr) -> Type:
"""Type check a complex literal."""
if mypy.options._based:
return self.infer_literal_expr_type(e.value, "builtins.complex")
return self.named_type("builtins.complex")

def visit_ellipsis(self, e: EllipsisExpr) -> Type:
Expand Down Expand Up @@ -6502,6 +6628,7 @@ def narrow_type_from_binder(
known_type, restriction, prohibit_none_typevar_overlap=True
):
return None

return narrow_declared_type(known_type, restriction)
return known_type

Expand Down
3 changes: 3 additions & 0 deletions mypy/errorcodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,9 @@ def __hash__(self) -> int:
TYPE_CHECK_ONLY: Final[ErrorCode] = ErrorCode(
"type-check-only", "Value doesn't exist at runtime", "General"
)
CALL_RAISES: Final[ErrorCode] = ErrorCode(
"call-raises", "function call raises an error", "General"
)
REVEAL: Final = ErrorCode("reveal", "Reveal types at check time", "General")

# Syntax errors are often blocking.
Expand Down
Loading
Loading