Skip to content

Commit

Permalink
Added support for numexpr and root
Browse files Browse the repository at this point in the history
Signed-off-by: Aryan Roy <[email protected]>
  • Loading branch information
aryan26roy committed Jul 3, 2024
1 parent fc6a284 commit 20965e1
Show file tree
Hide file tree
Showing 2 changed files with 215 additions and 3 deletions.
212 changes: 211 additions & 1 deletion src/formulate/AST.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@ class Literal(AST): # Literal: value that appears in the program text
def __str__(self):
return str(self.value)

def to_numexpr(self):
return repr(self.value)

def to_root(self):
return repr(self.value)

def to_python(self):
return repr(self.value)

Expand All @@ -40,6 +46,12 @@ def __str__(self):
# print(x)
# return x

def to_numexpr(self):
return self.symbol

def to_root(self):
return self.symbol

def to_python(self):
return self.symbol

Expand All @@ -57,6 +69,13 @@ def unary_to_ufunc(self, sign):
signmap = {"~": "np.invert", "!": "np.logical_not"}
return signmap[str(sign)]

def to_numexpr(self):
return "(" + self.sign.to_root() + self.operand.to_numexpr() + ")"

def to_root(self):
return "(" + self.sign.to_root() + self.operand.to_root() + ")"


def to_python(self):
if str(self.sign) in {"~", "!"}:
pycode = (
Expand Down Expand Up @@ -91,6 +110,12 @@ def binary_to_ufunc(self, sign):
}
return sign_mapping[str(sign)]

def to_numexpr(self):
return self.left.to_numexpr() + str(self.sign.to_numexpr()) + self.right.to_numexpr()

def to_root(self):
return self.left.to_root() + str(self.sign.to_root()) + self.right.to_root()

def to_python(self):
if str(self.sign) in {
"&",
Expand Down Expand Up @@ -137,6 +162,15 @@ class Matrix(AST): # Matrix: A matrix call
def __str__(self):
return "{0}[{1}]".format(str(self.var), ",".join(str(x) for x in self.paren))

def to_numexpr(self):
raise ValueError("Matrix operations are forbidden in Numexpr, please check the formula at index : " + str(self.index))

def to_root(self):
index = ""
for elem in self.paren:
index += "[" + str(elem.to_root()) + "]"
return self.var.to_root() + index

def to_python(self):
temp_str = [ "," + elem.to_python() for elem in self.paren]
return "(" + str(self.var.to_python()) + "[:" + "".join(temp_str)+"]" + ")"
Expand All @@ -150,6 +184,12 @@ class Slice(AST): # Slice: The slice for matrix
def __str__(self):
return "{0}".format(self.slices)

def to_numexpr(self):
raise ValueError("Matrix operations are forbidden in Numexpr, please check the formula at index : " + str(self.index))

def to_root(self):
return self.slices.to_root()

def to_python(self):
return self.slices.to_python()

Expand All @@ -161,6 +201,12 @@ class Empty(AST): # Slice: The slice for matrix
def __str__(self):
return ""

def to_numexpr(self):
raise ""

def to_root(self):
return ""

def to_python(self):
return ""

Expand All @@ -176,7 +222,171 @@ def __str__(self):
self.function,
", ".join(str(x) for x in self.arguments),
)


def to_numexpr(self):
print(str(self.function))
match str(self.function):
case "pi":
return "arccos(-1)"
case "e":
return "exp(1)"
case "inf":
return "inf"
case "nan":
raise ValueError("No equivalent in Numexpr!")
case "sqrt2":
return "sqrt(2)"
case "piby2":
return "(arccos(-1)/2)"
case "piby4":
return "(arccos(-1)/4)"
case "2pi":
return "(arccos(-1)*2.0)"
case "ln10":
return f"log(10)"
case "loge":
return f"np.log10(np.exp(1))"
case "log":
return f"log({self.arguments[0]})"
case "log10":
return f"(log10({self.arguments[0]})/log(2))"
case "degtorad":
return f"np.radians({self.arguments[0]})"
case "radtodeg":
return f"np.degrees({self.arguments[0]})"
case "exp":
return f"np.exp({self.arguments[0]})"
case "sin":
return f"sin({self.arguments[0]})"
case "asin":
return f"arcsin({self.arguments[0]})"
case "sinh":
return f"sinh({self.arguments[0]})"
case "asinh":
return f"arcsinh({self.arguments[0]})"
case "cos":
return f"cos({self.arguments[0]})"
case "arccos":
return f"arccos({self.arguments[0]})"
case "cosh":
return f"cosh({self.arguments[0]})"
case "acosh":
return f"arccosh({self.arguments[0]})"
case "tan":
return f"tan({self.arguments[0]})"
case "arctan":
return f"arctan({self.arguments[0]})"
case "tanh":
return f"tanh({self.arguments[0]})"
case "atanh":
return f"arctanh({self.arguments[0]})"
case "Math::sqrt":
return f"sqrt({self.arguments[0]})"
case "sqrt":
return f"sqrt({self.arguments[0]})"
case "ceil":
return f"ceil({self.arguments[0]})"
case "abs":
return f"abs({self.arguments[0]})"
case "even":
return f"not ({self.arguments[0]} % 2)"
case "factorial":
raise ValueError("Cannot translate to Numexpr!")
case "floor":
return f"! np.floor({self.arguments[0]})"
case "where":
return f"where({self.arguments[0]},{self.arguments[1]},{self.arguments[3]})"
case _ :
raise ValueError("Not a valid function!")

def to_root(self):
match str(self.function):
case "pi":
return "TMath::Pi"
case "e":
return "TMath::E"
case "inf":
return "TMATH::Infinity"
case "nan":
return "TMATH::QuietNan"
case "sqrt2":
return "TMATH::Sqrt2({self.arguments[0]})"
case "piby2":
return "TMATH::PiOver4"
case "piby4":
return "TMATH::PiOver4"
case "2pi":
return "TMATH::TwoPi"
case "ln10":
return f"TMATH::Ln10({self.arguments[0]})"
case "loge":
return f"TMATH::LogE({self.arguments[0]})"
case "log":
return f"TMATH::Log({self.arguments[0]})"
case "log2":
return f"TMATH::Log2({self.arguments[0]})"
case "degtorad":
return f"TMATH::DegToRad({self.arguments[0]})"
case "radtodeg":
return f"TMATH::RadToDeg({self.arguments[0]})"
case "exp":
return f"TMATH::Exp({self.arguments[0]})"
case "sin":
return f"TMATH::Sin({self.arguments[0]})"
case "asin":
return f"TMATH::ASin({self.arguments[0]})"
case "sinh":
return f"TMATH::SinH({self.arguments[0]})"
case "asinh":
return f"TMATH::ASinH({self.arguments[0]})"
case "cos":
return f"TMATH::Cos({self.arguments[0]})"
case "arccos":
return f"TMATH::ACos({self.arguments[0]})"
case "cosh":
return f"TMATH::CosH({self.arguments[0]})"
case "acosh":
return f"TMATH::ACosH({self.arguments[0]})"
case "tan":
return f"TMATH::Tan({self.arguments[0]})"
case "arctan":
return f"TMATH::ATan({self.arguments[0]})"
case "tanh":
return f"TMATH::TanH({self.arguments[0]})"
case "atanh":
return f"TMATH::ATanH({self.arguments[0]})"
case "Math::sqrt":
return f"TMATH::Sqrt({self.arguments[0]})"
case "sqrt":
return f"TMATH::Sqrt({self.arguments[0]})"
case "ceil":
return f"TMATH::Ceil({self.arguments[0]})"
case "abs":
return f"TMATH::Abs({self.arguments[0]})"
case "even":
return f"TMATH::Even({self.arguments[0]})"
case "factorial":
return f"TMATH::Factorial({self.arguments[0]})"
case "floor":
return f"TMATH::Floor({self.arguments[0]})"
case "abs":
return f"TMATH::Abs({self.arguments[0]})"
case "max":
return f"Max$({self.arguments[0]})"
case "min":
return f"Min$({self.arguments[0]})"
case "sum":
return f"Sum$({self.arguments[0]})"
case "no_of_entries":
return f"Length$({self.arguments[0]})"
case "min_if":
return f"MinIf$({self.arguments[0]})"
case "max_if":
return f"MaxIf$({self.arguments[0]})"
case _ :
raise ValueError("Not a valid function!")


def to_python(self):
print(str(self.function))
match str(self.function):
Expand Down
6 changes: 4 additions & 2 deletions src/formulate/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@


def from_root(exp : str, **kwargs) -> AST :
ptree = ttreeformula.exp_to_ptree(exp)
parser = ttreeformula.Lark_StandAlone()
ptree = parser.parse(exp)
return toast.toast(ptree, nxp=False)

def from_numexpr(exp : str, **kwargs) -> AST :
ptree = numexpr.exp_to_ptree(exp)
parser = numexpr.Lark_StandAlone()
ptree = parser.parse(exp)
return toast.toast(ptree, nxp=True)

0 comments on commit 20965e1

Please sign in to comment.