Skip to content

Commit

Permalink
Allow patterns in for loops
Browse files Browse the repository at this point in the history
Closes #58.
  • Loading branch information
osa1 committed Jan 25, 2025
1 parent 3b311ea commit 5fbe49c
Show file tree
Hide file tree
Showing 11 changed files with 1,277 additions and 1,242 deletions.
2 changes: 1 addition & 1 deletion src/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ pub enum AssignOp {

#[derive(Debug, Clone)]
pub struct ForStmt {
pub var: Id,
pub pat: L<Pat>,
pub ty: Option<Type>,
pub expr: L<Expr>,
pub expr_ty: Option<Ty>, // filled in by the type checker
Expand Down
4 changes: 2 additions & 2 deletions src/ast/printer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -382,14 +382,14 @@ impl Stmt {
Stmt::Expr(expr) => expr.node.print(buffer, indent),

Stmt::For(ForStmt {
var,
pat,
ty,
expr,
expr_ty: _,
body,
}) => {
buffer.push_str("for ");
buffer.push_str(var);
pat.node.print(buffer);
assert!(ty.is_none()); // TODO
buffer.push_str(" in ");
expr.node.print(buffer, 0);
Expand Down
4 changes: 2 additions & 2 deletions src/closure_collector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,15 +106,15 @@ fn visit_stmt(
}

ast::Stmt::For(ast::ForStmt {
var,
pat,
ty: _,
expr,
expr_ty: _,
body,
}) => {
visit_expr(&mut expr.node, closures, top_vars, local_vars, free_vars);
local_vars.enter();
local_vars.insert(var.clone());
bind_pat_binders(&pat.node, local_vars);
for stmt in body {
visit_stmt(&mut stmt.node, closures, top_vars, local_vars, free_vars);
}
Expand Down
20 changes: 14 additions & 6 deletions src/interpreter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -914,7 +914,7 @@ fn exec<W: Write>(
},

ast::Stmt::For(ast::ForStmt {
var,
pat,
ty: _,
expr,
expr_ty: _,
Expand Down Expand Up @@ -965,24 +965,32 @@ fn exec<W: Write>(
}

let value = heap[next_item_option + 1];
locals.insert(var.clone(), value);

let binds = try_bind_pat(pgm, heap, pat, value).unwrap_or_else(|| {
panic!(
"{}: For loop pattern failed to match item",
loc_display(&pat.loc)
)
});
locals.extend(binds.clone());
match exec(w, pgm, heap, locals, body) {
ControlFlow::Val(_) => {}
ControlFlow::Ret(val) => {
locals.remove(var);
for var in binds.keys() {
locals.remove(var);
}
return ControlFlow::Ret(val);
}
ControlFlow::Break => break,
ControlFlow::Continue => continue,
ControlFlow::Unwind(val) => {
locals.remove(var);
for var in binds.keys() {
locals.remove(var);
}
return ControlFlow::Unwind(val);
}
}
}

locals.remove(var);
0
}
};
Expand Down
4 changes: 2 additions & 2 deletions src/monomorph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ fn mono_stmt(
ast::Stmt::Expr(expr) => ast::Stmt::Expr(mono_l_expr(expr, ty_map, poly_pgm, mono_pgm)),

ast::Stmt::For(ast::ForStmt {
var,
pat,
ty,
expr,
expr_ty,
Expand Down Expand Up @@ -345,7 +345,7 @@ fn mono_stmt(
);

ast::Stmt::For(ast::ForStmt {
var: var.clone(),
pat: mono_l_pat(pat, ty_map, poly_pgm, mono_pgm),
ty: ty
.as_ref()
.map(|ty| mono_ty(ty, ty_map, poly_pgm, mono_pgm)),
Expand Down
4 changes: 2 additions & 2 deletions src/parser.lalrpop
Original file line number Diff line number Diff line change
Expand Up @@ -564,9 +564,9 @@ Stmt: Stmt = {
<l:@L> <expr:BlockExpr> <r:@R> =>
Stmt::Expr(L::new(module, l, r, expr)),

"for" <id:LowerId> "in" <expr:LExpr> ":" NEWLINE INDENT <statements:LStmts> DEDENT =>
"for" <pat:LPat> "in" <expr:LExpr> ":" NEWLINE INDENT <statements:LStmts> DEDENT =>
Stmt::For(ForStmt {
var: id.smol_str(),
pat,
ty: None,
expr,
expr_ty: None,
Expand Down
2,449 changes: 1,226 additions & 1,223 deletions src/parser.rs

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion src/record_collector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -257,12 +257,13 @@ fn visit_stmt(stmt: &ast::Stmt, records: &mut Set<RecordShape>, variants: &mut S
ast::Stmt::Expr(expr) => visit_expr(&expr.node, records, variants),

ast::Stmt::For(ast::ForStmt {
var: _,
pat,
ty,
expr,
expr_ty: _,
body,
}) => {
visit_pat(&pat.node, records, variants);
if let Some(ty) = ty {
visit_ty(ty, records, variants);
}
Expand Down
3 changes: 2 additions & 1 deletion src/type_checker/instantiation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,13 @@ pub(super) fn normalize_instantiation_types(stmt: &mut ast::Stmt, cons: &ScopeMa
ast::Stmt::Expr(expr) => normalize_expr(&mut expr.node, cons),

ast::Stmt::For(ast::ForStmt {
var: _,
pat,
ty: _,
expr,
expr_ty,
body,
}) => {
normalize_pat(&mut pat.node, cons);
normalize_expr(&mut expr.node, cons);
for stmt in body {
normalize_instantiation_types(&mut stmt.node, cons);
Expand Down
14 changes: 12 additions & 2 deletions src/type_checker/stmt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ fn check_stmt(
ast::Stmt::Expr(expr) => check_expr(tc_state, expr, expected_ty, level, loop_depth),

ast::Stmt::For(ast::ForStmt {
var,
pat,
ty,
expr,
expr_ty,
Expand Down Expand Up @@ -302,7 +302,17 @@ fn check_stmt(
));

tc_state.env.enter();
tc_state.env.insert(var.clone(), item_ty);

let pat_ty = check_pat(tc_state, pat, level);
unify(
&pat_ty,
&item_ty,
tc_state.tys.tys.cons(),
tc_state.var_gen,
level,
&pat.loc,
);

check_stmts(tc_state, body, None, level, loop_depth + 1);
tc_state.env.exit();
unify_expected_ty(
Expand Down
12 changes: 12 additions & 0 deletions tests/ForPatBind.fir
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
main
let v: Vec[(x: U32, y: U32)] = Vec.withCapacity(10)
v.push((x = 1, y = 2))
v.push((x = 3, y = 4))
v.push((x = 5, y = 6))
for (x = x, y = y) in v.iter():
printStr("x = `x.toStr()`, y = `y.toStr()`")

# expected stdout:
# x = 1, y = 2
# x = 3, y = 4
# x = 5, y = 6

0 comments on commit 5fbe49c

Please sign in to comment.