Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flock lock #90

Merged
merged 15 commits into from
Jan 8, 2025
8 changes: 4 additions & 4 deletions .github/workflows/rust.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,10 @@ jobs:
run: cargo test --no-default-features --verbose

- name: Run Tests (ssl cross)
run: |
cargo test --no-default-features --features ureq,native-tls
cargo test --no-default-features --features ureq,rustls-tls
cargo test --no-default-features --features tokio,native-tls
run: >
cargo test --no-default-features --features ureq,native-tls &&
cargo test --no-default-features --features ureq,rustls-tls &&
cargo test --no-default-features --features tokio,native-tls &&
cargo test --no-default-features --features tokio,rustls-tls

- name: Run Audit
Expand Down
13 changes: 13 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,15 @@ ureq = { version = "2.8.0", optional = true, features = [
] }
native-tls = { version = "0.2.12", optional = true }

[target.'cfg(windows)'.dependencies.windows-sys]
version = "0.59"
features = ["Win32_Foundation", "Win32_Storage_FileSystem", "Win32_System_IO"]
optional = true

[target.'cfg(unix)'.dependencies.libc]
version = "0.2"
optional = true

[features]
default = ["default-tls", "tokio", "ureq"]
# These features are only relevant when used with the `tokio` feature, but this might change in the future.
Expand All @@ -59,6 +68,8 @@ tokio = [
"dep:thiserror",
"dep:tokio",
"tokio/rt-multi-thread",
"dep:libc",
"dep:windows-sys",
]
ureq = [
"dep:http",
Expand All @@ -68,6 +79,8 @@ ureq = [
"dep:serde_json",
"dep:thiserror",
"dep:ureq",
"dep:libc",
"dep:windows-sys",
]

[dev-dependencies]
Expand Down
94 changes: 73 additions & 21 deletions src/api/sync.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ use std::io::Seek;
use std::num::ParseIntError;
use std::path::{Component, Path, PathBuf};
use std::str::FromStr;
use std::time::Duration;
use thiserror::Error;
use ureq::{Agent, AgentBuilder, Request};

Expand Down Expand Up @@ -70,38 +69,90 @@ impl HeaderAgent {
}
}

#[derive(Debug)]
struct Handle {
_file: std::fs::File,
path: PathBuf,
file: std::fs::File,
}

impl Drop for Handle {
fn drop(&mut self) {
std::fs::remove_file(&self.path).expect("Removing lockfile")
unlock(&self.file);
}
}

fn lock_file(mut path: PathBuf) -> Result<Handle, ApiError> {
path.set_extension("lock");

let mut lock_handle = None;
for i in 0..30 {
match std::fs::File::create_new(path.clone()) {
Ok(handle) => {
lock_handle = Some(handle);
break;
}
_ => {
if i == 0 {
log::warn!("Waiting for lock {path:?}");
}
std::thread::sleep(Duration::from_secs(1));
}
let file = std::fs::File::create(path.clone())?;
let mut res = lock(&file);
for _ in 0..5 {
if res == 0 {
break;
}
std::thread::sleep(std::time::Duration::from_secs(1));
res = lock(&file);
}
if res != 0 {
Err(ApiError::LockAcquisition(path))
} else {
Ok(Handle { file })
}
}

#[cfg(target_family = "unix")]
mod unix {
use std::os::fd::AsRawFd;

pub(crate) fn lock(file: &std::fs::File) -> i32 {
unsafe { libc::flock(file.as_raw_fd(), libc::LOCK_EX | libc::LOCK_NB) }
}
pub(crate) fn unlock(file: &std::fs::File) -> i32 {
unsafe { libc::flock(file.as_raw_fd(), libc::LOCK_UN) }
}
}
#[cfg(target_family = "unix")]
use unix::{lock, unlock};

#[cfg(target_family = "windows")]
mod windows {
use std::os::windows::io::AsRawHandle;
use windows_sys::Win32::Foundation::HANDLE;
use windows_sys::Win32::Storage::FileSystem::{
LockFileEx, UnlockFile, LOCKFILE_EXCLUSIVE_LOCK, LOCKFILE_FAIL_IMMEDIATELY,
};

pub(crate) fn lock(file: &std::fs::File) -> i32 {
unsafe {
let mut overlapped = std::mem::zeroed();
let flags = LOCKFILE_EXCLUSIVE_LOCK | LOCKFILE_FAIL_IMMEDIATELY;
let res = LockFileEx(
file.as_raw_handle() as HANDLE,
flags,
0,
!0,
!0,
&mut overlapped,
);
1 - res
}
}
let _file = lock_handle.ok_or_else(|| ApiError::LockAcquisition(path.clone()))?;
Ok(Handle { path, _file })
pub(crate) fn unlock(file: &std::fs::File) -> i32 {
unsafe { UnlockFile(file.as_raw_handle() as HANDLE, 0, 0, !0, !0) }
}
}
#[cfg(target_family = "windows")]
use windows::{lock, unlock};

#[cfg(not(any(target_family = "unix", target_family = "windows")))]
mod other {
pub(crate) fn lock(file: &std::fs::File) -> i32 {
0
}
pub(crate) fn unlock(file: &std::fs::File) -> i32 {
0
}
}
#[cfg(not(any(target_family = "unix", target_family = "windows")))]
use other::{lock, unlock};

#[derive(Debug, Error)]
/// All errors the API can throw
Expand Down Expand Up @@ -688,7 +739,7 @@ impl ApiRepo {
.blob_path(&metadata.etag);
std::fs::create_dir_all(blob_path.parent().unwrap())?;

let lock = lock_file(blob_path.clone())?;
let lock = lock_file(blob_path.clone()).unwrap();
let mut tmp_path = blob_path.clone();
tmp_path.set_extension(EXTENSION);
let tmp_filename =
Expand Down Expand Up @@ -769,6 +820,7 @@ mod tests {
use serde_json::{json, Value};
use sha2::{Digest, Sha256};
use std::io::{Seek, SeekFrom, Write};
use std::time::Duration;

struct TempDir {
path: PathBuf,
Expand Down
107 changes: 84 additions & 23 deletions src/api/tokio.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ use std::collections::BinaryHeap;
use std::num::ParseIntError;
use std::path::{Component, Path, PathBuf};
use std::sync::Arc;
use std::time::Duration;
use thiserror::Error;
use tokio::io::AsyncReadExt;
use tokio::io::{AsyncSeekExt, AsyncWriteExt, SeekFrom};
Expand Down Expand Up @@ -65,36 +64,89 @@ impl Progress for () {
}

struct Handle {
_file: tokio::fs::File,
path: PathBuf,
file: tokio::fs::File,
}

impl Drop for Handle {
fn drop(&mut self) {
std::fs::remove_file(&self.path).expect("Removing lockfile")
unlock(&self.file);
}
}

async fn lock_file(mut path: PathBuf) -> Result<Handle, ApiError> {
path.set_extension("lock");

let mut lock_handle = None;
for i in 0..30 {
match tokio::fs::File::create_new(path.clone()).await {
Ok(handle) => {
lock_handle = Some(handle);
break;
}
Err(_err) => {
if i == 0 {
log::warn!("Waiting for lock {path:?}");
}
tokio::time::sleep(Duration::from_secs(1)).await;
}
let file = tokio::fs::File::create(path.clone()).await?;
let mut res = lock(&file);
for _ in 0..5 {
if res == 0 {
break;
}
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
res = lock(&file);
}
if res != 0 {
Err(ApiError::LockAcquisition(path))
} else {
Ok(Handle { file })
}
}

#[cfg(target_family = "unix")]
mod unix {
use std::os::fd::AsRawFd;

pub(crate) fn lock(file: &tokio::fs::File) -> i32 {
unsafe { libc::flock(file.as_raw_fd(), libc::LOCK_EX | libc::LOCK_NB) }
}
pub(crate) fn unlock(file: &tokio::fs::File) -> i32 {
unsafe { libc::flock(file.as_raw_fd(), libc::LOCK_UN) }
}
}
#[cfg(target_family = "unix")]
use unix::{lock, unlock};

#[cfg(target_family = "windows")]
mod windows {
use std::os::windows::io::AsRawHandle;
use windows_sys::Win32::Foundation::HANDLE;
use windows_sys::Win32::Storage::FileSystem::{
LockFileEx, UnlockFile, LOCKFILE_EXCLUSIVE_LOCK, LOCKFILE_FAIL_IMMEDIATELY,
};

pub(crate) fn lock(file: &tokio::fs::File) -> i32 {
unsafe {
let mut overlapped = std::mem::zeroed();
let flags = LOCKFILE_EXCLUSIVE_LOCK | LOCKFILE_FAIL_IMMEDIATELY;
let res = LockFileEx(
file.as_raw_handle() as HANDLE,
flags,
0,
!0,
!0,
&mut overlapped,
);
1 - res
}
}
let _file = lock_handle.ok_or_else(|| ApiError::LockAcquisition(path.clone()))?;
Ok(Handle { path, _file })
pub(crate) fn unlock(file: &tokio::fs::File) -> i32 {
unsafe { UnlockFile(file.as_raw_handle() as HANDLE, 0, 0, !0, !0) }
}
}
#[cfg(target_family = "windows")]
use windows::{lock, unlock};

#[cfg(not(any(target_family = "unix", target_family = "windows")))]
mod other {
pub(crate) fn lock(file: &tokio::fs::File) -> i32 {
0
}
pub(crate) fn unlock(file: &tokio::fs::File) -> i32 {
0
}
}
#[cfg(not(any(target_family = "unix", target_family = "windows")))]
use other::{lock, unlock};

#[derive(Debug, Error)]
/// All errors the API can throw
Expand Down Expand Up @@ -670,14 +722,21 @@ impl ApiRepo {
.await?;
file.seek(SeekFrom::Start(length as u64)).await?;
file.write_all(&committed.to_le_bytes()).await?;
file.flush().await?;
}
}
tokio::fs::OpenOptions::new()
let mut f = tokio::fs::OpenOptions::new()
.write(true)
.open(&filename)
.await?
.set_len(length as u64)
.await?;
f.set_len(length as u64).await?;
// XXX Extremely important and not obvious.
// Tokio::fs doesn't guarantee data is written at the end of `.await`
// boundaries. Even though we await the `set_len` it may not have been
// committed to disk, leading to invalid rename.
// Forcing a flush forces the data (here the truncation) to be committed to disk
f.flush().await?;

progressbar.finish().await;
Ok(filename)
}
Expand Down Expand Up @@ -714,6 +773,7 @@ impl ApiRepo {
.await?;
file.seek(SeekFrom::Start(start as u64)).await?;
file.write_all(&buf).await?;
file.flush().await?;
Ok((start, stop))
}

Expand Down Expand Up @@ -798,7 +858,7 @@ impl ApiRepo {
let blob_path = cache.blob_path(&metadata.etag);
std::fs::create_dir_all(blob_path.parent().unwrap())?;

let lock = lock_file(blob_path.clone()).await;
let lock = lock_file(blob_path.clone()).await?;
progress.init(metadata.size, filename).await;
let mut tmp_path = blob_path.clone();
tmp_path.set_extension(EXTENSION);
Expand Down Expand Up @@ -859,6 +919,7 @@ mod tests {
use serde_json::{json, Value};
use sha2::{Digest, Sha256};
use std::io::{Seek, Write};
use std::time::Duration;

struct TempDir {
path: PathBuf,
Expand Down
Loading