Skip to content

Commit

Permalink
feat(mssql): fix a few bugs and implement Connection::describe
Browse files Browse the repository at this point in the history
  • Loading branch information
mehcode committed Jun 7, 2020
1 parent 559169c commit ef2527f
Show file tree
Hide file tree
Showing 27 changed files with 424 additions and 61 deletions.
1 change: 1 addition & 0 deletions .gitattributes
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
* text=auto eol=lf
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 5 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -159,3 +159,8 @@ required-features = [ "mssql" ]
name = "mssql-types"
path = "tests/mssql/types.rs"
required-features = [ "mssql" ]

[[test]]
name = "mssql-describe"
path = "tests/mssql/describe.rs"
required-features = [ "mssql" ]
4 changes: 3 additions & 1 deletion sqlx-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ default = [ "runtime-async-std" ]
postgres = [ "md-5", "sha2", "base64", "sha-1", "rand", "hmac", "futures-channel/sink", "futures-util/sink" ]
mysql = [ "sha-1", "sha2", "generic-array", "num-bigint", "base64", "digest", "rand" ]
sqlite = [ "libsqlite3-sys" ]
mssql = [ "uuid", "encoding_rs" ]
mssql = [ "uuid", "encoding_rs", "regex" ]

# types
all-types = [ "chrono", "time", "bigdecimal", "ipnetwork", "json", "uuid" ]
Expand Down Expand Up @@ -65,11 +65,13 @@ log = { version = "0.4.8", default-features = false }
md-5 = { version = "0.8.0", default-features = false, optional = true }
memchr = { version = "2.3.3", default-features = false }
num-bigint = { version = "0.2.6", default-features = false, optional = true, features = [ "std" ] }
once_cell = "1.4.0"
percent-encoding = "2.1.0"
parking_lot = "0.10.2"
threadpool = "*"
phf = { version = "0.8.0", features = [ "macros" ] }
rand = { version = "0.7.3", default-features = false, optional = true, features = [ "std" ] }
regex = { version = "1.3.9", optional = true }
serde = { version = "1.0.106", features = [ "derive", "rc" ], optional = true }
serde_json = { version = "1.0.51", features = [ "raw_value" ], optional = true }
sha-1 = { version = "0.8.2", default-features = false, optional = true }
Expand Down
71 changes: 42 additions & 29 deletions sqlx-core/src/encode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,36 +68,49 @@ where
}
}

impl<'q, T: 'q + Encode<'q, DB>, DB: Database> Encode<'q, DB> for Option<T> {
#[inline]
fn produces(&self) -> DB::TypeInfo {
if let Some(v) = self {
v.produces()
} else {
T::type_info()
}
}
#[allow(unused_macros)]
macro_rules! impl_encode_for_option {
($DB:ident) => {
impl<'q, T: 'q + crate::encode::Encode<'q, $DB>> crate::encode::Encode<'q, $DB>
for Option<T>
{
#[inline]
fn produces(&self) -> <$DB as crate::database::Database>::TypeInfo {
if let Some(v) = self {
v.produces()
} else {
T::type_info()
}
}

#[inline]
fn encode(self, buf: &mut <DB as HasArguments<'q>>::ArgumentBuffer) -> IsNull {
if let Some(v) = self {
v.encode(buf)
} else {
IsNull::Yes
}
}
#[inline]
fn encode(
self,
buf: &mut <$DB as crate::database::HasArguments<'q>>::ArgumentBuffer,
) -> crate::encode::IsNull {
if let Some(v) = self {
v.encode(buf)
} else {
crate::encode::IsNull::Yes
}
}

#[inline]
fn encode_by_ref(&self, buf: &mut <DB as HasArguments<'q>>::ArgumentBuffer) -> IsNull {
if let Some(v) = self {
v.encode_by_ref(buf)
} else {
IsNull::Yes
}
}
#[inline]
fn encode_by_ref(
&self,
buf: &mut <$DB as crate::database::HasArguments<'q>>::ArgumentBuffer,
) -> crate::encode::IsNull {
if let Some(v) = self {
v.encode_by_ref(buf)
} else {
crate::encode::IsNull::Yes
}
}

#[inline]
fn size_hint(&self) -> usize {
self.as_ref().map_or(0, Encode::size_hint)
}
#[inline]
fn size_hint(&self) -> usize {
self.as_ref().map_or(0, crate::encode::Encode::size_hint)
}
}
};
}
8 changes: 7 additions & 1 deletion sqlx-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,12 @@ pub mod connection;
#[macro_use]
pub mod transaction;

#[macro_use]
pub mod encode;

pub mod database;
pub mod decode;
pub mod describe;
pub mod encode;
pub mod executor;
mod ext;
pub mod from_row;
Expand All @@ -59,3 +61,7 @@ pub mod sqlite;
#[cfg(feature = "mysql")]
#[cfg_attr(docsrs, doc(cfg(feature = "mysql")))]
pub mod mysql;

#[cfg(feature = "mssql")]
#[cfg_attr(docsrs, doc(cfg(feature = "mssql")))]
pub mod mssql;
14 changes: 14 additions & 0 deletions sqlx-core/src/mssql/arguments.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use crate::arguments::Arguments;
use crate::encode::Encode;
use crate::mssql::database::MsSql;
use crate::mssql::io::MsSqlBufMutExt;
use crate::mssql::protocol::rpc::StatusFlags;

#[derive(Default)]
pub struct MsSqlArguments {
Expand Down Expand Up @@ -31,6 +32,19 @@ impl MsSqlArguments {
self.add_named("", value);
}

pub(crate) fn declare<'q, T: Encode<'q, MsSql>>(&mut self, name: &str, initial_value: T) {
let ty = initial_value.produces();

let mut ty_name = String::new();
ty.0.fmt(&mut ty_name);

self.data.put_b_varchar(name); // [ParamName]
self.data.push(StatusFlags::BY_REF_VALUE.bits()); // [StatusFlags]

ty.0.put(&mut self.data); // [TYPE_INFO]
ty.0.put_value(&mut self.data, initial_value); // [ParamLenData]
}

pub(crate) fn append(&mut self, arguments: &mut MsSqlArguments) {
self.ordinal += arguments.ordinal;
self.data.append(&mut arguments.data);
Expand Down
3 changes: 1 addition & 2 deletions sqlx-core/src/mssql/connection/establish.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,7 @@ impl MsSqlConnection {
server_name: "",
client_interface_name: "",
language: "",
// FIXME: connect this to options.database
database: "",
database: &*options.database,
client_id: [0; 6],
},
);
Expand Down
126 changes: 109 additions & 17 deletions sqlx-core/src/mssql/connection/executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,19 @@ use either::Either;
use futures_core::future::BoxFuture;
use futures_core::stream::BoxStream;
use futures_util::TryStreamExt;
use once_cell::sync::Lazy;
use regex::Regex;

use crate::describe::Describe;
use crate::describe::{Column, Describe};
use crate::error::Error;
use crate::executor::{Execute, Executor};
use crate::mssql::protocol::done::Done;
use crate::mssql::protocol::col_meta_data::Flags;
use crate::mssql::protocol::done::{Done, Status};
use crate::mssql::protocol::message::Message;
use crate::mssql::protocol::packet::PacketType;
use crate::mssql::protocol::rpc::{OptionFlags, Procedure, RpcRequest};
use crate::mssql::protocol::sql_batch::SqlBatch;
use crate::mssql::{MsSql, MsSqlArguments, MsSqlConnection, MsSqlRow};
use crate::mssql::{MsSql, MsSqlArguments, MsSqlConnection, MsSqlRow, MsSqlTypeInfo};

impl MsSqlConnection {
pub(crate) async fn wait_until_ready(&mut self) -> Result<(), Error> {
Expand All @@ -25,8 +28,10 @@ impl MsSqlConnection {
let message = self.stream.recv_message().await?;

if let Message::DoneProc(done) | Message::Done(done) = message {
// finished RPC procedure *OR* SQL batch
self.handle_done(done);
if !done.status.contains(Status::DONE_MORE) {
// finished RPC procedure *OR* SQL batch
self.handle_done(done);
}
}
}

Expand Down Expand Up @@ -106,20 +111,23 @@ impl<'c> Executor<'c> for &'c mut MsSqlConnection {
yield v;
}

Message::DoneProc(done) => {
self.handle_done(done);
break;
}
Message::Done(done) | Message::DoneProc(done) => {
if done.status.contains(Status::DONE_COUNT) {
let v = Either::Left(done.affected_rows);
yield v;
}

Message::DoneInProc(done) => {
// finished SQL query *within* procedure
let v = Either::Left(done.affected_rows);
yield v;
if !done.status.contains(Status::DONE_MORE) {
self.handle_done(done);
break;
}
}

Message::Done(done) => {
self.handle_done(done);
break;
Message::DoneInProc(done) => {
if done.status.contains(Status::DONE_COUNT) {
let v = Either::Left(done.affected_rows);
yield v;
}
}

_ => {}
Expand Down Expand Up @@ -157,6 +165,90 @@ impl<'c> Executor<'c> for &'c mut MsSqlConnection {
'c: 'e,
E: Execute<'q, Self::Database>,
{
unimplemented!()
let s = query.query();

// [sp_prepare] will emit the column meta data
// small issue is that we need to declare all the used placeholders with a "fallback" type
// we currently use regex to collect them; false positives are *okay* but false
// negatives would break the query
let proc = Either::Right(Procedure::Prepare);

// NOTE: this does not support unicode identifiers; as we don't even support
// named parameters (yet) this is probably fine, for now

static PARAMS_RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"@p[[:alnum:]]+").unwrap());

let mut params = String::new();
let mut num_params = 0;

for m in PARAMS_RE.captures_iter(s) {
if !params.is_empty() {
params.push_str(",");
}

params.push_str(&m[0]);

// NOTE: this means that a query! of `SELECT @p1` will have the macros believe
// it will return nvarchar(1); this is a greater issue with `query!` that we
// we need to circle back to. This doesn't happen much in practice however.
params.push_str(" nvarchar(1)");

num_params += 1;
}

let params = if params.is_empty() {
None
} else {
Some(&*params)
};

let mut args = MsSqlArguments::default();

args.declare("", 0_i32);
args.add_unnamed(params);
args.add_unnamed(s);
args.add_unnamed(0x0001_i32); // 1 = SEND_METADATA

self.stream.write_packet(
PacketType::Rpc,
RpcRequest {
transaction_descriptor: self.stream.transaction_descriptor,
arguments: &args,
procedure: proc,
options: OptionFlags::empty(),
},
);

Box::pin(async move {
self.stream.flush().await?;

loop {
match self.stream.recv_message().await? {
Message::DoneProc(done) | Message::Done(done) => {
if !done.status.contains(Status::DONE_MORE) {
// done with prepare
break;
}
}

_ => {}
}
}

let mut columns = Vec::with_capacity(self.stream.columns.len());

for col in &self.stream.columns {
columns.push(Column {
name: col.col_name.clone(),
type_info: Some(MsSqlTypeInfo(col.type_info.clone())),
not_null: Some(!col.flags.contains(Flags::NULLABLE)),
});
}

Ok(Describe {
params: vec![None; num_params],
columns,
})
})
}
}
5 changes: 4 additions & 1 deletion sqlx-core/src/mssql/connection/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use crate::mssql::protocol::login_ack::LoginAck;
use crate::mssql::protocol::message::{Message, MessageType};
use crate::mssql::protocol::packet::{PacketHeader, PacketType, Status};
use crate::mssql::protocol::return_status::ReturnStatus;
use crate::mssql::protocol::return_value::ReturnValue;
use crate::mssql::protocol::row::Row;
use crate::mssql::{MsSqlConnectOptions, MsSqlDatabaseError};
use crate::net::MaybeTlsStream;
Expand All @@ -30,7 +31,7 @@ pub(crate) struct MsSqlStream {

// most recent column data from ColMetaData
// we need to store this as its needed when decoding <Row>
columns: Vec<ColumnData>,
pub(crate) columns: Vec<ColumnData>,
}

impl MsSqlStream {
Expand Down Expand Up @@ -112,6 +113,7 @@ impl MsSqlStream {
};

let ty = MessageType::get(buf)?;

let message = match ty {
MessageType::EnvChange => {
match EnvChange::get(buf)? {
Expand All @@ -137,6 +139,7 @@ impl MsSqlStream {
MessageType::Row => Message::Row(Row::get(buf, &self.columns)?),
MessageType::LoginAck => Message::LoginAck(LoginAck::get(buf)?),
MessageType::ReturnStatus => Message::ReturnStatus(ReturnStatus::get(buf)?),
MessageType::ReturnValue => Message::ReturnValue(ReturnValue::get(buf)?),
MessageType::Done => Message::Done(Done::get(buf)?),
MessageType::DoneInProc => Message::DoneInProc(Done::get(buf)?),
MessageType::DoneProc => Message::DoneProc(Done::get(buf)?),
Expand Down
Loading

0 comments on commit ef2527f

Please sign in to comment.