From 9f890cefc4c4c2898ba0463317efbc2b921af29b Mon Sep 17 00:00:00 2001 From: Yan Chen Date: Tue, 19 Apr 2022 09:43:48 -0700 Subject: [PATCH 1/4] extra args cannot be null --- rust/candid/src/de.rs | 6 +++--- rust/candid/tests/serde.rs | 4 ++-- test/construct.test.did | 2 +- test/prim.test.did | 2 +- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/rust/candid/src/de.rs b/rust/candid/src/de.rs index 8efc298c..a9f25a4a 100644 --- a/rust/candid/src/de.rs +++ b/rust/candid/src/de.rs @@ -50,13 +50,13 @@ impl<'de> IDLDeserialize<'de> { { let expected_type = self.de.table.trace_type(&expected_type)?; if self.de.types.is_empty() { - if matches!(expected_type, Type::Opt(_) | Type::Reserved | Type::Null) { + if matches!(expected_type, Type::Opt(_) | Type::Reserved) { self.de.expect_type = expected_type; - self.de.wire_type = Type::Null; + self.de.wire_type = Type::Reserved; return T::deserialize(&mut self.de); } else { return Err(Error::msg(format!( - "No more values on the wire, the expected type {} is not opt, reserved or null", + "No more values on the wire, the expected type {} is not opt or reserved", expected_type ))); } diff --git a/rust/candid/tests/serde.rs b/rust/candid/tests/serde.rs index 23b4432c..cf24faa3 100644 --- a/rust/candid/tests/serde.rs +++ b/rust/candid/tests/serde.rs @@ -588,12 +588,12 @@ fn test_multiargs() { Vec<(Int, &str)>, (Int, String), Option, - (), + candid::Reserved, candid::Reserved ) .unwrap(); assert_eq!(tuple.2, None); - assert_eq!(tuple.3, ()); + assert_eq!(tuple.3, candid::Reserved); assert_eq!(tuple.4, candid::Reserved); } diff --git a/test/construct.test.did b/test/construct.test.did index 397da032..0b0d7f9b 100644 --- a/test/construct.test.did +++ b/test/construct.test.did @@ -188,7 +188,7 @@ assert blob "DIDL\02\6b\02\d1\a7\cf\02\7f\f1\f3\92\8e\04\01\6c\02\a0\d2\ac\a8\04 == "(variant { cons = record { head = 1; tail = variant { cons = record { head = 2; tail = variant { nil } } } } })" : (VariantList) "variant: list"; assert blob "DIDL\02\6b\02\d1\a7\cf\02\7f\f1\f3\92\8e\04\01\6c\02\a0\d2\ac\a8\04\7c\90\ed\da\e7\04\00\01\00\00" - == "(variant {nil}, null, null, null, null)" : (VariantList, opt VariantList, null, reserved, opt int) "variant: extra args"; + == "(variant {nil}, null, null, null, null)" : (VariantList, opt VariantList, reserved, reserved, opt int) "variant: extra args"; assert blob "DIDL\02\6b\02\d1\a7\cf\02\7f\f1\f3\92\8e\04\01\6c\02\a0\d2\ac\a8\04\7c\90\ed\da\e7\04\00\01\00\00" !: (VariantList, opt int, vec int) "non-null extra args"; diff --git a/test/prim.test.did b/test/prim.test.did index 2fc1533e..dde6b696 100644 --- a/test/prim.test.did +++ b/test/prim.test.did @@ -23,7 +23,7 @@ assert blob "DIDL\00\01\5e" !: () "Out of range type"; assert blob "DIDL\00\01\7f" : (null); assert blob "DIDL\00\01\7e" !: (null) "wrong type"; assert blob "DIDL\00\01\7f\00" !: (null) "null: too long"; -assert blob "DIDL\00\00" : (null) "null: extra null values"; +assert blob "DIDL\00\00" !: (null) "null: extra null values"; assert blob "DIDL\00\01\7e\00" == "(false)" : (bool) "bool: false"; assert blob "DIDL\00\01\7e\01" == "(true)" : (bool) "bool: true"; From 2af87f66ae67735c41f99507921f8b8a8890bbe1 Mon Sep 17 00:00:00 2001 From: Yan Chen Date: Thu, 21 Apr 2022 17:28:27 -0700 Subject: [PATCH 2/4] checkpoint --- rust/candid/src/de.rs | 31 ++++++++++++++----------------- rust/candid/src/parser/value.rs | 9 +++++++++ 2 files changed, 23 insertions(+), 17 deletions(-) diff --git a/rust/candid/src/de.rs b/rust/candid/src/de.rs index a9f25a4a..cda95bea 100644 --- a/rust/candid/src/de.rs +++ b/rust/candid/src/de.rs @@ -70,15 +70,6 @@ impl<'de> IDLDeserialize<'de> { expected_type.clone() }; self.de.wire_type = ty.clone(); - self.de - .check_subtype() - .with_context(|| self.de.dump_state()) - .with_context(|| { - format!( - "Fail to decode argument {} from {} to {}", - ind, ty, expected_type - ) - })?; let v = T::deserialize(&mut self.de) .with_context(|| self.de.dump_state()) @@ -241,8 +232,7 @@ impl<'de> Deserializer<'de> { .0 .try_into() .map_err(Error::msg)?), - // We already did subtype checking before deserialize, so this is unreachable code - _ => assert!(false), + t => return Err(Error::msg(format!("{} cannot be deserialized to int", t))), }; bytes.extend_from_slice(&int.0.to_signed_bytes_le()); visitor.visit_byte_buf(bytes) @@ -282,6 +272,7 @@ impl<'de> Deserializer<'de> { { self.unroll_type()?; assert!(matches!(self.wire_type, Type::Service(_))); + self.check_subtype()?; let mut bytes = vec![4u8]; let id = PrincipalBytes::read(&mut self.input)?.inner; bytes.extend_from_slice(&id); @@ -293,6 +284,7 @@ impl<'de> Deserializer<'de> { { self.unroll_type()?; assert!(matches!(self.wire_type, Type::Func(_))); + self.check_subtype()?; if !BoolValue::read(&mut self.input)?.0 { return Err(Error::msg("Opaque reference not supported")); } @@ -414,7 +406,7 @@ impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { let nat = Nat::decode(&mut self.input).map_err(Error::msg)?; nat.0.try_into().map_err(Error::msg)? } - _ => assert!(false), + t => return Err(Error::msg(format!("{} cannot be deserialized to int", t))), }; visitor.visit_i128(value) } @@ -491,22 +483,27 @@ impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { self.wire_type = *t1.clone(); self.expect_type = *t2.clone(); if BoolValue::read(&mut self.input)?.0 { + // this visitor is the same as visit_some, but converts Err to None + visitor + .__private_visit_untagged_option(self) + .map_err(|_| Error::msg("cannot deserialize opt value")) + /* if self.check_subtype().is_ok() { visitor.visit_some(self) } else { self.deserialize_ignored_any(serde::de::IgnoredAny)?; visitor.visit_none() - } + }*/ } else { visitor.visit_none() } } (_, Type::Opt(t2)) => { self.expect_type = self.table.trace_type(&*t2)?; - if !matches!(self.expect_type, Type::Null | Type::Reserved | Type::Opt(_)) - && self.check_subtype().is_ok() - { - visitor.visit_some(self) + if !matches!(self.expect_type, Type::Null | Type::Reserved | Type::Opt(_)) { + visitor + .__private_visit_untagged_option(self) + .map_err(|_| Error::msg("cannot deserialize opt value")) } else { self.deserialize_ignored_any(serde::de::IgnoredAny)?; visitor.visit_none() diff --git a/rust/candid/src/parser/value.rs b/rust/candid/src/parser/value.rs index a0c95053..4d2a1d8a 100644 --- a/rust/candid/src/parser/value.rs +++ b/rust/candid/src/parser/value.rs @@ -495,6 +495,15 @@ impl<'de> Visitor<'de> for IDLValueVisitor { let v = Deserialize::deserialize(deserializer)?; Ok(IDLValue::Opt(Box::new(v))) } + fn __private_visit_untagged_option(self, deserializer: D) -> DResult<()> + where + D: serde::Deserializer<'de>, + { + Ok(match Deserialize::deserialize(deserializer) { + Ok(v) => IDLValue::Opt(Box::new(v)), + Err(_) => IDLValue::None, + }) + } fn visit_unit(self) -> DResult { Ok(IDLValue::Null) } From bbe4ece75df53869b3c59d63cf5d4b2b80152532 Mon Sep 17 00:00:00 2001 From: Yan Chen Date: Fri, 22 Apr 2022 15:42:20 -0700 Subject: [PATCH 3/4] update test --- rust/candid/src/de.rs | 2 ++ test/construct.test.did | 10 +++++----- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/rust/candid/src/de.rs b/rust/candid/src/de.rs index cda95bea..e6f4cece 100644 --- a/rust/candid/src/de.rs +++ b/rust/candid/src/de.rs @@ -484,6 +484,8 @@ impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { self.expect_type = *t2.clone(); if BoolValue::read(&mut self.input)?.0 { // this visitor is the same as visit_some, but converts Err to None + // TODO the problem is that when decoding fails, it doesn't finish reading the whole part. + // Then there is a mismatch with next read visitor .__private_visit_untagged_option(self) .map_err(|_| Error::msg("cannot deserialize opt value")) diff --git a/test/construct.test.did b/test/construct.test.did index 0b0d7f9b..5d597041 100644 --- a/test/construct.test.did +++ b/test/construct.test.did @@ -54,7 +54,7 @@ assert blob "DIDL\00\00" == "(null)" : (Opt) "op // vector assert blob "DIDL\01\6d\7c\01\00\00" == "(vec {})" : (vec int) "vec: empty"; -assert blob "DIDL\01\6d\7c\01\00\00" !: (vec int8) "vec: non subtype empty"; +assert blob "DIDL\01\6d\7c\01\00\00" : (vec int8) "vec: non subtype empty"; assert blob "DIDL\01\6d\7c\01\00\02\01\02" == "(vec { 1; 2 })" : (vec int) "vec"; assert blob "DIDL\01\6d\7b\01\00\02\01\02" == "(blob \"\\01\\02\")" : (vec nat8) "vec: blob"; assert blob "DIDL\01\6d\00\01\00\00" == "(vec {})" : (Vec) "vec: recursive vector"; @@ -139,10 +139,10 @@ assert "(variant {})" ! assert blob "DIDL\01\6b\00\01\00" !: (variant {}) "variant: no empty value"; assert blob "DIDL\01\6b\01\00\7f\01\00\00" == "(variant {0})" : (variant {0}) "variant: numbered field"; assert blob "DIDL\01\6b\01\00\7f\01\00\00\2a" !: (variant {0:int}) "variant: type mismatch"; -assert blob "DIDL\01\6b\02\00\7f\01\7c\01\00\01\2a" !: (variant {0:int; 1:int}) "variant: type mismatch in unused tag"; +assert blob "DIDL\01\6b\02\00\7f\01\7c\01\00\01\2a" : (variant {0:int; 1:int}) "variant: type mismatch in unused tag"; assert blob "DIDL\01\6b\01\00\7f\01\00\00" == "(variant {0})" : (variant {0;1}) "variant: ignore field"; -assert blob "DIDL\01\6b\02\00\7f\01\7f\01\00\00" !: (variant {0}) "variant {0;1} !<: variant {0}"; -assert blob "DIDL\01\6b\02\00\7f\01\7f\01\00\00" == "(null)" : (opt variant {0}) "variant {0;1} <: opt variant {0}"; +assert blob "DIDL\01\6b\02\00\7f\01\7f\01\00\00" : (variant {0}) "variant {0;1} <: variant {0}"; +assert blob "DIDL\01\6b\02\00\7f\01\7f\01\00\00" == "(variant {0})" : (opt variant {0}) "variant {0;1} <: opt variant {0}"; assert blob "DIDL\01\6b\02\00\7f\01\7f\01\00\01" == "(variant {1})" : (variant {0;1;2}) "variant: change index"; assert blob "DIDL\01\6b\01\00\7f\01\00\00" !: (variant {1}) "variant: missing field"; assert blob "DIDL\01\6b\01\00\7f\01\00\01" !: (variant {0}) "variant: index out of range"; @@ -188,7 +188,7 @@ assert blob "DIDL\02\6b\02\d1\a7\cf\02\7f\f1\f3\92\8e\04\01\6c\02\a0\d2\ac\a8\04 == "(variant { cons = record { head = 1; tail = variant { cons = record { head = 2; tail = variant { nil } } } } })" : (VariantList) "variant: list"; assert blob "DIDL\02\6b\02\d1\a7\cf\02\7f\f1\f3\92\8e\04\01\6c\02\a0\d2\ac\a8\04\7c\90\ed\da\e7\04\00\01\00\00" - == "(variant {nil}, null, null, null, null)" : (VariantList, opt VariantList, reserved, reserved, opt int) "variant: extra args"; + == "(variant {nil}, null, null, null, null)" : (VariantList, opt VariantList, opt empty, reserved, opt int) "variant: extra args"; assert blob "DIDL\02\6b\02\d1\a7\cf\02\7f\f1\f3\92\8e\04\01\6c\02\a0\d2\ac\a8\04\7c\90\ed\da\e7\04\00\01\00\00" !: (VariantList, opt int, vec int) "non-null extra args"; From 7366db006904e7363eb7158c32d332fe23e11ac0 Mon Sep 17 00:00:00 2001 From: Yan Chen Date: Mon, 25 Apr 2022 12:20:27 -0700 Subject: [PATCH 4/4] reorder assert --- rust/candid/src/de.rs | 31 ++++++++++++++----------------- 1 file changed, 14 insertions(+), 17 deletions(-) diff --git a/rust/candid/src/de.rs b/rust/candid/src/de.rs index e6f4cece..f2407799 100644 --- a/rust/candid/src/de.rs +++ b/rust/candid/src/de.rs @@ -223,7 +223,6 @@ impl<'de> Deserializer<'de> { { use std::convert::TryInto; self.unroll_type()?; - assert!(self.expect_type == Type::Int); let mut bytes = vec![0u8]; let int = match &self.wire_type { Type::Int => Int::decode(&mut self.input).map_err(Error::msg)?, @@ -235,6 +234,7 @@ impl<'de> Deserializer<'de> { t => return Err(Error::msg(format!("{} cannot be deserialized to int", t))), }; bytes.extend_from_slice(&int.0.to_signed_bytes_le()); + assert!(self.expect_type == Type::Int); visitor.visit_byte_buf(bytes) } fn deserialize_nat<'a, V>(&'a mut self, visitor: V) -> Result @@ -242,9 +242,9 @@ impl<'de> Deserializer<'de> { V: Visitor<'de>, { self.unroll_type()?; - assert!(self.expect_type == Type::Nat && self.wire_type == Type::Nat); let mut bytes = vec![1u8]; let nat = Nat::decode(&mut self.input).map_err(Error::msg)?; + assert!(self.expect_type == Type::Nat && self.wire_type == Type::Nat); bytes.extend_from_slice(&nat.0.to_bytes_le()); visitor.visit_byte_buf(bytes) } @@ -253,9 +253,9 @@ impl<'de> Deserializer<'de> { V: Visitor<'de>, { self.unroll_type()?; - assert!(self.expect_type == Type::Principal && self.wire_type == Type::Principal); let mut bytes = vec![2u8]; let id = PrincipalBytes::read(&mut self.input)?.inner; + assert!(self.expect_type == Type::Principal && self.wire_type == Type::Principal); bytes.extend_from_slice(&id); visitor.visit_byte_buf(bytes) } @@ -271,10 +271,10 @@ impl<'de> Deserializer<'de> { V: Visitor<'de>, { self.unroll_type()?; - assert!(matches!(self.wire_type, Type::Service(_))); self.check_subtype()?; let mut bytes = vec![4u8]; let id = PrincipalBytes::read(&mut self.input)?.inner; + assert!(matches!(self.wire_type, Type::Service(_))); bytes.extend_from_slice(&id); visitor.visit_byte_buf(bytes) } @@ -283,7 +283,6 @@ impl<'de> Deserializer<'de> { V: Visitor<'de>, { self.unroll_type()?; - assert!(matches!(self.wire_type, Type::Func(_))); self.check_subtype()?; if !BoolValue::read(&mut self.input)?.0 { return Err(Error::msg("Opaque reference not supported")); @@ -294,6 +293,7 @@ impl<'de> Deserializer<'de> { let meth = self.borrow_bytes(len)?; // TODO find a better way leb128::write::unsigned(&mut bytes, len as u64)?; + assert!(matches!(self.wire_type, Type::Func(_))); bytes.extend_from_slice(meth); bytes.extend_from_slice(&id); visitor.visit_byte_buf(bytes) @@ -312,9 +312,9 @@ macro_rules! primitive_impl { fn [](self, visitor: V) -> Result where V: Visitor<'de> { self.unroll_type()?; - assert!(self.expect_type == $type && self.wire_type == $type); let val = self.input.$($value)*().map_err(|_| Error::msg(format!("Cannot read {} value", stringify!($type))))?; //let val: $ty = self.input.read_le()?; + assert!(self.expect_type == $type && self.wire_type == $type); visitor.[](val) } } @@ -396,7 +396,6 @@ impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { { use std::convert::TryInto; self.unroll_type()?; - assert!(self.expect_type == Type::Int); let value: i128 = match &self.wire_type { Type::Int => { let int = Int::decode(&mut self.input).map_err(Error::msg)?; @@ -408,6 +407,7 @@ impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { } t => return Err(Error::msg(format!("{} cannot be deserialized to int", t))), }; + assert!(self.expect_type == Type::Int); visitor.visit_i128(value) } fn deserialize_u128(self, visitor: V) -> Result @@ -416,9 +416,9 @@ impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { { use std::convert::TryInto; self.unroll_type()?; - assert!(self.expect_type == Type::Nat && self.wire_type == Type::Nat); let nat = Nat::decode(&mut self.input).map_err(Error::msg)?; let value: u128 = nat.0.try_into().map_err(Error::msg)?; + assert!(self.expect_type == Type::Nat && self.wire_type == Type::Nat); visitor.visit_u128(value) } fn deserialize_unit(self, visitor: V) -> Result @@ -434,8 +434,8 @@ impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { V: Visitor<'de>, { self.unroll_type()?; - assert!(self.expect_type == Type::Bool && self.wire_type == Type::Bool); let res = BoolValue::read(&mut self.input)?; + assert!(self.expect_type == Type::Bool && self.wire_type == Type::Bool); visitor.visit_bool(res.0) } fn deserialize_string(self, visitor: V) -> Result @@ -443,10 +443,10 @@ impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { V: Visitor<'de>, { self.unroll_type()?; - assert!(self.expect_type == Type::Text && self.wire_type == Type::Text); let len = Len::read(&mut self.input)?.0; let bytes = self.borrow_bytes(len)?.to_owned(); let value = String::from_utf8(bytes).map_err(Error::msg)?; + assert!(self.expect_type == Type::Text && self.wire_type == Type::Text); visitor.visit_string(value) } fn deserialize_str(self, visitor: V) -> Result @@ -454,10 +454,10 @@ impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { V: Visitor<'de>, { self.unroll_type()?; - assert!(self.expect_type == Type::Text && self.wire_type == Type::Text); let len = Len::read(&mut self.input)?.0; let slice = self.borrow_bytes(len)?; let value: &str = std::str::from_utf8(slice).map_err(Error::msg)?; + assert!(self.expect_type == Type::Text && self.wire_type == Type::Text); visitor.visit_borrowed_str(value) } fn deserialize_unit_struct(self, _name: &'static str, visitor: V) -> Result @@ -483,9 +483,6 @@ impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { self.wire_type = *t1.clone(); self.expect_type = *t2.clone(); if BoolValue::read(&mut self.input)?.0 { - // this visitor is the same as visit_some, but converts Err to None - // TODO the problem is that when decoding fails, it doesn't finish reading the whole part. - // Then there is a mismatch with next read visitor .__private_visit_untagged_option(self) .map_err(|_| Error::msg("cannot deserialize opt value")) @@ -546,13 +543,13 @@ impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { } fn deserialize_byte_buf>(self, visitor: V) -> Result { self.unroll_type()?; + let len = Len::read(&mut self.input)?.0; + let bytes = self.borrow_bytes(len)?.to_owned(); + //let bytes = Bytes::read(&mut self.input)?.inner; assert!( self.expect_type == Type::Vec(Box::new(Type::Nat8)) && self.wire_type == Type::Vec(Box::new(Type::Nat8)) ); - let len = Len::read(&mut self.input)?.0; - let bytes = self.borrow_bytes(len)?.to_owned(); - //let bytes = Bytes::read(&mut self.input)?.inner; visitor.visit_byte_buf(bytes) } fn deserialize_bytes>(self, visitor: V) -> Result {