Skip to content

Commit

Permalink
frontend: Support enums in new exhaustiveness check
Browse files Browse the repository at this point in the history
  • Loading branch information
dinfuehr committed Dec 9, 2024
1 parent 0d18f5e commit c854919
Showing 1 changed file with 166 additions and 31 deletions.
197 changes: 166 additions & 31 deletions dora-frontend/src/exhaustiveness.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
use std::collections::HashSet;
use std::collections::HashMap;

use dora_parser::ast;
use dora_parser::ast::visit::{self, Visitor};
use fixedbitset::FixedBitSet;

use crate::sema::{AnalysisData, IdentType, Sema, SourceFileId};
use crate::sema::{AnalysisData, EnumDefinitionId, IdentType, Sema, SourceFileId};
use crate::ErrorMessage;

pub fn check(sa: &Sema) {
Expand Down Expand Up @@ -178,7 +178,7 @@ fn check_match2(
}
}

let missing_patterns = check_exhaustive(matrix, 1);
let missing_patterns = check_exhaustive(sa, matrix, 1);

if !missing_patterns.is_empty() {
sa.report(
Expand All @@ -189,7 +189,7 @@ fn check_match2(
}
}

fn check_exhaustive(matrix: Vec<Vec<Pattern>>, n: usize) -> Vec<Vec<Pattern>> {
fn check_exhaustive(sa: &Sema, matrix: Vec<Vec<Pattern>>, n: usize) -> Vec<Vec<Pattern>> {
if matrix.is_empty() {
return vec![vec![Pattern::Any; n]];
}
Expand All @@ -208,7 +208,7 @@ fn check_exhaustive(matrix: Vec<Vec<Pattern>>, n: usize) -> Vec<Vec<Pattern>> {
.filter_map(|r| specialize_row_for_any(r))
.collect::<Vec<_>>();

let mut result = check_exhaustive(new_matrix, n - 1);
let mut result = check_exhaustive(sa, new_matrix, n - 1);

if result.is_empty() {
return Vec::new();
Expand All @@ -222,26 +222,53 @@ fn check_exhaustive(matrix: Vec<Vec<Pattern>>, n: usize) -> Vec<Vec<Pattern>> {
}

Signature::Complete => {
let ctors = discover_constructors(&matrix);
let arity = 0;
let ctors_count = 2;
let ctor_data = discover_constructors(&matrix);

if ctors.len() == ctors_count {
let mut combined_result = Vec::new();

for id in 0..ctors.len() {
if ctor_data.ctors.len() == ctor_data.total(sa) {
for (&id, &arity) in &ctor_data.ctors {
let new_matrix = matrix
.iter()
.flat_map(|r| specialize_row_for_constructor(r, id, arity))
.collect::<Vec<_>>();

let mut result = check_exhaustive(new_matrix, n - 1);
combined_result.append(&mut result);
let uncovered = check_exhaustive(sa, new_matrix, n + arity - 1);

if !uncovered.is_empty() {
let tail = n - arity;

let result = uncovered
.into_iter()
.map(|mut row| {
let ctor_params = row.drain(tail..).collect::<Vec<_>>();
row.push(ctor_data.kind.pattern(id, ctor_params));
row
})
.collect::<Vec<_>>();

if !result.is_empty() {
return result;
}
}
}

combined_result
Vec::new()
} else {
unimplemented!()
let new_matrix = matrix
.iter()
.filter_map(|r| specialize_row_for_any(r))
.collect::<Vec<_>>();

let mut result = check_exhaustive(sa, new_matrix, n - 1);

if result.is_empty() {
return Vec::new();
}

for row in &mut result {
row.push(Pattern::Any);
}

result
}
}
}
Expand All @@ -265,30 +292,93 @@ fn discover_signature(matrix: &[Vec<Pattern>]) -> Signature {
| LiteralValue::String(..) => Signature::Incomplete,
},
Pattern::Any => Signature::Incomplete,
Pattern::Ctor(..) => Signature::Complete,
Pattern::EnumVariant(..) => Signature::Complete,
Pattern::Tuple(..) => Signature::Complete,
}
}

fn discover_constructors(matrix: &[Vec<Pattern>]) -> HashSet<usize> {
let mut constructors = HashSet::new();
enum CtorKind {
Bool,
Tuple,
Enum(EnumDefinitionId),
}

impl CtorKind {
fn pattern(&self, id: usize, params: Vec<Pattern>) -> Pattern {
match self {
CtorKind::Bool => {
assert!(params.is_empty());
assert!(id == 0 || id == 1);
Pattern::Literal(LiteralValue::Bool(id == 1))
}

CtorKind::Tuple => {
assert!(id == 0);
Pattern::Tuple(params)
}

CtorKind::Enum(enum_id) => Pattern::EnumVariant(*enum_id, id, params),
}
}
}

struct CtorData {
kind: CtorKind,
ctors: HashMap<usize, usize>,
}

impl CtorData {
fn total(&self, sa: &Sema) -> usize {
match self.kind {
CtorKind::Bool => 2,
CtorKind::Enum(enum_id) => sa.enum_(enum_id).variants.len(),
CtorKind::Tuple => 1,
}
}
}

fn discover_constructors(matrix: &[Vec<Pattern>]) -> CtorData {
let mut ctors = HashMap::new();
let mut kind = None;

for row in matrix {
match row.last().expect("missing pattern") {
Pattern::Alt(..) => unimplemented!(),
Pattern::Literal(value) => match value {
LiteralValue::Bool(value) => {
constructors.insert(*value as usize);
match kind {
None => kind = Some(CtorKind::Bool),
Some(CtorKind::Bool) => (),
Some(_) => unreachable!(),
}
ctors.insert(*value as usize, 0);
}
_ => unimplemented!(),
},
Pattern::Any => unimplemented!(),
Pattern::Ctor(..) => unimplemented!(),
Pattern::Tuple(..) => unimplemented!(),
Pattern::Any => (),
Pattern::EnumVariant(enum_id, id, params) => {
match kind {
None => kind = Some(CtorKind::Enum(*enum_id)),
Some(CtorKind::Enum(exp_enum_id)) => assert_eq!(exp_enum_id, *enum_id),
Some(_) => unreachable!(),
}
ctors.insert(*id as usize, params.len());
}
Pattern::Tuple(ref params) => {
match kind {
None => kind = Some(CtorKind::Tuple),
Some(CtorKind::Tuple) => (),
Some(_) => unreachable!(),
}
ctors.insert(0, params.len());
}
}
}

constructors
CtorData {
kind: kind.expect("missing kind"),
ctors: ctors,
}
}

fn check_useful(_matrix: &[Vec<Pattern>], _new_pattern: &[Pattern]) -> bool {
Expand All @@ -300,7 +390,7 @@ fn specialize_row_for_any(row: &[Pattern]) -> Option<Vec<Pattern>> {

match last {
Pattern::Alt(..) => unimplemented!(),
Pattern::Literal(..) | Pattern::Ctor(..) => None,
Pattern::Literal(..) | Pattern::EnumVariant(..) => None,
Pattern::Any => {
let count = row.len();
Some(row[0..count - 1].to_vec())
Expand Down Expand Up @@ -330,7 +420,7 @@ fn specialize_row_for_constructor(

_ => unimplemented!(),
},
Pattern::Ctor(ctor_id, params) => {
Pattern::EnumVariant(_enum_id, ctor_id, params) => {
assert_eq!(arity, params.len());
if id == *ctor_id {
let mut result = row[0..row.len() - 1].to_vec();
Expand Down Expand Up @@ -361,7 +451,7 @@ enum Pattern {
Any,
Literal(LiteralValue),
Tuple(Vec<Pattern>),
Ctor(usize, Vec<Pattern>),
EnumVariant(EnumDefinitionId, usize, Vec<Pattern>),
Alt(Vec<Pattern>),
}

Expand Down Expand Up @@ -414,8 +504,8 @@ fn convert_pattern(sa: &Sema, analysis: &AnalysisData, pattern: &ast::Pattern) -
.get(pattern_ident.id)
.expect("missing ident");
match ident {
IdentType::EnumVariant(_pattern_enum_id, _type_params, variant_idx) => {
Pattern::Ctor(*variant_idx as usize, Vec::new())
IdentType::EnumVariant(pattern_enum_id, _type_params, variant_idx) => {
Pattern::EnumVariant(*pattern_enum_id, *variant_idx as usize, Vec::new())
}

IdentType::Var(_var_id) => Pattern::Any,
Expand Down Expand Up @@ -444,8 +534,8 @@ fn convert_pattern(sa: &Sema, analysis: &AnalysisData, pattern: &ast::Pattern) -
let ident = analysis.map_idents.get(ident.id).expect("missing ident");

match ident {
IdentType::EnumVariant(_pattern_enum_id, _type_params, variant_idx) => {
Pattern::Ctor(*variant_idx as usize, subpatterns)
IdentType::EnumVariant(pattern_enum_id, _type_params, variant_idx) => {
Pattern::EnumVariant(*pattern_enum_id, *variant_idx as usize, subpatterns)
}

_ => unreachable!(),
Expand Down Expand Up @@ -557,4 +647,49 @@ mod tests {
}
");
}

#[test]
fn exhaustive_enum() {
ok("
enum Foo { A, B, C, D }
@NewExhaustiveness
fn f(v: Foo) {
match v {
Foo::A => {}
Foo::B => {}
Foo::C => {}
Foo::D => {}
}
}
");

ok("
enum Foo { A, B, C, D }
@NewExhaustiveness
fn f(v: Foo) {
match v {
Foo::A => {}
Foo::C => {}
Foo::D => {}
_ => {}
}
}
");
}

#[test]
fn exhaustive_enum_through_underscore() {
ok("
enum Foo { A, B, C, D }
@NewExhaustiveness
fn f(v: Foo) {
match v {
Foo::A => {}
Foo::C => {}
Foo::D => {}
_ => {}
}
}
");
}
}

0 comments on commit c854919

Please sign in to comment.