diff --git a/memcrab-cli/src/main.rs b/memcrab-cli/src/main.rs index 90d83f1..63beedf 100644 --- a/memcrab-cli/src/main.rs +++ b/memcrab-cli/src/main.rs @@ -1,107 +1,108 @@ -use anyhow::bail; -use clap::Parser; -use core::num::NonZeroUsize; -use memcrab::RawClient; -use memcrab_cache::Cache; -use rustyline::error::ReadlineError; -use rustyline::DefaultEditor; - -#[derive(Parser)] -#[command(author, version, about, long_about = None)] -struct Cli { - #[arg(short = 'H', long, default_value = "127.0.0.1")] - host: String, - - #[arg(short, long, default_value = "9090")] - port: String, - - #[arg(short, long, action)] - server: bool, -} - -async fn eval_lines(addr: String) -> anyhow::Result<()> { - let addr = addr.parse()?; - let mut client = RawClient::connect(addr).await?; - let mut editor = DefaultEditor::new()?; - loop { - let line = editor.readline("memcrab> "); - match line { - Ok(line) => { - eval_line(&mut client, line) - .await - .unwrap_or_else(|e| println!("error: {:?}", e)); - } - Err(ReadlineError::Interrupted) => { - println!("Ctrl-c"); - } - Err(ReadlineError::Eof) => { - println!("quit"); - break; - } - Err(err) => { - println!("error: {:?}", err); - break; - } - } - } - - Ok(()) -} - -async fn eval_line(client: &mut RawClient, line: String) -> anyhow::Result<()> { - let tokens = line.split_whitespace().collect::>(); - if tokens.is_empty() { - return Ok(()); - } - if tokens[0] == "get" { - if tokens.len() != 2 { - bail!("syntax error: expected one key after `get`"); - } - let resp = client.get(tokens[1]).await?; - match resp { - Some(val) => println!("{}: {:?}", tokens[1], val), - None => println!("no value set"), - } - } else if tokens[0] == "set" { - if tokens.len() < 3 { - bail!("syntax error: expected one key and bytes after `set`"); - } - - client - .set( - tokens[1], - tokens[2..] - .iter() - .map(|&s| s.parse().unwrap()) - .collect::>(), - ) - .await?; - } else { - bail!("syntax error: unexpected token {}", tokens[0]); - } - Ok(()) -} - -#[allow(unused)] -async fn serve(addr: String) -> anyhow::Result<()> { - let maxbytes = 100_000; - let maxlen = NonZeroUsize::new(110).unwrap(); - let cache = Cache::new(maxlen, maxbytes); - - todo!("server is not implemented"); - Ok(()) -} - -#[tokio::main] -async fn main() -> anyhow::Result<()> { - let cli = Cli::parse(); - - let addr = format!("{}:{}", cli.host, cli.port); - - if cli.server { - serve(addr).await?; - } else { - eval_lines(addr).await?; - } - Ok(()) -} +// use anyhow::bail; +// use clap::Parser; +// use core::num::NonZeroUsize; +// use memcrab::RawClient; +// use memcrab_cache::Cache; +// use rustyline::error::ReadlineError; +// use rustyline::DefaultEditor; +// +// #[derive(Parser)] +// #[command(author, version, about, long_about = None)] +// struct Cli { +// #[arg(short = 'H', long, default_value = "127.0.0.1")] +// host: String, +// +// #[arg(short, long, default_value = "9090")] +// port: String, +// +// #[arg(short, long, action)] +// server: bool, +// } +// +// async fn eval_lines(addr: String) -> anyhow::Result<()> { +// let addr = addr.parse()?; +// let mut client = RawClient::connect(addr).await?; +// let mut editor = DefaultEditor::new()?; +// loop { +// let line = editor.readline("memcrab> "); +// match line { +// Ok(line) => { +// eval_line(&mut client, line) +// .await +// .unwrap_or_else(|e| println!("error: {:?}", e)); +// } +// Err(ReadlineError::Interrupted) => { +// println!("Ctrl-c"); +// } +// Err(ReadlineError::Eof) => { +// println!("quit"); +// break; +// } +// Err(err) => { +// println!("error: {:?}", err); +// break; +// } +// } +// } +// +// Ok(()) +// } +// +// async fn eval_line(client: &mut RawClient, line: String) -> anyhow::Result<()> { +// let tokens = line.split_whitespace().collect::>(); +// if tokens.is_empty() { +// return Ok(()); +// } +// if tokens[0] == "get" { +// if tokens.len() != 2 { +// bail!("syntax error: expected one key after `get`"); +// } +// let resp = client.get(tokens[1]).await?; +// match resp { +// Some(val) => println!("{}: {:?}", tokens[1], val), +// None => println!("no value set"), +// } +// } else if tokens[0] == "set" { +// if tokens.len() < 3 { +// bail!("syntax error: expected one key and bytes after `set`"); +// } +// +// client +// .set( +// tokens[1], +// tokens[2..] +// .iter() +// .map(|&s| s.parse().unwrap()) +// .collect::>(), +// ) +// .await?; +// } else { +// bail!("syntax error: unexpected token {}", tokens[0]); +// } +// Ok(()) +// } +// +// #[allow(unused)] +// async fn serve(addr: String) -> anyhow::Result<()> { +// let maxbytes = 100_000; +// let maxlen = NonZeroUsize::new(110).unwrap(); +// let cache = Cache::new(maxlen, maxbytes); +// +// todo!("server is not implemented"); +// Ok(()) +// } +// +// #[tokio::main] +// async fn main() -> anyhow::Result<()> { +// let cli = Cli::parse(); +// +// let addr = format!("{}:{}", cli.host, cli.port); +// +// if cli.server { +// serve(addr).await?; +// } else { +// eval_lines(addr).await?; +// } +// Ok(()) +// } +fn main() {} diff --git a/memcrab-protocol/README.md b/memcrab-protocol/README.md new file mode 100644 index 0000000..f8028ba --- /dev/null +++ b/memcrab-protocol/README.md @@ -0,0 +1,76 @@ +# memcrab-protocol + +This crate contains the implementation of the protocol that is used by server and the official Rust client. + +**Note that this is not meant to be the official Rust API for memcrab, use the official wrapper client instead.** + + +## Usage +```rust +TODO + +use memcrab_protocol::Socket; +use memcrab_protocol::{Message, Request, Response}; + +#[tokio::main] +async fn main() { +} +``` + + +## Protocol description + +### Encoding + +This is a binary protocol. The keys should be encoded as valid UTF-8 strings though. + +### Messages +TCP messages are not framed to distinct messages by themselves. Instead, we need to implement message borders ourselves. + + +Memcrab messages contain a header of a fixed length, and then payload of variable length. + + +The first byte of the header encodes the kind of message. The rest of the header encodes the information about the lengths of the payload or other metainfo. + + +Message kinds are shared by all messages for client and server. Clients should only send request messages and understand responses messages however, vice versa. + + +| Message kind | first byte | rest of the header | payload +| --- | --- | --- | --- +| VersionRequest | 0 | version | none +| PingRequest | 1 | none | none +| GetRequest | 2 | klen | key +| SetRequest | 3 | klen, vlen, exp | key, value +| DeleteRequest | 4 | klen | key +| ClearRequest | 5 | none | none + +| PongResponse | 128 | none | none +| OkResponse | 129 | none | none +| ValueResponse | 130 | vlen | value +| KeyNotFoundResponse | 131 | none | none +| ErrorResponse | 132 | vlen | value + + +The lengths of fields for klen, vlen, version, etc are as follows: + + +| header field | size (bytes) | +| --- | --- | +| klen | 8 | +| vlen | 8 | +| version | 2 | +| exp | 4 | + +The header length is 21 bytes. + + +### Versioning +Protocol is versioned by a number and are not backwards compatible. + +The current version is `0`. + +The clients must send `Version` message as their first message. The server must close the connection if the version is not compatible. + + diff --git a/memcrab-protocol/src/mapping/alias.rs b/memcrab-protocol/src/alias.rs similarity index 79% rename from memcrab-protocol/src/mapping/alias.rs rename to memcrab-protocol/src/alias.rs index b849446..5c689af 100644 --- a/memcrab-protocol/src/mapping/alias.rs +++ b/memcrab-protocol/src/alias.rs @@ -1,5 +1,4 @@ pub type Version = u16; -pub type ErrMsgLen = u64; pub type KeyLen = u64; pub type ValueLen = u64; pub type Expiration = u32; diff --git a/memcrab-protocol/src/err.rs b/memcrab-protocol/src/err.rs index 4990984..aafd6b1 100644 --- a/memcrab-protocol/src/err.rs +++ b/memcrab-protocol/src/err.rs @@ -1,19 +1,22 @@ use thiserror::Error; #[derive(Error, Debug)] -pub enum ClientSideError { +pub enum Error { #[error("io")] IO(#[from] std::io::Error), - #[error("parsing failed")] - Parsing(#[from] ParsingError), -} -pub type ServerSideError = ClientSideError; + #[error("cannot parse message")] + Parse(#[from] ParseError), +} #[derive(Error, Debug)] -pub enum ParsingError { - #[error("invalid header")] - Header, - #[error("invalid payload")] - Payload, +pub enum ParseError { + #[error("invalid message kind")] + UnknownKind, + + #[error("malformed string")] + InvalidString, + + #[error("message is too big")] + TooBig, } diff --git a/memcrab-protocol/src/kind.rs b/memcrab-protocol/src/kind.rs new file mode 100644 index 0000000..2115c4c --- /dev/null +++ b/memcrab-protocol/src/kind.rs @@ -0,0 +1,18 @@ +use num_enum::{IntoPrimitive, TryFromPrimitive}; + +#[repr(u8)] +#[derive(Debug, Clone, Copy, TryFromPrimitive, IntoPrimitive)] +pub enum MessageKind { + VersionRequest = 0, + PingRequest = 1, + GetRequest = 2, + SetRequest = 3, + DeleteRequest = 4, + ClearRequest = 5, + + PongResponse = 128, + OkResponse = 129, + ValueResponse = 130, + KeyNotFoundResponse = 131, + ErrorResponse = 132, +} diff --git a/memcrab-protocol/src/lib.rs b/memcrab-protocol/src/lib.rs index 9837507..6fe637b 100644 --- a/memcrab-protocol/src/lib.rs +++ b/memcrab-protocol/src/lib.rs @@ -1,16 +1,14 @@ +mod alias; mod err; - -#[allow(unused)] -mod transport; - -#[allow(unused)] -pub(crate) mod mapping; - -pub mod io; - -use mapping::alias::Version; - -pub use err::{ClientSideError, ParsingError, ServerSideError}; -pub use transport::{ClientSocket, ErrorResponse, Request, Response, ServerSocket}; - -pub const PROTOCOL_VERSION: Version = 0; +mod kind; +mod message; +mod sizes; +mod socket; +mod stream; +mod version; + +pub use err::{Error, ParseError}; +pub use message::{Message, Request, Response}; +pub use socket::Socket; +pub use stream::{AsyncReader, AsyncWriter}; +pub use version::VERSION; diff --git a/memcrab-protocol/src/mapping/flags.rs b/memcrab-protocol/src/mapping/flags.rs deleted file mode 100644 index 6473752..0000000 --- a/memcrab-protocol/src/mapping/flags.rs +++ /dev/null @@ -1,24 +0,0 @@ -use num_enum::{IntoPrimitive, TryFromPrimitive}; - -#[repr(u8)] -#[derive(Debug, Clone, Copy, TryFromPrimitive, IntoPrimitive)] -pub enum RequestFlag { - Version = 0, - Ping = 1, - Get = 2, - Set = 3, - Delete = 4, - Clear = 5, -} - -#[repr(u8)] -#[derive(Debug, Clone, Copy, TryFromPrimitive, IntoPrimitive)] -pub enum ResponseFlag { - Pong = 0, - Value = 1, - Ok = 2, - KeyNotFound = 3, - - ValidationErr = 201, - InternalErr = 202, -} diff --git a/memcrab-protocol/src/mapping/mod.rs b/memcrab-protocol/src/mapping/mod.rs deleted file mode 100644 index 112747c..0000000 --- a/memcrab-protocol/src/mapping/mod.rs +++ /dev/null @@ -1,3 +0,0 @@ -pub(crate) mod alias; -pub(crate) mod flags; -pub(crate) mod tokens; diff --git a/memcrab-protocol/src/mapping/tokens.rs b/memcrab-protocol/src/mapping/tokens.rs deleted file mode 100644 index 7850a36..0000000 --- a/memcrab-protocol/src/mapping/tokens.rs +++ /dev/null @@ -1,98 +0,0 @@ -use super::alias::{ErrMsgLen, Expiration, KeyLen, ValueLen, Version}; -use std::mem::size_of; - -#[derive(Debug)] -pub enum Payload { - Zero, - Key(String), - Value(Vec), - Pair { key: String, value: Vec }, - ErrMsg(String), -} - -#[derive(Debug, Clone, Copy)] -pub enum RequestHeader { - Version(Version), - Get { - klen: KeyLen, - }, - Set { - klen: KeyLen, - vlen: ValueLen, - expiration: Expiration, - }, - Delete { - klen: KeyLen, - }, - Clear, - Ping, -} - -impl RequestHeader { - pub const VERSION_SIZE: usize = size_of::(); - pub const KLEN_SIZE: usize = size_of::(); - pub const VLEN_SIZE: usize = size_of::(); - pub const EXP_SIZE: usize = size_of::(); - - // Max size of the request header. - pub const SIZE: usize = { - let set_size = Self::KLEN_SIZE + Self::VLEN_SIZE + Self::EXP_SIZE; - 1 + set_size - }; - - pub fn payload_len(self) -> usize { - match self { - Self::Get { klen } => klen as usize, - Self::Set { - klen, - vlen, - expiration, - } => (klen + vlen) as usize, - Self::Delete { klen } => klen as usize, - _ => 0, - } - } -} - -#[derive(Debug, Clone, Copy)] -pub enum ResponseHeader { - Ok, - Error(ErrorHeader), - Value { vlen: ValueLen }, - KeyNotFound, - Pong, -} - -impl ResponseHeader { - pub const VLEN_SIZE: usize = size_of::(); - pub const SIZE: usize = { 1 + ErrorHeader::SIZE }; - - pub fn payload_len(self) -> usize { - match self { - Self::Error(e) => e.errmsg_len() as usize, - Self::Value { vlen } => vlen as usize, - _ => 0, - } - } -} - -#[derive(Debug, Clone, Copy)] -pub enum ErrorHeader { - Validation { len: ErrMsgLen }, - Internal { len: ErrMsgLen }, -} - -impl ErrorHeader { - pub const MSG_LEN_SIZE: usize = size_of::(); - pub const SIZE: usize = { 1 + size_of::() }; - - pub const fn errmsg_len(self) -> ErrMsgLen { - match self { - Self::Validation { len } => len, - Self::Internal { len } => len, - } - } -} - -#[cfg(test)] -mod tests {} diff --git a/memcrab-protocol/src/message.rs b/memcrab-protocol/src/message.rs new file mode 100644 index 0000000..8fd4dbb --- /dev/null +++ b/memcrab-protocol/src/message.rs @@ -0,0 +1,30 @@ +use crate::alias::{Expiration, Version}; + +#[derive(Debug, Clone, PartialEq)] +pub enum Message { + Request(Request), + Response(Response), +} + +#[derive(Debug, Clone, PartialEq)] +pub enum Request { + Version(Version), + Get(String), + Set { + key: String, + value: Vec, + expiration: Expiration, + }, + Delete(String), + Clear, + Ping, +} + +#[derive(Debug, Clone, PartialEq)] +pub enum Response { + Value(Vec), + Ok, + Error(String), + KeyNotFound, + Pong, +} diff --git a/memcrab-protocol/src/sizes.rs b/memcrab-protocol/src/sizes.rs new file mode 100644 index 0000000..e12f094 --- /dev/null +++ b/memcrab-protocol/src/sizes.rs @@ -0,0 +1,11 @@ +use super::alias::{Expiration, KeyLen, ValueLen, Version}; +use std::mem::size_of; + +// Header fields sizes. +pub const VERSION_SIZE: usize = size_of::(); +pub const KLEN_SIZE: usize = size_of::(); +pub const VLEN_SIZE: usize = size_of::(); +pub const EXP_SIZE: usize = size_of::(); + +// We manually calculate the longest possible header and make it fixed. +pub const MAX_HEADER_SIZE: usize = 1 + KLEN_SIZE + VLEN_SIZE + EXP_SIZE; diff --git a/memcrab-protocol/src/socket.rs b/memcrab-protocol/src/socket.rs new file mode 100644 index 0000000..3b768c6 --- /dev/null +++ b/memcrab-protocol/src/socket.rs @@ -0,0 +1,332 @@ +use crate::{ + alias::{Expiration, KeyLen, ValueLen, Version}, + err::{Error, ParseError}, + kind::MessageKind, + message::{Message, Request, Response}, + sizes, + stream::{AsyncReader, AsyncWriter}, +}; + +/// Wraps a stream (typically TCPStream) for receiving/sending framed messages according to memcrab +/// protocol. It is used for both client and server implementations. +#[derive(Debug, Clone)] +pub struct Socket { + stream: S, +} + +impl Socket +where + S: AsyncReader + AsyncWriter + Send, +{ + pub fn new(stream: S) -> Self { + Self { stream } + } + + /// Wait for a complete message from socket and parse it. + pub async fn receive(&mut self) -> Result { + let header = self.stream.read_chunk(sizes::MAX_HEADER_SIZE).await?; + + let kind = MessageKind::try_from(header[0]).map_err(|_| ParseError::UnknownKind)?; + use MessageKind as Kind; + Ok(match kind { + Kind::VersionRequest => { + let version = self.parse_version_header(&header[1..])?; + Message::Request(Request::Version(version)) + } + Kind::PingRequest => Message::Request(Request::Ping), + Kind::GetRequest => { + let klen = self.parse_klen_header(&header[1..])?; + let key_bytes = self.stream.read_chunk(klen as usize).await?; + let key = Self::parse_utf8(key_bytes)?; + Message::Request(Request::Get(key)) + } + Kind::SetRequest => { + let (klen, vlen, expiration) = self.parse_klen_vlen_exp_header(&header[1..])?; + let buf = self + .stream + .read_chunk(klen as usize + vlen as usize) + .await?; + let (key, value) = buf.split_at(klen as usize); + let (key, value) = (Self::parse_utf8(key.to_vec())?, value.to_vec()); + Message::Request(Request::Set { + key, + value, + expiration, + }) + } + Kind::ClearRequest => Message::Request(Request::Clear), + Kind::DeleteRequest => { + let klen = self.parse_klen_header(&header[1..])?; + let key_bytes = self.stream.read_chunk(klen as usize).await.unwrap(); + let key = Self::parse_utf8(key_bytes)?; + Message::Request(Request::Delete(key)) + } + Kind::OkResponse => Message::Response(Response::Ok), + Kind::ErrorResponse => { + // Error is parsed as Value with message as payload. + let vlen = self.parse_vlen_header(&header[1..])?; + let buf = self.stream.read_chunk(vlen as usize).await?; + let msg = Self::parse_utf8(buf)?; + Message::Response(Response::Error(msg)) + } + Kind::PongResponse => Message::Response(Response::Pong), + Kind::ValueResponse => { + let vlen = self.parse_vlen_header(&header[1..])?; + let buf = self.stream.read_chunk(vlen as usize).await?; + Message::Response(Response::Value(buf)) + } + Kind::KeyNotFoundResponse => Message::Response(Response::KeyNotFound), + }) + } + + /// Encode a message and write it to the socket. + pub async fn send(&mut self, msg: Message) -> Result<(), Error> { + let mut bytes = Vec::::new(); + + match msg { + Message::Request(Request::Version(version)) => { + bytes.push(MessageKind::VersionRequest.into()); + bytes.extend_from_slice(&version.to_be_bytes()); + todo!(); + } + Message::Request(Request::Get(key)) => { + bytes.push(MessageKind::GetRequest.into()); + bytes.extend_from_slice(&key.len().to_be_bytes()); + bytes.extend_from_slice(key.as_bytes()); + } + Message::Request(Request::Clear) => { + bytes[0] = MessageKind::ClearRequest.into(); + } + Message::Request(Request::Ping) => { + bytes.push(MessageKind::PingRequest.into()); + } + Message::Request(Request::Set { + key, + mut value, + expiration, + }) => { + bytes.push(MessageKind::SetRequest.into()); + bytes.extend_from_slice(&key.len().to_be_bytes()); + bytes.extend_from_slice(&value.len().to_be_bytes()); + bytes.extend_from_slice(&expiration.to_be_bytes()); + bytes.extend_from_slice(key.as_bytes()); + bytes.append(&mut value); + } + Message::Request(Request::Delete(key)) => { + bytes.push(MessageKind::DeleteRequest.into()); + bytes.extend_from_slice(&key.len().to_be_bytes()); + bytes.extend_from_slice(key.as_bytes()); + } + Message::Response(Response::Ok) => { + bytes.push(MessageKind::OkResponse.into()); + } + Message::Response(Response::Error(msg)) => { + bytes.push(MessageKind::ErrorResponse.into()); + bytes.extend_from_slice(&msg.len().to_be_bytes()); + bytes.append(&mut msg.into_bytes()); + } + Message::Response(Response::Pong) => { + bytes.push(MessageKind::PongResponse.into()); + } + Message::Response(Response::Value(mut value)) => { + bytes.push(MessageKind::ValueResponse.into()); + bytes.extend_from_slice(&value.len().to_be_bytes()); + bytes.append(&mut value); + } + Message::Response(Response::KeyNotFound) => { + bytes.push(MessageKind::KeyNotFoundResponse.into()); + } + } + debug_assert!(bytes.len() <= sizes::MAX_HEADER_SIZE); + bytes.resize(sizes::MAX_HEADER_SIZE, 0u8); + self.stream.write_all(&bytes).await?; + Ok(()) + } + + fn parse_version_header(&mut self, header: &[u8]) -> Result { + let version_bytes = &header[..sizes::VERSION_SIZE]; + let version = Version::from_be_bytes( + version_bytes + .try_into() + .expect("version_bytes len != VERSION_SIZE"), + ); + Ok(version) + } + + fn parse_klen_header(&mut self, header: &[u8]) -> Result { + let klen_bytes = &header[..sizes::KLEN_SIZE]; + let klen = KeyLen::from_be_bytes( + klen_bytes + .try_into() + .expect("klen_bytes.len() should be equal to KLEN_SIZE"), + ); + Ok(klen) + } + + fn parse_klen_vlen_exp_header( + &mut self, + header: &[u8], + ) -> Result<(KeyLen, ValueLen, Expiration), Error> { + let (klen_bytes, rest) = header.split_at(sizes::KLEN_SIZE); + let (vlen_bytes, rest) = rest.split_at(sizes::VLEN_SIZE); + let expiration_bytes = &rest[..sizes::EXP_SIZE]; + + let klen = KeyLen::from_be_bytes( + klen_bytes + .try_into() + .expect("klen_bytes.len() should be equal to KLEN_SIZE"), + ); + let vlen = ValueLen::from_be_bytes( + vlen_bytes + .try_into() + .expect("vlen_bytes.len() should be equal to VLEN_SIZE"), + ); + let expiration = Expiration::from_be_bytes( + expiration_bytes + .try_into() + .expect("expiration_bytes.len() should be equal to EXP_SIZE"), + ); + + Ok((klen, vlen, expiration)) + } + + fn parse_vlen_header(&mut self, header: &[u8]) -> Result { + let vlen_bytes = &header[..sizes::VLEN_SIZE]; + assert_eq!(vlen_bytes.len(), sizes::VLEN_SIZE); + let vlen = ValueLen::from_be_bytes( + vlen_bytes + .try_into() + .expect("vlen_bytes.len() should be equal to VLEN_SIZE"), + ); + + Ok(vlen) + } + + fn parse_utf8(buf: Vec) -> Result { + Ok(String::from_utf8(buf).map_err(|_| ParseError::InvalidString)?) + } +} + +// test in submodule so we can access private stream +#[cfg(test)] +mod test { + use super::*; + + struct MockStream { + read_data: std::collections::VecDeque, + wrote_data: Vec, + } + + #[async_trait::async_trait] + impl AsyncReader for MockStream { + async fn read_exact(&mut self, buf: &mut [u8]) -> Result<(), std::io::Error> { + for byte in buf.iter_mut() { + if self.read_data.is_empty() { + return Err(std::io::Error::new( + std::io::ErrorKind::UnexpectedEof, + format!( + "read_exact: cannot completely fill buffer, stopped at: buf={:?}", + buf + ), + )); + } + *byte = self.read_data.pop_front().unwrap(); + } + Ok(()) + } + } + + #[async_trait::async_trait] + impl AsyncWriter for MockStream { + async fn write_all(&mut self, buf: &[u8]) -> Result<(), std::io::Error> { + self.wrote_data = Vec::from(buf); + Ok(()) + } + } + + async fn assert_parsed(data: Vec, msg: Message) { + let mut socket = Socket::new(MockStream { + read_data: vec![].into(), + wrote_data: vec![], + }); + socket.stream.read_data = data.into(); + let parsed = socket.receive().await; + assert_eq!(parsed.expect("error while parsing"), msg); + } + + #[tokio::test] + async fn test_socket() { + let mut data: Vec; + + assert_eq!(21, sizes::MAX_HEADER_SIZE); + + // TODO: divide test cases to functions so it does not stop at first fail + reduce + // boilerplate somehow + // TODO!: tests for encoding + + data = vec![0]; + data.append(&mut vec![0, 1]); + data.append(&mut vec![0; 18]); + assert_parsed(data, Message::Request(Request::Version(1))).await; + + data = vec![1]; + data.append(&mut vec![0; 20]); + assert_parsed(data, Message::Request(Request::Ping)).await; + + data = vec![2]; + data.append(&mut vec![0, 0, 0, 0, 0, 0, 0, 2]); // klen + data.append(&mut vec![0; 12]); // rest of header + data.append(&mut vec![97, 98]); // utf8 encoded key + assert_parsed(data, Message::Request(Request::Get("ab".to_owned()))).await; + + data = vec![3]; + data.append(&mut vec![0, 0, 0, 0, 0, 0, 0, 2]); // klen + data.append(&mut vec![0, 0, 0, 0, 0, 0, 0, 3]); // vlen + data.append(&mut vec![0, 0, 1, 0]); // exp + data.append(&mut vec![97, 98]); // utf8 encoded key + data.append(&mut vec![1, 2, 3]); // value + assert_parsed( + data, + Message::Request(Request::Set { + key: "ab".to_owned(), + value: vec![1, 2, 3], + expiration: 256, + }), + ) + .await; + + data = vec![4]; + data.append(&mut vec![0, 0, 0, 0, 0, 0, 0, 2]); // klen + data.append(&mut vec![0; 12]); // rest of header + data.append(&mut vec![97, 98]); // utf8 encoded key + assert_parsed(data, Message::Request(Request::Delete("ab".to_owned()))).await; + + data = vec![5]; + data.append(&mut vec![0; 20]); + assert_parsed(data, Message::Request(Request::Clear)).await; + + data = vec![128]; + data.append(&mut vec![0; 20]); + assert_parsed(data, Message::Response(Response::Pong)).await; + + data = vec![129]; + data.append(&mut vec![0; 20]); + assert_parsed(data, Message::Response(Response::Ok)).await; + + data = vec![130]; + data.append(&mut vec![0, 0, 0, 0, 0, 0, 0, 4]); // vlen + data.append(&mut vec![0; 12]); // rest of header + data.append(&mut vec![1, 2, 3, 4]); // msg + assert_parsed(data, Message::Response(Response::Value(vec![1, 2, 3, 4]))).await; + + data = vec![131]; + data.append(&mut vec![0; 20]); + assert_parsed(data, Message::Response(Response::KeyNotFound)).await; + + data = vec![132]; + data.append(&mut vec![0, 0, 0, 0, 0, 0, 0, 3]); // vlen + data.append(&mut vec![0; 12]); // rest of header + data.append(&mut vec![101, 114, 114]); // msg + assert_parsed(data, Message::Response(Response::Error("err".to_owned()))).await; + } +} diff --git a/memcrab-protocol/src/io/async_reader.rs b/memcrab-protocol/src/stream/async_reader.rs similarity index 100% rename from memcrab-protocol/src/io/async_reader.rs rename to memcrab-protocol/src/stream/async_reader.rs diff --git a/memcrab-protocol/src/io/async_writer.rs b/memcrab-protocol/src/stream/async_writer.rs similarity index 100% rename from memcrab-protocol/src/io/async_writer.rs rename to memcrab-protocol/src/stream/async_writer.rs diff --git a/memcrab-protocol/src/io/mod.rs b/memcrab-protocol/src/stream/mod.rs similarity index 100% rename from memcrab-protocol/src/io/mod.rs rename to memcrab-protocol/src/stream/mod.rs diff --git a/memcrab-protocol/src/transport/client.rs b/memcrab-protocol/src/transport/client.rs deleted file mode 100644 index 5340d9e..0000000 --- a/memcrab-protocol/src/transport/client.rs +++ /dev/null @@ -1,184 +0,0 @@ -use crate::{ - io::{AsyncReader, AsyncWriter}, - mapping::{ - alias::{ErrMsgLen, Expiration, KeyLen, ValueLen}, - flags::{RequestFlag, ResponseFlag}, - tokens::{ErrorHeader, Payload, RequestHeader, ResponseHeader}, - }, - ClientSideError, ErrorResponse, ParsingError, Request, Response, -}; - -#[derive(Debug, Clone)] -pub struct ClientSocket { - stream: S, -} - -impl ClientSocket -where - S: AsyncReader + AsyncWriter + Send, -{ - pub fn new(stream: S) -> Self { - Self { stream } - } - pub async fn make_request(&mut self, request: Request) -> Result { - let req_bytes = self.encode_request(request); - self.stream.write_all(&req_bytes).await?; - - let header_chunk = self.stream.read_chunk(ResponseHeader::SIZE).await?; - let header = self.decode_response_header(&header_chunk)?; - - let payload_chunk = if header.payload_len() > 0 { - self.stream.read_chunk(header.payload_len()).await? - } else { - vec![] - }; - let payload = self.decode_response_payload(header, payload_chunk)?; - - let resp = self.construct_response(header, payload); - Ok(resp) - } - - fn construct_response(&self, header: ResponseHeader, payload: Payload) -> Response { - match header { - ResponseHeader::Ok => Response::Ok, - ResponseHeader::Pong => Response::Pong, - ResponseHeader::KeyNotFound => Response::KeyNotFound, - ResponseHeader::Value { .. } => match payload { - Payload::Zero => Response::Value(vec![]), - Payload::Value(v) => Response::Value(v), - p => panic!("invalid value, payload={:?}", p), - }, - ResponseHeader::Error(err) => { - let msg = match payload { - Payload::Zero => "".to_owned(), - Payload::ErrMsg(msg) => msg, - p => panic!("invalid error msg, payload={:?}", p), - }; - let inner = match err { - ErrorHeader::Internal { .. } => ErrorResponse::Internal(msg), - ErrorHeader::Validation { .. } => ErrorResponse::Validation(msg), - }; - Response::Error(inner) - } - } - } - fn encode_request(&self, request: Request) -> Vec { - let mut bytes = vec![0; RequestHeader::SIZE]; - match request { - Request::Ping => { - bytes[0] = RequestFlag::Ping.into(); - } - Request::Clear => { - bytes[0] = RequestFlag::Clear.into(); - } - Request::Version(v) => { - bytes[0] = RequestFlag::Version.into(); - let [a, b] = v.to_be_bytes(); - bytes[1] = a; - bytes[2] = b; - } - Request::Get(key) => { - bytes[0] = RequestFlag::Get.into(); - let key = key.as_bytes(); - let klen: KeyLen = key.len().try_into().unwrap(); - - for (dst, src) in bytes[1..].iter_mut().zip(klen.to_be_bytes()) { - *dst = src; - } - bytes.extend_from_slice(key); - } - Request::Delete(key) => { - bytes[0] = RequestFlag::Delete.into(); - let key = key.as_bytes(); - let klen: KeyLen = key.len().try_into().unwrap(); - - for (dst, src) in bytes[1..].iter_mut().zip(klen.to_be_bytes()) { - *dst = src; - } - bytes.extend_from_slice(key); - } - Request::Set { - key, - value, - expiration, - } => { - bytes[0] = RequestFlag::Set.into(); - let key = key.as_bytes(); - let klen: KeyLen = key.len().try_into().unwrap(); - let vlen: ValueLen = value.len().try_into().unwrap(); - - let klen_bytes = klen.to_be_bytes(); - let vlen_bytes = vlen.to_be_bytes(); - let exp_bytes = expiration.to_be_bytes(); - - let tail = klen_bytes - .iter() - .chain(vlen_bytes.iter()) - .chain(exp_bytes.iter()); - for (dst, &src) in bytes[1..].iter_mut().zip(tail) { - *dst = src; - } - bytes.extend_from_slice(key); - bytes.extend_from_slice(&value); - } - } - bytes - } - fn decode_response_header(&self, header_chunk: &[u8]) -> Result { - let flag = ResponseFlag::try_from(header_chunk[0]).map_err(|_| ParsingError::Header)?; - match flag { - ResponseFlag::Pong => Ok(ResponseHeader::Pong), - ResponseFlag::Ok => Ok(ResponseHeader::Ok), - ResponseFlag::KeyNotFound => Ok(ResponseHeader::KeyNotFound), - ResponseFlag::Value => { - let vlen_bytes = &header_chunk[1..1 + ResponseHeader::VLEN_SIZE]; - let vlen = KeyLen::from_be_bytes( - vlen_bytes - .try_into() - .expect("vlen_bytes.len() should be equal to VLEN_SIZE"), - ); - Ok(ResponseHeader::Value { vlen }) - } - ResponseFlag::InternalErr => { - let msg_len_bytes = &header_chunk[1..1 + ErrorHeader::MSG_LEN_SIZE]; - let msg_len = ErrMsgLen::from_be_bytes( - msg_len_bytes - .try_into() - .expect("msg_len_bytes.len() should be equal to MSG_LEN_SIZE"), - ); - let err = ErrorHeader::Internal { len: msg_len }; - Ok(ResponseHeader::Error(err)) - } - ResponseFlag::ValidationErr => { - let msg_len_bytes = &header_chunk[1..1 + ErrorHeader::MSG_LEN_SIZE]; - let msg_len = ErrMsgLen::from_be_bytes( - msg_len_bytes - .try_into() - .expect("msg_len_bytes.len() should be equal to MSG_LEN_SIZE"), - ); - let err = ErrorHeader::Validation { len: msg_len }; - Ok(ResponseHeader::Error(err)) - } - } - } - fn decode_response_payload( - &self, - header: ResponseHeader, - payload_chunk: Vec, - ) -> Result { - use ResponseHeader as H; - - match header { - H::Pong | H::Ok | H::KeyNotFound => Ok(Payload::Zero), - H::Value { vlen } => { - assert_eq!(vlen, payload_chunk.len() as u64); - Ok(Payload::Value(payload_chunk)) - } - H::Error(inner) => { - assert_eq!(inner.errmsg_len(), payload_chunk.len() as u64); - let msg = String::from_utf8(payload_chunk).map_err(|_| ParsingError::Payload)?; - Ok(Payload::ErrMsg(msg)) - } - } - } -} diff --git a/memcrab-protocol/src/transport/mod.rs b/memcrab-protocol/src/transport/mod.rs deleted file mode 100644 index bcff728..0000000 --- a/memcrab-protocol/src/transport/mod.rs +++ /dev/null @@ -1,7 +0,0 @@ -mod client; -mod schemas; -mod server; - -pub use client::ClientSocket; -pub use schemas::{ErrorResponse, Request, Response}; -pub use server::ServerSocket; diff --git a/memcrab-protocol/src/transport/schemas.rs b/memcrab-protocol/src/transport/schemas.rs deleted file mode 100644 index 593e917..0000000 --- a/memcrab-protocol/src/transport/schemas.rs +++ /dev/null @@ -1,33 +0,0 @@ -use crate::mapping::alias::{Expiration, Version}; -use thiserror::Error; - -#[derive(Debug, Clone, PartialEq, Eq)] -pub enum Request { - Version(Version), - Get(String), - Set { - key: String, - value: Vec, - expiration: Expiration, - }, - Delete(String), - Clear, - Ping, -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub enum Response { - Value(Vec), - Ok, - Error(ErrorResponse), - KeyNotFound, - Pong, -} - -#[derive(Error, Debug, Clone, PartialEq, Eq)] -pub enum ErrorResponse { - #[error("validation error")] - Validation(String), - #[error("internal error")] - Internal(String), -} diff --git a/memcrab-protocol/src/transport/server.rs b/memcrab-protocol/src/transport/server.rs deleted file mode 100644 index a056db2..0000000 --- a/memcrab-protocol/src/transport/server.rs +++ /dev/null @@ -1,198 +0,0 @@ -use crate::{ - io::{AsyncReader, AsyncWriter}, - mapping::{ - alias::{ErrMsgLen, Expiration, KeyLen, ValueLen, Version}, - flags::{RequestFlag, ResponseFlag}, - tokens::{Payload, RequestHeader, ResponseHeader}, - }, - ErrorResponse, ParsingError, Request, Response, ServerSideError, -}; - -#[derive(Debug, Clone)] -pub struct ServerSocket { - stream: S, -} - -impl ServerSocket -where - S: AsyncReader + AsyncWriter + Send, -{ - pub fn new(stream: S) -> Self { - Self { stream } - } - pub async fn recv_request(&mut self) -> Result { - let header_chunk = self.stream.read_chunk(RequestHeader::SIZE).await?; - let header = self.decode_request_header(&header_chunk)?; - - let payload_chunk = if header.payload_len() > 0 { - self.stream.read_chunk(header.payload_len()).await? - } else { - vec![] - }; - let payload = self.decode_request_payload(header, payload_chunk)?; - - let req = self.construct_request(header, payload); - Ok(req) - } - pub async fn send_response(&mut self, response: &Response) -> Result<(), ServerSideError> { - let response_bytes = self.encode_response(response); - self.stream.write_all(&response_bytes).await?; - Ok(()) - } - - fn decode_request_header(&self, header_chunk: &[u8]) -> Result { - let flag = RequestFlag::try_from(header_chunk[0]).map_err(|_| ParsingError::Header)?; - match flag { - RequestFlag::Version => { - let version_bytes = &header_chunk[1..1 + RequestHeader::VERSION_SIZE]; - let version = Version::from_be_bytes( - version_bytes - .try_into() - .expect("version_bytes.len() should be equal to VERSION_SIZE"), - ); - Ok(RequestHeader::Version(version)) - } - RequestFlag::Ping => Ok(RequestHeader::Ping), - RequestFlag::Get => { - let klen_bytes = &header_chunk[1..1 + RequestHeader::KLEN_SIZE]; - let klen = KeyLen::from_be_bytes( - klen_bytes - .try_into() - .expect("klen_bytes.len() should be equal to KLEN_SIZE"), - ); - Ok(RequestHeader::Get { klen }) - } - RequestFlag::Set => { - let tail = &header_chunk[1..]; - let (klen_bytes, tail) = tail.split_at(RequestHeader::KLEN_SIZE); - let (vlen_bytes, tail) = tail.split_at(RequestHeader::VLEN_SIZE); - let expiration_bytes = &tail[..RequestHeader::EXP_SIZE]; - - let klen = KeyLen::from_be_bytes( - klen_bytes - .try_into() - .expect("klen_bytes.len() should be equal to KLEN_SIZE"), - ); - let vlen = ValueLen::from_be_bytes( - vlen_bytes - .try_into() - .expect("vlen_bytes.len() should be equal to VLEN_SIZE"), - ); - let expiration = Expiration::from_be_bytes( - expiration_bytes - .try_into() - .expect("expiration_bytes.len() should be equal to EXP_SIZE"), - ); - Ok(RequestHeader::Set { - klen, - vlen, - expiration, - }) - } - RequestFlag::Clear => Ok(RequestHeader::Clear), - RequestFlag::Delete => { - let klen_bytes = &header_chunk[1..1 + RequestHeader::KLEN_SIZE]; - let klen = KeyLen::from_be_bytes( - klen_bytes - .try_into() - .expect("klen_bytes.len() should be equal to KLEN_SIZE"), - ); - Ok(RequestHeader::Delete { klen }) - } - } - } - fn decode_request_payload( - &self, - header: RequestHeader, - payload_chunk: Vec, - ) -> Result { - match header { - RequestHeader::Ping => Ok(Payload::Zero), - RequestHeader::Version(v) => Ok(Payload::Zero), - RequestHeader::Delete { klen } => { - assert_eq!(klen, payload_chunk.len() as u64); - let key = String::from_utf8(payload_chunk).map_err(|_| ParsingError::Payload)?; - Ok(Payload::Key(key)) - } - RequestHeader::Clear => Ok(Payload::Zero), - RequestHeader::Get { klen } => { - assert_eq!(klen, payload_chunk.len() as u64); - let key = String::from_utf8(payload_chunk).map_err(|_| ParsingError::Payload)?; - Ok(Payload::Key(key)) - } - RequestHeader::Set { - klen, - vlen, - expiration, - } => { - let (head, tail) = payload_chunk.split_at(klen.try_into().unwrap()); - let key = String::from_utf8(head.to_vec()).map_err(|_| ParsingError::Payload)?; - let value = tail.to_vec(); - Ok(Payload::Pair { key, value }) - } - } - } - fn construct_request(&self, header: RequestHeader, payload: Payload) -> Request { - use RequestHeader as H; - - match (header, payload) { - (H::Ping, Payload::Zero) => Request::Ping, - (H::Version(v), Payload::Zero) => Request::Version(v), - (H::Delete { .. }, Payload::Key(key)) => Request::Delete(key), - (H::Clear, Payload::Zero) => Request::Clear, - (H::Get { .. }, Payload::Key(key)) => Request::Get(key), - (H::Set { expiration, .. }, Payload::Pair { key, value }) => Request::Set { - key, - value, - expiration, - }, - tuple => panic!("invalid (header, payload): {:?}", tuple), - } - } - fn encode_response(&self, response: &Response) -> Vec { - let mut bytes = vec![0; ResponseHeader::SIZE]; - - match response { - Response::Pong => { - bytes[0] = ResponseFlag::Pong.into(); - } - Response::Ok => { - bytes[0] = ResponseFlag::Ok.into(); - } - Response::KeyNotFound => { - bytes[0] = ResponseFlag::KeyNotFound.into(); - } - Response::Value(value) => { - bytes[0] = ResponseFlag::Value.into(); - let vlen: ValueLen = value.len().try_into().unwrap(); - for (dst, src) in bytes[1..].iter_mut().zip(vlen.to_be_bytes()) { - *dst = src; - } - bytes.extend_from_slice(value); - } - Response::Error(err) => { - match err { - ErrorResponse::Internal(msg) => { - bytes[0] = ResponseFlag::InternalErr.into(); - let msg = msg.as_bytes(); - let msg_len: ErrMsgLen = msg.len().try_into().unwrap(); - for (dst, src) in bytes[1..].iter_mut().zip(msg_len.to_be_bytes()) { - *dst = src; - } - bytes.extend_from_slice(msg); - } - ErrorResponse::Validation(msg) => { - bytes[0] = ResponseFlag::ValidationErr.into(); - let msg = msg.as_bytes(); - let msg_len: ErrMsgLen = msg.len().try_into().unwrap(); - for (dst, src) in bytes[1..].iter_mut().zip(msg_len.to_be_bytes()) { - *dst = src; - } - bytes.extend_from_slice(msg); - } - }; - } - } - bytes - } -} diff --git a/memcrab-protocol/src/version.rs b/memcrab-protocol/src/version.rs new file mode 100644 index 0000000..26656ed --- /dev/null +++ b/memcrab-protocol/src/version.rs @@ -0,0 +1,3 @@ +use crate::alias::Version; + +pub const VERSION: Version = 0; diff --git a/memcrab-protocol/tests/test_cmd.rs b/memcrab-protocol/tests/test_cmd.rs deleted file mode 100644 index 5f2534c..0000000 --- a/memcrab-protocol/tests/test_cmd.rs +++ /dev/null @@ -1,105 +0,0 @@ -use std::net::SocketAddr; -use std::time::Duration; -use tokio::net::{TcpListener, TcpStream}; -use tokio::sync::mpsc::{self, Receiver, Sender}; - -use memcrab_protocol::io::{AsyncReader, AsyncWriter}; -use memcrab_protocol::{ClientSocket, ErrorResponse, Request, Response, ServerSocket}; - -#[tokio::test] -async fn test_commands() { - let addr = "127.0.0.1:9008".parse().unwrap(); - let (echo_sender, echo_receiver) = mpsc::channel::(1); - - let cache = |req: Request| -> Response { - match req { - Request::Ping => Response::Pong, - Request::Clear => Response::Ok, - Request::Version(_) => { - let inner = ErrorResponse::Internal("..".to_owned()); - Response::Error(inner) - } - Request::Get(_) => Response::Value(vec![1, 2]), - Request::Delete(_) => Response::KeyNotFound, - Request::Set { .. } => Response::Ok, - } - }; - - tokio::spawn(start_server_socket(addr, echo_sender, cache)); - - let stream = connect(addr).await; - let socket = ClientSocket::new(stream); - let mut client = TestClient::new(socket, echo_receiver); - - let resp = client.make_request(Request::Ping).await; - assert_eq!(resp, Response::Pong); - - let resp = client.make_request(Request::Delete("abc".to_owned())).await; - assert_eq!(resp, Response::KeyNotFound); - - client.make_request(Request::Version(1)).await; - client.make_request(Request::Clear).await; - client.make_request(Request::Get("123".to_owned())).await; - client - .make_request(Request::Set { - key: "alex".to_owned(), - value: vec![11, 2], - expiration: 2, - }) - .await; -} - -async fn connect(addr: SocketAddr) -> TcpStream { - for _ in 0..30 { - match TcpStream::connect(addr).await { - Ok(stream) => return stream, - Err(_) => { - tokio::time::sleep(Duration::from_secs_f32(0.1)).await; - } - } - } - unreachable!("could not connect to {:?}", addr) -} - -async fn start_server_socket( - addr: SocketAddr, - echo_channel: Sender, - mut cache: impl FnMut(Request) -> Response, -) { - let listener = TcpListener::bind(addr).await.unwrap(); - - let (stream, _) = listener.accept().await.unwrap(); - let mut server = ServerSocket::new(stream); - loop { - let req = server.recv_request().await.unwrap(); - let resp = cache(req.clone()); - server.send_response(&resp).await.unwrap(); - echo_channel.send(req.clone()).await.unwrap(); - } -} - -struct TestClient -where - S: AsyncReader + AsyncWriter + Send, -{ - socket: ClientSocket, - echo_channel: Receiver, -} - -impl TestClient -where - S: AsyncReader + AsyncWriter + Send, -{ - fn new(socket: ClientSocket, echo_channel: Receiver) -> Self { - Self { - socket, - echo_channel, - } - } - async fn make_request(&mut self, request: Request) -> Response { - let resp = self.socket.make_request(request.clone()).await.unwrap(); - let echo = self.echo_channel.recv().await.unwrap(); - assert_eq!(echo, request); - resp - } -} diff --git a/memcrab/src/lib.rs b/memcrab/src/lib.rs index c174f41..3caa6a5 100644 --- a/memcrab/src/lib.rs +++ b/memcrab/src/lib.rs @@ -1,7 +1,7 @@ #[allow(unused_variables)] mod raw_client; -pub use raw_client::RawClient; +// pub use raw_client::RawClient; #[cfg(test)] mod tests {} diff --git a/memcrab/src/raw_client.rs b/memcrab/src/raw_client.rs index d0d97ac..f68a750 100644 --- a/memcrab/src/raw_client.rs +++ b/memcrab/src/raw_client.rs @@ -1,19 +1,19 @@ -use std::net::SocketAddr; - -use memcrab_protocol::ClientSideError; - -pub struct RawClient {} - -impl RawClient { - pub async fn connect(addr: SocketAddr) -> Result { - todo!() - } - pub async fn get(&self, key: impl Into) -> Result>, ClientSideError> { - let key = key.into(); - todo!() - } - pub async fn set(&self, key: impl Into, value: Vec) -> Result<(), ClientSideError> { - let key = key.into(); - todo!() - } -} +// use std::net::SocketAddr; +// +// use memcrab_protocol::ClientSideError; +// +// pub struct RawClient {} +// +// impl RawClient { +// pub async fn connect(addr: SocketAddr) -> Result { +// todo!() +// } +// pub async fn get(&self, key: impl Into) -> Result>, ClientSideError> { +// let key = key.into(); +// todo!() +// } +// pub async fn set(&self, key: impl Into, value: Vec) -> Result<(), ClientSideError> { +// let key = key.into(); +// todo!() +// } +// }