Skip to content

Commit

Permalink
removed infix flag from Call, to fix issue #45
Browse files Browse the repository at this point in the history
  • Loading branch information
jsiek committed Dec 19, 2024
1 parent 8a558c4 commit e42fb33
Show file tree
Hide file tree
Showing 6 changed files with 112 additions and 81 deletions.
75 changes: 39 additions & 36 deletions abstract_syntax.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,7 +642,7 @@ def is_match(pattern, arg, subst):

case PatternCons(loc1, constr, params):
match arg:
case Call(loc2, cty, rator, args, infix):
case Call(loc2, cty, rator, args):
match rator:
case Var(loc3, ty3, name, rs):
if constr == Var(loc3, ty3, name, rs) and len(params) == len(args):
Expand Down Expand Up @@ -731,10 +731,16 @@ def operator_name(trm):
return operator_name(subject)
case _:
raise Exception('operator_name, unexpected term ' + str(trm))


def is_infix_operator(trm):
return is_operator(trm) and operator_name(trm) in infix_precedence.keys()

def is_prefix_operator(trm):
return is_operator(trm) and operator_name(trm) in prefix_precedence.keys()

def precedence(trm):
match trm:
case Call(loc1, tyof, rator, args, infix) if is_operator(rator):
case Call(loc1, tyof, rator, args) if is_operator(rator):
op_name = operator_name(rator)
if len(args) == 2:
return infix_precedence.get(op_name, None)
Expand All @@ -755,22 +761,20 @@ def op_arg_str(trm, arg):
class Call(Term):
rator: Term
args: list[Term]
infix: bool

def copy(self):
ret = Call(self.location, self.typeof,
self.rator.copy(),
[arg.copy() for arg in self.args],
self.infix)
[arg.copy() for arg in self.args])
if hasattr(self, 'type_args'):
ret.type_args = self.type_args
return ret

def __str__(self):
if self.infix:
if is_infix_operator(self.rator) and len(self.args) == 2:
return op_arg_str(self, self.args[0]) + " " + operator_name(self.rator) \
+ " " + op_arg_str(self, self.args[1])
elif is_operator(self.rator): # prefix operator
elif is_prefix_operator(self.rator) and len(self.args) == 1:
return operator_name(self.rator) + " " + op_arg_str(self, self.args[0])
elif isNat(self) and not get_verbose():
return str(natToInt(self))
Expand Down Expand Up @@ -809,7 +813,7 @@ def reduce(self, env):
elif constructor_conflict(args[0], args[1], env):
ret = Bool(loc, BoolType(loc), False)
else:
ret = Call(self.location, self.typeof, fun, args, self.infix)
ret = Call(self.location, self.typeof, fun, args)
case Lambda(loc, ty, vars, body):
subst = {k: v for ((k,t),v) in zip(vars, args)}
for (k,v) in subst.items():
Expand Down Expand Up @@ -855,7 +859,7 @@ def reduce(self, env):
return result
else:
pass
ret = Call(self.location, self.typeof, fun, args, self.infix)
ret = Call(self.location, self.typeof, fun, args)

case RecFun(loc, name, [], params, returns, cases):
if get_verbose():
Expand Down Expand Up @@ -889,14 +893,14 @@ def reduce(self, env):
return result
else:
pass
ret = Call(self.location, self.typeof, fun, args, self.infix)
ret = Call(self.location, self.typeof, fun, args)

case Generic(loc2, tyof, typarams, body):
error(self.location, 'in reduction, call to generic\n\t' + str(self))
case _:
if get_verbose():
print('not reducing call because neutral function: ' + str(fun))
ret = Call(self.location, self.typeof, fun, args, self.infix)
ret = Call(self.location, self.typeof, fun, args)
if hasattr(self, 'type_args'):
ret.type_args = self.type_args
if get_verbose():
Expand All @@ -905,8 +909,7 @@ def reduce(self, env):

def substitute(self, sub):
ret = Call(self.location, self.typeof, self.rator.substitute(sub),
[arg.substitute(sub) for arg in self.args],
self.infix)
[arg.substitute(sub) for arg in self.args])
if hasattr(self, 'type_args'):
ret.type_args = self.type_args
return ret
Expand Down Expand Up @@ -1344,7 +1347,7 @@ def __str__(self):
case Bool(loc, tyof, False):
return str(Call(self.location, self.typeof,
Var(self.location, None, 'not'),
[self.premise], False))
[self.premise]))
case _:
return '(if ' + str(self.premise) \
+ ' then ' + str(self.conclusion) + ')'
Expand Down Expand Up @@ -2439,19 +2442,19 @@ def uniquify_body(self, env):
pass

def mkEqual(loc, arg1, arg2):
ret = Call(loc, None, Var(loc, None, '=', []), [arg1, arg2], True)
ret = Call(loc, None, Var(loc, None, '=', []), [arg1, arg2])
return ret

def split_equation(loc, equation):
match equation:
case Call(loc1, tyof, Var(loc2, tyof2, '=', rs2), [L, R], _):
case Call(loc1, tyof, Var(loc2, tyof2, '=', rs2), [L, R]):
return (L, R)
case _:
error(loc, 'expected an equality, not ' + str(equation))

def is_equation(formula):
match formula:
case Call(loc1, tyof, Var(loc2, tyof2, '=', rs2), [L, R], _):
case Call(loc1, tyof, Var(loc2, tyof2, '=', rs2), [L, R]):
return True
case _:
return False
Expand All @@ -2460,7 +2463,7 @@ def mkZero(loc):
return Var(loc, None, 'zero', [])

def mkSuc(loc, arg):
return Call(loc, None, Var(loc, None, 'suc', []), [arg], False)
return Call(loc, None, Var(loc, None, 'suc', []), [arg])

def intToNat(loc, n):
if n == 0:
Expand All @@ -2472,7 +2475,7 @@ def isNat(t):
match t:
case Var(loc, tyof, name, rs) if base_name(name) == 'zero':
return True
case Call(loc, tyof1, Var(loc2, tyof2, name, rs), [arg], infix) \
case Call(loc, tyof1, Var(loc2, tyof2, name, rs), [arg]) \
if base_name(name) == 'suc':
return isNat(arg)
case _:
Expand All @@ -2482,15 +2485,15 @@ def natToInt(t):
match t:
case Var(loc, tyof, name, rs) if base_name(name) == 'zero':
return 0
case Call(loc, tyof1, Var(loc2, tyof2, name, rs), [arg], infix) \
case Call(loc, tyof1, Var(loc2, tyof2, name, rs), [arg]) \
if base_name(name) == 'suc':
return 1 + natToInt(arg)

def mkPos(loc, arg):
return Call(loc, None, Var(loc, None, 'pos', []), [arg], False)
return Call(loc, None, Var(loc, None, 'pos', []), [arg])

def mkNeg(loc, arg):
return Call(loc, None, Var(loc, None, 'negsuc', []), [arg], False)
return Call(loc, None, Var(loc, None, 'negsuc', []), [arg])

def intToDeduceInt(loc, n, sign):
if sign == 'PLUS':
Expand All @@ -2500,19 +2503,19 @@ def intToDeduceInt(loc, n, sign):

def isDeduceInt(t):
match t:
case Call(loc, tyof1, Var(loc2, tyof2, name), [arg], infix) if base_name(name) == 'pos':
case Call(loc, tyof1, Var(loc2, tyof2, name), [arg]) if base_name(name) == 'pos':
return isNat(arg)
case Call(loc, tyof1, Var(loc2, tyof2, name), [arg], infix) if base_name(name) == 'negsuc':
case Call(loc, tyof1, Var(loc2, tyof2, name), [arg]) if base_name(name) == 'negsuc':
return isNat(arg)
case _:
return False


def deduceIntToInt(t):
match t:
case Call(loc, tyof1, Var(loc2, tyof2, name), [arg], infix) if base_name(name) == 'pos':
case Call(loc, tyof1, Var(loc2, tyof2, name), [arg]) if base_name(name) == 'pos':
return '+' + str(natToInt(arg))
case Call(loc, tyof1, Var(loc2, tyof2, name), [arg], infix) if base_name(name) == 'negsuc':
case Call(loc, tyof1, Var(loc2, tyof2, name), [arg]) if base_name(name) == 'negsuc':
return '-' + str(1 + natToInt(arg))

def is_constructor(constr_name, env):
Expand Down Expand Up @@ -2558,7 +2561,7 @@ def isNodeList(t):
if base_name(name) == 'empty':
return True
case Call(loc, tyof1, TermInst(loc2, tyof2, Var(loc3, tyof3, name, rs3), tyargs, inferred),
[arg, ls], infix) if base_name(name) == 'node':
[arg, ls]) if base_name(name) == 'node':
return isNodeList(ls)
case _:
return False
Expand All @@ -2569,14 +2572,14 @@ def nodeListToList(t):
if base_name(name) == 'empty':
return ''
case Call(loc, tyof1, TermInst(loc2, tyof2, Var(loc3, tyof3, name, rs), tyargs, inferred),
[arg, ls], infix) if base_name(name) == 'node':
[arg, ls]) if base_name(name) == 'node':
return str(arg) + ', ' + nodeListToList(ls)

def mkEmpty(loc):
return Var(loc, None, 'empty', [])

def mkNode(loc, arg, ls):
return Call(loc, None, Var(loc, None, 'node', []), [arg, ls], False)
return Call(loc, None, Var(loc, None, 'node', []), [arg, ls])

def listToNodeList(loc, lst):
if len(lst) == 0:
Expand All @@ -2587,7 +2590,7 @@ def listToNodeList(loc, lst):
def isEmptySet(t):
match t:
case Call(loc2, tyof2, TermInst(loc1, tyof1, Var(loc5, tyof5, name, rs), tyargs, implicit),
[Lambda(loc3, tyof3, vars, Bool(loc4, tyof4, False))], infix) \
[Lambda(loc3, tyof3, vars, Bool(loc4, tyof4, False))]) \
if base_name(name) == 'char_fun':
return True
case _:
Expand Down Expand Up @@ -2850,7 +2853,7 @@ def count_marks(formula):
return count_marks(frm2)
case Some(loc2, tyof, vars, frm2):
return count_marks(frm2)
case Call(loc2, tyof, rator, args, infix):
case Call(loc2, tyof, rator, args):
return count_marks(rator) + sum([count_marks(arg) for arg in args])
case Switch(loc2, tyof, subject, cases):
return count_marks(subject) + sum([count_marks(c) for c in cases])
Expand Down Expand Up @@ -2896,7 +2899,7 @@ def find_mark(formula):
find_mark(frm2)
case Some(loc2, tyof, vars, frm2):
find_mark(frm2)
case Call(loc2, tyof, rator, args, infix):
case Call(loc2, tyof, rator, args):
find_mark(rator)
for arg in args:
find_mark(arg)
Expand Down Expand Up @@ -2947,9 +2950,9 @@ def replace_mark(formula, replacement):
return All(loc2, tyof, var, pos, replace_mark(frm2, replacement))
case Some(loc2, tyof, vars, frm2):
return Some(loc2, tyof, vars, replace_mark(frm2, replacement))
case Call(loc2, tyof, rator, args, infix):
case Call(loc2, tyof, rator, args):
return Call(loc2, tyof, replace_mark(rator, replacement),
[replace_mark(arg, replacement) for arg in args], infix)
[replace_mark(arg, replacement) for arg in args])
case Switch(loc2, tyof, subject, cases):
return Switch(loc2, tyof, replace_mark(subject, replacement),
[replace_mark(c, replacement) for c in cases])
Expand Down
15 changes: 6 additions & 9 deletions parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def parse_tree_to_ast(e, parent):
return intToDeduceInt(e.meta, int(e.children[0].value), 'PLUS')
elif e.data == 'neg_int':
arg = parse_tree_to_ast(e.children[0], e)
return Call(e.meta, None, Var(e.meta, None, '-'), [arg], False)
return Call(e.meta, None, Var(e.meta, None, '-'), [arg])
elif e.data == 'hole_term':
return Hole(e.meta, None)
elif e.data == 'omitted_term':
Expand Down Expand Up @@ -270,16 +270,15 @@ def parse_tree_to_ast(e, parent):
return Call(e.meta, None,
Var(e.meta, None, 'char_fun', []),
[Lambda(e.meta, None, [('_',None)],
Bool(e.meta, None, False))],
False)
Bool(e.meta, None, False))])
# elif e.data == 'field_access':
# subject = parse_tree_to_ast(e.children[0], e)
# field_name = str(e.children[1].value)
# return FieldAccess(e.meta, None, subject, field_name)
elif e.data == 'call':
rator = parse_tree_to_ast(e.children[0], e)
rands = parse_tree_to_list(e.children[1], e)
return Call(e.meta, None, rator, rands, False)
return Call(e.meta, None, rator, rands)
elif e.data == 'lambda':
return Lambda(e.meta, None,
parse_tree_to_list(e.children[0], e),
Expand All @@ -292,16 +291,14 @@ def parse_tree_to_ast(e, parent):
kids = [parse_tree_to_ast(c, e) for c in e.children]
return IfThen(e.meta, None,
Call(e.meta, None, Var(e.meta, None, '=', []),
kids, True),
kids),
Bool(e.meta, None, False))
elif e.data in infix_ops:
return Call(e.meta, None, Var(e.meta, None, operator_symbol[e.data], []),
[parse_tree_to_ast(c, e) for c in e.children],
True)
[parse_tree_to_ast(c, e) for c in e.children])
elif e.data in prefix_ops:
return Call(e.meta, None, Var(e.meta, None, operator_symbol[e.data], []),
[parse_tree_to_ast(c, e) for c in e.children],
False)
[parse_tree_to_ast(c, e) for c in e.children])
elif e.data == 'switch_case':
e1 , e2 = e.children
return SwitchCase(e.meta, parse_tree_to_ast(e1, e),
Expand Down
Loading

0 comments on commit e42fb33

Please sign in to comment.