Skip to content

Commit

Permalink
[red-knot] Support TypeGuard and TypeIs
Browse files Browse the repository at this point in the history
  • Loading branch information
InSyncWithFoo committed Feb 22, 2025
1 parent 64effa4 commit 1f02c40
Show file tree
Hide file tree
Showing 10 changed files with 583 additions and 33 deletions.
227 changes: 227 additions & 0 deletions crates/red_knot_python_semantic/resources/mdtest/narrow/type_guards.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
# User-defined type guards

User-defined type guards are functions of which the return type is either `TypeGuard[...]` or
`TypeIs[...]`.

## Display

```py
from knot_extensions import Intersection, Not, TypeOf
from typing_extensions import TypeGuard, TypeIs

def _(
a: TypeGuard[str],
b: TypeIs[str | int],
c: TypeGuard[Intersection[complex, Not[int], Not[float]]],
d: TypeIs[tuple[TypeOf[bytes]]],
):
reveal_type(a) # revealed: TypeGuard[str]
reveal_type(b) # revealed: TypeIs[str | int]
reveal_type(c) # revealed: TypeGuard[complex & ~int & ~float]
reveal_type(d) # revealed: TypeIs[tuple[Literal[bytes]]]

def f(a) -> TypeGuard[str]: ...
def g(a) -> TypeIs[str]: ...

def _(a: object):
reveal_type(f(a)) # revealed: TypeGuard[a, str]
reveal_type(g(a)) # revealed: TypeIs[a, str]
```

## Parameters

A user-defined type guard must accept at least one positional argument, (in addition to `self`/`cls`
for non-static methods).

```py
from typing_extensions import TypeGuard, TypeIs

# error: [invalid-type-guard-definition]
def _() -> TypeGuard[str]: ...

# error: [invalid-type-guard-definition]
def _(**kwargs) -> TypeIs[str]: ...

class _:
def _(self, /, a) -> TypeGuard[str]: ...
@classmethod
def _(cls, a) -> TypeGuard[str]: ...
@staticmethod
def _(a) -> TypeIs[str]: ...

def _(self) -> TypeGuard[str]: ... # error: [invalid-type-guard-definition]
def _(self, /, *, a) -> TypeGuard[str]: ... # error: [invalid-type-guard-definition]
@classmethod
def _(cls) -> TypeIs[str]: ... # error: [invalid-type-guard-definition]
@classmethod
def _() -> TypeIs[str]: ... # error: [invalid-type-guard-definition]
@staticmethod
def _(*, a) -> TypeGuard[str]: ... # error: [invalid-type-guard-definition]
```

For `TypeIs` functions, the narrowed type must be assignable to the declared type of that parameter,
if any.

```py
from typing import Any
from typing_extensions import TypeGuard, TypeIs

def _(a: object) -> TypeIs[str]: ...
def _(a: Any) -> TypeIs[str]: ...
def _(a: tuple[object]) -> TypeIs[tuple[str]]: ...
def _(a: str | Any) -> TypeIs[str]: ...
def _(a) -> TypeIs[str]: ...

# error: [invalid-type-guard-definition]
def _(a: int) -> TypeIs[str]: ...
# error: [invalid-type-guard-definition]
def _(a: bool | str) -> TypeIs[int]: ...
```

## Arguments to special forms

`TypeGuard` and `TypeIs` accept exactly one type argument.

```py
from typing_extensions import TypeGuard, TypeIs

a = 123

# error: [invalid-type-form]
def f(_) -> TypeGuard[int, str]: ...
# error: [invalid-type-form]
def g(_) -> TypeIs[a, str]: ...

reveal_type(f(0)) # revealed: Unknown
reveal_type(g(0)) # revealed: Unknown
```

## Return types

All code paths in a type guard function must return booleans.

```py
from typing_extensions import Literal, TypeGuard, TypeIs, assert_never

def f(a: object, flag: bool) -> TypeGuard[str]:
if flag:
# TODO: Emit a diagnostic
return 1

# TODO: Emit a diagnostic
return ''

def g(a: Literal['foo', 'bar']) -> TypeIs[Literal['foo']]:
match a:
case 'foo':
# Logically wrong, but allowed regardless
return False
case 'bar':
return False
case _:
assert_never(a)
```

## Invalid calls

```py
from typing import Any
from typing_extensions import TypeGuard, TypeIs

def f(a: object) -> TypeGuard[str]: ...
def g(a: object) -> TypeIs[int]: ...

def _(d: Any):
if f(): # error: [missing-argument]
...

# TODO: Is this error correct?
if g(*d): # error: [missing-argument]
...

if f("foo"): # error: [invalid-type-guard-call]
...

if g(a=d): # error: [invalid-type-guard-call]
...

def _(a: tuple[str, int] | tuple[int, str]):
if g(a[0]): # error: [invalid-type-guard-call]
# TODO: Should be `tuple[str, int]`
reveal_type(a) # revealed: tuple[str, int] | tuple[int, str]
```

## Narrowing

```py
from typing import Any
from typing_extensions import TypeGuard, TypeIs

def guard_str(a: object) -> TypeGuard[str]: ...
def is_int(a: object) -> TypeIs[int]: ...

def _(a: str | int):
if guard_str(a):
reveal_type(a) # revealed: str
else:
reveal_type(a) # revealed: str | int

if is_int(a):
reveal_type(a) # revealed: int
else:
reveal_type(a) # revealed: str & ~int

def _(a: str | int):
b = guard_str(a)
c = is_int(a)

reveal_type(a) # revealed: str | int
reveal_type(b) # revealed: TypeGuard[a, str]
reveal_type(c) # revealed: TypeIs[a, int]

if b:
reveal_type(a) # revealed: str
else:
reveal_type(a) # revealed: str | int

if c:
reveal_type(a) # revealed: int
else:
reveal_type(a) # revealed: str

def _(x: str | int, flag: bool) -> None:
b = is_int(x)
reveal_type(b) # revealed: TypeIs[x, int]

if flag:
x = ''

if b:
reveal_type(x) # revealed: str | int
```

## `TypeGuard` special cases

```py
from typing import Any
from typing_extensions import TypeGuard

def guard_int(a: object) -> TypeGuard[int]: ...
def is_int(a: object) -> TypeGuard[int]: ...

def does_not_narrow_in_negative_case(a: str | int):
if not guard_int(a):
reveal_type(a) # revealed: str | int
else:
reveal_type(a) # revealed: int

def narrowed_type_must_be_exact(a: object, b: bool):
if guard_int(b):
reveal_type(b) # revealed: int

if isinstance(a, bool) and is_int(a):
reveal_type(a) # revealed: bool

if isinstance(a, bool) and guard_int(a):
reveal_type(a) # revealed: int
```
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,31 @@ static_assert(not is_subtype_of(str, AlwaysTruthy))
static_assert(not is_subtype_of(str, AlwaysFalsy))
```

### `TypeGuard` and `TypeIs`

`TypeGuard[...]` and `TypeIs[...]` are subtypes of `bool`.

```py
from knot_extensions import is_subtype_of, static_assert
from typing_extensions import TypeGuard, TypeIs

static_assert(is_subtype_of(TypeGuard[int], bool))
static_assert(is_subtype_of(TypeIs[str], bool))
```

`TypeIs` is invariant. `TypeGuard` is covariant.

```py
from knot_extensions import is_subtype_of, static_assert
from typing_extensions import TypeGuard, TypeIs

static_assert(is_subtype_of(TypeGuard[bool], TypeGuard[int]))

static_assert(not is_subtype_of(TypeGuard[int], TypeGuard[bool]))
static_assert(not is_subtype_of(TypeIs[bool], TypeIs[int]))
static_assert(not is_subtype_of(TypeIs[int], TypeIs[bool]))
```

### Module literals

```py
Expand Down
Loading

0 comments on commit 1f02c40

Please sign in to comment.