diff --git a/examples/aarch64/src/main.rs b/examples/aarch64/src/main.rs index 2d7ad046..90e067a6 100644 --- a/examples/aarch64/src/main.rs +++ b/examples/aarch64/src/main.rs @@ -105,7 +105,7 @@ extern "C" fn main(x0: u64, x1: u64, x2: u64, x3: u64) { debug!("Found VirtIO MMIO device at {:?}", region); let header = NonNull::new(region.starting_address as *mut VirtIOHeader).unwrap(); - match unsafe { MmioTransport::new(header) } { + match unsafe { MmioTransport::new(header, region.size.unwrap()) } { Err(e) => warn!("Error creating VirtIO MMIO transport: {}", e), Ok(transport) => { info!( diff --git a/examples/riscv/src/main.rs b/examples/riscv/src/main.rs index 1435353d..0d91a0b6 100644 --- a/examples/riscv/src/main.rs +++ b/examples/riscv/src/main.rs @@ -69,7 +69,7 @@ fn virtio_probe(node: FdtNode) { node.compatible().map(Compatible::first), ); let header = NonNull::new(vaddr as *mut VirtIOHeader).unwrap(); - match unsafe { MmioTransport::new(header) } { + match unsafe { MmioTransport::new(header, size) } { Err(e) => warn!("Error creating VirtIO MMIO transport: {}", e), Ok(transport) => { info!( diff --git a/examples/x86_64/Cargo.toml b/examples/x86_64/Cargo.toml index 640d83a6..a2358a71 100644 --- a/examples/x86_64/Cargo.toml +++ b/examples/x86_64/Cargo.toml @@ -11,7 +11,10 @@ default = ["tcp"] [dependencies] log = "0.4.17" spin = "0.9" -x86_64 = "0.14" +x86_64 = { version = "0.14.12", default-features = false, features = [ + "instructions", + "abi_x86_interrupt", +] } uart_16550 = "0.2" linked_list_allocator = "0.10" lazy_static = { version = "1.4.0", features = ["spin_no_std"] } diff --git a/src/lib.rs b/src/lib.rs index 4e1d29c2..aa3737df 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -14,9 +14,9 @@ //! use core::ptr::NonNull; //! use virtio_drivers::transport::mmio::{MmioTransport, VirtIOHeader}; //! -//! # fn example(mmio_device_address: usize) { +//! # fn example(mmio_device_address: usize, mmio_size: usize) { //! let header = NonNull::new(mmio_device_address as *mut VirtIOHeader).unwrap(); -//! let transport = unsafe { MmioTransport::new(header) }.unwrap(); +//! let transport = unsafe { MmioTransport::new(header, mmio_size) }.unwrap(); //! # } //! ``` //! diff --git a/src/queue.rs b/src/queue.rs index cef85932..b21d8429 100644 --- a/src/queue.rs +++ b/src/queue.rs @@ -995,7 +995,9 @@ mod tests { #[test] fn queue_too_big() { let mut header = VirtIOHeader::make_fake_header(MODERN_VERSION, 1, 0, 0, 4); - let mut transport = unsafe { MmioTransport::new(NonNull::from(&mut header)) }.unwrap(); + let mut transport = + unsafe { MmioTransport::new(NonNull::from(&mut header), size_of::()) } + .unwrap(); assert_eq!( VirtQueue::::new(&mut transport, 0, false, false).unwrap_err(), Error::InvalidParam @@ -1005,7 +1007,9 @@ mod tests { #[test] fn queue_already_used() { let mut header = VirtIOHeader::make_fake_header(MODERN_VERSION, 1, 0, 0, 4); - let mut transport = unsafe { MmioTransport::new(NonNull::from(&mut header)) }.unwrap(); + let mut transport = + unsafe { MmioTransport::new(NonNull::from(&mut header), size_of::()) } + .unwrap(); VirtQueue::::new(&mut transport, 0, false, false).unwrap(); assert_eq!( VirtQueue::::new(&mut transport, 0, false, false).unwrap_err(), @@ -1016,7 +1020,9 @@ mod tests { #[test] fn add_empty() { let mut header = VirtIOHeader::make_fake_header(MODERN_VERSION, 1, 0, 0, 4); - let mut transport = unsafe { MmioTransport::new(NonNull::from(&mut header)) }.unwrap(); + let mut transport = + unsafe { MmioTransport::new(NonNull::from(&mut header), size_of::()) } + .unwrap(); let mut queue = VirtQueue::::new(&mut transport, 0, false, false).unwrap(); assert_eq!( unsafe { queue.add(&[], &mut []) }.unwrap_err(), @@ -1027,7 +1033,9 @@ mod tests { #[test] fn add_too_many() { let mut header = VirtIOHeader::make_fake_header(MODERN_VERSION, 1, 0, 0, 4); - let mut transport = unsafe { MmioTransport::new(NonNull::from(&mut header)) }.unwrap(); + let mut transport = + unsafe { MmioTransport::new(NonNull::from(&mut header), size_of::()) } + .unwrap(); let mut queue = VirtQueue::::new(&mut transport, 0, false, false).unwrap(); assert_eq!(queue.available_desc(), 4); assert_eq!( @@ -1039,7 +1047,9 @@ mod tests { #[test] fn add_buffers() { let mut header = VirtIOHeader::make_fake_header(MODERN_VERSION, 1, 0, 0, 4); - let mut transport = unsafe { MmioTransport::new(NonNull::from(&mut header)) }.unwrap(); + let mut transport = + unsafe { MmioTransport::new(NonNull::from(&mut header), size_of::()) } + .unwrap(); let mut queue = VirtQueue::::new(&mut transport, 0, false, false).unwrap(); assert_eq!(queue.available_desc(), 4); @@ -1102,7 +1112,9 @@ mod tests { use core::ptr::slice_from_raw_parts; let mut header = VirtIOHeader::make_fake_header(MODERN_VERSION, 1, 0, 0, 4); - let mut transport = unsafe { MmioTransport::new(NonNull::from(&mut header)) }.unwrap(); + let mut transport = + unsafe { MmioTransport::new(NonNull::from(&mut header), size_of::()) } + .unwrap(); let mut queue = VirtQueue::::new(&mut transport, 0, true, false).unwrap(); assert_eq!(queue.available_desc(), 4); diff --git a/src/transport/mmio.rs b/src/transport/mmio.rs index 1d00f751..f70a404f 100644 --- a/src/transport/mmio.rs +++ b/src/transport/mmio.rs @@ -12,6 +12,7 @@ use core::{ mem::{align_of, size_of}, ptr::NonNull, }; +use zerocopy::{FromBytes, Immutable, IntoBytes}; const MAGIC_VALUE: u32 = 0x7472_6976; pub(crate) const LEGACY_VERSION: u32 = 1; @@ -61,6 +62,9 @@ pub enum MmioError { /// The header reports a device ID of 0. #[error("Device ID was zero")] ZeroDeviceId, + /// The MMIO region size was smaller than the header size we expect. + #[error("MMIO region too small")] + MmioRegionTooSmall, } /// MMIO Device Register Interface, both legacy and modern. @@ -263,6 +267,8 @@ impl VirtIOHeader { pub struct MmioTransport { header: NonNull, version: MmioVersion, + /// The size in bytes of the config space. + config_space_size: usize, } impl MmioTransport { @@ -272,7 +278,7 @@ impl MmioTransport { /// # Safety /// `header` must point to a properly aligned valid VirtIO MMIO region, which must remain valid /// for the lifetime of the transport that is returned. - pub unsafe fn new(header: NonNull) -> Result { + pub unsafe fn new(header: NonNull, mmio_size: usize) -> Result { let magic = volread!(header, magic); if magic != MAGIC_VALUE { return Err(MmioError::BadMagic(magic)); @@ -280,8 +286,15 @@ impl MmioTransport { if volread!(header, device_id) == 0 { return Err(MmioError::ZeroDeviceId); } + let Some(config_space_size) = mmio_size.checked_sub(CONFIG_SPACE_OFFSET) else { + return Err(MmioError::MmioRegionTooSmall); + }; let version = volread!(header, version).try_into()?; - Ok(Self { header, version }) + Ok(Self { + header, + version, + config_space_size, + }) } /// Gets the version of the VirtIO MMIO transport. @@ -484,40 +497,52 @@ impl Transport for MmioTransport { } } - fn read_config_space(&self, offset: usize) -> Result { + fn read_config_space(&self, offset: usize) -> Result { assert!(align_of::() <= 4, "Driver expected config space alignment of {} bytes, but VirtIO only guarantees 4 byte alignment.", align_of::()); assert!(offset % align_of::() == 0); - // SAFETY: The caller of `MmioTransport::new` guaranteed that the header pointer was valid, - // which includes the config space. - unsafe { - Ok(self - .header - .cast::() - .byte_add(CONFIG_SPACE_OFFSET) - .byte_add(offset) - .read_volatile()) + if self.config_space_size < offset + size_of::() { + Err(Error::ConfigSpaceTooSmall) + } else { + // SAFETY: The caller of `MmioTransport::new` guaranteed that the header pointer was valid, + // which includes the config space. + unsafe { + Ok(self + .header + .cast::() + .byte_add(CONFIG_SPACE_OFFSET) + .byte_add(offset) + .read_volatile()) + } } } - fn write_config_space(&mut self, offset: usize, value: T) -> Result<(), Error> { + fn write_config_space( + &mut self, + offset: usize, + value: T, + ) -> Result<(), Error> { assert!(align_of::() <= 4, "Driver expected config space alignment of {} bytes, but VirtIO only guarantees 4 byte alignment.", align_of::()); assert!(offset % align_of::() == 0); - // SAFETY: The caller of `MmioTransport::new` guaranteed that the header pointer was valid, - // which includes the config space. - unsafe { - self.header - .cast::() - .byte_add(CONFIG_SPACE_OFFSET) - .byte_add(offset) - .write_volatile(value); + if self.config_space_size < offset + size_of::() { + Err(Error::ConfigSpaceTooSmall) + } else { + // SAFETY: The caller of `MmioTransport::new` guaranteed that the header pointer was valid, + // which includes the config space. + unsafe { + self.header + .cast::() + .byte_add(CONFIG_SPACE_OFFSET) + .byte_add(offset) + .write_volatile(value); + } + Ok(()) } - Ok(()) } } diff --git a/src/transport/mod.rs b/src/transport/mod.rs index b258d15c..3e8a0990 100644 --- a/src/transport/mod.rs +++ b/src/transport/mod.rs @@ -11,7 +11,7 @@ use bitflags::{bitflags, Flags}; use core::{fmt::Debug, ops::BitAnd}; use log::debug; pub use some::SomeTransport; -use zerocopy::{FromBytes, IntoBytes}; +use zerocopy::{FromBytes, Immutable, IntoBytes}; /// A VirtIO transport layer. pub trait Transport { @@ -105,7 +105,11 @@ pub trait Transport { fn read_config_space(&self, offset: usize) -> Result; /// Writes a value to the device config space. - fn write_config_space(&mut self, offset: usize, value: T) -> Result<()>; + fn write_config_space( + &mut self, + offset: usize, + value: T, + ) -> Result<()>; } bitflags! { diff --git a/src/transport/pci.rs b/src/transport/pci.rs index 5ac41145..59bd78d1 100644 --- a/src/transport/pci.rs +++ b/src/transport/pci.rs @@ -18,6 +18,7 @@ use core::{ mem::{align_of, size_of}, ptr::{addr_of_mut, NonNull}, }; +use zerocopy::{FromBytes, Immutable, IntoBytes}; /// The PCI vendor ID for VirtIO devices. const VIRTIO_VENDOR_ID: u16 = 0x1af4; @@ -325,7 +326,7 @@ impl Transport for PciTransport { isr_status & 0x3 != 0 } - fn read_config_space(&self, offset: usize) -> Result { + fn read_config_space(&self, offset: usize) -> Result { assert!(align_of::() <= 4, "Driver expected config space alignment of {} bytes, but VirtIO only guarantees 4 byte alignment.", align_of::()); @@ -338,15 +339,18 @@ impl Transport for PciTransport { // SAFETY: If we have a config space pointer it must be valid for its length, and we just // checked that the offset and size of the access was within the length. unsafe { - // TODO: Use NonNull::as_non_null_ptr once it is stable. - Ok((config_space.as_ptr() as *mut T) + Ok((config_space.as_ptr().cast::()) .byte_add(offset) .read_volatile()) } } } - fn write_config_space(&mut self, offset: usize, value: T) -> Result<(), Error> { + fn write_config_space( + &mut self, + offset: usize, + value: T, + ) -> Result<(), Error> { assert!(align_of::() <= 4, "Driver expected config space alignment of {} bytes, but VirtIO only guarantees 4 byte alignment.", align_of::()); @@ -359,8 +363,7 @@ impl Transport for PciTransport { // SAFETY: If we have a config space pointer it must be valid for its length, and we just // checked that the offset and size of the access was within the length. unsafe { - // TODO: Use NonNull::as_non_null_ptr once it is stable. - (config_space.as_ptr() as *mut T) + (config_space.as_ptr().cast::()) .byte_add(offset) .write_volatile(value); } diff --git a/src/transport/some.rs b/src/transport/some.rs index 4806dfce..7da3f9f3 100644 --- a/src/transport/some.rs +++ b/src/transport/some.rs @@ -1,4 +1,4 @@ -use zerocopy::{FromBytes, IntoBytes}; +use zerocopy::{FromBytes, Immutable, IntoBytes}; use super::{mmio::MmioTransport, pci::PciTransport, DeviceStatus, DeviceType, Transport}; use crate::{PhysAddr, Result}; @@ -130,7 +130,11 @@ impl Transport for SomeTransport { } } - fn write_config_space(&mut self, offset: usize, value: T) -> Result<()> { + fn write_config_space( + &mut self, + offset: usize, + value: T, + ) -> Result<()> { match self { Self::Mmio(mmio) => mmio.write_config_space(offset, value), Self::Pci(pci) => pci.write_config_space(offset, value),