From 2caa0fc736d8c8b75159c7d12334e83412ab503b Mon Sep 17 00:00:00 2001 From: Bruce Collie Date: Fri, 1 Dec 2023 13:18:11 +0000 Subject: [PATCH 01/10] Make ast_to_string forward args --- include/kllvm/ast/AST.h | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/include/kllvm/ast/AST.h b/include/kllvm/ast/AST.h index 4f3afc20f..a7fc19241 100644 --- a/include/kllvm/ast/AST.h +++ b/include/kllvm/ast/AST.h @@ -41,10 +41,10 @@ std::string decodeKore(std::string); * just want the string representation of a node, rather than to print it to a * stream. */ -template -std::string ast_to_string(T &&node) { +template +std::string ast_to_string(T &&node, Args &&...args) { auto os = std::ostringstream{}; - std::forward(node).print(os); + std::forward(node).print(os, std::forward(args)...); return os.str(); } From 46535ee3b2a5547769919ca0ca7feb35df6e6b1c Mon Sep 17 00:00:00 2001 From: Bruce Collie Date: Fri, 1 Dec 2023 13:02:49 +0000 Subject: [PATCH 02/10] Refactor include/kllvm/ast/AST.h --- include/kllvm/ast/AST.h | 17 ++++------------- 1 file changed, 4 insertions(+), 13 deletions(-) diff --git a/include/kllvm/ast/AST.h b/include/kllvm/ast/AST.h index a7fc19241..9c0bb604e 100644 --- a/include/kllvm/ast/AST.h +++ b/include/kllvm/ast/AST.h @@ -74,9 +74,7 @@ static inline std::ostream &operator<<(std::ostream &out, const KORESort &s) { struct HashSort { size_t operator()(const kllvm::KORESort &s) const noexcept { - std::ostringstream Out; - s.print(Out); - return std::hash{}(Out.str()); + return std::hash{}(ast_to_string(s)); } }; @@ -88,9 +86,7 @@ struct EqualSortPtr { struct HashSortPtr { size_t operator()(kllvm::KORESort *const &s) const noexcept { - std::ostringstream Out; - s->print(Out); - return std::hash{}(Out.str()); + return std::hash{}(ast_to_string(*s)); } }; @@ -293,18 +289,13 @@ struct HashSymbol { struct EqualSymbolPtr { bool operator()(KORESymbol *const &first, KORESymbol *const &second) const { - std::ostringstream Out1, Out2; - first->print(Out1); - second->print(Out2); - return Out1.str() == Out2.str(); + return ast_to_string(*first) == ast_to_string(*second); } }; struct HashSymbolPtr { size_t operator()(kllvm::KORESymbol *const &s) const noexcept { - std::ostringstream Out; - s->print(Out); - return std::hash{}(Out.str()); + return std::hash{}(ast_to_string(*s)); } }; From cf63b3d3adcd06cde05c675da6b103ee4c79afa5 Mon Sep 17 00:00:00 2001 From: Bruce Collie Date: Fri, 1 Dec 2023 13:03:00 +0000 Subject: [PATCH 03/10] Refactor tools/llvm-kompile-codegen/main.cpp --- tools/llvm-kompile-codegen/main.cpp | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/tools/llvm-kompile-codegen/main.cpp b/tools/llvm-kompile-codegen/main.cpp index 2225d0966..dfeb4d78d 100644 --- a/tools/llvm-kompile-codegen/main.cpp +++ b/tools/llvm-kompile-codegen/main.cpp @@ -189,11 +189,10 @@ int main(int argc, char **argv) { auto funcDt = parseYamlDecisionTree( mod.get(), filename, definition->getAllSymbols(), definition->getHookedSorts()); - std::ostringstream Out; - decl->getSymbol()->print(Out); + makeAnywhereFunction( - definition->getAllSymbols().at(Out.str()), definition.get(), - mod.get(), funcDt); + definition->getAllSymbols().at(ast_to_string(*decl->getSymbol())), + definition.get(), mod.get(), funcDt); } } From 2c268d4c92775bde5624f789a1f00a1110a29645 Mon Sep 17 00:00:00 2001 From: Bruce Collie Date: Fri, 1 Dec 2023 13:04:59 +0000 Subject: [PATCH 04/10] Refactor lib/codegen/EmitConfigParser.cpp --- lib/codegen/EmitConfigParser.cpp | 41 ++++++++++++++------------------ 1 file changed, 18 insertions(+), 23 deletions(-) diff --git a/lib/codegen/EmitConfigParser.cpp b/lib/codegen/EmitConfigParser.cpp index 70b68103a..e717c5319 100644 --- a/lib/codegen/EmitConfigParser.cpp +++ b/lib/codegen/EmitConfigParser.cpp @@ -24,6 +24,8 @@ #include #include +#include + #include #include #include @@ -43,14 +45,13 @@ namespace kllvm { static llvm::Constant *getSymbolNamePtr( KORESymbol *symbol, llvm::BasicBlock *SetBlockName, llvm::Module *module) { llvm::LLVMContext &Ctx = module->getContext(); - std::ostringstream Out; - symbol->print(Out); + auto name = ast_to_string(*symbol); if (SetBlockName) { - SetBlockName->setName(Out.str()); + SetBlockName->setName(name); } - auto Str = llvm::ConstantDataArray::getString(Ctx, Out.str(), true); - auto global - = module->getOrInsertGlobal("sym_name_" + Out.str(), Str->getType()); + auto Str = llvm::ConstantDataArray::getString(Ctx, name, true); + auto global = module->getOrInsertGlobal( + fmt::format("sym_name_{}", name), Str->getType()); llvm::GlobalVariable *globalVar = llvm::dyn_cast(global); if (!globalVar->hasInitializer()) { @@ -457,12 +458,10 @@ emitGetTagForFreshSort(KOREDefinition *definition, llvm::Module *module) { auto CaseBlock = llvm::BasicBlock::Create(Ctx, name, func); llvm::BranchInst::Create(CaseBlock, FalseBlock, icmp, CurrentBlock); auto symbol = definition->getFreshFunctions().at(name); - std::ostringstream Out; - symbol->print(Out); Phi->addIncoming( llvm::ConstantInt::get( llvm::Type::getInt32Ty(Ctx), - definition->getAllSymbols().at(Out.str())->getTag()), + definition->getAllSymbols().at(ast_to_string(*symbol))->getTag()), CaseBlock); llvm::BranchInst::Create(MergeBlock, CaseBlock); CurrentBlock = FalseBlock; @@ -954,11 +953,10 @@ static void getVisitor( "", CaseBlock); llvm::Value *Child = new llvm::LoadInst( getValueType(cat, module), ChildPtr, "", CaseBlock); - std::ostringstream Out; - sort->print(Out); - auto Str = llvm::ConstantDataArray::getString(Ctx, Out.str(), true); - auto global - = module->getOrInsertGlobal("sort_name_" + Out.str(), Str->getType()); + auto sort_name = ast_to_string(*sort); + auto Str = llvm::ConstantDataArray::getString(Ctx, sort_name, true); + auto global = module->getOrInsertGlobal( + fmt::format("sort_name_{}", sort_name), Str->getType()); llvm::GlobalVariable *globalVar = llvm::dyn_cast(global); if (!globalVar->hasInitializer()) { @@ -1270,10 +1268,8 @@ static void emitSortTable(KOREDefinition *definition, llvm::Module *module) { auto subtableType = llvm::ArrayType::get( llvm::Type::getInt8PtrTy(Ctx), symbol->getArguments().size()); - std::ostringstream Out; - symbol->print(Out); - auto subtable - = module->getOrInsertGlobal("sorts_" + Out.str(), subtableType); + auto subtable = module->getOrInsertGlobal( + fmt::format("sorts_{}", ast_to_string(*symbol)), subtableType); llvm::GlobalVariable *subtableVar = llvm::dyn_cast(subtable); initDebugGlobal( @@ -1291,12 +1287,11 @@ static void emitSortTable(KOREDefinition *definition, llvm::Module *module) { std::vector subvalues; for (size_t i = 0; i < symbol->getArguments().size(); ++i) { - std::ostringstream Out; - symbol->getArguments()[i]->print(Out); + auto arg_str = ast_to_string(*symbol->getArguments()[i]); auto strType = llvm::ArrayType::get( - llvm::Type::getInt8Ty(Ctx), Out.str().size() + 1); - auto sortName - = module->getOrInsertGlobal("sort_name_" + Out.str(), strType); + llvm::Type::getInt8Ty(Ctx), arg_str.size() + 1); + auto sortName = module->getOrInsertGlobal( + fmt::format("sort_name_{}", arg_str), strType); subvalues.push_back(llvm::ConstantExpr::getInBoundsGetElementPtr( strType, sortName, indices)); } From c9b0139a8bb7daf793ad5743195a951993a39f9f Mon Sep 17 00:00:00 2001 From: Bruce Collie Date: Fri, 1 Dec 2023 13:15:21 +0000 Subject: [PATCH 05/10] Refactor runtime/util/ConfigurationParser.cpp --- runtime/util/ConfigurationParser.cpp | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/runtime/util/ConfigurationParser.cpp b/runtime/util/ConfigurationParser.cpp index 25847633a..25e188e8e 100644 --- a/runtime/util/ConfigurationParser.cpp +++ b/runtime/util/ConfigurationParser.cpp @@ -44,9 +44,8 @@ uint32_t getTagForSymbolName(const char *name) { } static uint32_t getTagForSymbol(KORESymbol const &symbol) { - std::ostringstream out; - symbol.print(out); - return getTagForSymbolName(out.str().c_str()); + auto name = ast_to_string(symbol); + return getTagForSymbolName(name.c_str()); } void *constructCompositePattern(uint32_t tag, std::vector &arguments) { From 5ee1b647a5b0991aca341d883140c40c91681040 Mon Sep 17 00:00:00 2001 From: Bruce Collie Date: Fri, 1 Dec 2023 13:16:42 +0000 Subject: [PATCH 06/10] Refactor lib/codegen/ProofEvent.cpp --- lib/codegen/ProofEvent.cpp | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/lib/codegen/ProofEvent.cpp b/lib/codegen/ProofEvent.cpp index 65d458898..37893ebef 100644 --- a/lib/codegen/ProofEvent.cpp +++ b/lib/codegen/ProofEvent.cpp @@ -18,10 +18,8 @@ namespace { template llvm::Constant *createGlobalSortStringPtr( IRBuilder &B, KORECompositeSort &sort, llvm::Module *mod) { - auto os = std::ostringstream{}; - sort.print(os); return B.CreateGlobalStringPtr( - os.str(), fmt::format("{}_str", sort.getName()), 0, mod); + ast_to_string(sort), fmt::format("{}_str", sort.getName()), 0, mod); } constexpr uint64_t word(uint8_t byte) { @@ -293,11 +291,9 @@ llvm::BasicBlock *ProofEvent::functionEvent_pre( auto [true_block, merge_block, outputFile] = eventPrelude("function_pre", current_block); - std::ostringstream symbolName; - pattern->getConstructor()->print(symbolName); - emitWriteUInt64(outputFile, word(0xDD), true_block); - emitWriteString(outputFile, symbolName.str(), true_block); + emitWriteString( + outputFile, ast_to_string(*pattern->getConstructor()), true_block); emitWriteString(outputFile, locationStack, true_block); llvm::BranchInst::Create(merge_block, true_block); From 4deae821bb131c5a75c7fffbc51fe0ba0d285319 Mon Sep 17 00:00:00 2001 From: Bruce Collie Date: Fri, 1 Dec 2023 13:22:55 +0000 Subject: [PATCH 07/10] Refactor lib/codegen/CreateTerm.cpp --- lib/codegen/CreateTerm.cpp | 26 ++++++++++---------------- 1 file changed, 10 insertions(+), 16 deletions(-) diff --git a/lib/codegen/CreateTerm.cpp b/lib/codegen/CreateTerm.cpp index 0efda6b4f..98b263c2b 100644 --- a/lib/codegen/CreateTerm.cpp +++ b/lib/codegen/CreateTerm.cpp @@ -652,10 +652,9 @@ llvm::Value *CreateTerm::createHook( std::string domain = name.substr(0, name.find('.')); if (domain == "ARRAY") { // array is not really hooked in llvm, it's implemented in K - std::ostringstream Out; - pattern->getConstructor()->print(Out, 0, false); - return createFunctionCall( - "eval_" + Out.str(), pattern, false, true, locationStack); + auto fn_name = fmt::format( + "eval_{}", ast_to_string(*pattern->getConstructor(), 0, false)); + return createFunctionCall(fn_name, pattern, false, true, locationStack); } std::string hookName = "hook_" + domain + "_" + name.substr(name.find('.') + 1); @@ -900,11 +899,10 @@ CreateTerm::createAllocation(KOREPattern *pattern, std::string locationStack) { return std::make_pair(val, true); } else { - std::ostringstream Out; - symbol->print(Out, 0, false); + auto fn_name = fmt::format("eval_{}", ast_to_string(*symbol, 0, false)); return std::make_pair( createFunctionCall( - "eval_" + Out.str(), constructor, false, true, locationStack), + fn_name, constructor, false, true, locationStack), true); } } else if (auto cat = dynamic_cast( @@ -1008,10 +1006,8 @@ bool makeFunction( return false; } auto cat = sort->getCategory(definition); - std::ostringstream Out; - sort->print(Out); llvm::Type *paramType = getValueType(cat, Module); - debugArgs.push_back(getDebugType(cat, Out.str())); + debugArgs.push_back(getDebugType(cat, ast_to_string(*sort))); switch (cat.cat) { case SortCategory::Map: case SortCategory::RangeMap: @@ -1045,11 +1041,11 @@ bool makeFunction( if (axiom->getAttributes().count("label")) { debugName = axiom->getStringAttribute("label") + postfix; } - std::ostringstream Out; - termSort(pattern)->print(Out); initDebugFunction( debugName, debugName, - getDebugFunctionType(getDebugType(returnCat, Out.str()), debugArgs), + getDebugFunctionType( + getDebugType(returnCat, ast_to_string(*termSort(pattern))), + debugArgs), definition, applyRule); if (tailcc) { applyRule->setCallingConv(llvm::CallingConv::Tail); @@ -1129,10 +1125,8 @@ std::string makeApplyRuleFunction( return ""; } auto cat = sort->getCategory(definition); - std::ostringstream Out; - sort->print(Out); llvm::Type *paramType = getValueType(cat, Module); - debugArgs.push_back(getDebugType(cat, Out.str())); + debugArgs.push_back(getDebugType(cat, ast_to_string(*sort))); switch (cat.cat) { case SortCategory::Map: case SortCategory::RangeMap: From c6e01ceac190ce8208624d9bd9660979367cce37 Mon Sep 17 00:00:00 2001 From: Bruce Collie Date: Fri, 1 Dec 2023 13:50:16 +0000 Subject: [PATCH 08/10] Refactor lib/codegen/Decision.cpp --- lib/codegen/Decision.cpp | 54 ++++++++++++++++------------------------ 1 file changed, 22 insertions(+), 32 deletions(-) diff --git a/lib/codegen/Decision.cpp b/lib/codegen/Decision.cpp index cb7630047..cbfa9e44f 100644 --- a/lib/codegen/Decision.cpp +++ b/lib/codegen/Decision.cpp @@ -29,11 +29,14 @@ #include #include +#include + #include #include #include #include #include + namespace kllvm { static std::string LAYOUTITEM_STRUCT = "layoutitem"; @@ -104,22 +107,19 @@ getFailPattern(DecisionCase const &_case, bool isInt) { + std::to_string(bitwidth) + "\")"); } } else { - std::ostringstream symbol; - _case.getConstructor()->print(symbol); - std::ostringstream returnSort; - _case.getConstructor()->getSort()->print(returnSort); - std::string result = symbol.str() + "("; + auto result = fmt::format("{}(", ast_to_string(*_case.getConstructor())); + std::string conn = ""; for (int i = 0; i < _case.getConstructor()->getArguments().size(); i++) { - result += conn; - result += "Var'Unds'"; - std::ostringstream argSort; - _case.getConstructor()->getArguments()[i]->print(argSort); - result += ":" + argSort.str(); + result += fmt::format( + "{}Var'Unds':{}", conn, + ast_to_string(*_case.getConstructor()->getArguments()[i])); conn = ","; } result += ")"; - return std::make_pair(returnSort.str(), result); + + auto return_sort = ast_to_string(*_case.getConstructor()->getSort()); + return std::make_pair(return_sort, result); } } @@ -732,18 +732,15 @@ void makeEvalOrAnywhereFunction( auto returnSort = dynamic_cast(function->getSort().get()) ->getCategory(definition); auto returnType = getParamType(returnSort, module); - std::ostringstream Out; - function->getSort()->print(Out); - auto debugReturnType = getDebugType(returnSort, Out.str()); + auto debugReturnType + = getDebugType(returnSort, ast_to_string(*function->getSort())); std::vector args; std::vector debugArgs; std::vector cats; for (auto &sort : function->getArguments()) { auto cat = dynamic_cast(sort.get()) ->getCategory(definition); - std::ostringstream Out; - sort->print(Out); - debugArgs.push_back(getDebugType(cat, Out.str())); + debugArgs.push_back(getDebugType(cat, ast_to_string(*sort))); switch (cat.cat) { case SortCategory::Map: case SortCategory::RangeMap: @@ -760,9 +757,7 @@ void makeEvalOrAnywhereFunction( } llvm::FunctionType *funcType = llvm::FunctionType::get(returnType, args, false); - std::ostringstream Out2; - function->print(Out2, 0, false); - std::string name = "eval_" + Out2.str(); + std::string name = fmt::format("eval_{}", ast_to_string(*function, 0, false)); llvm::Function *matchFunc = getOrInsertFunction(module, name, funcType); KORESymbolDeclaration *symbolDecl = definition->getSymbolDeclarations().at(function->getName()); @@ -791,9 +786,9 @@ void makeEvalOrAnywhereFunction( ++val, ++i) { val->setName("_" + std::to_string(i + 1)); codegen.store(std::make_pair(val->getName().str(), val->getType()), val); - std::ostringstream Out; - function->getArguments()[i]->print(Out); - initDebugParam(matchFunc, i, val->getName().str(), cats[i], Out.str()); + initDebugParam( + matchFunc, i, val->getName().str(), cats[i], + ast_to_string(*function->getArguments()[i])); } addStuck(stuck, module, function, codegen, definition); @@ -804,9 +799,7 @@ void abortWhenStuck( llvm::BasicBlock *CurrentBlock, llvm::Module *Module, KORESymbol *symbol, Decision &codegen, KOREDefinition *d) { auto &Ctx = Module->getContext(); - std::ostringstream Out; - symbol->print(Out); - symbol = d->getAllSymbols().at(Out.str()); + symbol = d->getAllSymbols().at(ast_to_string(*symbol)); auto BlockType = getBlockType(Module, d, symbol); llvm::Value *Ptr; auto BlockPtr = llvm::PointerType::getUnqual( @@ -1276,9 +1269,7 @@ void makeStepFunction( auto argSort = dynamic_cast(res.pattern->getSort().get()); auto cat = argSort->getCategory(definition); - std::ostringstream Out; - argSort->print(Out); - debugTypes.push_back(getDebugType(cat, Out.str())); + debugTypes.push_back(getDebugType(cat, ast_to_string(*argSort))); switch (cat.cat) { case SortCategory::Map: case SortCategory::RangeMap: @@ -1334,9 +1325,8 @@ void makeStepFunction( auto cat = dynamic_cast(sort.get()) ->getCategory(definition); types.push_back(cat); - std::ostringstream Out; - sort->print(Out); - initDebugParam(matchFunc, i, "_" + std::to_string(i + 1), cat, Out.str()); + initDebugParam( + matchFunc, i, "_" + std::to_string(i + 1), cat, ast_to_string(*sort)); } auto header = stepFunctionHeader( axiom->getOrdinal(), module, definition, block, stuck, args, types); From 628c64b0d90e26fbcc8aa7a44b45a81e2ade54c8 Mon Sep 17 00:00:00 2001 From: Bruce Collie Date: Fri, 1 Dec 2023 13:52:46 +0000 Subject: [PATCH 09/10] Refactor lib/ast/AST.cpp --- lib/ast/AST.cpp | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/lib/ast/AST.cpp b/lib/ast/AST.cpp index fa9372d94..86694ff85 100644 --- a/lib/ast/AST.cpp +++ b/lib/ast/AST.cpp @@ -970,9 +970,7 @@ sptr KORECompositePattern::dedupeDisjuncts(void) { flatten(this, "\\or", items); std::set printed; for (sptr item : items) { - std::ostringstream Out; - item->print(Out); - if (printed.insert(Out.str()).second) { + if (printed.insert(ast_to_string(*item)).second) { dedupedItems.push_back(item); } } @@ -1170,10 +1168,7 @@ bool KOREVariablePattern::matches( substitution &subst, SubsortMap const &subsorts, SymbolMap const &overloads, sptr subject) { if (subst[name->getName()]) { - std::ostringstream Out1, Out2; - subst[name->getName()]->print(Out1); - subject->print(Out2); - return Out1.str() == Out2.str(); + return ast_to_string(*subst[name->getName()]) == ast_to_string(*subject); } else { subst[name->getName()] = subject; return true; @@ -1796,9 +1791,7 @@ void KOREDefinition::preprocess() { symbol->firstTag = symbol->lastTag = instantiations.at(*symbol); symbol->layout = layouts.at(layoutStr); objectSymbols[symbol->firstTag] = symbol; - std::ostringstream Out; - symbol->print(Out); - allObjectSymbols[Out.str()] = symbol; + allObjectSymbols[ast_to_string(*symbol)] = symbol; } } uint32_t lastTag = nextSymbol - 1; From 7ed2684b1d1bdd7a8f528000815628ca014e6ef5 Mon Sep 17 00:00:00 2001 From: Bruce Collie Date: Fri, 1 Dec 2023 13:55:07 +0000 Subject: [PATCH 10/10] Refactor bindings/c/lib.cpp --- bindings/c/lib.cpp | 20 +++++--------------- 1 file changed, 5 insertions(+), 15 deletions(-) diff --git a/bindings/c/lib.cpp b/bindings/c/lib.cpp index 9bd959254..0d8e14eb2 100644 --- a/bindings/c/lib.cpp +++ b/bindings/c/lib.cpp @@ -25,8 +25,7 @@ namespace fs = std::filesystem; // Internal implementation details namespace { -template -char *get_c_string(OS const &); +char *get_c_string(std::string const &); kore_pattern *kore_string_pattern_new_internal(std::string const &); @@ -85,9 +84,7 @@ struct kore_symbol { /* KOREPattern */ char *kore_pattern_dump(kore_pattern const *pat) { - auto os = std::ostringstream{}; - pat->ptr_->print(os); - return get_c_string(os); + return get_c_string(ast_to_string(*pat->ptr_)); } char *kore_pattern_pretty_print(kore_pattern const *pat) { @@ -320,9 +317,7 @@ kore_string_pattern_new_with_len(char const *contents, size_t len) { /* KORESort */ char *kore_sort_dump(kore_sort const *sort) { - auto os = std::ostringstream{}; - sort->ptr_->print(os); - return get_c_string(os); + return get_c_string(ast_to_string(*sort->ptr_)); } void kore_sort_free(kore_sort const *sort) { @@ -372,9 +367,7 @@ void kore_symbol_free(kore_symbol const *sym) { } char *kore_symbol_dump(kore_symbol const *sym) { - auto os = std::ostringstream{}; - sym->ptr_->print(os); - return get_c_string(os); + return get_c_string(ast_to_string(*sym->ptr_)); } void kore_symbol_add_formal_argument(kore_symbol *sym, kore_sort const *sort) { @@ -394,10 +387,7 @@ void kllvm_free_all_memory(void) { namespace { -template -char *get_c_string(OS const &os) { - auto str = os.str(); - +char *get_c_string(std::string const &str) { // Include null terminator auto total_length = str.length() + 1;