Skip to content

Commit

Permalink
Add enabled flag to main Enforcer function and update docs accordingly.
Browse files Browse the repository at this point in the history
  • Loading branch information
connor-makowski committed Mar 1, 2024
1 parent 8003db3 commit 67ecd6c
Show file tree
Hide file tree
Showing 11 changed files with 123 additions and 50 deletions.
11 changes: 8 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,11 @@ pip install type_enforced
```py
import type_enforced

@type_enforced.Enforcer
@type_enforced.Enforcer(enabled=True)
def my_fn(a: int , b: [int, str] =2, c: int =3) -> None:
pass
```
- Note: `enabled=True` by default if not specified. You can set `enabled=False` to disable type checking for a specific function, method, or class. This is useful for a production vs debugging environment or for undecorating a single method in a larger wrapped class.

# Getting Started

Expand Down Expand Up @@ -159,15 +160,19 @@ class my_class:
pass
```

You can ignore a specific class method if you wrap it in `type_enforced.EnforcerIgnore`
You can skip enforcement if you add the argument `enabled=False` in the `Enforcer` call.
- This is useful for a production vs debugging environment.
- This is also useful for undecorating a single method in a larger wrapped class.
- Note: You can set `enabled=False` for an entire class or simply disable a specific method in a larger wrapped class.
- Note: Method level wrapper `enabled` values take precedence over class level wrappers.
```py
import type_enforced
@type_enforced.Enforcer
class my_class:
def my_fn(self, a: int) -> None:
pass

@type_enforced.EnforcerIgnore
@type_enforced.Enforcer(enabled=False)
def my_other_fn(self, a: int) -> None:
pass
```
Expand Down
2 changes: 1 addition & 1 deletion test/test_class_01.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ def __init__(self):

@type_enforced.Enforcer
def my_fn(self, b: int):
print(self.a, b)
pass


mc = my_class()
Expand Down
2 changes: 1 addition & 1 deletion test/test_class_02.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ def __init__(self):
self.a = 10

def my_fn(self, b: int):
print(self.a, b)
pass


mc = my_class()
Expand Down
4 changes: 3 additions & 1 deletion test/test_class_05.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@ def subtract(a: int, b: int) -> int:
pass

# classmethod and staticmethod wrappers do not contain annotations prior to 3.9
if success or sys.version_info <= (3, 10, 0):
if success:
print("test_class_05.py passed")
elif sys.version_info <= (3, 10, 0):
print("test_class_05.py skipped")
else:
print("test_class_05.py failed")
15 changes: 14 additions & 1 deletion test/test_class_10.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,27 @@ class Foo:
def bar(self, a: int) -> None:
pass

@type_enforced.EnforcerIgnore
@type_enforced.Enforcer(enabled=False)
def baz(self, a: int) -> None:
pass

@type_enforced.Enforcer(enabled=False)
class Boo:
@type_enforced.Enforcer(enabled=True)
def bar(self, a: int) -> None:
pass

def baz(self, a: int) -> None:
pass

try:
foo = Foo()
foo.bar(a=1) #=> No Exception
foo.baz(a='a') #=> No Exception

boo = Boo()
boo.bar(a=1) #=> No Exception
boo.baz(a='a') #=> No Exception
print("test_class_10.py passed")
except:
print("test_class_10.py failed")
1 change: 0 additions & 1 deletion test/test_fn_01.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
def my_fn(a: int, b: [int, str], c: int) -> None:
return None


success_1 = True
try:
my_fn(a=1, b=2, c=3) # No Error
Expand Down
2 changes: 0 additions & 2 deletions test/test_fn_06.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,11 @@ def run_tests(fn, success):
fn("a", "b") # No Error
try:
fn(a="a", b=2, c="c") # Error (b can only accept str)
print("a")
success = False
except:
pass
try:
fn("a", b=2) # Error (b can only accept str)
print("b")
success = False
except:
pass
Expand Down
29 changes: 29 additions & 0 deletions test/test_fn_15.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import type_enforced

@type_enforced.Enforcer()
def my_fn(a: int):
return None

@type_enforced.Enforcer(enabled = False)
def my_fn2(a: int):
return None

success = True

try:
my_fn(a=1)
my_fn2(a=1)
my_fn2(a='1')
except:
success = False

try:
my_fn(a='1')
success = False
except:
pass

if success:
print("test_fn_15.py passed")
else:
print("test_fn_15.py failed")
2 changes: 1 addition & 1 deletion type_enforced/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
import sys

from .enforcer import Enforcer, FunctionMethodEnforcer, EnforcerIgnore
from .enforcer import Enforcer, FunctionMethodEnforcer
48 changes: 9 additions & 39 deletions type_enforced/enforcer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from types import FunctionType, MethodType, GenericAlias, GeneratorType, BuiltinFunctionType, BuiltinMethodType
from typing import Type, Union, Sized, Literal, Callable
from functools import update_wrapper, wraps
from type_enforced.utils import Partial

# Python 3.10+ has a UnionType object that is used to represent Union types
try:
Expand Down Expand Up @@ -231,8 +232,7 @@ def __check_type__(self, obj, acceptable_types, key):
def __repr__(self):
return f"<type_enforced {self.__fn__.__module__}.{self.__fn__.__qualname__} object at {hex(id(self))}>"


def Enforcer(clsFnMethod):
def Enforcer(clsFnMethod, enabled):
"""
A wrapper to enforce types within a function or method given argument annotations.
Expand Down Expand Up @@ -271,11 +271,15 @@ def Enforcer(clsFnMethod):
Exception: (my_fn): Type mismatch for typed variable `a`. Expected one of the following `[<class 'int'>]` but got `<class 'str'>` instead.
```
"""
if not hasattr(clsFnMethod, "__enforcer_enabled__"):
clsFnMethod.__enforcer_enabled__ = enabled
if clsFnMethod.__enforcer_enabled__ == False:
return clsFnMethod
if isinstance(
clsFnMethod, (staticmethod, classmethod, FunctionType, MethodType)
):
# Only apply the enforcer if annotations are specified
if getattr(clsFnMethod, "__annotations__", {}) == {} or getattr(clsFnMethod, "__no_type_check__", False):
if getattr(clsFnMethod, "__annotations__", {}) == {}:
return clsFnMethod
elif isinstance(clsFnMethod, staticmethod):
return staticmethod(FunctionMethodEnforcer(clsFnMethod.__func__))
Expand All @@ -288,45 +292,11 @@ def Enforcer(clsFnMethod):
if hasattr(value, "__call__") or isinstance(
value, (classmethod, staticmethod)
):
setattr(clsFnMethod, key, Enforcer(value))
setattr(clsFnMethod, key, Enforcer(value, enabled=enabled))
return clsFnMethod
else:
raise Exception(
"Enforcer can only be used on class methods, functions, or classes."
)

def EnforcerIgnore(fnMethod):
"""
A wrapper to ignore type enforcement for or method in a larger class wrapped by `Enforcer`.
Requires:
- `fnMethod`:
- What: The method or function that should have input types enforced
- Type: method | classmethod | staticmethod | function
Example Use:
```
import type_enforced
@type_enforced.Enforcer
class Foo:
def bar(self, a: int) -> None:
pass
@type_enforced.EnforcerIgnore
def baz(self, a: int) -> None:
pass
foo = Foo()
foo.bar(a=1) #=> No Exception
foo.baz(a='a') #=> No Exception
```
"""
if isinstance(
fnMethod, (staticmethod, classmethod, FunctionType, MethodType)
):
setattr(fnMethod, "__no_type_check__", True)
else:
raise Exception("EnforcerIgnore can only be used on methods, classmethods, staticmethods, or functions.")
return fnMethod
Enforcer = Partial(Enforcer, enabled=True)
57 changes: 57 additions & 0 deletions type_enforced/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import types
from functools import update_wrapper

def WithSubclasses(obj):
"""
A special helper function to allow a class type to be passed and also allow all subclasses of that type.
Expand All @@ -23,3 +26,57 @@ def WithSubclasses(obj):
for i in obj.__subclasses__():
out += WithSubclasses(i)
return out

class Partial:
"""
A special class wrapper to allow for easy partial function wrappings and calls.
"""
def __init__(
self,
__fn__,
*__args__,
**__kwargs__,
):
update_wrapper(self, __fn__)
self.__fn__ = __fn__
self.__args__ = __args__
self.__kwargs__ = __kwargs__
self.__fnArity__ = self.__getFnArity__()
self.__arity__ = self.__getArity__(__args__, __kwargs__)

def __exception__(self, message):
pre_message = (
f"({self.__fn__.__module__}.{self.__fn__.__qualname__}_partial): "
)
raise Exception(pre_message + message)

def __call__(self, *args, **kwargs):
new_args = self.__args__ + args
new_kwargs = {**self.__kwargs__, **kwargs}
self.__arity__ = self.__getArity__(new_args, new_kwargs)
if self.__arity__ < 0:
self.__exception__("Too many arguments were supplied")
if self.__arity__ == 0:
results = self.__fn__(*new_args, **new_kwargs)
return results
return Partial(
self.__fn__,
*new_args,
**new_kwargs,
)

def __repr__(self):
return f"<Partial {self.__fn__.__module__}.{self.__fn__.__qualname__} object at {hex(id(self))}>"

def __getArity__(self, args, kwargs):
return self.__fnArity__ - (len(args) + len(kwargs))

def __getFnArity__(self):
if not isinstance(self.__fn__, (types.MethodType, types.FunctionType)):
self.__exception__(
"A non function was passed as a function and does not have any arity. See the stack trace above for more information."
)
extra_method_input_count = (
1 if isinstance(self.__fn__, (types.MethodType)) else 0
)
return self.__fn__.__code__.co_argcount - extra_method_input_count

0 comments on commit 67ecd6c

Please sign in to comment.