Skip to content

Commit

Permalink
feat: LUA add register manager for lua backend
Browse files Browse the repository at this point in the history
  • Loading branch information
sbwtw committed Dec 17, 2023
1 parent 9800032 commit 21b486d
Show file tree
Hide file tree
Showing 6 changed files with 177 additions and 11 deletions.
17 changes: 15 additions & 2 deletions lib/src/ast/expression.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use crate::ast::{
AssignExpression, AstVisitor, CallExpression, CompoAccessExpression, ExprStatement,
IntoStatement, LiteralExpression, OperatorExpression, Statement, VariableExpression,
AssignExpression, AstVisitor, AstVisitorMut, CallExpression, CompoAccessExpression,
ExprStatement, IntoStatement, LiteralExpression, OperatorExpression, Statement,
VariableExpression,
};
use crate::{impl_ast_display, impl_into_statement};

Expand All @@ -25,42 +26,54 @@ impl_into_statement!(Expression, |x| Statement::expr(Box::new(
)));

impl Expression {
#[inline]
pub fn accept_mut<V: AstVisitorMut>(&mut self, vis: &mut V) {
vis.visit_expression_mut(self)
}

#[inline]
pub fn assign(assign: Box<AssignExpression>) -> Self {
Self {
kind: ExprKind::Assign(assign),
}
}

#[inline]
pub fn call(call: Box<CallExpression>) -> Self {
Self {
kind: ExprKind::Call(call),
}
}

#[inline]
pub fn literal(literal: Box<LiteralExpression>) -> Self {
Self {
kind: ExprKind::Literal(literal),
}
}

#[inline]
pub fn operator(operator: Box<OperatorExpression>) -> Self {
Self {
kind: ExprKind::Operator(operator),
}
}

#[inline]
pub fn variable(variable: Box<VariableExpression>) -> Self {
Self {
kind: ExprKind::Variable(variable),
}
}

#[inline]
pub fn compo(compo: Box<CompoAccessExpression>) -> Self {
Self {
kind: ExprKind::Compo(compo),
}
}

#[inline]
pub fn get_variable_expression(&self) -> Option<&VariableExpression> {
match &self.kind {
ExprKind::Variable(var) => Some(var),
Expand Down
16 changes: 16 additions & 0 deletions lib/src/backend/lua/bytecode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ pub enum LuaByteCode {
SetTabUp(u8, u8, u8),
/// A B: R[A] := K[Bx]
LoadK(u8, u32),
/// A B: R[A] := R[B]
Move(u8, u8),
/// A sBx: R[A] := sBx
LoadI(u8, u32),
}

impl LuaByteCode {
Expand All @@ -60,6 +64,8 @@ impl LuaByteCode {
LuaByteCode::GetTabUp(..) => "GETTABUP",
LuaByteCode::SetTabUp(..) => "SETTABUP",
LuaByteCode::LoadK(..) => "LOADK",
LuaByteCode::Move(..) => "MOVE",
LuaByteCode::LoadI(..) => "LOADI",
}
}

Expand All @@ -69,6 +75,8 @@ impl LuaByteCode {
LuaByteCode::GetTabUp(..) => LuaOpCode::OP_GETTABUP,
LuaByteCode::SetTabUp(..) => LuaOpCode::OP_SETTABUP,
LuaByteCode::LoadK(..) => LuaOpCode::OP_LOADK,
LuaByteCode::Move(..) => LuaOpCode::OP_MOVE,
LuaByteCode::LoadI(..) => LuaOpCode::OP_LOADI,
}
}
}
Expand All @@ -95,6 +103,14 @@ impl LuaCode {
LuaByteCode::LoadK(a, bx) => {
write!(s, "{a} {bx}").unwrap();
}
// AsBx
LuaByteCode::LoadI(a, sbx) => {
write!(s, "{a} {sbx}").unwrap();
}
// A B
LuaByteCode::Move(a, b) => {
write!(s, "{a} {b}").unwrap();
}
}

match code {
Expand Down
56 changes: 48 additions & 8 deletions lib/src/backend/lua/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,20 @@ use bytecode::{LuaByteCode, LuaCode, LuaConstants};

/// 32-bits Lua instruction bytecode encoding/decoding
mod encoding;
mod register;
mod utils;

use crate::backend::*;
use crate::parser::{LiteralValue, Operator};
use crate::prelude::*;

use crate::backend::lua::register::{RegisterId, RegisterManager};
use crate::backend::lua::utils::try_fit_sbx;
use indexmap::IndexSet;
use log::*;
use smallvec::{smallvec, SmallVec};
use std::rc::Rc;

type RegisterId = usize;

#[derive(Clone)]
pub struct LuaBackendAttribute {
variable: Option<Rc<Variable>>,
Expand Down Expand Up @@ -46,6 +48,7 @@ pub struct LuaBackend {
attributes: SmallVec<[LuaBackendAttribute; 32]>,
local_function: Option<Function>,
constants: IndexSet<LuaConstants>,
reg_mgr: RegisterManager,
}

impl LuaBackend {
Expand Down Expand Up @@ -120,6 +123,7 @@ impl CodeGenBackend for LuaBackend {
attributes: smallvec![],
local_function: None,
constants: IndexSet::new(),
reg_mgr: RegisterManager::new(),
}
}

Expand Down Expand Up @@ -166,6 +170,20 @@ impl AstVisitorMut for LuaBackend {
fn visit_literal_mut(&mut self, literal: &mut LiteralExpression) {
trace!("LuaGen: literal expression: {:?}", literal);

// Literals can't WRITE
assert!(!self
.top_attribute()
.access_mode
.contains(AccessModeFlags::WRITE));

// if literal can use LoadI instructions
if let Some(v) = try_fit_sbx(literal.literal()) {
let r = self.reg_mgr.alloc();
self.byte_codes.push(LuaByteCode::LoadI(r as u8, v));
self.top_attribute().register = Some(r);
return;
}

match literal.literal() {
LiteralValue::String(s) => {
let constant_index = self.add_string_constant(s);
Expand All @@ -184,22 +202,28 @@ impl AstVisitorMut for LuaBackend {
}
}

fn visit_variable_expression_mut(&mut self, variable: &mut VariableExpression) {
fn visit_variable_expression_mut(&mut self, var_expr: &mut VariableExpression) {
let scope = self.current_scope();
let var = scope.find_variable(variable.name());
let var = scope.find_variable(var_expr.name());

trace!(
"LuaGen: variable expression: {}: {:?}",
variable,
var_expr,
var.and_then(|x| x.ty())
);

// Callee process
if self.top_attribute().access_mode == AccessModeFlags::CALL {
self.top_attribute().constant_index =
Some(self.add_string_constant(variable.org_name()));
Some(self.add_string_constant(var_expr.org_name()));
return;
}

let scope = self.top_attribute().scope.as_ref().unwrap();
if let Some(variable) = scope.find_variable(var_expr.name()) {
self.top_attribute().register = Some(self.reg_mgr.alloc());
} else {
let scope = self.top_attribute().scope.as_ref().unwrap();
self.top_attribute().variable = scope.find_variable(variable.name());
// TODO: variable not found error
}
}

Expand Down Expand Up @@ -276,5 +300,21 @@ impl AstVisitorMut for LuaBackend {

fn visit_assign_expression_mut(&mut self, assign: &mut AssignExpression) {
trace!("LuaGen: assignment expression: {}", assign);

self.push_access_attribute(AccessModeFlags::READ);
assign.right_mut().accept_mut(self);
let rhs = self.pop_attribute();

self.push_access_attribute(AccessModeFlags::WRITE);
assign.left_mut().accept_mut(self);
let lhs = self.pop_attribute();

// free temporary registers
if let Some(r) = rhs.register {
self.byte_codes
.push(LuaByteCode::Move(lhs.register.unwrap() as u8, r as u8));

self.reg_mgr.free(&r)
}
}
}
44 changes: 44 additions & 0 deletions lib/src/backend/lua/register.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
use std::collections::HashSet;

pub type RegisterId = usize;

const MAX_REGISTER_ID: usize = 255;

pub struct RegisterManager {
register_alloc_cursor: usize,
used_registers: HashSet<usize>,
}

impl RegisterManager {
#[inline]
pub fn new() -> Self {
Self {
register_alloc_cursor: 0,
used_registers: HashSet::with_capacity(MAX_REGISTER_ID),
}
}

pub fn alloc(&mut self) -> RegisterId {
// ensure has free register to allocate
assert!(self.used_registers.len() <= MAX_REGISTER_ID);

loop {
if self.used_registers.insert(self.register_alloc_cursor) {
return self.register_alloc_cursor;
}

self.register_alloc_cursor += 1;
self.register_alloc_cursor %= MAX_REGISTER_ID;
}
}

#[inline]
pub fn free(&mut self, id: &RegisterId) {
_ = self.used_registers.remove(id)
}

#[inline]
pub fn used_count(&self) -> usize {
self.used_registers.len()
}
}
53 changes: 53 additions & 0 deletions lib/src/backend/lua/utils.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
use crate::parser::{BitValue, LiteralValue};

/// sBx use 17 Bits
const SBX_BIT_SIZE: u32 = 17;
/// Max value of sBx is 2^18 - 1
const SBX_MAX_VALUE: i32 = 2_i32.pow(SBX_BIT_SIZE + 1) - 1;
/// Min value of sBx is -2^18
const SBX_MIN_VALUE: i32 = 0 - 2_i32.pow(SBX_BIT_SIZE);
/// sBx Mask, lower 17 Bits is 1
const SBX_MASK: u32 = 0b0001_1111_1111_1111_1111;

/// Returns true if literal can fit into sBx value
pub fn try_fit_sbx(literal: &LiteralValue) -> Option<u32> {
match literal {
LiteralValue::Bit(BitValue::Zero) => Some(0),
LiteralValue::Bit(BitValue::One) => Some(1),
LiteralValue::Bool(false) => Some(0),
LiteralValue::Bool(true) => Some(1),
LiteralValue::Byte(v) => Some(*v as u32),
LiteralValue::SInt(v) => Some((*v as i32) as u32 & SBX_MASK),
LiteralValue::Int(v) => Some((*v as i32) as u32 & SBX_MASK),
LiteralValue::UInt(v) => Some(*v as u32 & SBX_MASK),
LiteralValue::DInt(v) => {
if (SBX_MIN_VALUE..=SBX_MAX_VALUE).contains(v) {
Some(*v as u32 & SBX_MASK)
} else {
None
}
}
LiteralValue::UDInt(v) => {
if *v <= SBX_MAX_VALUE as u32 {
Some(v & SBX_MASK)
} else {
None
}
}
LiteralValue::LInt(v) => {
if (SBX_MIN_VALUE as i64..=SBX_MAX_VALUE as i64).contains(v) {
Some(*v as u32 & SBX_MASK)
} else {
None
}
}
LiteralValue::ULInt(v) => {
if *v <= SBX_MAX_VALUE as u64 {
Some(*v as u32 & SBX_MASK)
} else {
None
}
}
_ => None,
}
}
2 changes: 1 addition & 1 deletion viewer/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ fn main() {
let test_fun_decl = StDeclarationParser::new().parse(test_func).unwrap();
let test_fun_decl_id = app_ctx.write().add_declaration(test_fun_decl);

let test_func = StLexerBuilder::new().build_str("print(a + b);");
let test_func = StLexerBuilder::new().build_str("a := 1; b := 2; print(a + b);");
let test_fun_body = StFunctionParser::new().parse(test_func).unwrap();
app_ctx
.write()
Expand Down

0 comments on commit 21b486d

Please sign in to comment.