Skip to content

Commit

Permalink
Fixed all the tests
Browse files Browse the repository at this point in the history
  • Loading branch information
aryan26roy committed Jan 23, 2024
1 parent b7131b0 commit dfb5ed9
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 159 deletions.
12 changes: 7 additions & 5 deletions src/formulate/AST.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,11 @@ class Symbol(AST): # Symbol: value referenced by name
def __str__(self):
return self.symbol

def check_CNAME(self):
regex = "((\.)\2{2,})"
x = re.search(regex, self.symbol)
print(x)
return x
# def check_CNAME(self):
# regex = "((\.)\2{2,})"
# x = re.search(regex, self.symbol)
# print(x)
# return x

def to_python(self):
return self.symbol
Expand Down Expand Up @@ -186,6 +186,8 @@ def to_python(self):
return "np.exp(1)"
case "e":
return "np.exp(1)"
case "Math::sqrt":
return f"np.sqrt({self.arguments[0]})"
case "max":
return f"root_max({self.arguments[0]})"
case "min":
Expand Down
4 changes: 3 additions & 1 deletion src/formulate/toast.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,9 @@ def toast(ptnode: matching_tree.ptnode):
return AST.Call(funcs, func_arguments, index=func_names[0].start_pos)

case matching_tree.ptnode("symbol", children):
temp_symbol = AST.Symbol(str(children[0]), index=children[0].start_pos)
var_name = _get_func_names(children[0])[0]
print(var_name)
temp_symbol = AST.Symbol(str(var_name), index=var_name.start_pos)
# if temp_symbol.check_CNAME() is not None:
return temp_symbol
# else:
Expand Down
8 changes: 4 additions & 4 deletions src/formulate/ttreeformula.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,14 @@
| "-" factor -> neg
pow: atom | atom "**" factor -> pow
matpos: "[" [sum] "]"
atom: "(" expression ")" | CNAME -> symbol
atom: "(" expression ")" | var_name -> symbol
| NUMBER -> literal
| func_name trailer -> func
func_name: CNAME | CNAME "::" func_name
func_name: NAME | NAME "::" func_name
var_name: NAME | NAME "." var_name
trailer: "(" [arglist] ")"
arglist: expression ("," expression)* [","]
CNAME: /[A-Za-z_$]([A-Za-z0-9_$]|\s*\.\s*)*/
NAME: /[A-Za-z_][A-Za-z0-9_]*\$?/
%import common.NUMBER
%import common.WS
%ignore WS
Expand Down
228 changes: 79 additions & 149 deletions tests/test_parsing_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,226 +2,156 @@

import formulate

from formulate.toast import toast

import ast


def test_simple_add():
a = formulate.ttreeformula.exp_to_ptree("a+2")
aa = []
tree = formulate._utils._ptree_to_string(a, aa)
out = "".join(tree)
assert out == "(a+2)"
a = toast(formulate.ttreeformula.exp_to_ptree("a+2.0"))
out = a.to_python()
assert ast.unparse(ast.parse(out)) == ast.unparse(ast.parse("(a+2.0)"))


def test_simple_sub():
a = formulate.ttreeformula.exp_to_ptree("a-2")
aa = []
tree = formulate._utils._ptree_to_string(a, aa)
out = "".join(tree)
assert out == "(a-2)"
a = toast(formulate.ttreeformula.exp_to_ptree("a-2.0"))
out = a.to_python()
assert ast.unparse(ast.parse(out)) == ast.unparse(ast.parse("(a-2.0)"))


def test_simple_mul():
a = formulate.ttreeformula.exp_to_ptree("f*2")
aa = []
tree = formulate._utils._ptree_to_string(a, aa)
out = "".join(tree)
assert out == "(f*2)"
a = toast(formulate.ttreeformula.exp_to_ptree("f*2.0"))
out = a.to_python()
assert ast.unparse(ast.parse(out)) == ast.unparse(ast.parse("(f*2.0)"))


def test_simple_div():
a = formulate.ttreeformula.exp_to_ptree("a/2")
aa = []
tree = formulate._utils._ptree_to_string(a, aa)
out = "".join(tree)
assert out == "(a/2)"
a = toast(formulate.ttreeformula.exp_to_ptree("a/2.0"))
out = a.to_python()
assert ast.unparse(ast.parse(out)) == ast.unparse(ast.parse("(a/2.0)"))


def test_simple_lt():
a = formulate.ttreeformula.exp_to_ptree("a<2")
aa = []
tree = formulate._utils._ptree_to_string(a, aa)
out = "".join(tree)
assert out == "(a<2)"
a = toast(formulate.ttreeformula.exp_to_ptree("a<2.0"))
out = a.to_python()
assert ast.unparse(ast.parse(out)) == ast.unparse(ast.parse("(a<2.0)"))


def test_simple_lte():
a = formulate.ttreeformula.exp_to_ptree("a<=2")
aa = []
tree = formulate._utils._ptree_to_string(a, aa)
out = "".join(tree)
assert out == "(a<=2)"
a = toast(formulate.ttreeformula.exp_to_ptree("a<=2.0"))
out = a.to_python()
assert ast.unparse(ast.parse(out)) == ast.unparse(ast.parse("(a<=2.0)"))


def test_simple_gt():
a = formulate.ttreeformula.exp_to_ptree("a>2")
aa = []
tree = formulate._utils._ptree_to_string(a, aa)
out = "".join(tree)
assert out == "(a>2)"
a = toast(formulate.ttreeformula.exp_to_ptree("a>2.0"))
out = a.to_python()
assert ast.unparse(ast.parse(out)) == ast.unparse(ast.parse("(a>2.0)"))


def test_simple_gte():
a = formulate.ttreeformula.exp_to_ptree("a>=2")
aa = []
tree = formulate._utils._ptree_to_string(a, aa)
out = "".join(tree)
assert out == "(a>=2)"
a = toast(formulate.ttreeformula.exp_to_ptree("a>=2.0"))
out = a.to_python()
assert ast.unparse(ast.parse(out)) == ast.unparse(ast.parse("(a>=2.0)"))


def test_simple_eq():
a = formulate.ttreeformula.exp_to_ptree("a==2")
aa = []
tree = formulate._utils._ptree_to_string(a, aa)
out = "".join(tree)
assert out == "(a==2)"
a = toast(formulate.ttreeformula.exp_to_ptree("a==2.0"))
out = a.to_python()
assert ast.unparse(ast.parse(out)) == ast.unparse(ast.parse("(a==2.0)"))


def test_simple_neq():
a = formulate.ttreeformula.exp_to_ptree("a!=2")
aa = []
tree = formulate._utils._ptree_to_string(a, aa)
out = "".join(tree)
assert out == "(a!=2)"
a = toast(formulate.ttreeformula.exp_to_ptree("a!=2.0"))
out = a.to_python()
assert ast.unparse(ast.parse(out)) == ast.unparse(ast.parse("(a!=2.0)"))


def test_simple_bor():
a = formulate.ttreeformula.exp_to_ptree("a|b")
aa = []
tree = formulate._utils._ptree_to_string(a, aa)
out = "".join(tree)
assert out == "(a|b)"
a = toast(formulate.ttreeformula.exp_to_ptree("a|b"))
out = a.to_python()
assert ast.unparse(ast.parse(out)) == ast.unparse(ast.parse("np.bitwise_or(a,b)"))


def test_simple_band():
a = formulate.ttreeformula.exp_to_ptree("a&c")
aa = []
tree = formulate._utils._ptree_to_string(a, aa)
out = "".join(tree)
assert out == "(a&c)"
a = toast(formulate.ttreeformula.exp_to_ptree("a&c"))
out = a.to_python()
assert ast.unparse(ast.parse(out)) == ast.unparse(ast.parse("np.bitwise_and(a,c)"))


def test_simple_bxor():
a = formulate.ttreeformula.exp_to_ptree("a^2")
aa = []
tree = formulate._utils._ptree_to_string(a, aa)
out = "".join(tree)
assert out == "(a^2)"
a = toast(formulate.ttreeformula.exp_to_ptree("a^2.0"))
out = a.to_python()
assert ast.unparse(ast.parse(out)) == ast.unparse(ast.parse("(a^2.0)"))


def test_simple_land():
a = formulate.ttreeformula.exp_to_ptree("a&&2")
aa = []
tree = formulate._utils._ptree_to_string(a, aa)
out = "".join(tree)
assert out == "(a&&2)"
a = toast(formulate.ttreeformula.exp_to_ptree("a&&2.0"))
out = a.to_python()
assert ast.unparse(ast.parse(out)) == ast.unparse(ast.parse("(a and 2.0)"))


def test_simple_lor():
a = formulate.ttreeformula.exp_to_ptree("a||2")
aa = []
tree = formulate._utils._ptree_to_string(a, aa)
out = "".join(tree)
assert out == "(a||2)"
a = toast(formulate.ttreeformula.exp_to_ptree("a||2.0"))
out = a.to_python()
assert ast.unparse(ast.parse(out)) == ast.unparse(ast.parse("(a or 2.0)"))


def test_simple_pow():
a = formulate.ttreeformula.exp_to_ptree("a**2")
aa = []
tree = formulate._utils._ptree_to_string(a, aa)
out = "".join(tree)
assert out == "(a**2)"
a = toast(formulate.ttreeformula.exp_to_ptree("a**2.0"))
out = a.to_python()
assert ast.unparse(ast.parse(out)) == ast.unparse(ast.parse("(a**2.0)"))


def test_simple_matrix():
a = formulate.ttreeformula.exp_to_ptree("a[45][1]")
aa = []
tree = formulate._utils._ptree_to_string(a, aa)
out = "".join(tree)
assert out == "a[45][1]"
a = toast(formulate.ttreeformula.exp_to_ptree("a[45][1]"))
out = a.to_python()
assert ast.unparse(ast.parse(out)) == ast.unparse(ast.parse("a[:, 45.0, 1.0]"))


def test_simple_function():
a = formulate.ttreeformula.exp_to_ptree("Math::sqrt(4)")
aa = []
tree = formulate._utils._ptree_to_string(a, aa)
out = "".join(tree)
assert out == "Math::sqrt(4)"


def test_simple_function2():
a = formulate.ttreeformula.exp_to_ptree("Math::sqrt::three(4)")
aa = []
tree = formulate._utils._ptree_to_string(a, aa)
out = "".join(tree)
assert out == "Math::sqrt::three(4)"


def test_simple_function3():
a = formulate.ttreeformula.exp_to_ptree("Math::sqrt::three::four(4)")
aa = []
tree = formulate._utils._ptree_to_string(a, aa)
out = "".join(tree)
assert out == "Math::sqrt::three::four(4)"
a = toast(formulate.ttreeformula.exp_to_ptree("Math::sqrt(4)"))
out = a.to_python()
assert ast.unparse(ast.parse(out)) == ast.unparse(ast.parse("np.sqrt(4.0)"))


def test_simple_unary_pos():
a = formulate.ttreeformula.exp_to_ptree("+5")
aa = []
tree = formulate._utils._ptree_to_string(a, aa)
out = "".join(tree)
assert out == "(+5)"
a = toast(formulate.ttreeformula.exp_to_ptree("+5.0"))
out = a.to_python()
assert ast.unparse(ast.parse(out)) == ast.unparse(ast.parse("(+5.0)"))


def test_simple_unary_neg():
a = formulate.ttreeformula.exp_to_ptree("-5")
aa = []
tree = formulate._utils._ptree_to_string(a, aa)
out = "".join(tree)
assert out == "(-5)"
a = toast(formulate.ttreeformula.exp_to_ptree("-5.0"))
out = a.to_python()
assert ast.unparse(ast.parse(out)) == ast.unparse(ast.parse("(-5.0)"))


def test_simple_unary_binv():
a = formulate.ttreeformula.exp_to_ptree("~bool")
aa = []
tree = formulate._utils._ptree_to_string(a, aa)
out = "".join(tree)
assert out == "(~bool)"
a = toast(formulate.ttreeformula.exp_to_ptree("~bool"))
out = a.to_python()
assert ast.unparse(ast.parse(out)) == ast.unparse(ast.parse("np.invert(bool)"))


def test_simple_unary_linv():
a = formulate.ttreeformula.exp_to_ptree("!bool")
aa = []
tree = formulate._utils._ptree_to_string(a, aa)
out = "".join(tree)
assert out == "(!bool)"


def test_simple_matrix_unary_pos():
a = formulate.ttreeformula.exp_to_ptree("Math::sqrt::three::four(-4)")
aa = []
tree = formulate._utils._ptree_to_string(a, aa)
out = "".join(tree)
assert out == "Math::sqrt::three::four((-4))"
a = toast(formulate.ttreeformula.exp_to_ptree("!bool"))
out = a.to_python()
assert ast.unparse(ast.parse(out)) == ast.unparse(ast.parse("np.logical_not(bool)"))


def test_unary_binary_pos():
a = formulate.ttreeformula.exp_to_ptree("2 - -6")
aa = []
tree = formulate._utils._ptree_to_string(a, aa)
out = "".join(tree)
assert out == "(2-(-6))"
a = toast(formulate.ttreeformula.exp_to_ptree("2.0 - -6"))
out = a.to_python()
assert ast.unparse(ast.parse(out)) == ast.unparse(ast.parse("(2.0-(-6.0))"))


def test_complex_matrix():
a = formulate.ttreeformula.exp_to_ptree("mat1[a**23][mat2[45 - -34]]")
aa = []
tree = formulate._utils._ptree_to_string(a, aa)
out = "".join(tree)
assert out == "mat1[(a**23)][mat2[(45-(-34))]]"
a = toast(formulate.ttreeformula.exp_to_ptree("mat1[a**23][mat2[45 - -34]]"))
out = a.to_python()
assert ast.unparse(ast.parse(out)) == ast.unparse(ast.parse("(mat1[:,(a**23.0),(mat2[:,(45.0-(-34.0))])])"))


def test_complex_exp():
a = formulate.ttreeformula.exp_to_ptree("~a**b*23/(var||45)")
aa = []
tree = formulate._utils._ptree_to_string(a, aa)
out = "".join(tree)
assert out == "(~((a**b)*(23/(var||45))))"
a = toast(formulate.ttreeformula.exp_to_ptree("~a**b*23/(var||45)"))
out = a.to_python()
assert ast.unparse(ast.parse(out)) == ast.unparse(ast.parse("np.invert(((a**b)*(23.0/(var or 45.0))))"))

0 comments on commit dfb5ed9

Please sign in to comment.