Skip to content

Commit

Permalink
refactor(mssql): clean up unused imports and other warnings
Browse files Browse the repository at this point in the history
  • Loading branch information
mehcode committed Jun 7, 2020
1 parent 95149c4 commit 559169c
Show file tree
Hide file tree
Showing 20 changed files with 370 additions and 244 deletions.
1 change: 0 additions & 1 deletion sqlx-core/src/mssql/connection/establish.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ use crate::error::Error;
use crate::io::Decode;
use crate::mssql::connection::stream::MsSqlStream;
use crate::mssql::protocol::login::Login7;
use crate::mssql::protocol::login_ack::LoginAck;
use crate::mssql::protocol::message::Message;
use crate::mssql::protocol::packet::PacketType;
use crate::mssql::protocol::pre_login::{Encrypt, PreLogin, Version};
Expand Down
19 changes: 13 additions & 6 deletions sqlx-core/src/mssql/connection/executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,16 @@ use crate::mssql::protocol::sql_batch::SqlBatch;
use crate::mssql::{MsSql, MsSqlArguments, MsSqlConnection, MsSqlRow};

impl MsSqlConnection {
async fn wait_until_ready(&mut self) -> Result<(), Error> {
pub(crate) async fn wait_until_ready(&mut self) -> Result<(), Error> {
if !self.stream.wbuf.is_empty() {
self.pending_done_count += 1;
self.stream.flush().await?;
}

while self.pending_done_count > 0 {
if let Message::DoneProc(done) | Message::Done(done) =
self.stream.recv_message().await?
{
let message = self.stream.recv_message().await?;

if let Message::DoneProc(done) | Message::Done(done) = message {
// finished RPC procedure *OR* SQL batch
self.handle_done(done);
}
Expand Down Expand Up @@ -59,14 +60,20 @@ impl MsSqlConnection {
self.stream.write_packet(
PacketType::Rpc,
RpcRequest {
transaction_descriptor: self.stream.transaction_descriptor,
arguments: &proc_args,
procedure: proc,
options: OptionFlags::empty(),
},
);
} else {
self.stream
.write_packet(PacketType::SqlBatch, SqlBatch { sql: query });
self.stream.write_packet(
PacketType::SqlBatch,
SqlBatch {
transaction_descriptor: self.stream.transaction_descriptor,
sql: query,
},
);
}

self.stream.flush().await?;
Expand Down
6 changes: 3 additions & 3 deletions sqlx-core/src/mssql/connection/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use futures_core::future::BoxFuture;
use futures_util::{future::ready, FutureExt, TryFutureExt};

use crate::connection::{Connect, Connection};
use crate::error::{BoxDynError, Error};
use crate::error::Error;
use crate::executor::Executor;
use crate::mssql::connection::stream::MsSqlStream;
use crate::mssql::{MsSql, MsSqlConnectOptions};
Expand All @@ -15,7 +15,7 @@ mod executor;
mod stream;

pub struct MsSqlConnection {
stream: MsSqlStream,
pub(crate) stream: MsSqlStream,

// number of Done* messages that we are currently expecting
pub(crate) pending_done_count: usize,
Expand All @@ -42,7 +42,7 @@ impl Connection for MsSqlConnection {

#[doc(hidden)]
fn flush(&mut self) -> BoxFuture<'_, Result<(), Error>> {
unimplemented!()
self.wait_until_ready().boxed()
}

#[doc(hidden)]
Expand Down
34 changes: 29 additions & 5 deletions sqlx-core/src/mssql/connection/stream.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::ops::{Deref, DerefMut};

use bytes::Bytes;
use sqlx_rt::{TcpStream, TlsStream};
use sqlx_rt::TcpStream;

use crate::error::Error;
use crate::io::{BufStream, Encode};
Expand All @@ -21,6 +21,10 @@ use crate::net::MaybeTlsStream;
pub(crate) struct MsSqlStream {
inner: BufStream<MaybeTlsStream<TcpStream>>,

// current transaction descriptor
// set from ENVCHANGE on `BEGIN` and reset to `0` on a ROLLBACK
pub(crate) transaction_descriptor: u64,

// current TabularResult from the server that we are iterating over
response: Option<(PacketHeader, Bytes)>,

Expand All @@ -39,12 +43,13 @@ impl MsSqlStream {
inner,
columns: Vec::new(),
response: None,
transaction_descriptor: 0,
})
}

// writes the packet out to the write buffer
// will (eventually) handle packet chunking
pub(super) fn write_packet<'en, T: Encode<'en>>(&mut self, ty: PacketType, payload: T) {
pub(crate) fn write_packet<'en, T: Encode<'en>>(&mut self, ty: PacketType, payload: T) {
// TODO: Support packet chunking for large packet sizes
// We likely need to double-buffer the writes so we know to chunk

Expand Down Expand Up @@ -98,7 +103,7 @@ impl MsSqlStream {
pub(super) async fn recv_message(&mut self) -> Result<Message, Error> {
loop {
while self.response.as_ref().map_or(false, |r| !r.1.is_empty()) {
let mut buf = if let Some((_, buf)) = self.response.as_mut() {
let buf = if let Some((_, buf)) = self.response.as_mut() {
buf
} else {
// this shouldn't be reachable but just nope out
Expand All @@ -108,8 +113,27 @@ impl MsSqlStream {

let ty = MessageType::get(buf)?;
let message = match ty {
MessageType::EnvChange => Message::EnvChange(EnvChange::get(buf)?),
MessageType::Info => Message::Info(Info::get(buf)?),
MessageType::EnvChange => {
match EnvChange::get(buf)? {
EnvChange::BeginTransaction(desc) => {
self.transaction_descriptor = desc;
}

EnvChange::CommitTransaction(_) | EnvChange::RollbackTransaction(_) => {
self.transaction_descriptor = 0;
}

_ => {}
}

continue;
}

MessageType::Info => {
let _ = Info::get(buf)?;
continue;
}

MessageType::Row => Message::Row(Row::get(buf, &self.columns)?),
MessageType::LoginAck => Message::LoginAck(LoginAck::get(buf)?),
MessageType::ReturnStatus => Message::ReturnStatus(ReturnStatus::get(buf)?),
Expand Down
18 changes: 14 additions & 4 deletions sqlx-core/src/mssql/protocol/env_change.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use bytes::{Buf, Bytes};

use crate::error::Error;
use crate::io::Decode;
use crate::mssql::io::MsSqlBufExt;

#[derive(Debug)]
Expand All @@ -16,9 +15,9 @@ pub(crate) enum EnvChange {
SqlCollation(Bytes),

// TDS 7.2+
BeginTransaction,
CommitTransaction,
RollbackTransaction,
BeginTransaction(u64),
CommitTransaction(u64),
RollbackTransaction(u64),
EnlistDtcTransaction,
DefectTransaction,
RealTimeLogShipping,
Expand Down Expand Up @@ -46,6 +45,17 @@ impl EnvChange {
5 => EnvChange::UnicodeDataSortingLocalId(data.get_b_varchar()?),
6 => EnvChange::UnicodeDataSortingComparisonFlags(data.get_b_varchar()?),
7 => EnvChange::SqlCollation(data.get_b_varbyte()),
8 => EnvChange::BeginTransaction(data.get_b_varbyte().get_u64_le()),

9 => {
let _ = data.get_u8();
EnvChange::CommitTransaction(data.get_u64_le())
}

10 => {
let _ = data.get_u8();
EnvChange::RollbackTransaction(data.get_u64_le())
}

_ => {
return Err(err_protocol!("unexpected value {} for ENVCHANGE Type", ty));
Expand Down
24 changes: 0 additions & 24 deletions sqlx-core/src/mssql/protocol/login.rs
Original file line number Diff line number Diff line change
@@ -1,26 +1,6 @@
use hex::encode;
use std::mem::size_of;

use crate::io::Encode;
use crate::mssql::io::MsSqlBufMutExt;

// Stream definition
// LOGIN7 = Length
// TDSVersion
// PacketSize
// ClientProgVer
// ClientPID
// ConnectionID
// OptionFlags1
// OptionFlags2
// TypeFlags
// OptionFlags3
// ClientTimeZone
// ClientLCID
// OffsetLength
// Data
// FeatureExt

#[derive(Debug)]
pub struct Login7<'a> {
pub version: u32,
Expand Down Expand Up @@ -156,10 +136,6 @@ impl Encode<'_> for Login7<'_> {

// [ChangePassword] New password for the specified login
write_offset(buf, &mut offsets, beg);
offsets += 2;

// [SSPILong] Used for large SSPI data
offsets += 4;

// Establish the length of the entire structure
let len = buf.len();
Expand Down
6 changes: 0 additions & 6 deletions sqlx-core/src/mssql/protocol/message.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,13 @@
use bytes::{Buf, Bytes};

use crate::mssql::protocol::col_meta_data::ColMetaData;
use crate::mssql::protocol::done::Done;
use crate::mssql::protocol::env_change::EnvChange;
use crate::mssql::protocol::error::Error;
use crate::mssql::protocol::info::Info;
use crate::mssql::protocol::login_ack::LoginAck;
use crate::mssql::protocol::return_status::ReturnStatus;
use crate::mssql::protocol::row::Row;

#[derive(Debug)]
pub(crate) enum Message {
Info(Info),
LoginAck(LoginAck),
EnvChange(EnvChange),
Done(Done),
DoneInProc(Done),
DoneProc(Done),
Expand Down
1 change: 0 additions & 1 deletion sqlx-core/src/mssql/protocol/return_status.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use bitflags::bitflags;
use bytes::{Buf, Bytes};

use crate::error::Error;
Expand Down
4 changes: 1 addition & 3 deletions sqlx-core/src/mssql/protocol/row.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
use std::ops::Range;

use bytes::Bytes;

use crate::error::Error;
use crate::mssql::protocol::col_meta_data::ColumnData;
use crate::mssql::{MsSql, MsSqlTypeInfo};
use crate::mssql::MsSqlTypeInfo;

#[derive(Debug)]
pub(crate) struct Row {
Expand Down
4 changes: 3 additions & 1 deletion sqlx-core/src/mssql/protocol/rpc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ use crate::mssql::protocol::header::{AllHeaders, Header};
use crate::mssql::MsSqlArguments;

pub(crate) struct RpcRequest<'a> {
pub(crate) transaction_descriptor: u64,

// the procedure can be encoded as a u16 of a built-in or the name for a custom one
pub(crate) procedure: Either<&'a str, Procedure>,
pub(crate) options: OptionFlags,
Expand Down Expand Up @@ -67,7 +69,7 @@ impl Encode<'_> for RpcRequest<'_> {
fn encode_with(&self, buf: &mut Vec<u8>, _: ()) {
AllHeaders(&[Header::TransactionDescriptor {
outstanding_request_count: 1,
transaction_descriptor: 0,
transaction_descriptor: self.transaction_descriptor,
}])
.encode(buf);

Expand Down
3 changes: 2 additions & 1 deletion sqlx-core/src/mssql/protocol/sql_batch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@ use crate::mssql::protocol::header::{AllHeaders, Header};

#[derive(Debug)]
pub(crate) struct SqlBatch<'a> {
pub(crate) transaction_descriptor: u64,
pub(crate) sql: &'a str,
}

impl Encode<'_> for SqlBatch<'_> {
fn encode_with(&self, buf: &mut Vec<u8>, _: ()) {
AllHeaders(&[Header::TransactionDescriptor {
outstanding_request_count: 1,
transaction_descriptor: 0,
transaction_descriptor: self.transaction_descriptor,
}])
.encode(buf);

Expand Down
2 changes: 0 additions & 2 deletions sqlx-core/src/mssql/protocol/type_info.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
use std::borrow::Cow;

use bitflags::bitflags;
use bytes::{Buf, Bytes};
use encoding_rs::Encoding;
Expand Down
57 changes: 49 additions & 8 deletions sqlx-core/src/mssql/transaction.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
use std::borrow::Cow;

use futures_core::future::BoxFuture;

use crate::error::Error;
use crate::executor::Executor;
use crate::mssql::protocol::packet::PacketType;
use crate::mssql::protocol::sql_batch::SqlBatch;
use crate::mssql::{MsSql, MsSqlConnection};
use crate::transaction::{
begin_ansi_transaction_sql, commit_ansi_transaction_sql, rollback_ansi_transaction_sql,
TransactionManager,
};
use crate::transaction::TransactionManager;

/// Implementation of [`TransactionManager`] for MSSQL.
pub struct MsSqlTransactionManager;
Expand All @@ -15,18 +16,58 @@ impl TransactionManager for MsSqlTransactionManager {
type Database = MsSql;

fn begin(conn: &mut MsSqlConnection, depth: usize) -> BoxFuture<'_, Result<(), Error>> {
unimplemented!()
Box::pin(async move {
let query = if depth == 0 {
Cow::Borrowed("BEGIN TRAN ")
} else {
Cow::Owned(format!("SAVE TRAN _sqlx_savepoint_{}", depth))
};

conn.execute(&*query).await?;

Ok(())
})
}

fn commit(conn: &mut MsSqlConnection, depth: usize) -> BoxFuture<'_, Result<(), Error>> {
unimplemented!()
Box::pin(async move {
if depth == 1 {
// savepoints are not released in MSSQL
conn.execute("COMMIT TRAN").await?;
}

Ok(())
})
}

fn rollback(conn: &mut MsSqlConnection, depth: usize) -> BoxFuture<'_, Result<(), Error>> {
unimplemented!()
Box::pin(async move {
let query = if depth == 1 {
Cow::Borrowed("ROLLBACK TRAN")
} else {
Cow::Owned(format!("ROLLBACK TRAN _sqlx_savepoint_{}", depth - 1))
};

conn.execute(&*query).await?;

Ok(())
})
}

fn start_rollback(conn: &mut MsSqlConnection, depth: usize) {
unimplemented!()
let query = if depth == 1 {
Cow::Borrowed("ROLLBACK TRAN")
} else {
Cow::Owned(format!("ROLLBACK TRAN _sqlx_savepoint_{}", depth - 1))
};

conn.pending_done_count += 1;
conn.stream.write_packet(
PacketType::SqlBatch,
SqlBatch {
transaction_descriptor: conn.stream.transaction_descriptor,
sql: &*query,
},
);
}
}
1 change: 0 additions & 1 deletion sqlx-core/src/mssql/types/float.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use byteorder::{ByteOrder, LittleEndian};

use crate::database::{Database, HasArguments, HasValueRef};
use crate::decode::Decode;
use crate::encode::{Encode, IsNull};
use crate::error::BoxDynError;
Expand Down
Loading

0 comments on commit 559169c

Please sign in to comment.