Skip to content

Commit

Permalink
frontend: Support patterns in parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
dinfuehr committed Feb 19, 2025
1 parent e670376 commit 380217e
Show file tree
Hide file tree
Showing 7 changed files with 165 additions and 127 deletions.
4 changes: 4 additions & 0 deletions dora-frontend/src/error/msg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,7 @@ pub enum ErrorMessage {
IndexSetNotImplemented(String),
IndexGetAndIndexSetDoNotMatch,
MissingAssocType(String),
NameBoundMultipleTimesInParams(String),
}

impl ErrorMessage {
Expand Down Expand Up @@ -763,6 +764,9 @@ impl ErrorMessage {
ErrorMessage::MissingAssocType(ref name) => {
format!("Missing associated type `{}`.", name)
}
ErrorMessage::NameBoundMultipleTimesInParams(ref name) => {
format!("Name `{}` bound multiple times in parameter list.", name)
}
}
}
}
Expand Down
35 changes: 19 additions & 16 deletions dora-frontend/src/generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -155,13 +155,12 @@ impl<'a> AstBytecodeGen<'a> {
}

for param in &ast.params {
let var_id = *self.analysis.map_vars.get(param.id).unwrap();
let ty = self.var_ty(var_id);

let ty = self.analysis.ty(param.id);
let bty = bty_from_ty(ty.clone());
params.push(bty);

self.allocate_register_for_var(var_id);
let bty: BytecodeType = register_bty_from_ty(ty);
self.alloc_var(bty);
}

self.builder.set_params(params);
Expand All @@ -188,18 +187,26 @@ impl<'a> AstBytecodeGen<'a> {
};

for (param_idx, param) in ast.params.iter().enumerate() {
let var_id = *self.analysis.map_vars.get(param.id).unwrap();
let var = self.analysis.vars.get_var(var_id);
let reg = Register(next_register_idx + param_idx);

match var.location {
VarLocation::Context(scope_id, field_id) => {
self.store_in_context(reg, scope_id, field_id, self.loc(self.span));
}
if let Some(ident) = param.pattern.to_ident() {
let var_id = *self.analysis.map_vars.get(ident.id).unwrap();

VarLocation::Stack => {
// Nothing to do.
let var = self.analysis.vars.get_var(var_id);

match var.location {
VarLocation::Context(scope_id, field_id) => {
self.store_in_context(reg, scope_id, field_id, self.loc(self.span));
}

VarLocation::Stack => {
self.set_var_reg(var_id, reg);
}
}
} else {
let ty = self.analysis.ty(param.id);
self.setup_pattern_vars(&param.pattern);
self.destruct_pattern_or_fail(&param.pattern, reg, ty);
}
}
}
Expand Down Expand Up @@ -3301,10 +3308,6 @@ impl<'a> AstBytecodeGen<'a> {
self.analysis.ty(id)
}

fn var_ty(&self, id: VarId) -> SourceType {
self.analysis.vars.get_var(id).ty.clone()
}

fn get_intrinsic(&self, id: ast::NodeId) -> Option<IntrinsicInfo> {
let call_type = self.analysis.map_calls.get(id).expect("missing CallType");

Expand Down
43 changes: 25 additions & 18 deletions dora-frontend/src/typeck/function.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::cell::OnceCell;
use std::collections::HashSet;
use std::rc::Rc;
use std::str::Chars;
use std::{f32, f64};
Expand All @@ -11,7 +12,7 @@ use crate::sema::{
PackageDefinitionId, Param, ScopeId, Sema, SourceFileId, TypeParamDefinition, Var, VarAccess,
VarId, VarLocation, Visibility,
};
use crate::typeck::{check_expr, check_stmt, CallArguments};
use crate::typeck::{check_expr, check_pattern, check_stmt, CallArguments};
use crate::{
always_returns, expr_always_returns, replace_type, report_sym_shadow_span, ModuleSymTable,
SourceType, SourceTypeArray, SymbolKind,
Expand Down Expand Up @@ -305,35 +306,41 @@ impl<'a> TypeCheck<'a> {
let self_count = if self.has_hidden_self_argument { 1 } else { 0 };
assert_eq!(ast.params.len() + self_count, self.param_types.len());

for (ind, (ast_param, param)) in ast
.params
let param_types = self
.param_types
.iter()
.zip(self.param_types.iter().skip(self_count))
.enumerate()
.skip(self_count)
.map(|p| p.ty())
.collect::<Vec<_>>();

let mut bound_params = HashSet::new();

for (ind, (ast_param, param_ty)) in
ast.params.iter().zip(param_types.into_iter()).enumerate()
{
// is this last argument of function with variadic arguments?
let ty = if ind == ast.params.len() - 1
&& ast.params.last().expect("missing param").variadic
{
// type of variable is Array[T]
self.sa.known.array_ty(param.ty())
self.sa.known.array_ty(param_ty)
} else {
param.ty()
param_ty
};

let ident_pattern = ast_param.pattern.to_ident().expect("missing name");

let name = self.sa.interner.intern(&ident_pattern.name.name_as_string);
self.analysis.set_ty(ast_param.id, ty.clone());

let var_id = self.vars.add_var(name, ty, ident_pattern.mutable);
self.analysis
.map_vars
.insert(ast_param.id, self.vars.local_var_id(var_id));
let local_bound_params = check_pattern(self, &ast_param.pattern, ty);

// params are only allowed to replace functions, vars cannot be replaced
let replaced_sym = self.symtable.insert(name, SymbolKind::Var(var_id));
if let Some(replaced_sym) = replaced_sym {
report_sym_shadow_span(self.sa, name, self.file_id, ast_param.span, replaced_sym)
for (name, data) in local_bound_params {
if !bound_params.insert(name) {
let name = self.sa.interner.str(name).to_string();
self.sa.report(
self.file_id,
data.span,
ErrorMessage::NameBoundMultipleTimesInParams(name),
);
}
}
}
}
Expand Down
173 changes: 85 additions & 88 deletions dora-frontend/src/typeck/stmt.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::collections::{HashMap, HashSet};
use std::collections::HashMap;
use std::sync::Arc;

use dora_parser::ast;
use dora_parser::{ast, Span};

use crate::access::{
class_accessible_from, enum_accessible_from, is_default_accessible, struct_accessible_from,
Expand Down Expand Up @@ -76,44 +76,29 @@ fn check_stmt_let(ck: &mut TypeCheck, s: &ast::StmtLetType) {
}

#[derive(Debug, Clone)]
struct BindingData {
var_id: VarId,
ty: SourceType,
}

#[derive(Clone)]
pub struct Bindings {
map: HashMap<Name, BindingData>,
}

impl Bindings {
pub fn new() -> Bindings {
Bindings {
map: HashMap::new(),
}
}

fn insert(&mut self, name: Name, var_id: VarId, ty: SourceType) {
let old = self.map.insert(name, BindingData { var_id, ty });
assert!(old.is_none());
}

fn get(&self, name: Name) -> Option<BindingData> {
self.map.get(&name).cloned()
}
pub struct BindingData {
pub var_id: VarId,
pub ty: SourceType,
pub span: Span,
}

struct Context {
alt: Bindings,
current: HashSet<Name>,
alt_bindings: HashMap<Name, BindingData>,
current: HashMap<Name, BindingData>,
}

pub(super) fn check_pattern(ck: &mut TypeCheck, pattern: &ast::Pattern, ty: SourceType) {
pub(super) fn check_pattern(
ck: &mut TypeCheck,
pattern: &ast::Pattern,
ty: SourceType,
) -> HashMap<Name, BindingData> {
let mut ctxt = Context {
alt: Bindings::new(),
current: HashSet::new(),
alt_bindings: HashMap::new(),
current: HashMap::new(),
};
check_pattern_inner(ck, &mut ctxt, pattern, ty);

ctxt.current
}

fn check_pattern_inner(
Expand Down Expand Up @@ -183,48 +168,56 @@ fn check_pattern_inner(
}

ast::Pattern::Alt(ref p) => {
let mut bindings = Bindings::new();
let mut alt_bindings: Vec<HashSet<Name>> = Vec::with_capacity(p.alts.len());
let mut bindings_per_alt: Vec<HashMap<Name, BindingData>> =
Vec::with_capacity(p.alts.len());
let mut all_bindings = HashMap::new();

for alt in &p.alts {
let mut alt_ctxt = Context {
alt: bindings,
alt_bindings: all_bindings,
current: ctxt.current.clone(),
};

check_pattern_inner(ck, &mut alt_ctxt, alt.as_ref(), ty.clone());
bindings = alt_ctxt.alt;

let new_bindings = alt_ctxt
.current
.difference(&ctxt.current)
.map(|n| *n)
.collect::<HashSet<Name>>();
alt_bindings.push(new_bindings);
}
all_bindings = alt_ctxt.alt_bindings;

let mut all = alt_bindings.pop().expect("no element");
let mut intersect = all.clone();
let mut local_bindings = HashMap::new();

for alt in alt_bindings {
all = all.union(&alt).map(|n| *n).collect::<HashSet<Name>>();
intersect = intersect
.intersection(&alt)
.map(|n| *n)
.collect::<HashSet<Name>>();
}
for (name, data) in &alt_ctxt.current {
if !ctxt.current.contains_key(name) {
local_bindings.insert(*name, data.clone());
all_bindings.entry(*name).or_insert_with(|| data.clone());
}
}

for &name in all.difference(&intersect) {
let name = ck.sa.interner.str(name).to_string();
let msg = ErrorMessage::PatternBindingNotDefinedInAllAlternatives(name);
ck.sa.report(ck.file_id, p.span, msg);
bindings_per_alt.push(local_bindings);
}

for (name, data) in bindings.map {
let old = ctxt.alt.map.insert(name, data.clone());
assert!(old.is_none());
for (name, data) in all_bindings {
let mut defined_in_all_alternatives = true;

for local_bindings in &bindings_per_alt {
if let Some(local_data) = local_bindings.get(&name) {
if !data.ty.allows(ck.sa, local_data.ty.clone())
&& !local_data.ty.is_error()
{
let ty = local_data.ty.name(ck.sa);
let expected_ty = data.ty.name(ck.sa);
let msg = ErrorMessage::PatternBindingWrongType(ty, expected_ty);
ck.sa.report(ck.file_id, local_data.span, msg);
}
} else {
defined_in_all_alternatives = false;
}
}

assert!(ctxt.current.insert(name));
if !defined_in_all_alternatives {
let name = ck.sa.interner.str(name).to_string();
let msg = ErrorMessage::PatternBindingNotDefinedInAllAlternatives(name);
ck.sa.report(ck.file_id, data.span, msg);
}

assert!(ctxt.current.insert(name, data).is_none());
}
}

Expand Down Expand Up @@ -656,40 +649,44 @@ fn check_pattern_var(
) {
let name = ck.sa.interner.intern(&pattern.name.name_as_string);

if ctxt.current.contains(&name) {
if ctxt.current.contains_key(&name) {
let msg = ErrorMessage::PatternDuplicateBinding;
ck.sa.report(ck.file_id, pattern.span, msg);
} else if let Some(data) = ctxt.alt.get(name) {
if !data.ty.allows(ck.sa, ty.clone()) && !ty.is_error() {
let ty = ty.name(ck.sa);
let expected_ty = data.ty.name(ck.sa);
let msg = ErrorMessage::PatternBindingWrongType(ty, expected_ty);
ck.sa.report(ck.file_id, pattern.span, msg);
}

assert!(ctxt.current.insert(name));

ck.analysis
.map_idents
.insert(pattern.id, IdentType::Var(data.var_id));
} else {
let nested_var_id = ck.vars.add_var(name, ty.clone(), pattern.mutable);
let var_id = ck.vars.local_var_id(nested_var_id);

add_local(
ck.sa,
ck.symtable,
ck.vars,
nested_var_id,
ck.file_id,
pattern.span,
);
let var_id = if let Some(data) = ctxt.alt_bindings.get(&name) {
data.var_id
} else {
let nested_var_id = ck.vars.add_var(name, ty.clone(), pattern.mutable);
let var_id = ck.vars.local_var_id(nested_var_id);

add_local(
ck.sa,
ck.symtable,
ck.vars,
nested_var_id,
ck.file_id,
pattern.span,
);

ctxt.alt.insert(name, var_id, ty.clone());
assert!(ctxt.current.insert(name));
var_id
};

assert!(ctxt
.current
.insert(
name,
BindingData {
var_id,
ty,
span: pattern.span
}
)
.is_none());

ck.analysis
.map_idents
.insert(pattern.id, IdentType::Var(var_id));

ck.analysis.map_vars.insert(pattern.id, var_id);
}
}
Loading

0 comments on commit 380217e

Please sign in to comment.