From d096838c8952000111931d9571bdfbb36c7658fe Mon Sep 17 00:00:00 2001 From: Thorkil Vaerge Date: Tue, 5 Dec 2023 19:02:53 +0100 Subject: [PATCH] Add support for `is_err` method for `Result` --- src/libraries/core.rs | 49 +++++++- .../ozk/programs/result_types.rs | 1 + .../ozk/programs/result_types/boxed_struct.rs | 106 ++++++++++++++++++ 3 files changed, 155 insertions(+), 1 deletion(-) create mode 100644 src/tests_and_benchmarks/ozk/programs/result_types/boxed_struct.rs diff --git a/src/libraries/core.rs b/src/libraries/core.rs index 35321252..95ea5796 100644 --- a/src/libraries/core.rs +++ b/src/libraries/core.rs @@ -33,11 +33,12 @@ pub(crate) fn result_type(ok_type: ast_types::DataType) -> crate::composite_type type_parameter: Some(ok_type.clone()), }; let is_ok_method = is_ok_method(&enum_type); + let is_err_method = is_err_method(&enum_type); let unwrap_method = unwrap_method(&enum_type); crate::composite_types::TypeContext { composite_type: enum_type.try_into().unwrap(), - methods: vec![is_ok_method, unwrap_method], + methods: vec![is_ok_method, is_err_method, unwrap_method], associated_functions: vec![], } } @@ -134,6 +135,52 @@ fn unwrap_method(enum_type: &ast_types::EnumType) -> ast::Method { } } +fn is_err_method(enum_type: &ast_types::EnumType) -> ast::Method { + let stack_size = enum_type.stack_size(); + let swap_to_bottom = match stack_size { + 0 => unreachable!(), + 1 => triton_asm!(), + 2..=16 => triton_asm!(swap { stack_size - 1 }), + _ => panic!("Can't handle this yet"), // This should work with spilling + }; + let remove_data = match stack_size { + 0 => unreachable!(), + 1 => triton_asm!(pop), + 2..=16 => { + let as_str = "pop\n".repeat(stack_size - 1); + triton_asm!({ as_str }) + } + _ => panic!("Can't handle this yet"), + }; + let is_ok_input_data_type = ast_types::DataType::Reference(Box::new(enum_type.into())); + let method_signature = ast::FnSignature { + name: "is_err".to_owned(), + args: vec![AbstractArgument::ValueArgument(AbstractValueArg { + name: "self".to_owned(), + data_type: is_ok_input_data_type, + mutable: false, + })], + output: ast_types::DataType::Bool, + arg_evaluation_order: Default::default(), + }; + + ast::Method { + body: crate::ast::RoutineBody::::Instructions(triton_asm!( + // _ [ok_type] discriminant + {&swap_to_bottom} + // _ discriminant [ok_type'] + + {&remove_data} + // _ discriminant + + push 0 + eq + // _ (discriminant == 0 :== variant is 'Err') + )), + signature: method_signature, + } +} + fn is_ok_method(enum_type: &ast_types::EnumType) -> ast::Method { let stack_size = enum_type.stack_size(); let swap_to_bottom = match stack_size { diff --git a/src/tests_and_benchmarks/ozk/programs/result_types.rs b/src/tests_and_benchmarks/ozk/programs/result_types.rs index b5928835..5fa0e4dc 100644 --- a/src/tests_and_benchmarks/ozk/programs/result_types.rs +++ b/src/tests_and_benchmarks/ozk/programs/result_types.rs @@ -1,3 +1,4 @@ +mod boxed_struct; mod copy_types; mod non_copy_types; mod question_mark_operator; diff --git a/src/tests_and_benchmarks/ozk/programs/result_types/boxed_struct.rs b/src/tests_and_benchmarks/ozk/programs/result_types/boxed_struct.rs new file mode 100644 index 00000000..57775c5b --- /dev/null +++ b/src/tests_and_benchmarks/ozk/programs/result_types/boxed_struct.rs @@ -0,0 +1,106 @@ +use arbitrary::Arbitrary; +use triton_vm::Digest; +use twenty_first::shared_math::b_field_element::BFieldElement; +use twenty_first::shared_math::bfield_codec::BFieldCodec; +use twenty_first::shared_math::x_field_element::XFieldElement; + +use crate::tests_and_benchmarks::ozk::rust_shadows as tasm; + +#[derive(Arbitrary, BFieldCodec, Clone, Debug)] +struct NotCopyStruct { + digests: Vec, + xfes: Vec, +} + +#[allow(clippy::assertions_on_constants)] +#[allow(clippy::len_zero)] +fn main() { + let a_stack_value: BFieldElement = BFieldElement::new(404); + let boxed_enum_type: Box = + NotCopyStruct::decode(&tasm::load_from_memory(BFieldElement::new(84))).unwrap(); + + let as_ok: Result, ()> = Ok(boxed_enum_type); + assert!(as_ok.is_ok()); + + match as_ok { + Result::Ok(inner) => { + tasm::tasm_io_write_to_stdout___u32(inner.digests.len() as u32); + tasm::tasm_io_write_to_stdout___u32(inner.xfes.len() as u32); + if inner.digests.len() > 0 { + tasm::tasm_io_write_to_stdout___digest(inner.digests[0]); + } + if inner.digests.len() > 1 { + tasm::tasm_io_write_to_stdout___digest(inner.digests[1]); + } + if inner.xfes.len() > 0 { + tasm::tasm_io_write_to_stdout___xfe(inner.xfes[0]); + } + if inner.xfes.len() > 1 { + tasm::tasm_io_write_to_stdout___xfe(inner.xfes[1]); + } + } + Result::Err(_) => { + assert!(false); + } + }; + + let as_err: Result, ()> = Err(()); + match as_err { + Result::Ok(_) => { + assert!(false); + } + Result::Err(_) => { + tasm::tasm_io_write_to_stdout___bfe(a_stack_value); + } + }; + + // assert!(as_ok.is_ok()); + assert!(as_err.is_err()); + + return; +} + +mod test { + use arbitrary::Unstructured; + use itertools::Itertools; + use rand::random; + use std::collections::HashMap; + use std::default::Default; + + use crate::tests_and_benchmarks::ozk::{ozk_parsing, rust_shadows}; + use crate::tests_and_benchmarks::test_helpers::shared_test::*; + + use super::*; + + #[test] + fn boxed_result_test() { + for _ in 0..5 { + let rand: [u8; 1000] = random(); + let sv = NotCopyStruct::arbitrary(&mut Unstructured::new(&rand)).unwrap(); + let non_determinism = init_memory_from(&sv, BFieldElement::new(84)); + let stdin = vec![]; + + // Run program on host machine + let native_output = + rust_shadows::wrap_main_with_io(&main)(stdin.clone(), non_determinism.clone()); + let test_program = ozk_parsing::compile_for_test( + "result_types", + "boxed_struct", + "main", + crate::ast_types::ListType::Unsafe, + ); + println!("test_program:\n{}", test_program.iter().join("\n")); + let vm_output = execute_compiled_with_stack_memory_and_ins_for_test( + &test_program, + vec![], + &mut HashMap::default(), + stdin, + non_determinism.clone(), + 0, + ) + .unwrap(); + assert_eq!(native_output, vm_output.output); + println!("native_output: {native_output:#?}"); + } + } +}