Skip to content

Commit

Permalink
WIP feat(sqlite): create better constructors for SqliteConnectOptions
Browse files Browse the repository at this point in the history
  • Loading branch information
abonander committed Aug 3, 2024
1 parent 4acecfc commit b307691
Show file tree
Hide file tree
Showing 9 changed files with 617 additions and 70 deletions.
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 3 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,8 @@ uuid = "1.1.2"

# Common utility crates
dotenvy = { version = "0.15.0", default-features = false }
tempfile = "3.10.1"
once_cell = { version = "1.19.0", default-features = false, features = ["std"] }

# Runtimes
[workspace.dependencies.async-std]
Expand Down Expand Up @@ -176,7 +178,7 @@ url = "2.2.2"
rand = "0.8.4"
rand_xoshiro = "0.6.0"
hex = "0.4.3"
tempfile = "3.10.1"
tempfile = { workspace = true }
criterion = { version = "0.5.1", features = ["async_tokio"] }

# If this is an unconditional dev-dependency then Cargo will *always* try to build `libsqlite3-sys`,
Expand Down
14 changes: 11 additions & 3 deletions sqlx-core/src/rt/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,22 +79,30 @@ where

#[track_caller]
pub fn spawn_blocking<F, R>(f: F) -> JoinHandle<R>
where
F: FnOnce() -> R + Send + 'static,
R: Send + 'static,
{
try_spawn_blocking(f).unwrap_or_else(missing_rt)
}

pub fn try_spawn_blocking<F, R>(f: F) -> Result<JoinHandle<R>, F>
where
F: FnOnce() -> R + Send + 'static,
R: Send + 'static,
{
#[cfg(feature = "_rt-tokio")]
if let Ok(handle) = tokio::runtime::Handle::try_current() {
return JoinHandle::Tokio(handle.spawn_blocking(f));
return Ok(JoinHandle::Tokio(handle.spawn_blocking(f)));
}

#[cfg(feature = "_rt-async-std")]
{
JoinHandle::AsyncStd(async_std::task::spawn_blocking(f))
Ok(JoinHandle::AsyncStd(async_std::task::spawn_blocking(f)))
}

#[cfg(not(feature = "_rt-async-std"))]
missing_rt(f)
Err(f)
}

pub async fn yield_now() {
Expand Down
3 changes: 3 additions & 0 deletions sqlx-sqlite/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ tracing = { version = "0.1.37", features = ["log"] }
serde = { version = "1.0.145", features = ["derive"], optional = true }
regex = { version = "1.5.5", optional = true }

tempfile = { workspace = true }
once_cell = { workspace = true }

[dependencies.libsqlite3-sys]
version = "0.30.1"
default-features = false
Expand Down
148 changes: 106 additions & 42 deletions sqlx-sqlite/src/connection/establish.rs
Original file line number Diff line number Diff line change
@@ -1,23 +1,29 @@
use crate::connection::handle::ConnectionHandle;
use crate::connection::LogSettings;
use crate::connection::{ConnectionState, Statements};
use crate::error::Error;
use crate::{SqliteConnectOptions, SqliteError};
use std::borrow::Cow;
use std::collections::BTreeMap;
use std::ffi::{c_void, CStr, CString};
use std::io;
use std::os::raw::c_int;
use std::ptr::{addr_of_mut, null, null_mut};
use std::sync::atomic::{AtomicUsize, Ordering};
use std::time::Duration;

use libsqlite3_sys::{
sqlite3, sqlite3_busy_timeout, sqlite3_db_config, sqlite3_extended_result_codes, sqlite3_free,
sqlite3_load_extension, sqlite3_open_v2, SQLITE_DBCONFIG_ENABLE_LOAD_EXTENSION, SQLITE_OK,
SQLITE_OPEN_CREATE, SQLITE_OPEN_FULLMUTEX, SQLITE_OPEN_MEMORY, SQLITE_OPEN_NOMUTEX,
SQLITE_OPEN_PRIVATECACHE, SQLITE_OPEN_READONLY, SQLITE_OPEN_READWRITE, SQLITE_OPEN_SHAREDCACHE,
SQLITE_OPEN_URI,
};
use percent_encoding::NON_ALPHANUMERIC;

use sqlx_core::IndexMap;
use std::collections::BTreeMap;
use std::ffi::{c_void, CStr, CString};
use std::io;
use std::os::raw::c_int;
use std::ptr::{addr_of_mut, null, null_mut};
use std::sync::atomic::{AtomicUsize, Ordering};
use std::time::Duration;

use crate::connection::handle::ConnectionHandle;
use crate::connection::LogSettings;
use crate::connection::{ConnectionState, Statements};
use crate::error::Error;
use crate::options::{Filename, SqliteTempPath};
use crate::{SqliteConnectOptions, SqliteError};

// This was originally `AtomicU64` but that's not supported on MIPS (or PowerPC):
// https://github.com/launchbadge/sqlx/issues/2859
Expand All @@ -42,7 +48,7 @@ impl SqliteLoadExtensionMode {
}

pub struct EstablishParams {
filename: CString,
filename: EstablishFilename,
open_flags: i32,
busy_timeout: Duration,
statement_cache_capacity: usize,
Expand All @@ -54,20 +60,16 @@ pub struct EstablishParams {
register_regexp_function: bool,
}

enum EstablishFilename {
Owned(CString),
Temp {
temp: SqliteTempPath,
query: Option<String>,
},
}

impl EstablishParams {
pub fn from_options(options: &SqliteConnectOptions) -> Result<Self, Error> {
let mut filename = options
.filename
.to_str()
.ok_or_else(|| {
io::Error::new(
io::ErrorKind::InvalidData,
"filename passed to SQLite must be valid UTF-8",
)
})?
.to_owned();

// By default, we connect to an in-memory database.
// [SQLITE_OPEN_NOMUTEX] will instruct [sqlite3_open_v2] to return an error if it
// cannot satisfy our wish for a thread-safe, lock-free connection object

Expand Down Expand Up @@ -105,21 +107,51 @@ impl EstablishParams {
query_params.insert("vfs", vfs);
}

if !query_params.is_empty() {
filename = format!(
"file:{}?{}",
percent_encoding::percent_encode(filename.as_bytes(), NON_ALPHANUMERIC),
serde_urlencoded::to_string(&query_params).unwrap()
);
flags |= libsqlite3_sys::SQLITE_OPEN_URI;
}
let filename = match &options.filename {
Filename::Owned(owned) => {
let filename_str = owned.to_str().ok_or_else(|| {
io::Error::new(
io::ErrorKind::InvalidData,
"filename passed to SQLite must be valid UTF-8",
)
})?;

let filename = if !query_params.is_empty() {
flags |= SQLITE_OPEN_URI;

format!(
"file:{}?{}",
percent_encoding::percent_encode(filename_str.as_bytes(), NON_ALPHANUMERIC),
serde_urlencoded::to_string(&query_params)
.expect("BUG: failed to URL encode query parameters")
)
} else {
filename_str.to_string()
};

let filename = CString::new(filename).map_err(|_| {
io::Error::new(
io::ErrorKind::InvalidData,
"filename passed to SQLite must not contain nul bytes",
)
})?;

EstablishFilename::Owned(filename)
}
Filename::Temp(temp) => {
let query = (!query_params.is_empty()).then(|| {
flags |= SQLITE_OPEN_URI;

serde_urlencoded::to_string(&query_params)
.expect("BUG: failed to URL encode query parameters")
});

let filename = CString::new(filename).map_err(|_| {
io::Error::new(
io::ErrorKind::InvalidData,
"filename passed to SQLite must not contain nul bytes",
)
})?;
EstablishFilename::Temp {
temp: temp.clone(),
query,
}
}
};

let extensions = options
.extensions
Expand Down Expand Up @@ -187,12 +219,43 @@ impl EstablishParams {
}

pub(crate) fn establish(&self) -> Result<ConnectionState, Error> {
let mut open_flags = self.open_flags;

let (filename, temp) = match &self.filename {
EstablishFilename::Owned(cstr) => (Cow::Borrowed(&**cstr), None),
EstablishFilename::Temp { temp, query } => {
let path = temp.force_create_blocking()?.to_str().ok_or_else(|| {
io::Error::new(
io::ErrorKind::InvalidData,
"filename passed to SQLite must be valid UTF-8",
)
})?;

let filename = if let Some(query) = query {
// Ensure the flag is set.
open_flags |= SQLITE_OPEN_URI;
format!("file:{path}?{query}")
} else {
path.to_string()
};

(
Cow::Owned(CString::new(filename).map_err(|_| {
io::Error::new(
io::ErrorKind::InvalidData,
"filename passed to SQLite must not contain nul bytes",
)
})?),
Some(temp)
)
}
};

let mut handle = null_mut();

// <https://www.sqlite.org/c3ref/open.html>
let mut status = unsafe {
sqlite3_open_v2(self.filename.as_ptr(), &mut handle, self.open_flags, null())
};
let mut status =
unsafe { sqlite3_open_v2(filename.as_ptr(), &mut handle, open_flags, null()) };

if handle.is_null() {
// Failed to allocate memory
Expand Down Expand Up @@ -296,6 +359,7 @@ impl EstablishParams {
log_settings: self.log_settings.clone(),
progress_handler_callback: None,
update_hook_callback: None,
_temp: temp.cloned()
})
}
}
10 changes: 8 additions & 2 deletions sqlx-sqlite/src/connection/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use std::os::raw::{c_char, c_int, c_void};
use std::panic::catch_unwind;
use std::ptr;
use std::ptr::NonNull;

use std::sync::Arc;
use futures_core::future::BoxFuture;
use futures_intrusive::sync::MutexGuard;
use futures_util::future;
Expand All @@ -24,7 +24,7 @@ use sqlx_core::transaction::Transaction;

use crate::connection::establish::EstablishParams;
use crate::connection::worker::ConnectionWorker;
use crate::options::OptimizeOnClose;
use crate::options::{OptimizeOnClose, SqliteTempPath, TempFilename};
use crate::statement::VirtualStatement;
use crate::{Sqlite, SqliteConnectOptions};

Expand Down Expand Up @@ -106,6 +106,12 @@ pub(crate) struct ConnectionState {
progress_handler_callback: Option<Handler>,

update_hook_callback: Option<UpdateHookHandler>,

/// (MUST BE LAST) If applicable, hold a strong ref to the temporary directory
/// until the connection is closed.
///
/// When the last strong ref is dropped, the temporary directory is deleted.
pub(crate) _temp: Option<SqliteTempPath>,
}

impl ConnectionState {
Expand Down
2 changes: 1 addition & 1 deletion sqlx-sqlite/src/connection/worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use crate::connection::establish::EstablishParams;
use crate::connection::execute;
use crate::connection::ConnectionState;
use crate::{Sqlite, SqliteArguments, SqliteQueryResult, SqliteRow, SqliteStatement};

use crate::options::TempFilename;
// Each SQLite connection has a dedicated thread.

// TODO: Tweak this so that we can use a thread pool per pool of SQLite3 connections to reduce
Expand Down
Loading

0 comments on commit b307691

Please sign in to comment.