Skip to content

Commit

Permalink
Added numexpr and recreated old API
Browse files Browse the repository at this point in the history
  • Loading branch information
aryan26roy committed Apr 28, 2024
1 parent cfd97c4 commit fc6a284
Show file tree
Hide file tree
Showing 6 changed files with 200 additions and 49 deletions.
2 changes: 2 additions & 0 deletions src/formulate/AST.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,8 @@ def to_python(self):
return f"np.arctanh({self.arguments[0]})"
case "Math::sqrt":
return f"np.sqrt({self.arguments[0]})"
case "sqrt":
return f"np.sqrt({self.arguments[0]})"
case "ceil":
return f"np.ceil({self.arguments[0]})"
case "abs":
Expand Down
16 changes: 15 additions & 1 deletion src/formulate/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,22 @@

from __future__ import annotations

from . import _utils, ttreeformula # noqa # noqa
from . import ttreeformula, numexpr # noqa # noqa

from . import AST

from . import toast

__version__ = "0.1.0"

__all__ = ("__version__",)



def from_root(exp : str, **kwargs) -> AST :
ptree = ttreeformula.exp_to_ptree(exp)
return toast.toast(ptree, nxp=False)

def from_numexpr(exp : str, **kwargs) -> AST :
ptree = numexpr.exp_to_ptree(exp)
return toast.toast(ptree, nxp=True)
17 changes: 9 additions & 8 deletions src/formulate/numexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,7 @@
from . import matching_tree

expression_grammar = r'''
start: disjunction
disjunction: conjunction | conjunction "||" conjunction -> lor
conjunction: inversion | inversion "&&" inversion -> land
inversion: comparison | "!" inversion -> linv
start: comparison
comparison: bitwise_or | comparison ">" bitwise_or -> gt
| comparison ">=" bitwise_or -> gte
| comparison "<" bitwise_or -> lt
Expand All @@ -30,11 +27,15 @@
term: factor | factor "*" term -> mul
| factor "/" term -> div
| factor "%" term -> mod
factor: pow | factor matpos+ -> matr
| "+" factor -> pos
factor: pow | "+" factor -> pos
| "-" factor -> neg
pow: CNAME | CNAME "**" factor -> pow
matpos: "[" [sum] "]"
pow: atom | atom "**" factor -> pow
atom: "(" comparison ")" | CNAME -> symbol
| NUMBER -> literal
| func_name trailer -> func
func_name: CNAME
trailer: "(" [arglist] ")"
arglist: comparison ("," comparison)* [","]
%import common.CNAME
%import common.NUMBER
%import common.WS
Expand Down
37 changes: 22 additions & 15 deletions src/formulate/toast.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,14 +64,16 @@
"TMATH::QUIETNAN": "nan",
"TMATH::SQRT2": "sqrt2",
"SQRT2": "sqrt2",
"SQRT": "sqrt",
"TMATH::PIOVER2": "piby2",
"TMATH::PIOVER4": "piby4",
"TMATH::TWOPI": "2pi",
"LN10": "ln10",
"TMATH::LN10": "ln10",
"TMATH::LogE": "loge",
"TMATH::Log": "log",
"TMATH::Log2": "log2",
"TMATH::LOGE": "loge",
"TMATH::LOG": "log",
"LOG": "log",
"TMATH::LOG2": "log2",
"EXP": "exp",
"TMATH::EXP": "exp",
"TMATH::DEGTORAD": "degtorad",
Expand Down Expand Up @@ -112,16 +114,17 @@

def _get_func_names(func_names):
children = []
print(func_names)
if len(func_names.children) > 1:
children.extend(_get_func_names(func_names.children[1]))
children.append(func_names.children[0])
return children


def toast(ptnode: matching_tree.ptnode):
def toast(ptnode: matching_tree.ptnode, nxp : bool):
match ptnode:
case matching_tree.ptnode(operator, (left, right)) if operator in BINARY_OP:
arguments = [toast(left), toast(right)]
arguments = [toast(left,nxp), toast(right,nxp)]
return AST.BinaryOperator(
AST.Symbol(val_to_sign[operator], index=arguments[1].index),
arguments[0],
Expand All @@ -130,14 +133,14 @@ def toast(ptnode: matching_tree.ptnode):
)

case matching_tree.ptnode(operator, operand) if operator in UNARY_OP:
argument = toast(operand[0])
argument = toast(operand[0],nxp)
return AST.UnaryOperator(
AST.Symbol(val_to_sign[operator], index=argument.index), argument
)

case matching_tree.ptnode("multi_out", (exp1, exp2)):
exp_node1 = toast(exp1)
exp_node2 = toast(exp2)
exp_node1 = toast(exp1,nxp)
exp_node2 = toast(exp2,nxp)
exps = [exp_node1, exp_node2]
if isinstance(exp_node2, AST.Call) and exp_node2.function == ":":
del exps[-1]
Expand All @@ -146,14 +149,14 @@ def toast(ptnode: matching_tree.ptnode):
return AST.Call(val_to_sign["multi_out"], exps, index=exp_node1.index)

case matching_tree.ptnode("matr", (array, *slice)):
var = toast(array)
paren = [toast(elem) for elem in slice]
var = toast(array,nxp)
paren = [toast(elem,nxp) for elem in slice]
return AST.Matrix(var, paren, index=var.index)

case matching_tree.ptnode("matpos", child):
if child[0] is None:
return AST.Empty()
slice = toast(child[0])
slice = toast(child[0],nxp)
return AST.Slice(slice, index=slice.index)

case matching_tree.ptnode("func", (func_name, trailer)):
Expand All @@ -171,15 +174,19 @@ def toast(ptnode: matching_tree.ptnode):
func_arguments,
index=func_names[0].start_pos,
)

func_arguments = [toast(elem) for elem in trailer.children[0].children]
for elem in trailer.children[0].children:
print(elem)
func_arguments = [toast(elem,nxp) for elem in trailer.children[0].children]

funcs = root_to_common(func_names, func_names[0].start_pos)

return AST.Call(funcs, func_arguments, index=func_names[0].start_pos)

case matching_tree.ptnode("symbol", children):
var_name = _get_func_names(children[0])[0]
if not nxp:
var_name = _get_func_names(children[0])[0]
else:
var_name = children[0]
print(var_name)
temp_symbol = AST.Symbol(str(var_name), index=var_name.start_pos)
# if temp_symbol.check_CNAME() is not None:
Expand All @@ -191,7 +198,7 @@ def toast(ptnode: matching_tree.ptnode):
return AST.Literal(float(children[0]), index=children[0].start_pos)

case matching_tree.ptnode(_, (child,)):
return toast(child)
return toast(child,nxp)

case _:
raise TypeError(f"Unknown Node Type: {ptnode!r}.")
Expand Down
127 changes: 127 additions & 0 deletions tests/test_numexpr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
from __future__ import annotations

import formulate

from formulate.toast import toast

import ast


def test_simple_add():
a = toast(formulate.numexpr.exp_to_ptree("a+2.0"), nxp = True)
out = a.to_python()
assert ast.unparse(ast.parse(out)) == ast.unparse(ast.parse("(a+2.0)"))


def test_simple_sub():
a = toast(formulate.numexpr.exp_to_ptree("a-2.0"), nxp = True)
out = a.to_python()
assert ast.unparse(ast.parse(out)) == ast.unparse(ast.parse("(a-2.0)"))


def test_simple_mul():
a = toast(formulate.numexpr.exp_to_ptree("f*2.0"), nxp = True)
out = a.to_python()
assert ast.unparse(ast.parse(out)) == ast.unparse(ast.parse("(f*2.0)"))


def test_simple_div():
a = toast(formulate.numexpr.exp_to_ptree("a/2.0"), nxp = True)
out = a.to_python()
assert ast.unparse(ast.parse(out)) == ast.unparse(ast.parse("(a/2.0)"))


def test_simple_lt():
a = toast(formulate.numexpr.exp_to_ptree("a<2.0"), nxp = True)
out = a.to_python()
assert ast.unparse(ast.parse(out)) == ast.unparse(ast.parse("(a<2.0)"))


def test_simple_lte():
a = toast(formulate.numexpr.exp_to_ptree("a<=2.0"), nxp = True)
out = a.to_python()
assert ast.unparse(ast.parse(out)) == ast.unparse(ast.parse("(a<=2.0)"))


def test_simple_gt():
a = toast(formulate.numexpr.exp_to_ptree("a>2.0"), nxp = True)
out = a.to_python()
assert ast.unparse(ast.parse(out)) == ast.unparse(ast.parse("(a>2.0)"))


def test_simple_gte():
a = toast(formulate.numexpr.exp_to_ptree("a>=2.0"), nxp = True)
out = a.to_python()
assert ast.unparse(ast.parse(out)) == ast.unparse(ast.parse("(a>=2.0)"))


def test_simple_eq():
a = toast(formulate.numexpr.exp_to_ptree("a==2.0"), nxp = True)
out = a.to_python()
assert ast.unparse(ast.parse(out)) == ast.unparse(ast.parse("(a==2.0)"))


def test_simple_neq():
a = toast(formulate.numexpr.exp_to_ptree("a!=2.0"), nxp = True)
out = a.to_python()
assert ast.unparse(ast.parse(out)) == ast.unparse(ast.parse("(a!=2.0)"))


def test_simple_bor():
a = toast(formulate.numexpr.exp_to_ptree("a|b"), nxp = True)
out = a.to_python()
assert ast.unparse(ast.parse(out)) == ast.unparse(ast.parse("np.bitwise_or(a,b)"))


def test_simple_band():
a = toast(formulate.numexpr.exp_to_ptree("a&c"), nxp = True)
out = a.to_python()
assert ast.unparse(ast.parse(out)) == ast.unparse(ast.parse("np.bitwise_and(a,c)"))


def test_simple_bxor():
a = toast(formulate.numexpr.exp_to_ptree("a^2.0"), nxp = True)
out = a.to_python()
assert ast.unparse(ast.parse(out)) == ast.unparse(ast.parse("(a^2.0)"))

def test_simple_pow():
a = toast(formulate.numexpr.exp_to_ptree("a**2.0"), nxp = True)
out = a.to_python()
assert ast.unparse(ast.parse(out)) == ast.unparse(ast.parse("(a**2.0)"))


def test_simple_function():
a = toast(formulate.numexpr.exp_to_ptree("sqrt(4)"), nxp = True)
out = a.to_python()
assert ast.unparse(ast.parse(out)) == ast.unparse(ast.parse("np.sqrt(4.0)"))


def test_simple_unary_pos():
a = toast(formulate.numexpr.exp_to_ptree("+5.0"), nxp = True)
out = a.to_python()
assert ast.unparse(ast.parse(out)) == ast.unparse(ast.parse("(+5.0)"))


def test_simple_unary_neg():
a = toast(formulate.numexpr.exp_to_ptree("-5.0"), nxp = True)
out = a.to_python()
assert ast.unparse(ast.parse(out)) == ast.unparse(ast.parse("(-5.0)"))


def test_simple_unary_binv():
a = toast(formulate.numexpr.exp_to_ptree("~bool"), nxp = True)
out = a.to_python()
assert ast.unparse(ast.parse(out)) == ast.unparse(ast.parse("np.invert(bool)"))



def test_unary_binary_pos():
a = toast(formulate.numexpr.exp_to_ptree("2.0 - -6"), nxp = True)
out = a.to_python()
assert ast.unparse(ast.parse(out)) == ast.unparse(ast.parse("(2.0-(-6.0))"))


def test_complex_exp():
a = toast(formulate.numexpr.exp_to_ptree("~a**b*23/(var|45)"), nxp = True)
out = a.to_python()
assert ast.unparse(ast.parse(out)) == ast.unparse(ast.parse("np.invert(((a**b)*(23.0/np.bitwise_or(var,45.0))))"))
Loading

0 comments on commit fc6a284

Please sign in to comment.