diff --git a/source/postcard-dyn/Cargo.toml b/source/postcard-dyn/Cargo.toml index acbab01..47cd057 100644 --- a/source/postcard-dyn/Cargo.toml +++ b/source/postcard-dyn/Cargo.toml @@ -19,7 +19,9 @@ documentation = "https://docs.rs/postcard-dyn/" [dependencies] +hashbrown = { version = "0.15.2", default-features = false, features = ["default-hasher"] } serde = { version = "1.0.202", features = ["derive"] } +serde-content = "0.1.0" serde_json = "1.0.117" [dependencies.postcard] @@ -31,3 +33,7 @@ path = "../postcard" version = "0.2" features = ["use-std", "derive"] path = "../postcard-schema" + +[dev-dependencies.serde_json] +version = "1.0" +features = ["preserve_order"] diff --git a/source/postcard-dyn/src/de.rs b/source/postcard-dyn/src/de.rs index 1d56dac..fcd4c6e 100644 --- a/source/postcard-dyn/src/de.rs +++ b/source/postcard-dyn/src/de.rs @@ -6,7 +6,7 @@ use crate::Error; pub fn from_slice_dyn( schema: &OwnedNamedType, data: &[u8], -) -> Result> { +) -> Result> { // Matches current value type (`serde_json::Value`)'s representation crate::reserialize::lossy::reserialize_with_structs_and_enums_as_maps( schema, diff --git a/source/postcard-dyn/src/error.rs b/source/postcard-dyn/src/error.rs index 711283a..6c7bf1d 100644 --- a/source/postcard-dyn/src/error.rs +++ b/source/postcard-dyn/src/error.rs @@ -2,12 +2,12 @@ use core::fmt::{self, Display}; /// Errors encountered by `postcard-dyn` #[derive(Debug, Clone, PartialEq, Eq)] -pub enum Error { - Deserialize(postcard::Error), +pub enum Error { + Deserialize(DeserializeError), Serialize(SerializeError), } -impl Display for Error { +impl Display for Error { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { Self::Deserialize(err) => Display::fmt(err, f), @@ -16,7 +16,7 @@ impl Display for Error { } } -impl core::error::Error for Error { +impl core::error::Error for Error { fn source(&self) -> Option<&(dyn core::error::Error + 'static)> { match self { Self::Deserialize(err) => err.source(), diff --git a/source/postcard-dyn/src/reserialize.rs b/source/postcard-dyn/src/reserialize.rs index 51c086d..dcab9fe 100644 --- a/source/postcard-dyn/src/reserialize.rs +++ b/source/postcard-dyn/src/reserialize.rs @@ -1,293 +1,420 @@ -//! Dynamically reserialize [`postcard`]-encoded values into a [`Serializer`]. +//! Dynamically reserialize [`postcard`]-encoded values from [`Deserializer`]s into [`Serializer`]s. //! -//! This module implements transformations from postcard-encoded data to other serialized forms +//! This module implements transformations between postcard-encoded data and other serialized forms //! based on [dynamic schemas](postcard_schema::schema::owned). For example, this could be used to -//! transform postcard-encoded data to JSON or another human-readable format. +//! transform postcard-encoded data to JSON or another human-readable format, or to transform JSON +//! to postcard. //! //! # Limitations //! -//! Several [`Serializer`] methods require `&'static str`s parameters: namely, those for -//! serializing structs and enums require `&'static str`s for the names of structs, fields, enums, -//! and variants. Since these transformations work with dynamic schemas that contain [`String`]s -//! instead of `&'static str`s, lossless reserialization is possible only with other compromises. +//! Several [`Deserializer`] and [`Serializer`] methods require `&'static` parameters, namely: +//! - [`Serializer`] methods for serializing structs and enums require `&'static str`s for the +//! names of structs, fields, enums, and variants. +//! - [`Deserializer`] methods for deserializing structs and enums require the same `&'static str`s +//! as the corresponding serialize methods, and moreover `&'static [&'static str]`s for the names +//! of fields in structs. +//! +//! Since these transformations work with dynamic schemas that contain [`String`]s instead of +//! `&'static str`s, lossless reserialization is possible only with other compromises. //! //! In particular, reserialization can be either: //! - Lossless with implementation compromises: see [`lossless`] //! - Lossy with regards to structs and enums: see [`lossy`] -use core::cell::{Cell, RefCell}; +use core::{cell::Cell, fmt, marker::PhantomData, slice}; -use postcard::{de_flavors::Flavor, Deserializer}; use postcard_schema::schema::owned::{OwnedDataModelType, OwnedNamedType}; -use serde::{ - ser::{Error as _, SerializeMap, SerializeSeq, SerializeTuple}, - Deserialize, Serialize, Serializer, -}; +use serde::{de, ser::Error as _, Deserialize, Deserializer, Serialize, Serializer}; use crate::Error; pub mod lossless; pub mod lossy; -fn reserialize<'de, F, S>( - schema: &OwnedNamedType, - deserializer: &mut Deserializer<'de, F>, - serializer: S, - structs_and_enums: impl structs_and_enums::Strategy, -) -> Result> -where - F: Flavor<'de>, - S: Serializer, -{ - let reserializer = Reserialize { - schema, - deserializer: &RefCell::new(deserializer), - deserializer_error: &Cell::new(None), - strategy: &structs_and_enums, - }; - match reserializer.serialize(serializer) { - Ok(out) => { - debug_assert_eq!(reserializer.deserializer_error.take(), None); - Ok(out) - } - Err(err) => { - if let Some(err) = reserializer.deserializer_error.take() { - Err(Error::Deserialize(err)) - } else { - Err(Error::Serialize(err)) - } - } - } -} +mod expecting; +mod strategy; +use strategy::Strategy; -struct Reserialize<'a, 'de, 'deserializer, F, Strategy> -where - F: Flavor<'de> + 'de, -{ - schema: &'a OwnedNamedType, - deserializer: &'a RefCell<&'deserializer mut postcard::Deserializer<'de, F>>, - deserializer_error: &'a Cell>, +mod map; +mod option; +mod seq; +mod tuple; + +struct Context<'a, Strategy> { strategy: &'a Strategy, } -impl<'a, 'de, 'deserializer, F, Strategy> serde::Serialize - for Reserialize<'a, 'de, 'deserializer, F, Strategy> -where - F: Flavor<'de> + 'de, - Strategy: structs_and_enums::Strategy, -{ +struct Reserialize { + f: Cell>, + deserializer_error: Cell>, +} + +trait ReserializeFn { + type DeserializeError: de::Error; + + fn reserialize( + self, + serializer: S, + ) -> Result>; +} + +impl serde::Serialize for Reserialize { fn serialize(&self, serializer: S) -> Result { - match self.serialize_inner(serializer) { + let f = self.f.take().unwrap(); + match f.reserialize(serializer) { Ok(out) => Ok(out), Err(Error::Serialize(err)) => Err(err), Err(Error::Deserialize(err)) => { - self.deserializer_error.set(Some(err.clone())); - Err(S::Error::custom(err)) + let res = Err(S::Error::custom(format_args!("{err}"))); + self.deserializer_error.set(Some(err)); + res } } } } -impl<'a, 'de: 'a, 'deserializer, F, Strategy> Reserialize<'a, 'de, 'deserializer, F, Strategy> +impl<'a, Strategy: strategy::Strategy> Context<'a, Strategy> { + fn reserialize( + &self, + reserialize: F, + f: impl FnOnce(&Reserialize) -> T, + ) -> Result { + let reserialize = Reserialize { + f: Cell::new(Some(reserialize)), + deserializer_error: Cell::new(None), + }; + let res = f(&reserialize); + match reserialize.deserializer_error.take() { + Some(err) => Err(err), + None => Ok(res), + } + } + + fn reserialize_ty<'de, D: Deserializer<'de>, T>( + &self, + schema: &OwnedNamedType, + deserializer: D, + f: impl FnOnce(&Reserialize>) -> T, + ) -> Result { + self.reserialize( + ReserializeTy { + context: self, + deserializer, + schema, + de: PhantomData, + }, + f, + ) + } +} + +struct ReserializeTy<'a, 'de, D, Strategy> { + context: &'a Context<'a, Strategy>, + deserializer: D, + schema: &'a OwnedNamedType, + de: PhantomData<&'de ()>, +} + +impl<'de, D, Strategy> ReserializeFn for ReserializeTy<'_, 'de, D, Strategy> where - F: Flavor<'de>, - Strategy: structs_and_enums::Strategy, + D: Deserializer<'de>, + Strategy: strategy::Strategy, { - fn serialize_inner(&self, serializer: S) -> Result> { - match &self.schema.ty { - OwnedDataModelType::Bool => serializer.serialize_bool(self.deserialize()?), - OwnedDataModelType::U8 => serializer.serialize_u8(self.deserialize()?), - OwnedDataModelType::U16 => serializer.serialize_u16(self.deserialize()?), - OwnedDataModelType::U32 => serializer.serialize_u32(self.deserialize()?), - OwnedDataModelType::U64 => serializer.serialize_u64(self.deserialize()?), - OwnedDataModelType::U128 => serializer.serialize_u128(self.deserialize()?), - OwnedDataModelType::I8 => serializer.serialize_i8(self.deserialize()?), - OwnedDataModelType::I16 => serializer.serialize_i16(self.deserialize()?), - OwnedDataModelType::I32 => serializer.serialize_i32(self.deserialize()?), - OwnedDataModelType::I64 => serializer.serialize_i64(self.deserialize()?), - OwnedDataModelType::I128 => serializer.serialize_i128(self.deserialize()?), - OwnedDataModelType::Usize => self.deserialize::()?.serialize(serializer), - OwnedDataModelType::Isize => self.deserialize::()?.serialize(serializer), - OwnedDataModelType::F32 => serializer.serialize_f32(self.deserialize()?), - OwnedDataModelType::F64 => serializer.serialize_f64(self.deserialize()?), - OwnedDataModelType::Char => serializer.serialize_char(self.deserialize()?), - OwnedDataModelType::String => serializer.serialize_str(self.deserialize()?), - OwnedDataModelType::ByteArray => serializer.serialize_bytes(self.deserialize()?), - OwnedDataModelType::Option(inner) => { - if self.deserialize()? { - serializer.serialize_some(&self.with_schema(inner)) - } else { - serializer.serialize_none() - } - } + type DeserializeError = D::Error; + + fn reserialize( + self, + serializer: S, + ) -> Result> { + fn deserialize<'de, T, D, SerializerError>( + deserializer: D, + ) -> Result> + where + T: Deserialize<'de>, + D: Deserializer<'de>, + { + T::deserialize(deserializer).map_err(Error::Deserialize) + } + let (context, deserializer, schema) = (self.context, self.deserializer, self.schema); + match &schema.ty { + OwnedDataModelType::Schema => OwnedNamedType::deserialize(deserializer) + .map_err(Error::Deserialize)? + .serialize(serializer), OwnedDataModelType::Unit => serializer.serialize_unit(), - OwnedDataModelType::Seq(element) => { - let len = self.deserialize()?; - serializer - .serialize_seq(Some(len)) - .and_then(|mut serializer| { - for _ in 0..len { - serializer.serialize_element(&self.with_schema(element))?; - } - serializer.end() - }) + OwnedDataModelType::Bool => serializer.serialize_bool(deserialize(deserializer)?), + OwnedDataModelType::U8 => serializer.serialize_u8(deserialize(deserializer)?), + OwnedDataModelType::U16 => serializer.serialize_u16(deserialize(deserializer)?), + OwnedDataModelType::U32 => serializer.serialize_u32(deserialize(deserializer)?), + OwnedDataModelType::U64 => serializer.serialize_u64(deserialize(deserializer)?), + OwnedDataModelType::U128 => serializer.serialize_u128(deserialize(deserializer)?), + OwnedDataModelType::I8 => serializer.serialize_i8(deserialize(deserializer)?), + OwnedDataModelType::I16 => serializer.serialize_i16(deserialize(deserializer)?), + OwnedDataModelType::I32 => serializer.serialize_i32(deserialize(deserializer)?), + OwnedDataModelType::I64 => serializer.serialize_i64(deserialize(deserializer)?), + OwnedDataModelType::I128 => serializer.serialize_i128(deserialize(deserializer)?), + OwnedDataModelType::Usize => { + deserialize::(deserializer)?.serialize(serializer) } - OwnedDataModelType::Tuple(elements) => serializer - .serialize_tuple(elements.len()) - .and_then(|mut serializer| { - for element in elements { - serializer.serialize_element(&self.with_schema(element))?; - } - serializer.end() - }), - OwnedDataModelType::UnitStruct => self - .strategy - .serialize_unit_struct(serializer, &self.schema.name), - OwnedDataModelType::NewtypeStruct(inner) => self.strategy.serialize_newtype_struct( - serializer, - &self.schema.name, - &self.with_schema(inner), - ), - OwnedDataModelType::TupleStruct(fields) => { - self.strategy - .serialize_tuple_struct(serializer, self, &self.schema.name, fields) + OwnedDataModelType::Isize => { + deserialize::(deserializer)?.serialize(serializer) } - OwnedDataModelType::Struct(fields) => { - self.strategy - .serialize_struct(serializer, self, &self.schema.name, fields) + OwnedDataModelType::F32 => serializer.serialize_f32(deserialize(deserializer)?), + OwnedDataModelType::F64 => serializer.serialize_f64(deserialize(deserializer)?), + OwnedDataModelType::Char => serializer.serialize_char(deserialize(deserializer)?), + OwnedDataModelType::String => serializer.serialize_str(deserialize(deserializer)?), + OwnedDataModelType::ByteArray => serializer.serialize_bytes(deserialize(deserializer)?), + OwnedDataModelType::Option(inner) => deserializer + .deserialize_option(option::Visitor { + context, + serializer, + schema: inner, + }) + .map_err(Error::Deserialize)?, + OwnedDataModelType::Map { key, val } => deserializer + .deserialize_map(map::Visitor { + context, + serializer, + key, + val, + }) + .map_err(Error::Deserialize)?, + OwnedDataModelType::Seq(element) => deserializer + .deserialize_seq(seq::Visitor { + context, + serializer, + schemas: slice::from_ref(element), + }) + .map_err(Error::Deserialize)?, + OwnedDataModelType::Tuple(elements) => deserializer + .deserialize_tuple( + elements.len(), + tuple::Visitor { + context, + serializer, + fields: elements, + reserializer: expecting::Tuple, + }, + ) + .map_err(Error::Deserialize)?, + OwnedDataModelType::UnitStruct => { + Strategy::reserialize_unit_struct(context, deserializer, serializer, &schema.name) + .map_err(Error::Deserialize)? } - OwnedDataModelType::Map { key, val } => { - let map_len = self.deserialize()?; - serializer - .serialize_map(Some(map_len)) - .and_then(|mut serializer| { - for _ in 0..map_len { - // Important these are deserialized in order instead of using - // serialize_entry() which could deserialize the value first - serializer.serialize_key(&self.with_schema(key))?; - serializer.serialize_value(&self.with_schema(val))?; - } - serializer.end() - }) - } - OwnedDataModelType::Enum(variants) => { - let variant: u32 = self.deserialize()?; - let schema = usize::try_from(variant) - .ok() - .and_then(|variant| variants.get(variant)) - .ok_or(postcard::Error::DeserializeBadEncoding) - .map_err(Error::Deserialize)?; - self.strategy - .serialize_enum(serializer, self, &self.schema.name, variant, schema) - } - OwnedDataModelType::Schema => todo!(), + OwnedDataModelType::NewtypeStruct(inner) => Strategy::reserialize_newtype_struct( + context, + deserializer, + serializer, + expecting::Struct { + name: &schema.name, + data: expecting::data::Newtype { schema: inner }, + }, + ) + .map_err(Error::Deserialize)?, + OwnedDataModelType::TupleStruct(fields) => Strategy::reserialize_tuple_struct( + context, + deserializer, + serializer, + expecting::Struct { + name: &schema.name, + data: expecting::data::Tuple { elements: fields }, + }, + ) + .map_err(Error::Deserialize)?, + OwnedDataModelType::Struct(fields) => Strategy::reserialize_struct( + context, + deserializer, + serializer, + expecting::Struct { + name: &schema.name, + data: expecting::data::Struct { fields }, + }, + ) + .map_err(Error::Deserialize)?, + OwnedDataModelType::Enum(variants) => Strategy::reserialize_enum( + context, + deserializer, + serializer, + expecting::Enum { + name: &schema.name, + variants, + }, + ) + .map_err(Error::Deserialize)?, } .map_err(Error::Serialize) } +} - fn with_schema(&self, schema: &'a OwnedNamedType) -> Self { - Self { - schema, - deserializer: self.deserializer, - deserializer_error: self.deserializer_error, - strategy: self.strategy, - } - } +struct Expected<'a>(fmt::Arguments<'a>); - fn deserialize, SerializeError: serde::ser::Error>( - &self, - ) -> Result> { - T::deserialize(&mut **self.deserializer.borrow_mut()).map_err(Error::Deserialize) +impl de::Expected for Expected<'_> { + fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + write!(formatter, "{}", self.0) } } -mod structs_and_enums { - //! How to reserialize structs and enums to work around [`Serializer`]'s `&'static str` requirements. +#[cfg(test)] +mod tests { + use core::fmt::Debug; + use postcard::ser_flavors::Flavor; + use postcard_schema::Schema; + use serde::de::DeserializeOwned; + use serde_json::json; - use postcard::de_flavors::Flavor; - use postcard_schema::schema::owned::{OwnedNamedType, OwnedNamedValue, OwnedNamedVariant}; - use serde::{Serialize, Serializer}; + use super::*; - /// Type-erased wrapper around [`super::Reserialize`] to avoid needing to introduce all of the - /// generic lifetime parameters and bounds associated with the aforementioned. - pub(crate) trait Reserialize { - fn with_schema<'a>(&'a self, schema: &'a OwnedNamedType) -> impl Serialize + 'a; + #[derive(Serialize, Deserialize, Schema, PartialEq, Debug)] + enum Enum { + Struct { a: u8, b: u8 }, + Tuple(bool, u8), + Newtype(u32), + Unit, } - impl<'a, 'de, 'deserializer, F, Strategy> Reserialize - for &super::Reserialize<'a, 'de, 'deserializer, F, Strategy> + #[derive(Serialize, Deserialize, Schema, PartialEq, Debug)] + struct Struct { + a: Option, + b: u8, + c: u8, + } + + fn postcard_to_json(postcard: &[u8]) -> serde_json::Value { + let schema = T::SCHEMA.into(); + let leaky = lossless::reserialize_leaky( + &schema, + &mut postcard::Deserializer::from_bytes(postcard), + serde_json::value::Serializer, + ) + .unwrap(); + let lossy = lossy::reserialize_with_structs_and_enums_as_maps( + &schema, + &mut postcard::Deserializer::from_bytes(postcard), + serde_json::value::Serializer, + ) + .unwrap(); + assert_eq!(leaky, lossy); + leaky + } + + fn json_to_postcard(json: &serde_json::Value) -> Vec { + let mut serializer = postcard::Serializer { + output: postcard::ser_flavors::AllocVec::new(), + }; + lossless::reserialize_leaky(&T::SCHEMA.into(), json, &mut serializer).unwrap(); + serializer.output.finalize().unwrap() + } + + fn test_postcard_to_json_and_back(value: T) where - F: Flavor<'de> + 'de, - Strategy: self::Strategy, + T: Schema + Serialize + DeserializeOwned + Debug + PartialEq, { - fn with_schema<'b>(&'b self, schema: &'b OwnedNamedType) -> impl Serialize + 'b { - super::Reserialize::with_schema(self, schema) - } + let postcard_bytes = postcard::to_allocvec(&value).unwrap(); + let json = postcard_to_json::(&postcard_bytes); + assert_eq!(json, serde_json::to_value(&value).unwrap()); + assert_eq!(T::deserialize(&json).unwrap(), value); + + let roundtripped_postcard_bytes = json_to_postcard::(&json); + assert_eq!(roundtripped_postcard_bytes, postcard_bytes); + assert_eq!( + postcard::from_bytes::(&roundtripped_postcard_bytes).unwrap(), + value + ); } - /// How to reserialize structs and enums to work around [`Serializer`]'s `&'static str` requirements. - pub(crate) trait Strategy { - fn serialize_unit_struct( - &self, - serializer: S, - name: &str, - ) -> Result; - - fn serialize_newtype_struct( - &self, - serializer: S, - name: &str, - value: &T, - ) -> Result; - - fn serialize_tuple_struct( - &self, - serializer: S, - reserialize: impl Reserialize, - name: &str, - fields: &[OwnedNamedType], - ) -> Result; - - fn serialize_struct( - &self, - serializer: S, - reserialize: impl Reserialize, - name: &str, - fields: &[OwnedNamedValue], - ) -> Result; - - fn serialize_enum( - &self, - serializer: S, - reserialize: impl Reserialize, - name: &str, - variant_index: u32, - variant: &OwnedNamedVariant, - ) -> Result; + fn test_json_to_postcard(json: serde_json::Value) + where + T: Schema + Serialize + DeserializeOwned + Debug + PartialEq, + { + let postcard_bytes = json_to_postcard::(&json); + let from_json = T::deserialize(&json).unwrap(); + let from_postcard_bytes = postcard::from_bytes::(&postcard_bytes).unwrap(); + assert_eq!(from_postcard_bytes, from_json); } -} -#[test] -fn errors() { - use postcard_schema::Schema; + fn test_json_to_postcard_and_back(json: serde_json::Value) + where + T: Schema + Serialize + DeserializeOwned + Debug + PartialEq, + { + let postcard_bytes = json_to_postcard::(&json); + + let from_json = T::deserialize(&json).unwrap(); + let from_postcard_bytes = postcard::from_bytes::(&postcard_bytes).unwrap(); + assert_eq!(from_postcard_bytes, from_json); + + let json_roundtripped = postcard_to_json::(&postcard_bytes); + assert_eq!(json_roundtripped, json); + + let from_json_roundtripped = T::deserialize(&json_roundtripped).unwrap(); + assert_eq!(from_json_roundtripped, from_json); + } + + #[test] + fn json() { + use test_postcard_to_json_and_back as test; + test(Enum::Struct { a: 5, b: 10 }); + test(Enum::Tuple(false, 15)); + test(Enum::Newtype(20)); + test(Enum::Unit); + test(Struct { + a: Some(5), + b: 10, + c: 7, + }); + } + + #[test] + /// Make sure reserialization handles out-of-order struct fields correctly. + /// Serializers like postcard rely on struct fields being serialized in order. + fn out_of_order_fields() { + use test_json_to_postcard_and_back as test; + test::(json!({"Struct": {"b": 10, "a": 5}})); + test::(json!({"Struct": {"a": 5, "b": 0}})); + test::(json!({"a": 5, "b": 0, "c": 10})); + test::(json!({"b": 0, "a": 5, "c": 10})); + test::(json!({"b": 0, "c": 10, "a": 5})); + } + + #[test] + fn extra_fields() { + use test_json_to_postcard as test; + test::(json!({"Struct": {"a": 5, "b": 0, "UNUSED": 10}})); + test::(json!({"a": 5, "xyz": "wat", "b": 0, "c": 10})); + } + + #[test] + #[should_panic = "missing field `b`"] + fn missing_fields() { + test_json_to_postcard::(json!({"Struct": {"a": 5}})); + } - assert!(matches!( - lossy::reserialize_with_structs_and_enums_as_maps( - &u8::SCHEMA.into(), - &mut postcard::Deserializer::from_bytes(&[]), - serde_json::value::Serializer - ), - Err(Error::Deserialize( - postcard::Error::DeserializeUnexpectedEnd - )) - )); - // Bad enum discriminant - assert!(matches!( - lossy::reserialize_with_structs_and_enums_as_maps( - &Result::::SCHEMA.into(), - &mut postcard::Deserializer::from_bytes(&[99]), - serde_json::value::Serializer - ), - Err(Error::Deserialize(postcard::Error::DeserializeBadEncoding)) - )); + #[test] + #[should_panic = "invalid length 1, expected tuple variant Enum::Tuple with 2 elements"] + fn missing_tuple_fields() { + test_json_to_postcard::(json!({"Tuple": [false]})); + } + + #[test] + /// Make sure both deserializer and serializer errors are bubbled up + fn errors() { + use postcard_schema::Schema; + + assert!(matches!( + dbg!(lossless::reserialize_leaky( + &u8::SCHEMA.into(), + &mut postcard::Deserializer::from_bytes(&[]), + serde_json::value::Serializer + )), + Err(Error::Deserialize( + postcard::Error::DeserializeUnexpectedEnd + )) + )); + assert!(matches!( + dbg!(lossless::reserialize_leaky( + &u8::SCHEMA.into(), + &mut postcard::Deserializer::from_bytes(&[5]), + &mut serde_json::Serializer::new(std::io::Cursor::new([].as_mut_slice())) + )), + Err(Error::Serialize(_)) + )); + } } diff --git a/source/postcard-dyn/src/reserialize/expecting.rs b/source/postcard-dyn/src/reserialize/expecting.rs new file mode 100644 index 0000000..5313f2f --- /dev/null +++ b/source/postcard-dyn/src/reserialize/expecting.rs @@ -0,0 +1,118 @@ +use core::{ + fmt::{self, Display}, + ops::RangeTo, +}; + +use postcard_schema::schema::owned::OwnedNamedVariant; +use serde::de::{self, Expected}; + +pub trait Unexpected: de::Error { + fn missing_elements(len: usize, expected: &dyn Expected, expected_elements: usize) -> Self { + Self::invalid_length( + len, + &super::Expected(format_args!( + "{expected} with {expected_elements} element{}", + if expected_elements == 1 { "" } else { "s" }, + )), + ) + } + + fn unknown_variant_index(index: impl Into, expected: RangeTo) -> Self { + Self::invalid_value( + de::Unexpected::Unsigned(index.into()), + &super::Expected(format_args!("variant index 0 <= i < {}", expected.end)), + ) + } +} + +impl Unexpected for Error {} + +pub struct Tuple; + +pub struct Struct<'a, Data> { + pub name: &'a str, + pub data: Data, +} + +pub struct Enum<'name, 'schema> { + pub name: &'name str, + pub variants: &'schema [OwnedNamedVariant], +} + +pub struct Variant<'a, Data> { + pub enum_name: &'a str, + pub variant_index: u32, + pub variant_name: &'a str, + pub data: Data, +} + +pub mod data { + use postcard_schema::schema::owned::{OwnedNamedType, OwnedNamedValue}; + + pub struct Unit; + pub struct Newtype<'a> { + pub schema: &'a OwnedNamedType, + } + pub struct Tuple<'a> { + pub elements: &'a [OwnedNamedType], + } + pub struct Struct<'a> { + pub fields: &'a [OwnedNamedValue], + } +} + +impl Expected for Tuple { + fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + write!(formatter, "a tuple") + } +} + +impl Expected for Struct<'_, data::Unit> { + fn fmt(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(formatter, "unit struct {}", self.name) + } +} + +impl Expected for Struct<'_, data::Newtype<'_>> { + fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + write!(formatter, "tuple struct {}", self.name) + } +} + +impl Expected for Struct<'_, data::Tuple<'_>> { + fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + write!(formatter, "tuple struct {}", self.name) + } +} + +impl Expected for Struct<'_, data::Struct<'_>> { + fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + write!(formatter, "struct {}", self.name) + } +} + +impl Expected for Enum<'_, '_> { + fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + write!(formatter, "enum {}", self.name) + } +} + +impl Expected for Variant<'_, data::Struct<'_>> { + fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + write!( + formatter, + "struct variant {}::{}", + self.enum_name, self.variant_name + ) + } +} + +impl Expected for Variant<'_, data::Tuple<'_>> { + fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + write!( + formatter, + "tuple variant {}::{}", + self.enum_name, self.variant_name + ) + } +} diff --git a/source/postcard-dyn/src/reserialize/lossless.rs b/source/postcard-dyn/src/reserialize/lossless.rs index ab98b8a..dd58c96 100644 --- a/source/postcard-dyn/src/reserialize/lossless.rs +++ b/source/postcard-dyn/src/reserialize/lossless.rs @@ -4,26 +4,33 @@ //! This module provides implementations with different compromises: //! - [`reserialize_leaky()`] leaks memory for each unique struct/enum/variant/field name -use core::cell::RefCell; -use std::collections::HashSet; +use core::{cell::RefCell, fmt, str}; -use postcard::de_flavors::Flavor; -use postcard_schema::schema::owned::{ - OwnedDataModelVariant, OwnedNamedType, OwnedNamedValue, OwnedNamedVariant, -}; +use postcard_schema::schema::owned::OwnedNamedType; use serde::{ - ser::{SerializeStruct, SerializeStructVariant, SerializeTupleStruct, SerializeTupleVariant}, - Serialize, Serializer, + de::{self, Deserializer}, + ser::Serializer, }; use crate::Error; -use super::{reserialize, structs_and_enums::Reserialize}; +use super::{ + expecting, + strategy::{self, Strategy as _}, + Context, +}; + +mod interned; +use interned::Interned; + +mod enums; +mod structs; +mod tuples; /// Reserialize [`postcard`]-encoded data losslessly, **leaking memory**. /// /// In order to serialize structs and enums losslessly, this **allocates and leaks each unique -/// struct/enum/variant/field name**. +/// struct/enum/variant/field name, and the list of field names for each struct**. /// /// # Examples /// @@ -54,131 +61,189 @@ use super::{reserialize, structs_and_enums::Reserialize}; /// # Ok(()) /// # } /// ``` -pub fn reserialize_leaky<'de, F, S>( +pub fn reserialize_leaky<'de, D, S>( schema: &OwnedNamedType, - deserializer: &mut postcard::Deserializer<'de, F>, + deserializer: D, serializer: S, -) -> Result> +) -> Result> where - F: Flavor<'de>, + D: Deserializer<'de>, S: Serializer, { - reserialize(schema, deserializer, serializer, Strategy) + Strategy.reserialize(schema, deserializer, serializer) } /// Reserialize structs and enums losslessly, **leaking memory**. struct Strategy; impl Strategy { - fn intern(&self, s: &str) -> &'static str { + fn with_interned(&self, f: impl FnOnce(&mut Interned) -> T) -> T { thread_local! { - static STRINGS: RefCell> = RefCell::new(HashSet::new()); + static INTERNED: RefCell = RefCell::new(Default::default()); } - STRINGS.with_borrow_mut(|strings| { - if !strings.contains(s) { - strings.insert(String::leak(s.to_string())); - } - *strings.get(s).unwrap() - }) + INTERNED.with_borrow_mut(f) } } -impl super::structs_and_enums::Strategy for Strategy { - fn serialize_unit_struct( - &self, +impl strategy::Strategy for Strategy { + fn reserialize_unit_struct<'de, D: Deserializer<'de>, S: Serializer>( + context: &Context<'_, Self>, + deserializer: D, serializer: S, name: &str, - ) -> Result { - serializer.serialize_unit_struct(self.intern(name)) + ) -> Result, D::Error> { + struct Visitor { + serializer: S, + expecting: expecting::Struct<'static, expecting::data::Unit>, + } + + impl<'de, S: Serializer> de::Visitor<'de> for Visitor { + type Value = Result; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + de::Expected::fmt(&self.expecting, formatter) + } + + fn visit_unit(self) -> Result { + Ok(self.serializer.serialize_unit_struct(self.expecting.name)) + } + } + + let name = context + .strategy + .with_interned(|interned| interned.intern_identifier(name)); + deserializer.deserialize_unit_struct( + name, + Visitor { + serializer, + expecting: expecting::Struct { + name, + data: expecting::data::Unit, + }, + }, + ) } - fn serialize_newtype_struct( - &self, + fn reserialize_newtype_struct<'de, D: Deserializer<'de>, S: Serializer>( + context: &Context<'_, Self>, + deserializer: D, serializer: S, - name: &str, - value: &T, - ) -> Result { - serializer.serialize_newtype_struct(self.intern(name), value) + expecting: expecting::Struct<'_, expecting::data::Newtype>, + ) -> Result, D::Error> { + struct Visitor<'a, S> { + context: &'a Context<'a, Strategy>, + serializer: S, + expecting: expecting::Struct<'static, expecting::data::Newtype<'a>>, + } + + impl<'a, 'de, S: Serializer> de::Visitor<'de> for Visitor<'a, S> { + type Value = Result; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + de::Expected::fmt(&self.expecting, formatter) + } + + fn visit_newtype_struct(self, deserializer: D) -> Result + where + D: Deserializer<'de>, + { + self.context + .reserialize_ty(self.expecting.data.schema, deserializer, |inner| { + self.serializer + .serialize_newtype_struct(self.expecting.name, inner) + }) + } + } + + let name = context + .strategy + .with_interned(|interned| interned.intern_identifier(expecting.name)); + deserializer.deserialize_newtype_struct( + name, + Visitor { + context, + serializer, + expecting: expecting::Struct { + name, + data: expecting.data, + }, + }, + ) } - fn serialize_tuple_struct( - &self, + fn reserialize_tuple_struct<'de, D: Deserializer<'de>, S: Serializer>( + context: &Context<'_, Self>, + deserializer: D, serializer: S, - reserialize: impl Reserialize, - name: &str, - fields: &[OwnedNamedType], - ) -> Result { - let mut serializer = serializer.serialize_tuple_struct(self.intern(name), fields.len())?; - for field in fields { - serializer.serialize_field(&reserialize.with_schema(field))?; - } - serializer.end() + expecting: expecting::Struct<'_, expecting::data::Tuple>, + ) -> Result, D::Error> { + let name = context + .strategy + .with_interned(|interned| interned.intern_identifier(expecting.name)); + deserializer.deserialize_tuple_struct( + name, + expecting.data.elements.len(), + super::tuple::Visitor { + context, + serializer, + fields: expecting.data.elements, + reserializer: expecting::Struct { + name, + data: expecting.data, + }, + }, + ) } - fn serialize_struct( - &self, + fn reserialize_struct<'de, D: Deserializer<'de>, S: Serializer>( + context: &Context<'_, Self>, + deserializer: D, serializer: S, - reserialize: impl Reserialize, - name: &str, - fields: &[OwnedNamedValue], - ) -> Result { - let mut serializer = serializer.serialize_struct(self.intern(name), fields.len())?; - for field in fields { - serializer.serialize_field( - self.intern(&field.name), - &reserialize.with_schema(&field.ty), - )?; - } - serializer.end() + expecting: expecting::Struct<'_, expecting::data::Struct>, + ) -> Result, D::Error> { + let fields = expecting.data.fields; + let (name, field_names) = context.strategy.with_interned(|interned| { + let name = interned.intern_identifier(expecting.name); + let field_names = interned.intern_slice(fields.iter().map(|f| f.name.as_str())); + (name, field_names) + }); + deserializer.deserialize_struct( + name, + field_names, + structs::Visitor { + context, + serializer, + fields, + field_names, + reserializer: expecting::Struct { + name, + data: expecting.data, + }, + }, + ) } - fn serialize_enum( - &self, + fn reserialize_enum<'de, D: Deserializer<'de>, S: Serializer>( + context: &Context<'_, Self>, + deserializer: D, serializer: S, - reserialize: impl Reserialize, - name: &str, - variant_index: u32, - variant: &OwnedNamedVariant, - ) -> Result { - match &variant.ty { - OwnedDataModelVariant::UnitVariant => serializer.serialize_unit_variant( - self.intern(name), - variant_index, - self.intern(&variant.name), - ), - OwnedDataModelVariant::NewtypeVariant(inner) => serializer.serialize_newtype_variant( - self.intern(name), - variant_index, - self.intern(&variant.name), - &reserialize.with_schema(inner), - ), - OwnedDataModelVariant::TupleVariant(fields) => { - let mut serializer = serializer.serialize_tuple_variant( - self.intern(name), - variant_index, - self.intern(&variant.name), - fields.len(), - )?; - for field in fields { - serializer.serialize_field(&reserialize.with_schema(field))?; - } - serializer.end() - } - OwnedDataModelVariant::StructVariant(fields) => { - let mut serializer = serializer.serialize_struct_variant( - self.intern(name), - variant_index, - self.intern(&variant.name), - fields.len(), - )?; - for field in fields { - serializer.serialize_field( - self.intern(&field.name), - &reserialize.with_schema(&field.ty), - )?; - } - serializer.end() - } - } + expecting: expecting::Enum<'_, '_>, + ) -> Result, D::Error> { + let variants = expecting.variants; + let (name, variant_names) = context.strategy.with_interned(|interned| { + let name = interned.intern_identifier(expecting.name); + let variant_names = interned.intern_slice(variants.iter().map(|v| v.name.as_str())); + (name, variant_names) + }); + deserializer.deserialize_enum( + name, + variant_names, + enums::Visitor { + context, + serializer, + expecting: expecting::Enum { name, variants }, + variant_names, + }, + ) } } diff --git a/source/postcard-dyn/src/reserialize/lossless/enums.rs b/source/postcard-dyn/src/reserialize/lossless/enums.rs new file mode 100644 index 0000000..295f8a7 --- /dev/null +++ b/source/postcard-dyn/src/reserialize/lossless/enums.rs @@ -0,0 +1,174 @@ +use core::{fmt, str}; + +use postcard_schema::schema::owned::{OwnedDataModelVariant, OwnedNamedType, OwnedNamedVariant}; +use serde::{ + de::{self, DeserializeSeed, EnumAccess, VariantAccess}, + Deserializer, Serializer, +}; + +use crate::reserialize::{ + self, + expecting::{self, Unexpected}, + Context, +}; + +use super::Strategy; + +pub struct Visitor<'a, S> { + pub context: &'a Context<'a, Strategy>, + pub serializer: S, + pub expecting: expecting::Enum<'static, 'a>, + pub variant_names: &'static [&'static str], +} + +impl<'a, 'de, S: Serializer> de::Visitor<'de> for Visitor<'a, S> { + type Value = Result; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + de::Expected::fmt(&self.expecting, formatter) + } + + fn visit_enum>(self, data: A) -> Result { + let ((variant_index, variant_name, variant), deserializer) = + data.variant_seed(VariantVisitor { + variants: self.expecting.variants, + variant_names: self.variant_names, + })?; + match variant { + OwnedDataModelVariant::UnitVariant => { + deserializer.unit_variant()?; + Ok(self.serializer.serialize_unit_variant( + self.expecting.name, + variant_index, + variant_name, + )) + } + OwnedDataModelVariant::NewtypeVariant(inner) => { + deserializer.newtype_variant_seed(NewtypeVariantSeed { + context: self.context, + schema: inner, + serializer: self.serializer, + location: expecting::Variant { + enum_name: self.expecting.name, + variant_index, + variant_name, + data: expecting::data::Newtype { schema: inner }, + }, + }) + } + OwnedDataModelVariant::TupleVariant(fields) => deserializer.tuple_variant( + fields.len(), + reserialize::tuple::Visitor { + context: self.context, + serializer: self.serializer, + fields, + reserializer: expecting::Variant { + enum_name: self.expecting.name, + variant_index, + variant_name, + data: expecting::data::Tuple { elements: fields }, + }, + }, + ), + OwnedDataModelVariant::StructVariant(fields) => { + let field_names = self.context.strategy.with_interned(|interned| { + interned.intern_slice(fields.iter().map(|f| f.name.as_str())) + }); + deserializer.struct_variant( + field_names, + super::structs::Visitor { + context: self.context, + serializer: self.serializer, + fields, + field_names, + reserializer: expecting::Variant { + enum_name: self.expecting.name, + variant_index, + variant_name, + data: expecting::data::Struct { fields }, + }, + }, + ) + } + } + } +} + +struct VariantVisitor<'a> { + variants: &'a [OwnedNamedVariant], + variant_names: &'static [&'static str], +} + +impl<'a, 'de> DeserializeSeed<'de> for VariantVisitor<'a> { + type Value = (u32, &'static str, &'a OwnedDataModelVariant); + + fn deserialize>(self, deserializer: D) -> Result { + deserializer.deserialize_identifier(self) + } +} + +impl<'a, 'de> de::Visitor<'de> for VariantVisitor<'a> { + type Value = (u32, &'static str, &'a OwnedDataModelVariant); + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + write!(formatter, "variant identifier") + } + + fn visit_u64(self, value: u64) -> Result { + let err = || E::unknown_variant_index(value, ..self.variants.len()); + let index = u32::try_from(value).map_err(|_| err())?; + let (name, schema) = { + let idx = usize::try_from(value).map_err(|_| err())?; + (self.variant_names.get(idx)) + .zip(self.variants.get(idx)) + .ok_or_else(err)? + }; + Ok((index, name, &schema.ty)) + } + + fn visit_str(self, value: &str) -> Result { + self.find(value.as_bytes()) + .ok_or_else(|| E::unknown_variant(value, self.variant_names)) + } + + fn visit_bytes(self, value: &[u8]) -> Result { + self.find(value).ok_or_else(|| match str::from_utf8(value) { + Ok(value) => E::unknown_variant(value, self.variant_names), + Err(_) => E::invalid_value(de::Unexpected::Bytes(value), &self), + }) + } +} + +impl<'a> VariantVisitor<'a> { + fn find(&self, variant: &[u8]) -> Option<(u32, &'static str, &'a OwnedDataModelVariant)> { + (self.variant_names.iter()) + .zip(self.variants) + .enumerate() + .find_map(|(index, (&name, schema))| { + (name.as_bytes() == variant).then_some((index as u32, name, &schema.ty)) + }) + } +} + +struct NewtypeVariantSeed<'a, S> { + context: &'a Context<'a, Strategy>, + schema: &'a OwnedNamedType, + serializer: S, + location: expecting::Variant<'static, expecting::data::Newtype<'a>>, +} + +impl<'a, 'de, S: Serializer> DeserializeSeed<'de> for NewtypeVariantSeed<'a, S> { + type Value = Result; + + fn deserialize>(self, deserializer: D) -> Result { + self.context + .reserialize_ty(self.schema, deserializer, |inner| { + self.serializer.serialize_newtype_variant( + self.location.enum_name, + self.location.variant_index, + self.location.variant_name, + inner, + ) + }) + } +} diff --git a/source/postcard-dyn/src/reserialize/lossless/interned.rs b/source/postcard-dyn/src/reserialize/lossless/interned.rs new file mode 100644 index 0000000..4689b7f --- /dev/null +++ b/source/postcard-dyn/src/reserialize/lossless/interned.rs @@ -0,0 +1,80 @@ +use core::hash::{Hash, Hasher}; + +use hashbrown::HashSet; + +#[derive(Default)] +pub struct Interned { + strings: HashSet<&'static str>, + slices: HashSet, +} + +impl Interned { + pub fn intern_identifier(&mut self, s: &str) -> &'static str { + Self::intern_str(&mut self.strings, s) + } + + fn intern_str(strings: &mut HashSet<&'static str>, s: &str) -> &'static str { + strings.get_or_insert_with(s, |s| String::leak(s.to_string())) + } + + pub fn intern_slice<'a>( + &mut self, + strings: impl IntoIterator, + ) -> &'static [&'static str] { + let Slice(slice) = self + .slices + .get_or_insert_with(&Iter(strings.into_iter()), |elements| { + let strings = elements.0.clone(); + let interned = strings.map(|s| Self::intern_str(&mut self.strings, s)); + Slice(Box::leak(interned.collect())) + }); + slice + } +} + +#[derive(PartialEq, Eq)] +struct Slice(&'static [&'static str]); +struct Iter<'a, I: Iterator>(I); + +impl Hash for Slice { + fn hash(&self, state: &mut H) { + for s in self.0 { + s.hash(state) + } + } +} + +impl<'a, I: Iterator + Clone> Hash for Iter<'a, I> { + fn hash(&self, state: &mut H) { + for s in self.0.clone() { + s.hash(state) + } + } +} + +impl<'a, I> hashbrown::Equivalent for Iter<'a, I> +where + I: Iterator + Clone, +{ + fn equivalent(&self, slice: &Slice) -> bool { + self.0.clone().eq(slice.0.iter().copied()) + } +} + +#[cfg(test)] +mod tests { + use super::Interned; + + #[test] + fn basic() { + let mut interned = Interned::default(); + + assert_eq!(interned.intern_identifier("hello"), "hello"); + + let slices: &[&[&str]] = &[&[], &["foo"], &["foo", "bar"]]; + for &slice in slices { + assert_eq!(interned.intern_slice(slice.iter().copied()), slice); + } + assert_eq!(interned.slices.len(), slices.len()); + } +} diff --git a/source/postcard-dyn/src/reserialize/lossless/structs.rs b/source/postcard-dyn/src/reserialize/lossless/structs.rs new file mode 100644 index 0000000..b0edf8a --- /dev/null +++ b/source/postcard-dyn/src/reserialize/lossless/structs.rs @@ -0,0 +1,286 @@ +use core::fmt; +use std::collections::HashMap; + +use postcard_schema::schema::owned::{OwnedNamedType, OwnedNamedValue}; +use serde::{ + de::{self, DeserializeSeed, Deserializer, Error as _, MapAccess, SeqAccess}, + ser::{self, Error as _, Serialize, SerializeStruct, Serializer}, +}; + +use crate::reserialize::{ + expecting::{self, Unexpected}, + Context, +}; + +use super::Strategy; + +pub struct Visitor<'a, S, Strategy, Reserializer> { + pub context: &'a Context<'a, Strategy>, + pub serializer: S, + pub reserializer: Reserializer, + pub fields: &'a [OwnedNamedValue], + pub field_names: &'static [&'static str], +} + +trait Reserializer: de::Expected { + type SerializeFields: SerializeStruct; + + fn reserialize_struct( + &self, + serializer: S, + len: usize, + ) -> Result; +} + +struct SerializeStructVariant(T); + +impl SerializeStruct for SerializeStructVariant { + type Ok = S::Ok; + type Error = S::Error; + + fn serialize_field(&mut self, key: &'static str, value: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + self.0.serialize_field(key, value) + } + + fn end(self) -> Result { + self.0.end() + } +} + +impl Reserializer for expecting::Variant<'static, expecting::data::Struct<'_>> { + type SerializeFields = SerializeStructVariant; + + fn reserialize_struct( + &self, + serializer: S, + len: usize, + ) -> Result { + serializer + .serialize_struct_variant(self.enum_name, self.variant_index, self.variant_name, len) + .map(SerializeStructVariant) + } +} + +impl Reserializer for expecting::Struct<'static, expecting::data::Struct<'_>> { + type SerializeFields = S::SerializeStruct; + + fn reserialize_struct( + &self, + serializer: S, + len: usize, + ) -> Result { + serializer.serialize_struct(self.name, len) + } +} + +impl<'de, S, Reserializer> de::Visitor<'de> for Visitor<'_, S, Strategy, Reserializer> +where + S: Serializer, + Reserializer: self::Reserializer, +{ + type Value = Result; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + de::Expected::fmt(&self.reserializer, formatter) + } + + fn visit_seq>(self, mut seq: A) -> Result { + let mut serializer = match self + .reserializer + .reserialize_struct(self.serializer, self.field_names.len()) + { + Ok(serializer) => serializer, + Err(err) => return Ok(Err(err)), + }; + let fields = (self.field_names.iter()) + .zip(self.fields) + .map(|(&name, field)| (name, &field.ty)); + for (idx, (name, schema)) in fields.enumerate() { + let seed = FieldSeed { + context: self.context, + serializer: &mut serializer, + name, + schema, + }; + let res = seq.next_element_seed(seed)?.ok_or_else(|| { + A::Error::missing_elements(idx, &self.reserializer, self.fields.len()) + })?; + match res { + Ok(()) => {} + Err(err) => return Ok(Err(err)), + } + } + Ok(serializer.end()) + } + + fn visit_map>(self, mut map: A) -> Result { + let mut serializer = match self + .reserializer + .reserialize_struct(self.serializer, self.field_names.len()) + { + Ok(serializer) => serializer, + Err(err) => return Ok(Err(err)), + }; + + let key = FieldVisitor { + fields: self.fields, + field_names: self.field_names, + }; + let mut remaining_fields = self.field_names.iter().peekable(); + let mut out_of_order_fields = None; + while let Some(field) = map.next_key_seed(&key)? { + match field { + Err(Ignored) => { + // This only works for self-describing formats, but it should only + // be self-describing formats that deserialize to ignored fields. + let de::IgnoredAny = map.next_value::()?; + } + Ok((name, schema)) if remaining_fields.next_if_eq(&&name).is_some() => { + let res = map.next_value_seed(FieldSeed { + context: self.context, + serializer: &mut serializer, + name, + schema, + })?; + match res { + Ok(()) => {} + Err(err) => return Ok(Err(err)), + } + } + Ok((name, schema)) => { + // Fields were deserialized out-of-order. Serializers assume fields are + // serialized in-order, so buffer up the out of order fields then serialize + // them in order. + let out_of_order = out_of_order_fields.get_or_insert_with(|| { + OutOfOrderFields(HashMap::with_capacity(remaining_fields.len())) + }); + let res: Result<(), serde_content::Error> = map.next_value_seed(FieldSeed { + context: self.context, + serializer: out_of_order, + name, + schema, + })?; + match res { + Ok(()) => {} + Err(err) => return Ok(Err(S::Error::custom(err))), + } + } + } + } + let mut out_of_order = out_of_order_fields + .map(|OutOfOrderFields(out_of_order)| out_of_order) + .unwrap_or(HashMap::new()); + for field in remaining_fields { + match out_of_order.remove(field) { + Some(value) => match serializer.serialize_field(field, &value) { + Ok(()) => {} + Err(err) => return Ok(Err(err)), + }, + None => return Err(A::Error::missing_field(field)), + } + } + for field in self.field_names { + if out_of_order.contains_key(field) { + return Err(A::Error::duplicate_field(field)); + } + } + Ok(serializer.end()) + } +} + +struct FieldSeed<'a, S> { + context: &'a Context<'a, Strategy>, + serializer: &'a mut S, + name: &'static str, + schema: &'a OwnedNamedType, +} + +impl<'de, 'a, S: SerializeStruct> DeserializeSeed<'de> for FieldSeed<'a, S> { + type Value = Result<(), S::Error>; + + fn deserialize>(self, deserializer: D) -> Result { + self.context + .reserialize_ty(self.schema, deserializer, |value| { + self.serializer.serialize_field(self.name, value) + }) + } +} + +struct FieldVisitor<'a> { + fields: &'a [OwnedNamedValue], + field_names: &'static [&'static str], +} + +struct Ignored; + +impl<'a, 'de> DeserializeSeed<'de> for &FieldVisitor<'a> { + type Value = Result<(&'static str, &'a OwnedNamedType), Ignored>; + + fn deserialize(self, deserializer: D) -> Result + where + D: Deserializer<'de>, + { + deserializer.deserialize_identifier(self) + } +} + +impl<'a, 'de> de::Visitor<'de> for &FieldVisitor<'a> { + type Value = Result<(&'static str, &'a OwnedNamedType), Ignored>; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + write!(formatter, "field identifier") + } + + fn visit_u64(self, value: u64) -> Result { + Ok((usize::try_from(value).ok()) + .and_then(|idx| { + let (&name, schema) = (self.field_names.get(idx)).zip(self.fields.get(idx))?; + Some((name, &schema.ty)) + }) + .ok_or(Ignored)) + } + + fn visit_str(self, value: &str) -> Result { + Ok(self.find(value.as_bytes())) + } + + fn visit_bytes(self, value: &[u8]) -> Result { + Ok(self.find(value)) + } +} + +impl<'a> FieldVisitor<'a> { + fn find(&self, field: &[u8]) -> Result<(&'static str, &'a OwnedNamedType), Ignored> { + self.field_names + .iter() + .zip(self.fields) + .find_map(|(&name, schema)| (name.as_bytes() == field).then_some((name, &schema.ty))) + .ok_or(Ignored) + } +} + +#[derive(Debug)] +struct OutOfOrderFields<'a>(HashMap<&'a str, serde_content::Value<'a>>); + +impl<'a> SerializeStruct for OutOfOrderFields<'a> { + type Ok = HashMap<&'a str, serde_content::Value<'a>>; + type Error = serde_content::Error; + + fn serialize_field( + &mut self, + key: &'static str, + value: &T, + ) -> Result<(), Self::Error> { + let value = value.serialize(serde_content::Serializer::new())?; + debug_assert!(self.0.len() < self.0.capacity()); + self.0.insert(key, value); + Ok(()) + } + + fn end(self) -> Result { + Ok(self.0) + } +} diff --git a/source/postcard-dyn/src/reserialize/lossless/tuples.rs b/source/postcard-dyn/src/reserialize/lossless/tuples.rs new file mode 100644 index 0000000..8e1a07d --- /dev/null +++ b/source/postcard-dyn/src/reserialize/lossless/tuples.rs @@ -0,0 +1,66 @@ +use serde::ser::{self, Serialize, SerializeTuple, Serializer}; + +use crate::reserialize::{expecting, tuple::Reserializer}; + +pub struct SerializeTupleVariant(T); +pub struct SerializeTupleStruct(T); + +impl Reserializer for expecting::Variant<'static, expecting::data::Tuple<'_>> { + type SerializeTuple = SerializeTupleVariant; + + fn reserialize_tuple( + &self, + serializer: S, + len: usize, + ) -> Result { + serializer + .serialize_tuple_variant(self.enum_name, self.variant_index, self.variant_name, len) + .map(SerializeTupleVariant) + } +} + +impl Reserializer for expecting::Struct<'static, expecting::data::Tuple<'_>> { + type SerializeTuple = SerializeTupleStruct; + + fn reserialize_tuple( + &self, + serializer: S, + len: usize, + ) -> Result { + serializer + .serialize_tuple_struct(self.name, len) + .map(SerializeTupleStruct) + } +} + +impl SerializeTuple for SerializeTupleVariant { + type Ok = S::Ok; + type Error = S::Error; + + fn serialize_element(&mut self, value: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + self.0.serialize_field(value) + } + + fn end(self) -> Result { + self.0.end() + } +} + +impl SerializeTuple for SerializeTupleStruct { + type Ok = S::Ok; + type Error = S::Error; + + fn serialize_element(&mut self, value: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + self.0.serialize_field(value) + } + + fn end(self) -> Result { + self.0.end() + } +} diff --git a/source/postcard-dyn/src/reserialize/lossy.rs b/source/postcard-dyn/src/reserialize/lossy.rs index e3ab7de..327f11c 100644 --- a/source/postcard-dyn/src/reserialize/lossy.rs +++ b/source/postcard-dyn/src/reserialize/lossy.rs @@ -6,17 +6,22 @@ //! instead of actual structs and enums. use postcard::de_flavors::Flavor; -use postcard_schema::schema::owned::{ - OwnedDataModelVariant, OwnedNamedType, OwnedNamedValue, OwnedNamedVariant, -}; +use postcard_schema::schema::owned::OwnedNamedType; use serde::{ - ser::{SerializeMap, SerializeSeq, SerializeTuple}, - Serialize, Serializer, + de::{Deserialize, Deserializer}, + ser::{Serialize, Serializer}, }; -use crate::Error; +use crate::{reserialize, Error}; + +use super::{ + expecting, + strategy::{self, Strategy as _}, + Context, +}; -use super::{reserialize, structs_and_enums::Reserialize}; +mod enums; +mod structs; /// Reserialize [`postcard`]-encoded data, transforming structs and enums into maps. /// @@ -64,129 +69,87 @@ pub fn reserialize_with_structs_and_enums_as_maps<'de, F, S>( schema: &OwnedNamedType, deserializer: &mut postcard::Deserializer<'de, F>, serializer: S, -) -> Result> +) -> Result> where F: Flavor<'de>, S: Serializer, { - reserialize(schema, deserializer, serializer, Strategy) + Strategy.reserialize(schema, deserializer, serializer) } /// Reserialize structs and enums as maps similar to [`serde_json`]. struct Strategy; -impl super::structs_and_enums::Strategy for Strategy { - fn serialize_unit_struct( - &self, +impl strategy::Strategy for Strategy { + fn reserialize_unit_struct<'de, D: Deserializer<'de>, S: Serializer>( + _context: &Context<'_, Self>, + deserializer: D, serializer: S, _name: &str, - ) -> Result { - serializer.serialize_unit() + ) -> Result, D::Error> { + <()>::deserialize(deserializer)?; + Ok(serializer.serialize_unit()) } - fn serialize_newtype_struct( - &self, + fn reserialize_newtype_struct<'de, D: Deserializer<'de>, S: Serializer>( + context: &Context<'_, Self>, + deserializer: D, serializer: S, - _name: &str, - value: &T, - ) -> Result { - value.serialize(serializer) + expecting: expecting::Struct<'_, expecting::data::Newtype>, + ) -> Result, D::Error> { + context.reserialize_ty(expecting.data.schema, deserializer, |inner| { + inner.serialize(serializer) + }) } - fn serialize_tuple_struct( - &self, + fn reserialize_tuple_struct<'de, D: Deserializer<'de>, S: Serializer>( + context: &Context<'_, Self>, + deserializer: D, serializer: S, - reserialize: impl Reserialize, - _name: &str, - fields: &[OwnedNamedType], - ) -> Result { - let mut serializer = serializer.serialize_seq(Some(fields.len()))?; - for field in fields { - serializer.serialize_element(&reserialize.with_schema(field))?; - } - serializer.end() + expecting: expecting::Struct<'_, expecting::data::Tuple<'_>>, + ) -> Result, D::Error> { + deserializer.deserialize_tuple( + expecting.data.elements.len(), + reserialize::tuple::Visitor { + context, + serializer, + fields: expecting.data.elements, + reserializer: expecting::Tuple, + }, + ) } - fn serialize_struct( - &self, + fn reserialize_struct<'de, D: Deserializer<'de>, S: Serializer>( + context: &Context<'_, Self>, + deserializer: D, serializer: S, - reserialize: impl Reserialize, - _name: &str, - fields: &[OwnedNamedValue], - ) -> Result { - let mut serializer = serializer.serialize_map(Some(fields.len()))?; - for field in fields { - serializer.serialize_entry(&field.name, &reserialize.with_schema(&field.ty))?; - } - serializer.end() + expecting: expecting::Struct<'_, expecting::data::Struct<'_>>, + ) -> Result, D::Error> { + deserializer.deserialize_tuple( + expecting.data.fields.len(), + reserialize::tuple::Visitor { + context, + serializer, + fields: expecting.data.fields.iter().map(|f| &f.ty), + reserializer: structs::ReserializeStructAsMap { expecting }, + }, + ) } - fn serialize_enum( - &self, + fn reserialize_enum<'de, D: Deserializer<'de>, S: Serializer>( + context: &Context<'_, Self>, + deserializer: D, serializer: S, - reserialize: impl Reserialize, - _name: &str, - _variant_index: u32, - variant: &OwnedNamedVariant, - ) -> Result { - struct ReserializeTuple<'a, Reserialize> { - reserialize: Reserialize, - elements: &'a [OwnedNamedType], - } - impl Serialize for ReserializeTuple<'_, R> { - fn serialize(&self, serializer: S) -> Result { - let mut serializer = serializer.serialize_tuple(self.elements.len())?; - for element in self.elements { - serializer.serialize_element(&self.reserialize.with_schema(element))?; - } - serializer.end() - } - } - - struct ReserializeFields<'a, Reserialize> { - reserialize: Reserialize, - fields: &'a [OwnedNamedValue], - } - impl Serialize for ReserializeFields<'_, R> { - fn serialize(&self, serializer: S) -> Result { - let mut serializer = serializer.serialize_map(Some(self.fields.len()))?; - for field in self.fields { - serializer - .serialize_entry(&field.name, &self.reserialize.with_schema(&field.ty))?; - } - serializer.end() - } - } - - match &variant.ty { - OwnedDataModelVariant::UnitVariant => serializer.serialize_str(&variant.name), - OwnedDataModelVariant::NewtypeVariant(inner) => { - let mut serializer = serializer.serialize_map(Some(1))?; - serializer.serialize_entry(&variant.name, &reserialize.with_schema(inner))?; - serializer.end() - } - OwnedDataModelVariant::TupleVariant(fields) => { - let mut serializer = serializer.serialize_map(Some(1))?; - serializer.serialize_entry( - &variant.name, - &ReserializeTuple { - reserialize, - elements: fields, - }, - )?; - serializer.end() - } - OwnedDataModelVariant::StructVariant(fields) => { - let mut serializer = serializer.serialize_map(Some(1))?; - serializer.serialize_entry( - &variant.name, - &ReserializeFields { - reserialize, - fields, - }, - )?; - serializer.end() - } - } + expecting: expecting::Enum<'_, '_>, + ) -> Result, D::Error> { + // Postcard encodes enums as (index, value) + deserializer.deserialize_tuple( + 2, + enums::Visitor { + serializer, + context, + expecting, + }, + ) } } diff --git a/source/postcard-dyn/src/reserialize/lossy/enums.rs b/source/postcard-dyn/src/reserialize/lossy/enums.rs new file mode 100644 index 0000000..f3528a3 --- /dev/null +++ b/source/postcard-dyn/src/reserialize/lossy/enums.rs @@ -0,0 +1,293 @@ +use core::{fmt, marker::PhantomData}; + +use postcard_schema::schema::owned::{OwnedDataModelVariant, OwnedNamedType}; +use serde::{ + de::{self, DeserializeSeed, Deserializer, SeqAccess}, + ser::{Error as _, SerializeMap, SerializeTuple, Serializer}, +}; + +use crate::{ + reserialize::{ + expecting::{self, Unexpected}, + Context, ReserializeFn, + }, + Error, +}; + +use super::Strategy; + +pub struct Visitor<'a, S> { + pub serializer: S, + pub context: &'a Context<'a, Strategy>, + pub expecting: expecting::Enum<'a, 'a>, +} + +impl<'de, S: Serializer> de::Visitor<'de> for Visitor<'_, S> { + type Value = Result; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + de::Expected::fmt(&self.expecting, formatter) + } + + fn visit_seq>(self, mut seq: A) -> Result { + let variant_index: u32 = seq.next_element()?.unwrap(); + let variant = (usize::try_from(variant_index).ok()) + .and_then(|v| self.expecting.variants.get(v)) + .ok_or_else(|| { + A::Error::unknown_variant_index(variant_index, ..self.expecting.variants.len()) + })?; + + let err = || S::Error::custom("missing variant data"); + Ok(match &variant.ty { + OwnedDataModelVariant::UnitVariant => self.serializer.serialize_str(&variant.name), + OwnedDataModelVariant::NewtypeVariant(inner) => seq + .next_element_seed(NewtypeVariantSeed { + serializer: self.serializer, + context: self.context, + variant: &variant.name, + inner, + })? + .ok_or_else(err) + .and_then(|res| res), + OwnedDataModelVariant::TupleVariant(fields) => seq + .next_element_seed(TupleVariantVisitor { + serializer: self.serializer, + context: self.context, + expecting: expecting::Variant { + enum_name: self.expecting.name, + variant_index, + variant_name: &variant.name, + data: expecting::data::Tuple { elements: fields }, + }, + })? + .ok_or_else(err) + .and_then(|res| res), + OwnedDataModelVariant::StructVariant(fields) => seq + .next_element_seed(StructVariantVisitor { + serializer: self.serializer, + context: self.context, + expecting: expecting::Variant { + enum_name: self.expecting.name, + variant_index, + variant_name: &variant.name, + data: expecting::data::Struct { fields }, + }, + })? + .ok_or_else(err) + .and_then(|res| res), + }) + } +} + +struct NewtypeVariantSeed<'a, S> { + serializer: S, + context: &'a Context<'a, Strategy>, + variant: &'a str, + inner: &'a OwnedNamedType, +} + +impl<'de, S: Serializer> DeserializeSeed<'de> for NewtypeVariantSeed<'_, S> { + type Value = Result; + + fn deserialize>(self, deserializer: D) -> Result { + let mut serializer = match self.serializer.serialize_map(Some(1)) { + Ok(serializer) => serializer, + Err(err) => return Ok(Err(err)), + }; + let res = self + .context + .reserialize_ty(self.inner, deserializer, |value| { + serializer.serialize_entry(self.variant, value) + })?; + match res { + Ok(()) => {} + Err(err) => return Ok(Err(err)), + } + Ok(serializer.end()) + } +} + +struct TupleVariantVisitor<'a, S> { + serializer: S, + context: &'a Context<'a, Strategy>, + expecting: expecting::Variant<'a, expecting::data::Tuple<'a>>, +} + +impl<'de, S: Serializer> DeserializeSeed<'de> for TupleVariantVisitor<'_, S> { + type Value = Result; + + fn deserialize>(self, deserializer: D) -> Result { + deserializer.deserialize_tuple(self.expecting.data.elements.len(), self) + } +} + +impl<'de, S: Serializer> de::Visitor<'de> for TupleVariantVisitor<'_, S> { + type Value = Result; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + de::Expected::fmt(&self.expecting, formatter) + } + + fn visit_seq>(self, seq: A) -> Result { + let mut serializer = match self.serializer.serialize_map(Some(1)) { + Ok(serializer) => serializer, + Err(err) => return Ok(Err(err)), + }; + self.context.reserialize( + ReserializeTupleVariant { + context: self.context, + seq, + expecting: &self.expecting, + de: PhantomData, + }, + |data| { + serializer.serialize_entry(self.expecting.variant_name, data)?; + serializer.end() + }, + ) + } +} +struct ReserializeTupleVariant<'de, 'a, A> { + context: &'a Context<'a, Strategy>, + seq: A, + expecting: &'a expecting::Variant<'a, expecting::data::Tuple<'a>>, + de: PhantomData<&'de ()>, +} + +struct ElementSeed<'a, S> { + context: &'a Context<'a, Strategy>, + serializer: &'a mut S, + schema: &'a OwnedNamedType, +} + +impl<'de, 'a, S: SerializeTuple> DeserializeSeed<'de> for ElementSeed<'a, S> { + type Value = Result<(), S::Error>; + + fn deserialize>(self, deserializer: D) -> Result { + self.context + .reserialize_ty(self.schema, deserializer, |element| { + self.serializer.serialize_element(element) + }) + } +} + +impl<'de, 'a, A: SeqAccess<'de>> ReserializeFn for ReserializeTupleVariant<'de, 'a, A> { + type DeserializeError = A::Error; + + fn reserialize( + mut self, + serializer: S, + ) -> Result> { + let fields = self.expecting.data.elements; + let mut serializer = serializer + .serialize_tuple(fields.len()) + .map_err(Error::Serialize)?; + for (idx, field) in fields.iter().enumerate() { + self.seq + .next_element_seed(ElementSeed { + context: self.context, + serializer: &mut serializer, + schema: field, + }) + .map_err(Error::Deserialize)? + .ok_or_else(|| A::Error::missing_elements(idx, self.expecting, fields.len())) + .map_err(Error::Deserialize)? + .map_err(Error::Serialize)?; + } + serializer.end().map_err(Error::Serialize) + } +} + +struct StructVariantVisitor<'a, S> { + serializer: S, + context: &'a Context<'a, Strategy>, + expecting: expecting::Variant<'a, expecting::data::Struct<'a>>, +} + +impl<'de, S: Serializer> DeserializeSeed<'de> for StructVariantVisitor<'_, S> { + type Value = Result; + + fn deserialize>(self, deserializer: D) -> Result { + deserializer.deserialize_tuple(self.expecting.data.fields.len(), self) + } +} + +impl<'de, S: Serializer> de::Visitor<'de> for StructVariantVisitor<'_, S> { + type Value = Result; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + de::Expected::fmt(&self.expecting, formatter) + } + + fn visit_seq>(self, seq: A) -> Result { + let mut serializer = match self.serializer.serialize_map(Some(1)) { + Ok(serializer) => serializer, + Err(err) => return Ok(Err(err)), + }; + self.context.reserialize( + ReserializeStructVariant { + context: self.context, + seq, + expecting: &self.expecting, + de: PhantomData, + }, + |data| { + serializer.serialize_entry(self.expecting.variant_name, data)?; + serializer.end() + }, + ) + } +} + +struct ReserializeStructVariant<'a, 'de, A> { + context: &'a Context<'a, Strategy>, + seq: A, + expecting: &'a expecting::Variant<'a, expecting::data::Struct<'a>>, + de: PhantomData<&'de ()>, +} + +struct FieldSeed<'a, S> { + context: &'a Context<'a, Strategy>, + serializer: &'a mut S, + key: &'a str, + schema: &'a OwnedNamedType, +} + +impl<'de, 'a, S: SerializeMap> DeserializeSeed<'de> for FieldSeed<'a, S> { + type Value = Result<(), S::Error>; + + fn deserialize>(self, deserializer: D) -> Result { + self.context + .reserialize_ty(self.schema, deserializer, |value| { + self.serializer.serialize_entry(self.key, value) + }) + } +} + +impl<'de, 'a, A: SeqAccess<'de>> ReserializeFn for ReserializeStructVariant<'a, 'de, A> { + type DeserializeError = A::Error; + + fn reserialize( + mut self, + serializer: S, + ) -> Result> { + let fields = self.expecting.data.fields; + let mut serializer = serializer + .serialize_map(Some(fields.len())) + .map_err(Error::Serialize)?; + for (idx, field) in fields.iter().enumerate() { + self.seq + .next_element_seed(FieldSeed { + context: self.context, + serializer: &mut serializer, + key: &field.name, + schema: &field.ty, + }) + .map_err(Error::Deserialize)? + .ok_or_else(|| A::Error::missing_elements(idx, self.expecting, fields.len())) + .map_err(Error::Deserialize)? + .map_err(Error::Serialize)?; + } + serializer.end().map_err(Error::Serialize) + } +} diff --git a/source/postcard-dyn/src/reserialize/lossy/structs.rs b/source/postcard-dyn/src/reserialize/lossy/structs.rs new file mode 100644 index 0000000..600ce25 --- /dev/null +++ b/source/postcard-dyn/src/reserialize/lossy/structs.rs @@ -0,0 +1,55 @@ +use core::slice; + +use postcard_schema::schema::owned::OwnedNamedValue; +use serde::{ + de, + ser::{Serialize, SerializeMap, SerializeTuple, Serializer}, +}; + +use crate::reserialize::{self, expecting}; + +pub struct ReserializeStructAsMap<'a> { + pub expecting: expecting::Struct<'a, expecting::data::Struct<'a>>, +} + +impl de::Expected for ReserializeStructAsMap<'_> { + fn fmt(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + de::Expected::fmt(&self.expecting, formatter) + } +} + +impl<'a, S: Serializer> reserialize::tuple::Reserializer for ReserializeStructAsMap<'a> { + type SerializeTuple = SerializeFieldsAsMapEntries<'a, S::SerializeMap>; + + fn reserialize_tuple( + &self, + serializer: S, + len: usize, + ) -> Result::Error> { + let serializer = serializer.serialize_map(Some(len))?; + let fields = self.expecting.data.fields.iter(); + Ok(SerializeFieldsAsMapEntries { serializer, fields }) + } +} + +pub struct SerializeFieldsAsMapEntries<'a, S> { + serializer: S, + fields: slice::Iter<'a, OwnedNamedValue>, +} + +impl<'a, S: SerializeMap> SerializeTuple for SerializeFieldsAsMapEntries<'a, S> { + type Ok = S::Ok; + type Error = S::Error; + + fn serialize_element(&mut self, value: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + let field = self.fields.next().unwrap(); + self.serializer.serialize_entry(&field.name, value) + } + + fn end(self) -> Result { + self.serializer.end() + } +} diff --git a/source/postcard-dyn/src/reserialize/map.rs b/source/postcard-dyn/src/reserialize/map.rs new file mode 100644 index 0000000..a68fbd0 --- /dev/null +++ b/source/postcard-dyn/src/reserialize/map.rs @@ -0,0 +1,98 @@ +use core::fmt; + +use postcard_schema::schema::owned::OwnedNamedType; +use serde::{ + de::{self, DeserializeSeed, MapAccess}, + ser::SerializeMap, + Deserializer, Serializer, +}; + +use super::Context; + +pub struct Visitor<'a, S, Strategy> { + pub context: &'a Context<'a, Strategy>, + pub serializer: S, + pub key: &'a OwnedNamedType, + pub val: &'a OwnedNamedType, +} + +impl<'de, S, Strategy> de::Visitor<'de> for Visitor<'_, S, Strategy> +where + S: Serializer, + Strategy: super::Strategy, +{ + type Value = Result; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("a map") + } + + fn visit_map>(self, mut map: A) -> Result { + let mut serializer = match self.serializer.serialize_map(map.size_hint()) { + Ok(serializer) => serializer, + Err(err) => return Ok(Err(err)), + }; + while let Some(res) = map.next_key_seed(KeySeed { + context: self.context, + serializer: &mut serializer, + schema: self.key, + })? { + match res { + Ok(()) => {} + Err(err) => return Ok(Err(err)), + } + let seed = ValueSeed { + context: self.context, + serializer: &mut serializer, + schema: self.val, + }; + match map.next_value_seed(seed)? { + Ok(()) => {} + Err(err) => return Ok(Err(err)), + } + } + Ok(serializer.end()) + } +} + +struct KeySeed<'a, S, Strategy> { + context: &'a Context<'a, Strategy>, + serializer: &'a mut S, + schema: &'a OwnedNamedType, +} + +impl<'de, 'a, S, Strategy> DeserializeSeed<'de> for KeySeed<'a, S, Strategy> +where + S: SerializeMap, + Strategy: super::Strategy, +{ + type Value = Result<(), S::Error>; + + fn deserialize>(self, deserializer: D) -> Result { + self.context + .reserialize_ty(self.schema, deserializer, |key| { + self.serializer.serialize_key(key) + }) + } +} + +struct ValueSeed<'a, S, Strategy> { + context: &'a Context<'a, Strategy>, + serializer: &'a mut S, + schema: &'a OwnedNamedType, +} + +impl<'de, 'a, S, Strategy> DeserializeSeed<'de> for ValueSeed<'a, S, Strategy> +where + S: SerializeMap, + Strategy: super::Strategy, +{ + type Value = Result<(), S::Error>; + + fn deserialize>(self, deserializer: D) -> Result { + self.context + .reserialize_ty(self.schema, deserializer, |val| { + self.serializer.serialize_value(val) + }) + } +} diff --git a/source/postcard-dyn/src/reserialize/option.rs b/source/postcard-dyn/src/reserialize/option.rs new file mode 100644 index 0000000..348eef3 --- /dev/null +++ b/source/postcard-dyn/src/reserialize/option.rs @@ -0,0 +1,35 @@ +use core::fmt; + +use postcard_schema::schema::owned::OwnedNamedType; +use serde::{de, Deserializer, Serializer}; + +use super::Context; + +pub struct Visitor<'a, S, Strategy> { + pub context: &'a Context<'a, Strategy>, + pub serializer: S, + pub schema: &'a OwnedNamedType, +} + +impl<'a, 'de, S, Strategy> de::Visitor<'de> for Visitor<'a, S, Strategy> +where + S: Serializer, + Strategy: super::Strategy, +{ + type Value = Result; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("option") + } + + fn visit_none(self) -> Result { + Ok(self.serializer.serialize_none()) + } + + fn visit_some>(self, deserializer: D) -> Result { + self.context + .reserialize_ty(self.schema, deserializer, |inner| { + self.serializer.serialize_some(inner) + }) + } +} diff --git a/source/postcard-dyn/src/reserialize/seq.rs b/source/postcard-dyn/src/reserialize/seq.rs new file mode 100644 index 0000000..0d0f584 --- /dev/null +++ b/source/postcard-dyn/src/reserialize/seq.rs @@ -0,0 +1,80 @@ +use core::fmt; + +use postcard_schema::schema::owned::OwnedNamedType; +use serde::{ + de::{self, DeserializeSeed, Error as _, SeqAccess}, + ser::SerializeSeq, + Deserializer, Serializer, +}; + +use super::{Context, Expected}; + +pub struct Visitor<'a, S, Strategy> { + pub context: &'a Context<'a, Strategy>, + pub serializer: S, + pub schemas: &'a [OwnedNamedType], +} + +impl<'de, S, Strategy> de::Visitor<'de> for Visitor<'_, S, Strategy> +where + S: Serializer, + Strategy: super::Strategy, +{ + type Value = Result; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("a sequence") + } + + fn visit_seq>(self, mut seq: A) -> Result { + let mut serializer = match self.serializer.serialize_seq(seq.size_hint()) { + Ok(serializer) => serializer, + Err(err) => return Ok(Err(err)), + }; + let mut seed = ElementSeed { + context: self.context, + serializer: &mut serializer, + schemas: self.schemas, + idx: 0, + }; + while let Some(res) = seq.next_element_seed(&mut seed)? { + match res { + Ok(()) => {} + Err(err) => return Ok(Err(err)), + } + } + Ok(serializer.end()) + } +} + +struct ElementSeed<'a, S: 'a, Strategy> { + context: &'a Context<'a, Strategy>, + serializer: &'a mut S, + schemas: &'a [OwnedNamedType], + idx: usize, +} + +impl<'de, 'a, S, Strategy> DeserializeSeed<'de> for &mut ElementSeed<'a, S, Strategy> +where + S: SerializeSeq, + Strategy: super::Strategy, +{ + type Value = Result<(), S::Error>; + + fn deserialize(self, deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let schemas = self.schemas; + let schema = schemas.get(self.idx).ok_or_else(|| { + D::Error::invalid_length( + self.idx + 1, + &Expected(format_args!("sequence of length {}", schemas.len())), + ) + })?; + self.context + .reserialize_ty(schema, deserializer, |element| { + self.serializer.serialize_element(element) + }) + } +} diff --git a/source/postcard-dyn/src/reserialize/strategy.rs b/source/postcard-dyn/src/reserialize/strategy.rs new file mode 100644 index 0000000..8c98cb2 --- /dev/null +++ b/source/postcard-dyn/src/reserialize/strategy.rs @@ -0,0 +1,62 @@ +//! How to reserialize structs and enums to work around [`Deserializer`] and [`Serializer`]'s +//! `&'static str` requirements. + +use postcard_schema::schema::owned::OwnedNamedType; +use serde::{Deserializer, Serialize, Serializer}; + +use crate::Error; + +use super::{expecting, Context}; + +/// How to reserialize structs and enums to work around [`Deserializer`] and [`Serializer`]'s +/// `&'static str` requirements. +pub(super) trait Strategy: Sized { + fn reserialize_unit_struct<'de, D: Deserializer<'de>, S: Serializer>( + context: &Context<'_, Self>, + deserializer: D, + serializer: S, + name: &str, + ) -> Result, D::Error>; + + fn reserialize_newtype_struct<'de, D: Deserializer<'de>, S: Serializer>( + context: &Context<'_, Self>, + deserializer: D, + serializer: S, + expecting: expecting::Struct<'_, expecting::data::Newtype>, + ) -> Result, D::Error>; + + fn reserialize_tuple_struct<'de, D: Deserializer<'de>, S: Serializer>( + context: &Context<'_, Self>, + deserializer: D, + serializer: S, + expecting: expecting::Struct<'_, expecting::data::Tuple>, + ) -> Result, D::Error>; + + fn reserialize_struct<'de, D: Deserializer<'de>, S: Serializer>( + context: &Context<'_, Self>, + deserializer: D, + serializer: S, + expecting: expecting::Struct<'_, expecting::data::Struct>, + ) -> Result, D::Error>; + + fn reserialize_enum<'de, D: Deserializer<'de>, S: Serializer>( + context: &Context<'_, Self>, + deserializer: D, + serializer: S, + expecting: expecting::Enum<'_, '_>, + ) -> Result, D::Error>; + + fn reserialize<'de, D: Deserializer<'de>, S: Serializer>( + &self, + schema: &OwnedNamedType, + deserializer: D, + serializer: S, + ) -> Result> { + let context = Context { strategy: self }; + match context.reserialize_ty(schema, deserializer, |value| value.serialize(serializer)) { + Ok(Ok(out)) => Ok(out), + Ok(Err(err)) => Err(Error::Serialize(err)), + Err(err) => Err(Error::Deserialize(err)), + } + } +} diff --git a/source/postcard-dyn/src/reserialize/tuple.rs b/source/postcard-dyn/src/reserialize/tuple.rs new file mode 100644 index 0000000..4dd9fbd --- /dev/null +++ b/source/postcard-dyn/src/reserialize/tuple.rs @@ -0,0 +1,108 @@ +use core::fmt; + +use postcard_schema::schema::owned::OwnedNamedType; +use serde::{ + de::{self, DeserializeSeed, SeqAccess}, + ser::SerializeTuple, + Deserializer, Serializer, +}; + +use super::{ + expecting::{self, Unexpected}, + Context, +}; + +pub struct Visitor<'a, S, Strategy, Fields, Reserializer> { + pub context: &'a Context<'a, Strategy>, + pub serializer: S, + pub fields: Fields, + pub reserializer: Reserializer, +} + +pub trait Reserializer: de::Expected { + type SerializeTuple: SerializeTuple; + + fn reserialize_tuple( + &self, + serializer: S, + len: usize, + ) -> Result; +} + +impl Reserializer for expecting::Tuple { + type SerializeTuple = S::SerializeTuple; + + fn reserialize_tuple( + &self, + serializer: S, + len: usize, + ) -> Result { + serializer.serialize_tuple(len) + } +} + +impl<'de, 'schema, S, Strategy, Fields, Reserializer> de::Visitor<'de> + for Visitor<'_, S, Strategy, Fields, Reserializer> +where + S: Serializer, + Strategy: super::Strategy, + Fields: IntoIterator, + Reserializer: self::Reserializer, +{ + type Value = Result; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + de::Expected::fmt(&self.reserializer, formatter) + } + + fn visit_seq>(self, mut seq: A) -> Result { + let fields = self.fields.into_iter(); + let num_fields = fields.len(); + let mut serializer = match self + .reserializer + .reserialize_tuple(self.serializer, num_fields) + { + Ok(serializer) => serializer, + Err(err) => return Ok(Err(err)), + }; + for (idx, field) in fields.enumerate() { + let seed = ElementSeed { + context: self.context, + serializer: &mut serializer, + field, + }; + let res = seq + .next_element_seed(seed)? + .ok_or_else(|| A::Error::missing_elements(idx, &self.reserializer, num_fields))?; + match res { + Ok(()) => {} + Err(err) => return Ok(Err(err)), + } + } + Ok(serializer.end()) + } +} + +struct ElementSeed<'a, S, Strategy> { + context: &'a Context<'a, Strategy>, + serializer: &'a mut S, + field: &'a OwnedNamedType, +} + +impl<'de, 'a, S, Strategy> DeserializeSeed<'de> for ElementSeed<'a, S, Strategy> +where + S: SerializeTuple, + Strategy: super::Strategy, +{ + type Value = Result<(), S::Error>; + + fn deserialize(self, deserializer: D) -> Result + where + D: Deserializer<'de>, + { + self.context + .reserialize_ty(self.field, deserializer, |element| { + self.serializer.serialize_element(element) + }) + } +}