diff --git a/kani-compiler/src/codegen_cprover_gotoc/compiler_interface.rs b/kani-compiler/src/codegen_cprover_gotoc/compiler_interface.rs index a1dda0147b5c..526730a865d8 100644 --- a/kani-compiler/src/codegen_cprover_gotoc/compiler_interface.rs +++ b/kani-compiler/src/codegen_cprover_gotoc/compiler_interface.rs @@ -37,8 +37,8 @@ use rustc_metadata::fs::{emit_wrapper_file, METADATA_FILENAME}; use rustc_metadata::EncodedMetadata; use rustc_middle::dep_graph::{WorkProduct, WorkProductId}; use rustc_middle::mir::mono::MonoItem; +use rustc_middle::ty::{self, Instance, InstanceDef, ParamEnv, TyCtxt}; use rustc_middle::util::Providers; -use rustc_middle::ty::{Instance, InstanceDef, TyCtxt}; use rustc_session::config::{CrateType, OutputFilenames, OutputType}; use rustc_session::cstore::MetadataLoaderDyn; use rustc_session::output::out_filename; @@ -82,7 +82,7 @@ impl GotocCodegenBackend { starting_items: &[MonoItem<'tcx>], symtab_goto: &Path, machine_model: &MachineModel, - mut check_contract: Option, + check_contract: Option, ) -> (GotocCtx<'tcx>, Vec>) { let items = with_timer( || collect_reachable_items(tcx, starting_items), @@ -146,25 +146,38 @@ impl GotocCodegenBackend { } } - if let Some(did) = &mut check_contract { - let attrs = KaniAttributes::for_item(tcx, *did); - *did = attrs.inner_check().unwrap().unwrap() - } + check_contract.map(|check_contract_def| { + let check_contract_attrs = KaniAttributes::for_item(tcx, check_contract_def); + let wrapper_def = check_contract_attrs.inner_check().unwrap().unwrap(); - let mut instances_of_check = items.iter().copied().filter_map(|i| match i { - MonoItem::Fn(instance @ Instance { def: InstanceDef::Item(did), .. }) => (check_contract == Some(did)).then_some(instance), - _ => None - }); - let instance_of_check = instances_of_check.next().unwrap(); - assert!(instances_of_check.next().is_none()); - let attrs = KaniAttributes::for_item(tcx, instance_of_check.def_id()); - let assigns_contract = attrs.modifies_contract().unwrap_or_else(|| { - debug!(?instance_of_check, "had no assigns contract specified"); - vec![] - }); - gcx.attach_contract(instance_of_check, assigns_contract); + let mut instances_of_check = items.iter().copied().filter_map(|i| match i { + MonoItem::Fn(instance @ Instance { def: InstanceDef::Item(did), .. }) => { + (wrapper_def == did).then_some(instance) + } + _ => None, + }); + let instance_of_check = instances_of_check.next().unwrap(); + assert!(instances_of_check.next().is_none()); + let attrs = KaniAttributes::for_item(tcx, instance_of_check.def_id()); + let assigns_contract = attrs.modifies_contract().unwrap_or_else(|| { + debug!(?instance_of_check, "had no assigns contract specified"); + vec![] + }); + gcx.attach_contract(instance_of_check, assigns_contract); + + let tracker_def = check_contract_attrs.recursion_tracker().unwrap().unwrap(); + let tracker_instance = Instance::expect_resolve( + tcx, + ParamEnv::reveal_all(), + tracker_def, + ty::List::empty(), + ); - tcx.symbol_name(instance_of_check).to_string() + ( + tcx.symbol_name(instance_of_check).to_string(), + tcx.symbol_name(tracker_instance).to_string(), + ) + }) }, "codegen", ); diff --git a/kani-compiler/src/kani_middle/attributes.rs b/kani-compiler/src/kani_middle/attributes.rs index 8b480bb37309..c6e8c2663b2a 100644 --- a/kani-compiler/src/kani_middle/attributes.rs +++ b/kani-compiler/src/kani_middle/attributes.rs @@ -60,6 +60,7 @@ enum KaniAttributeKind { IsContractGenerated, Modifies, InnerCheck, + RecursionTracker, } impl KaniAttributeKind { @@ -77,6 +78,7 @@ impl KaniAttributeKind { | KaniAttributeKind::ReplacedWith | KaniAttributeKind::CheckedWith | KaniAttributeKind::Modifies + | KaniAttributeKind::RecursionTracker | KaniAttributeKind::InnerCheck | KaniAttributeKind::IsContractGenerated => false, } @@ -222,6 +224,10 @@ impl<'tcx> KaniAttributes<'tcx> { .map(|target| expect_key_string_value(self.tcx.sess, target)) } + pub fn recursion_tracker(&self) -> Option> { + self.eval_sibling_attribute(KaniAttributeKind::RecursionTracker) + } + fn eval_sibling_attribute( &self, kind: KaniAttributeKind, @@ -340,6 +346,9 @@ impl<'tcx> KaniAttributes<'tcx> { KaniAttributeKind::InnerCheck => { self.inner_check(); } + KaniAttributeKind::RecursionTracker => { + self.recursion_tracker(); + } } } } @@ -443,6 +452,7 @@ impl<'tcx> KaniAttributes<'tcx> { | KaniAttributeKind::IsContractGenerated | KaniAttributeKind::Modifies | KaniAttributeKind::InnerCheck + | KaniAttributeKind::RecursionTracker | KaniAttributeKind::ReplacedWith => { self.tcx.sess.span_err(self.tcx.def_span(self.item), format!("Contracts are not supported on harnesses. (Found the kani-internal contract attribute `{}`)", kind.as_ref())); } @@ -498,7 +508,7 @@ impl<'tcx> KaniAttributes<'tcx> { .span_note( self.tcx.def_span(def_id), format!( - "Try adding a contract to this function or use the unsound `{}` attribute instead.", + "Try adding a contract to this function or use the unsound `{}` attribute instead.", KaniAttributeKind::Stub.as_ref(), ) ) diff --git a/kani-driver/src/call_goto_instrument.rs b/kani-driver/src/call_goto_instrument.rs index 444298fb1769..c6b36a7353f0 100644 --- a/kani-driver/src/call_goto_instrument.rs +++ b/kani-driver/src/call_goto_instrument.rs @@ -22,7 +22,7 @@ impl KaniSession { output: &Path, project: &Project, harness: &HarnessMetadata, - contract_info: Option, + contract_info: Option<(String, String)>, ) -> Result<()> { // We actually start by calling goto-cc to start the specialization: self.specialize_to_proof_harness(input, output, &harness.mangled_name)?; @@ -168,7 +168,7 @@ impl KaniSession { &self, harness: &HarnessMetadata, file: &Path, - check: Option, + check: Option<(String, String)>, ) -> Result<()> { if check.is_none() { return Ok(()); @@ -177,9 +177,14 @@ impl KaniSession { let mut args: Vec = vec!["--dfcc".into(), (&harness.mangled_name).into()]; - if let Some(function) = check { + if let Some((function, recursion_tracker)) = check { println!("enforcing function contract for {function}"); - args.extend(["--enforce-contract".into(), function.into()]); + args.extend([ + "--enforce-contract".into(), + function.into(), + "--nondet-static-exclude".into(), + recursion_tracker.into(), + ]); } args.extend([file.into(), file.into()]); diff --git a/kani-driver/src/harness_runner.rs b/kani-driver/src/harness_runner.rs index 563d3f2f86c6..b7cfcd3c35c7 100644 --- a/kani-driver/src/harness_runner.rs +++ b/kani-driver/src/harness_runner.rs @@ -80,7 +80,7 @@ impl<'sess, 'pr> HarnessRunner<'sess, 'pr> { Ok(results) } - fn get_contract_info(&self, harness: &'pr HarnessMetadata) -> Result> { + fn get_contract_info(&self, harness: &'pr HarnessMetadata) -> Result> { let contract_info_artifact = self.project.get_harness_artifact(&harness, ArtifactType::ContractMetadata).unwrap(); diff --git a/library/kani_macros/src/sysroot/contracts.rs b/library/kani_macros/src/sysroot/contracts.rs index 3fc360690842..c50b88f42752 100644 --- a/library/kani_macros/src/sysroot/contracts.rs +++ b/library/kani_macros/src/sysroot/contracts.rs @@ -239,26 +239,24 @@ use std::{ }; use syn::{ parse_macro_input, spanned::Spanned, visit::Visit, visit_mut::VisitMut, Attribute, Expr, FnArg, - ItemFn, PredicateType, ReturnType, Signature, Token, TraitBound, - TypeParamBound, WhereClause, + ItemFn, PredicateType, ReturnType, Signature, Token, TraitBound, TypeParamBound, WhereClause, }; #[allow(dead_code)] pub fn requires(attr: TokenStream, item: TokenStream) -> TokenStream { - requires_ensures_main(attr, item, 0) + requires_ensures_main(attr, item, ContractConditionsType::Requires) } #[allow(dead_code)] pub fn ensures(attr: TokenStream, item: TokenStream) -> TokenStream { - requires_ensures_main(attr, item, 1) + requires_ensures_main(attr, item, ContractConditionsType::Ensures) } #[allow(dead_code)] pub fn modifies(attr: TokenStream, item: TokenStream) -> TokenStream { - requires_ensures_main(attr, item, 2) + requires_ensures_main(attr, item, ContractConditionsType::Modifies) } - /// This is very similar to the kani_attribute macro, but it instead creates /// key-value style attributes which I find a little easier to parse. macro_rules! passthrough { @@ -283,7 +281,21 @@ macro_rules! passthrough { } passthrough!(stub_verified, false); -passthrough!(proof_for_contract, true); + +pub fn proof_for_contract(attr: TokenStream, item: TokenStream) -> TokenStream { + let args = proc_macro2::TokenStream::from(attr); + let ItemFn { attrs, vis, sig, block } = parse_macro_input!(item as ItemFn); + quote!( + #[allow(dead_code)] + #[kanitool::proof_for_contract = stringify!(#args)] + #(#attrs)* + #vis #sig { + let _ = std::boxed::Box::new(0_usize); + #block + } + ) + .into() +} /// Classifies the state a function is in in the contract handling pipeline. #[derive(Clone, Copy, PartialEq, Eq)] @@ -362,7 +374,7 @@ impl ContractFunctionState { struct ContractConditionsHandler<'a> { function_state: ContractFunctionState, /// Information specific to the type of contract attribute we're expanding. - condition_type: ContractConditionsType, + condition_type: ContractConditionsData, /// Body of the function this attribute was found on. annotated_fn: &'a mut ItemFn, /// An unparsed, unmodified copy of `attr`, used in the error messages. @@ -372,9 +384,16 @@ struct ContractConditionsHandler<'a> { hash: Option, } +#[derive(Copy, Clone, Eq, PartialEq)] +enum ContractConditionsType { + Requires, + Ensures, + Modifies, +} + /// Information needed for generating check and replace handlers for different /// contract attributes. -enum ContractConditionsType { +enum ContractConditionsData { Requires { /// The contents of the attribute. attr: Expr, @@ -391,7 +410,7 @@ enum ContractConditionsType { }, } -impl ContractConditionsType { +impl ContractConditionsData { /// Constructs a [`Self::Ensures`] from the signature of the decorated /// function and the contents of the decorating attribute. /// @@ -399,7 +418,7 @@ impl ContractConditionsType { /// `argument_names`. fn new_ensures(sig: &Signature, mut attr: Expr) -> Self { let argument_names = rename_argument_occurrences(sig, &mut attr); - ContractConditionsType::Ensures { argument_names, attr } + ContractConditionsData::Ensures { argument_names, attr } } } @@ -412,7 +431,7 @@ impl<'a> ContractConditionsHandler<'a> { /// [`ContractConditionsType`] depending on `is_requires`. fn new( function_state: ContractFunctionState, - is_requires: u8, + is_requires: ContractConditionsType, attr: TokenStream, annotated_fn: &'a mut ItemFn, attr_copy: TokenStream2, @@ -420,9 +439,13 @@ impl<'a> ContractConditionsHandler<'a> { hash: Option, ) -> Result { let condition_type = match is_requires { - 0 => ContractConditionsType::Requires { attr: syn::parse(attr)? }, - 1 => ContractConditionsType::new_ensures(&annotated_fn.sig, syn::parse(attr)?), - 2 => ContractConditionsType::Modifies { + ContractConditionsType::Requires => { + ContractConditionsData::Requires { attr: syn::parse(attr)? } + } + ContractConditionsType::Ensures => { + ContractConditionsData::new_ensures(&annotated_fn.sig, syn::parse(attr)?) + } + ContractConditionsType::Modifies => ContractConditionsData::Modifies { attr: chunks_by(TokenStream2::from(attr), is_token_stream_2_comma) .map(syn::parse2) .filter_map(|expr| match expr { @@ -434,7 +457,6 @@ impl<'a> ContractConditionsHandler<'a> { }) .collect(), }, - _ => unreachable!(), }; Ok(Self { function_state, condition_type, annotated_fn, attr_copy, output, hash }) @@ -450,14 +472,14 @@ impl<'a> ContractConditionsHandler<'a> { let Self { attr_copy, .. } = self; match &self.condition_type { - ContractConditionsType::Requires { attr } => { + ContractConditionsData::Requires { attr } => { let block = self.create_inner_call([].into_iter()); quote!( kani::assume(#attr); #(#block)* ) } - ContractConditionsType::Ensures { argument_names, attr } => { + ContractConditionsData::Ensures { argument_names, attr } => { let (arg_copies, copy_clean) = make_unsafe_argument_copies(&argument_names); // The code that enforces the postconditions and cleans up the shallow @@ -480,7 +502,7 @@ impl<'a> ContractConditionsHandler<'a> { result ) } - ContractConditionsType::Modifies { attr } => { + ContractConditionsData::Modifies { attr } => { let wrapper_name = self.make_wrapper_name().to_string(); let wrapper_args = make_wrapper_args(attr.len()); // TODO handle first invocation where this is the actual body. @@ -556,11 +578,11 @@ impl<'a> ContractConditionsHandler<'a> { let return_type = return_type_to_type(&sig.output); match &self.condition_type { - ContractConditionsType::Requires { attr } => quote!( + ContractConditionsData::Requires { attr } => quote!( kani::assert(#attr, stringify!(#attr_copy)); #call_to_prior ), - ContractConditionsType::Ensures { attr, argument_names } => { + ContractConditionsData::Ensures { attr, argument_names } => { let (arg_copies, copy_clean) = make_unsafe_argument_copies(&argument_names); quote!( #arg_copies @@ -570,7 +592,7 @@ impl<'a> ContractConditionsHandler<'a> { result ) } - ContractConditionsType::Modifies { .. } => { + ContractConditionsData::Modifies { .. } => { quote!(kani::assert(false, "Replacement with modifies is not supported yet.")) } } @@ -637,7 +659,7 @@ impl<'a> ContractConditionsHandler<'a> { } fn emit_augmented_modifies_wrapper(&mut self) { - if let ContractConditionsType::Modifies { attr } = &self.condition_type { + if let ContractConditionsData::Modifies { attr } = &self.condition_type { let wrapper_args = make_wrapper_args(attr.len()); let sig = &mut self.annotated_fn.sig; for arg in wrapper_args.clone() { @@ -877,7 +899,11 @@ fn make_unsafe_argument_copies( /// /// See the [module level documentation][self] for a description of how the code /// generation works. -fn requires_ensures_main(attr: TokenStream, item: TokenStream, is_requires: u8) -> TokenStream { +fn requires_ensures_main( + attr: TokenStream, + item: TokenStream, + is_requires: ContractConditionsType, +) -> TokenStream { let attr_copy = TokenStream2::from(attr.clone()); let mut output = proc_macro2::TokenStream::new(); @@ -949,10 +975,20 @@ fn requires_ensures_main(attr: TokenStream, item: TokenStream, is_requires: u8) // and "replace" functions. let item_hash = hash.unwrap(); - let check_fn_name = identifier_for_generated_function(&original_function_name, "check", item_hash); - let replace_fn_name = identifier_for_generated_function(&original_function_name, "replace", item_hash); - let recursion_wrapper_name = - identifier_for_generated_function(&original_function_name, "recursion_wrapper", item_hash); + let check_fn_name = + identifier_for_generated_function(&original_function_name, "check", item_hash); + let replace_fn_name = + identifier_for_generated_function(&original_function_name, "replace", item_hash); + let recursion_wrapper_name = identifier_for_generated_function( + &original_function_name, + "recursion_wrapper", + item_hash, + ); + let recursion_tracker_name = identifier_for_generated_function( + &original_function_name, + "recursion_tracker", + item_hash, + ); // Constructing string literals explicitly here, because `stringify!` // doesn't work. Let's say we have an identifier `check_fn` and we were @@ -966,6 +1002,8 @@ fn requires_ensures_main(attr: TokenStream, item: TokenStream, is_requires: u8) syn::LitStr::new(&handler.make_wrapper_name().to_string(), Span::call_site()); let recursion_wrapper_name_str = syn::LitStr::new(&recursion_wrapper_name.to_string(), Span::call_site()); + let recursion_tracker_name_str = + syn::LitStr::new(&recursion_tracker_name.to_string(), Span::call_site()); // The order of `attrs` and `kanitool::{checked_with, // is_contract_generated}` is important here, because macros are @@ -982,8 +1020,8 @@ fn requires_ensures_main(attr: TokenStream, item: TokenStream, is_requires: u8) #[kanitool::checked_with = #recursion_wrapper_name_str] #[kanitool::replaced_with = #replace_fn_name_str] #[kanitool::inner_check = #wrapper_fn_name_str] + #[kanitool::recursion_tracker = #recursion_tracker_name_str] #vis #sig { - let _ = std::boxed::Box::new(0_usize); #block } )); @@ -1001,16 +1039,17 @@ fn requires_ensures_main(attr: TokenStream, item: TokenStream, is_requires: u8) }; handler.output.extend(quote!( + + static mut #recursion_tracker_name: bool = false; #[allow(dead_code, unused_variables)] #[kanitool::is_contract_generated(recursion_wrapper)] #wrapper_sig { - static mut REENTRY: bool = false; - if unsafe { REENTRY } { + if unsafe { #recursion_tracker_name } { #call_replace(#(#args),*) } else { - unsafe { REENTRY = true }; + unsafe { #recursion_tracker_name = true }; let result = #call_check(#(#also_args),*); - unsafe { REENTRY = false }; + unsafe { #recursion_tracker_name = false }; result } }