diff --git a/scylla/src/client/session.rs b/scylla/src/client/session.rs index 3d3035e08..e7b8aa3ac 100644 --- a/scylla/src/client/session.rs +++ b/scylla/src/client/session.rs @@ -16,7 +16,7 @@ use crate::cluster::node::{InternalKnownNode, KnownNode, NodeRef}; use crate::cluster::{Cluster, ClusterNeatDebug, ClusterState}; use crate::errors::{ BadQuery, MetadataError, NewSessionError, ProtocolError, QueryError, RequestAttemptError, - RequestError, TracingProtocolError, + RequestError, TracingProtocolError, UseKeyspaceError, }; use crate::frame::response::result; #[cfg(feature = "ssl")] @@ -1697,7 +1697,7 @@ where &self, keyspace_name: impl Into, case_sensitive: bool, - ) -> Result<(), QueryError> { + ) -> Result<(), UseKeyspaceError> { let keyspace_name = keyspace_name.into(); self.keyspace_name .store(Some(Arc::new(keyspace_name.clone()))); diff --git a/scylla/src/client/session_test.rs b/scylla/src/client/session_test.rs index 87c1c26e0..6d870d1eb 100644 --- a/scylla/src/client/session_test.rs +++ b/scylla/src/client/session_test.rs @@ -9,7 +9,7 @@ use crate::cluster::metadata::{ CollectionType, ColumnKind, ColumnType, NativeType, UserDefinedType, }; use crate::deserialize::DeserializeOwnedValue; -use crate::errors::{BadKeyspaceName, BadQuery, DbError, QueryError}; +use crate::errors::{BadKeyspaceName, DbError, QueryError, UseKeyspaceError}; use crate::observability::tracing::TracingInfo; use crate::policies::retry::{RequestInfo, RetryDecision, RetryPolicy, RetrySession}; use crate::prepared_statement::PreparedStatement; @@ -750,24 +750,23 @@ async fn test_use_keyspace() { // Test that invalid keyspaces get rejected assert!(matches!( session.use_keyspace("", false).await, - Err(QueryError::BadQuery(BadQuery::BadKeyspaceName( - BadKeyspaceName::Empty - ))) + Err(UseKeyspaceError::BadKeyspaceName(BadKeyspaceName::Empty)) )); let long_name: String = ['a'; 49].iter().collect(); assert!(matches!( session.use_keyspace(long_name, false).await, - Err(QueryError::BadQuery(BadQuery::BadKeyspaceName( - BadKeyspaceName::TooLong(_, _) + Err(UseKeyspaceError::BadKeyspaceName(BadKeyspaceName::TooLong( + _, + _ ))) )); assert!(matches!( session.use_keyspace("abcd;dfdsf", false).await, - Err(QueryError::BadQuery(BadQuery::BadKeyspaceName( + Err(UseKeyspaceError::BadKeyspaceName( BadKeyspaceName::IllegalCharacter(_, ';') - ))) + )) )); // Make sure that use_keyspace on SessionBuiler works diff --git a/scylla/src/cluster/node.rs b/scylla/src/cluster/node.rs index 808902c23..1e66f9be9 100644 --- a/scylla/src/cluster/node.rs +++ b/scylla/src/cluster/node.rs @@ -3,7 +3,7 @@ use tokio::net::lookup_host; use tracing::warn; use uuid::Uuid; -use crate::errors::{ConnectionPoolError, QueryError}; +use crate::errors::{ConnectionPoolError, UseKeyspaceError}; use crate::network::Connection; use crate::network::VerifiedKeyspaceName; use crate::network::{NodeConnectionPool, PoolConfig}; @@ -180,7 +180,7 @@ impl Node { pub(crate) async fn use_keyspace( &self, keyspace_name: VerifiedKeyspaceName, - ) -> Result<(), QueryError> { + ) -> Result<(), UseKeyspaceError> { if let Some(pool) = &self.pool { pool.use_keyspace(keyspace_name).await?; } diff --git a/scylla/src/cluster/worker.rs b/scylla/src/cluster/worker.rs index 133b81f44..1d047377e 100644 --- a/scylla/src/cluster/worker.rs +++ b/scylla/src/cluster/worker.rs @@ -1,5 +1,5 @@ use crate::client::session::TABLET_CHANNEL_SIZE; -use crate::errors::{MetadataError, NewSessionError, QueryError}; +use crate::errors::{MetadataError, NewSessionError, RequestAttemptError, UseKeyspaceError}; use crate::frame::response::event::{Event, StatusChangeEvent}; use crate::network::{PoolConfig, VerifiedKeyspaceName}; use crate::policies::host_filter::HostFilter; @@ -101,7 +101,7 @@ struct RefreshRequest { #[derive(Debug)] struct UseKeyspaceRequest { keyspace_name: VerifiedKeyspaceName, - response_chan: tokio::sync::oneshot::Sender>, + response_chan: tokio::sync::oneshot::Sender>, } impl Cluster { @@ -202,7 +202,7 @@ impl Cluster { pub(crate) async fn use_keyspace( &self, keyspace_name: VerifiedKeyspaceName, - ) -> Result<(), QueryError> { + ) -> Result<(), UseKeyspaceError> { let (response_sender, response_receiver) = tokio::sync::oneshot::channel(); self.use_keyspace_channel @@ -390,12 +390,12 @@ impl ClusterWorker { async fn send_use_keyspace( cluster_data: Arc, keyspace_name: &VerifiedKeyspaceName, - ) -> Result<(), QueryError> { + ) -> Result<(), UseKeyspaceError> { let use_keyspace_futures = cluster_data .known_peers .values() .map(|node| node.use_keyspace(keyspace_name.clone())); - let use_keyspace_results: Vec> = + let use_keyspace_results: Vec> = join_all(use_keyspace_futures).await; use_keyspace_result(use_keyspace_results.into_iter()) @@ -438,8 +438,8 @@ impl ClusterWorker { /// /// This function assumes that `use_keyspace_results` iterator is NON-EMPTY! pub(crate) fn use_keyspace_result( - use_keyspace_results: impl Iterator>, -) -> Result<(), QueryError> { + use_keyspace_results: impl Iterator>, +) -> Result<(), UseKeyspaceError> { // If there was at least one Ok and the rest were broken connection errors we can return Ok // keyspace name is correct and will be used on broken connection on the next reconnect @@ -447,13 +447,13 @@ pub(crate) fn use_keyspace_result( // If there was an error different than broken connection error return this error - something is wrong let mut was_ok: bool = false; - let mut broken_conn_error: Option = None; + let mut broken_conn_error: Option = None; for result in use_keyspace_results { match result { Ok(()) => was_ok = true, Err(err) => match err { - QueryError::BrokenConnection(_) | QueryError::ConnectionPoolError(_) => { + UseKeyspaceError::RequestError(RequestAttemptError::BrokenConnectionError(_)) => { broken_conn_error = Some(err) } _ => return Err(err), diff --git a/scylla/src/errors.rs b/scylla/src/errors.rs index 479682fa5..3ae8216a6 100644 --- a/scylla/src/errors.rs +++ b/scylla/src/errors.rs @@ -89,10 +89,6 @@ pub enum QueryError { #[error("Protocol error: {0}")] ProtocolError(#[from] ProtocolError), - /// Timeout error has occurred, function didn't complete in time. - #[error("Timeout Error")] - TimeoutError, - /// A connection has been broken during query execution. #[error(transparent)] BrokenConnection(#[from] BrokenConnectionError), @@ -112,6 +108,10 @@ pub enum QueryError { #[error("Schema agreement exceeded {}ms", std::time::Duration::as_millis(.0))] SchemaAgreementTimeout(std::time::Duration), + /// 'USE KEYSPACE <>' request failed. + #[error("'USE KEYSPACE <>' request failed: {0}")] + UseKeyspaceError(#[from] UseKeyspaceError), + // TODO: This should not belong here, but it requires changes to error types // returned in async iterator API. This should be handled in separate PR. // The reason this needs to be included is that topology.rs makes use of iter API and returns QueryError. @@ -144,39 +144,6 @@ impl From for QueryError { } } -impl From for NewSessionError { - fn from(query_error: QueryError) -> NewSessionError { - match query_error { - QueryError::DbError(e, msg) => NewSessionError::DbError(e, msg), - QueryError::BadQuery(e) => NewSessionError::BadQuery(e), - QueryError::CqlRequestSerialization(e) => NewSessionError::CqlRequestSerialization(e), - QueryError::CqlResultParseError(e) => NewSessionError::CqlResultParseError(e), - QueryError::CqlErrorParseError(e) => NewSessionError::CqlErrorParseError(e), - QueryError::BodyExtensionsParseError(e) => NewSessionError::BodyExtensionsParseError(e), - QueryError::EmptyPlan => NewSessionError::EmptyPlan, - QueryError::MetadataError(e) => NewSessionError::MetadataError(e), - QueryError::ConnectionPoolError(e) => NewSessionError::ConnectionPoolError(e), - QueryError::ProtocolError(e) => NewSessionError::ProtocolError(e), - QueryError::TimeoutError => NewSessionError::TimeoutError, - QueryError::BrokenConnection(e) => NewSessionError::BrokenConnection(e), - QueryError::UnableToAllocStreamId => NewSessionError::UnableToAllocStreamId, - QueryError::RequestTimeout(dur) => NewSessionError::RequestTimeout(dur), - QueryError::SchemaAgreementTimeout(dur) => NewSessionError::SchemaAgreementTimeout(dur), - #[allow(deprecated)] - QueryError::IntoLegacyQueryResultError(e) => { - NewSessionError::IntoLegacyQueryResultError(e) - } - QueryError::NextRowError(e) => NewSessionError::NextRowError(e), - } - } -} - -impl From for QueryError { - fn from(keyspace_err: BadKeyspaceName) -> QueryError { - QueryError::BadQuery(BadQuery::BadKeyspaceName(keyspace_err)) - } -} - impl From for QueryError { fn from(error: response::Error) -> QueryError { QueryError::DbError(error.error, error.reason) @@ -197,91 +164,13 @@ pub enum NewSessionError { #[error("Empty known nodes list")] EmptyKnownNodesList, - /// Database sent a response containing some error with a message - #[error("Database returned an error: {0}, Error message: {1}")] - DbError(DbError, String), - - /// Caller passed an invalid query - #[error(transparent)] - BadQuery(#[from] BadQuery), - - /// Failed to serialize CQL request. - #[error("Failed to serialize CQL request: {0}")] - CqlRequestSerialization(#[from] CqlRequestSerializationError), - - /// Load balancing policy returned an empty plan. - #[error( - "Load balancing policy returned an empty plan.\ - First thing to investigate should be the logic of custom LBP implementation.\ - If you think that your LBP implementation is correct, or you make use of `DefaultPolicy`,\ - then this is most probably a driver bug!" - )] - EmptyPlan, - - /// Failed to deserialize frame body extensions. - #[error(transparent)] - BodyExtensionsParseError(#[from] FrameBodyExtensionsParseError), - /// Failed to perform initial cluster metadata fetch. #[error("Failed to perform initial cluster metadata fetch: {0}")] MetadataError(#[from] MetadataError), - /// Received a RESULT server response, but failed to deserialize it. - #[error(transparent)] - CqlResultParseError(#[from] CqlResultParseError), - - /// Received an ERROR server response, but failed to deserialize it. - #[error("Failed to deserialize ERROR response: {0}")] - CqlErrorParseError(#[from] CqlErrorParseError), - - /// Selected node's connection pool is in invalid state. - #[error("No connections in the pool: {0}")] - ConnectionPoolError(#[from] ConnectionPoolError), - - /// Protocol error. - #[error("Protocol error: {0}")] - ProtocolError(#[from] ProtocolError), - - /// Timeout error has occurred, couldn't connect to node in time. - #[error("Timeout Error")] - TimeoutError, - - /// A connection has been broken during query execution. - #[error(transparent)] - BrokenConnection(#[from] BrokenConnectionError), - - /// Driver was unable to allocate a stream id to execute a query on. - #[error("Unable to allocate stream id")] - UnableToAllocStreamId, - - /// Failed to run a request within a provided client timeout. - #[error( - "Request execution exceeded a client timeout of {}ms", - std::time::Duration::as_millis(.0) - )] - RequestTimeout(std::time::Duration), - - /// Schema agreement timed out. - #[error("Schema agreement exceeded {}ms", std::time::Duration::as_millis(.0))] - SchemaAgreementTimeout(std::time::Duration), - - // TODO: This should not belong here, but it requires changes to error types - // returned in async iterator API. This should be handled in separate PR. - // The reason this needs to be included is that topology.rs makes use of iter API and returns QueryError. - // Once iter API is adjusted, we can then adjust errors returned by topology module (e.g. refactor MetadataError and not include it in QueryError). - /// An error occurred during async iteration over rows of result. - #[error("An error occurred during async iteration over rows of result: {0}")] - NextRowError(#[from] NextRowError), - - /// Failed to convert [`QueryResult`][crate::response::query_result::QueryResult] - /// into [`LegacyQueryResult`][crate::response::legacy_query_result::LegacyQueryResult]. - #[deprecated( - since = "0.15.1", - note = "Legacy deserialization API is inefficient and is going to be removed soon" - )] - #[allow(deprecated)] - #[error("Failed to convert `QueryResult` into `LegacyQueryResult`: {0}")] - IntoLegacyQueryResultError(#[from] IntoLegacyQueryResultError), + /// 'USE KEYSPACE <>' request failed. + #[error("'USE KEYSPACE <>' request failed: {0}")] + UseKeyspaceError(#[from] UseKeyspaceError), } /// A protocol error. @@ -316,10 +205,6 @@ pub enum ProtocolError { reprepared_id: Vec, }, - /// USE KEYSPACE protocol error. - #[error("USE KEYSPACE protocol error: {0}")] - UseKeyspace(#[from] UseKeyspaceProtocolError), - /// A protocol error appeared during schema version fetch. #[error("Schema version fetch protocol error: {0}")] SchemaVersionFetch(#[from] SchemaVersionFetchError), @@ -342,17 +227,31 @@ pub enum ProtocolError { RepreparedIdMissingInBatch, } -/// A protocol error that occurred during `USE KEYSPACE <>` request. +/// An error that occurred during `USE KEYSPACE <>` request. #[derive(Error, Debug, Clone)] #[non_exhaustive] -pub enum UseKeyspaceProtocolError { +pub enum UseKeyspaceError { + /// Passed invalid keyspace name to use. + #[error("Passed invalid keyspace name to use: {0}")] + BadKeyspaceName(#[from] BadKeyspaceName), + + /// An error during request execution. + #[error(transparent)] + RequestError(#[from] RequestAttemptError), + + /// Keyspace name mismatch. #[error("Keyspace name mismtach; expected: {expected_keyspace_name_lowercase}, received: {result_keyspace_name_lowercase}")] KeyspaceNameMismatch { expected_keyspace_name_lowercase: String, result_keyspace_name_lowercase: String, }, - #[error("Received unexpected response: {0}. Expected RESULT:Set_keyspace")] - UnexpectedResponse(CqlResponseKind), + + /// Failed to run a request within a provided client timeout. + #[error( + "Request execution exceeded a client timeout of {}ms", + std::time::Duration::as_millis(.0) + )] + RequestTimeout(std::time::Duration), } /// A protocol error that occurred during schema version fetch. @@ -592,10 +491,6 @@ pub enum BadQuery { #[error("Serialized values are too long to compute partition key! Length: {0}, Max allowed length: {1}")] ValuesTooLongForKey(usize, usize), - /// Passed invalid keyspace name to use - #[error("Passed invalid keyspace name to use: {0}")] - BadKeyspaceName(#[from] BadKeyspaceName), - /// Too many queries in the batch statement #[error("Number of Queries in Batch Statement supplied is {0} which has exceeded the max value of 65,535")] TooManyQueriesInBatchStatement(usize), diff --git a/scylla/src/network/connection.rs b/scylla/src/network/connection.rs index d4bd142ee..8e99ecdc0 100644 --- a/scylla/src/network/connection.rs +++ b/scylla/src/network/connection.rs @@ -11,7 +11,7 @@ use crate::errors::{ BadKeyspaceName, BrokenConnectionError, BrokenConnectionErrorKind, ConnectionError, ConnectionSetupRequestError, ConnectionSetupRequestErrorKind, CqlEventHandlingError, DbError, InternalRequestError, ProtocolError, QueryError, RequestAttemptError, ResponseParseError, - SchemaVersionFetchError, TranslationError, UseKeyspaceProtocolError, + SchemaVersionFetchError, TranslationError, UseKeyspaceError, }; use crate::frame::protocol_features::ProtocolFeatures; use crate::frame::{ @@ -1179,7 +1179,7 @@ impl Connection { pub(super) async fn use_keyspace( &self, keyspace_name: &VerifiedKeyspaceName, - ) -> Result<(), QueryError> { + ) -> Result<(), UseKeyspaceError> { // Trying to pass keyspace_name as bound value doesn't work // We have to send "USE " + keyspace_name let query: Query = match keyspace_name.is_case_sensitive { @@ -1187,17 +1187,14 @@ impl Connection { false => format!("USE {}", keyspace_name.as_str()).into(), }; - let query_response = self - .query_raw_unpaged(&query) - .await - .map_err(RequestAttemptError::into_query_error)?; + let query_response = self.query_raw_unpaged(&query).await?; Self::verify_use_keyspace_result(keyspace_name, query_response) } fn verify_use_keyspace_result( keyspace_name: &VerifiedKeyspaceName, query_response: QueryResponse, - ) -> Result<(), QueryError> { + ) -> Result<(), UseKeyspaceError> { match query_response.response { Response::Result(result::Result::SetKeyspace(set_keyspace)) => { if !set_keyspace @@ -1207,24 +1204,20 @@ impl Connection { let expected_keyspace_name_lowercase = keyspace_name.as_str().to_lowercase(); let result_keyspace_name_lowercase = set_keyspace.keyspace_name.to_lowercase(); - return Err(ProtocolError::UseKeyspace( - UseKeyspaceProtocolError::KeyspaceNameMismatch { - expected_keyspace_name_lowercase, - result_keyspace_name_lowercase, - }, - ) - .into()); + return Err(UseKeyspaceError::KeyspaceNameMismatch { + expected_keyspace_name_lowercase, + result_keyspace_name_lowercase, + }); } Ok(()) } - Response::Error(err) => Err(err.into()), - _ => Err( - ProtocolError::UseKeyspace(UseKeyspaceProtocolError::UnexpectedResponse( - query_response.response.to_response_kind(), - )) - .into(), - ), + Response::Error(err) => Err(UseKeyspaceError::RequestError( + RequestAttemptError::DbError(err.error, err.reason), + )), + _ => Err(UseKeyspaceError::RequestError( + RequestAttemptError::UnexpectedResponse(query_response.response.to_response_kind()), + )), } } diff --git a/scylla/src/network/connection_pool.rs b/scylla/src/network/connection_pool.rs index 2390365f2..835ac3f21 100644 --- a/scylla/src/network/connection_pool.rs +++ b/scylla/src/network/connection_pool.rs @@ -7,7 +7,9 @@ use super::connection::{ ErrorReceiver, VerifiedKeyspaceName, }; -use crate::errors::{BrokenConnectionErrorKind, ConnectionError, ConnectionPoolError, QueryError}; +use crate::errors::{ + BrokenConnectionErrorKind, ConnectionError, ConnectionPoolError, UseKeyspaceError, +}; use crate::routing::{Shard, ShardCount, Sharder}; use crate::cluster::metadata::{PeerEndpoint, UntranslatedEndpoint}; @@ -317,7 +319,7 @@ impl NodeConnectionPool { pub(crate) async fn use_keyspace( &self, keyspace_name: VerifiedKeyspaceName, - ) -> Result<(), QueryError> { + ) -> Result<(), UseKeyspaceError> { let (response_sender, response_receiver) = tokio::sync::oneshot::channel(); self.use_keyspace_request_sender @@ -481,7 +483,7 @@ struct PoolRefiller { #[derive(Debug)] struct UseKeyspaceRequest { keyspace_name: VerifiedKeyspaceName, - response_sender: tokio::sync::oneshot::Sender>, + response_sender: tokio::sync::oneshot::Sender>, } impl PoolRefiller { @@ -1075,7 +1077,7 @@ impl PoolRefiller { fn use_keyspace( &mut self, keyspace_name: VerifiedKeyspaceName, - response_sender: tokio::sync::oneshot::Sender>, + response_sender: tokio::sync::oneshot::Sender>, ) { self.current_keyspace = Some(keyspace_name.clone()); @@ -1097,12 +1099,13 @@ impl PoolRefiller { return Ok(()); } - let use_keyspace_results: Vec> = tokio::time::timeout( + let use_keyspace_results: Vec> = tokio::time::timeout( connect_timeout, futures::future::join_all(use_keyspace_futures), ) .await - .map_err(|_| QueryError::TimeoutError)?; + // FIXME: We could probably make USE KEYSPACE request timeout configurable in the future. + .map_err(|_| UseKeyspaceError::RequestTimeout(connect_timeout))?; crate::cluster::use_keyspace_result(use_keyspace_results.into_iter()) };