Skip to content

Commit

Permalink
Refactoring proof trace code (#876)
Browse files Browse the repository at this point in the history
As discussed previously, this PR is an initial cleanup and refactoring
effort for the existing proof-hint generation code implemented for
rewrite, function and hook events. The changes in this PR don't add any
new features to the traces; they just reorganise the existing code.

Most of the changes made are documentation and identifying duplicated
code across different call-sites that can be merged together.

I plan to port the reorganisation in #862 over to the infrastructure in
this PR once it is merged.

The trace format is not tested in the backend currently (this is future
work), but I have verified that a proof trace (that uses all event
types) generated using this branch is byte-for-byte identical to one
generated from the current master branch.
  • Loading branch information
Baltoli authored Nov 9, 2023
1 parent cb8fe60 commit 0e24009
Show file tree
Hide file tree
Showing 4 changed files with 344 additions and 321 deletions.
108 changes: 95 additions & 13 deletions include/kllvm/codegen/ProofEvent.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,32 +7,114 @@

#include "llvm/IR/Instructions.h"

namespace kllvm {
#include <map>
#include <tuple>

void writeUInt64(
llvm::Value *outputFile, llvm::Module *Module, uint64_t value,
llvm::BasicBlock *Block);
namespace kllvm {

class ProofEvent {
private:
KOREDefinition *Definition;
llvm::BasicBlock *CurrentBlock;
llvm::Module *Module;
llvm::LLVMContext &Ctx;

/*
* Load the boolean flag that controls whether proof hint output is enabled or
* not, then create a branch at the end of this basic block depending on the
* result.
*
* Returns a pair of blocks [proof enabled, merge]; the first of these is
* intended for self-contained behaviour only relevant in proof output mode,
* while the second is for the continuation of the interpreter's previous
* behaviour.
*/
std::pair<llvm::BasicBlock *, llvm::BasicBlock *>
proofBranch(std::string label);
proofBranch(std::string const &label, llvm::BasicBlock *insertAtEnd);

/*
* Set up a standard event prelude by creating a pair of basic blocks for the
* proof output and continuation, then loading the output filename from its
* global.
*
* Returns a triple [proof enabled, merge, output_file]; see `proofBranch` and
* `emitGetOutputFileName`.
*/
std::tuple<llvm::BasicBlock *, llvm::BasicBlock *, llvm::Value *>
eventPrelude(std::string const &label, llvm::BasicBlock *insertAtEnd);

/*
* Emit a call that will serialize `term` to the specified `outputFile` as
* binary KORE. This function can be called on any term, but the sort of that
* term must be known.
*/
llvm::CallInst *emitSerializeTerm(
KORECompositeSort &sort, llvm::Value *outputFile, llvm::Value *term,
llvm::BasicBlock *insertAtEnd);

/*
* Emit a call that will serialize `config` to the specified `outputFile` as
* binary KORE. This function does not require a sort, but the configuration
* passed must be a top-level configuration.
*/
llvm::CallInst *emitSerializeConfiguration(
llvm::Value *outputFile, llvm::Value *config,
llvm::BasicBlock *insertAtEnd);

/*
* Emit a call that will serialize `value` to the specified `outputFile`.
*/
llvm::CallInst *emitWriteUInt64(
llvm::Value *outputFile, uint64_t value, llvm::BasicBlock *insertAtEnd);

/*
* Emit a call that will serialize `str` to the specified `outputFile`.
*/
llvm::CallInst *emitWriteString(
llvm::Value *outputFile, std::string const &str,
llvm::BasicBlock *insertAtEnd);

/*
* Emit an instruction that has no effect and will be removed by optimization
* passes.
*
* We need this workaround because some callsites will try to use
* llvm::Instruction::insertAfter on the back of the MergeBlock after a proof
* branch is created. If the MergeBlock has no instructions, this has resulted
* in a segfault when printing the IR. Adding an effective no-op prevents this.
*/
llvm::BinaryOperator *emitNoOp(llvm::BasicBlock *insertAtEnd);

/*
* Emit instructions to load the path of the interpreter's current output
* file; used here for binary proof trace data.
*/
llvm::LoadInst *emitGetOutputFileName(llvm::BasicBlock *insertAtEnd);

public:
llvm::BasicBlock *hookEvent_pre(std::string name);
llvm::BasicBlock *hookEvent_post(llvm::Value *val, KORECompositeSort *sort);
llvm::BasicBlock *hookArg(llvm::Value *val, KORECompositeSort *sort);
[[nodiscard]] llvm::BasicBlock *
hookEvent_pre(std::string name, llvm::BasicBlock *current_block);

[[nodiscard]] llvm::BasicBlock *hookEvent_post(
llvm::Value *val, KORECompositeSort *sort,
llvm::BasicBlock *current_block);

[[nodiscard]] llvm::BasicBlock *hookArg(
llvm::Value *val, KORECompositeSort *sort,
llvm::BasicBlock *current_block);

[[nodiscard]] llvm::BasicBlock *rewriteEvent(
KOREAxiomDeclaration *axiom, llvm::Value *return_value, uint64_t arity,
std::map<std::string, KOREVariablePattern *> vars,
llvm::StringMap<llvm::Value *> const &subst,
llvm::BasicBlock *current_block);

[[nodiscard]] llvm::BasicBlock *functionEvent(
llvm::BasicBlock *current_block, KORECompositePattern *pattern,
std::string const &locationStack);

public:
ProofEvent(
KOREDefinition *Definition, llvm::BasicBlock *EntryBlock,
llvm::Module *Module)
ProofEvent(KOREDefinition *Definition, llvm::Module *Module)
: Definition(Definition)
, CurrentBlock(EntryBlock)
, Module(Module)
, Ctx(Module->getContext()) { }
};
Expand Down
150 changes: 12 additions & 138 deletions lib/codegen/CreateTerm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <iostream>

#include "runtime/header.h" //for macros

#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/DerivedTypes.h"
Expand Down Expand Up @@ -117,10 +118,6 @@ declare void @printConfiguration(i8 *, %block *)
}
} // namespace

void writeUInt64(
llvm::Value *outputFile, llvm::Module *Module, uint64_t value,
llvm::BasicBlock *Block);

std::unique_ptr<llvm::Module>
newModule(std::string name, llvm::LLVMContext &Context) {
llvm::SMDiagnostic Err;
Expand Down Expand Up @@ -318,8 +315,8 @@ llvm::Value *CreateTerm::alloc_arg(
llvm::Value *ret
= createAllocation(p, fmt::format("{}:{}", locationStack, idx)).first;
auto sort = dynamic_cast<KORECompositeSort *>(p->getSort().get());
ProofEvent e(Definition, CurrentBlock, Module);
CurrentBlock = e.hookArg(ret, sort);
ProofEvent e(Definition, Module);
CurrentBlock = e.hookArg(ret, sort, CurrentBlock);
return ret;
}

Expand Down Expand Up @@ -700,49 +697,8 @@ llvm::Value *CreateTerm::createFunctionCall(
}
}

llvm::Function *func = CurrentBlock->getParent();

auto ProofOutputFlag = Module->getOrInsertGlobal(
"proof_output", llvm::Type::getInt1Ty(Module->getContext()));
auto OutputFileName = Module->getOrInsertGlobal(
"output_file", llvm::Type::getInt8PtrTy(Module->getContext()));
auto proofOutput = new llvm::LoadInst(
llvm::Type::getInt1Ty(Module->getContext()), ProofOutputFlag,
"proof_output", CurrentBlock);
llvm::BasicBlock *TrueBlock
= llvm::BasicBlock::Create(Module->getContext(), "if", func);
auto outputFile = new llvm::LoadInst(
llvm::Type::getInt8PtrTy(Module->getContext()), OutputFileName, "output",
TrueBlock);
auto ir = new llvm::IRBuilder(TrueBlock);
llvm::BasicBlock *MergeBlock
= llvm::BasicBlock::Create(Module->getContext(), "tail", func);
llvm::BranchInst::Create(TrueBlock, MergeBlock, proofOutput, CurrentBlock);

std::ostringstream symbolName;
pattern->getConstructor()->print(symbolName);

auto symbolString
= ir->CreateGlobalStringPtr(symbolName.str(), "", 0, Module);
auto positionString = ir->CreateGlobalStringPtr(locationStack, "", 0, Module);
writeUInt64(outputFile, Module, 0xdddddddddddddddd, TrueBlock);
ir->CreateCall(
getOrInsertFunction(
Module, "printVariableToFile",
llvm::Type::getVoidTy(Module->getContext()),
llvm::Type::getInt8PtrTy(Module->getContext()),
llvm::Type::getInt8PtrTy(Module->getContext())),
{outputFile, symbolString});
ir->CreateCall(
getOrInsertFunction(
Module, "printVariableToFile",
llvm::Type::getVoidTy(Module->getContext()),
llvm::Type::getInt8PtrTy(Module->getContext()),
llvm::Type::getInt8PtrTy(Module->getContext())),
{outputFile, positionString});

llvm::BranchInst::Create(MergeBlock, TrueBlock);
CurrentBlock = MergeBlock;
auto event = ProofEvent(Definition, Module);
CurrentBlock = event.functionEvent(CurrentBlock, pattern, locationStack);

return createFunctionCall(name, returnCat, args, sret, tailcc, locationStack);
}
Expand Down Expand Up @@ -932,13 +888,12 @@ CreateTerm::createAllocation(KOREPattern *pattern, std::string locationStack) {
.get());
std::string name = strPattern->getContents();

ProofEvent p1(Definition, CurrentBlock, Module);
CurrentBlock = p1.hookEvent_pre(name);
ProofEvent p(Definition, Module);
CurrentBlock = p.hookEvent_pre(name, CurrentBlock);
llvm::Value *val = createHook(
symbolDecl->getAttributes().at("hook").get(), constructor,
locationStack);
ProofEvent p2(Definition, CurrentBlock, Module);
CurrentBlock = p2.hookEvent_post(val, sort);
CurrentBlock = p.hookEvent_post(val, sort, CurrentBlock);

return std::make_pair(val, true);
} else {
Expand Down Expand Up @@ -1114,91 +1069,10 @@ bool makeFunction(

auto CurrentBlock = creator.getCurrentBlock();
if (apply && bigStep) {
auto ProofOutputFlag = Module->getOrInsertGlobal(
"proof_output", llvm::Type::getInt1Ty(Module->getContext()));
auto OutputFileName = Module->getOrInsertGlobal(
"output_file", llvm::Type::getInt8PtrTy(Module->getContext()));
auto proofOutput = new llvm::LoadInst(
llvm::Type::getInt1Ty(Module->getContext()), ProofOutputFlag,
"proof_output", CurrentBlock);
llvm::BasicBlock *TrueBlock
= llvm::BasicBlock::Create(Module->getContext(), "if", applyRule);
auto ir = new llvm::IRBuilder(TrueBlock);
llvm::BasicBlock *MergeBlock
= llvm::BasicBlock::Create(Module->getContext(), "tail", applyRule);
llvm::BranchInst::Create(TrueBlock, MergeBlock, proofOutput, CurrentBlock);
auto outputFile = new llvm::LoadInst(
llvm::Type::getInt8PtrTy(Module->getContext()), OutputFileName,
"output", TrueBlock);
writeUInt64(outputFile, Module, axiom->getOrdinal(), TrueBlock);
writeUInt64(
outputFile, Module, applyRule->arg_end() - applyRule->arg_begin(),
TrueBlock);
for (auto entry = subst.begin(); entry != subst.end(); ++entry) {
auto key = entry->getKey();
auto val = entry->getValue();
auto var = vars[key.str()];
auto sort = dynamic_cast<KORECompositeSort *>(var->getSort().get());
auto cat = sort->getCategory(definition);
std::ostringstream Out;
sort->print(Out);
auto sortptr = ir->CreateGlobalStringPtr(Out.str(), "", 0, Module);
auto varname = ir->CreateGlobalStringPtr(key, "", 0, Module);
ir->CreateCall(
getOrInsertFunction(
Module, "printVariableToFile",
llvm::Type::getVoidTy(Module->getContext()),
llvm::Type::getInt8PtrTy(Module->getContext()),
llvm::Type::getInt8PtrTy(Module->getContext())),
{outputFile, varname});
if (cat.cat == SortCategory::Symbol
|| cat.cat == SortCategory::Variable) {
ir->CreateCall(
getOrInsertFunction(
Module, "serializeTermToFile",
llvm::Type::getVoidTy(Module->getContext()),
llvm::Type::getInt8PtrTy(Module->getContext()),
getValueType({SortCategory::Symbol, 0}, Module),
llvm::Type::getInt8PtrTy(Module->getContext())),
{outputFile, val, sortptr});
} else if (val->getType()->isIntegerTy()) {
val = ir->CreateIntToPtr(
val, llvm::Type::getInt8PtrTy(Module->getContext()));
ir->CreateCall(
getOrInsertFunction(
Module, "serializeRawTermToFile",
llvm::Type::getVoidTy(Module->getContext()),
llvm::Type::getInt8PtrTy(Module->getContext()),
llvm::Type::getInt8PtrTy(Module->getContext()),
llvm::Type::getInt8PtrTy(Module->getContext())),
{outputFile, val, sortptr});
} else {
val = ir->CreatePointerCast(
val, llvm::Type::getInt8PtrTy(Module->getContext()));
ir->CreateCall(
getOrInsertFunction(
Module, "serializeRawTermToFile",
llvm::Type::getVoidTy(Module->getContext()),
llvm::Type::getInt8PtrTy(Module->getContext()),
llvm::Type::getInt8PtrTy(Module->getContext()),
llvm::Type::getInt8PtrTy(Module->getContext())),
{outputFile, val, sortptr});
}
writeUInt64(outputFile, Module, 0xcccccccccccccccc, TrueBlock);
}

writeUInt64(outputFile, Module, 0xffffffffffffffff, TrueBlock);
ir->CreateCall(
getOrInsertFunction(
Module, "serializeConfigurationToFile",
llvm::Type::getVoidTy(Module->getContext()),
llvm::Type::getInt8PtrTy(Module->getContext()),
getValueType({SortCategory::Symbol, 0}, Module)),
{outputFile, retval});
writeUInt64(outputFile, Module, 0xcccccccccccccccc, TrueBlock);

llvm::BranchInst::Create(MergeBlock, TrueBlock);
CurrentBlock = MergeBlock;
auto event = ProofEvent(definition, Module);
CurrentBlock = event.rewriteEvent(
axiom, retval, applyRule->arg_end() - applyRule->arg_begin(), vars,
subst, CurrentBlock);
}

if (bigStep) {
Expand Down
2 changes: 1 addition & 1 deletion lib/codegen/EmitConfigParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1119,7 +1119,7 @@ static llvm::Constant *getOffsetOfMember(
auto offset
= llvm::DataLayout(mod).getStructLayout(struct_ty)->getElementOffset(
nth_member);
auto offset_ty = llvm::Type::getInt32Ty(mod->getContext());
auto offset_ty = llvm::Type::getInt64Ty(mod->getContext());
return llvm::ConstantInt::get(offset_ty, offset);
#else
return llvm::ConstantExpr::getOffsetOf(struct_ty, nth_member);
Expand Down
Loading

0 comments on commit 0e24009

Please sign in to comment.