Skip to content

Commit

Permalink
Added more tests
Browse files Browse the repository at this point in the history
Signed-off-by: Aryan Roy <[email protected]>
  • Loading branch information
aryan26roy committed Jul 21, 2024
1 parent 20965e1 commit 37702f2
Show file tree
Hide file tree
Showing 3 changed files with 155 additions and 28 deletions.
4 changes: 2 additions & 2 deletions src/formulate/AST.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,10 @@ def binary_to_ufunc(self, sign):
return sign_mapping[str(sign)]

def to_numexpr(self):
return self.left.to_numexpr() + str(self.sign.to_numexpr()) + self.right.to_numexpr()
return "(" + self.left.to_numexpr() + str(self.sign.to_numexpr()) + self.right.to_numexpr() + ")"

def to_root(self):
return self.left.to_root() + str(self.sign.to_root()) + self.right.to_root()
return "(" + self.left.to_root() + str(self.sign.to_root()) + self.right.to_root() + ")"

def to_python(self):
if str(self.sign) in {
Expand Down
52 changes: 26 additions & 26 deletions tests/test_numexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,119 +9,119 @@

def test_simple_add():
a = toast(formulate.numexpr.exp_to_ptree("a+2.0"), nxp = True)
out = a.to_python()
out = a.to_numexpr()
assert ast.unparse(ast.parse(out)) == ast.unparse(ast.parse("(a+2.0)"))


def test_simple_sub():
a = toast(formulate.numexpr.exp_to_ptree("a-2.0"), nxp = True)
out = a.to_python()
out = a.to_numexpr()
assert ast.unparse(ast.parse(out)) == ast.unparse(ast.parse("(a-2.0)"))


def test_simple_mul():
a = toast(formulate.numexpr.exp_to_ptree("f*2.0"), nxp = True)
out = a.to_python()
out = a.to_numexpr()
assert ast.unparse(ast.parse(out)) == ast.unparse(ast.parse("(f*2.0)"))


def test_simple_div():
a = toast(formulate.numexpr.exp_to_ptree("a/2.0"), nxp = True)
out = a.to_python()
out = a.to_numexpr()
assert ast.unparse(ast.parse(out)) == ast.unparse(ast.parse("(a/2.0)"))


def test_simple_lt():
a = toast(formulate.numexpr.exp_to_ptree("a<2.0"), nxp = True)
out = a.to_python()
out = a.to_numexpr()
assert ast.unparse(ast.parse(out)) == ast.unparse(ast.parse("(a<2.0)"))


def test_simple_lte():
a = toast(formulate.numexpr.exp_to_ptree("a<=2.0"), nxp = True)
out = a.to_python()
out = a.to_numexpr()
assert ast.unparse(ast.parse(out)) == ast.unparse(ast.parse("(a<=2.0)"))


def test_simple_gt():
a = toast(formulate.numexpr.exp_to_ptree("a>2.0"), nxp = True)
out = a.to_python()
out = a.to_numexpr()
assert ast.unparse(ast.parse(out)) == ast.unparse(ast.parse("(a>2.0)"))


def test_simple_gte():
a = toast(formulate.numexpr.exp_to_ptree("a>=2.0"), nxp = True)
out = a.to_python()
out = a.to_numexpr()
assert ast.unparse(ast.parse(out)) == ast.unparse(ast.parse("(a>=2.0)"))


def test_simple_eq():
a = toast(formulate.numexpr.exp_to_ptree("a==2.0"), nxp = True)
out = a.to_python()
out = a.to_numexpr()
assert ast.unparse(ast.parse(out)) == ast.unparse(ast.parse("(a==2.0)"))


def test_simple_neq():
a = toast(formulate.numexpr.exp_to_ptree("a!=2.0"), nxp = True)
out = a.to_python()
out = a.to_numexpr()
assert ast.unparse(ast.parse(out)) == ast.unparse(ast.parse("(a!=2.0)"))


def test_simple_bor():
a = toast(formulate.numexpr.exp_to_ptree("a|b"), nxp = True)
out = a.to_python()
assert ast.unparse(ast.parse(out)) == ast.unparse(ast.parse("np.bitwise_or(a,b)"))
out = a.to_numexpr()
assert ast.unparse(ast.parse(out)) == ast.unparse(ast.parse("a | b"))


def test_simple_band():
a = toast(formulate.numexpr.exp_to_ptree("a&c"), nxp = True)
out = a.to_python()
assert ast.unparse(ast.parse(out)) == ast.unparse(ast.parse("np.bitwise_and(a,c)"))
out = a.to_numexpr()
assert ast.unparse(ast.parse(out)) == ast.unparse(ast.parse("a & c"))


def test_simple_bxor():
a = toast(formulate.numexpr.exp_to_ptree("a^2.0"), nxp = True)
out = a.to_python()
out = a.to_numexpr()
assert ast.unparse(ast.parse(out)) == ast.unparse(ast.parse("(a^2.0)"))

def test_simple_pow():
a = toast(formulate.numexpr.exp_to_ptree("a**2.0"), nxp = True)
out = a.to_python()
out = a.to_numexpr()
assert ast.unparse(ast.parse(out)) == ast.unparse(ast.parse("(a**2.0)"))


def test_simple_function():
a = toast(formulate.numexpr.exp_to_ptree("sqrt(4)"), nxp = True)
out = a.to_python()
assert ast.unparse(ast.parse(out)) == ast.unparse(ast.parse("np.sqrt(4.0)"))
out = a.to_numexpr()
assert ast.unparse(ast.parse(out)) == ast.unparse(ast.parse("sqrt(4.0)"))


def test_simple_unary_pos():
a = toast(formulate.numexpr.exp_to_ptree("+5.0"), nxp = True)
out = a.to_python()
out = a.to_numexpr()
assert ast.unparse(ast.parse(out)) == ast.unparse(ast.parse("(+5.0)"))


def test_simple_unary_neg():
a = toast(formulate.numexpr.exp_to_ptree("-5.0"), nxp = True)
out = a.to_python()
out = a.to_numexpr()
assert ast.unparse(ast.parse(out)) == ast.unparse(ast.parse("(-5.0)"))


def test_simple_unary_binv():
a = toast(formulate.numexpr.exp_to_ptree("~bool"), nxp = True)
out = a.to_python()
assert ast.unparse(ast.parse(out)) == ast.unparse(ast.parse("np.invert(bool)"))
out = a.to_numexpr()
assert ast.unparse(ast.parse(out)) == ast.unparse(ast.parse("~bool"))



def test_unary_binary_pos():
a = toast(formulate.numexpr.exp_to_ptree("2.0 - -6"), nxp = True)
out = a.to_python()
out = a.to_numexpr()
assert ast.unparse(ast.parse(out)) == ast.unparse(ast.parse("(2.0-(-6.0))"))


def test_complex_exp():
a = toast(formulate.numexpr.exp_to_ptree("~a**b*23/(var|45)"), nxp = True)
out = a.to_python()
assert ast.unparse(ast.parse(out)) == ast.unparse(ast.parse("np.invert(((a**b)*(23.0/np.bitwise_or(var,45.0))))"))
a = toast(formulate.numexpr.exp_to_ptree("(~a**b)*23/(var|45)"), nxp = True)
out = a.to_numexpr()
assert ast.unparse(ast.parse(out)) == ast.unparse(ast.parse("((~(a**b))*(23.0/(var|45.0)))"))
127 changes: 127 additions & 0 deletions tests/test_root.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
from __future__ import annotations

import formulate

from formulate.toast import toast

import ast


def test_simple_add():
a = toast(formulate.numexpr.exp_to_ptree("a+2.0"), nxp = True)
out = a.to_root()
assert ast.unparse(ast.parse(out)) == ast.unparse(ast.parse("(a+2.0)"))


def test_simple_sub():
a = toast(formulate.numexpr.exp_to_ptree("a-2.0"), nxp = True)
out = a.to_root()
assert ast.unparse(ast.parse(out)) == ast.unparse(ast.parse("(a-2.0)"))


def test_simple_mul():
a = toast(formulate.numexpr.exp_to_ptree("f*2.0"), nxp = True)
out = a.to_root()
assert ast.unparse(ast.parse(out)) == ast.unparse(ast.parse("(f*2.0)"))


def test_simple_div():
a = toast(formulate.numexpr.exp_to_ptree("a/2.0"), nxp = True)
out = a.to_root()
assert ast.unparse(ast.parse(out)) == ast.unparse(ast.parse("(a/2.0)"))


def test_simple_lt():
a = toast(formulate.numexpr.exp_to_ptree("a<2.0"), nxp = True)
out = a.to_root()
assert ast.unparse(ast.parse(out)) == ast.unparse(ast.parse("(a<2.0)"))


def test_simple_lte():
a = toast(formulate.numexpr.exp_to_ptree("a<=2.0"), nxp = True)
out = a.to_root()
assert ast.unparse(ast.parse(out)) == ast.unparse(ast.parse("(a<=2.0)"))


def test_simple_gt():
a = toast(formulate.numexpr.exp_to_ptree("a>2.0"), nxp = True)
out = a.to_root()
assert ast.unparse(ast.parse(out)) == ast.unparse(ast.parse("(a>2.0)"))


def test_simple_gte():
a = toast(formulate.numexpr.exp_to_ptree("a>=2.0"), nxp = True)
out = a.to_root()
assert ast.unparse(ast.parse(out)) == ast.unparse(ast.parse("(a>=2.0)"))


def test_simple_eq():
a = toast(formulate.numexpr.exp_to_ptree("a==2.0"), nxp = True)
out = a.to_root()
assert ast.unparse(ast.parse(out)) == ast.unparse(ast.parse("(a==2.0)"))


def test_simple_neq():
a = toast(formulate.numexpr.exp_to_ptree("a!=2.0"), nxp = True)
out = a.to_root()
assert ast.unparse(ast.parse(out)) == ast.unparse(ast.parse("(a!=2.0)"))


def test_simple_bor():
a = toast(formulate.numexpr.exp_to_ptree("a|b"), nxp = True)
out = a.to_root()
assert ast.unparse(ast.parse(out)) == ast.unparse(ast.parse("a | b"))


def test_simple_band():
a = toast(formulate.numexpr.exp_to_ptree("a&c"), nxp = True)
out = a.to_root()
assert ast.unparse(ast.parse(out)) == ast.unparse(ast.parse("a & c"))


def test_simple_bxor():
a = toast(formulate.numexpr.exp_to_ptree("a^2.0"), nxp = True)
out = a.to_root()
assert ast.unparse(ast.parse(out)) == ast.unparse(ast.parse("(a^2.0)"))

def test_simple_pow():
a = toast(formulate.numexpr.exp_to_ptree("a**2.0"), nxp = True)
out = a.to_root()
assert ast.unparse(ast.parse(out)) == ast.unparse(ast.parse("(a**2.0)"))


def test_simple_function():
a = toast(formulate.numexpr.exp_to_ptree("sqrt(4)"), nxp = True)
out = a.to_root()
assert out == "TMATH::Sqrt(4.0)"


def test_simple_unary_pos():
a = toast(formulate.numexpr.exp_to_ptree("+5.0"), nxp = True)
out = a.to_root()
assert ast.unparse(ast.parse(out)) == ast.unparse(ast.parse("(+5.0)"))


def test_simple_unary_neg():
a = toast(formulate.numexpr.exp_to_ptree("-5.0"), nxp = True)
out = a.to_root()
assert ast.unparse(ast.parse(out)) == ast.unparse(ast.parse("(-5.0)"))


def test_simple_unary_binv():
a = toast(formulate.numexpr.exp_to_ptree("~bool"), nxp = True)
out = a.to_root()
assert ast.unparse(ast.parse(out)) == ast.unparse(ast.parse("~bool"))



def test_unary_binary_pos():
a = toast(formulate.numexpr.exp_to_ptree("2.0 - -6"), nxp = True)
out = a.to_root()
assert ast.unparse(ast.parse(out)) == ast.unparse(ast.parse("(2.0-(-6.0))"))


def test_complex_exp():
a = toast(formulate.numexpr.exp_to_ptree("(~a**b)*23/(var|45)"), nxp = True)
out = a.to_root()
assert ast.unparse(ast.parse(out)) == ast.unparse(ast.parse("((~(a**b))*(23.0/(var|45.0)))"))

0 comments on commit 37702f2

Please sign in to comment.