Skip to content

Commit

Permalink
Merge pull request #1180 from muzarski/use-keyspace-errors-refactor
Browse files Browse the repository at this point in the history
errors: refactor errors on USE KEYSPACE execution path and clean up NewSessionError
  • Loading branch information
wprzytula authored Jan 30, 2025
2 parents 9730847 + 32f5a21 commit 4c464b6
Show file tree
Hide file tree
Showing 7 changed files with 68 additions and 178 deletions.
4 changes: 2 additions & 2 deletions scylla/src/client/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use crate::cluster::node::{InternalKnownNode, KnownNode, NodeRef};
use crate::cluster::{Cluster, ClusterNeatDebug, ClusterState};
use crate::errors::{
BadQuery, MetadataError, NewSessionError, ProtocolError, QueryError, RequestAttemptError,
RequestError, TracingProtocolError,
RequestError, TracingProtocolError, UseKeyspaceError,
};
use crate::frame::response::result;
#[cfg(feature = "ssl")]
Expand Down Expand Up @@ -1697,7 +1697,7 @@ where
&self,
keyspace_name: impl Into<String>,
case_sensitive: bool,
) -> Result<(), QueryError> {
) -> Result<(), UseKeyspaceError> {
let keyspace_name = keyspace_name.into();
self.keyspace_name
.store(Some(Arc::new(keyspace_name.clone())));
Expand Down
15 changes: 7 additions & 8 deletions scylla/src/client/session_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use crate::cluster::metadata::{
CollectionType, ColumnKind, ColumnType, NativeType, UserDefinedType,
};
use crate::deserialize::DeserializeOwnedValue;
use crate::errors::{BadKeyspaceName, BadQuery, DbError, QueryError};
use crate::errors::{BadKeyspaceName, DbError, QueryError, UseKeyspaceError};
use crate::observability::tracing::TracingInfo;
use crate::policies::retry::{RequestInfo, RetryDecision, RetryPolicy, RetrySession};
use crate::prepared_statement::PreparedStatement;
Expand Down Expand Up @@ -750,24 +750,23 @@ async fn test_use_keyspace() {
// Test that invalid keyspaces get rejected
assert!(matches!(
session.use_keyspace("", false).await,
Err(QueryError::BadQuery(BadQuery::BadKeyspaceName(
BadKeyspaceName::Empty
)))
Err(UseKeyspaceError::BadKeyspaceName(BadKeyspaceName::Empty))
));

let long_name: String = ['a'; 49].iter().collect();
assert!(matches!(
session.use_keyspace(long_name, false).await,
Err(QueryError::BadQuery(BadQuery::BadKeyspaceName(
BadKeyspaceName::TooLong(_, _)
Err(UseKeyspaceError::BadKeyspaceName(BadKeyspaceName::TooLong(
_,
_
)))
));

assert!(matches!(
session.use_keyspace("abcd;dfdsf", false).await,
Err(QueryError::BadQuery(BadQuery::BadKeyspaceName(
Err(UseKeyspaceError::BadKeyspaceName(
BadKeyspaceName::IllegalCharacter(_, ';')
)))
))
));

// Make sure that use_keyspace on SessionBuiler works
Expand Down
4 changes: 2 additions & 2 deletions scylla/src/cluster/node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use tokio::net::lookup_host;
use tracing::warn;
use uuid::Uuid;

use crate::errors::{ConnectionPoolError, QueryError};
use crate::errors::{ConnectionPoolError, UseKeyspaceError};
use crate::network::Connection;
use crate::network::VerifiedKeyspaceName;
use crate::network::{NodeConnectionPool, PoolConfig};
Expand Down Expand Up @@ -180,7 +180,7 @@ impl Node {
pub(crate) async fn use_keyspace(
&self,
keyspace_name: VerifiedKeyspaceName,
) -> Result<(), QueryError> {
) -> Result<(), UseKeyspaceError> {
if let Some(pool) = &self.pool {
pool.use_keyspace(keyspace_name).await?;
}
Expand Down
18 changes: 9 additions & 9 deletions scylla/src/cluster/worker.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::client::session::TABLET_CHANNEL_SIZE;
use crate::errors::{MetadataError, NewSessionError, QueryError};
use crate::errors::{MetadataError, NewSessionError, RequestAttemptError, UseKeyspaceError};
use crate::frame::response::event::{Event, StatusChangeEvent};
use crate::network::{PoolConfig, VerifiedKeyspaceName};
use crate::policies::host_filter::HostFilter;
Expand Down Expand Up @@ -101,7 +101,7 @@ struct RefreshRequest {
#[derive(Debug)]
struct UseKeyspaceRequest {
keyspace_name: VerifiedKeyspaceName,
response_chan: tokio::sync::oneshot::Sender<Result<(), QueryError>>,
response_chan: tokio::sync::oneshot::Sender<Result<(), UseKeyspaceError>>,
}

impl Cluster {
Expand Down Expand Up @@ -202,7 +202,7 @@ impl Cluster {
pub(crate) async fn use_keyspace(
&self,
keyspace_name: VerifiedKeyspaceName,
) -> Result<(), QueryError> {
) -> Result<(), UseKeyspaceError> {
let (response_sender, response_receiver) = tokio::sync::oneshot::channel();

self.use_keyspace_channel
Expand Down Expand Up @@ -390,12 +390,12 @@ impl ClusterWorker {
async fn send_use_keyspace(
cluster_data: Arc<ClusterState>,
keyspace_name: &VerifiedKeyspaceName,
) -> Result<(), QueryError> {
) -> Result<(), UseKeyspaceError> {
let use_keyspace_futures = cluster_data
.known_peers
.values()
.map(|node| node.use_keyspace(keyspace_name.clone()));
let use_keyspace_results: Vec<Result<(), QueryError>> =
let use_keyspace_results: Vec<Result<(), UseKeyspaceError>> =
join_all(use_keyspace_futures).await;

use_keyspace_result(use_keyspace_results.into_iter())
Expand Down Expand Up @@ -438,22 +438,22 @@ impl ClusterWorker {
///
/// This function assumes that `use_keyspace_results` iterator is NON-EMPTY!
pub(crate) fn use_keyspace_result(
use_keyspace_results: impl Iterator<Item = Result<(), QueryError>>,
) -> Result<(), QueryError> {
use_keyspace_results: impl Iterator<Item = Result<(), UseKeyspaceError>>,
) -> Result<(), UseKeyspaceError> {
// If there was at least one Ok and the rest were broken connection errors we can return Ok
// keyspace name is correct and will be used on broken connection on the next reconnect

// If there were only broken connection errors then return broken connection error.
// If there was an error different than broken connection error return this error - something is wrong

let mut was_ok: bool = false;
let mut broken_conn_error: Option<QueryError> = None;
let mut broken_conn_error: Option<UseKeyspaceError> = None;

for result in use_keyspace_results {
match result {
Ok(()) => was_ok = true,
Err(err) => match err {
QueryError::BrokenConnection(_) | QueryError::ConnectionPoolError(_) => {
UseKeyspaceError::RequestError(RequestAttemptError::BrokenConnectionError(_)) => {
broken_conn_error = Some(err)
}
_ => return Err(err),
Expand Down
155 changes: 25 additions & 130 deletions scylla/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,6 @@ pub enum QueryError {
#[error("Protocol error: {0}")]
ProtocolError(#[from] ProtocolError),

/// Timeout error has occurred, function didn't complete in time.
#[error("Timeout Error")]
TimeoutError,

/// A connection has been broken during query execution.
#[error(transparent)]
BrokenConnection(#[from] BrokenConnectionError),
Expand All @@ -112,6 +108,10 @@ pub enum QueryError {
#[error("Schema agreement exceeded {}ms", std::time::Duration::as_millis(.0))]
SchemaAgreementTimeout(std::time::Duration),

/// 'USE KEYSPACE <>' request failed.
#[error("'USE KEYSPACE <>' request failed: {0}")]
UseKeyspaceError(#[from] UseKeyspaceError),

// TODO: This should not belong here, but it requires changes to error types
// returned in async iterator API. This should be handled in separate PR.
// The reason this needs to be included is that topology.rs makes use of iter API and returns QueryError.
Expand Down Expand Up @@ -144,39 +144,6 @@ impl From<SerializationError> for QueryError {
}
}

impl From<QueryError> for NewSessionError {
fn from(query_error: QueryError) -> NewSessionError {
match query_error {
QueryError::DbError(e, msg) => NewSessionError::DbError(e, msg),
QueryError::BadQuery(e) => NewSessionError::BadQuery(e),
QueryError::CqlRequestSerialization(e) => NewSessionError::CqlRequestSerialization(e),
QueryError::CqlResultParseError(e) => NewSessionError::CqlResultParseError(e),
QueryError::CqlErrorParseError(e) => NewSessionError::CqlErrorParseError(e),
QueryError::BodyExtensionsParseError(e) => NewSessionError::BodyExtensionsParseError(e),
QueryError::EmptyPlan => NewSessionError::EmptyPlan,
QueryError::MetadataError(e) => NewSessionError::MetadataError(e),
QueryError::ConnectionPoolError(e) => NewSessionError::ConnectionPoolError(e),
QueryError::ProtocolError(e) => NewSessionError::ProtocolError(e),
QueryError::TimeoutError => NewSessionError::TimeoutError,
QueryError::BrokenConnection(e) => NewSessionError::BrokenConnection(e),
QueryError::UnableToAllocStreamId => NewSessionError::UnableToAllocStreamId,
QueryError::RequestTimeout(dur) => NewSessionError::RequestTimeout(dur),
QueryError::SchemaAgreementTimeout(dur) => NewSessionError::SchemaAgreementTimeout(dur),
#[allow(deprecated)]
QueryError::IntoLegacyQueryResultError(e) => {
NewSessionError::IntoLegacyQueryResultError(e)
}
QueryError::NextRowError(e) => NewSessionError::NextRowError(e),
}
}
}

impl From<BadKeyspaceName> for QueryError {
fn from(keyspace_err: BadKeyspaceName) -> QueryError {
QueryError::BadQuery(BadQuery::BadKeyspaceName(keyspace_err))
}
}

impl From<response::Error> for QueryError {
fn from(error: response::Error) -> QueryError {
QueryError::DbError(error.error, error.reason)
Expand All @@ -197,91 +164,13 @@ pub enum NewSessionError {
#[error("Empty known nodes list")]
EmptyKnownNodesList,

/// Database sent a response containing some error with a message
#[error("Database returned an error: {0}, Error message: {1}")]
DbError(DbError, String),

/// Caller passed an invalid query
#[error(transparent)]
BadQuery(#[from] BadQuery),

/// Failed to serialize CQL request.
#[error("Failed to serialize CQL request: {0}")]
CqlRequestSerialization(#[from] CqlRequestSerializationError),

/// Load balancing policy returned an empty plan.
#[error(
"Load balancing policy returned an empty plan.\
First thing to investigate should be the logic of custom LBP implementation.\
If you think that your LBP implementation is correct, or you make use of `DefaultPolicy`,\
then this is most probably a driver bug!"
)]
EmptyPlan,

/// Failed to deserialize frame body extensions.
#[error(transparent)]
BodyExtensionsParseError(#[from] FrameBodyExtensionsParseError),

/// Failed to perform initial cluster metadata fetch.
#[error("Failed to perform initial cluster metadata fetch: {0}")]
MetadataError(#[from] MetadataError),

/// Received a RESULT server response, but failed to deserialize it.
#[error(transparent)]
CqlResultParseError(#[from] CqlResultParseError),

/// Received an ERROR server response, but failed to deserialize it.
#[error("Failed to deserialize ERROR response: {0}")]
CqlErrorParseError(#[from] CqlErrorParseError),

/// Selected node's connection pool is in invalid state.
#[error("No connections in the pool: {0}")]
ConnectionPoolError(#[from] ConnectionPoolError),

/// Protocol error.
#[error("Protocol error: {0}")]
ProtocolError(#[from] ProtocolError),

/// Timeout error has occurred, couldn't connect to node in time.
#[error("Timeout Error")]
TimeoutError,

/// A connection has been broken during query execution.
#[error(transparent)]
BrokenConnection(#[from] BrokenConnectionError),

/// Driver was unable to allocate a stream id to execute a query on.
#[error("Unable to allocate stream id")]
UnableToAllocStreamId,

/// Failed to run a request within a provided client timeout.
#[error(
"Request execution exceeded a client timeout of {}ms",
std::time::Duration::as_millis(.0)
)]
RequestTimeout(std::time::Duration),

/// Schema agreement timed out.
#[error("Schema agreement exceeded {}ms", std::time::Duration::as_millis(.0))]
SchemaAgreementTimeout(std::time::Duration),

// TODO: This should not belong here, but it requires changes to error types
// returned in async iterator API. This should be handled in separate PR.
// The reason this needs to be included is that topology.rs makes use of iter API and returns QueryError.
// Once iter API is adjusted, we can then adjust errors returned by topology module (e.g. refactor MetadataError and not include it in QueryError).
/// An error occurred during async iteration over rows of result.
#[error("An error occurred during async iteration over rows of result: {0}")]
NextRowError(#[from] NextRowError),

/// Failed to convert [`QueryResult`][crate::response::query_result::QueryResult]
/// into [`LegacyQueryResult`][crate::response::legacy_query_result::LegacyQueryResult].
#[deprecated(
since = "0.15.1",
note = "Legacy deserialization API is inefficient and is going to be removed soon"
)]
#[allow(deprecated)]
#[error("Failed to convert `QueryResult` into `LegacyQueryResult`: {0}")]
IntoLegacyQueryResultError(#[from] IntoLegacyQueryResultError),
/// 'USE KEYSPACE <>' request failed.
#[error("'USE KEYSPACE <>' request failed: {0}")]
UseKeyspaceError(#[from] UseKeyspaceError),
}

/// A protocol error.
Expand Down Expand Up @@ -316,10 +205,6 @@ pub enum ProtocolError {
reprepared_id: Vec<u8>,
},

/// USE KEYSPACE protocol error.
#[error("USE KEYSPACE protocol error: {0}")]
UseKeyspace(#[from] UseKeyspaceProtocolError),

/// A protocol error appeared during schema version fetch.
#[error("Schema version fetch protocol error: {0}")]
SchemaVersionFetch(#[from] SchemaVersionFetchError),
Expand All @@ -342,17 +227,31 @@ pub enum ProtocolError {
RepreparedIdMissingInBatch,
}

/// A protocol error that occurred during `USE KEYSPACE <>` request.
/// An error that occurred during `USE KEYSPACE <>` request.
#[derive(Error, Debug, Clone)]
#[non_exhaustive]
pub enum UseKeyspaceProtocolError {
pub enum UseKeyspaceError {
/// Passed invalid keyspace name to use.
#[error("Passed invalid keyspace name to use: {0}")]
BadKeyspaceName(#[from] BadKeyspaceName),

/// An error during request execution.
#[error(transparent)]
RequestError(#[from] RequestAttemptError),

/// Keyspace name mismatch.
#[error("Keyspace name mismtach; expected: {expected_keyspace_name_lowercase}, received: {result_keyspace_name_lowercase}")]
KeyspaceNameMismatch {
expected_keyspace_name_lowercase: String,
result_keyspace_name_lowercase: String,
},
#[error("Received unexpected response: {0}. Expected RESULT:Set_keyspace")]
UnexpectedResponse(CqlResponseKind),

/// Failed to run a request within a provided client timeout.
#[error(
"Request execution exceeded a client timeout of {}ms",
std::time::Duration::as_millis(.0)
)]
RequestTimeout(std::time::Duration),
}

/// A protocol error that occurred during schema version fetch.
Expand Down Expand Up @@ -592,10 +491,6 @@ pub enum BadQuery {
#[error("Serialized values are too long to compute partition key! Length: {0}, Max allowed length: {1}")]
ValuesTooLongForKey(usize, usize),

/// Passed invalid keyspace name to use
#[error("Passed invalid keyspace name to use: {0}")]
BadKeyspaceName(#[from] BadKeyspaceName),

/// Too many queries in the batch statement
#[error("Number of Queries in Batch Statement supplied is {0} which has exceeded the max value of 65,535")]
TooManyQueriesInBatchStatement(usize),
Expand Down
Loading

0 comments on commit 4c464b6

Please sign in to comment.