diff --git a/src/tests_and_benchmarks/ozk/ozk_parsing.rs b/src/tests_and_benchmarks/ozk/ozk_parsing.rs index 0b1eb4f8..716a9074 100644 --- a/src/tests_and_benchmarks/ozk/ozk_parsing.rs +++ b/src/tests_and_benchmarks/ozk/ozk_parsing.rs @@ -12,18 +12,21 @@ pub type StructsAndMethods = HashMap (StructsAndMethods, Option) { + function_name: Option<&str>, +) -> (StructsAndMethods, Option, Vec) { get_standard_setup!(ast_types::ListType::Unsafe, graft_config, _lib); let mut outer_function: Option = None; let mut custom_types: HashMap, Vec)> = HashMap::default(); + let mut dependencies = vec![]; for item in parsed_file.items { match item { // Top-level function declaration syn::Item::Fn(func) => { - if func.sig.ident == function_name { - outer_function = Some(func.to_owned()); + if let Some(function_name) = function_name { + if func.sig.ident == function_name { + outer_function = Some(func.to_owned()); + } } } @@ -83,7 +86,43 @@ fn extract_types_and_function( } }; } - _ => {} + syn::Item::Use(syn::ItemUse { + attrs: _, + vis: _, + use_token: _, + leading_colon: _, + tree, + semi_token: _, + }) => { + fn get_module_name(tree: &syn::UseTree) -> Option { + match tree { + syn::UseTree::Path(use_path) => { + if use_path.ident == "super" { + match use_path.tree.as_ref() { + syn::UseTree::Path(use_path) => { + if let syn::UseTree::Glob(_) = *use_path.tree { + Some(use_path.ident.to_string()) + } else { + None + } + } + _ => None, + } + } else { + None + } + } + _ => None, + } + } + + // Import files specified with `use super::::*;` + // This is currently the only way to import + if let Some(module_name) = get_module_name(&tree) { + dependencies.push(module_name); + } + } + _ => (), } } @@ -93,7 +132,7 @@ fn extract_types_and_function( .map(|(struct_name, (option_struct, methods))| (struct_name.clone(), (option_struct.unwrap_or_else(|| panic!("Couldn't find struct definition for {struct_name} for which methods was defined")), methods))) .collect(); - (structs, outer_function) + (structs, outer_function, dependencies) } /// Return the Rust-AST for the `main` function and all custom types defined in the @@ -103,17 +142,34 @@ pub(super) fn parse_function_and_structs( module_name: &str, function_name: &str, ) -> (syn::ItemFn, StructsAndMethods, String) { - let path = format!( - "{}/src/tests_and_benchmarks/ozk/programs/{directory}/{module_name}.rs", - env!("CARGO_MANIFEST_DIR"), - ); - let content = fs::read_to_string(&path).expect("Unable to read file {path}"); - let parsed_file: syn::File = syn::parse_str(&content).expect("Unable to parse rust code"); - let (custom_types, main_parsed) = extract_types_and_function(parsed_file, function_name); + fn parse_function_and_structs_inner( + directory: &str, + module_name: &str, + function_name: Option<&str>, + ) -> (Option, StructsAndMethods, String) { + let path = format!( + "{}/src/tests_and_benchmarks/ozk/programs/{directory}/{module_name}.rs", + env!("CARGO_MANIFEST_DIR"), + ); + let content = fs::read_to_string(&path).expect("Unable to read file {path}"); + let parsed_file: syn::File = syn::parse_str(&content).expect("Unable to parse rust code"); + let (mut custom_types, main_parsed, dependencies) = + extract_types_and_function(parsed_file, function_name); + + for dependency in dependencies { + let (_, imported_custom_types, _) = + parse_function_and_structs_inner(directory, &dependency, None); + custom_types.extend(imported_custom_types.into_iter()) + } + + (main_parsed, custom_types, module_name.to_owned()) + } + let (main_parsed, custom_types, module_name) = + parse_function_and_structs_inner(directory, module_name, Some(function_name)); match main_parsed { Some(main) => (main, custom_types, module_name.to_owned()), - None => panic!("Failed to parse file {path}"), + None => panic!("Failed to parse module {module_name}"), } } diff --git a/src/tests_and_benchmarks/ozk/programs/other.rs b/src/tests_and_benchmarks/ozk/programs/other.rs index 70082c7c..4298ab09 100644 --- a/src/tests_and_benchmarks/ozk/programs/other.rs +++ b/src/tests_and_benchmarks/ozk/programs/other.rs @@ -1,8 +1,10 @@ mod hash_varlen; +mod import_type_declaration; mod nested_tuples; #[allow(dead_code)] mod removal_record_integrity_partial; mod returning_block_expr_u32; mod simple_encode; mod simple_map_on_bfe; +mod simple_struct; mod value; diff --git a/src/tests_and_benchmarks/ozk/programs/other/import_type_declaration.rs b/src/tests_and_benchmarks/ozk/programs/other/import_type_declaration.rs new file mode 100644 index 00000000..9fe0dc23 --- /dev/null +++ b/src/tests_and_benchmarks/ozk/programs/other/import_type_declaration.rs @@ -0,0 +1,67 @@ +use super::simple_struct::*; +use crate::tests_and_benchmarks::ozk::rust_shadows as tasm; +use triton_vm::BFieldElement; +use twenty_first::shared_math::bfield_codec::BFieldCodec; + +fn main() { + let ts: Box = + SimpleStruct::decode(&tasm::load_from_memory(BFieldElement::new(300))).unwrap(); + + tasm::tasm_io_write_to_stdout___u128(ts.a); + tasm::tasm_io_write_to_stdout___bfe(ts.b); + tasm::tasm_io_write_to_stdout___bool(ts.c); + tasm::tasm_io_write_to_stdout___u32(ts.d.len() as u32); + tasm::tasm_io_write_to_stdout___digest(ts.e); + + return; +} + +mod tests { + use super::*; + use crate::tests_and_benchmarks::ozk::ozk_parsing; + use crate::tests_and_benchmarks::ozk::rust_shadows; + use crate::tests_and_benchmarks::test_helpers::shared_test::execute_compiled_with_stack_memory_and_ins_for_test; + use crate::tests_and_benchmarks::test_helpers::shared_test::init_memory_from; + use arbitrary::Arbitrary; + use arbitrary::Unstructured; + use itertools::Itertools; + use rand::random; + use std::collections::HashMap; + + #[test] + fn import_type_declaration_test() { + let rand: [u8; 2000] = random(); + let test_struct = SimpleStruct::arbitrary(&mut Unstructured::new(&rand)).unwrap(); + let non_determinism = init_memory_from(&test_struct, BFieldElement::new(300)); + let stdin = vec![]; + + // Run test on host machine + let native_output = + rust_shadows::wrap_main_with_io(&main)(stdin.clone(), non_determinism.clone()); + + // Run test on Triton-VM + let test_program = ozk_parsing::compile_for_test( + "other", + "import_type_declaration", + "main", + crate::ast_types::ListType::Unsafe, + ); + let vm_output = execute_compiled_with_stack_memory_and_ins_for_test( + &test_program, + vec![], + &mut HashMap::default(), + stdin, + non_determinism, + 0, + ) + .unwrap(); + if native_output != vm_output.output { + panic!( + "native_output:\n {}, got:\n{}. Code was:\n{}", + native_output.iter().join(", "), + vm_output.output.iter().join(", "), + test_program.iter().join("\n") + ); + } + } +} diff --git a/src/tests_and_benchmarks/ozk/programs/other/simple_struct.rs b/src/tests_and_benchmarks/ozk/programs/other/simple_struct.rs new file mode 100644 index 00000000..bff0fac7 --- /dev/null +++ b/src/tests_and_benchmarks/ozk/programs/other/simple_struct.rs @@ -0,0 +1,13 @@ +use arbitrary::Arbitrary; +use tasm_lib::structure::tasm_object::TasmObject; +use triton_vm::{BFieldElement, Digest}; +use twenty_first::shared_math::bfield_codec::BFieldCodec; + +#[derive(TasmObject, BFieldCodec, Clone, Arbitrary)] +pub(super) struct SimpleStruct { + pub a: u128, + pub b: BFieldElement, + pub c: bool, + pub d: Vec, + pub e: Digest, +}