Skip to content

Commit

Permalink
type checking refactor and generalization
Browse files Browse the repository at this point in the history
  • Loading branch information
cmbant committed Nov 4, 2024
1 parent 19e4143 commit e540704
Show file tree
Hide file tree
Showing 4 changed files with 189 additions and 97 deletions.
83 changes: 15 additions & 68 deletions cobaya/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,16 @@
from inspect import cleandoc
from packaging import version
from importlib import import_module, resources
from numbers import Integral, Real
from typing import ClassVar, ForwardRef, Optional, Union, List, Set
from typing import Optional, Union, List, Set

from cobaya.log import HasLogger, LoggedError, get_logger
from cobaya.typing import Any, InfoDict, InfoDictIn, ParamDict, empty_dict
from cobaya.typing import Any, InfoDict, InfoDictIn, empty_dict, validate_type
from cobaya.tools import resolve_packages_path, load_module, get_base_classes, \
get_internal_class_component_name, deepcopy_where_possible, NumberWithUnits, VersionCheckError
get_internal_class_component_name, deepcopy_where_possible, VersionCheckError
from cobaya.conventions import kinds, cobaya_package, reserved_attributes
from cobaya.yaml import yaml_load_file, yaml_dump, yaml_load
from cobaya.mpi import is_main_process
import cobaya


class Timer:
Expand Down Expand Up @@ -352,12 +352,14 @@ def __init__(self, info: InfoDictIn = empty_dict,
# set attributes from the info (from yaml file or directly input dictionary)
annotations = self.get_annotations()
for k, value in info.items():
self.validate_bool(k, value, annotations)
self.validate_info(k, value, annotations)
try:
setattr(self, k, value)
except AttributeError:
raise AttributeError("Cannot set {} attribute for {}!".format(k, self))
self.set_logger(name=self._name)
self.validate_attributes(annotations)

self.set_timing_on(timing)
try:
if initialize:
Expand All @@ -369,9 +371,6 @@ def __init__(self, info: InfoDictIn = empty_dict,
" are set (%s, %s)", self, e)
raise

if self._enforce_types:
self.validate_attributes()

def set_timing_on(self, on):
self.timer = Timer() if on else None

Expand Down Expand Up @@ -418,13 +417,13 @@ def has_version(self):
"""
return True

def validate_bool(self, name: str, value: Any, annotations: dict):
def validate_info(self, name: str, value: Any, annotations: dict):
"""
Does any validation on parameter k read from an input dictionary or yaml file,
before setting the corresponding class attribute.
You could enforce consistency with annotations here, but does not by default.
:param k: name of parameter
:param name: name of parameter
:param value: value
:param annotations: resolved inherited dictionary of attributes for this class
"""
Expand All @@ -433,64 +432,12 @@ def validate_bool(self, name: str, value: Any, annotations: dict):
raise AttributeError("Class '%s' parameter '%s' should be True "
"or False, got '%s'" % (self, name, value))

def validate_info(self, name: str, value: Any, annotations: dict):
if name in annotations:
expected_type = annotations[name]
if not self._validate_type(expected_type, value):
msg = f"Attribute '{name}' must be of type {expected_type}, not {type(value)}(value={value})"
raise TypeError(msg)

def _validate_composite_type(self, expected_type, value):
origin = expected_type.__origin__
try: # for Callable and Sequence types, which have no __args__
args = expected_type.__args__
except AttributeError:
pass

if origin is Union:
return any(self._validate_type(t, value) for t in args)
elif origin is Optional:
return value is None or self._validate_type(args[0], value)
elif origin is list:
return all(self._validate_type(args[0], item) for item in value)
elif origin is dict:
return all(
self._validate_type(args[0], k) and self._validate_type(args[1], v)
for k, v in value.items()
)
elif origin is tuple:
return len(args) == len(value) and all(
self._validate_type(t, v) for t, v in zip(args, value)
)
elif origin is ClassVar:
return self._validate_type(args[0], value)
else:
return isinstance(value, origin)

def _validate_type(self, expected_type, value):
if value is None or expected_type is Any: # Any is always valid
return True

if hasattr(expected_type, "__origin__"):
return self._validate_composite_type(expected_type, value)
else:
# Exceptions for some types
if expected_type is ParamDict:
return isinstance(value, dict)
elif expected_type is int:
if value == float('inf'): # for infinite values parsed as floats
return isinstance(value, float)
return isinstance(value, Integral)
elif expected_type is float:
return isinstance(value, Real)
elif expected_type is NumberWithUnits:
return isinstance(value, Real) or isinstance(value, str)
return isinstance(value, expected_type)

def validate_attributes(self):
annotations = self.get_annotations()
for name in annotations.keys():
self.validate_info(name, getattr(self, name, None), annotations)
def validate_attributes(self, annotations: dict):
check = cobaya.typing.enforce_type_checking
if check or (self._enforce_types and check is not False):
for name, annotation in annotations.items():
validate_type(annotation, getattr(self, name, None),
self.get_name() + ':' + name)

@classmethod
def get_kind(cls):
Expand Down
159 changes: 157 additions & 2 deletions cobaya/typing.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
from typing import Dict, Any, Optional, Union, Sequence, Type, Callable, Mapping
from typing import TypedDict, Literal
from collections.abc import Mapping, Callable, Sequence, Iterable
from typing import Dict, Any, Optional, Union, Type, TypedDict, Literal
from types import MappingProxyType
import typing
import numbers
import numpy as np
import sys

InfoDict = Dict[str, Any]
InfoDictIn = Mapping[str, Any]
Expand Down Expand Up @@ -102,3 +106,154 @@ class InputDict(ModelDict, total=False):
packages_path: Optional[str]
output: Optional[str]
version: Optional[Union[str, InfoDict]]


enforce_type_checking = None


def validate_type(expected_type: type, value: Any, path: str = ''):
"""
Checks for soft compatibility of a value with a type.
Raises TypeError with descriptive messages when validation fails.
:param expected_type: from annotation
:param value: value to validate
:param path: string tracking the nested path for error messages
:raises TypeError: with descriptive message when validation fails
"""
curr_path = f"'{path}'" if path else 'value'

if value is None or expected_type is Any:
return

if expected_type is int:
if not (value in (np.inf, -np.inf) or isinstance(value, numbers.Integral)):
raise TypeError(
f"{curr_path} must be an integer or infinity, got {type(value).__name__}"
)
return

if expected_type is float:
if not (isinstance(value, numbers.Real) or
(isinstance(value, np.ndarray) and value.shape == ())):
raise TypeError(f"{curr_path} must be a float, got {type(value).__name__}")
return

if expected_type is bool:
if not hasattr(value, '__bool__') and not isinstance(value, (str, np.ndarray)):
raise TypeError(
f"{curr_path} must be boolean, got {type(value).__name__}"
)
return

# special case for Cobaya
if expected_type.__name__ == 'NumberWithUnits':
if not isinstance(value, (numbers.Real, str)):
raise TypeError(f"{curr_path} must be a number or string for NumberWithUnits,"
f" got {type(value).__name__}")
return

if sys.version_info < (3, 10):
from typing_extensions import is_typeddict
else:
from typing import is_typeddict

if is_typeddict(expected_type):
type_hints = typing.get_type_hints(expected_type)
if not isinstance(value, Mapping):
raise TypeError(f"{curr_path} must be a mapping for TypedDict "
f"'{expected_type.__name__}', got {type(value).__name__}")
if invalid_keys := set(value) - set(type_hints):
raise TypeError(f"{curr_path} contains invalid keys for TypedDict "
f"'{expected_type.__name__}': {invalid_keys}")
for key, val in value.items():
validate_type(type_hints[key], val, f"{path}.{key}" if path else str(key))
return True

if origin := typing.get_origin(expected_type):
args = typing.get_args(expected_type)

if origin is Union:
errors = []
structural_errors = []

for t in args:
try:
return validate_type(t, value, path)
except TypeError as e:
error_msg = str(e)
error_path = error_msg.split(' ')[0].strip("'")

# If error is about the current path, it's a structural error
if error_path == path:
# Skip uninformative "must be of type NoneType" errors
if "must be of type NoneType" not in error_msg:
structural_errors.append(error_msg)
else:
errors.append((error_path, error_msg))
else:
errors.append((error_path, error_msg))

# If we have structural errors, show those first
if structural_errors:
if len(structural_errors) == 1:
raise TypeError(structural_errors[0])
raise TypeError(
f"{curr_path} failed to match any Union type:\n" +
"\n".join(f"- {e}" for e in set(structural_errors))
)

# Otherwise, show the deepest validation errors
longest_path = max((p for p, _ in errors), key=len)
path_errors = list(set(e for p, e in errors if p == longest_path))
raise TypeError(
f"{longest_path} failed to match any Union type:\n" +
"\n".join(f"- {e}" for e in path_errors)
)

if origin is typing.ClassVar:
return validate_type(args[0], value, path)

if origin in (list, tuple, set, Sequence, Iterable, np.ndarray):
if isinstance(value, np.ndarray):
if not value.shape:
raise TypeError(f"{curr_path} numpy array zero rank")
if len(args) == 1 and not np.issubdtype(value.dtype, args[0]):
raise TypeError(
f"{curr_path} numpy array has wrong dtype: "
f"expected {args[0]}, got {value.dtype}"
)
return

if not isinstance(value, Iterable):
raise TypeError(
f"{curr_path} must be iterable, got {type(value).__name__}"
)

if len(args) == 1:
for i, item in enumerate(value):
validate_type(args[0], item, f"{path}[{i}]" if path else f"[{i}]")
else:
if not isinstance(value, Sequence):
raise TypeError(f"{curr_path} must be a sequence for "
f"tuple types, got {type(value).__name__}")
if len(args) != len(value):
raise TypeError(f"{curr_path} has wrong length: "
f"expected {len(args)}, got {len(value)}")
for i, (t, v) in enumerate(zip(args, value)):
validate_type(t, v, f"{path}[{i}]" if path else f"[{i}]")
return

if origin in (dict, Mapping):
if not isinstance(value, Mapping):
raise TypeError(f"{curr_path} must be a mapping, "
f"got {type(value).__name__}")
for k, v in value.items():
key_path = f"{path}[{k!r}]" if path else f"[{k!r}]"
validate_type(args[0], k, f"{key_path} (key)")
validate_type(args[1], v, key_path)
return
if not (isinstance(value, expected_type) or
expected_type is Sequence and isinstance(value, np.ndarray)):
raise TypeError(f"{curr_path} must be of type {expected_type.__name__}, "
f"got {type(value).__name__}")
3 changes: 3 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
from cobaya.conventions import packages_path_env, packages_path_arg_posix, \
test_skip_env
from cobaya.tools import resolve_packages_path
import cobaya.typing

cobaya.typing.enforce_type_checking = True


def pytest_addoption(parser):
Expand Down
Loading

0 comments on commit e540704

Please sign in to comment.