Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Rust] Subtype checking only for reference types (WIP) #341

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 30 additions & 34 deletions rust/candid/src/de.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,13 @@ impl<'de> IDLDeserialize<'de> {
{
let expected_type = self.de.table.trace_type(&expected_type)?;
if self.de.types.is_empty() {
if matches!(expected_type, Type::Opt(_) | Type::Reserved | Type::Null) {
if matches!(expected_type, Type::Opt(_) | Type::Reserved) {
self.de.expect_type = expected_type;
self.de.wire_type = Type::Null;
self.de.wire_type = Type::Reserved;
return T::deserialize(&mut self.de);
} else {
return Err(Error::msg(format!(
"No more values on the wire, the expected type {} is not opt, reserved or null",
"No more values on the wire, the expected type {} is not opt or reserved",
expected_type
)));
}
Expand All @@ -70,15 +70,6 @@ impl<'de> IDLDeserialize<'de> {
expected_type.clone()
};
self.de.wire_type = ty.clone();
self.de
.check_subtype()
.with_context(|| self.de.dump_state())
.with_context(|| {
format!(
"Fail to decode argument {} from {} to {}",
ind, ty, expected_type
)
})?;

let v = T::deserialize(&mut self.de)
.with_context(|| self.de.dump_state())
Expand Down Expand Up @@ -232,7 +223,6 @@ impl<'de> Deserializer<'de> {
{
use std::convert::TryInto;
self.unroll_type()?;
assert!(self.expect_type == Type::Int);
let mut bytes = vec![0u8];
let int = match &self.wire_type {
Type::Int => Int::decode(&mut self.input).map_err(Error::msg)?,
Expand All @@ -241,20 +231,20 @@ impl<'de> Deserializer<'de> {
.0
.try_into()
.map_err(Error::msg)?),
// We already did subtype checking before deserialize, so this is unreachable code
_ => assert!(false),
t => return Err(Error::msg(format!("{} cannot be deserialized to int", t))),
};
bytes.extend_from_slice(&int.0.to_signed_bytes_le());
assert!(self.expect_type == Type::Int);
visitor.visit_byte_buf(bytes)
}
fn deserialize_nat<'a, V>(&'a mut self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
self.unroll_type()?;
assert!(self.expect_type == Type::Nat && self.wire_type == Type::Nat);
let mut bytes = vec![1u8];
let nat = Nat::decode(&mut self.input).map_err(Error::msg)?;
assert!(self.expect_type == Type::Nat && self.wire_type == Type::Nat);
bytes.extend_from_slice(&nat.0.to_bytes_le());
visitor.visit_byte_buf(bytes)
}
Expand All @@ -263,9 +253,9 @@ impl<'de> Deserializer<'de> {
V: Visitor<'de>,
{
self.unroll_type()?;
assert!(self.expect_type == Type::Principal && self.wire_type == Type::Principal);
let mut bytes = vec![2u8];
let id = PrincipalBytes::read(&mut self.input)?.inner;
assert!(self.expect_type == Type::Principal && self.wire_type == Type::Principal);
bytes.extend_from_slice(&id);
visitor.visit_byte_buf(bytes)
}
Expand All @@ -281,9 +271,10 @@ impl<'de> Deserializer<'de> {
V: Visitor<'de>,
{
self.unroll_type()?;
assert!(matches!(self.wire_type, Type::Service(_)));
self.check_subtype()?;
let mut bytes = vec![4u8];
let id = PrincipalBytes::read(&mut self.input)?.inner;
assert!(matches!(self.wire_type, Type::Service(_)));
bytes.extend_from_slice(&id);
visitor.visit_byte_buf(bytes)
}
Expand All @@ -292,7 +283,7 @@ impl<'de> Deserializer<'de> {
V: Visitor<'de>,
{
self.unroll_type()?;
assert!(matches!(self.wire_type, Type::Func(_)));
self.check_subtype()?;
if !BoolValue::read(&mut self.input)?.0 {
return Err(Error::msg("Opaque reference not supported"));
}
Expand All @@ -302,6 +293,7 @@ impl<'de> Deserializer<'de> {
let meth = self.borrow_bytes(len)?;
// TODO find a better way
leb128::write::unsigned(&mut bytes, len as u64)?;
assert!(matches!(self.wire_type, Type::Func(_)));
bytes.extend_from_slice(meth);
bytes.extend_from_slice(&id);
visitor.visit_byte_buf(bytes)
Expand All @@ -320,9 +312,9 @@ macro_rules! primitive_impl {
fn [<deserialize_ $ty>]<V>(self, visitor: V) -> Result<V::Value>
where V: Visitor<'de> {
self.unroll_type()?;
assert!(self.expect_type == $type && self.wire_type == $type);
let val = self.input.$($value)*().map_err(|_| Error::msg(format!("Cannot read {} value", stringify!($type))))?;
//let val: $ty = self.input.read_le()?;
assert!(self.expect_type == $type && self.wire_type == $type);
visitor.[<visit_ $ty>](val)
}
}
Expand Down Expand Up @@ -404,7 +396,6 @@ impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> {
{
use std::convert::TryInto;
self.unroll_type()?;
assert!(self.expect_type == Type::Int);
let value: i128 = match &self.wire_type {
Type::Int => {
let int = Int::decode(&mut self.input).map_err(Error::msg)?;
Expand All @@ -414,8 +405,9 @@ impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> {
let nat = Nat::decode(&mut self.input).map_err(Error::msg)?;
nat.0.try_into().map_err(Error::msg)?
}
_ => assert!(false),
t => return Err(Error::msg(format!("{} cannot be deserialized to int", t))),
};
assert!(self.expect_type == Type::Int);
visitor.visit_i128(value)
}
fn deserialize_u128<V>(self, visitor: V) -> Result<V::Value>
Expand All @@ -424,9 +416,9 @@ impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> {
{
use std::convert::TryInto;
self.unroll_type()?;
assert!(self.expect_type == Type::Nat && self.wire_type == Type::Nat);
let nat = Nat::decode(&mut self.input).map_err(Error::msg)?;
let value: u128 = nat.0.try_into().map_err(Error::msg)?;
assert!(self.expect_type == Type::Nat && self.wire_type == Type::Nat);
visitor.visit_u128(value)
}
fn deserialize_unit<V>(self, visitor: V) -> Result<V::Value>
Expand All @@ -442,30 +434,30 @@ impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> {
V: Visitor<'de>,
{
self.unroll_type()?;
assert!(self.expect_type == Type::Bool && self.wire_type == Type::Bool);
let res = BoolValue::read(&mut self.input)?;
assert!(self.expect_type == Type::Bool && self.wire_type == Type::Bool);
visitor.visit_bool(res.0)
}
fn deserialize_string<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
self.unroll_type()?;
assert!(self.expect_type == Type::Text && self.wire_type == Type::Text);
let len = Len::read(&mut self.input)?.0;
let bytes = self.borrow_bytes(len)?.to_owned();
let value = String::from_utf8(bytes).map_err(Error::msg)?;
assert!(self.expect_type == Type::Text && self.wire_type == Type::Text);
visitor.visit_string(value)
}
fn deserialize_str<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
self.unroll_type()?;
assert!(self.expect_type == Type::Text && self.wire_type == Type::Text);
let len = Len::read(&mut self.input)?.0;
let slice = self.borrow_bytes(len)?;
let value: &str = std::str::from_utf8(slice).map_err(Error::msg)?;
assert!(self.expect_type == Type::Text && self.wire_type == Type::Text);
visitor.visit_borrowed_str(value)
}
fn deserialize_unit_struct<V>(self, _name: &'static str, visitor: V) -> Result<V::Value>
Expand All @@ -491,22 +483,26 @@ impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> {
self.wire_type = *t1.clone();
self.expect_type = *t2.clone();
if BoolValue::read(&mut self.input)?.0 {
visitor
.__private_visit_untagged_option(self)
.map_err(|_| Error::msg("cannot deserialize opt value"))
/*
if self.check_subtype().is_ok() {
visitor.visit_some(self)
} else {
self.deserialize_ignored_any(serde::de::IgnoredAny)?;
visitor.visit_none()
}
}*/
} else {
visitor.visit_none()
}
}
(_, Type::Opt(t2)) => {
self.expect_type = self.table.trace_type(&*t2)?;
if !matches!(self.expect_type, Type::Null | Type::Reserved | Type::Opt(_))
&& self.check_subtype().is_ok()
{
visitor.visit_some(self)
if !matches!(self.expect_type, Type::Null | Type::Reserved | Type::Opt(_)) {
visitor
.__private_visit_untagged_option(self)
.map_err(|_| Error::msg("cannot deserialize opt value"))
} else {
self.deserialize_ignored_any(serde::de::IgnoredAny)?;
visitor.visit_none()
Expand Down Expand Up @@ -547,13 +543,13 @@ impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> {
}
fn deserialize_byte_buf<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
self.unroll_type()?;
let len = Len::read(&mut self.input)?.0;
let bytes = self.borrow_bytes(len)?.to_owned();
//let bytes = Bytes::read(&mut self.input)?.inner;
assert!(
self.expect_type == Type::Vec(Box::new(Type::Nat8))
&& self.wire_type == Type::Vec(Box::new(Type::Nat8))
);
let len = Len::read(&mut self.input)?.0;
let bytes = self.borrow_bytes(len)?.to_owned();
//let bytes = Bytes::read(&mut self.input)?.inner;
visitor.visit_byte_buf(bytes)
}
fn deserialize_bytes<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
Expand Down
9 changes: 9 additions & 0 deletions rust/candid/src/parser/value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,15 @@ impl<'de> Visitor<'de> for IDLValueVisitor {
let v = Deserialize::deserialize(deserializer)?;
Ok(IDLValue::Opt(Box::new(v)))
}
fn __private_visit_untagged_option<D>(self, deserializer: D) -> DResult<()>
where
D: serde::Deserializer<'de>,
{
Ok(match Deserialize::deserialize(deserializer) {
Ok(v) => IDLValue::Opt(Box::new(v)),
Err(_) => IDLValue::None,
})
}
fn visit_unit<E>(self) -> DResult<E> {
Ok(IDLValue::Null)
}
Expand Down
4 changes: 2 additions & 2 deletions rust/candid/tests/serde.rs
Original file line number Diff line number Diff line change
Expand Up @@ -588,12 +588,12 @@ fn test_multiargs() {
Vec<(Int, &str)>,
(Int, String),
Option<i32>,
(),
candid::Reserved,
candid::Reserved
)
.unwrap();
assert_eq!(tuple.2, None);
assert_eq!(tuple.3, ());
assert_eq!(tuple.3, candid::Reserved);
assert_eq!(tuple.4, candid::Reserved);
}

Expand Down
10 changes: 5 additions & 5 deletions test/construct.test.did
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ assert blob "DIDL\00\00" == "(null)" : (Opt) "op

// vector
assert blob "DIDL\01\6d\7c\01\00\00" == "(vec {})" : (vec int) "vec: empty";
assert blob "DIDL\01\6d\7c\01\00\00" !: (vec int8) "vec: non subtype empty";
assert blob "DIDL\01\6d\7c\01\00\00" : (vec int8) "vec: non subtype empty";
assert blob "DIDL\01\6d\7c\01\00\02\01\02" == "(vec { 1; 2 })" : (vec int) "vec";
assert blob "DIDL\01\6d\7b\01\00\02\01\02" == "(blob \"\\01\\02\")" : (vec nat8) "vec: blob";
assert blob "DIDL\01\6d\00\01\00\00" == "(vec {})" : (Vec) "vec: recursive vector";
Expand Down Expand Up @@ -139,10 +139,10 @@ assert "(variant {})" !
assert blob "DIDL\01\6b\00\01\00" !: (variant {}) "variant: no empty value";
assert blob "DIDL\01\6b\01\00\7f\01\00\00" == "(variant {0})" : (variant {0}) "variant: numbered field";
assert blob "DIDL\01\6b\01\00\7f\01\00\00\2a" !: (variant {0:int}) "variant: type mismatch";
assert blob "DIDL\01\6b\02\00\7f\01\7c\01\00\01\2a" !: (variant {0:int; 1:int}) "variant: type mismatch in unused tag";
assert blob "DIDL\01\6b\02\00\7f\01\7c\01\00\01\2a" : (variant {0:int; 1:int}) "variant: type mismatch in unused tag";
assert blob "DIDL\01\6b\01\00\7f\01\00\00" == "(variant {0})" : (variant {0;1}) "variant: ignore field";
assert blob "DIDL\01\6b\02\00\7f\01\7f\01\00\00" !: (variant {0}) "variant {0;1} !<: variant {0}";
assert blob "DIDL\01\6b\02\00\7f\01\7f\01\00\00" == "(null)" : (opt variant {0}) "variant {0;1} <: opt variant {0}";
assert blob "DIDL\01\6b\02\00\7f\01\7f\01\00\00" : (variant {0}) "variant {0;1} <: variant {0}";
assert blob "DIDL\01\6b\02\00\7f\01\7f\01\00\00" == "(variant {0})" : (opt variant {0}) "variant {0;1} <: opt variant {0}";
assert blob "DIDL\01\6b\02\00\7f\01\7f\01\00\01" == "(variant {1})" : (variant {0;1;2}) "variant: change index";
assert blob "DIDL\01\6b\01\00\7f\01\00\00" !: (variant {1}) "variant: missing field";
assert blob "DIDL\01\6b\01\00\7f\01\00\01" !: (variant {0}) "variant: index out of range";
Expand Down Expand Up @@ -188,7 +188,7 @@ assert blob "DIDL\02\6b\02\d1\a7\cf\02\7f\f1\f3\92\8e\04\01\6c\02\a0\d2\ac\a8\04
== "(variant { cons = record { head = 1; tail = variant { cons = record { head = 2; tail = variant { nil } } } } })" : (VariantList) "variant: list";

assert blob "DIDL\02\6b\02\d1\a7\cf\02\7f\f1\f3\92\8e\04\01\6c\02\a0\d2\ac\a8\04\7c\90\ed\da\e7\04\00\01\00\00"
== "(variant {nil}, null, null, null, null)" : (VariantList, opt VariantList, null, reserved, opt int) "variant: extra args";
== "(variant {nil}, null, null, null, null)" : (VariantList, opt VariantList, opt empty, reserved, opt int) "variant: extra args";

assert blob "DIDL\02\6b\02\d1\a7\cf\02\7f\f1\f3\92\8e\04\01\6c\02\a0\d2\ac\a8\04\7c\90\ed\da\e7\04\00\01\00\00" !: (VariantList, opt int, vec int) "non-null extra args";

Expand Down
2 changes: 1 addition & 1 deletion test/prim.test.did
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ assert blob "DIDL\00\01\5e" !: () "Out of range type";
assert blob "DIDL\00\01\7f" : (null);
assert blob "DIDL\00\01\7e" !: (null) "wrong type";
assert blob "DIDL\00\01\7f\00" !: (null) "null: too long";
assert blob "DIDL\00\00" : (null) "null: extra null values";
assert blob "DIDL\00\00" !: (null) "null: extra null values";

assert blob "DIDL\00\01\7e\00" == "(false)" : (bool) "bool: false";
assert blob "DIDL\00\01\7e\01" == "(true)" : (bool) "bool: true";
Expand Down