Skip to content

Commit

Permalink
Vm: Forward ECALL return value from host
Browse files Browse the repository at this point in the history
For certain ECALLs (for now, just debug console) from guest VMs, we want to
forward the call to the host and have the host provide the return value.
Do this by taking the A0/A1 values from the guest GPRs in the shared memory
area when running the vCPU after it exits due to a forwarded ECALL, just
like we do for emulated MMIO loads.

Signed-off-by: Andrew Bresticker <[email protected]>
  • Loading branch information
abrestic-rivos committed Jan 3, 2023
1 parent 7161da5 commit 40d9f0e
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 57 deletions.
15 changes: 12 additions & 3 deletions src/host_vm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use riscv_regs::{
CSR_HTINST, CSR_HTVAL, CSR_SCAUSE, CSR_STVAL,
};
use s_mode_utils::print::*;
use sbi::{self, DebugConsoleFunction, SbiMessage, StateFunction};
use sbi::{self, DebugConsoleFunction, Error as SbiError, SbiMessage, SbiReturn, StateFunction};

use crate::guest_tracking::{GuestVm, Guests, Result as GuestTrackingResult};
use crate::smp;
Expand Down Expand Up @@ -428,11 +428,20 @@ impl HostVmRunner {
return ControlFlow::Continue(())
}
Ok(DebugConsole(DebugConsoleFunction::PutString { len, addr })) => {
// Can't do anything about errors right now.
let _ = self.handle_put_string(&vm, addr, len);
let sbi_ret = match self.handle_put_string(&vm, addr, len) {
Ok(n) => SbiReturn::success(n),
Err(n) => SbiReturn {
error_code: SbiError::InvalidAddress as i64,
return_value: n,
},
};

self.gprs.set_reg(GprIndex::A0, sbi_ret.error_code as u64);
self.gprs.set_reg(GprIndex::A1, sbi_ret.return_value);
}
Ok(PutChar(c)) => {
print!("{}", c as u8 as char);
self.gprs.set_reg(GprIndex::A0, 0);
}
_ => {
println!("Unhandled ECALL from host");
Expand Down
18 changes: 9 additions & 9 deletions src/vm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ pub enum VmExitCause {
FatalEcall(SbiMessage),
ResumableEcall(SbiMessage),
BlockingEcall(SbiMessage, TlbVersion),
ForwardedEcall(SbiMessage),
PageFault(Exception, GuestPageAddr),
MmioFault(MmioOperation, GuestPhysAddr),
Wfi(DecodedInstruction),
Expand Down Expand Up @@ -197,6 +198,7 @@ enum EcallAction {
Continue(SbiReturn),
Break(VmExitCause, SbiReturn),
Retry(VmExitCause),
Forward(SbiMessage),
}

impl From<EcallResult<u64>> for EcallAction {
Expand Down Expand Up @@ -500,6 +502,9 @@ impl<'a, T: GuestStagePagingMode> FinalizedVm<'a, T> {
active_vcpu.set_ecall_result(Standard(sbi_ret));
break reason;
}
EcallAction::Forward(sbi_msg) => {
break VmExitCause::ForwardedEcall(sbi_msg);
}
EcallAction::Retry(reason) => {
break reason;
}
Expand Down Expand Up @@ -701,10 +706,7 @@ impl<'a, T: GuestStagePagingMode> FinalizedVm<'a, T> {
/// Handles ecalls from the guest.
fn handle_ecall(&self, msg: SbiMessage, active_vcpu: &mut ActiveVmCpu<T>) -> EcallAction {
match msg {
SbiMessage::PutChar(_) => {
// TODO: Let the host set the return value and forward it to the guest.
EcallAction::Break(VmExitCause::ResumableEcall(msg), SbiReturn::success(0))
}
SbiMessage::PutChar(_) => EcallAction::Forward(msg),
SbiMessage::Reset(ResetFunction::Reset { .. }) => {
EcallAction::Break(VmExitCause::FatalEcall(msg), SbiReturn::success(0))
}
Expand Down Expand Up @@ -885,11 +887,9 @@ impl<'a, T: GuestStagePagingMode> FinalizedVm<'a, T> {

fn handle_debug_console(&self, debug_con_func: DebugConsoleFunction) -> EcallAction {
match debug_con_func {
// TODO: Let the host set the return value and forward it to the guest.
DebugConsoleFunction::PutString { len, addr: _ } => EcallAction::Break(
VmExitCause::ResumableEcall(SbiMessage::DebugConsole(debug_con_func)),
SbiReturn::success(len),
),
DebugConsoleFunction::PutString { .. } => {
EcallAction::Forward(SbiMessage::DebugConsole(debug_con_func))
}
}
}

Expand Down
108 changes: 67 additions & 41 deletions src/vm_cpu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use page_tracking::TlbVersion;
use riscv_page_tables::GuestStagePagingMode;
use riscv_pages::{GuestPhysAddr, GuestVirtAddr, PageOwnerId, RawAddr};
use riscv_regs::*;
use sbi::{self, api::tee_host::TsmShmemAreaRef, SbiMessage, SbiReturnType};
use sbi::{self, api::tee_host::TsmShmemAreaRef, SbiMessage, SbiReturn, SbiReturnType};
use spin::{Mutex, MutexGuard, Once, RwLock};

use crate::smp::PerCpu;
Expand Down Expand Up @@ -364,12 +364,18 @@ struct PrevTlb {
tlb_version: TlbVersion,
}

// An operation that's pending a return value from the vCPU's host.
enum PendingOperation {
Mmio(MmioOperation),
Ecall(SbiMessage),
}

// The architectural state of a vCPU.
struct VmCpuArchState {
regs: VmCpuRegisters,
pmu: VmPmuState,
prev_tlb: Option<PrevTlb>,
pending_mmio_op: Option<MmioOperation>,
pending_op: Option<PendingOperation>,
shmem_area: Option<PinnedTsmShmemArea>,
}

Expand Down Expand Up @@ -406,7 +412,7 @@ impl VmCpuArchState {
regs,
pmu: VmPmuState::default(),
prev_tlb: None,
pending_mmio_op: None,
pending_op: None,
shmem_area: None,
}
}
Expand Down Expand Up @@ -553,7 +559,7 @@ impl<'vcpu, 'pages, 'host, T: GuestStagePagingMode> ActiveVmCpu<'vcpu, 'pages, '

/// Runs this vCPU until it traps.
pub fn run(&mut self) -> VmCpuTrap {
self.complete_pending_mmio_op();
self.complete_pending_op();

match self.host_context {
VmCpuParent::HostVm(ref host_vcpu) => {
Expand Down Expand Up @@ -791,6 +797,10 @@ impl<'vcpu, 'pages, 'host, T: GuestStagePagingMode> ActiveVmCpu<'vcpu, 'pages, '
ResumableEcall(msg) | FatalEcall(msg) | BlockingEcall(msg, _) => {
self.report_ecall_exit(msg);
}
ForwardedEcall(msg) => {
self.report_ecall_exit(msg);
self.arch.pending_op = Some(PendingOperation::Ecall(msg));
}
PageFault(exception, page_addr) => {
self.report_pf_exit(exception, page_addr.into());
}
Expand All @@ -817,7 +827,7 @@ impl<'vcpu, 'pages, 'host, T: GuestStagePagingMode> ActiveVmCpu<'vcpu, 'pages, '
self.host_context.set_guest_gpr(GprIndex::A0, val);

// We'll complete a load instruction the next time this vCPU is run.
self.arch.pending_mmio_op = Some(mmio_op);
self.arch.pending_op = Some(PendingOperation::Mmio(mmio_op));
}
Wfi(inst) => {
self.report_vi_exit(inst.raw() as u64);
Expand Down Expand Up @@ -1052,45 +1062,61 @@ impl<'vcpu, 'pages, 'host, T: GuestStagePagingMode> ActiveVmCpu<'vcpu, 'pages, '
self.arch.shmem_area = None;
}

// Completes any pending MMIO operation for this CPU.
fn complete_pending_mmio_op(&mut self) {
// Complete any pending load operations. The host is expected to have written the value
// to complete the load to A0.
if let Some(mmio_op) = self.arch.pending_mmio_op {
let val = self.host_context.guest_gpr(GprIndex::A0);
use MmioOpcode::*;
// Write the value to the actual destination register.
match mmio_op.opcode() {
Load8 => {
self.set_gpr(mmio_op.register(), val as i8 as u64);
}
Load8U => {
self.set_gpr(mmio_op.register(), val as u8 as u64);
}
Load16 => {
self.set_gpr(mmio_op.register(), val as i16 as u64);
}
Load16U => {
self.set_gpr(mmio_op.register(), val as u16 as u64);
}
Load32 => {
self.set_gpr(mmio_op.register(), val as i32 as u64);
}
Load32U => {
self.set_gpr(mmio_op.register(), val as u32 as u64);
}
Load64 => {
self.set_gpr(mmio_op.register(), val);
}
_ => (),
};
// Completes any pending MMIO or ECALL result from the host for this vCPU.
fn complete_pending_op(&mut self) {
match self.arch.pending_op {
Some(PendingOperation::Mmio(mmio_op)) => {
// Complete any pending load operations. The host is expected to have written the
// value to complete the load to A0.
let val = self.host_context.guest_gpr(GprIndex::A0);
use MmioOpcode::*;
// Write the value to the actual destination register.
match mmio_op.opcode() {
Load8 => {
self.set_gpr(mmio_op.register(), val as i8 as u64);
}
Load8U => {
self.set_gpr(mmio_op.register(), val as u8 as u64);
}
Load16 => {
self.set_gpr(mmio_op.register(), val as i16 as u64);
}
Load16U => {
self.set_gpr(mmio_op.register(), val as u16 as u64);
}
Load32 => {
self.set_gpr(mmio_op.register(), val as i32 as u64);
}
Load32U => {
self.set_gpr(mmio_op.register(), val as u32 as u64);
}
Load64 => {
self.set_gpr(mmio_op.register(), val);
}
_ => (),
};
self.host_context.set_guest_gpr(GprIndex::A0, 0);

self.arch.pending_mmio_op = None;
self.host_context.set_guest_gpr(GprIndex::A0, 0);
// Advance SEPC past the faulting instruction.
self.inc_sepc(mmio_op.len() as u64);
}
Some(PendingOperation::Ecall(msg)) => {
// Forward the SBI call return value from the A0/A1 values provided by the host.
let sbi_ret = match msg {
SbiMessage::PutChar(_) => {
SbiReturnType::Legacy(self.host_context.guest_gpr(GprIndex::A0))
}
_ => SbiReturnType::Standard(SbiReturn {
error_code: self.host_context.guest_gpr(GprIndex::A0) as i64,
return_value: self.host_context.guest_gpr(GprIndex::A1),
}),
};

// Advance SEPC past the faulting instruction.
self.inc_sepc(mmio_op.len() as u64);
self.set_ecall_result(sbi_ret);
}
None => (),
}
self.arch.pending_op = None;
}

fn save(&mut self) {
Expand Down
18 changes: 14 additions & 4 deletions test-workloads/src/bin/tellus.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,9 @@ use s_mode_utils::ecall::ecall_send;
use s_mode_utils::{print::*, sbi_console::SbiConsole};
use sbi::api::{base, nacl, pmu, reset, tee_host, tee_interrupt};
use sbi::{
PmuCounterConfigFlags, PmuCounterStartFlags, PmuCounterStopFlags, PmuEventType, PmuFirmware,
PmuHardware, SbiMessage, EXT_PMU, EXT_TEE_HOST, EXT_TEE_INTERRUPT,
Error as SbiError, PmuCounterConfigFlags, PmuCounterStartFlags, PmuCounterStopFlags,
PmuEventType, PmuFirmware, PmuHardware, SbiMessage, SbiReturn, EXT_PMU, EXT_TEE_HOST,
EXT_TEE_INTERRUPT,
};

// Dummy global allocator - panic if anything tries to do an allocation.
Expand Down Expand Up @@ -636,15 +637,24 @@ extern "C" fn kernel_init(hart_id: u64, fdt_addr: u64) {
}
}
Ok(DebugConsole(sbi::DebugConsoleFunction::PutString { len, addr })) => {
let _ = do_guest_puts(
let sbi_ret = match do_guest_puts(
dbcn_gpa_range.clone(),
dbcn_spa_range.clone(),
addr,
len,
);
) {
Ok(n) => SbiReturn::success(n),
Err(n) => SbiReturn {
error_code: SbiError::InvalidAddress as i64,
return_value: n,
},
};
shmem.set_gpr(GprIndex::A0 as usize, sbi_ret.error_code as u64);
shmem.set_gpr(GprIndex::A1 as usize, sbi_ret.return_value);
}
Ok(PutChar(c)) => {
print!("{}", c as u8 as char);
shmem.set_gpr(GprIndex::A0 as usize, 0);
}
_ => {
println!("Unexpected ECALL from guest");
Expand Down

0 comments on commit 40d9f0e

Please sign in to comment.