diff --git a/Cargo.lock b/Cargo.lock index 10b4aff613..a42e32d102 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -593,6 +593,7 @@ dependencies = [ name = "fe-codegen" version = "0.14.0-alpha" dependencies = [ + "fe-analyzer", "fe-common", "fe-mir", "fe-new_abi", @@ -676,6 +677,7 @@ version = "0.14.0-alpha" dependencies = [ "fe-abi", "fe-analyzer", + "fe-codegen", "fe-common", "fe-lowering", "fe-mir", diff --git a/crates/codegen/Cargo.toml b/crates/codegen/Cargo.toml index 31f40afbb3..fa5c17df36 100644 --- a/crates/codegen/Cargo.toml +++ b/crates/codegen/Cargo.toml @@ -5,10 +5,12 @@ authors = ["The Fe Developers "] edition = "2021" [dependencies] +fe-analyzer = { path = "../analyzer", version = "^0.14.0-alpha"} fe-mir = { path = "../mir", version = "^0.14.0-alpha" } fe-common = { path = "../common", version = "^0.14.0-alpha" } fe-new_abi = { path = "../new_abi", version = "^0.14.0-alpha" } salsa = "0.16.1" fxhash = "0.2.1" smol_str = "0.1.21" +id-arena = "2.2.1" yultsur = { git = "https://github.com/g-r-a-n-t/yultsur", rev = "ae85470" } \ No newline at end of file diff --git a/crates/codegen/src/db.rs b/crates/codegen/src/db.rs index 835697d507..e8bde2461c 100644 --- a/crates/codegen/src/db.rs +++ b/crates/codegen/src/db.rs @@ -1,8 +1,9 @@ use std::rc::Rc; -use fe_common::db::{Upcast, UpcastMut}; +use fe_analyzer::{db::AnalyzerDbStorage, AnalyzerDb}; +use fe_common::db::{SourceDb, SourceDbStorage, Upcast, UpcastMut}; use fe_mir::{ - db::MirDb, + db::{MirDb, MirDbStorage}, ir::{FunctionBody, FunctionId, FunctionSignature, TypeId}, }; use fe_new_abi::{function::AbiFunction, types::AbiType}; @@ -15,9 +16,55 @@ pub trait CodegenDb: MirDb + Upcast + UpcastMut { fn codegen_legalized_signature(&self, function_id: FunctionId) -> Rc; #[salsa::invoke(queries::function::legalized_body)] fn codegen_legalized_body(&self, function_id: FunctionId) -> Rc; + #[salsa::invoke(queries::function::lower_function)] + fn codegen_lower_function(&self, function_id: FunctionId) -> Rc; #[salsa::invoke(queries::abi::abi_type)] fn codegen_abi_type(&self, ty: TypeId) -> AbiType; #[salsa::invoke(queries::abi::abi_function)] fn codegen_abi_function(&self, function_id: FunctionId) -> Rc; } + +// TODO: Move this to driver. +#[salsa::database(SourceDbStorage, AnalyzerDbStorage, MirDbStorage, CodegenDbStorage)] +#[derive(Default)] +pub struct NewDb { + storage: salsa::Storage, +} +impl salsa::Database for NewDb {} + +impl Upcast for NewDb { + fn upcast(&self) -> &(dyn MirDb + 'static) { + &*self + } +} + +impl UpcastMut for NewDb { + fn upcast_mut(&mut self) -> &mut (dyn MirDb + 'static) { + &mut *self + } +} + +impl Upcast for NewDb { + fn upcast(&self) -> &(dyn SourceDb + 'static) { + &*self + } +} + +impl UpcastMut for NewDb { + fn upcast_mut(&mut self) -> &mut (dyn SourceDb + 'static) { + &mut *self + } +} + +impl Upcast for NewDb { + fn upcast(&self) -> &(dyn AnalyzerDb + 'static) { + &*self + } +} + +impl UpcastMut for NewDb { + fn upcast_mut(&mut self) -> &mut (dyn AnalyzerDb + 'static) { + &mut *self + } +} diff --git a/crates/codegen/src/db/queries/abi.rs b/crates/codegen/src/db/queries/abi.rs index 2720503cc0..7d5e176a6d 100644 --- a/crates/codegen/src/db/queries/abi.rs +++ b/crates/codegen/src/db/queries/abi.rs @@ -79,7 +79,10 @@ pub fn abi_type(db: &dyn CodegenDb, ty: TypeId) -> AbiType { AbiType::Tuple(fields) } - ir::Type::Event(_) | ir::Type::Contract(_) => unreachable!(), - ir::Type::Map(_) => todo!("map type can't be used in parameter or return type"), + ir::Type::Event(_) + | ir::Type::Contract(_) + | ir::Type::Map(_) + | ir::Type::MPtr(_) + | ir::Type::SPtr(_) => unreachable!(), } } diff --git a/crates/codegen/src/db/queries/function.rs b/crates/codegen/src/db/queries/function.rs index 4878dd4be8..814c770cb3 100644 --- a/crates/codegen/src/db/queries/function.rs +++ b/crates/codegen/src/db/queries/function.rs @@ -2,7 +2,10 @@ use std::rc::Rc; use fe_mir::ir::{FunctionBody, FunctionId, FunctionSignature}; -use crate::{db::CodegenDb, yul::legalize}; +use crate::{ + db::CodegenDb, + yul::{isel, legalize}, +}; pub fn legalized_signature(db: &dyn CodegenDb, function: FunctionId) -> Rc { let mut sig = function.signature(db.upcast()).as_ref().clone(); @@ -15,3 +18,7 @@ pub fn legalized_body(db: &dyn CodegenDb, function: FunctionId) -> Rc Rc { + isel::lower_function(db, function).into() +} diff --git a/crates/codegen/src/yul/inst_order.rs b/crates/codegen/src/yul/inst_order.rs deleted file mode 100644 index 1e599100a4..0000000000 --- a/crates/codegen/src/yul/inst_order.rs +++ /dev/null @@ -1,309 +0,0 @@ -#![allow(unused)] - -use fe_mir::{ - analysis::{ - domtree::DFSet, loop_tree::LoopId, post_domtree::PostIDom, ControlFlowGraph, DomTree, - LoopTree, PostDomTree, - }, - ir::{inst::BranchInfo, BasicBlockId, FunctionBody, InstId, ValueId}, -}; - -#[derive(Debug, Clone, Default)] -pub struct InstOrder { - pub order: Vec, -} - -#[derive(Debug, Clone)] -pub enum StructuralInst { - Inst(InstId), - If { - cond: ValueId, - then: Vec, - else_: Vec, - }, - For { - body: Vec, - }, - Break, - Continue, -} - -struct InstSerializer<'a> { - body: &'a FunctionBody, - cfg: ControlFlowGraph, - loop_tree: LoopTree, - df: DFSet, - pd_tree: PostDomTree, - scope: Option, -} - -impl<'a> InstSerializer<'a> { - fn new(body: &'a FunctionBody) -> Self { - let cfg = ControlFlowGraph::compute(body); - let domtree = DomTree::compute(&cfg); - let df = domtree.compute_df(&cfg); - let pd_tree = PostDomTree::compute(body); - let loop_tree = LoopTree::compute(&cfg, &domtree); - - Self { - body, - cfg, - loop_tree, - df, - pd_tree, - scope: None, - } - } - - fn serialize_insts(&mut self) -> InstOrder { - self.scope = None; - let entry = self.cfg.entry(); - let mut order = vec![]; - self.analyze_block(entry, &mut order); - InstOrder { order } - } - - fn analyze_block(&mut self, block: BasicBlockId, order: &mut Vec) { - match self.loop_tree.loop_of_block(block) { - Some(lp) - if block == self.loop_tree.loop_header(lp) - && Some(block) != self.scope.as_ref().and_then(Scope::loop_header) => - { - let loop_exit = self.find_loop_exit(lp); - self.enter_loop_scope(block, loop_exit); - let mut body = vec![]; - self.analyze_block(block, &mut body); - self.exit_scope(); - order.push(StructuralInst::For { body }); - - if let Some(exit) = loop_exit { - self.analyze_block(exit, order); - } - return; - } - _ => {} - }; - - for inst in self.body.order.iter_inst(block) { - if self.body.store.is_terminator(inst) { - break; - } - order.push(StructuralInst::Inst(inst)); - } - - let terminator = self.body.order.terminator(&self.body.store, block).unwrap(); - match self.analyze_terminator(terminator) { - TerminatorInfo::If { - cond, - then, - else_, - merge_block, - } => { - let mut then_body = vec![]; - let mut else_body = vec![]; - - self.enter_if_scope(merge_block); - if let Some(merge_block) = merge_block { - if merge_block == else_ { - self.analyze_block(then, &mut then_body) - } else if merge_block == then { - self.analyze_block(else_, &mut else_body); - } else { - self.analyze_block(then, &mut then_body); - self.analyze_block(else_, &mut else_body); - } - order.push(StructuralInst::If { - cond, - then: then_body, - else_: else_body, - }); - self.exit_scope(); - self.analyze_block(merge_block, order) - } else { - self.analyze_block(then, &mut then_body); - self.analyze_block(else_, &mut else_body); - self.exit_scope(); - order.push(StructuralInst::If { - cond, - then: then_body, - else_: else_body, - }); - } - } - TerminatorInfo::ToMergeBlock => {} - TerminatorInfo::Continue => order.push(StructuralInst::Continue), - TerminatorInfo::Break => order.push(StructuralInst::Break), - TerminatorInfo::FallThrough(next) => self.analyze_block(next, order), - TerminatorInfo::NormalInst(inst) => order.push(StructuralInst::Inst(inst)), - } - } - - fn enter_loop_scope(&mut self, header: BasicBlockId, exit: Option) { - let kind = ScopeKind::Loop { header, exit }; - let current_scope = std::mem::take(&mut self.scope); - self.scope = Some(Scope { - kind, - parent: current_scope.map(Into::into), - }); - } - - fn enter_if_scope(&mut self, merge_block: Option) { - let kind = ScopeKind::If { merge_block }; - let current_scope = std::mem::take(&mut self.scope); - self.scope = Some(Scope { - kind, - parent: current_scope.map(Into::into), - }); - } - - fn exit_scope(&mut self) { - let current_scope = std::mem::take(&mut self.scope); - self.scope = current_scope.unwrap().parent.map(|parent| *parent); - } - - // NOTE: We assume loop has at most one canonical loop exit. - fn find_loop_exit(&self, lp: LoopId) -> Option { - let mut exit_candidates = vec![]; - for block_in_loop in self.loop_tree.iter_blocks_post_order(&self.cfg, lp) { - for &succ in self.cfg.succs(block_in_loop) { - if !self.loop_tree.is_block_in_loop(succ, lp) { - exit_candidates.push(succ); - } - } - } - - if exit_candidates.is_empty() { - return None; - } - - for &cand in &exit_candidates { - // `cand` is true loop exit if the `cand` is contained in the dominance frontier - // of all other candidates. and yeset foo - if exit_candidates.iter().all(|&block| { - if block == cand { - true - } else if let Some(mut df) = self.df.frontiers(block) { - df.any(|frontier| frontier == cand) - } else { - true - } - }) { - return Some(cand); - } - } - - None - } - - fn analyze_terminator(&self, inst: InstId) -> TerminatorInfo { - debug_assert!(self.body.store.is_terminator(inst)); - - match self.body.store.branch_info(inst) { - BranchInfo::Jump(dest) => self.analyze_jump(dest), - BranchInfo::Branch(cond, then, else_) => { - self.analyze_branch(self.body.order.inst_block(inst), cond, then, else_) - } - BranchInfo::NotBranch => TerminatorInfo::NormalInst(inst), - } - } - - // NOTE: We remove critical edges in legalization pass, so `break` and - // `continue` never appear in branch info. - fn analyze_branch( - &self, - block: BasicBlockId, - cond: ValueId, - then: BasicBlockId, - else_: BasicBlockId, - ) -> TerminatorInfo { - let merge_block = match self.pd_tree.post_idom(block) { - PostIDom::DummyEntry | PostIDom::DummyExit => None, - PostIDom::Block(block) => Some(block), - }; - - TerminatorInfo::If { - cond, - then, - else_, - merge_block, - } - } - - fn analyze_jump(&self, dest: BasicBlockId) -> TerminatorInfo { - match &self.scope { - Some(scope) => { - if Some(dest) == scope.loop_header_recursive() { - TerminatorInfo::Continue - } else if Some(dest) == scope.loop_exit_recursive() { - TerminatorInfo::Break - } else if Some(dest) == scope.if_merge_block() { - TerminatorInfo::ToMergeBlock - } else { - TerminatorInfo::FallThrough(dest) - } - } - - None => TerminatorInfo::FallThrough(dest), - } - } -} - -struct Scope { - kind: ScopeKind, - parent: Option>, -} - -#[derive(Debug, Clone, Copy)] -enum ScopeKind { - Loop { - header: BasicBlockId, - exit: Option, - }, - If { - merge_block: Option, - }, -} - -impl Scope { - fn loop_header(&self) -> Option { - match self.kind { - ScopeKind::Loop { header, .. } => Some(header), - _ => None, - } - } - fn loop_header_recursive(&self) -> Option { - match self.kind { - ScopeKind::Loop { header, .. } => Some(header), - _ => self.parent.as_ref()?.loop_header_recursive(), - } - } - - fn loop_exit_recursive(&self) -> Option { - match self.kind { - ScopeKind::Loop { exit, .. } => exit, - _ => self.parent.as_ref()?.loop_exit_recursive(), - } - } - - fn if_merge_block(&self) -> Option { - match self.kind { - ScopeKind::If { merge_block } => merge_block, - _ => None, - } - } -} - -#[derive(Debug, Clone)] -enum TerminatorInfo { - If { - cond: ValueId, - then: BasicBlockId, - else_: BasicBlockId, - merge_block: Option, - }, - ToMergeBlock, - Continue, - Break, - FallThrough(BasicBlockId), - NormalInst(InstId), -} diff --git a/crates/codegen/src/yul/isel/function.rs b/crates/codegen/src/yul/isel/function.rs new file mode 100644 index 0000000000..f35a89dcdd --- /dev/null +++ b/crates/codegen/src/yul/isel/function.rs @@ -0,0 +1,399 @@ +#![allow(unused)] +use super::inst_order::InstSerializer; + +use fe_mir::ir::{ + constant::ConstantValue, + inst::{AssignableValue, BinOp, InstKind, UnOp}, + Constant, FunctionBody, FunctionId, FunctionSignature, InstId, Type, TypeId, Value, ValueId, +}; +use fxhash::FxHashMap; +use smol_str::SmolStr; +use yultsur::{ + yul::{self, Statement}, + *, +}; + +use crate::{db::CodegenDb, yul::isel::inst_order::StructuralInst}; + +const MEMORY_SLOT_SIZE: usize = 1; +const STORAGE_SLOT_SIZE: usize = 32; + +// TODO: consider return type. +pub fn lower_function(db: &dyn CodegenDb, function: FunctionId) -> String { + let sig = db.codegen_legalized_signature(function); + let body = db.codegen_legalized_body(function); + + let func = FuncLowerHelper::new(db, &sig, &body).lower_func(); + format!("{}", func) +} + +struct FuncLowerHelper<'db, 'a> { + db: &'db dyn CodegenDb, + value_map: FxHashMap, + sig: &'a FunctionSignature, + body: &'a FunctionBody, + ret_value: Option, +} + +impl<'db, 'a> FuncLowerHelper<'db, 'a> { + fn new(db: &'db dyn CodegenDb, sig: &'a FunctionSignature, body: &'a FunctionBody) -> Self { + let mut value_map = FxHashMap::default(); + // Register arguments to value_map. + for &value in body.store.locals() { + match body.store.value_data(value) { + Value::Local(local) if local.is_arg => { + let ident = identifier! {(local.name.clone())}; + value_map.insert(value, ident); + } + _ => {} + } + } + + let ret_value = if sig.return_type.is_some() { + Some(identifier! {("$ret")}) + } else { + None + }; + + Self { + db, + value_map, + sig, + body, + ret_value, + } + } + + fn lower_func(&mut self) -> yul::FunctionDefinition { + let name = identifier! {(self.sig.analyzer_func_id.name(self.db.upcast()))}; + + let parameters = self + .sig + .params + .iter() + .map(|param| identifier! {(param.name.clone())}) + .collect(); + + let ret = self + .ret_value + .clone() + .map(|value| vec![value]) + .unwrap_or_default(); + + let body = self.lower_body(); + + yul::FunctionDefinition { + name, + parameters, + returns: ret, + block: body, + } + } + + fn lower_body(&mut self) -> yul::Block { + let inst_order = InstSerializer::new(self.body).serialize(); + + let mut sink = vec![]; + for inst in inst_order { + self.lower_structural_inst(inst, &mut sink) + } + + yul::Block { statements: sink } + } + + fn lower_structural_inst(&mut self, inst: StructuralInst, sink: &mut Vec) { + let stmt = match inst { + StructuralInst::Inst(inst) => self.lower_inst(inst), + StructuralInst::If { cond, then, else_ } => self.lower_if(cond, then, else_), + StructuralInst::For { body } => self.lower_for(body), + StructuralInst::Break => yul::Statement::Break, + StructuralInst::Continue => yul::Statement::Continue, + }; + + sink.push(stmt) + } + + fn lower_inst(&mut self, inst: InstId) -> yul::Statement { + match &self.body.store.inst_data(inst).kind { + InstKind::Declare { local: value } => { + let (value, ident) = match self.body.store.value_data(*value) { + Value::Local(local) => { + if local.is_tmp { + (*value, format!("${}", local.name)) + } else { + (*value, format!("{}", local.name)) + } + } + _ => unreachable!(), + }; + + let ident = identifier! {(ident)}; + self.value_map.insert(value, ident.clone()); + yul::Statement::VariableDeclaration(yul::VariableDeclaration { + identifiers: vec![ident], + expression: None, + }) + } + InstKind::Assign { lhs, rhs } => self.lower_assign(lhs, *rhs), + InstKind::Unary { op, value } => { + let result = self.lower_unary(*op, *value); + self.assign_inst_result(inst, result) + } + InstKind::Binary { op, lhs, rhs } => { + let result = self.lower_binary(*op, *lhs, *rhs, inst); + self.assign_inst_result(inst, result) + } + InstKind::Cast { value, to } => { + let result = self.lower_cast(*value, *to); + self.assign_inst_result(inst, result) + } + InstKind::AggregateConstruct { ty, args } => { + todo!() + } + InstKind::AggregateAccess { value, indices } => { + todo!() + } + InstKind::MapAccess { value, key } => { + todo!() + } + InstKind::Call { + func, + args, + call_type, + } => { + todo!() + } + + InstKind::Revert { arg } => { + todo!() + } + + InstKind::Emit { arg } => { + todo!() + } + + InstKind::Return { arg } => { + if let Some(arg) = arg { + let arg = self.value_expr(*arg); + let ret_value = self.ret_value.clone().unwrap(); + let mut statements = vec![statement! { [ret_value] := [arg]}]; + statements.push(yul::Statement::Leave); + Statement::Block(yul::Block { statements }) + } else { + yul::Statement::Leave + } + } + + InstKind::Keccak256 { arg } => { + todo!() + } + + InstKind::Clone { arg } => { + todo!() + } + + InstKind::ToMem { arg } => { + todo!() + } + + InstKind::AbiEncode { arg } => { + todo!() + } + + InstKind::Create { value, contract } => { + todo!() + } + + InstKind::Create2 { + value, + salt, + contract, + } => { + todo!() + } + + InstKind::Revert { arg } => todo!(), + InstKind::Emit { arg } => { + todo!() + } + InstKind::Jump { .. } | InstKind::Branch { .. } | InstKind::Nop => { + unreachable!() + } + + InstKind::YulIntrinsic { op, args } => { + todo!() + } + } + } + + fn lower_if( + &mut self, + cond: ValueId, + then: Vec, + else_: Vec, + ) -> yul::Statement { + let cond = self.value_expr(cond); + let mut then_stmts = vec![]; + let mut else_stmts = vec![]; + + for inst in then { + self.lower_structural_inst(inst, &mut then_stmts); + } + for inst in else_ { + self.lower_structural_inst(inst, &mut else_stmts); + } + + switch! { + switch ([cond]) + (case 1 {[then_stmts...]}) + (case 0 {[else_stmts...]}) + } + } + + fn lower_for(&mut self, body: Vec) -> yul::Statement { + let mut body_stmts = vec![]; + for inst in body { + self.lower_structural_inst(inst, &mut body_stmts); + } + + block_statement! {( + for {} (1) {} + { + [body_stmts...] + } + )} + } + fn lower_assign(&self, _lhs: &AssignableValue, _rhs: ValueId) -> yul::Statement { + todo!() + } + + fn lower_unary(&self, op: UnOp, value: ValueId) -> yul::Expression { + let value = self.value_expr(value); + match op { + UnOp::Not => expression! { iszero([value])}, + UnOp::Neg => { + let zero = literal_expression! {0}; + expression! { sub([zero], [value])} + } + UnOp::Inv => expression! { not([value])}, + } + } + + fn lower_binary(&self, op: BinOp, lhs: ValueId, rhs: ValueId, inst: InstId) -> yul::Expression { + let lhs = self.value_expr(lhs); + let rhs = self.value_expr(rhs); + let is_signed = self + .body + .store + .inst_result(inst) + .map(|val| { + let ty = self.body.store.value_ty(val); + ty.is_signed(self.db.upcast()) + }) + .unwrap_or(false); + match op { + BinOp::Add => expression! {add([lhs], [rhs])}, + BinOp::Sub => expression! {sub([lhs], [rhs])}, + BinOp::Mul => expression! {mul([lhs], [rhs])}, + // TODO: zero division check for div and mod. + BinOp::Div if is_signed => expression! {sdiv([lhs], [rhs])}, + BinOp::Div => expression! {div([lhs], [rhs])}, + BinOp::Mod if is_signed => expression! {smod([lhs], [rhs])}, + BinOp::Mod => expression! {mod([lhs], [rhs])}, + BinOp::Pow => expression! {exp([lhs], [rhs])}, + BinOp::Shl => expression! {shl([lhs], [rhs])}, + BinOp::Shr => expression! {shr([lhs], [rhs])}, + BinOp::BitOr | BinOp::LogicalOr => expression! {or([lhs], [rhs])}, + BinOp::BitXor => expression! {xor([lhs], [rhs])}, + BinOp::BitAnd | BinOp::LogicalAnd => expression! {and([lhs], [rhs])}, + BinOp::Eq => expression! {eq([lhs], [rhs])}, + BinOp::Ne => expression! {is_zero((eq([lhs], [rhs])))}, + BinOp::Ge if is_signed => expression! {is_zero((slt([lhs], [rhs])))}, + BinOp::Ge => expression! {is_zero((lt([lhs], [rhs])))}, + BinOp::Gt if is_signed => expression! {sgt([lhs], [rhs])}, + BinOp::Gt => expression! {gt([lhs], [rhs])}, + BinOp::Le if is_signed => expression! {is_zero((sgt([lhs], [rhs])))}, + BinOp::Le => expression! {is_zero((gt([lhs], [rhs])))}, + BinOp::Lt if is_signed => expression! {slt([lhs], [rhs])}, + BinOp::Lt => expression! {lt([lhs], [rhs])}, + } + } + + fn lower_cast(&self, value: ValueId, to: TypeId) -> yul::Expression { + let from_ty = self.body.store.value_ty(value); + debug_assert!(from_ty.is_primitive(self.db.upcast())); + debug_assert!(to.is_primitive(self.db.upcast())); + let from_size = from_ty.size_of(self.db.upcast(), MEMORY_SLOT_SIZE); + let to_size = to.size_of(self.db.upcast(), MEMORY_SLOT_SIZE); + + let value = self.value_expr(value); + if to_size <= from_size { + let mask = bit_mask_expr(to_size); + expression! { and([value], [mask]) } + } else if from_ty.is_signed(self.db.upcast()) { + let significant = literal_expression! {(from_size-1)}; + expression! { signextend([value], [significant])} + } else { + let mask = bit_mask_expr(from_size); + expression! { and([value], [mask]) } + } + } + + fn assign_inst_result(&mut self, inst: InstId, rhs: yul::Expression) -> yul::Statement { + if let Some(result) = self.body.store.inst_result(inst) { + let tmp = self.make_tmp(result); + statement! {let [tmp] := [rhs]} + } else { + yul::Statement::Expression(rhs) + } + } + + fn value_expr(&self, value: ValueId) -> yul::Expression { + match self.body.store.value_data(value) { + Value::Local(_) | Value::Temporary { .. } => { + let ident = &self.value_map[&value]; + literal_expression! {(ident)} + } + Value::Immediate { imm, .. } => { + let num = format!("{:#x}", imm); + literal_expression! {(num)} + } + Value::Constant { constant, .. } => match &constant.data(self.db.upcast()).value { + ConstantValue::Immediate(imm) => { + let num = format!("{:#x}", imm); + literal_expression! {(num)} + } + ConstantValue::Str(_) => { + todo!() + } + ConstantValue::Bool(true) => { + literal_expression! {1} + } + ConstantValue::Bool(false) => { + literal_expression! {0} + } + }, + Value::Unit { .. } => unreachable!(), + } + } + + fn value_ident(&self, value: ValueId) -> yul::Identifier { + self.value_map[&value].clone() + } + + fn make_tmp(&mut self, tmp: ValueId) -> yul::Identifier { + let tmp_name = format!("$tmp_{}", tmp.index()); + let ident = identifier! {(tmp_name)}; + self.value_map.insert(tmp, ident.clone()); + ident + } +} + +fn bit_mask(byte_size: usize) -> usize { + (1 << (byte_size * 8)) - 1 +} + +fn bit_mask_expr(byte_size: usize) -> yul::Expression { + let mask = format!("{:#x}", bit_mask(byte_size)); + literal_expression! {(mask)} +} diff --git a/crates/codegen/src/yul/isel/inst_order.rs b/crates/codegen/src/yul/isel/inst_order.rs new file mode 100644 index 0000000000..b6e5317e77 --- /dev/null +++ b/crates/codegen/src/yul/isel/inst_order.rs @@ -0,0 +1,830 @@ +use fe_mir::{ + analysis::{ + domtree::DFSet, loop_tree::LoopId, post_domtree::PostIDom, ControlFlowGraph, DomTree, + LoopTree, PostDomTree, + }, + ir::{inst::BranchInfo, BasicBlockId, FunctionBody, InstId, ValueId}, +}; +use fxhash::FxHashSet; + +#[derive(Debug, Clone)] +pub(super) enum StructuralInst { + Inst(InstId), + If { + cond: ValueId, + then: Vec, + else_: Vec, + }, + For { + body: Vec, + }, + Break, + Continue, +} + +pub(super) struct InstSerializer<'a> { + body: &'a FunctionBody, + cfg: ControlFlowGraph, + loop_tree: LoopTree, + df: DFSet, + pd_tree: PostDomTree, + scope: Option, +} + +impl<'a> InstSerializer<'a> { + pub(super) fn new(body: &'a FunctionBody) -> Self { + let cfg = ControlFlowGraph::compute(body); + let domtree = DomTree::compute(&cfg); + let df = domtree.compute_df(&cfg); + let pd_tree = PostDomTree::compute(body); + let loop_tree = LoopTree::compute(&cfg, &domtree); + + Self { + body, + cfg, + loop_tree, + df, + pd_tree, + scope: None, + } + } + + pub(super) fn serialize(&mut self) -> Vec { + self.scope = None; + let entry = self.cfg.entry(); + let mut order = vec![]; + self.serialize_block(entry, &mut order); + order + } + + fn serialize_block(&mut self, block: BasicBlockId, order: &mut Vec) { + match self.loop_tree.loop_of_block(block) { + Some(lp) + if block == self.loop_tree.loop_header(lp) + && Some(block) != self.scope.as_ref().and_then(Scope::loop_header) => + { + let loop_exit = self.find_loop_exit(lp); + self.enter_loop_scope(lp, block, loop_exit); + let mut body = vec![]; + self.serialize_block(block, &mut body); + self.exit_scope(); + order.push(StructuralInst::For { body }); + + match loop_exit { + Some(exit) + if self + .scope + .as_ref() + .map(|scope| scope.if_merge_block() != Some(exit)) + .unwrap_or(true) => + { + self.serialize_block(exit, order); + } + _ => {} + } + + return; + } + _ => {} + }; + + for inst in self.body.order.iter_inst(block) { + if self.body.store.is_terminator(inst) { + break; + } + if !self.body.store.is_nop(inst) { + order.push(StructuralInst::Inst(inst)); + } + } + + let terminator = self.body.order.terminator(&self.body.store, block).unwrap(); + match self.analyze_terminator(terminator) { + TerminatorInfo::If { + cond, + then, + else_, + merge_block, + } => self.serialize_if_terminator(cond, *then, *else_, merge_block, order), + TerminatorInfo::ToMergeBlock => {} + TerminatorInfo::Continue => order.push(StructuralInst::Continue), + TerminatorInfo::Break => order.push(StructuralInst::Break), + TerminatorInfo::FallThrough(next) => self.serialize_block(next, order), + TerminatorInfo::NormalInst(inst) => order.push(StructuralInst::Inst(inst)), + } + } + + fn serialize_if_terminator( + &mut self, + cond: ValueId, + then: TerminatorInfo, + else_: TerminatorInfo, + merge_block: Option, + order: &mut Vec, + ) { + let mut then_body = vec![]; + let mut else_body = vec![]; + + self.enter_if_scope(merge_block); + + let mut serialize_dest = + |dest_info, body: &mut Vec, merge_block| match dest_info { + TerminatorInfo::Break => body.push(StructuralInst::Break), + TerminatorInfo::Continue => body.push(StructuralInst::Continue), + TerminatorInfo::ToMergeBlock => {} + TerminatorInfo::FallThrough(dest) => { + if Some(dest) != merge_block { + self.serialize_block(dest, body); + } + } + _ => unreachable!(), + }; + serialize_dest(then, &mut then_body, merge_block); + serialize_dest(else_, &mut else_body, merge_block); + self.exit_scope(); + + order.push(StructuralInst::If { + cond, + then: then_body, + else_: else_body, + }); + if let Some(merge_block) = merge_block { + self.serialize_block(merge_block, order); + } + } + + fn enter_loop_scope(&mut self, lp: LoopId, header: BasicBlockId, exit: Option) { + let kind = ScopeKind::Loop { lp, header, exit }; + let current_scope = std::mem::take(&mut self.scope); + self.scope = Some(Scope { + kind, + parent: current_scope.map(Into::into), + }); + } + + fn enter_if_scope(&mut self, merge_block: Option) { + let kind = ScopeKind::If { merge_block }; + let current_scope = std::mem::take(&mut self.scope); + self.scope = Some(Scope { + kind, + parent: current_scope.map(Into::into), + }); + } + + fn exit_scope(&mut self) { + let current_scope = std::mem::take(&mut self.scope); + self.scope = current_scope.unwrap().parent.map(|parent| *parent); + } + + // NOTE: We assume loop has at most one canonical loop exit. + fn find_loop_exit(&self, lp: LoopId) -> Option { + let mut exit_candidates = vec![]; + for block_in_loop in self.loop_tree.iter_blocks_post_order(&self.cfg, lp) { + for &succ in self.cfg.succs(block_in_loop) { + if !self.loop_tree.is_block_in_loop(succ, lp) { + exit_candidates.push(succ); + } + } + } + + if exit_candidates.is_empty() { + return None; + } + + if exit_candidates.len() == 1 { + let candidate = exit_candidates[0]; + let exit = if let Some(mut df) = self.df.frontiers(candidate) { + debug_assert_eq!(self.df.frontier_num(candidate), 1); + df.next() + } else { + Some(candidate) + }; + return exit; + } + + // If a candidate is a dominance frontier of all other nodes, then the candidate + // is a loop exit. + for &cand in &exit_candidates { + if exit_candidates.iter().all(|&block| { + if block == cand { + true + } else if let Some(mut df) = self.df.frontiers(block) { + df.any(|frontier| frontier == cand) + } else { + true + } + }) { + return Some(cand); + } + } + + // If all candidates have the same dominance frontier, then the frontier block + // is the canonicalized loop exit. + let mut frontier: FxHashSet<_> = self + .df + .frontiers(exit_candidates.pop().unwrap()) + .map(std::iter::Iterator::collect) + .unwrap_or_default(); + for cand in exit_candidates { + for cand_frontier in self.df.frontiers(cand).unwrap() { + if !frontier.contains(&cand_frontier) { + frontier.remove(&cand_frontier); + } + } + } + debug_assert!(frontier.len() < 2); + frontier.iter().next().copied() + } + + fn analyze_terminator(&self, inst: InstId) -> TerminatorInfo { + debug_assert!(self.body.store.is_terminator(inst)); + + match self.body.store.branch_info(inst) { + BranchInfo::Jump(dest) => self.analyze_jump(dest), + BranchInfo::Branch(cond, then, else_) => { + self.analyze_branch(self.body.order.inst_block(inst), cond, then, else_) + } + BranchInfo::NotBranch => TerminatorInfo::NormalInst(inst), + } + } + + fn analyze_branch( + &self, + block: BasicBlockId, + cond: ValueId, + then: BasicBlockId, + else_: BasicBlockId, + ) -> TerminatorInfo { + let then = Box::new(self.analyze_dest(then)); + let else_ = Box::new(self.analyze_dest(else_)); + + let merge_block = match self.pd_tree.post_idom(block) { + PostIDom::Block(block) => { + if let Some(lp) = self.scope.as_ref().and_then(Scope::loop_recursive) { + if self.loop_tree.is_block_in_loop(block, lp) { + Some(block) + } else { + None + } + } else { + Some(block) + } + } + _ => None, + }; + + TerminatorInfo::If { + cond, + then, + else_, + merge_block, + } + } + + fn analyze_jump(&self, dest: BasicBlockId) -> TerminatorInfo { + self.analyze_dest(dest) + } + + fn analyze_dest(&self, dest: BasicBlockId) -> TerminatorInfo { + match &self.scope { + Some(scope) => { + if Some(dest) == scope.loop_header_recursive() { + TerminatorInfo::Continue + } else if Some(dest) == scope.loop_exit_recursive() { + TerminatorInfo::Break + } else if Some(dest) == scope.if_merge_block() { + TerminatorInfo::ToMergeBlock + } else { + TerminatorInfo::FallThrough(dest) + } + } + + None => TerminatorInfo::FallThrough(dest), + } + } +} + +struct Scope { + kind: ScopeKind, + parent: Option>, +} + +#[derive(Debug, Clone, Copy)] +enum ScopeKind { + Loop { + lp: LoopId, + header: BasicBlockId, + exit: Option, + }, + If { + merge_block: Option, + }, +} + +impl Scope { + fn loop_recursive(&self) -> Option { + match self.kind { + ScopeKind::Loop { lp, .. } => Some(lp), + _ => self.parent.as_ref()?.loop_recursive(), + } + } + + fn loop_header(&self) -> Option { + match self.kind { + ScopeKind::Loop { header, .. } => Some(header), + _ => None, + } + } + + fn loop_header_recursive(&self) -> Option { + match self.kind { + ScopeKind::Loop { header, .. } => Some(header), + _ => self.parent.as_ref()?.loop_header_recursive(), + } + } + + fn loop_exit_recursive(&self) -> Option { + match self.kind { + ScopeKind::Loop { exit, .. } => exit, + _ => self.parent.as_ref()?.loop_exit_recursive(), + } + } + + fn if_merge_block(&self) -> Option { + match self.kind { + ScopeKind::If { merge_block } => merge_block, + _ => None, + } + } +} + +#[derive(Debug, Clone)] +enum TerminatorInfo { + If { + cond: ValueId, + then: Box, + else_: Box, + merge_block: Option, + }, + ToMergeBlock, + Continue, + Break, + FallThrough(BasicBlockId), + NormalInst(InstId), +} + +#[cfg(test)] +mod tests { + use fe_mir::ir::{body_builder::BodyBuilder, inst::InstKind, FunctionId, SourceInfo, TypeId}; + + use super::*; + + fn body_builder() -> BodyBuilder { + BodyBuilder::new(FunctionId(0), SourceInfo::dummy()) + } + + fn serialize_func_body(func: &mut FunctionBody) -> impl Iterator { + InstSerializer::new(func).serialize().into_iter() + } + + fn expect_if( + inst: StructuralInst, + ) -> ( + impl Iterator, + impl Iterator, + ) { + match inst { + StructuralInst::If { then, else_, .. } => (then.into_iter(), else_.into_iter()), + _ => panic!("expect if inst"), + } + } + + fn expect_for(inst: StructuralInst) -> impl Iterator { + match inst { + StructuralInst::For { body } => body.into_iter(), + _ => panic!("expect if inst"), + } + } + + fn expect_break(inst: StructuralInst) { + assert!(matches!(inst, StructuralInst::Break)) + } + + fn expect_continue(inst: StructuralInst) { + assert!(matches!(inst, StructuralInst::Continue)) + } + + fn expect_return(func: &FunctionBody, inst: &StructuralInst) { + match inst { + StructuralInst::Inst(inst) => { + assert!(matches!( + func.store.inst_data(*inst).kind, + InstKind::Return { .. } + )) + } + _ => panic!("expect return"), + } + } + + #[test] + fn if_non_merge() { + // +------+ +-------+ + // | then | <-- | bb0 | + // +------+ +-------+ + // | + // | + // v + // +-------+ + // | else_ | + // +-------+ + let mut builder = body_builder(); + + let then = builder.make_block(); + let else_ = builder.make_block(); + + let dummy_ty = TypeId(0); + let v0 = builder.make_imm_from_bool(true, dummy_ty); + let unit = builder.make_unit(dummy_ty); + + builder.branch(v0, then, else_, SourceInfo::dummy()); + + builder.move_to_block(then); + builder.ret(unit, SourceInfo::dummy()); + + builder.move_to_block(else_); + builder.ret(unit, SourceInfo::dummy()); + + let mut func = builder.build(); + let mut order = serialize_func_body(&mut func); + + let (mut then, mut else_) = expect_if(order.next().unwrap()); + expect_return(&func, &then.next().unwrap()); + assert!(then.next().is_none()); + expect_return(&func, &else_.next().unwrap()); + assert!(else_.next().is_none()); + + assert!(order.next().is_none()); + } + + #[test] + fn if_merge() { + // +------+ +-------+ + // | then | <-- | bb0 | + // +------+ +-------+ + // | | + // | | + // | v + // | +-------+ + // | | else_ | + // | +-------+ + // | | + // | | + // | v + // | +-------+ + // +--------> | merge | + // +-------+ + let mut builder = body_builder(); + + let then = builder.make_block(); + let else_ = builder.make_block(); + let merge = builder.make_block(); + + let dummy_ty = TypeId(0); + let v0 = builder.make_imm_from_bool(true, dummy_ty); + let unit = builder.make_unit(dummy_ty); + + builder.branch(v0, then, else_, SourceInfo::dummy()); + + builder.move_to_block(then); + builder.jump(merge, SourceInfo::dummy()); + + builder.move_to_block(else_); + builder.jump(merge, SourceInfo::dummy()); + + builder.move_to_block(merge); + builder.ret(unit, SourceInfo::dummy()); + + let mut func = builder.build(); + let mut order = serialize_func_body(&mut func); + + let (mut then, mut else_) = expect_if(order.next().unwrap()); + assert!(then.next().is_none()); + assert!(else_.next().is_none()); + + expect_return(&func, &order.next().unwrap()); + assert!(order.next().is_none()); + } + + #[test] + fn simple_loop() { + // +--------+ + // | bb0 | -+ + // +--------+ | + // | | + // | | + // v | + // +--------+ | + // +> | header | | + // | +--------+ | + // | | | + // | | | + // | v | + // | +--------+ | + // +- | latch | | + // +--------+ | + // | | + // | | + // v | + // +--------+ | + // | exit | <+ + // +--------+ + let mut builder = body_builder(); + + let header = builder.make_block(); + let latch = builder.make_block(); + let exit = builder.make_block(); + + let dummy_ty = TypeId(0); + let v0 = builder.make_imm_from_bool(true, dummy_ty); + let unit = builder.make_unit(dummy_ty); + + builder.branch(v0, header, exit, SourceInfo::dummy()); + + builder.move_to_block(header); + builder.jump(latch, SourceInfo::dummy()); + + builder.move_to_block(latch); + builder.branch(v0, header, exit, SourceInfo::dummy()); + + builder.move_to_block(exit); + builder.ret(unit, SourceInfo::dummy()); + + let mut func = builder.build(); + let mut order = serialize_func_body(&mut func); + + let (mut lp, mut empty) = expect_if(order.next().unwrap()); + + let mut body = expect_for(lp.next().unwrap()); + let (mut continue_, mut break_) = expect_if(body.next().unwrap()); + assert!(body.next().is_none()); + + expect_continue(continue_.next().unwrap()); + assert!(continue_.next().is_none()); + + expect_break(break_.next().unwrap()); + assert!(break_.next().is_none()); + + assert!(empty.next().is_none()); + + expect_return(&func, &order.next().unwrap()); + assert!(order.next().is_none()); + } + + #[test] + fn loop_with_continue() { + // +-----+ + // +- | bb0 | + // | +-----+ + // | | + // | | + // | v + // | +---------------+ +-----+ + // | | bb1 | --> | bb3 | + // | +---------------+ +-----+ + // | | ^ ^ | + // | | | +---------+ + // | v | + // | +-----+ | + // | | bb4 | -+ + // | +-----+ + // | | + // | | + // | v + // | +-----+ + // +> | bb2 | + // +-----+ + let mut builder = body_builder(); + + let bb1 = builder.make_block(); + let bb2 = builder.make_block(); + let bb3 = builder.make_block(); + let bb4 = builder.make_block(); + + let dummy_ty = TypeId(0); + let v0 = builder.make_imm_from_bool(true, dummy_ty); + let unit = builder.make_unit(dummy_ty); + + builder.branch(v0, bb1, bb2, SourceInfo::dummy()); + + builder.move_to_block(bb1); + builder.branch(v0, bb3, bb4, SourceInfo::dummy()); + + builder.move_to_block(bb3); + builder.jump(bb1, SourceInfo::dummy()); + + builder.move_to_block(bb4); + builder.branch(v0, bb1, bb2, SourceInfo::dummy()); + + builder.move_to_block(bb2); + builder.ret(unit, SourceInfo::dummy()); + + let mut func = builder.build(); + let mut order = serialize_func_body(&mut func); + + let (mut lp, mut empty) = expect_if(order.next().unwrap()); + assert!(empty.next().is_none()); + + let mut body = expect_for(lp.next().unwrap()); + + let (mut continue_, mut empty) = expect_if(body.next().unwrap()); + expect_continue(continue_.next().unwrap()); + assert!(continue_.next().is_none()); + assert!(empty.next().is_none()); + + let (mut continue_, mut break_) = expect_if(body.next().unwrap()); + expect_continue(continue_.next().unwrap()); + assert!(continue_.next().is_none()); + expect_break(break_.next().unwrap()); + assert!(break_.next().is_none()); + + assert!(body.next().is_none()); + assert!(lp.next().is_none()); + + expect_return(&func, &order.next().unwrap()); + assert!(order.next().is_none()); + } + + #[test] + fn loop_with_break() { + // +-----+ + // +- | bb0 | + // | +-----+ + // | | + // | | +---------+ + // | v v | + // | +---------------+ +-----+ + // | | bb1 | --> | bb4 | + // | +---------------+ +-----+ + // | | | + // | | | + // | v | + // | +-----+ | + // | | bb3 | | + // | +-----+ | + // | | | + // | | | + // | v | + // | +-----+ | + // +> | bb2 | <---------------+ + // +-----+ + let mut builder = body_builder(); + + let bb1 = builder.make_block(); + let bb2 = builder.make_block(); + let bb3 = builder.make_block(); + let bb4 = builder.make_block(); + + let dummy_ty = TypeId(0); + let v0 = builder.make_imm_from_bool(true, dummy_ty); + let unit = builder.make_unit(dummy_ty); + + builder.branch(v0, bb1, bb2, SourceInfo::dummy()); + + builder.move_to_block(bb1); + builder.branch(v0, bb3, bb4, SourceInfo::dummy()); + + builder.move_to_block(bb3); + builder.jump(bb2, SourceInfo::dummy()); + + builder.move_to_block(bb4); + builder.branch(v0, bb1, bb2, SourceInfo::dummy()); + + builder.move_to_block(bb2); + builder.ret(unit, SourceInfo::dummy()); + + let mut func = builder.build(); + let mut order = serialize_func_body(&mut func); + + let (mut lp, mut empty) = expect_if(order.next().unwrap()); + assert!(empty.next().is_none()); + + let mut body = expect_for(lp.next().unwrap()); + + let (mut break_, mut latch) = expect_if(body.next().unwrap()); + expect_break(break_.next().unwrap()); + assert!(break_.next().is_none()); + + let (mut continue_, mut break_) = expect_if(latch.next().unwrap()); + assert!(latch.next().is_none()); + expect_continue(continue_.next().unwrap()); + assert!(continue_.next().is_none()); + expect_break(break_.next().unwrap()); + assert!(break_.next().is_none()); + + assert!(body.next().is_none()); + assert!(lp.next().is_none()); + + expect_return(&func, &order.next().unwrap()); + assert!(order.next().is_none()); + } + + #[test] + fn loop_no_guard() { + // +-----+ + // | bb0 | + // +-----+ + // | + // | + // v + // +-----+ + // | bb1 | <+ + // +-----+ | + // | | + // | | + // v | + // +-----+ | + // | bb2 | -+ + // +-----+ + // | + // | + // v + // +-----+ + // | bb3 | + // +-----+ + let mut builder = body_builder(); + + let bb1 = builder.make_block(); + let bb2 = builder.make_block(); + let bb3 = builder.make_block(); + + let dummy_ty = TypeId(0); + let v0 = builder.make_imm_from_bool(true, dummy_ty); + let unit = builder.make_unit(dummy_ty); + + builder.jump(bb1, SourceInfo::dummy()); + + builder.move_to_block(bb1); + builder.jump(bb2, SourceInfo::dummy()); + + builder.move_to_block(bb2); + builder.branch(v0, bb1, bb3, SourceInfo::dummy()); + + builder.move_to_block(bb3); + builder.ret(unit, SourceInfo::dummy()); + + let mut func = builder.build(); + let mut order = serialize_func_body(&mut func); + + let mut body = expect_for(order.next().unwrap()); + let (mut continue_, mut break_) = expect_if(body.next().unwrap()); + assert!(body.next().is_none()); + + expect_continue(continue_.next().unwrap()); + assert!(continue_.next().is_none()); + + expect_break(break_.next().unwrap()); + assert!(break_.next().is_none()); + + expect_return(&func, &order.next().unwrap()); + assert!(order.next().is_none()); + } + + #[test] + fn infinite_loop() { + // +-----+ + // | bb0 | + // +-----+ + // | + // | + // v + // +-----+ + // | bb1 | <+ + // +-----+ | + // | | + // | | + // v | + // +-----+ | + // | bb2 | -+ + // +-----+ + let mut builder = body_builder(); + + let bb1 = builder.make_block(); + let bb2 = builder.make_block(); + + builder.jump(bb1, SourceInfo::dummy()); + + builder.move_to_block(bb1); + builder.jump(bb2, SourceInfo::dummy()); + + builder.move_to_block(bb2); + builder.jump(bb1, SourceInfo::dummy()); + + let mut func = builder.build(); + let mut order = serialize_func_body(&mut func); + + let mut body = expect_for(order.next().unwrap()); + expect_continue(body.next().unwrap()); + assert!(body.next().is_none()); + + assert!(order.next().is_none()); + } +} diff --git a/crates/codegen/src/yul/isel/mod.rs b/crates/codegen/src/yul/isel/mod.rs index e6a246a6fa..6a99febc02 100644 --- a/crates/codegen/src/yul/isel/mod.rs +++ b/crates/codegen/src/yul/isel/mod.rs @@ -1,25 +1,5 @@ -#![allow(unused)] -use fe_mir::{ - analysis::ControlFlowGraph, - ir::{FunctionBody, FunctionSignature, ValueId}, -}; -use fxhash::FxHashMap; -use smol_str::SmolStr; +mod function; -use crate::db::CodegenDb; +mod inst_order; -struct FuncLowerHelper<'db, 'a> { - db: &'db dyn CodegenDb, - value_map: FxHashMap, - sig: &'a FunctionSignature, - body: &'a FunctionBody, - cfg: ControlFlowGraph, - sink: Vec, - ret_value: Option, -} - -impl<'db, 'a> FuncLowerHelper<'db, 'a> { - fn lower_func(self) -> Vec { - todo!() - } -} +pub use function::lower_function; diff --git a/crates/codegen/src/yul/legalize/body.rs b/crates/codegen/src/yul/legalize/body.rs index 4683402b59..b40af2db68 100644 --- a/crates/codegen/src/yul/legalize/body.rs +++ b/crates/codegen/src/yul/legalize/body.rs @@ -1,7 +1,7 @@ use fe_mir::ir::{ body_cursor::{BodyCursor, CursorLocation}, inst::InstKind, - FunctionBody, Inst, InstId, + FunctionBody, Inst, InstId, Type, TypeId, Value, }; use crate::db::CodegenDb; @@ -12,6 +12,8 @@ pub fn legalize_func_body(db: &dyn CodegenDb, body: &mut FunctionBody) { // Remove critical edges. CriticalEdgeSplitter::new().run(body); + legalize_arg_values(db, body); + // Remove zero-sized types usage. let mut cursor = BodyCursor::new_at_entry(body); loop { @@ -37,7 +39,7 @@ fn legalize_inst_arg(db: &dyn CodegenDb, body: &mut FunctionBody, inst_id: InstI let mut inst = body.store.replace_inst(inst_id, dummy_inst); match &mut inst.kind { - InstKind::AggregateConstruct { args, .. } | InstKind::Call { args, .. } => { + InstKind::Call { args, .. } => { args.retain(|arg| !body.store.value_ty(*arg).is_zero_sized(db.upcast())); } @@ -56,11 +58,64 @@ fn legalize_inst_arg(db: &dyn CodegenDb, body: &mut FunctionBody, inst_id: InstI body.store.replace_inst(inst_id, inst); } -/// Remove instruction result if its type is zero-sized. fn legalize_inst_result(db: &dyn CodegenDb, body: &mut FunctionBody, inst: InstId) { - if let Some(result) = body.store.inst_result(inst) { - if body.store.value_ty(result).is_zero_sized(db.upcast()) { - body.store.remove_inst_result(inst) + let result_value = if let Some(result) = body.store.inst_result(inst) { + result + } else { + return; + }; + + let result_ty = body.store.value_ty(result_value); + if result_ty.is_zero_sized(db.upcast()) { + body.store.remove_inst_result(inst); + return; + }; + + let new_ty = if result_ty.is_aggregate(db.upcast()) || result_ty.is_map(db.upcast()) { + match &body.store.inst_data(inst).kind { + InstKind::AggregateAccess { value, .. } => { + let value_ty = body.store.value_ty(*value); + match value_ty.data(db.upcast()).as_ref() { + Type::MPtr(..) => result_ty.make_mptr(db.upcast()), + Type::SPtr(..) => result_ty.make_sptr(db.upcast()), + _ => unreachable!(), + } + } + _ => result_ty.make_mptr(db.upcast()), + } + } else { + return; + }; + + let value = body.store.value_data_mut(result_value); + change_ty(value, new_ty); +} + +fn legalize_arg_values(db: &dyn CodegenDb, body: &mut FunctionBody) { + for value in body.store.func_args_mut() { + let ty = value.ty(); + if ty.is_contract(db.upcast()) { + let slot_ptr = make_storage_ptr(db, ty); + *value = slot_ptr; + } else if ty.is_aggregate(db.upcast()) { + change_ty(value, ty.make_sptr(db.upcast())) } } } + +fn change_ty(value: &mut Value, new_ty: TypeId) { + match value { + Value::Local(val) => val.ty = new_ty, + Value::Immediate { ty, .. } + | Value::Temporary { ty, .. } + | Value::Unit { ty } + | Value::Constant { ty, .. } => *ty = new_ty, + } +} + +fn make_storage_ptr(db: &dyn CodegenDb, ty: TypeId) -> Value { + debug_assert!(ty.is_contract(db.upcast())); + let ty = ty.make_sptr(db.upcast()); + + Value::Immediate { imm: 0.into(), ty } +} diff --git a/crates/codegen/src/yul/legalize/critical_edge.rs b/crates/codegen/src/yul/legalize/critical_edge.rs index 40a898c18b..aa468db29e 100644 --- a/crates/codegen/src/yul/legalize/critical_edge.rs +++ b/crates/codegen/src/yul/legalize/critical_edge.rs @@ -22,7 +22,7 @@ impl CriticalEdgeSplitter { pub fn run(&mut self, func: &mut FunctionBody) { let cfg = ControlFlowGraph::compute(func); - for block in func.order.iter_block() { + for block in cfg.post_order() { let terminator = func.order.terminator(&func.store, block).unwrap(); self.add_critical_edges(terminator, func, &cfg); } diff --git a/crates/codegen/src/yul/legalize/mod.rs b/crates/codegen/src/yul/legalize/mod.rs index 62e82f78fe..7832fb834b 100644 --- a/crates/codegen/src/yul/legalize/mod.rs +++ b/crates/codegen/src/yul/legalize/mod.rs @@ -1,5 +1,6 @@ mod body; mod critical_edge; +mod runtime; mod signature; pub use body::legalize_func_body; diff --git a/crates/codegen/src/yul/legalize/runtime.rs b/crates/codegen/src/yul/legalize/runtime.rs new file mode 100644 index 0000000000..e45f827096 --- /dev/null +++ b/crates/codegen/src/yul/legalize/runtime.rs @@ -0,0 +1,68 @@ +use fxhash::{FxHashMap, FxHashSet}; +use id_arena::{Arena, Id}; +use yultsur::{ + yul::{self, FunctionCall}, + *, +}; + +pub trait RuntimeFunctionProvider { + fn alloc(&mut self, bytes: yul::Expression) -> yul::Expression; + + fn collect_definitions(&mut self) -> Vec; +} + +pub struct DefaultRuntime { + called: FxHashSet, + functions: Arena, + dispatcher: FxHashMap, +} + +type RuntimeFunctionId = Id; + +impl RuntimeFunctionProvider for DefaultRuntime { + fn alloc(&mut self, bytes: yul::Expression) -> yul::Expression { + let func_id = self.dispatcher["alloc"]; + self.functions[func_id].call(vec![bytes]) + } + + fn collect_definitions(&mut self) -> Vec { + let mut defs = Vec::with_capacity(self.called.len()); + + for func_id in &self.called { + defs.push(self.functions[*func_id].func_def()); + } + + defs + } +} + +struct RuntimeFunction { + name: &'static str, + definition: yul::FunctionDefinition, + arg_num: usize, +} + +impl RuntimeFunction { + fn func_name(&self) -> &'static str { + self.name + } + + fn arg_num(&self) -> usize { + self.arg_num + } + + fn func_def(&self) -> yul::FunctionDefinition { + self.definition.clone() + } + + /// # Panics + /// Panics if a number of arguments doesn't match the definition. + fn call(&self, args: Vec) -> yul::Expression { + debug_assert_eq!(self.arg_num(), args.len()); + + yul::Expression::FunctionCall(FunctionCall { + identifier: identifier! {(self.func_name())}, + arguments: args, + }) + } +} diff --git a/crates/codegen/src/yul/legalize/signature.rs b/crates/codegen/src/yul/legalize/signature.rs index 1354f70f8e..aad5e00319 100644 --- a/crates/codegen/src/yul/legalize/signature.rs +++ b/crates/codegen/src/yul/legalize/signature.rs @@ -4,5 +4,6 @@ use crate::db::CodegenDb; pub fn legalize_func_signature(_db: &dyn CodegenDb, _sig: &mut FunctionSignature) { // TODO: Remove zero sized types from arguments, also remove return type if - // it's zero-sized + // it's zero-size. + // TODO: Remove contract types from arguments. } diff --git a/crates/codegen/src/yul/mod.rs b/crates/codegen/src/yul/mod.rs index ba32f5b6eb..32d435afaa 100644 --- a/crates/codegen/src/yul/mod.rs +++ b/crates/codegen/src/yul/mod.rs @@ -1,4 +1,2 @@ pub mod isel; pub mod legalize; - -mod inst_order; diff --git a/crates/driver/Cargo.toml b/crates/driver/Cargo.toml index a5234c6b72..463d249bca 100644 --- a/crates/driver/Cargo.toml +++ b/crates/driver/Cargo.toml @@ -17,6 +17,7 @@ fe-analyzer = {path = "../analyzer", version = "^0.14.0-alpha"} fe-common = {path = "../common", version = "^0.14.0-alpha"} fe-lowering = {path = "../lowering", version = "^0.14.0-alpha"} fe-mir = {path = "../mir", version = "^0.14.0-alpha"} +fe-codegen = {path = "../codegen", version = "^0.14.0-alpha"} fe-parser = {path = "../parser", version = "^0.14.0-alpha"} fe-yulgen = {path = "../yulgen", version = "^0.14.0-alpha"} fe-yulc = {path = "../yulc", version = "^0.14.0-alpha", features = ["solc-backend"], optional = true} diff --git a/crates/driver/src/lib.rs b/crates/driver/src/lib.rs index 764bd8ea6c..7aa4063ab7 100644 --- a/crates/driver/src/lib.rs +++ b/crates/driver/src/lib.rs @@ -1,6 +1,6 @@ #![allow(unused_imports, dead_code)] -pub use fe_mir::db::NewDb; +pub use fe_codegen::db::{CodegenDb, NewDb}; pub use fe_yulgen::Db; use fe_analyzer::context::Analysis; @@ -99,6 +99,28 @@ pub fn dump_mir_single_file(db: &mut NewDb, path: &str, src: &str) -> Result Result, CompileError> { + let module = ModuleId::new_standalone(db, path, src); + + let diags = module.diagnostics(db); + if !diags.is_empty() { + return Err(CompileError(diags)); + } + + let mut functions = vec![]; + for &func in db.mir_lower_module_all_functions(module).as_ref() { + let yul_func = db.codegen_lower_function(func); + functions.push(yul_func.as_ref().clone()); + } + + Ok(functions) +} + fn compile_module_id( db: &mut Db, module_id: ModuleId, diff --git a/crates/fe/src/main.rs b/crates/fe/src/main.rs index 75c4608700..e8f569eae1 100644 --- a/crates/fe/src/main.rs +++ b/crates/fe/src/main.rs @@ -77,6 +77,12 @@ pub fn main() { .help("dump mir dot file") .takes_value(false), ) + .arg( + Arg::with_name("codegen") + .long("codegen") + .help("todo") + .takes_value(false), + ) .get_matches(); let input_path = matches.value_of("input").unwrap(); @@ -90,6 +96,11 @@ pub fn main() { if matches.is_present("mir") { return mir_dump(input_path); } + + if matches.is_present("codegen") { + return yul_functions_dump(input_path); + } + #[cfg(not(feature = "solc-backend"))] if with_bytecode { eprintln!("Warning: bytecode output requires 'solc-backend' feature. Try `cargo build --release --features solc-backend`. Skipping."); @@ -296,3 +307,32 @@ fn mir_dump(input_path: &str) { std::process::exit(1) } } + +fn yul_functions_dump(input_path: &str) { + let mut db = fe_driver::NewDb::default(); + if Path::new(input_path).is_file() { + let content = match std::fs::read_to_string(input_path) { + Err(err) => { + eprintln!("Failed to load file: `{}`. Error: {}", input_path, err); + std::process::exit(1) + } + Ok(content) => content, + }; + + match fe_driver::dump_codegen_funcs(&mut db, input_path, &content) { + Ok(functions) => { + for func in functions { + println!("{}", func) + } + } + Err(err) => { + eprintln!("Unable to dump mir `{}", input_path); + print_diagnostics(&db, &err.0); + std::process::exit(1) + } + } + } else { + eprintln!("mir doesn't support ingot yet"); + std::process::exit(1) + } +} diff --git a/crates/mir/src/db/queries/types.rs b/crates/mir/src/db/queries/types.rs index 15365acc3e..52f9bd840f 100644 --- a/crates/mir/src/db/queries/types.rs +++ b/crates/mir/src/db/queries/types.rs @@ -7,7 +7,7 @@ use num_traits::ToPrimitive; use crate::{ db::MirDb, - ir::{types::ArrayDef, value::Immediate, Type, TypeId, Value}, + ir::{types::ArrayDef, Type, TypeId, Value}, lower::types::{lower_event_type, lower_type}, }; @@ -25,7 +25,8 @@ impl TypeId { } pub fn projection_ty(self, db: &dyn MirDb, access: &Value) -> TypeId { - match self.data(db).as_ref() { + let ty = self.deref(db); + match ty.data(db).as_ref() { Type::Array(ArrayDef { elem_ty, .. }) => *elem_ty, Type::Tuple(def) => { let index = expect_projection_index(access); @@ -43,6 +44,30 @@ impl TypeId { } } + pub fn deref(self, db: &dyn MirDb) -> TypeId { + match self.data(db).as_ref() { + Type::SPtr(inner) => *inner, + Type::MPtr(inner) => *inner, + _ => self, + } + } + + pub fn deref_recursive(self, db: &dyn MirDb) -> TypeId { + match self.data(db).as_ref() { + Type::SPtr(inner) => inner.deref_recursive(db), + Type::MPtr(inner) => inner.deref_recursive(db), + _ => self, + } + } + + pub fn make_sptr(self, db: &dyn MirDb) -> TypeId { + db.mir_intern_type(Type::SPtr(self).into()) + } + + pub fn make_mptr(self, db: &dyn MirDb) -> TypeId { + db.mir_intern_type(Type::MPtr(self).into()) + } + pub fn projection_ty_imm(self, db: &dyn MirDb, index: usize) -> TypeId { debug_assert!(self.is_aggregate(db)); @@ -55,42 +80,28 @@ impl TypeId { } } - pub fn index_from_fname(self, db: &dyn MirDb, fname: &str, index_ty: TypeId) -> Immediate { - match self.data(db).as_ref() { + pub fn index_from_fname(self, db: &dyn MirDb, fname: &str) -> BigInt { + let ty = self.deref(db); + match ty.data(db).as_ref() { Type::Tuple(_) => { // TODO: Fix this when the syntax for tuple access changes. let index_str = &fname[4..]; - Immediate { - value: BigInt::from_str(index_str).unwrap(), - ty: index_ty, - } + BigInt::from_str(index_str).unwrap() } - Type::Struct(def) | Type::Contract(def) => { - let index = def - .fields - .iter() - .enumerate() - .find_map(|(i, field)| (field.0 == fname).then(|| i.into())) - .unwrap(); - Immediate { - value: index, - ty: index_ty, - } - } + Type::Struct(def) | Type::Contract(def) => def + .fields + .iter() + .enumerate() + .find_map(|(i, field)| (field.0 == fname).then(|| i.into())) + .unwrap(), - Type::Event(def) => { - let index = def - .fields - .iter() - .enumerate() - .find_map(|(i, field)| (field.0 == fname).then(|| i.into())) - .unwrap(); - Immediate { - value: index, - ty: index_ty, - } - } + Type::Event(def) => def + .fields + .iter() + .enumerate() + .find_map(|(i, field)| (field.0 == fname).then(|| i.into())) + .unwrap(), other => unreachable!("{:?} does not have fields", other), } @@ -117,6 +128,31 @@ impl TypeId { ) } + pub fn is_integral(self, db: &dyn MirDb) -> bool { + matches!( + self.data(db).as_ref(), + Type::I8 + | Type::I16 + | Type::I32 + | Type::I64 + | Type::I128 + | Type::I256 + | Type::U8 + | Type::U16 + | Type::U32 + | Type::U64 + | Type::U128 + | Type::U256 + ) + } + + pub fn is_signed(self, db: &dyn MirDb) -> bool { + matches!( + self.data(db).as_ref(), + Type::I8 | Type::I16 | Type::I32 | Type::I64 | Type::I128 | Type::I256 + ) + } + /// Returns size of the type in bytes. pub fn size_of(self, db: &dyn MirDb, slot_size: usize) -> usize { match self.data(db).as_ref() { @@ -125,7 +161,7 @@ impl TypeId { Type::I32 | Type::U32 => 4, Type::I64 | Type::U64 => 8, Type::I128 | Type::U128 => 16, - Type::I256 | Type::U256 | Type::Map(_) => 32, + Type::MPtr(..) | Type::SPtr(..) | Type::I256 | Type::U256 | Type::Map(_) => 32, Type::Address => 20, Type::Unit => 0, @@ -208,6 +244,10 @@ impl TypeId { ) } + pub fn is_map(self, db: &dyn MirDb) -> bool { + matches!(self.data(db).as_ref(), Type::Map(_)) + } + pub fn is_contract(self, db: &dyn MirDb) -> bool { matches!(self.data(db).as_ref(), Type::Contract(_)) } @@ -230,7 +270,7 @@ fn array_elem_size_imp(arr: &ArrayDef, db: &dyn MirDb, slot_size: usize) -> usiz fn expect_projection_index(value: &Value) -> usize { match value { - Value::Immediate(imm) => imm.value.to_usize().unwrap(), + Value::Immediate { imm, .. } => imm.to_usize().unwrap(), _ => panic!("given `value` is not an immediate"), } } diff --git a/crates/mir/src/ir/body_builder.rs b/crates/mir/src/ir/body_builder.rs index 0a18d542e1..940733b395 100644 --- a/crates/mir/src/ir/body_builder.rs +++ b/crates/mir/src/ir/body_builder.rs @@ -4,13 +4,12 @@ use num_bigint::BigInt; use crate::ir::{ body_cursor::{BodyCursor, CursorLocation}, inst::{BinOp, Inst, InstKind, UnOp}, - value::{Local, Temporary}, + value::Local, BasicBlock, BasicBlockId, FunctionBody, FunctionId, SourceInfo, TypeId, ValueId, }; use super::{ - inst::{CallType, YulIntrinsicOp}, - value::{self, Constant, Immediate}, + inst::{AssignableValue, CallType, YulIntrinsicOp}, ConstantId, Value, }; @@ -78,15 +77,11 @@ impl BodyBuilder { } pub fn make_unit(&mut self, unit_ty: TypeId) -> ValueId { - self.body - .store - .store_value(Value::Unit(value::Unit { ty: unit_ty })) + self.body.store.store_value(Value::Unit { ty: unit_ty }) } pub fn make_imm(&mut self, imm: BigInt, ty: TypeId) -> ValueId { - self.body - .store - .store_value(Value::Immediate(Immediate { value: imm, ty })) + self.body.store.store_value(Value::Immediate { imm, ty }) } pub fn make_imm_from_bool(&mut self, imm: bool, ty: TypeId) -> ValueId { @@ -100,12 +95,12 @@ impl BodyBuilder { pub fn make_constant(&mut self, constant: ConstantId, ty: TypeId) -> ValueId { self.body .store - .store_value(Value::Constant(Constant { constant, ty })) + .store_value(Value::Constant { constant, ty }) } pub fn declare(&mut self, local: Local) -> ValueId { let source = local.source.clone(); - let local_id = self.body.store.store_value(local.into()); + let local_id = self.body.store.store_value(Value::Local(local)); let kind = InstKind::Declare { local: local_id }; let inst = Inst::new(kind, source); @@ -114,10 +109,10 @@ impl BodyBuilder { } pub fn store_func_arg(&mut self, local: Local) -> ValueId { - self.body.store.store_value(local.into()) + self.body.store.store_value(Value::Local(local)) } - pub fn assign(&mut self, lhs: ValueId, rhs: ValueId, source: SourceInfo) { + pub fn assign(&mut self, lhs: AssignableValue, rhs: ValueId, source: SourceInfo) { let kind = InstKind::Assign { lhs, rhs }; let inst = Inst::new(kind, source); self.insert_inst(inst, None); @@ -290,7 +285,7 @@ impl BodyBuilder { self.insert_inst(inst, None); } - pub fn revert(&mut self, arg: ValueId, source: SourceInfo) { + pub fn revert(&mut self, arg: Option, source: SourceInfo) { let kind = InstKind::Revert { arg }; let inst = Inst::new(kind, source); self.insert_inst(inst, None); @@ -308,6 +303,12 @@ impl BodyBuilder { self.insert_inst(inst, None); } + pub fn nop(&mut self, source: SourceInfo) { + let kind = InstKind::Nop; + let inst = Inst::new(kind, source); + self.insert_inst(inst, None); + } + pub fn value_ty(&mut self, value: ValueId) -> TypeId { self.body.store.value_ty(value) } @@ -340,11 +341,11 @@ impl BodyBuilder { let result = if let Some(result_ty) = result_ty { // Map a result value to the inst. - let temp = Temporary { + let temp = Value::Temporary { inst: inst_id, ty: result_ty, }; - Some(cursor.store_and_map_result(temp.into())) + Some(cursor.store_and_map_result(temp)) } else { None }; diff --git a/crates/mir/src/ir/function.rs b/crates/mir/src/ir/function.rs index 760dfc854b..9323e456a3 100644 --- a/crates/mir/src/ir/function.rs +++ b/crates/mir/src/ir/function.rs @@ -2,6 +2,7 @@ use fe_analyzer::namespace::items as analyzer_items; use fe_common::impl_intern_key; use fxhash::FxHashMap; use id_arena::Arena; +use num_bigint::BigInt; use smol_str::SmolStr; use super::{ @@ -9,7 +10,7 @@ use super::{ body_order::BodyOrder, inst::{BranchInfo, Inst, InstId, InstKind}, types::TypeId, - value::{Immediate, Local, Value, ValueId}, + value::{Local, Value, ValueId}, BasicBlockId, SourceInfo, }; @@ -90,7 +91,7 @@ pub struct BodyDataStore { /// Maps an immediate to a value to ensure the same immediate results in the /// same value. - immediates: FxHashMap, + immediates: FxHashMap<(BigInt, TypeId), ValueId>, unit_value: Option, @@ -121,9 +122,9 @@ impl BodyDataStore { pub fn store_value(&mut self, value: Value) -> ValueId { match value { - Value::Immediate(imm) => self.store_immediate(imm), + Value::Immediate { imm, ty } => self.store_immediate(imm, ty), - Value::Unit(_) => { + Value::Unit { .. } => { if let Some(unit_value) = self.unit_value { unit_value } else { @@ -146,6 +147,10 @@ impl BodyDataStore { } } + pub fn is_nop(&self, inst: InstId) -> bool { + matches!(&self.inst_data(inst).kind, InstKind::Nop) + } + pub fn is_terminator(&self, inst: InstId) -> bool { self.inst_data(inst).is_terminator() } @@ -158,6 +163,18 @@ impl BodyDataStore { &self.values[value] } + pub fn value_data_mut(&mut self, value: ValueId) -> &mut Value { + &mut self.values[value] + } + + pub fn values(&self) -> impl Iterator { + self.values.iter().map(|(_, value_data)| value_data) + } + + pub fn values_mut(&mut self) -> impl Iterator { + self.values.iter_mut().map(|(_, value_data)| value_data) + } + pub fn store_block(&mut self, block: BasicBlock) -> BasicBlockId { self.blocks.alloc(block) } @@ -203,6 +220,27 @@ impl BodyDataStore { &self.locals } + pub fn locals_mut(&mut self) -> &[ValueId] { + &mut self.locals + } + + pub fn func_args(&self) -> impl Iterator + '_ { + self.locals() + .iter() + .filter(|value| match self.value_data(**value) { + Value::Local(local) => local.is_arg, + _ => unreachable!(), + }) + .copied() + } + + pub fn func_args_mut(&mut self) -> impl Iterator { + self.values_mut().filter(|value| match value { + Value::Local(local) => local.is_arg, + _ => false, + }) + } + /// Returns Some(`local_name`) if value is `Value::Local`. pub fn local_name(&self, value: ValueId) -> Option<&str> { match self.value_data(value) { @@ -211,12 +249,19 @@ impl BodyDataStore { } } - fn store_immediate(&mut self, imm: Immediate) -> ValueId { - if let Some(value) = self.immediates.get(&imm) { + pub fn replace_value(&mut self, value: ValueId, to: Value) -> Value { + std::mem::replace(&mut self.values[value], to) + } + + fn store_immediate(&mut self, imm: BigInt, ty: TypeId) -> ValueId { + if let Some(value) = self.immediates.get(&(imm.clone(), ty)) { *value } else { - let id = self.values.alloc(Value::Immediate(imm.clone())); - self.immediates.insert(imm, id); + let id = self.values.alloc(Value::Immediate { + imm: imm.clone(), + ty, + }); + self.immediates.insert((imm, ty), id); id } } diff --git a/crates/mir/src/ir/inst.rs b/crates/mir/src/ir/inst.rs index e8ad3eed65..0cd8e032d8 100644 --- a/crates/mir/src/ir/inst.rs +++ b/crates/mir/src/ir/inst.rs @@ -22,7 +22,7 @@ pub enum InstKind { }, Assign { - lhs: ValueId, + lhs: AssignableValue, rhs: ValueId, }, @@ -88,7 +88,7 @@ pub enum InstKind { }, Revert { - arg: ValueId, + arg: Option, }, Emit { @@ -179,6 +179,99 @@ impl Inst { _ => BranchInfo::NotBranch, } } + + pub fn args(&self) -> ArgIter { + use InstKind::*; + match &self.kind { + Declare { local: arg } + | Assign { rhs: arg, .. } + | Unary { value: arg, .. } + | Cast { value: arg, .. } + | Emit { arg } + | Keccak256 { arg } + | Clone { arg } + | ToMem { arg } + | AbiEncode { arg } + | Create { value: arg, .. } + | Branch { cond: arg, .. } => ArgIter::One(Some(*arg)), + + Binary { lhs, rhs, .. } + | MapAccess { + value: lhs, + key: rhs, + } + | Create2 { + value: lhs, + salt: rhs, + .. + } => ArgIter::One(Some(*lhs)).chain(ArgIter::One(Some(*rhs))), + + Revert { arg } | Return { arg } => ArgIter::One(*arg), + + Nop | Jump { .. } => ArgIter::Zero, + + AggregateAccess { value, indices } => { + ArgIter::One(Some(*value)).chain(ArgIter::Slice(indices.iter())) + } + + AggregateConstruct { args, .. } | Call { args, .. } | YulIntrinsic { args, .. } => { + ArgIter::Slice(args.iter()) + } + } + } + + pub fn args_mut(&mut self) -> ArgMutIter { + use InstKind::*; + match &mut self.kind { + Declare { local: arg } + | Assign { rhs: arg, .. } + | Unary { value: arg, .. } + | Cast { value: arg, .. } + | Emit { arg } + | Keccak256 { arg } + | Clone { arg } + | ToMem { arg } + | AbiEncode { arg } + | Create { value: arg, .. } + | Branch { cond: arg, .. } => ArgMutIter::One(Some(arg)), + + Binary { lhs, rhs, .. } + | MapAccess { + value: lhs, + key: rhs, + } + | Create2 { + value: lhs, + salt: rhs, + .. + } => ArgMutIter::One(Some(lhs)).chain(ArgMutIter::One(Some(rhs))), + + Revert { arg } | Return { arg } => ArgMutIter::One(arg.as_mut()), + + Nop | Jump { .. } => ArgMutIter::Zero, + + AggregateAccess { value, indices } => { + ArgMutIter::One(Some(value)).chain(ArgMutIter::Slice(indices.iter_mut())) + } + + AggregateConstruct { args, .. } | Call { args, .. } | YulIntrinsic { args, .. } => { + ArgMutIter::Slice(args.iter_mut()) + } + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum AssignableValue { + Value(ValueId), + Aggregate { + lhs: Box, + idx: ValueId, + }, + Map { + lhs: Box, + key: ValueId, + }, } #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] @@ -528,3 +621,69 @@ pub enum BranchInfo { Jump(BasicBlockId), Branch(ValueId, BasicBlockId, BasicBlockId), } + +#[derive(Debug)] +pub enum ArgIter<'a> { + Zero, + One(Option), + Slice(std::slice::Iter<'a, ValueId>), + Chain(Box>, Box>), +} + +impl<'a> ArgIter<'a> { + fn chain(self, rhs: Self) -> Self { + Self::Chain(self.into(), rhs.into()) + } +} + +impl<'a> Iterator for ArgIter<'a> { + type Item = ValueId; + + fn next(&mut self) -> Option { + match self { + Self::Zero => None, + Self::One(value) => value.take(), + Self::Slice(s) => s.next().copied(), + Self::Chain(first, second) => { + if let Some(value) = first.next() { + Some(value) + } else { + second.next() + } + } + } + } +} + +#[derive(Debug)] +pub enum ArgMutIter<'a> { + Zero, + One(Option<&'a mut ValueId>), + Slice(std::slice::IterMut<'a, ValueId>), + Chain(Box>, Box>), +} + +impl<'a> ArgMutIter<'a> { + fn chain(self, rhs: Self) -> Self { + Self::Chain(self.into(), rhs.into()) + } +} + +impl<'a> Iterator for ArgMutIter<'a> { + type Item = &'a mut ValueId; + + fn next(&mut self) -> Option { + match self { + Self::Zero => None, + Self::One(value) => value.take(), + Self::Slice(s) => s.next(), + Self::Chain(first, second) => { + if let Some(value) = first.next() { + Some(value) + } else { + second.next() + } + } + } + } +} diff --git a/crates/mir/src/ir/types.rs b/crates/mir/src/ir/types.rs index c300e0be31..7f624a800b 100644 --- a/crates/mir/src/ir/types.rs +++ b/crates/mir/src/ir/types.rs @@ -25,6 +25,8 @@ pub enum Type { Event(EventDef), Contract(StructDef), Map(MapDef), + MPtr(TypeId), + SPtr(TypeId), } /// An interned Id for [`ArrayDef`]. diff --git a/crates/mir/src/ir/value.rs b/crates/mir/src/ir/value.rs index c869dd3949..b9afbbfffa 100644 --- a/crates/mir/src/ir/value.rs +++ b/crates/mir/src/ir/value.rs @@ -9,64 +9,33 @@ pub type ValueId = Id; #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub enum Value { /// A value resulted from an instruction. - Temporary(Temporary), + Temporary { inst: InstId, ty: TypeId }, /// A local variable declared in a function body. Local(Local), /// An immediate value. - Immediate(Immediate), + Immediate { imm: BigInt, ty: TypeId }, /// A constant value. - Constant(Constant), + Constant { constant: ConstantId, ty: TypeId }, /// A singleton value representing `Unit` type. - Unit(Unit), + Unit { ty: TypeId }, } impl Value { pub fn ty(&self) -> TypeId { match self { - Self::Temporary(val) => val.ty, Self::Local(val) => val.ty, - Self::Immediate(val) => val.ty, - Self::Constant(val) => val.ty, - Self::Unit(val) => val.ty, + Self::Immediate { ty, .. } + | Self::Temporary { ty, .. } + | Self::Unit { ty } + | Self::Constant { ty, .. } => *ty, } } } -macro_rules! embed { - ($(($variant: expr, $ty: ty)),*) => { - $( - impl From<$ty> for Value { - fn from(val: $ty) -> Self { - $variant(val) - } - })* - }; -} - -embed! { - (Value::Temporary, Temporary), - (Value::Local, Local), - (Value::Immediate, Immediate), - (Value::Constant, Constant), - (Value::Unit, Unit) -} - -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub struct Temporary { - pub inst: InstId, - pub ty: TypeId, -} - -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub struct Immediate { - pub value: BigInt, - pub ty: TypeId, -} - #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct Local { /// An original name of a local variable. @@ -114,14 +83,3 @@ impl Local { } } } - -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -pub struct Constant { - pub constant: ConstantId, - pub ty: TypeId, -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -pub struct Unit { - pub ty: TypeId, -} diff --git a/crates/mir/src/lower/function.rs b/crates/mir/src/lower/function.rs index cb1c1527c2..6d16b31801 100644 --- a/crates/mir/src/lower/function.rs +++ b/crates/mir/src/lower/function.rs @@ -14,9 +14,14 @@ use smol_str::SmolStr; use crate::{ db::MirDb, ir::{ - self, body_builder::BodyBuilder, constant::ConstantValue, function::Linkage, - inst::CallType, value::Local, BasicBlockId, Constant, FunctionBody, FunctionId, - FunctionParam, FunctionSignature, SourceInfo, TypeId, ValueId, + self, + body_builder::BodyBuilder, + constant::ConstantValue, + function::Linkage, + inst::{AssignableValue, CallType}, + value::Local, + BasicBlockId, Constant, FunctionBody, FunctionId, FunctionParam, FunctionSignature, + SourceInfo, TypeId, Value, ValueId, }, }; @@ -146,7 +151,8 @@ impl<'db, 'a> BodyLowerHelper<'db, 'a> { .builder .aggregate_access(rhs, vec![index], ty, expr.into()); } - self.builder.assign(local, rhs, stmt.into()); + self.builder + .assign(AssignableValue::Value(local), rhs, stmt.into()); } } } @@ -167,7 +173,7 @@ impl<'db, 'a> BodyLowerHelper<'db, 'a> { } ast::FuncStmt::Assign { target, value } => { - let lhs = self.lower_expr(target); + let lhs = self.lower_assignable_value(target); let rhs = self.lower_expr(value); self.builder.assign(lhs, rhs, stmt.into()); } @@ -178,7 +184,8 @@ impl<'db, 'a> BodyLowerHelper<'db, 'a> { let ty = self.expr_ty(target); let tmp = self.lower_binop(op.kind, lhs, rhs, ty, stmt.into()); - self.builder.assign(lhs, tmp, stmt.into()); + self.builder + .assign(AssignableValue::Value(lhs), tmp, stmt.into()); } ast::FuncStmt::For { target, iter, body } => self.lower_for_loop(target, iter, body), @@ -222,11 +229,7 @@ impl<'db, 'a> BodyLowerHelper<'db, 'a> { .branch(cond, then_bb, false_bb, SourceInfo::dummy()); self.builder.move_to_block(false_bb); - let msg = if let Some(msg) = msg { - self.lower_expr(msg) - } else { - self.make_unit() - }; + let msg = msg.as_ref().map(|msg| self.lower_expr(msg)); self.builder.revert(msg, stmt.into()); self.builder.move_to_block(then_bb); @@ -255,8 +258,7 @@ impl<'db, 'a> BodyLowerHelper<'db, 'a> { ast::FuncStmt::Pass => { // TODO: Generate appropriate error message. - let arg = self.make_unit(); - self.builder.revert(arg, stmt.into()); + self.builder.revert(None, stmt.into()); let next_block = self.builder.make_block(); self.builder.move_to_block(next_block); } @@ -276,7 +278,8 @@ impl<'db, 'a> BodyLowerHelper<'db, 'a> { let inc = self .builder .add(loop_idx, imm_one, u256_ty, SourceInfo::dummy()); - self.builder.assign(loop_idx, inc, SourceInfo::dummy()); + self.builder + .assign(AssignableValue::Value(loop_idx), inc, SourceInfo::dummy()); let maximum_iter_count = self.scope().maximum_iter_count(&self.scopes).unwrap(); let cond = self .builder @@ -291,12 +294,7 @@ impl<'db, 'a> BodyLowerHelper<'db, 'a> { } ast::FuncStmt::Revert { error } => { - let error = if let Some(error) = error { - self.lower_expr(error) - } else { - self.make_unit() - }; - + let error = error.as_ref().map(|err| self.lower_expr(err)); self.builder.revert(error, stmt.into()); let next_block = self.builder.make_block(); self.builder.move_to_block(next_block); @@ -438,14 +436,18 @@ impl<'db, 'a> BodyLowerHelper<'db, 'a> { let loop_idx = Local::tmp_local("$loop_idx_tmp".into(), u256_ty); let loop_idx = self.builder.declare(loop_idx); let imm_zero = self.builder.make_imm(0u32.into(), u256_ty); - self.builder.assign(loop_idx, imm_zero, SourceInfo::dummy()); + self.builder.assign( + AssignableValue::Value(loop_idx), + imm_zero, + SourceInfo::dummy(), + ); // Evaluates loop variable. let iter = self.lower_expr(iter); // Create maximum loop count. let iter_ty = self.builder.value_ty(iter); - let maximum_iter_count = match iter_ty.data(self.db).as_ref() { + let maximum_iter_count = match iter_ty.deref(self.db).data(self.db).as_ref() { ir::Type::Array(ir::types::ArrayDef { len, .. }) => *len, _ => unreachable!(), }; @@ -465,8 +467,11 @@ impl<'db, 'a> BodyLowerHelper<'db, 'a> { let iter_elem = self.builder .aggregate_access(iter, vec![loop_idx], iter_elem_ty, SourceInfo::dummy()); - self.builder - .assign(loop_value, iter_elem, SourceInfo::dummy()); + self.builder.assign( + AssignableValue::Value(loop_value), + iter_elem, + SourceInfo::dummy(), + ); for stmt in body { self.lower_stmt(stmt); @@ -477,7 +482,8 @@ impl<'db, 'a> BodyLowerHelper<'db, 'a> { let inc = self .builder .add(loop_idx, imm_one, u256_ty, SourceInfo::dummy()); - self.builder.assign(loop_idx, inc, SourceInfo::dummy()); + self.builder + .assign(AssignableValue::Value(loop_idx), inc, SourceInfo::dummy()); let cond = self .builder .eq(loop_idx, maximum_iter_count, u256_ty, SourceInfo::dummy()); @@ -511,12 +517,14 @@ impl<'db, 'a> BodyLowerHelper<'db, 'a> { self.builder.move_to_block(true_bb); let value = self.lower_expr(if_expr); - self.builder.assign(tmp, value, SourceInfo::dummy()); + self.builder + .assign(AssignableValue::Value(tmp), value, SourceInfo::dummy()); self.builder.jump(merge_bb, SourceInfo::dummy()); self.builder.move_to_block(false_bb); let value = self.lower_expr(else_expr); - self.builder.assign(tmp, value, SourceInfo::dummy()); + self.builder + .assign(AssignableValue::Value(tmp), value, SourceInfo::dummy()); self.builder.jump(merge_bb, SourceInfo::dummy()); self.builder.move_to_block(merge_bb); @@ -560,7 +568,6 @@ impl<'db, 'a> BodyLowerHelper<'db, 'a> { let result_ty = self.expr_ty(expr); if !self.expr_ty(value).is_aggregate(self.db) { - // Indices is empty is the `expr` is map let value = self.lower_expr(value); let key = self.lower_expr(index); self.builder.map_access(value, key, result_ty, expr.into()) @@ -611,6 +618,34 @@ impl<'db, 'a> BodyLowerHelper<'db, 'a> { } } + fn lower_assignable_value(&mut self, expr: &Node) -> AssignableValue { + match &expr.kind { + ast::Expr::Attribute { value, attr } => { + let idx_ty = self.u256_ty(); + let idx = self.expr_ty(value).index_from_fname(self.db, &attr.kind); + let idx = self.builder.make_value(Value::Immediate { + imm: idx, + ty: idx_ty, + }); + let lhs = self.lower_assignable_value(value).into(); + AssignableValue::Aggregate { lhs, idx } + } + ast::Expr::Subscript { value, index } => { + let lhs = self.lower_assignable_value(value).into(); + let attr = self.lower_expr(index); + if self.expr_ty(value).is_aggregate(self.db) { + AssignableValue::Aggregate { lhs, idx: attr } + } else { + AssignableValue::Map { lhs, key: attr } + } + } + _ => { + let value = self.lower_expr(expr); + AssignableValue::Value(value) + } + } + } + fn expr_ty(&self, expr: &Node) -> TypeId { let analyzer_ty = self.analyzer_body.expressions[&expr.id].typ.clone(); self.db.mir_lowered_type(analyzer_ty) @@ -639,12 +674,14 @@ impl<'db, 'a> BodyLowerHelper<'db, 'a> { self.builder.move_to_block(true_bb); let rhs = self.lower_expr(rhs); - self.builder.assign(tmp, rhs, SourceInfo::dummy()); + self.builder + .assign(AssignableValue::Value(tmp), rhs, SourceInfo::dummy()); self.builder.jump(merge_bb, SourceInfo::dummy()); self.builder.move_to_block(false_bb); let false_imm = self.builder.make_imm_from_bool(false, ty); - self.builder.assign(tmp, false_imm, SourceInfo::dummy()); + self.builder + .assign(AssignableValue::Value(tmp), false_imm, SourceInfo::dummy()); self.builder.jump(merge_bb, SourceInfo::dummy()); } @@ -654,12 +691,14 @@ impl<'db, 'a> BodyLowerHelper<'db, 'a> { self.builder.move_to_block(true_bb); let true_imm = self.builder.make_imm_from_bool(true, ty); - self.builder.assign(tmp, true_imm, SourceInfo::dummy()); + self.builder + .assign(AssignableValue::Value(tmp), true_imm, SourceInfo::dummy()); self.builder.jump(merge_bb, SourceInfo::dummy()); self.builder.move_to_block(false_bb); let rhs = self.lower_expr(rhs); - self.builder.assign(tmp, rhs, SourceInfo::dummy()); + self.builder + .assign(AssignableValue::Value(tmp), rhs, SourceInfo::dummy()); self.builder.jump(merge_bb, SourceInfo::dummy()); } } @@ -783,7 +822,7 @@ impl<'db, 'a> BodyLowerHelper<'db, 'a> { } else { self.builder.cast(arg, ty, source) } - } else if ty.is_aggregate(self.db) { + } else if ty.deref(self.db).is_aggregate(self.db) { self.builder.aggregate_construct(ty, args, source) } else { unreachable!() @@ -807,11 +846,13 @@ impl<'db, 'a> BodyLowerHelper<'db, 'a> { ) -> ValueId { match &expr.kind { ast::Expr::Attribute { value, attr } => { - let index = - self.expr_ty(value) - .index_from_fname(self.db, &attr.kind, self.u256_ty()); + let index = self.expr_ty(value).index_from_fname(self.db, &attr.kind); + let index_ty = self.u256_ty(); let value = self.lower_aggregate_access(value, indices); - indices.push(self.builder.make_value(index)); + indices.push(self.builder.make_value(Value::Immediate { + imm: index, + ty: index_ty, + })); value } diff --git a/crates/mir/src/pretty_print/inst.rs b/crates/mir/src/pretty_print/inst.rs index 4af6582b55..93a9420ea7 100644 --- a/crates/mir/src/pretty_print/inst.rs +++ b/crates/mir/src/pretty_print/inst.rs @@ -2,7 +2,11 @@ use std::fmt::{self, Write}; use crate::{ db::MirDb, - ir::{function::BodyDataStore, inst::InstKind, InstId}, + ir::{ + function::BodyDataStore, + inst::{AssignableValue, InstKind}, + InstId, + }, }; use super::PrettyPrint; @@ -31,8 +35,6 @@ impl PrettyPrint for InstId { InstKind::Assign { lhs, rhs } => { lhs.pretty_print(db, store, w)?; - write!(w, ": ")?; - store.value_ty(*lhs).pretty_print(db, store, w)?; write!(w, " = ")?; rhs.pretty_print(db, store, w) } @@ -55,7 +57,7 @@ impl PrettyPrint for InstId { } InstKind::AggregateConstruct { ty, args } => { - ty.pretty_print(db, store, w)?; + ty.deref(db).pretty_print(db, store, w)?; write!(w, "{{")?; if args.is_empty() { return write!(w, "}}"); @@ -113,7 +115,10 @@ impl PrettyPrint for InstId { InstKind::Revert { arg } => { write!(w, "revert ")?; - arg.pretty_print(db, store, w) + if let Some(arg) = arg { + arg.pretty_print(db, store, w)?; + } + Ok(()) } InstKind::Emit { arg } => { @@ -182,3 +187,29 @@ impl PrettyPrint for InstId { } } } + +impl PrettyPrint for AssignableValue { + fn pretty_print( + &self, + db: &dyn MirDb, + store: &BodyDataStore, + w: &mut W, + ) -> fmt::Result { + match self { + Self::Value(value) => value.pretty_print(db, store, w), + Self::Aggregate { lhs, idx } => { + lhs.pretty_print(db, store, w)?; + write!(w, ".<")?; + idx.pretty_print(db, store, w)?; + write!(w, ">") + } + + Self::Map { lhs, key } => { + lhs.pretty_print(db, store, w)?; + write!(w, "{{")?; + key.pretty_print(db, store, w)?; + write!(w, "}}") + } + } + } +} diff --git a/crates/mir/src/pretty_print/types.rs b/crates/mir/src/pretty_print/types.rs index 2b8330e037..e3f74d28e0 100644 --- a/crates/mir/src/pretty_print/types.rs +++ b/crates/mir/src/pretty_print/types.rs @@ -69,6 +69,14 @@ impl PrettyPrint for TypeId { def.value_ty.pretty_print(db, store, w)?; write!(w, ">") } + Type::MPtr(inner) => { + write!(w, "*@s ")?; + inner.pretty_print(db, store, w) + } + Type::SPtr(inner) => { + write!(w, "*@m ")?; + inner.pretty_print(db, store, w) + } } } } diff --git a/crates/mir/src/pretty_print/value.rs b/crates/mir/src/pretty_print/value.rs index a8ef1255ce..a39b251669 100644 --- a/crates/mir/src/pretty_print/value.rs +++ b/crates/mir/src/pretty_print/value.rs @@ -15,10 +15,10 @@ impl PrettyPrint for ValueId { w: &mut W, ) -> fmt::Result { match store.value_data(*self) { - Value::Temporary(_) | Value::Local(_) => write!(w, "_{}", self.index()), - Value::Immediate(imm) => write!(w, "{}", imm.value), - Value::Constant(constant) => { - let const_value = constant.constant.data(db); + Value::Temporary { .. } | Value::Local(_) => write!(w, "_{}", self.index()), + Value::Immediate { imm, .. } => write!(w, "{}", imm), + Value::Constant { constant, .. } => { + let const_value = constant.data(db); write!(w, "const ")?; match &const_value.value { ConstantValue::Immediate(num) => write!(w, "{}", num), @@ -26,7 +26,7 @@ impl PrettyPrint for ValueId { ConstantValue::Bool(b) => write!(w, "{}", b), } } - Value::Unit(_) => write!(w, "()"), + Value::Unit { .. } => write!(w, "()"), } } }