Skip to content

Commit

Permalink
fix(board): avoid infinite wait in sync_vcpus()
Browse files Browse the repository at this point in the history
Signed-off-by: Changyuan Lyu <[email protected]>
  • Loading branch information
Lencerf committed Jan 13, 2025
1 parent d123ca3 commit db11aab
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 91 deletions.
161 changes: 104 additions & 57 deletions alioth/src/board/board.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ mod x86_64;
#[cfg(target_os = "linux")]
use std::collections::HashMap;
use std::ffi::CStr;
use std::sync::atomic::{AtomicU8, Ordering};
use std::sync::mpsc::{Receiver, Sender};
use std::sync::Arc;
use std::thread::JoinHandle;
Expand Down Expand Up @@ -94,14 +93,26 @@ pub enum Error {
Firmware { error: std::io::Error },
#[snafu(display("Failed to notify the VMM thread"))]
NotifyVmm,
#[snafu(display("Another VCPU thread has signaled failure"))]
PeerFailure,
}

type Result<T, E = Error> = std::result::Result<T, E>;

pub const STATE_CREATED: u8 = 0;
pub const STATE_RUNNING: u8 = 1;
pub const STATE_SHUTDOWN: u8 = 2;
pub const STATE_REBOOT_PENDING: u8 = 3;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum BoardState {
Created,
Running,
Shutdown,
RebootPending,
Failure,
}

#[derive(Debug)]
struct MpSync {
state: BoardState,
count: u32,
}

pub const PCIE_MMIO_64_SIZE: u64 = 1 << 40;

Expand Down Expand Up @@ -129,9 +140,7 @@ where
pub vcpus: Arc<RwLock<Vec<VcpuHandle>>>,
pub arch: ArchBoard<V>,
pub config: BoardConfig,
pub state: AtomicU8,
pub payload: RwLock<Option<Payload>>,
pub mp_sync: Arc<(Mutex<u32>, Condvar)>,
pub io_devs: RwLock<Vec<(u16, Arc<dyn Mmio>)>>,
#[cfg(target_arch = "aarch64")]
pub mmio_devs: RwLock<Vec<(u64, Arc<MemRegion>)>>,
Expand All @@ -142,12 +151,52 @@ where
pub vfio_ioases: Mutex<HashMap<Box<str>, Arc<Ioas>>>,
#[cfg(target_os = "linux")]
pub vfio_containers: Mutex<HashMap<Box<str>, Arc<Container>>>,

mp_sync: Mutex<MpSync>,
cond_var: Condvar,
}

impl<V> Board<V>
where
V: Vm,
{
pub fn new(vm: V, memory: Memory, arch: ArchBoard<V>, config: BoardConfig) -> Self {
Board {
vm,
memory,
arch,
config,
payload: RwLock::new(None),
vcpus: Arc::new(RwLock::new(Vec::new())),
io_devs: RwLock::new(Vec::new()),
#[cfg(target_arch = "aarch64")]
mmio_devs: RwLock::new(Vec::new()),
pci_bus: PciBus::new(),
#[cfg(target_arch = "x86_64")]
fw_cfg: Mutex::new(None),
#[cfg(target_os = "linux")]
vfio_ioases: Mutex::new(HashMap::new()),
#[cfg(target_os = "linux")]
vfio_containers: Mutex::new(HashMap::new()),

mp_sync: Mutex::new(MpSync {
state: BoardState::Created,
count: 0,
}),
cond_var: Condvar::new(),
}
}

pub fn boot(&self) -> Result<()> {
let vcpus = self.vcpus.read();
let mut mp_sync = self.mp_sync.lock();
mp_sync.state = BoardState::Running;
for (_, boot_tx) in vcpus.iter() {
boot_tx.send(()).unwrap();
}
Ok(())
}

fn load_payload(&self) -> Result<InitState, Error> {
let payload = self.payload.read();
let Some(payload) = payload.as_ref() else {
Expand Down Expand Up @@ -221,10 +270,11 @@ where
break Ok(true);
}
VmExit::Interrupted => {
let state = self.state.load(Ordering::Acquire);
match state {
STATE_SHUTDOWN => VmEntry::Shutdown,
STATE_REBOOT_PENDING => VmEntry::Reboot,
let mp_sync = self.mp_sync.lock();
match mp_sync.state {
BoardState::Shutdown => VmEntry::Shutdown,
BoardState::RebootPending => VmEntry::Reboot,
BoardState::Failure => break error::PeerFailure.fail(),
_ => VmEntry::None,
}
}
Expand All @@ -237,16 +287,25 @@ where
}
}

fn sync_vcpus(&self, vcpus: &VcpuGuard) {
let (lock, cvar) = &*self.mp_sync;
let mut count = lock.lock();
*count += 1;
if *count == vcpus.len() as u32 {
*count = 0;
cvar.notify_all();
fn sync_vcpus(&self, vcpus: &VcpuGuard) -> Result<()> {
let mut mp_sync = self.mp_sync.lock();
if mp_sync.state == BoardState::Failure {
return error::PeerFailure.fail();
}

mp_sync.count += 1;
if mp_sync.count == vcpus.len() as u32 {
mp_sync.count = 0;
self.cond_var.notify_all();
} else {
cvar.wait(&mut count)
self.cond_var.wait(&mut mp_sync)
}

if mp_sync.state == BoardState::Failure {
return error::PeerFailure.fail();
}

Ok(())
}

fn run_vcpu_inner(
Expand All @@ -257,7 +316,7 @@ where
) -> Result<(), Error> {
self.init_vcpu(id, vcpu)?;
boot_rx.recv().unwrap();
if self.state.load(Ordering::Acquire) != STATE_RUNNING {
if self.mp_sync.lock().state != BoardState::Running {
return Ok(());
}
loop {
Expand All @@ -279,37 +338,28 @@ where
}
self.init_ap(id, vcpu, &vcpus)?;
self.coco_finalize(id, &vcpus)?;
self.sync_vcpus(&vcpus)?;
drop(vcpus);

let reboot = self.vcpu_loop(vcpu, id)?;

let new_state = if reboot {
STATE_REBOOT_PENDING
} else {
STATE_SHUTDOWN
};
let vcpus = self.vcpus.read();
match self.state.compare_exchange(
STATE_RUNNING,
new_state,
Ordering::AcqRel,
Ordering::Acquire,
) {
Ok(STATE_RUNNING) => {
for (vcpu_id, (handle, _)) in vcpus.iter().enumerate() {
if id != vcpu_id as u32 {
log::info!("vcpu{id} to kill {vcpu_id}");
V::stop_vcpu(vcpu_id as u32, handle).context(error::StopVcpu { id })?;
}
let mut mp_sync = self.mp_sync.lock();
if mp_sync.state == BoardState::Running {
mp_sync.state = if reboot {
BoardState::RebootPending
} else {
BoardState::Shutdown
};
drop(mp_sync);
for (vcpu_id, (handle, _)) in vcpus.iter().enumerate() {
if id != vcpu_id as u32 {
log::info!("VCPU-{id}: stopping VCPU-{vcpu_id}");
V::stop_vcpu(vcpu_id as u32, handle).context(error::StopVcpu { id })?;
}
}
Err(s) if s == new_state => {}
Ok(s) | Err(s) => {
log::error!("unexpected state: {s}");
}
}

self.sync_vcpus(&vcpus);
self.sync_vcpus(&vcpus)?;

if id == 0 {
let devices = self.pci_bus.segment.devices.read();
Expand All @@ -319,22 +369,14 @@ where
}
self.memory.reset()?;
}
self.reset_vcpu(id, vcpu)?;

if new_state == STATE_SHUTDOWN {
let mut mp_state = self.mp_sync.lock();
if mp_state.state == BoardState::Shutdown {
break Ok(());
} else {
mp_state.state = BoardState::Running;
}

match self.state.compare_exchange(
STATE_REBOOT_PENDING,
STATE_RUNNING,
Ordering::AcqRel,
Ordering::Acquire,
) {
Ok(STATE_REBOOT_PENDING) | Err(STATE_RUNNING) => {}
_ => break Ok(()),
}

self.reset_vcpu(id, vcpu)?;
}
}

Expand All @@ -356,7 +398,12 @@ where
let mut vcpu = self.create_vcpu(id, &event_tx)?;

let ret = self.run_vcpu_inner(id, &mut vcpu, &boot_rx);
self.state.store(STATE_SHUTDOWN, Ordering::Release);
if ret.is_err() && !matches!(ret, Err(Error::PeerFailure { .. })) {
log::warn!("VCPU-{id} reported error, unblocking other VCPUs...");
let mut mp_lock = self.mp_sync.lock();
mp_lock.state = BoardState::Failure;
self.cond_var.notify_all();
}
event_tx.send(id).unwrap();
ret
}
Expand Down
6 changes: 3 additions & 3 deletions alioth/src/board/x86_64.rs
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ where
Some(Coco::AmdSnp { .. }) => {}
_ => return Ok(()),
}
self.sync_vcpus(vcpus);
self.sync_vcpus(vcpus)?;
if id == 0 {
return Ok(());
}
Expand Down Expand Up @@ -319,7 +319,7 @@ where

pub fn coco_finalize(&self, id: u32, vcpus: &VcpuGuard) -> Result<()> {
if let Some(coco) = &self.config.coco {
self.sync_vcpus(vcpus);
self.sync_vcpus(vcpus)?;
if id == 0 {
match coco {
Coco::AmdSev { policy } => {
Expand All @@ -334,7 +334,7 @@ where
}
}
}
self.sync_vcpus(vcpus);
self.sync_vcpus(vcpus)?;
}
Ok(())
}
Expand Down
36 changes: 5 additions & 31 deletions alioth/src/vm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,24 +12,22 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#[cfg(target_os = "linux")]
use std::collections::HashMap;
#[cfg(target_os = "linux")]
use std::path::Path;
use std::sync::atomic::{AtomicU8, Ordering};
use std::sync::mpsc::{self, Receiver, Sender};
use std::sync::Arc;
use std::thread;
use std::time::Duration;

use parking_lot::{Condvar, Mutex, RwLock};
#[cfg(target_os = "linux")]
use parking_lot::Mutex;
use snafu::{ResultExt, Snafu};

#[cfg(target_arch = "aarch64")]
use crate::arch::layout::PL011_START;
#[cfg(target_arch = "x86_64")]
use crate::arch::layout::{PORT_COM1, PORT_FW_CFG_SELECTOR};
use crate::board::{ArchBoard, Board, BoardConfig, STATE_CREATED, STATE_RUNNING};
use crate::board::{ArchBoard, Board, BoardConfig};
#[cfg(target_arch = "x86_64")]
use crate::device::fw_cfg::{FwCfg, FwCfgItemParam};
#[cfg(target_arch = "aarch64")]
Expand All @@ -45,7 +43,6 @@ use crate::loader::Payload;
use crate::mem::Memory;
#[cfg(target_arch = "aarch64")]
use crate::mem::{MemRegion, MemRegionType};
use crate::pci::bus::PciBus;
use crate::pci::{Bdf, PciDevice};
#[cfg(target_os = "linux")]
use crate::vfio::bindings::VfioIommu;
Expand Down Expand Up @@ -134,26 +131,7 @@ where
let memory = Memory::new(vm_memory);
let arch = ArchBoard::new(&hv, &vm, &config)?;

let board = Arc::new(Board {
vm,
memory,
arch,
config,
state: AtomicU8::new(STATE_CREATED),
payload: RwLock::new(None),
vcpus: Arc::new(RwLock::new(Vec::new())),
mp_sync: Arc::new((Mutex::new(0), Condvar::new())),
io_devs: RwLock::new(Vec::new()),
#[cfg(target_arch = "aarch64")]
mmio_devs: RwLock::new(Vec::new()),
pci_bus: PciBus::new(),
#[cfg(target_arch = "x86_64")]
fw_cfg: Mutex::new(None),
#[cfg(target_os = "linux")]
vfio_ioases: Mutex::new(HashMap::new()),
#[cfg(target_os = "linux")]
vfio_containers: Mutex::new(HashMap::new()),
});
let board = Arc::new(Board::new(vm, memory, arch, config));

let (event_tx, event_rx) = mpsc::channel();

Expand Down Expand Up @@ -291,11 +269,7 @@ where
}

pub fn boot(&self) -> Result<(), Error> {
let vcpus = self.board.vcpus.read();
self.board.state.store(STATE_RUNNING, Ordering::Release);
for (_, boot_tx) in vcpus.iter() {
boot_tx.send(()).unwrap();
}
self.board.boot()?;
Ok(())
}

Expand Down

0 comments on commit db11aab

Please sign in to comment.