Skip to content

Commit

Permalink
WIP: Switch Batches to new serialization traits
Browse files Browse the repository at this point in the history
  • Loading branch information
Lorak-mmk committed Dec 7, 2023
1 parent 366d6c4 commit 685fca4
Show file tree
Hide file tree
Showing 9 changed files with 121 additions and 246 deletions.
99 changes: 40 additions & 59 deletions scylla-cql/src/frame/request/batch.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
use bytes::{Buf, BufMut};
use std::{borrow::Cow, convert::TryInto};

use crate::frame::{
frame_errors::ParseError,
request::{RequestOpcode, SerializableRequest},
types::{self, SerialConsistency},
value::{BatchValues, BatchValuesIterator, LegacySerializedValues},
use crate::{
frame::{
frame_errors::ParseError,
request::{RequestOpcode, SerializableRequest},
types::{self, SerialConsistency},
},
types::serialize::row::SerializedValues,
};

use super::DeserializableRequest;
Expand All @@ -16,18 +18,16 @@ const FLAG_WITH_DEFAULT_TIMESTAMP: u8 = 0x20;
const ALL_FLAGS: u8 = FLAG_WITH_SERIAL_CONSISTENCY | FLAG_WITH_DEFAULT_TIMESTAMP;

#[cfg_attr(test, derive(Debug, PartialEq, Eq))]
pub struct Batch<'b, Statement, Values>
pub struct Batch<'b, Statement>
where
BatchStatement<'b>: From<&'b Statement>,
Statement: Clone,
Values: BatchValues,
{
pub statements: Cow<'b, [Statement]>,
pub batch_type: BatchType,
pub consistency: types::Consistency,
pub serial_consistency: Option<types::SerialConsistency>,
pub timestamp: Option<i64>,
pub values: Values,
}

/// The type of a batch.
Expand Down Expand Up @@ -64,15 +64,20 @@ impl TryFrom<u8> for BatchType {

#[derive(Debug, Clone, Eq, PartialEq, PartialOrd, Ord)]
pub enum BatchStatement<'a> {
Query { text: Cow<'a, str> },
Prepared { id: Cow<'a, [u8]> },
Query {
text: Cow<'a, str>,
values: Cow<'a, SerializedValues>,
},
Prepared {
id: Cow<'a, [u8]>,
values: Cow<'a, SerializedValues>,
},
}

impl<Statement, Values> SerializableRequest for Batch<'_, Statement, Values>
impl<Statement> SerializableRequest for Batch<'_, Statement>
where
for<'s> BatchStatement<'s>: From<&'s Statement>,
Statement: Clone,
Values: BatchValues,
{
const OPCODE: RequestOpcode = RequestOpcode::Batch;

Expand All @@ -83,36 +88,9 @@ where
// Serializing queries
types::write_short(self.statements.len().try_into()?, buf);

let counts_mismatch_err = |n_values: usize, n_statements: usize| {
ParseError::BadDataToSerialize(format!(
"Length of provided values must be equal to number of batch statements \
(got {n_values} values, {n_statements} statements)"
))
};
let mut n_serialized_statements = 0usize;
let mut value_lists = self.values.batch_values_iter();
for (idx, statement) in self.statements.iter().enumerate() {
BatchStatement::from(statement).serialize(buf)?;
value_lists
.write_next_to_request(buf)
.ok_or_else(|| counts_mismatch_err(idx, self.statements.len()))??;
n_serialized_statements += 1;
}
// At this point, we have all statements serialized. If any values are still left, we have a mismatch.
if value_lists.skip_next().is_some() {
return Err(counts_mismatch_err(
n_serialized_statements + 1 /*skipped above*/ + value_lists.count(),
n_serialized_statements,
));
}
if n_serialized_statements != self.statements.len() {
// We want to check this to avoid propagating an invalid construction of self.statements_count as a
// hard-to-debug silent fail
return Err(ParseError::BadDataToSerialize(format!(
"Invalid Batch constructed: not as many statements serialized as announced \
(batch.statement_count: {announced_statement_count}, {n_serialized_statements}",
announced_statement_count = self.statements.len()
)));
for statement in self.statements.iter() {
let stmt = BatchStatement::from(statement);
stmt.serialize(buf)?;
}

// Serializing consistency
Expand Down Expand Up @@ -146,11 +124,13 @@ impl BatchStatement<'_> {
match kind {
0 => {
let text = Cow::Owned(types::read_long_string(buf)?.to_owned());
Ok(BatchStatement::Query { text })
let values = Cow::Owned(SerializedValues::new_from_frame(buf)?);
Ok(BatchStatement::Query { text, values })
}
1 => {
let id = types::read_short_bytes(buf)?.to_vec().into();
Ok(BatchStatement::Prepared { id })
let values = Cow::Owned(SerializedValues::new_from_frame(buf)?);
Ok(BatchStatement::Prepared { id, values })
}
_ => Err(ParseError::BadIncomingData(format!(
"Unexpected batch statement kind: {}",
Expand All @@ -163,13 +143,15 @@ impl BatchStatement<'_> {
impl BatchStatement<'_> {
fn serialize(&self, buf: &mut impl BufMut) -> Result<(), ParseError> {
match self {
Self::Query { text } => {
Self::Query { text, values } => {
buf.put_u8(0);
types::write_long_string(text, buf)?;
values.write_to_request(buf);
}
Self::Prepared { id } => {
Self::Prepared { id, values } => {
buf.put_u8(1);
types::write_short_bytes(id, buf)?;
values.write_to_request(buf);
}
}

Expand All @@ -180,25 +162,28 @@ impl BatchStatement<'_> {
impl<'s, 'b> From<&'s BatchStatement<'b>> for BatchStatement<'s> {
fn from(value: &'s BatchStatement) -> Self {
match value {
BatchStatement::Query { text } => BatchStatement::Query { text: text.clone() },
BatchStatement::Prepared { id } => BatchStatement::Prepared { id: id.clone() },
BatchStatement::Query { text, values } => BatchStatement::Query {
text: text.clone(),
values: values.clone(),
},
BatchStatement::Prepared { id, values } => BatchStatement::Prepared {
id: id.clone(),
values: values.clone(),
},
}
}
}

impl<'b> DeserializableRequest for Batch<'b, BatchStatement<'b>, Vec<LegacySerializedValues>> {
impl<'b> DeserializableRequest for Batch<'b, BatchStatement<'b>> {
fn deserialize(buf: &mut &[u8]) -> Result<Self, ParseError> {
let batch_type = buf.get_u8().try_into()?;

let statements_count: usize = types::read_short(buf)?.into();
let statements_with_values = (0..statements_count)
let statements = (0..statements_count)
.map(|_| {
let batch_statement = BatchStatement::deserialize(buf)?;

// As stated in CQL protocol v4 specification, values names in Batch are broken and should be never used.
let values = LegacySerializedValues::new_from_frame(buf, false)?;

Ok((batch_statement, values))
Ok(batch_statement)
})
.collect::<Result<Vec<_>, ParseError>>()?;

Expand Down Expand Up @@ -233,16 +218,12 @@ impl<'b> DeserializableRequest for Batch<'b, BatchStatement<'b>, Vec<LegacySeria
.then(|| types::read_long(buf))
.transpose()?;

let (statements, values): (Vec<BatchStatement>, Vec<LegacySerializedValues>) =
statements_with_values.into_iter().unzip();

Ok(Self {
statements: Cow::Owned(statements),
batch_type,
consistency,
serial_consistency,
timestamp,
statements: Cow::Owned(statements),
values,
})
}
}
32 changes: 6 additions & 26 deletions scylla-cql/src/frame/request/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ pub use startup::Startup;
use self::batch::BatchStatement;

use super::types::SerialConsistency;
use super::value::LegacySerializedValues;

#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, TryFromPrimitive)]
#[repr(u8)]
Expand Down Expand Up @@ -59,7 +58,7 @@ pub trait DeserializableRequest: SerializableRequest + Sized {
pub enum Request<'r> {
Query(Query<'r>),
Execute(Execute<'r>),
Batch(Batch<'r, BatchStatement<'r>, Vec<LegacySerializedValues>>),
Batch(Batch<'r, BatchStatement<'r>>),
}

impl<'r> Request<'r> {
Expand Down Expand Up @@ -100,7 +99,7 @@ impl<'r> Request<'r> {

#[cfg(test)]
mod tests {
use std::{borrow::Cow, ops::Deref};
use std::borrow::Cow;

use bytes::Bytes;

Expand Down Expand Up @@ -176,9 +175,12 @@ mod tests {
let statements = vec![
BatchStatement::Query {
text: query.contents,
values: query.parameters.values.clone(),
},
// Not execute's values, because named values are not supported in batches.
BatchStatement::Prepared {
id: Cow::Borrowed(&execute.id),
values: query.parameters.values,
},
];
let batch = Batch {
Expand All @@ -187,22 +189,6 @@ mod tests {
consistency: Consistency::EachQuorum,
serial_consistency: Some(SerialConsistency::LocalSerial),
timestamp: Some(32432),

// Not execute's values, because named values are not supported in batches.
values: vec![
query
.parameters
.values
.deref()
.clone()
.into_old_serialized_values(),
query
.parameters
.values
.deref()
.clone()
.into_old_serialized_values(),
],
};
{
let mut buf = Vec::new();
Expand Down Expand Up @@ -264,20 +250,14 @@ mod tests {
// Batch
let statements = vec![BatchStatement::Query {
text: query.contents,
values: query.parameters.values,
}];
let batch = Batch {
statements: Cow::Owned(statements),
batch_type: BatchType::Logged,
consistency: Consistency::EachQuorum,
serial_consistency: None,
timestamp: None,

values: vec![query
.parameters
.values
.deref()
.clone()
.into_old_serialized_values()],
};
{
let mut buf = Vec::new();
Expand Down
59 changes: 26 additions & 33 deletions scylla/src/statement/batch.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
use std::borrow::Cow;
use std::sync::Arc;

use scylla_cql::types::serialize::row::{SerializeRow, SerializedValues};
use scylla_cql::types::serialize::SerializationError;

use crate::history::HistoryListener;
use crate::retry_policy::RetryPolicy;
use crate::statement::{prepared_statement::PreparedStatement, query::Query};
Expand All @@ -17,7 +20,7 @@ pub use crate::frame::request::batch::BatchType;
pub struct Batch {
pub(crate) config: StatementConfig,

pub statements: Vec<BatchStatement>,
pub(crate) statements: Vec<BatchStatement>,
batch_type: BatchType,
}

Expand All @@ -30,18 +33,21 @@ impl Batch {
}
}

/// Creates a new, empty `Batch` of `batch_type` type with the provided statements.
pub fn new_with_statements(batch_type: BatchType, statements: Vec<BatchStatement>) -> Self {
Self {
batch_type,
statements,
..Default::default()
}
pub fn append_query(&mut self, query: impl Into<Query>) {
self.statements.push(BatchStatement::Query(query.into()));
}

/// Appends a new statement to the batch.
pub fn append_statement(&mut self, statement: impl Into<BatchStatement>) {
self.statements.push(statement.into());
pub fn append_statement(
&mut self,
statement: PreparedStatement,
values: impl SerializeRow,
) -> Result<(), SerializationError> {
let serialized = statement.serialize_values(&values)?;
self.statements.push(BatchStatement::PreparedStatement {
statement,
values: serialized,
});
Ok(())
}

/// Gets type of batch.
Expand Down Expand Up @@ -156,27 +162,12 @@ impl Default for Batch {

/// This enum represents a CQL statement, that can be part of batch.
#[derive(Clone)]
pub enum BatchStatement {
pub(crate) enum BatchStatement {
Query(Query),
PreparedStatement(PreparedStatement),
}

impl From<&str> for BatchStatement {
fn from(s: &str) -> Self {
BatchStatement::Query(Query::from(s))
}
}

impl From<Query> for BatchStatement {
fn from(q: Query) -> Self {
BatchStatement::Query(q)
}
}

impl From<PreparedStatement> for BatchStatement {
fn from(p: PreparedStatement) -> Self {
BatchStatement::PreparedStatement(p)
}
PreparedStatement {
statement: PreparedStatement,
values: SerializedValues,
},
}

impl<'a: 'b, 'b> From<&'a BatchStatement>
Expand All @@ -187,11 +178,13 @@ impl<'a: 'b, 'b> From<&'a BatchStatement>
BatchStatement::Query(query) => {
scylla_cql::frame::request::batch::BatchStatement::Query {
text: Cow::Borrowed(&query.contents),
values: Cow::Owned(SerializedValues::new()),
}
}
BatchStatement::PreparedStatement(prepared) => {
BatchStatement::PreparedStatement { statement, values } => {
scylla_cql::frame::request::batch::BatchStatement::Prepared {
id: Cow::Borrowed(prepared.get_id()),
id: Cow::Borrowed(statement.get_id()),
values: Cow::Borrowed(values),
}
}
}
Expand Down
Loading

0 comments on commit 685fca4

Please sign in to comment.