Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(fs): support virtio-fs DAX mapping #31

Merged
merged 2 commits into from
Jun 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions alioth/src/board/x86_64.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ where
let memory = &self.memory;

let low_mem_size = std::cmp::min(config.mem_size, RAM_32_SIZE);
let pages_low = ArcMemPages::from_anonymous(low_mem_size, None, Some(c"ram-low"))?;
let pages_low = ArcMemPages::from_memfd(low_mem_size, None, Some(c"ram-low"))?;
if self.config.coco.is_some() {
self.memory.ram_bus().register_encrypted_pages(&pages_low)?;
}
Expand Down Expand Up @@ -161,11 +161,8 @@ where
};
memory.add_region(AddrOpt::Fixed(0), Arc::new(region_low))?;
if config.mem_size > RAM_32_SIZE {
let mem_hi = ArcMemPages::from_anonymous(
config.mem_size - RAM_32_SIZE,
None,
Some(c"ram-high"),
)?;
let mem_hi =
ArcMemPages::from_memfd(config.mem_size - RAM_32_SIZE, None, Some(c"ram-high"))?;
if self.config.coco.is_some() {
self.memory.ram_bus().register_encrypted_pages(&mem_hi)?;
}
Expand Down
2 changes: 1 addition & 1 deletion alioth/src/loader/firmware/x86_64.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ pub fn load<P: AsRef<Path>>(memory: &Memory, path: P) -> Result<(InitState, ArcM
let size = file.metadata()?.len() as usize;
assert_eq!(size & 0xfff, 0);

let mut rom = ArcMemPages::from_anonymous(size, None, Some(c"rom"))?;
let mut rom = ArcMemPages::from_memfd(size, None, Some(c"rom"))?;
file.read_exact(rom.as_slice_mut())?;

let gpa = MEM_64_START - size;
Expand Down
31 changes: 20 additions & 11 deletions alioth/src/mem/mapped.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::Arc;

use libc::{
c_void, mmap, msync, munmap, MAP_FAILED, MAP_SHARED, MFD_CLOEXEC, MS_ASYNC, PROT_EXEC,
PROT_READ, PROT_WRITE,
c_void, mmap, msync, munmap, MAP_ANONYMOUS, MAP_FAILED, MAP_PRIVATE, MAP_SHARED, MFD_CLOEXEC,
MS_ASYNC, PROT_EXEC, PROT_READ, PROT_WRITE,
};
use parking_lot::{RwLock, RwLockReadGuard};
use zerocopy::{AsBytes, FromBytes};
Expand All @@ -41,7 +41,7 @@ use super::{Error, Result};
struct MemPages {
addr: NonNull<c_void>,
len: usize,
fd: File,
fd: Option<File>,
}

unsafe impl Send for MemPages {}
Expand Down Expand Up @@ -83,16 +83,16 @@ impl ArcMemPages {
self.size
}

pub fn fd(&self) -> BorrowedFd {
self._inner.fd.as_fd()
pub fn fd(&self) -> Option<BorrowedFd> {
self._inner.fd.as_ref().map(|f| f.as_fd())
}

pub fn sync(&self) -> Result<()> {
ffi!(unsafe { msync(self.addr as *mut _, self.size, MS_ASYNC) })?;
Ok(())
}

fn from_raw(addr: *mut c_void, len: usize, fd: File) -> Self {
fn from_raw(addr: *mut c_void, len: usize, fd: Option<File>) -> Self {
let addr = NonNull::new(addr).expect("address from mmap() should not be null");
ArcMemPages {
addr: addr.as_ptr() as usize,
Expand All @@ -106,10 +106,10 @@ impl ArcMemPages {
unsafe { mmap(null_mut(), len, prot, MAP_SHARED, file.as_raw_fd(), offset) },
MAP_FAILED
)?;
Ok(Self::from_raw(addr, len, file))
Ok(Self::from_raw(addr, len, Some(file)))
}

pub fn from_anonymous(size: usize, prot: Option<i32>, name: Option<&CStr>) -> Result<Self> {
pub fn from_memfd(size: usize, prot: Option<i32>, name: Option<&CStr>) -> Result<Self> {
let name = name.unwrap_or(c"anon");
let fd = ffi!(unsafe { libc::memfd_create(name.as_ptr(), MFD_CLOEXEC) })?;
let prot = prot.unwrap_or(PROT_WRITE | PROT_READ | PROT_EXEC);
Expand All @@ -119,7 +119,16 @@ impl ArcMemPages {
)?;
let file = unsafe { File::from_raw_fd(fd) };
file.set_len(size as _)?;
Ok(Self::from_raw(addr, size, file))
Ok(Self::from_raw(addr, size, Some(file)))
}

pub fn from_anonymous(size: usize, prot: Option<i32>) -> Result<Self> {
let prot = prot.unwrap_or(PROT_WRITE | PROT_READ | PROT_EXEC);
let addr = ffi!(
unsafe { mmap(null_mut(), size, prot, MAP_ANONYMOUS | MAP_PRIVATE, -1, 0) },
MAP_FAILED
)?;
Ok(Self::from_raw(addr, size, None))
}

/// Given offset and len, return the host virtual address and len;
Expand Down Expand Up @@ -581,8 +590,8 @@ mod test {
fn test_ram_bus_read() {
let bus = RamBus::new(FakeVmMemory);
let prot = PROT_READ | PROT_WRITE;
let mem1 = ArcMemPages::from_anonymous(PAGE_SIZE, Some(prot), None).unwrap();
let mem2 = ArcMemPages::from_anonymous(PAGE_SIZE, Some(prot), None).unwrap();
let mem1 = ArcMemPages::from_memfd(PAGE_SIZE, Some(prot), None).unwrap();
let mem2 = ArcMemPages::from_memfd(PAGE_SIZE, Some(prot), None).unwrap();

if mem1.addr > mem2.addr {
bus.add(0x0, mem1).unwrap();
Expand Down
137 changes: 127 additions & 10 deletions alioth/src/virtio/dev/fs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,29 +12,35 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::io::ErrorKind;
use std::iter::zip;
use std::mem::size_of_val;
use std::os::fd::{AsRawFd, FromRawFd, OwnedFd};
use std::path::PathBuf;
use std::sync::atomic::Ordering;
use std::sync::Arc;

use bitflags::bitflags;
use libc::{eventfd, EFD_CLOEXEC, EFD_NONBLOCK};
use libc::{
eventfd, mmap, EFD_CLOEXEC, EFD_NONBLOCK, MAP_ANONYMOUS, MAP_FAILED, MAP_FIXED, MAP_PRIVATE,
MAP_SHARED, PROT_NONE,
};
use mio::event::Event;
use mio::unix::SourceFd;
use mio::{Interest, Registry, Token};
use serde::Deserialize;
use zerocopy::{AsBytes, FromBytes, FromZeroes};

use crate::hv::IoeventFd;
use crate::mem::mapped::RamBus;
use crate::mem::mapped::{ArcMemPages, RamBus};
use crate::mem::{MemRegion, MemRegionType};
use crate::virtio::dev::{DevParam, Virtio};
use crate::virtio::queue::{Queue, VirtQueue};
use crate::virtio::vu::{
DeviceConfig, MemoryRegion, MemorySingleRegion, VirtqAddr, VirtqState, VuDev, VuFeature,
};
use crate::virtio::{DeviceId, Error, IrqSender, Result, VirtioFeature};
use crate::{ffi, impl_mmio_for_zerocopy};
use crate::{align_up, ffi, impl_mmio_for_zerocopy};

#[repr(C, align(4))]
#[derive(Debug, FromBytes, FromZeroes, AsBytes)]
Expand All @@ -53,6 +59,18 @@ bitflags! {
}
}

#[derive(Debug, Clone, FromBytes, FromZeroes, AsBytes)]
#[repr(C)]
struct VuFsMap {
pub fd_offset: [u64; 8],
pub cache_offset: [u64; 8],
pub len: [u64; 8],
pub flags: [u64; 8],
}

const VHOST_USER_BACKEND_FS_MAP: u32 = 6;
const VHOST_USER_BACKEND_FS_UNMAP: u32 = 7;

#[derive(Debug)]
pub struct VuFs {
name: Arc<String>,
Expand All @@ -61,12 +79,13 @@ pub struct VuFs {
feature: u64,
num_queues: u16,
regions: Vec<MemoryRegion>,
dax_region: Option<ArcMemPages>,
error_fds: Vec<OwnedFd>,
}

impl VuFs {
pub fn new(param: VuFsParam, name: Arc<String>) -> Result<Self> {
let vu_dev = VuDev::new(param.socket)?;
let mut vu_dev = VuDev::new(param.socket)?;
let dev_feat = vu_dev.get_features()?;
let virtio_feat = VirtioFeature::from_bits_retain(dev_feat);
let need_feat = VirtioFeature::VHOST_PROTOCOL | VirtioFeature::VERSION_1;
Expand All @@ -80,6 +99,10 @@ impl VuFs {
if param.tag.is_none() {
need_feat |= VuFeature::CONFIG;
}
if param.dax_window > 0 {
assert!(param.dax_window.count_ones() == 1 && param.dax_window > (4 << 10));
need_feat |= VuFeature::BACKEND_REQ | VuFeature::BACKEND_SEND_FD;
}
if !prot_feat.contains(need_feat) {
return Err(Error::VuMissingProtocolFeature(need_feat & !prot_feat));
}
Expand All @@ -103,6 +126,13 @@ impl VuFs {
let dev_config = vu_dev.get_config(&empty_cfg)?;
FsConfig::read_from_prefix(&dev_config.region).unwrap()
};
let dax_region = if param.dax_window > 0 {
vu_dev.setup_channel()?;
let size = align_up!(param.dax_window, 4 << 10);
Some(ArcMemPages::from_anonymous(size, Some(PROT_NONE))?)
} else {
None
};

Ok(VuFs {
num_queues,
Expand All @@ -112,6 +142,7 @@ impl VuFs {
feature: dev_feat & !VirtioFeature::VHOST_PROTOCOL.bits(),
regions: Vec::new(),
error_fds: Vec::new(),
dax_region,
})
}
}
Expand All @@ -120,6 +151,8 @@ impl VuFs {
pub struct VuFsParam {
pub socket: PathBuf,
pub tag: Option<String>,
#[serde(default)]
pub dax_window: usize,
}

impl DevParam for VuFsParam {
Expand Down Expand Up @@ -156,6 +189,9 @@ impl Virtio for VuFs {
.set_features(&(feature | VirtioFeature::VHOST_PROTOCOL.bits()))?;
let mem = memory.lock_layout();
for (gpa, slot) in mem.iter() {
let Some(fd) = slot.pages.fd() else {
continue;
};
let region = MemorySingleRegion {
_padding: 0,
region: MemoryRegion {
Expand All @@ -165,9 +201,7 @@ impl Virtio for VuFs {
mmap_offset: 0,
},
};
self.vu_dev
.add_mem_region(&region, slot.pages.fd().as_raw_fd())
.unwrap();
self.vu_dev.add_mem_region(&region, fd.as_raw_fd()).unwrap();
log::info!("region: {region:x?}");
self.regions.push(region.region);
}
Expand Down Expand Up @@ -221,18 +255,93 @@ impl Virtio for VuFs {
self.vu_dev.set_virtq_enable(&virtq_enable).unwrap();
log::info!("virtq_enable: {virtq_enable:x?}");
}
if let Some(channel) = self.vu_dev.get_channel() {
channel.set_nonblocking(true)?;
registry.register(
&mut SourceFd(&channel.as_raw_fd()),
Token(self.num_queues as _),
Interest::READABLE,
)?;
}
Ok(())
}

fn handle_event(
&mut self,
event: &Event,
_queues: &[impl VirtQueue],
queues: &[impl VirtQueue],
_irq_sender: &impl IrqSender,
_registry: &Registry,
) -> Result<()> {
let q_index = event.token();
Err(Error::VuQueueErr(q_index.0 as _))
let q_index = event.token().0;
if q_index < queues.len() {
return Err(Error::VuQueueErr(q_index as _));
}

let Some(dax_region) = &self.dax_region else {
return Err(Error::VuMissingProtocolFeature(VuFeature::BACKEND_REQ));
};
loop {
let mut fs_map = VuFsMap::new_zeroed();
let mut fds = [None, None, None, None, None, None, None, None];
let ret = self
.vu_dev
.receive_from_channel(fs_map.as_bytes_mut(), &mut fds);
let (request, size) = match ret {
Ok((r, s)) => (r, s),
Err(Error::Io(e)) if e.kind() == ErrorKind::WouldBlock => break,
Err(e) => return Err(e),
};
if size as usize != size_of_val(&fs_map) {
return Err(Error::VuInvalidPayloadSize(size_of_val(&fs_map), size));
}
match request {
VHOST_USER_BACKEND_FS_MAP => {
for (index, fd) in fds.iter().enumerate() {
let Some(fd) = fd else {
break;
};
let raw_fd = fd.as_raw_fd();
let map_addr = dax_region.addr() + fs_map.cache_offset[index] as usize;
log::trace!(
"{}: mapping fd {raw_fd} to offset {:#x}",
self.name,
fs_map.cache_offset[index]
);
ffi!(
unsafe {
mmap(
map_addr as _,
fs_map.len[index] as _,
fs_map.flags[index] as _,
MAP_SHARED | MAP_FIXED,
raw_fd,
fs_map.fd_offset[index] as _,
)
},
MAP_FAILED
)?;
}
}
VHOST_USER_BACKEND_FS_UNMAP => {
for (len, offset) in zip(fs_map.len, fs_map.cache_offset) {
if len == 0 {
continue;
}
log::trace!("{}: unmapping offset {offset:#x}, size {len:#x}", self.name);
let map_addr = dax_region.addr() + offset as usize;
let flags = MAP_ANONYMOUS | MAP_PRIVATE | MAP_FIXED;
ffi!(
unsafe { mmap(map_addr as _, len as _, PROT_NONE, flags, -1, 0) },
MAP_FAILED
)?;
}
}
_ => unimplemented!("unknown request {request:#x}"),
}
self.vu_dev.ack_request(request, &0u64)?;
}
Ok(())
}

fn handle_queue(
Expand Down Expand Up @@ -285,4 +394,12 @@ impl Virtio for VuFs {
Err(Error::InvalidQueueIndex(q_index))
}
}

fn shared_mem_regions(&self) -> Option<Arc<MemRegion>> {
let dax_region = self.dax_region.as_ref()?;
Some(Arc::new(MemRegion::with_mapped(
dax_region.clone(),
MemRegionType::Hidden,
)))
}
}
11 changes: 7 additions & 4 deletions alioth/src/virtio/virtio.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,11 @@ pub enum Error {
#[error("Invalid vhost user response message, want {0}, got {1}")]
InvalidVhostRespMsg(u32, u32),

#[error("Invalid vhost user response size, want {0}, get {1}")]
InvalidVhostRespSize(usize, usize),
#[error("Invalid vhost user message size, want {0}, get {1}")]
VuMessageSizeMismatch(usize, usize),

#[error("Invalid vhost user response payload size, want {0}, got {1}")]
InvalidVhostRespPayloadSize(usize, u32),
#[error("Invalid vhost user message payload size, want {0}, got {1}")]
VuInvalidPayloadSize(usize, u32),

#[error("vhost-user backend replied error code {0:#x} to request {1:#x}")]
VuRequestErr(u64, u32),
Expand All @@ -73,6 +73,9 @@ pub enum Error {
#[error("vhost-user backend is missing protocol feature {0:x?}")]
VuMissingProtocolFeature(vu::VuFeature),

#[error("insufficient buffer (size {0}) for holding {1} fds")]
VuInsufficientBuffer(usize, usize),

#[error("vhost backend is missing device feature {0:#x}")]
VhostMissingDeviceFeature(u64),

Expand Down
Loading