Skip to content

Commit

Permalink
added support for arrays
Browse files Browse the repository at this point in the history
  • Loading branch information
jsiek committed Jan 5, 2025
1 parent c86fe87 commit 3b93342
Show file tree
Hide file tree
Showing 8 changed files with 257 additions and 7 deletions.
3 changes: 3 additions & 0 deletions Deduce.lark
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,9 @@ ident: IDENT -> ident
| "not" term_hi -> logical_not
| IDENT -> term_var
| "@" term_hi "<" type_list ">" -> term_inst
| "array" "(" term ")" -> make_array
| term_hi "(" term_list ")" -> call
| term_hi "[" INT "]" -> array_get
| "λ" var_list "{" term "}" -> lambda
| "fun" var_list "{" term "}" -> lambda
| "generic" ident_list "{" term "}" -> generic
Expand Down Expand Up @@ -230,6 +232,7 @@ ident: IDENT -> ident
| "bool" -> bool_type
| "type" -> type_type
| "(" type ")" -> paren
| "[" type "]" -> array_type

?type_list: -> empty
| type -> single
Expand Down
150 changes: 146 additions & 4 deletions abstract_syntax.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,35 @@ def reduce(self, env):
[ty.reduce(env) for ty in self.param_types],
self.return_type.reduce(env))

@dataclass
class ArrayType(Type):
elt_type: Type

def copy(self):
return ArrayType(self.location, self.elt_type.copy())

def __str__(self):
return '[' + (self.elt_type) + ']'

def __eq__(self, other):
match other:
case ArrayType(loc, elt_type):
return self.elt_type == elt_type
case _:
return False

def free_vars(self):
return self.elt_type.free_vars()

def substitute(self, sub):
return ArrayType(self.location, self.elt_type.substitute(sub))

def uniquify(self, env):
self.elt_type.uniquify(env)

def reduce(self, env):
return ArrayType(self.location, self.elt_type.reduce(env))

@dataclass
class TypeInst(Type):
typ: Type
Expand Down Expand Up @@ -801,7 +830,7 @@ def __str__(self):
elif isDeduceInt(self):
return deduceIntToInt(self)
elif isNodeList(self):
return '[' + nodeListToList(self)[:-2] + ']'
return '[' + nodeListToString(self)[:-2] + ']'
elif isEmptySet(self) and not get_verbose():
return '∅'
else:
Expand Down Expand Up @@ -1089,7 +1118,109 @@ def uniquify(self, env):
for ty in self.type_args:
ty.uniquify(env)

@dataclass
class Array(Term):
elements: List[Term]

def __eq__(self, other):
if isinstance(other, Array):
return all([elt == other_elt for (elt, other_elt) in zip(self.elements,
other.elements)])
else:
return False

def copy(self):
return Array(self.location, [elt.copy() for elt in self.elements])

def __str__(self):
return 'array(' + ', '.join([str(elt) for elt in self.elements]) + ')'

def reduce(self, env):
return Array(self.location, self.typeof,
[elt.reduce(env) for elt in self.elements])

def substitute(self, sub):
return Array(self.location, self.typeof,
[elt.substitute(sub) for elt in self.elements])

def uniquify(self, env):
for elt in self.elements:
elt.uniquify(env)

@dataclass
class MakeArray(Term):
subject: Term

def __eq__(self, other):
if isinstance(other, MakeArray):
return self.subject == other.subject
else:
return False

def copy(self):
return MakeArray(self.location, self.typeof,
self.subject.copy())

def __str__(self):
return 'array(' + str(self.subject) + ')'

def reduce(self, env):
subject_red = self.subject.reduce(env)
if isNodeList(subject_red):
elements = nodeListToList(subject_red)
return Array(self.location, self.typeof, elements)
else:
return MakeArray(self.location, self.typeof, self.subject.reduce(env))

def substitute(self, sub):
return MakeArray(self.location, self.typeof,
self.subject.substitute(sub))

def uniquify(self, env):
self.subject.uniquify(env)

@dataclass
class ArrayGet(Term):
subject: Term
index: int

def __eq__(self, other):
if isinstance(other, ArrayGet):
return self.subject == other.subject \
and self.index == other.index
else:
return False

def copy(self):
return ArrayGet(self.location, self.typeof,
self.subject.copy(), self.index)

def __str__(self):
return str(self.subject) + '[' + str(self.index) + ']'

def reduce(self, env):
subject_red = self.subject.reduce(env)
index_red = self.index.reduce(env)
match subject_red:
case Array(loc2, _, elements):
if isNat(index_red):
index = natToInt(index_red)
if 0 <= index and index < len(elements):
return elements[index].reduce(env)
else:
error(self.location, 'array index out of bounds\n' \
+ 'index: ' + str(index) + '\n' \
+ 'array length: ' + str(len(elements)))
return ArrayGet(self.location, self.typeof, subject_red, index_red)

def substitute(self, sub):
return ArrayGet(self.location, self.typeof,
self.subject.substitute(sub),
self.index)

def uniquify(self, env):
self.subject.uniquify(env)

@dataclass
class TLet(Term):
var: str
Expand Down Expand Up @@ -2595,10 +2726,21 @@ def nodeListToList(t):
match t:
case TermInst(loc2, tyof2, Var(loc3, tyof3, name, rs), tyargs, inferred) \
if base_name(name) == 'empty':
return ''
case Call(loc, tyof1, TermInst(loc2, tyof2, Var(loc3, tyof3, name, rs), tyargs, inferred),
return []
case Call(loc, tyof1, TermInst(loc2, tyof2, Var(loc3, tyof3, name, rs),
tyargs, inferred),
[arg, ls]) if base_name(name) == 'node':
return [arg] + nodeListToList(ls)

def nodeListToString(t):
match t:
case TermInst(loc2, tyof2, Var(loc3, tyof3, name, rs), tyargs, inferred) \
if base_name(name) == 'empty':
return ''
case Call(loc, tyof1, TermInst(loc2, tyof2, Var(loc3, tyof3, name, rs),
tyargs, inferred),
[arg, ls]) if base_name(name) == 'node':
return str(arg) + ', ' + nodeListToList(ls)
return str(arg) + ', ' + nodeListToString(ls)

def mkEmpty(loc):
return Var(loc, None, 'empty', [])
Expand Down
Empty file removed actual_error.tmp
Empty file.
10 changes: 10 additions & 0 deletions parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,9 @@ def parse_tree_to_ast(e, parent):
return IntType(e.meta)
elif e.data == 'bool_type':
return BoolType(e.meta)
elif e.data == 'array_type':
elt_type = parse_tree_to_ast(e.children[0])
return ArrayType(e.meta, elt_type)
elif e.data == 'type_type':
return TypeType(e.meta)
elif e.data == 'function_type':
Expand All @@ -216,6 +219,13 @@ def parse_tree_to_ast(e, parent):
parse_tree_to_ast(e.children[0], e),
parse_tree_to_list(e.children[1], e),
False)
elif e.data == 'array_get':
return ArrayGet(e.meta, None,
parse_tree_to_ast(e.children[0], e),
intToNat(e.meta, int(e.children[1])))
elif e.data == 'make_array':
return MakeArray(e.meta, None,
parse_tree_to_ast(e.children[0], e))
elif e.data == 'mark':
return Mark(e.meta, None, parse_tree_to_ast(e.children[0], e))
elif e.data == 'list_literal':
Expand Down
20 changes: 20 additions & 0 deletions proof_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1918,6 +1918,26 @@ def type_synth_term(term, env, recfun, subterms):
ty = BoolType(loc)
ret = Some(loc, ty, vars, new_body)

case MakeArray(loc, _, arg):
lst = type_synth_term(arg, env, recfun, subterms)
match lst.typeof:
case TypeInst(loc2, lst_ty, [elt_type]):
union_def = lookup_union(loc, lst_ty, env)
if base_name(union_def.name) == 'List':
ret = MakeArray(loc, ArrayType(loc, elt_type), lst)
else:
error(loc, 'expected List, not union ' + union_def.name)
case _:
error(loc, 'expected List, not ' + str(lst.typeof))

case ArrayGet(loc, _, array, index):
new_array = type_synth_term(array, env, recfun, subterms)
match new_array.typeof:
case ArrayType(loc2, elt_type):
ret = ArrayGet(loc, elt_type, new_array, index)
case _:
error(loc, 'expected an array, not ' + str(new_array.typeof))

case Call(loc, _, Var(loc2, ty2, name, rs), args) \
if name == '=' or name == '≠':
lhs = type_synth_term(args[0], env, recfun, subterms)
Expand Down
66 changes: 63 additions & 3 deletions rec_desc_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,10 +360,35 @@ def parse_term_hi():
error(meta_from_tokens(token, current_token()),
'expected a term, not\n\t' + quote(current_token().value))

def parse_array_get():
while_parsing = 'while parsing array access\n' \
+ '\tterm ::= term "[" integer "]"\n'
term = parse_term_hi()

while (not end_of_file()) and current_token().type == 'LSQB':
try:
start_token = current_token()
advance()
index = intToNat(meta_from_tokens(current_token(),current_token()),
int(current_token().value))
advance()
if current_token().type != 'RSQB':
error(meta_from_tokens(start_token, current_token()),
'expected closing "]", not\n\t' \
+ current_token().value)
term = ArrayGet(meta_from_tokens(start_token, current_token()), None,
term, index)
advance()
except Exception as e:
meta = meta_from_tokens(start_token, previous_token())
raise Exception(str(e) + '\n' + error_header(meta) + while_parsing)

return term

def parse_call():
while_parsing = 'while parsing function call\n' \
+ '\tterm ::= term "(" term_list ")"\n'
term = parse_term_hi()
term = parse_array_get()

while (not end_of_file()) and current_token().type == 'LPAR':
try:
Expand All @@ -382,14 +407,41 @@ def parse_call():
raise Exception(str(e) + '\n' + error_header(meta) + while_parsing)

return term

def parse_make_array():
if current_token().value == 'array':
while_parsing = 'while parsing array creation\n' \
+ '\tterm ::= "array" "(" term ")"\n'
start_token = current_token()
advance()
try:
if current_token().type != 'LPAR':
error(meta_from_tokens(start_token, current_token()),
'expected open parenthesis "(", not\n\t' \
+ current_token().value)
advance()
arg = parse_term()
if current_token().type != 'RPAR':
error(meta_from_tokens(start_token, current_token()),
'expected closing parenthesis ")", not\n\t' \
+ current_token().value)
term = MakeArray(meta_from_tokens(start_token, current_token()),None,arg)
advance()
except Exception as e:
meta = meta_from_tokens(start_token, previous_token())
raise Exception(str(e) + '\n' + error_header(meta) + while_parsing)
else:
term = parse_call()
return term

def parse_term_mult():
term = parse_call()
term = parse_make_array()

while (not end_of_file()) and current_token().value in mult_operators:
start_token = current_token()
rator = Var(meta_from_tokens(current_token(), current_token()),
None, to_unicode.get(current_token().value, current_token().value))
None, to_unicode.get(current_token().value,
current_token().value))
advance()
right = parse_term_mult()
term = Call(meta_from_tokens(start_token, previous_token()), None,
Expand Down Expand Up @@ -1408,6 +1460,14 @@ def parse_type():
return TypeType(meta_from_tokens(token,token))
elif token.type == 'FN':
return parse_function_type()
elif token.type == 'LSQB':
advance()
elt_type = parse_type()
if current_token().type != 'RSQB':
error(meta_from_tokens(start_token, current_token()),
'expected closing "]", not\n\t' + current_token().value)
advance()
return ArrayType(meta_from_tokens(token, previous_token()))
elif token.type == 'LPAR':
start_token = current_token()
advance()
Expand Down
7 changes: 7 additions & 0 deletions test/should-error/array2.pf
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import Nat
import List

define L = [1,2,3]
define A = array(L)
assert A[3] = 1 // out of bounds

8 changes: 8 additions & 0 deletions test/should-pass/array1.pf
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import Nat
import List

define L = [1,2,3]
define A = array(L)
assert A[0] = 1
assert A[1] = 2
assert A[2] = 3

0 comments on commit 3b93342

Please sign in to comment.