Skip to content

Commit

Permalink
Add support for is_err method for Result<T>
Browse files Browse the repository at this point in the history
  • Loading branch information
Sword-Smith committed Dec 5, 2023
1 parent 6fd40fa commit d096838
Show file tree
Hide file tree
Showing 3 changed files with 155 additions and 1 deletion.
49 changes: 48 additions & 1 deletion src/libraries/core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,12 @@ pub(crate) fn result_type(ok_type: ast_types::DataType) -> crate::composite_type
type_parameter: Some(ok_type.clone()),
};
let is_ok_method = is_ok_method(&enum_type);
let is_err_method = is_err_method(&enum_type);
let unwrap_method = unwrap_method(&enum_type);

crate::composite_types::TypeContext {
composite_type: enum_type.try_into().unwrap(),
methods: vec![is_ok_method, unwrap_method],
methods: vec![is_ok_method, is_err_method, unwrap_method],
associated_functions: vec![],
}
}
Expand Down Expand Up @@ -134,6 +135,52 @@ fn unwrap_method(enum_type: &ast_types::EnumType) -> ast::Method<Typing> {
}
}

fn is_err_method(enum_type: &ast_types::EnumType) -> ast::Method<Typing> {
let stack_size = enum_type.stack_size();
let swap_to_bottom = match stack_size {
0 => unreachable!(),
1 => triton_asm!(),
2..=16 => triton_asm!(swap { stack_size - 1 }),
_ => panic!("Can't handle this yet"), // This should work with spilling
};
let remove_data = match stack_size {
0 => unreachable!(),
1 => triton_asm!(pop),
2..=16 => {
let as_str = "pop\n".repeat(stack_size - 1);
triton_asm!({ as_str })
}
_ => panic!("Can't handle this yet"),
};
let is_ok_input_data_type = ast_types::DataType::Reference(Box::new(enum_type.into()));
let method_signature = ast::FnSignature {
name: "is_err".to_owned(),
args: vec![AbstractArgument::ValueArgument(AbstractValueArg {
name: "self".to_owned(),
data_type: is_ok_input_data_type,
mutable: false,
})],
output: ast_types::DataType::Bool,
arg_evaluation_order: Default::default(),
};

ast::Method {
body: crate::ast::RoutineBody::<Typing>::Instructions(triton_asm!(
// _ [ok_type] discriminant
{&swap_to_bottom}
// _ discriminant [ok_type']

{&remove_data}
// _ discriminant

push 0
eq
// _ (discriminant == 0 :== variant is 'Err')
)),
signature: method_signature,
}
}

fn is_ok_method(enum_type: &ast_types::EnumType) -> ast::Method<Typing> {
let stack_size = enum_type.stack_size();
let swap_to_bottom = match stack_size {
Expand Down
1 change: 1 addition & 0 deletions src/tests_and_benchmarks/ozk/programs/result_types.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
mod boxed_struct;
mod copy_types;
mod non_copy_types;
mod question_mark_operator;
Expand Down
106 changes: 106 additions & 0 deletions src/tests_and_benchmarks/ozk/programs/result_types/boxed_struct.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
use arbitrary::Arbitrary;
use triton_vm::Digest;
use twenty_first::shared_math::b_field_element::BFieldElement;
use twenty_first::shared_math::bfield_codec::BFieldCodec;
use twenty_first::shared_math::x_field_element::XFieldElement;

use crate::tests_and_benchmarks::ozk::rust_shadows as tasm;

#[derive(Arbitrary, BFieldCodec, Clone, Debug)]
struct NotCopyStruct {
digests: Vec<Digest>,
xfes: Vec<XFieldElement>,
}

#[allow(clippy::assertions_on_constants)]
#[allow(clippy::len_zero)]
fn main() {
let a_stack_value: BFieldElement = BFieldElement::new(404);
let boxed_enum_type: Box<NotCopyStruct> =
NotCopyStruct::decode(&tasm::load_from_memory(BFieldElement::new(84))).unwrap();

let as_ok: Result<Box<NotCopyStruct>, ()> = Ok(boxed_enum_type);
assert!(as_ok.is_ok());

match as_ok {
Result::Ok(inner) => {
tasm::tasm_io_write_to_stdout___u32(inner.digests.len() as u32);
tasm::tasm_io_write_to_stdout___u32(inner.xfes.len() as u32);
if inner.digests.len() > 0 {
tasm::tasm_io_write_to_stdout___digest(inner.digests[0]);
}
if inner.digests.len() > 1 {
tasm::tasm_io_write_to_stdout___digest(inner.digests[1]);
}
if inner.xfes.len() > 0 {
tasm::tasm_io_write_to_stdout___xfe(inner.xfes[0]);
}
if inner.xfes.len() > 1 {
tasm::tasm_io_write_to_stdout___xfe(inner.xfes[1]);
}
}
Result::Err(_) => {
assert!(false);
}
};

let as_err: Result<Box<NotCopyStruct>, ()> = Err(());
match as_err {
Result::Ok(_) => {
assert!(false);
}
Result::Err(_) => {
tasm::tasm_io_write_to_stdout___bfe(a_stack_value);
}
};

// assert!(as_ok.is_ok());
assert!(as_err.is_err());

return;
}

mod test {
use arbitrary::Unstructured;
use itertools::Itertools;
use rand::random;
use std::collections::HashMap;
use std::default::Default;

use crate::tests_and_benchmarks::ozk::{ozk_parsing, rust_shadows};
use crate::tests_and_benchmarks::test_helpers::shared_test::*;

use super::*;

#[test]
fn boxed_result_test() {
for _ in 0..5 {
let rand: [u8; 1000] = random();
let sv = NotCopyStruct::arbitrary(&mut Unstructured::new(&rand)).unwrap();
let non_determinism = init_memory_from(&sv, BFieldElement::new(84));
let stdin = vec![];

// Run program on host machine
let native_output =
rust_shadows::wrap_main_with_io(&main)(stdin.clone(), non_determinism.clone());
let test_program = ozk_parsing::compile_for_test(
"result_types",
"boxed_struct",
"main",
crate::ast_types::ListType::Unsafe,
);
println!("test_program:\n{}", test_program.iter().join("\n"));
let vm_output = execute_compiled_with_stack_memory_and_ins_for_test(
&test_program,
vec![],
&mut HashMap::default(),
stdin,
non_determinism.clone(),
0,
)
.unwrap();
assert_eq!(native_output, vm_output.output);
println!("native_output: {native_output:#?}");
}
}
}

0 comments on commit d096838

Please sign in to comment.