From 3b9334290d1fe26c0493d59896aa6b63a6ececea Mon Sep 17 00:00:00 2001 From: "Jeremy G. Siek" Date: Sun, 5 Jan 2025 11:45:57 -0700 Subject: [PATCH] added support for arrays --- Deduce.lark | 3 + abstract_syntax.py | 150 +++++++++++++++++++++++++++++++++++- actual_error.tmp | 0 parser.py | 10 +++ proof_checker.py | 20 +++++ rec_desc_parser.py | 66 +++++++++++++++- test/should-error/array2.pf | 7 ++ test/should-pass/array1.pf | 8 ++ 8 files changed, 257 insertions(+), 7 deletions(-) delete mode 100644 actual_error.tmp create mode 100644 test/should-error/array2.pf create mode 100644 test/should-pass/array1.pf diff --git a/Deduce.lark b/Deduce.lark index e6def47..9bfaca3 100644 --- a/Deduce.lark +++ b/Deduce.lark @@ -83,7 +83,9 @@ ident: IDENT -> ident | "not" term_hi -> logical_not | IDENT -> term_var | "@" term_hi "<" type_list ">" -> term_inst + | "array" "(" term ")" -> make_array | term_hi "(" term_list ")" -> call + | term_hi "[" INT "]" -> array_get | "λ" var_list "{" term "}" -> lambda | "fun" var_list "{" term "}" -> lambda | "generic" ident_list "{" term "}" -> generic @@ -230,6 +232,7 @@ ident: IDENT -> ident | "bool" -> bool_type | "type" -> type_type | "(" type ")" -> paren + | "[" type "]" -> array_type ?type_list: -> empty | type -> single diff --git a/abstract_syntax.py b/abstract_syntax.py index c04350c..12d2bab 100644 --- a/abstract_syntax.py +++ b/abstract_syntax.py @@ -246,6 +246,35 @@ def reduce(self, env): [ty.reduce(env) for ty in self.param_types], self.return_type.reduce(env)) +@dataclass +class ArrayType(Type): + elt_type: Type + + def copy(self): + return ArrayType(self.location, self.elt_type.copy()) + + def __str__(self): + return '[' + (self.elt_type) + ']' + + def __eq__(self, other): + match other: + case ArrayType(loc, elt_type): + return self.elt_type == elt_type + case _: + return False + + def free_vars(self): + return self.elt_type.free_vars() + + def substitute(self, sub): + return ArrayType(self.location, self.elt_type.substitute(sub)) + + def uniquify(self, env): + self.elt_type.uniquify(env) + + def reduce(self, env): + return ArrayType(self.location, self.elt_type.reduce(env)) + @dataclass class TypeInst(Type): typ: Type @@ -801,7 +830,7 @@ def __str__(self): elif isDeduceInt(self): return deduceIntToInt(self) elif isNodeList(self): - return '[' + nodeListToList(self)[:-2] + ']' + return '[' + nodeListToString(self)[:-2] + ']' elif isEmptySet(self) and not get_verbose(): return '∅' else: @@ -1089,7 +1118,109 @@ def uniquify(self, env): for ty in self.type_args: ty.uniquify(env) +@dataclass +class Array(Term): + elements: List[Term] + + def __eq__(self, other): + if isinstance(other, Array): + return all([elt == other_elt for (elt, other_elt) in zip(self.elements, + other.elements)]) + else: + return False + + def copy(self): + return Array(self.location, [elt.copy() for elt in self.elements]) + + def __str__(self): + return 'array(' + ', '.join([str(elt) for elt in self.elements]) + ')' + + def reduce(self, env): + return Array(self.location, self.typeof, + [elt.reduce(env) for elt in self.elements]) + + def substitute(self, sub): + return Array(self.location, self.typeof, + [elt.substitute(sub) for elt in self.elements]) + + def uniquify(self, env): + for elt in self.elements: + elt.uniquify(env) +@dataclass +class MakeArray(Term): + subject: Term + + def __eq__(self, other): + if isinstance(other, MakeArray): + return self.subject == other.subject + else: + return False + + def copy(self): + return MakeArray(self.location, self.typeof, + self.subject.copy()) + + def __str__(self): + return 'array(' + str(self.subject) + ')' + + def reduce(self, env): + subject_red = self.subject.reduce(env) + if isNodeList(subject_red): + elements = nodeListToList(subject_red) + return Array(self.location, self.typeof, elements) + else: + return MakeArray(self.location, self.typeof, self.subject.reduce(env)) + + def substitute(self, sub): + return MakeArray(self.location, self.typeof, + self.subject.substitute(sub)) + + def uniquify(self, env): + self.subject.uniquify(env) + +@dataclass +class ArrayGet(Term): + subject: Term + index: int + + def __eq__(self, other): + if isinstance(other, ArrayGet): + return self.subject == other.subject \ + and self.index == other.index + else: + return False + + def copy(self): + return ArrayGet(self.location, self.typeof, + self.subject.copy(), self.index) + + def __str__(self): + return str(self.subject) + '[' + str(self.index) + ']' + + def reduce(self, env): + subject_red = self.subject.reduce(env) + index_red = self.index.reduce(env) + match subject_red: + case Array(loc2, _, elements): + if isNat(index_red): + index = natToInt(index_red) + if 0 <= index and index < len(elements): + return elements[index].reduce(env) + else: + error(self.location, 'array index out of bounds\n' \ + + 'index: ' + str(index) + '\n' \ + + 'array length: ' + str(len(elements))) + return ArrayGet(self.location, self.typeof, subject_red, index_red) + + def substitute(self, sub): + return ArrayGet(self.location, self.typeof, + self.subject.substitute(sub), + self.index) + + def uniquify(self, env): + self.subject.uniquify(env) + @dataclass class TLet(Term): var: str @@ -2595,10 +2726,21 @@ def nodeListToList(t): match t: case TermInst(loc2, tyof2, Var(loc3, tyof3, name, rs), tyargs, inferred) \ if base_name(name) == 'empty': - return '' - case Call(loc, tyof1, TermInst(loc2, tyof2, Var(loc3, tyof3, name, rs), tyargs, inferred), + return [] + case Call(loc, tyof1, TermInst(loc2, tyof2, Var(loc3, tyof3, name, rs), + tyargs, inferred), + [arg, ls]) if base_name(name) == 'node': + return [arg] + nodeListToList(ls) + +def nodeListToString(t): + match t: + case TermInst(loc2, tyof2, Var(loc3, tyof3, name, rs), tyargs, inferred) \ + if base_name(name) == 'empty': + return '' + case Call(loc, tyof1, TermInst(loc2, tyof2, Var(loc3, tyof3, name, rs), + tyargs, inferred), [arg, ls]) if base_name(name) == 'node': - return str(arg) + ', ' + nodeListToList(ls) + return str(arg) + ', ' + nodeListToString(ls) def mkEmpty(loc): return Var(loc, None, 'empty', []) diff --git a/actual_error.tmp b/actual_error.tmp deleted file mode 100644 index e69de29..0000000 diff --git a/parser.py b/parser.py index fb87f39..cb414a0 100644 --- a/parser.py +++ b/parser.py @@ -193,6 +193,9 @@ def parse_tree_to_ast(e, parent): return IntType(e.meta) elif e.data == 'bool_type': return BoolType(e.meta) + elif e.data == 'array_type': + elt_type = parse_tree_to_ast(e.children[0]) + return ArrayType(e.meta, elt_type) elif e.data == 'type_type': return TypeType(e.meta) elif e.data == 'function_type': @@ -216,6 +219,13 @@ def parse_tree_to_ast(e, parent): parse_tree_to_ast(e.children[0], e), parse_tree_to_list(e.children[1], e), False) + elif e.data == 'array_get': + return ArrayGet(e.meta, None, + parse_tree_to_ast(e.children[0], e), + intToNat(e.meta, int(e.children[1]))) + elif e.data == 'make_array': + return MakeArray(e.meta, None, + parse_tree_to_ast(e.children[0], e)) elif e.data == 'mark': return Mark(e.meta, None, parse_tree_to_ast(e.children[0], e)) elif e.data == 'list_literal': diff --git a/proof_checker.py b/proof_checker.py index c5271c0..33d5a90 100644 --- a/proof_checker.py +++ b/proof_checker.py @@ -1918,6 +1918,26 @@ def type_synth_term(term, env, recfun, subterms): ty = BoolType(loc) ret = Some(loc, ty, vars, new_body) + case MakeArray(loc, _, arg): + lst = type_synth_term(arg, env, recfun, subterms) + match lst.typeof: + case TypeInst(loc2, lst_ty, [elt_type]): + union_def = lookup_union(loc, lst_ty, env) + if base_name(union_def.name) == 'List': + ret = MakeArray(loc, ArrayType(loc, elt_type), lst) + else: + error(loc, 'expected List, not union ' + union_def.name) + case _: + error(loc, 'expected List, not ' + str(lst.typeof)) + + case ArrayGet(loc, _, array, index): + new_array = type_synth_term(array, env, recfun, subterms) + match new_array.typeof: + case ArrayType(loc2, elt_type): + ret = ArrayGet(loc, elt_type, new_array, index) + case _: + error(loc, 'expected an array, not ' + str(new_array.typeof)) + case Call(loc, _, Var(loc2, ty2, name, rs), args) \ if name == '=' or name == '≠': lhs = type_synth_term(args[0], env, recfun, subterms) diff --git a/rec_desc_parser.py b/rec_desc_parser.py index 64887f5..792937d 100644 --- a/rec_desc_parser.py +++ b/rec_desc_parser.py @@ -360,10 +360,35 @@ def parse_term_hi(): error(meta_from_tokens(token, current_token()), 'expected a term, not\n\t' + quote(current_token().value)) +def parse_array_get(): + while_parsing = 'while parsing array access\n' \ + + '\tterm ::= term "[" integer "]"\n' + term = parse_term_hi() + + while (not end_of_file()) and current_token().type == 'LSQB': + try: + start_token = current_token() + advance() + index = intToNat(meta_from_tokens(current_token(),current_token()), + int(current_token().value)) + advance() + if current_token().type != 'RSQB': + error(meta_from_tokens(start_token, current_token()), + 'expected closing "]", not\n\t' \ + + current_token().value) + term = ArrayGet(meta_from_tokens(start_token, current_token()), None, + term, index) + advance() + except Exception as e: + meta = meta_from_tokens(start_token, previous_token()) + raise Exception(str(e) + '\n' + error_header(meta) + while_parsing) + + return term + def parse_call(): while_parsing = 'while parsing function call\n' \ + '\tterm ::= term "(" term_list ")"\n' - term = parse_term_hi() + term = parse_array_get() while (not end_of_file()) and current_token().type == 'LPAR': try: @@ -382,14 +407,41 @@ def parse_call(): raise Exception(str(e) + '\n' + error_header(meta) + while_parsing) return term + +def parse_make_array(): + if current_token().value == 'array': + while_parsing = 'while parsing array creation\n' \ + + '\tterm ::= "array" "(" term ")"\n' + start_token = current_token() + advance() + try: + if current_token().type != 'LPAR': + error(meta_from_tokens(start_token, current_token()), + 'expected open parenthesis "(", not\n\t' \ + + current_token().value) + advance() + arg = parse_term() + if current_token().type != 'RPAR': + error(meta_from_tokens(start_token, current_token()), + 'expected closing parenthesis ")", not\n\t' \ + + current_token().value) + term = MakeArray(meta_from_tokens(start_token, current_token()),None,arg) + advance() + except Exception as e: + meta = meta_from_tokens(start_token, previous_token()) + raise Exception(str(e) + '\n' + error_header(meta) + while_parsing) + else: + term = parse_call() + return term def parse_term_mult(): - term = parse_call() + term = parse_make_array() while (not end_of_file()) and current_token().value in mult_operators: start_token = current_token() rator = Var(meta_from_tokens(current_token(), current_token()), - None, to_unicode.get(current_token().value, current_token().value)) + None, to_unicode.get(current_token().value, + current_token().value)) advance() right = parse_term_mult() term = Call(meta_from_tokens(start_token, previous_token()), None, @@ -1408,6 +1460,14 @@ def parse_type(): return TypeType(meta_from_tokens(token,token)) elif token.type == 'FN': return parse_function_type() + elif token.type == 'LSQB': + advance() + elt_type = parse_type() + if current_token().type != 'RSQB': + error(meta_from_tokens(start_token, current_token()), + 'expected closing "]", not\n\t' + current_token().value) + advance() + return ArrayType(meta_from_tokens(token, previous_token())) elif token.type == 'LPAR': start_token = current_token() advance() diff --git a/test/should-error/array2.pf b/test/should-error/array2.pf new file mode 100644 index 0000000..8f84609 --- /dev/null +++ b/test/should-error/array2.pf @@ -0,0 +1,7 @@ +import Nat +import List + +define L = [1,2,3] +define A = array(L) +assert A[3] = 1 // out of bounds + diff --git a/test/should-pass/array1.pf b/test/should-pass/array1.pf new file mode 100644 index 0000000..7ec6b69 --- /dev/null +++ b/test/should-pass/array1.pf @@ -0,0 +1,8 @@ +import Nat +import List + +define L = [1,2,3] +define A = array(L) +assert A[0] = 1 +assert A[1] = 2 +assert A[2] = 3