From 955212af7f82557e29931334e2d23df41f61c758 Mon Sep 17 00:00:00 2001 From: Changyuan Lyu Date: Sun, 2 Jun 2024 11:22:57 -0700 Subject: [PATCH 1/2] fix(mem): add private anonymous memory back Private anonymous memory is useful when a device does not want other devices to access its memory. Fixes: a53681865381 ("feat(mem)!: create anonymous mem with memfd_create") Signed-off-by: Changyuan Lyu --- alioth/src/board/x86_64.rs | 9 +++----- alioth/src/loader/firmware/x86_64.rs | 2 +- alioth/src/mem/mapped.rs | 31 ++++++++++++++++++---------- alioth/src/virtio/dev/fs.rs | 7 ++++--- 4 files changed, 28 insertions(+), 21 deletions(-) diff --git a/alioth/src/board/x86_64.rs b/alioth/src/board/x86_64.rs index 7cb935e..5c46710 100644 --- a/alioth/src/board/x86_64.rs +++ b/alioth/src/board/x86_64.rs @@ -132,7 +132,7 @@ where let memory = &self.memory; let low_mem_size = std::cmp::min(config.mem_size, RAM_32_SIZE); - let pages_low = ArcMemPages::from_anonymous(low_mem_size, None, Some(c"ram-low"))?; + let pages_low = ArcMemPages::from_memfd(low_mem_size, None, Some(c"ram-low"))?; if self.config.coco.is_some() { self.memory.ram_bus().register_encrypted_pages(&pages_low)?; } @@ -161,11 +161,8 @@ where }; memory.add_region(AddrOpt::Fixed(0), Arc::new(region_low))?; if config.mem_size > RAM_32_SIZE { - let mem_hi = ArcMemPages::from_anonymous( - config.mem_size - RAM_32_SIZE, - None, - Some(c"ram-high"), - )?; + let mem_hi = + ArcMemPages::from_memfd(config.mem_size - RAM_32_SIZE, None, Some(c"ram-high"))?; if self.config.coco.is_some() { self.memory.ram_bus().register_encrypted_pages(&mem_hi)?; } diff --git a/alioth/src/loader/firmware/x86_64.rs b/alioth/src/loader/firmware/x86_64.rs index 104138e..07dd35b 100644 --- a/alioth/src/loader/firmware/x86_64.rs +++ b/alioth/src/loader/firmware/x86_64.rs @@ -29,7 +29,7 @@ pub fn load>(memory: &Memory, path: P) -> Result<(InitState, ArcM let size = file.metadata()?.len() as usize; assert_eq!(size & 0xfff, 0); - let mut rom = ArcMemPages::from_anonymous(size, None, Some(c"rom"))?; + let mut rom = ArcMemPages::from_memfd(size, None, Some(c"rom"))?; file.read_exact(rom.as_slice_mut())?; let gpa = MEM_64_START - size; diff --git a/alioth/src/mem/mapped.rs b/alioth/src/mem/mapped.rs index f08895c..dbeec9a 100644 --- a/alioth/src/mem/mapped.rs +++ b/alioth/src/mem/mapped.rs @@ -25,8 +25,8 @@ use std::sync::atomic::{AtomicU32, Ordering}; use std::sync::Arc; use libc::{ - c_void, mmap, msync, munmap, MAP_FAILED, MAP_SHARED, MFD_CLOEXEC, MS_ASYNC, PROT_EXEC, - PROT_READ, PROT_WRITE, + c_void, mmap, msync, munmap, MAP_ANONYMOUS, MAP_FAILED, MAP_PRIVATE, MAP_SHARED, MFD_CLOEXEC, + MS_ASYNC, PROT_EXEC, PROT_READ, PROT_WRITE, }; use parking_lot::{RwLock, RwLockReadGuard}; use zerocopy::{AsBytes, FromBytes}; @@ -41,7 +41,7 @@ use super::{Error, Result}; struct MemPages { addr: NonNull, len: usize, - fd: File, + fd: Option, } unsafe impl Send for MemPages {} @@ -83,8 +83,8 @@ impl ArcMemPages { self.size } - pub fn fd(&self) -> BorrowedFd { - self._inner.fd.as_fd() + pub fn fd(&self) -> Option { + self._inner.fd.as_ref().map(|f| f.as_fd()) } pub fn sync(&self) -> Result<()> { @@ -92,7 +92,7 @@ impl ArcMemPages { Ok(()) } - fn from_raw(addr: *mut c_void, len: usize, fd: File) -> Self { + fn from_raw(addr: *mut c_void, len: usize, fd: Option) -> Self { let addr = NonNull::new(addr).expect("address from mmap() should not be null"); ArcMemPages { addr: addr.as_ptr() as usize, @@ -106,10 +106,10 @@ impl ArcMemPages { unsafe { mmap(null_mut(), len, prot, MAP_SHARED, file.as_raw_fd(), offset) }, MAP_FAILED )?; - Ok(Self::from_raw(addr, len, file)) + Ok(Self::from_raw(addr, len, Some(file))) } - pub fn from_anonymous(size: usize, prot: Option, name: Option<&CStr>) -> Result { + pub fn from_memfd(size: usize, prot: Option, name: Option<&CStr>) -> Result { let name = name.unwrap_or(c"anon"); let fd = ffi!(unsafe { libc::memfd_create(name.as_ptr(), MFD_CLOEXEC) })?; let prot = prot.unwrap_or(PROT_WRITE | PROT_READ | PROT_EXEC); @@ -119,7 +119,16 @@ impl ArcMemPages { )?; let file = unsafe { File::from_raw_fd(fd) }; file.set_len(size as _)?; - Ok(Self::from_raw(addr, size, file)) + Ok(Self::from_raw(addr, size, Some(file))) + } + + pub fn from_anonymous(size: usize, prot: Option) -> Result { + let prot = prot.unwrap_or(PROT_WRITE | PROT_READ | PROT_EXEC); + let addr = ffi!( + unsafe { mmap(null_mut(), size, prot, MAP_ANONYMOUS | MAP_PRIVATE, -1, 0) }, + MAP_FAILED + )?; + Ok(Self::from_raw(addr, size, None)) } /// Given offset and len, return the host virtual address and len; @@ -581,8 +590,8 @@ mod test { fn test_ram_bus_read() { let bus = RamBus::new(FakeVmMemory); let prot = PROT_READ | PROT_WRITE; - let mem1 = ArcMemPages::from_anonymous(PAGE_SIZE, Some(prot), None).unwrap(); - let mem2 = ArcMemPages::from_anonymous(PAGE_SIZE, Some(prot), None).unwrap(); + let mem1 = ArcMemPages::from_memfd(PAGE_SIZE, Some(prot), None).unwrap(); + let mem2 = ArcMemPages::from_memfd(PAGE_SIZE, Some(prot), None).unwrap(); if mem1.addr > mem2.addr { bus.add(0x0, mem1).unwrap(); diff --git a/alioth/src/virtio/dev/fs.rs b/alioth/src/virtio/dev/fs.rs index 3636dc1..396d15d 100644 --- a/alioth/src/virtio/dev/fs.rs +++ b/alioth/src/virtio/dev/fs.rs @@ -156,6 +156,9 @@ impl Virtio for VuFs { .set_features(&(feature | VirtioFeature::VHOST_PROTOCOL.bits()))?; let mem = memory.lock_layout(); for (gpa, slot) in mem.iter() { + let Some(fd) = slot.pages.fd() else { + continue; + }; let region = MemorySingleRegion { _padding: 0, region: MemoryRegion { @@ -165,9 +168,7 @@ impl Virtio for VuFs { mmap_offset: 0, }, }; - self.vu_dev - .add_mem_region(®ion, slot.pages.fd().as_raw_fd()) - .unwrap(); + self.vu_dev.add_mem_region(®ion, fd.as_raw_fd()).unwrap(); log::info!("region: {region:x?}"); self.regions.push(region.region); } From ed396f0905f619b859e6185d6bd547dcbbfaf072 Mon Sep 17 00:00:00 2001 From: Changyuan Lyu Date: Sat, 1 Jun 2024 19:37:20 -0700 Subject: [PATCH 2/2] feat(fs): support virtio-fs DAX mapping Signed-off-by: Changyuan Lyu --- alioth/src/virtio/dev/fs.rs | 130 ++++++++++++++++++++++++++++++++++-- alioth/src/virtio/virtio.rs | 11 +-- alioth/src/virtio/vu.rs | 112 ++++++++++++++++++++++++++++--- 3 files changed, 231 insertions(+), 22 deletions(-) diff --git a/alioth/src/virtio/dev/fs.rs b/alioth/src/virtio/dev/fs.rs index 396d15d..34e6073 100644 --- a/alioth/src/virtio/dev/fs.rs +++ b/alioth/src/virtio/dev/fs.rs @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::io::ErrorKind; +use std::iter::zip; use std::mem::size_of_val; use std::os::fd::{AsRawFd, FromRawFd, OwnedFd}; use std::path::PathBuf; @@ -19,7 +21,10 @@ use std::sync::atomic::Ordering; use std::sync::Arc; use bitflags::bitflags; -use libc::{eventfd, EFD_CLOEXEC, EFD_NONBLOCK}; +use libc::{ + eventfd, mmap, EFD_CLOEXEC, EFD_NONBLOCK, MAP_ANONYMOUS, MAP_FAILED, MAP_FIXED, MAP_PRIVATE, + MAP_SHARED, PROT_NONE, +}; use mio::event::Event; use mio::unix::SourceFd; use mio::{Interest, Registry, Token}; @@ -27,14 +32,15 @@ use serde::Deserialize; use zerocopy::{AsBytes, FromBytes, FromZeroes}; use crate::hv::IoeventFd; -use crate::mem::mapped::RamBus; +use crate::mem::mapped::{ArcMemPages, RamBus}; +use crate::mem::{MemRegion, MemRegionType}; use crate::virtio::dev::{DevParam, Virtio}; use crate::virtio::queue::{Queue, VirtQueue}; use crate::virtio::vu::{ DeviceConfig, MemoryRegion, MemorySingleRegion, VirtqAddr, VirtqState, VuDev, VuFeature, }; use crate::virtio::{DeviceId, Error, IrqSender, Result, VirtioFeature}; -use crate::{ffi, impl_mmio_for_zerocopy}; +use crate::{align_up, ffi, impl_mmio_for_zerocopy}; #[repr(C, align(4))] #[derive(Debug, FromBytes, FromZeroes, AsBytes)] @@ -53,6 +59,18 @@ bitflags! { } } +#[derive(Debug, Clone, FromBytes, FromZeroes, AsBytes)] +#[repr(C)] +struct VuFsMap { + pub fd_offset: [u64; 8], + pub cache_offset: [u64; 8], + pub len: [u64; 8], + pub flags: [u64; 8], +} + +const VHOST_USER_BACKEND_FS_MAP: u32 = 6; +const VHOST_USER_BACKEND_FS_UNMAP: u32 = 7; + #[derive(Debug)] pub struct VuFs { name: Arc, @@ -61,12 +79,13 @@ pub struct VuFs { feature: u64, num_queues: u16, regions: Vec, + dax_region: Option, error_fds: Vec, } impl VuFs { pub fn new(param: VuFsParam, name: Arc) -> Result { - let vu_dev = VuDev::new(param.socket)?; + let mut vu_dev = VuDev::new(param.socket)?; let dev_feat = vu_dev.get_features()?; let virtio_feat = VirtioFeature::from_bits_retain(dev_feat); let need_feat = VirtioFeature::VHOST_PROTOCOL | VirtioFeature::VERSION_1; @@ -80,6 +99,10 @@ impl VuFs { if param.tag.is_none() { need_feat |= VuFeature::CONFIG; } + if param.dax_window > 0 { + assert!(param.dax_window.count_ones() == 1 && param.dax_window > (4 << 10)); + need_feat |= VuFeature::BACKEND_REQ | VuFeature::BACKEND_SEND_FD; + } if !prot_feat.contains(need_feat) { return Err(Error::VuMissingProtocolFeature(need_feat & !prot_feat)); } @@ -103,6 +126,13 @@ impl VuFs { let dev_config = vu_dev.get_config(&empty_cfg)?; FsConfig::read_from_prefix(&dev_config.region).unwrap() }; + let dax_region = if param.dax_window > 0 { + vu_dev.setup_channel()?; + let size = align_up!(param.dax_window, 4 << 10); + Some(ArcMemPages::from_anonymous(size, Some(PROT_NONE))?) + } else { + None + }; Ok(VuFs { num_queues, @@ -112,6 +142,7 @@ impl VuFs { feature: dev_feat & !VirtioFeature::VHOST_PROTOCOL.bits(), regions: Vec::new(), error_fds: Vec::new(), + dax_region, }) } } @@ -120,6 +151,8 @@ impl VuFs { pub struct VuFsParam { pub socket: PathBuf, pub tag: Option, + #[serde(default)] + pub dax_window: usize, } impl DevParam for VuFsParam { @@ -222,18 +255,93 @@ impl Virtio for VuFs { self.vu_dev.set_virtq_enable(&virtq_enable).unwrap(); log::info!("virtq_enable: {virtq_enable:x?}"); } + if let Some(channel) = self.vu_dev.get_channel() { + channel.set_nonblocking(true)?; + registry.register( + &mut SourceFd(&channel.as_raw_fd()), + Token(self.num_queues as _), + Interest::READABLE, + )?; + } Ok(()) } fn handle_event( &mut self, event: &Event, - _queues: &[impl VirtQueue], + queues: &[impl VirtQueue], _irq_sender: &impl IrqSender, _registry: &Registry, ) -> Result<()> { - let q_index = event.token(); - Err(Error::VuQueueErr(q_index.0 as _)) + let q_index = event.token().0; + if q_index < queues.len() { + return Err(Error::VuQueueErr(q_index as _)); + } + + let Some(dax_region) = &self.dax_region else { + return Err(Error::VuMissingProtocolFeature(VuFeature::BACKEND_REQ)); + }; + loop { + let mut fs_map = VuFsMap::new_zeroed(); + let mut fds = [None, None, None, None, None, None, None, None]; + let ret = self + .vu_dev + .receive_from_channel(fs_map.as_bytes_mut(), &mut fds); + let (request, size) = match ret { + Ok((r, s)) => (r, s), + Err(Error::Io(e)) if e.kind() == ErrorKind::WouldBlock => break, + Err(e) => return Err(e), + }; + if size as usize != size_of_val(&fs_map) { + return Err(Error::VuInvalidPayloadSize(size_of_val(&fs_map), size)); + } + match request { + VHOST_USER_BACKEND_FS_MAP => { + for (index, fd) in fds.iter().enumerate() { + let Some(fd) = fd else { + break; + }; + let raw_fd = fd.as_raw_fd(); + let map_addr = dax_region.addr() + fs_map.cache_offset[index] as usize; + log::trace!( + "{}: mapping fd {raw_fd} to offset {:#x}", + self.name, + fs_map.cache_offset[index] + ); + ffi!( + unsafe { + mmap( + map_addr as _, + fs_map.len[index] as _, + fs_map.flags[index] as _, + MAP_SHARED | MAP_FIXED, + raw_fd, + fs_map.fd_offset[index] as _, + ) + }, + MAP_FAILED + )?; + } + } + VHOST_USER_BACKEND_FS_UNMAP => { + for (len, offset) in zip(fs_map.len, fs_map.cache_offset) { + if len == 0 { + continue; + } + log::trace!("{}: unmapping offset {offset:#x}, size {len:#x}", self.name); + let map_addr = dax_region.addr() + offset as usize; + let flags = MAP_ANONYMOUS | MAP_PRIVATE | MAP_FIXED; + ffi!( + unsafe { mmap(map_addr as _, len as _, PROT_NONE, flags, -1, 0) }, + MAP_FAILED + )?; + } + } + _ => unimplemented!("unknown request {request:#x}"), + } + self.vu_dev.ack_request(request, &0u64)?; + } + Ok(()) } fn handle_queue( @@ -286,4 +394,12 @@ impl Virtio for VuFs { Err(Error::InvalidQueueIndex(q_index)) } } + + fn shared_mem_regions(&self) -> Option> { + let dax_region = self.dax_region.as_ref()?; + Some(Arc::new(MemRegion::with_mapped( + dax_region.clone(), + MemRegionType::Hidden, + ))) + } } diff --git a/alioth/src/virtio/virtio.rs b/alioth/src/virtio/virtio.rs index 421e493..7858d6e 100644 --- a/alioth/src/virtio/virtio.rs +++ b/alioth/src/virtio/virtio.rs @@ -55,11 +55,11 @@ pub enum Error { #[error("Invalid vhost user response message, want {0}, got {1}")] InvalidVhostRespMsg(u32, u32), - #[error("Invalid vhost user response size, want {0}, get {1}")] - InvalidVhostRespSize(usize, usize), + #[error("Invalid vhost user message size, want {0}, get {1}")] + VuMessageSizeMismatch(usize, usize), - #[error("Invalid vhost user response payload size, want {0}, got {1}")] - InvalidVhostRespPayloadSize(usize, u32), + #[error("Invalid vhost user message payload size, want {0}, got {1}")] + VuInvalidPayloadSize(usize, u32), #[error("vhost-user backend replied error code {0:#x} to request {1:#x}")] VuRequestErr(u64, u32), @@ -73,6 +73,9 @@ pub enum Error { #[error("vhost-user backend is missing protocol feature {0:x?}")] VuMissingProtocolFeature(vu::VuFeature), + #[error("insufficient buffer (size {0}) for holding {1} fds")] + VuInsufficientBuffer(usize, usize), + #[error("vhost backend is missing device feature {0:#x}")] VhostMissingDeviceFeature(u64), diff --git a/alioth/src/virtio/vu.rs b/alioth/src/virtio/vu.rs index ba7f0a5..c36a5fd 100644 --- a/alioth/src/virtio/vu.rs +++ b/alioth/src/virtio/vu.rs @@ -12,9 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::io::{IoSlice, IoSliceMut, Read}; +use std::io::{IoSlice, IoSliceMut, Read, Write}; +use std::iter::zip; use std::mem::{size_of, size_of_val}; -use std::os::fd::{AsRawFd, RawFd}; +use std::os::fd::{AsRawFd, FromRawFd, OwnedFd, RawFd}; use std::os::unix::net::UnixStream; use std::path::Path; use std::ptr::null_mut; @@ -114,6 +115,9 @@ impl MessageFlag { pub const fn sender() -> Self { MessageFlag(MessageFlag::VERSION_1 | MessageFlag::NEED_REPLY) } + pub const fn receiver() -> Self { + MessageFlag(MessageFlag::VERSION_1 | MessageFlag::REPLY) + } } #[derive(Debug, AsBytes, FromBytes, FromZeroes)] @@ -178,15 +182,36 @@ pub struct Message { #[derive(Debug)] pub struct VuDev { conn: UnixStream, + channel: Option, } impl VuDev { pub fn new>(sock: P) -> Result { Ok(VuDev { conn: UnixStream::connect(sock)?, + channel: None, }) } + pub fn setup_channel(&mut self) -> Result<()> { + if self.channel.is_some() { + return Ok(()); + } + let mut socket_fds = [0; 2]; + ffi!(unsafe { + libc::socketpair(libc::PF_UNIX, libc::SOCK_STREAM, 0, socket_fds.as_mut_ptr()) + })?; + self.set_backend_req_fd(socket_fds[1])?; + ffi!(unsafe { libc::close(socket_fds[1]) })?; + let channel = unsafe { UnixStream::from_raw_fd(socket_fds[0]) }; + self.channel = Some(channel); + Ok(()) + } + + pub fn get_channel(&self) -> Option<&UnixStream> { + self.channel.as_ref() + } + fn send_msg( &self, req: u32, @@ -252,24 +277,18 @@ impl VuDev { let read_size = (&self.conn).read_vectored(&mut bufs)?; let expect_size = size_of::() + bufs[1].len(); if read_size != expect_size { - return Err(Error::InvalidVhostRespSize(expect_size, read_size)); + return Err(Error::VuMessageSizeMismatch(expect_size, read_size)); } if resp.request != req { return Err(Error::InvalidVhostRespMsg(req, resp.request)); } if size_of::() != 0 { if resp.size != size_of::() as u32 { - return Err(Error::InvalidVhostRespPayloadSize( - size_of::(), - resp.size, - )); + return Err(Error::VuInvalidPayloadSize(size_of::(), resp.size)); } } else { if resp.size != size_of::() as u32 { - return Err(Error::InvalidVhostRespPayloadSize( - size_of::(), - resp.size, - )); + return Err(Error::VuInvalidPayloadSize(size_of::(), resp.size)); } if ret_code != 0 { return Err(Error::VuRequestErr(ret_code, req)); @@ -338,4 +357,75 @@ impl VuDev { pub fn remove_mem_region(&self, payload: &MemorySingleRegion) -> Result<()> { self.send_msg(VHOST_USER_REM_MEM_REG, payload, &[]) } + + fn set_backend_req_fd(&self, fd: RawFd) -> Result<()> { + self.send_msg(VHOST_USER_SET_BACKEND_REQ_FD, &0u64, &[fd]) + } + + pub fn receive_from_channel( + &self, + buf: &mut [u8], + fds: &mut [Option], + ) -> Result<(u32, u32)> { + let mut msg = Message::new_zeroed(); + let mut bufs = [IoSliceMut::new(msg.as_bytes_mut()), IoSliceMut::new(buf)]; + const CMSG_BUF_LEN: usize = unsafe { libc::CMSG_SPACE(8) } as usize; + debug_assert_eq!(CMSG_BUF_LEN % size_of::(), 0); + let mut cmsg_buf = [0u64; CMSG_BUF_LEN / size_of::()]; + let mut uds_msg = libc::msghdr { + msg_name: null_mut(), + msg_namelen: 0, + msg_iov: bufs.as_mut_ptr() as _, + msg_iovlen: bufs.len(), + msg_control: cmsg_buf.as_mut_ptr() as _, + msg_controllen: CMSG_BUF_LEN, + msg_flags: 0, + }; + let Some(channel) = &self.channel else { + return Err(Error::VuMissingProtocolFeature(VuFeature::BACKEND_REQ)); + }; + let r_size = ffi!(unsafe { libc::recvmsg(channel.as_raw_fd(), &mut uds_msg, 0) })? as usize; + let expected_size = size_of::() + msg.size as usize; + if r_size != expected_size { + return Err(Error::VuMessageSizeMismatch(expected_size, r_size)); + } + + let cmsg_ptr = unsafe { libc::CMSG_FIRSTHDR(&uds_msg) }; + if cmsg_ptr.is_null() { + return Ok((msg.request, msg.size)); + } + let cmsg = unsafe { &*cmsg_ptr }; + if cmsg.cmsg_level != libc::SOL_SOCKET || cmsg.cmsg_type != libc::SCM_RIGHTS { + return Ok((msg.request, msg.size)); + } + let cmsg_data_ptr = unsafe { libc::CMSG_DATA(cmsg_ptr) } as *const RawFd; + let count = + (cmsg_ptr as usize + cmsg.cmsg_len - cmsg_data_ptr as usize) / size_of::(); + if count > fds.len() { + return Err(Error::VuInsufficientBuffer(fds.len(), count)); + } + for (fd, index) in zip(fds.iter_mut(), 0..count) { + *fd = Some(unsafe { + OwnedFd::from_raw_fd(std::ptr::read_unaligned(cmsg_data_ptr.add(index))) + }); + } + Ok((msg.request, msg.size)) + } + + pub fn ack_request(&self, req: u32, payload: &T) -> Result<()> { + let Some(channel) = &self.channel else { + return Err(Error::VuMissingProtocolFeature(VuFeature::BACKEND_REQ)); + }; + let msg = Message { + request: req, + flag: MessageFlag::receiver(), + size: size_of_val(payload) as _, + }; + let bufs = [ + IoSlice::new(msg.as_bytes()), + IoSlice::new(payload.as_bytes()), + ]; + Write::write_vectored(&mut (&*channel), &bufs)?; + Ok(()) + } }