diff --git a/dora-frontend/src/exhaustiveness.rs b/dora-frontend/src/exhaustiveness.rs index b41ce2f03..8c9e434de 100644 --- a/dora-frontend/src/exhaustiveness.rs +++ b/dora-frontend/src/exhaustiveness.rs @@ -1,8 +1,9 @@ -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use std::fmt::{self, Write}; use dora_parser::ast; use dora_parser::ast::visit::{self, Visitor}; +use dora_parser::Span; use fixedbitset::FixedBitSet; use crate::sema::{AnalysisData, EnumDefinitionId, IdentType, Sema, SourceFileId}; @@ -175,7 +176,7 @@ fn check_match2( if arm.cond.is_some() { row.push(Pattern::Guard); } else { - row.push(Pattern::Any); + row.push(Pattern::any_no_span()); } } row.push(convert_pattern(sa, analysis, &arm.pattern)); @@ -218,11 +219,11 @@ fn check_match2( fn display_pattern(sa: &Sema, pattern: Pattern, output: &mut String) -> fmt::Result { match pattern { - Pattern::Alt(params) => { - assert!(params.len() > 1); + Pattern::Alt { alts, .. } => { + assert!(alts.len() > 1); let mut first = true; - for param in params.into_iter().rev() { + for param in alts.into_iter().rev() { if !first { output.write_str(" | ")?; } @@ -232,7 +233,7 @@ fn display_pattern(sa: &Sema, pattern: Pattern, output: &mut String) -> fmt::Res Ok(()) } - Pattern::Literal(value) => match value { + Pattern::Literal { value, .. } => match value { LiteralValue::Bool(value) => write!(output, "{}", value), LiteralValue::Char(value) => write!(output, "{}", value), LiteralValue::Int(value) => write!(output, "{}", value), @@ -240,7 +241,12 @@ fn display_pattern(sa: &Sema, pattern: Pattern, output: &mut String) -> fmt::Res LiteralValue::String(value) => write!(output, "{:?}", value), }, - Pattern::EnumVariant(enum_id, variant_id, params) => { + Pattern::EnumVariant { + enum_id, + variant_id, + params, + .. + } => { let enum_ = sa.enum_(enum_id); let variant = enum_.variants()[variant_id].name; write!(output, "{}::{}", enum_.name(sa), sa.interner.str(variant))?; @@ -263,9 +269,9 @@ fn display_pattern(sa: &Sema, pattern: Pattern, output: &mut String) -> fmt::Res Ok(()) } - Pattern::Any => write!(output, "_"), + Pattern::Any { .. } => write!(output, "_"), - Pattern::Tuple(params) => { + Pattern::Tuple { params, .. } => { let mut first = true; write!(output, "(")?; @@ -290,7 +296,7 @@ fn check_exhaustive(sa: &Sema, matrix: Vec>, n: usize) -> Vec>, n: usize) -> Vec>, n: usize) -> Vec>, n: usize) -> Vec { assert!(params.is_empty()); assert!(id == 0 || id == 1); - Pattern::Literal(LiteralValue::Bool(id == 1)) + Pattern::Literal { + span: Span::new(1, 1), + value: LiteralValue::Bool(id == 1), + } } CtorKind::Tuple => { assert!(id == 0); - Pattern::Tuple(params) + Pattern::Tuple { + span: Span::new(1, 1), + params, + } } - CtorKind::Enum(enum_id) => Pattern::EnumVariant(*enum_id, id, params), + CtorKind::Enum(enum_id) => Pattern::EnumVariant { + span: Span::new(1, 1), + enum_id: *enum_id, + variant_id: id, + params, + }, } } @@ -469,12 +486,12 @@ fn discover_signature_for_pattern( kind: &mut Option, ) { match pattern { - Pattern::Alt(ref params) => { - for param in params { + Pattern::Alt { ref alts, .. } => { + for param in alts { discover_signature_for_pattern(param, ctors, kind); } } - Pattern::Literal(value) => match value { + Pattern::Literal { value, .. } => match value { LiteralValue::Bool(value) => { match kind { None => *kind = Some(CtorKind::Bool), @@ -488,16 +505,21 @@ fn discover_signature_for_pattern( | LiteralValue::Int(..) | LiteralValue::String(..) => {} }, - Pattern::Any => (), - Pattern::EnumVariant(enum_id, id, params) => { + Pattern::Any { .. } => (), + Pattern::EnumVariant { + enum_id, + variant_id, + params, + .. + } => { match *kind { None => *kind = Some(CtorKind::Enum(*enum_id)), Some(CtorKind::Enum(exp_enum_id)) => assert_eq!(exp_enum_id, *enum_id), Some(_) => unreachable!(), } - ctors.insert(*id as usize, params.len()); + ctors.insert(*variant_id as usize, params.len()); } - Pattern::Tuple(ref params) => { + Pattern::Tuple { ref params, .. } => { match *kind { None => *kind = Some(CtorKind::Tuple), Some(CtorKind::Tuple) => (), @@ -509,6 +531,36 @@ fn discover_signature_for_pattern( } } +#[allow(unused)] +fn check_useful_expand( + sa: &Sema, + matrix: Vec>, + mut row: Vec, +) -> HashSet { + check_useful_expand_inner( + sa, + matrix, + Vec::new(), + Vec::new(), + row, + Vec::new(), + Vec::new(), + ) +} + +#[allow(unused)] +fn check_useful_expand_inner( + sa: &Sema, + matrix_p: Vec>, + matrix_q: Vec>, + matrix_r: Vec>, + mut pattern_p: Vec, + mut pattern_q: Vec, + mut pattern_r: Vec, +) -> HashSet { + unimplemented!() +} + fn check_useful(sa: &Sema, matrix: Vec>, mut pattern: Vec) -> bool { let n = pattern.len(); @@ -527,8 +579,8 @@ fn check_useful(sa: &Sema, matrix: Vec>, mut pattern: Vec) let last = pattern.pop().expect("missing pattern"); match last { - Pattern::Alt(params) => { - for param in params { + Pattern::Alt { alts, .. } => { + for param in alts { let mut param_pattern = pattern.clone(); param_pattern.push(param); if check_useful(sa, matrix.clone(), param_pattern) { @@ -539,16 +591,16 @@ fn check_useful(sa: &Sema, matrix: Vec>, mut pattern: Vec) false } - Pattern::Literal(literal) => { + Pattern::Literal { value, .. } => { let new_matrix = matrix .iter() - .flat_map(|r| specialize_row_for_literal(r, literal.clone())) + .flat_map(|r| specialize_row_for_literal(r, value.clone())) .collect::>(); check_useful(sa, new_matrix, pattern) } - Pattern::Any => { + Pattern::Any { .. } => { let signature = discover_signature(&matrix); match signature { @@ -573,7 +625,8 @@ fn check_useful(sa: &Sema, matrix: Vec>, mut pattern: Vec) .collect::>(); let mut new_pattern = pattern.clone(); - new_pattern.extend(std::iter::repeat(Pattern::Any).take(arity)); + new_pattern + .extend(std::iter::repeat(Pattern::any_no_span()).take(arity)); if check_useful(sa, new_matrix, new_pattern) { return true; @@ -593,7 +646,11 @@ fn check_useful(sa: &Sema, matrix: Vec>, mut pattern: Vec) } } - Pattern::EnumVariant(_enum_id, variant_id, mut params) => { + Pattern::EnumVariant { + variant_id, + mut params, + .. + } => { let arity = params.len(); let new_matrix = matrix .iter() @@ -604,7 +661,7 @@ fn check_useful(sa: &Sema, matrix: Vec>, mut pattern: Vec) check_useful(sa, new_matrix, pattern) } - Pattern::Tuple(mut params) => { + Pattern::Tuple { mut params, .. } => { let arity = params.len(); let new_matrix = matrix @@ -634,7 +691,7 @@ fn specialize_row_for_any(row: &[Pattern]) -> Vec> { let last = result_row.pop().expect("missing pattern"); match last { - Pattern::Alt(params) => params + Pattern::Alt { alts, .. } => alts .into_iter() .flat_map(|p| { result_row.push(p); @@ -643,13 +700,13 @@ fn specialize_row_for_any(row: &[Pattern]) -> Vec> { rows }) .collect::>(), - Pattern::Literal(..) | Pattern::EnumVariant(..) => Vec::new(), + Pattern::Literal { .. } | Pattern::EnumVariant { .. } => Vec::new(), Pattern::Guard => Vec::new(), - Pattern::Any => vec![result_row], + Pattern::Any { .. } => vec![result_row], // This should never be reached as long as all patterns are useful. // This is because tuples conceptually have a single constructor and thus // should always reach the complete signature code path. - Pattern::Tuple(..) => unreachable!(), + Pattern::Tuple { .. } => unreachable!(), } } @@ -658,7 +715,7 @@ fn specialize_row_for_literal(row: &[Pattern], literal: LiteralValue) -> Vec params + Pattern::Alt { alts, .. } => alts .into_iter() .flat_map(|p| { result_row.push(p); @@ -667,17 +724,17 @@ fn specialize_row_for_literal(row: &[Pattern], literal: LiteralValue) -> Vec>(), - Pattern::Literal(lit) => { - if lit == literal { + Pattern::Literal { value, .. } => { + if value == literal { vec![result_row] } else { Vec::new() } } - Pattern::EnumVariant(..) => Vec::new(), - Pattern::Any => vec![result_row], + Pattern::EnumVariant { .. } => Vec::new(), + Pattern::Any { .. } => vec![result_row], // This should never be reached because literals and tuples shouldn't type check. - Pattern::Tuple(..) => unreachable!(), + Pattern::Tuple { .. } => unreachable!(), Pattern::Guard => unimplemented!(), } } @@ -687,7 +744,7 @@ fn specialize_row_for_constructor(row: &[Pattern], id: usize, arity: usize) -> V let last = result_row.pop().expect("missing pattern"); match last { - Pattern::Alt(params) => params + Pattern::Alt { alts, .. } => alts .into_iter() .flat_map(|p| { result_row.push(p); @@ -696,7 +753,7 @@ fn specialize_row_for_constructor(row: &[Pattern], id: usize, arity: usize) -> V rows }) .collect::>(), - Pattern::Literal(value) => match value { + Pattern::Literal { value, .. } => match value { LiteralValue::Bool(value) => { assert_eq!(arity, 0); if id == value as usize { @@ -713,8 +770,12 @@ fn specialize_row_for_constructor(row: &[Pattern], id: usize, arity: usize) -> V | LiteralValue::Int(..) | LiteralValue::String(..) => unreachable!(), }, - Pattern::EnumVariant(_enum_id, ctor_id, mut params) => { - if id == ctor_id { + Pattern::EnumVariant { + variant_id, + mut params, + .. + } => { + if id == variant_id { assert_eq!(arity, params.len()); result_row.append(&mut params); vec![result_row] @@ -722,12 +783,12 @@ fn specialize_row_for_constructor(row: &[Pattern], id: usize, arity: usize) -> V Vec::new() } } - Pattern::Any => { - result_row.extend(std::iter::repeat(Pattern::Any).take(arity)); + Pattern::Any { .. } => { + result_row.extend(std::iter::repeat(Pattern::any_no_span()).take(arity)); vec![result_row] } - Pattern::Tuple(mut params) => { + Pattern::Tuple { mut params, .. } => { assert_eq!(id, 0); result_row.append(&mut params); vec![result_row] @@ -758,22 +819,62 @@ impl fmt::Debug for LiteralValue { } #[derive(Clone)] +#[allow(unused)] enum Pattern { - Any, - Literal(LiteralValue), - Tuple(Vec), - EnumVariant(EnumDefinitionId, usize, Vec), - Alt(Vec), + Any { + span: Option, + }, + Literal { + span: Span, + value: LiteralValue, + }, + Tuple { + span: Span, + params: Vec, + }, + EnumVariant { + span: Span, + enum_id: EnumDefinitionId, + variant_id: usize, + params: Vec, + }, + Alt { + span: Span, + alts: Vec, + }, Guard, } +impl Pattern { + fn any_no_span() -> Pattern { + Pattern::Any { span: None } + } + + #[allow(unused)] + fn span(&self) -> Span { + match self { + Pattern::Any { span } => span.clone().expect("missing span"), + Pattern::Literal { span, .. } => span.clone(), + Pattern::Tuple { span, .. } => span.clone(), + Pattern::EnumVariant { span, .. } => span.clone(), + Pattern::Alt { span, .. } => span.clone(), + Pattern::Guard => unreachable!(), + } + } +} + impl fmt::Debug for Pattern { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { - Pattern::Any => write!(f, "_"), - Pattern::Literal(lit) => write!(f, "{:?}", lit), - Pattern::EnumVariant(enum_id, variant_idx, params) => { - write!(f, "e{}::{}", enum_id.index(), *variant_idx)?; + Pattern::Any { .. } => write!(f, "_"), + Pattern::Literal { value, .. } => write!(f, "{:?}", value), + Pattern::EnumVariant { + enum_id, + variant_id, + params, + .. + } => { + write!(f, "e{}::{}", enum_id.index(), *variant_id)?; if !params.is_empty() { write!(f, "(")?; @@ -790,7 +891,7 @@ impl fmt::Debug for Pattern { Ok(()) } } - Pattern::Tuple(params) => { + Pattern::Tuple { params, .. } => { write!(f, "(")?; let mut first = true; for param in params { @@ -802,13 +903,13 @@ impl fmt::Debug for Pattern { } write!(f, ")") } - Pattern::Alt(params) => { + Pattern::Alt { alts, .. } => { let mut first = true; - for param in params { + for alt in alts { if !first { write!(f, " | ")?; } - write!(f, "{:?}", param)?; + write!(f, "{:?}", alt)?; first = false; } Ok(()) @@ -820,17 +921,21 @@ impl fmt::Debug for Pattern { fn convert_pattern(sa: &Sema, analysis: &AnalysisData, pattern: &ast::Pattern) -> Pattern { match pattern { - ast::Pattern::Underscore(..) => Pattern::Any, + ast::Pattern::Underscore(ref p) => Pattern::Any { span: Some(p.span) }, ast::Pattern::Rest(..) => unreachable!(), - ast::Pattern::LitBool(ref lit) => { - Pattern::Literal(LiteralValue::Bool(lit.expr.is_lit_true())) - } + ast::Pattern::LitBool(ref lit) => Pattern::Literal { + span: pattern.span(), + value: LiteralValue::Bool(lit.expr.is_lit_true()), + }, ast::Pattern::LitInt(ref lit) => { let value = analysis.const_value(lit.id).to_i64().expect("i64 expected"); - Pattern::Literal(LiteralValue::Int(value)) + Pattern::Literal { + span: pattern.span(), + value: LiteralValue::Int(value), + } } ast::Pattern::LitString(ref lit) => { @@ -839,17 +944,26 @@ fn convert_pattern(sa: &Sema, analysis: &AnalysisData, pattern: &ast::Pattern) - .to_string() .cloned() .expect("string expected"); - Pattern::Literal(LiteralValue::String(value)) + Pattern::Literal { + span: lit.span, + value: LiteralValue::String(value), + } } ast::Pattern::LitFloat(ref lit) => { let value = analysis.const_value(lit.id).to_f64().expect("f64 expected"); - Pattern::Literal(LiteralValue::Float(value)) + Pattern::Literal { + span: lit.span, + value: LiteralValue::Float(value), + } } ast::Pattern::LitChar(ref lit) => { let value = analysis.const_value(lit.id).to_char(); - Pattern::Literal(LiteralValue::Char(value)) + Pattern::Literal { + span: lit.span, + value: LiteralValue::Char(value), + } } ast::Pattern::Tuple(ref tuple) => { @@ -859,7 +973,10 @@ fn convert_pattern(sa: &Sema, analysis: &AnalysisData, pattern: &ast::Pattern) - .rev() .map(|p| convert_pattern(sa, analysis, &p)) .collect(); - Pattern::Tuple(patterns) + Pattern::Tuple { + span: tuple.span, + params: patterns, + } } ast::Pattern::Ident(ref pattern_ident) => { @@ -868,25 +985,34 @@ fn convert_pattern(sa: &Sema, analysis: &AnalysisData, pattern: &ast::Pattern) - .get(pattern_ident.id) .expect("missing ident"); match ident { - IdentType::EnumVariant(pattern_enum_id, _type_params, variant_idx) => { - Pattern::EnumVariant(*pattern_enum_id, *variant_idx as usize, Vec::new()) + IdentType::EnumVariant(pattern_enum_id, _type_params, variant_id) => { + Pattern::EnumVariant { + span: pattern_ident.span, + enum_id: *pattern_enum_id, + variant_id: *variant_id as usize, + params: Vec::new(), + } } - IdentType::Var(_var_id) => Pattern::Any, + IdentType::Var(_var_id) => Pattern::Any { + span: Some(pattern_ident.span), + }, _ => unreachable!(), } } - ast::Pattern::Alt(ref p) => Pattern::Alt( - p.alts + ast::Pattern::Alt(ref p) => Pattern::Alt { + span: p.span, + alts: p + .alts .iter() .map(|alt| convert_pattern(sa, analysis, alt.as_ref())) .collect(), - ), + }, - ast::Pattern::ClassOrStructOrEnum(ref ident) => { - let subpatterns = ident + ast::Pattern::ClassOrStructOrEnum(ref p) => { + let subpatterns = p .params .as_ref() .map(|v| { @@ -896,11 +1022,16 @@ fn convert_pattern(sa: &Sema, analysis: &AnalysisData, pattern: &ast::Pattern) - .collect::>() }) .unwrap_or_default(); - let ident = analysis.map_idents.get(ident.id).expect("missing ident"); + let ident = analysis.map_idents.get(p.id).expect("missing ident"); match ident { - IdentType::EnumVariant(pattern_enum_id, _type_params, variant_idx) => { - Pattern::EnumVariant(*pattern_enum_id, *variant_idx as usize, subpatterns) + IdentType::EnumVariant(pattern_enum_id, _type_params, variant_id) => { + Pattern::EnumVariant { + span: p.span, + enum_id: *pattern_enum_id, + variant_id: *variant_id as usize, + params: subpatterns, + } } _ => unreachable!(), @@ -1143,6 +1274,8 @@ mod tests { @NewExhaustiveness fn f(v: Bool) { match v { + true if true => {} + true if true => {} true if true => {} true => {} false => {} @@ -1300,6 +1433,22 @@ mod tests { "); } + #[test] + fn exhaustive_enum_through_underscore_with_guard() { + ok(" + enum Foo { A, B, C, D } + @NewExhaustiveness + fn f(v: Foo) { + match v { + Foo::A => {} + Foo::C if true => {} + Foo::D if true => {} + _ => {} + } + } + "); + } + #[test] fn exhaustive_only_underscore() { ok("