diff --git a/tests/test_type_view.py b/tests/test_type_view.py index 2a42af6..2704b59 100644 --- a/tests/test_type_view.py +++ b/tests/test_type_view.py @@ -259,7 +259,14 @@ def test_parsed_type_is_optional_predicate() -> None: def test_parsed_type_is_subtype_of() -> None: - """Test ParsedType.is_type_of.""" + """Test TypeView.is_subtype_of.""" + + class Foo: + pass + + class Bar(Foo): + pass + assert TypeView(bool).is_subtype_of(int) is True assert TypeView(bool).is_subtype_of(str) is False assert TypeView(Union[int, str]).is_subtype_of(int) is False @@ -268,6 +275,31 @@ def test_parsed_type_is_subtype_of() -> None: assert TypeView(Optional[int]).is_subtype_of(int) is False assert TypeView(Union[bool, int]).is_subtype_of(int) is True + assert TypeView(Foo).is_subtype_of(Foo) is True + assert TypeView(Bar).is_subtype_of(Foo) is True + + +def test_is_subclass_of() -> None: + class Foo: + pass + + class Bar(Foo): + pass + + assert TypeView(bool).is_subclass_of(int) is True + assert TypeView(bool).is_subclass_of(str) is False + assert TypeView(Union[int, str]).is_subclass_of(int) is False + assert TypeView(List[int]).is_subclass_of(int) is False + assert TypeView(Optional[int]).is_subclass_of(int) is False + assert TypeView(None).is_subclass_of(int) is False + assert TypeView(Literal[1]).is_subclass_of(int) is False + assert TypeView(Union[bool, int]).is_subclass_of(int) is False + + assert TypeView(bool).is_subclass_of(bool) is True + assert TypeView(List[int]).is_subclass_of(list) is True + assert TypeView(Foo).is_subclass_of(Foo) is True + assert TypeView(Bar).is_subclass_of(Foo) is True + def test_parsed_type_has_inner_subtype_of() -> None: """Test ParsedType.has_type_of.""" diff --git a/type_lens/type_view.py b/type_lens/type_view.py index 12684ec..9516297 100644 --- a/type_lens/type_view.py +++ b/type_lens/type_view.py @@ -226,6 +226,17 @@ def is_subtype_of(self, typ: Any | tuple[Any, ...], /) -> bool: return issubclass(str, typ) or issubclass(bytes, typ) return self.annotation is not Any and not self.is_type_var and issubclass(self.annotation, typ) + def is_subclass_of(self, typ: Any | tuple[Any, ...], /) -> bool: + """Whether the annotation is a subclass of the given type. + + Args: + typ: The type to check, or tuple of types. Passed as 2nd argument to ``issubclass()``. + + Returns: + Whether the annotation is a subclass of the given type(s). + """ + return isinstance(self.fallback_origin, type) and issubclass(self.fallback_origin, typ) + def strip_optional(self) -> TypeView: if not self.is_optional: return self