Skip to content

Commit

Permalink
Better function, variable and call generation (galaxy-lang#60 from we…
Browse files Browse the repository at this point in the history
…suRage/main)

Better function, variable and call generation
  • Loading branch information
wesuRage authored Dec 17, 2024
2 parents 90422de + 454464c commit 19b51cd
Show file tree
Hide file tree
Showing 14 changed files with 121 additions and 86 deletions.
13 changes: 6 additions & 7 deletions examples/a.glx
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
if (num > 0) :
int macaco := 0;
elif (num > 10) :
int macaco := 10;
else:
int macaco := 100;
end;
def main( ) -> int:
int num := 3;
int quatro := num + 1;
return quatro;
end;

Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,6 @@ extern "C" {
#include "llvm/IR/Function.h"
#include "llvm/IR/Module.h"

llvm::Value *generate_binary_expr(BinaryExprNode *node, llvm::LLVMContext &Context, llvm::IRBuilder<> &Builder, llvm::Module &TheModule);
llvm::Value *generate_binary_expr(BinaryExprNode *node, llvm::LLVMContext &Context, llvm::IRBuilder<> &Builder, llvm::Module &Module);

#endif // GENERATE_BINARY_EXPR_H
2 changes: 1 addition & 1 deletion include/backend/generator/expressions/generate_call.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,6 @@ extern "C" {
#include <llvm/IR/LLVMContext.h>
#include <llvm/IR/IRBuilder.h>

llvm::Value *generate_call(CallNode *node, llvm::LLVMContext &Context, llvm::IRBuilder<> &Builder, llvm::Module &TheModule);
llvm::Value *generate_call(CallNode *call_node, llvm::LLVMContext &Context, llvm::IRBuilder<> &Builder, llvm::Module &Module);

#endif // GENERATE_CALL_H
5 changes: 3 additions & 2 deletions include/backend/generator/symbols/identifier_symbol_table.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
#include <unordered_map>
#include <string>
#include <llvm/IR/Function.h>
#include "backend/generator/symbols/symbol_stack.hpp"

llvm::Value* find_identifier(const std::string &name);
void add_identifier(const std::string &name, llvm::Value *value);
const SymbolInfo *find_identifier(const std::string &name);
void add_identifier(const std::string &name, llvm::Value* value, llvm::Type* type);

#endif // IDENTIFIER_SYMBOL_TABLE_H
7 changes: 6 additions & 1 deletion include/backend/generator/symbols/symbol_stack.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,12 @@
#include <stack>
#include <llvm/IR/Value.h>

using SymbolTable = std::unordered_map<std::string, llvm::Value*>;
struct SymbolInfo {
llvm::Value* value;
llvm::Type* type;
};

using SymbolTable = std::unordered_map<std::string, SymbolInfo>;

extern std::stack<SymbolTable> symbol_stack;
void enter_scope(void);
Expand Down
104 changes: 72 additions & 32 deletions src/backend/generator/expressions/generate_binary_expr.cpp
Original file line number Diff line number Diff line change
@@ -1,29 +1,69 @@
#include "backend/generator/expressions/generate_binary_expr.hpp"
#include "backend/generator/expressions/generate_expr.hpp"
#include "backend/generator/symbols/identifier_symbol_table.hpp"

llvm::Value *generate_binary_expr(BinaryExprNode *node, llvm::LLVMContext &Context, llvm::IRBuilder<> &Builder, llvm::Module &Module) {
// Generate LLVM IR for the left-hand side and right-hand side expressions
llvm::Value *L = generate_expr(node->left, Context, Builder, Module);
llvm::Value *R = generate_expr(node->right, Context, Builder, Module);

// Determine the types of L and R
bool isLInteger = L->getType()->isIntegerTy();
bool isRInteger = R->getType()->isIntegerTy();
llvm::Type *LType = L->getType();
llvm::Type *RType = R->getType();

bool isLInteger = LType->isIntegerTy();
bool isRInteger = RType->isIntegerTy();
bool isLFloating = LType->isFloatingPointTy();
bool isRFloating = RType->isFloatingPointTy();
bool isLPointer = LType->isPointerTy();
bool isRPointer = RType->isPointerTy();

if (isLPointer) {
const SymbolInfo* symbolInfo = find_identifier(static_cast<IdentifierNode*>(node->left->data)->symbol);
if (!symbolInfo) {
throw std::runtime_error("Unknown identifier for left operand");
}
llvm::Type* pointeeType = symbolInfo->type;

if (pointeeType->isIntegerTy()) {
L = Builder.CreateLoad(pointeeType, L, "loadtmp");
isLInteger = true;
} else if (pointeeType->isFloatingPointTy()) {
L = Builder.CreateLoad(pointeeType, L, "loadtmp");
isLFloating = true;
} else {
throw std::runtime_error("Unsupported type for left pointer operand");
}
}

if (isRPointer) {
const SymbolInfo* symbolInfo = find_identifier(static_cast<IdentifierNode*>(node->left->data)->symbol);
if (!symbolInfo) {
throw std::runtime_error("Unknown identifier for right operand");
}
llvm::Type* pointeeType = symbolInfo->type;

if (pointeeType->isIntegerTy()) {
R = Builder.CreateLoad(pointeeType, R, "loadtmp");
isRInteger = true;
} else if (pointeeType->isFloatingPointTy()) {
R = Builder.CreateLoad(pointeeType, R, "loadtmp");
isRFloating = true;
} else {
throw std::runtime_error("Unsupported type for right pointer operand");
}
}

// If both are integers, handle integer-specific operations
if (isLInteger && isRInteger) {
if (strcmp(node->op, "+") == 0) return Builder.CreateAdd(L, R, "addtmp");
if (strcmp(node->op, "-") == 0) return Builder.CreateSub(L, R, "subtmp");
if (strcmp(node->op, "*") == 0) return Builder.CreateMul(L, R, "multmp");
if (strcmp(node->op, "/") == 0) return Builder.CreateSDiv(L, R, "divtmp"); // Signed division
if (strcmp(node->op, "%") == 0) return Builder.CreateSRem(L, R, "modtmp"); // Modulus
if (strcmp(node->op, "/") == 0) return Builder.CreateSDiv(L, R, "divtmp");
if (strcmp(node->op, "%") == 0) return Builder.CreateSRem(L, R, "modtmp");
if (strcmp(node->op, "&") == 0) return Builder.CreateAnd(L, R, "andtmp");
if (strcmp(node->op, "|") == 0) return Builder.CreateOr(L, R, "ortmp");
if (strcmp(node->op, "^") == 0) return Builder.CreateXor(L, R, "xortmp");
if (strcmp(node->op, ">>") == 0) return Builder.CreateAShr(L, R, "shrtmp"); // Arithmetic shift right
if (strcmp(node->op, "<<") == 0) return Builder.CreateShl(L, R, "shltmp"); // Shift left
if (strcmp(node->op, ">>") == 0) return Builder.CreateAShr(L, R, "shrtmp");
if (strcmp(node->op, "<<") == 0) return Builder.CreateShl(L, R, "shltmp");

// Comparison operators
if (strcmp(node->op, "==") == 0) return Builder.CreateICmpEQ(L, R, "eqtmp");
if (strcmp(node->op, "!=") == 0) return Builder.CreateICmpNE(L, R, "netmp");
if (strcmp(node->op, "<") == 0) return Builder.CreateICmpSLT(L, R, "lttmp");
Expand All @@ -34,28 +74,28 @@ llvm::Value *generate_binary_expr(BinaryExprNode *node, llvm::LLVMContext &Conte
throw std::runtime_error("Unknown binary operator for integers");
}

// If either is floating-point, cast both to floating-point and perform floating-point operations
if (isLInteger) {
L = Builder.CreateSIToFP(L, llvm::Type::getDoubleTy(Context), "cast_to_fp_L");
}
if (isRInteger) {
R = Builder.CreateSIToFP(R, llvm::Type::getDoubleTy(Context), "cast_to_fp_R");
if (isLFloating || isRFloating) {
if (isLInteger) {
L = Builder.CreateSIToFP(L, llvm::Type::getDoubleTy(Context), "cast_to_fp_L");
}
if (isRInteger) {
R = Builder.CreateSIToFP(R, llvm::Type::getDoubleTy(Context), "cast_to_fp_R");
}

if (strcmp(node->op, "+") == 0) return Builder.CreateFAdd(L, R, "addtmp");
if (strcmp(node->op, "-") == 0) return Builder.CreateFSub(L, R, "subtmp");
if (strcmp(node->op, "*") == 0) return Builder.CreateFMul(L, R, "multmp");
if (strcmp(node->op, "/") == 0) return Builder.CreateFDiv(L, R, "divtmp");

if (strcmp(node->op, "==") == 0) return Builder.CreateFCmpOEQ(L, R, "eqtmp");
if (strcmp(node->op, "!=") == 0) return Builder.CreateFCmpONE(L, R, "netmp");
if (strcmp(node->op, "<") == 0) return Builder.CreateFCmpOLT(L, R, "lttmp");
if (strcmp(node->op, "<=") == 0) return Builder.CreateFCmpOLE(L, R, "letmp");
if (strcmp(node->op, ">") == 0) return Builder.CreateFCmpOGT(L, R, "gttmp");
if (strcmp(node->op, ">=") == 0) return Builder.CreateFCmpOGE(L, R, "getmp");

throw std::runtime_error("Unsupported operator for floating-point numbers");
}

// Floating-point operations
if (strcmp(node->op, "+") == 0) return Builder.CreateFAdd(L, R, "addtmp");
if (strcmp(node->op, "-") == 0) return Builder.CreateFSub(L, R, "subtmp");
if (strcmp(node->op, "*") == 0) return Builder.CreateFMul(L, R, "multmp");
if (strcmp(node->op, "/") == 0) return Builder.CreateFDiv(L, R, "divtmp");

// Comparison operators for floating-point
if (strcmp(node->op, "==") == 0) return Builder.CreateFCmpOEQ(L, R, "eqtmp");
if (strcmp(node->op, "!=") == 0) return Builder.CreateFCmpONE(L, R, "netmp");
if (strcmp(node->op, "<") == 0) return Builder.CreateFCmpOLT(L, R, "lttmp");
if (strcmp(node->op, "<=") == 0) return Builder.CreateFCmpOLE(L, R, "letmp");
if (strcmp(node->op, ">") == 0) return Builder.CreateFCmpOGT(L, R, "gttmp");
if (strcmp(node->op, ">=") == 0) return Builder.CreateFCmpOGE(L, R, "getmp");

// Handle unsupported operators for floating-point
throw std::runtime_error("Unsupported operator for floating-point numbers");
throw std::runtime_error("Unsupported types for binary operation");
}
4 changes: 1 addition & 3 deletions src/backend/generator/expressions/generate_call.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
#include "backend/generator/statements/generate_function_declaration_stmt.hpp"
#include "backend/generator/statements/generate_stmt.hpp"
#include "backend/generator/statements/generate_call.hpp"
#include "backend/generator/expressions/generate_expr.hpp"
#include "backend/generator/types/generate_type.hpp"
#include "backend/generator/symbols/function_symbol_table.hpp"

llvm::Value *generate_call(CallNode *call_node, llvm::LLVMContext &Context, llvm::IRBuilder<> &Builder, llvm::Module &Module) {
Expand Down
7 changes: 4 additions & 3 deletions src/backend/generator/expressions/generate_expr.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
#include <future>
#include <vector>
#include <stdexcept>
#include <mutex>
#include "backend/generator/expressions/generate_expr.hpp"
#include "backend/generator/expressions/generate_numeric_literal.hpp"
#include "backend/generator/expressions/generate_identifier.hpp"
Expand All @@ -12,8 +16,6 @@
#include "backend/generator/expressions/generate_string.hpp"

llvm::Value *generate_expr(AstNode *node, llvm::LLVMContext &Context, llvm::IRBuilder<> &Builder, llvm::Module &Module) {
// Checks the node kind, casts the node data into
// the perspective node type and then generates it.
switch (node->kind) {
case NODE_STRING: {
StringNode *string_node = (StringNode *)node->data;
Expand Down Expand Up @@ -60,7 +62,6 @@ llvm::Value *generate_expr(AstNode *node, llvm::LLVMContext &Context, llvm::IRBu
return generate_assignment_expr(assignNode, Context, Builder, Module);
}
default: {
// TODO
return nullptr;
}
}
Expand Down
9 changes: 3 additions & 6 deletions src/backend/generator/expressions/generate_identifier.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,11 @@
#include "backend/generator/symbols/identifier_symbol_table.hpp"

llvm::Value *generate_identifier(IdentifierNode *node) {
// Usa find_identifier para buscar o valor associado ao símbolo no escopo atual
llvm::Value *value = find_identifier(node->symbol);
const SymbolInfo *id = find_identifier(node->symbol);

// Verifica se o identificador foi encontrado
if (!value) {
if (!id) {
throw std::runtime_error("Error: identifier not found!");
}

// Retorna o LLVM Value associado ao identificador
return value;
return id->value;
}
2 changes: 1 addition & 1 deletion src/backend/generator/statements/generate_extern_stmt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ llvm::Value* generate_extern_stmt(ExternNode *node, llvm::LLVMContext &Context,

global_var->setLinkage(llvm::GlobalValue::ExternalLinkage);

add_identifier(node->identifier, global_var);
add_identifier(node->identifier, global_var, decl_type);

return global_var; // Return the variable declaration
} else {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,18 +1,16 @@
#include "backend/generator/statements/generate_function_declaration_stmt.hpp"
#include "backend/generator/statements/generate_stmt.hpp"
#include "backend/generator/expressions/generate_expr.hpp"
#include "backend/generator/types/generate_type.hpp"
#include "backend/generator/symbols/identifier_symbol_table.hpp"
#include "backend/generator/symbols/function_symbol_table.hpp"
#include "backend/generator/symbols/symbol_stack.hpp"
#include "backend/generator/symbols/identifier_symbol_table.hpp"
#include "backend/generator/types/generate_type.hpp"
#include "backend/generator/expressions/generate_expr.hpp"

llvm::Value* generate_function_declaration_stmt(FunctionNode *node, llvm::LLVMContext &Context, llvm::IRBuilder<> &Builder, llvm::Module &Module) {
if (!node || !node->name || !node->parameters) {
throw std::runtime_error("Invalid function: node, name, or parameters are null.");
}


// Generate the return type for the function
llvm::Type *return_type = generate_type(node->type, Context);

// Generate parameter types
Expand All @@ -23,29 +21,30 @@ llvm::Value* generate_function_declaration_stmt(FunctionNode *node, llvm::LLVMCo
param_types.push_back(param_type);
}

// Create the function type and function
llvm::FunctionType *func_type = llvm::FunctionType::get(return_type, param_types, false);
llvm::Function *function = llvm::Function::Create(func_type, llvm::Function::ExternalLinkage, node->name, &Module);

// Saves the function on a symbol table
function_symbol_table[node->name] = function;

// Assign parameter names

enter_scope();
int idx = 0;
// Assign parameter names
for (auto &arg : function->args()) {
ParameterNode *param = static_cast<ParameterNode*>(node->parameters->parameters[idx]->data);
arg.setName(param->name);

add_identifier(param->name, &arg); // Stores in the identifier symbol table
llvm::Type *type = generate_type(param->type, Context);

// Stores in the identifier symbol table
add_identifier(param->name, &arg, type);
++idx;
}

if (node->body != nullptr) {
llvm::BasicBlock *entry = llvm::BasicBlock::Create(Context, "entry", function);
Builder.SetInsertPoint(entry);

// Map function parameters to LLVM variables
std::map<std::string, llvm::Value*> variable_map;
idx = 0;
for (auto &arg : function->args()) {
Expand All @@ -63,7 +62,6 @@ llvm::Value* generate_function_declaration_stmt(FunctionNode *node, llvm::LLVMCo
ReturnNode *return_node = static_cast<ReturnNode*>(statement->data);
AstNode *value_node = return_node->value;

// Generate the return expression
return_value = generate_expr(value_node, Context, Builder, Module);

if (!return_value) {
Expand All @@ -73,20 +71,16 @@ llvm::Value* generate_function_declaration_stmt(FunctionNode *node, llvm::LLVMCo
} else {
return_value = llvm::Constant::getNullValue(return_type);
Builder.CreateRet(return_value);
return function;
}
} else {
Builder.CreateRet(return_value);
return function;
}

} else {
// Handle other statement types
generate_stmt(statement, Context, Module, Builder);
}
}

// Add a default return if no explicit return is provided
if (!return_value) {
if (return_type->isVoidTy()) {
Builder.CreateRetVoid();
Expand Down
8 changes: 5 additions & 3 deletions src/backend/generator/statements/generate_stmt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,27 +3,29 @@
#include "backend/generator/statements/generate_function_declaration_stmt.hpp"
#include "backend/generator/statements/generate_extern_stmt.hpp"
#include "backend/generator/expressions/generate_expr.hpp"
#include <future>

llvm::Value* generate_stmt(AstNode *node, llvm::LLVMContext &Context, llvm::Module &Module, llvm::IRBuilder<> &Builder) {
// Checks each statement, casts the it's node type from the node data
// and then performs a IR generation of it
switch (node->kind) {
case NODE_VARIABLE: {
VariableNode *varNode = (VariableNode *)node->data;
generate_variable_declaration_stmt(varNode, Context, Builder, Module);

return nullptr;
}
case NODE_FUNCTION: {
FunctionNode *funcNode = (FunctionNode *)node->data;
generate_function_declaration_stmt(funcNode, Context, Builder, Module);

return nullptr;
}
case NODE_EXTERN: {
ExternNode *externNode = (ExternNode *)node->data;
generate_extern_stmt(externNode, Context, Builder, Module);

return nullptr;
}
// Othewise generates an expression

default: {
return generate_expr(node, Context, Builder, Module);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,16 @@ llvm::Value* generate_variable_declaration_stmt(VariableNode *node, llvm::LLVMCo
// Create an AllocaInst to allocate space for the variable on the stack
llvm::AllocaInst *alloca = Builder.CreateAlloca(var_type, nullptr, node->name);

// If the variable has an initial value, generate the corresponding LLVM IR for the initialization
if (node->value != nullptr) {
// Generate the LLVM IR for the initialization expression
// Generate IR for the expression
llvm::Value *init_value = generate_expr(node->value, Context, Builder, Module);

// Store the initialized value into the allocated space (AllocaInst)
Builder.CreateStore(init_value, alloca);
}

// Stores the allocated variable in the identifier symbol table
add_identifier(node->name, alloca);
add_identifier(node->name, alloca, var_type);

// Return the AllocaInst, which represents the variable's storage in memory
return alloca;
}
Loading

0 comments on commit 19b51cd

Please sign in to comment.