diff --git a/src/ocptv/output/emit.py b/src/ocptv/output/emit.py index 4ad976e..005d672 100644 --- a/src/ocptv/output/emit.py +++ b/src/ocptv/output/emit.py @@ -16,7 +16,7 @@ JSON = ty.Union[ty.Dict[str, "JSON"], ty.List["JSON"], Primitive] -def _is_optional(field: ty.Type): +def _is_optional(field: ty.Type) -> bool: # type hackery incoming # ty.Optional[T] == ty.Union[T, None] # since ty.Union[ty.Union[T,U]] = ty.Union[T,U] we can the @@ -30,7 +30,7 @@ class ArtifactEmitter: Uses the low level dataclass models for the spec, but should not be used in user code. """ - def __init__(self, writer: Writer): + def __init__(self, writer: Writer) -> None: self._seq_lock = threading.Lock() self._seq = 0 @@ -41,7 +41,7 @@ def __init__(self, writer: Writer): self._version_emitted = threading.Event() @staticmethod - def _serialize(artifact: ArtifactType): + def _serialize(artifact: ArtifactType) -> str: def visit( value: ty.Union[ArtifactType, ty.Dict, ty.List, Primitive], formatter: ty.Optional[ty.Callable[[ty.Any], str]] = None, @@ -56,7 +56,7 @@ def visit( val = getattr(value, field.name) if val is None: - if not _is_optional(field.type): + if not _is_optional(ty.cast(ty.Type, field.type)): # TODO: fix exception text/type raise RuntimeError("unacceptable none where not optional") diff --git a/src/ocptv/output/runtime_checks.py b/src/ocptv/output/runtime_checks.py index c8f83d5..085f692 100644 --- a/src/ocptv/output/runtime_checks.py +++ b/src/ocptv/output/runtime_checks.py @@ -119,8 +119,8 @@ def __str__(self): def _check_type_any(obj: CheckedValue, hint: ty.Type, trace: ty.List[str]): - type_origin = get_origin(hint) - type_args = get_args(hint) + type_origin = ty.cast(ty.Type | None, get_origin(hint)) + type_args = ty.cast(ty.Tuple[ty.Type, ...], get_args(hint)) if type_origin is list: # generic type: typ == ty.List[...] @@ -186,7 +186,7 @@ def _check_type_any(obj: CheckedValue, hint: ty.Type, trace: ty.List[str]): elif dc.is_dataclass(obj): for field in dc.fields(obj): subtrace = trace + [f"{obj.__class__.__name__}.{field.name}"] - _check_type_any(getattr(obj, field.name), field.type, subtrace) + _check_type_any(getattr(obj, field.name), ty.cast(ty.Type, field.type), subtrace) elif not isinstance(obj, hint): raise TypeCheckError(obj, expected=hint.__name__, trace=trace)