Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor usages of std::ostringstream #912

Merged
merged 10 commits into from
Dec 1, 2023
20 changes: 5 additions & 15 deletions bindings/c/lib.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@ namespace fs = std::filesystem;
// Internal implementation details
namespace {

template <typename OS>
char *get_c_string(OS const &);
char *get_c_string(std::string const &);

kore_pattern *kore_string_pattern_new_internal(std::string const &);

Expand Down Expand Up @@ -85,9 +84,7 @@ struct kore_symbol {
/* KOREPattern */

char *kore_pattern_dump(kore_pattern const *pat) {
auto os = std::ostringstream{};
pat->ptr_->print(os);
return get_c_string(os);
return get_c_string(ast_to_string(*pat->ptr_));
}

char *kore_pattern_pretty_print(kore_pattern const *pat) {
Expand Down Expand Up @@ -320,9 +317,7 @@ kore_string_pattern_new_with_len(char const *contents, size_t len) {
/* KORESort */

char *kore_sort_dump(kore_sort const *sort) {
auto os = std::ostringstream{};
sort->ptr_->print(os);
return get_c_string(os);
return get_c_string(ast_to_string(*sort->ptr_));
}

void kore_sort_free(kore_sort const *sort) {
Expand Down Expand Up @@ -372,9 +367,7 @@ void kore_symbol_free(kore_symbol const *sym) {
}

char *kore_symbol_dump(kore_symbol const *sym) {
auto os = std::ostringstream{};
sym->ptr_->print(os);
return get_c_string(os);
return get_c_string(ast_to_string(*sym->ptr_));
}

void kore_symbol_add_formal_argument(kore_symbol *sym, kore_sort const *sort) {
Expand All @@ -394,10 +387,7 @@ void kllvm_free_all_memory(void) {

namespace {

template <typename OS>
char *get_c_string(OS const &os) {
auto str = os.str();

char *get_c_string(std::string const &str) {
// Include null terminator
auto total_length = str.length() + 1;

Expand Down
23 changes: 7 additions & 16 deletions include/kllvm/ast/AST.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,10 @@ std::string decodeKore(std::string);
* just want the string representation of a node, rather than to print it to a
* stream.
*/
template <typename T>
std::string ast_to_string(T &&node) {
template <typename T, typename... Args>
std::string ast_to_string(T &&node, Args &&...args) {
auto os = std::ostringstream{};
std::forward<T>(node).print(os);
std::forward<T>(node).print(os, std::forward<Args>(args)...);
return os.str();
}

Expand Down Expand Up @@ -74,9 +74,7 @@ static inline std::ostream &operator<<(std::ostream &out, const KORESort &s) {

struct HashSort {
size_t operator()(const kllvm::KORESort &s) const noexcept {
std::ostringstream Out;
s.print(Out);
return std::hash<std::string>{}(Out.str());
return std::hash<std::string>{}(ast_to_string(s));
}
};

Expand All @@ -88,9 +86,7 @@ struct EqualSortPtr {

struct HashSortPtr {
size_t operator()(kllvm::KORESort *const &s) const noexcept {
std::ostringstream Out;
s->print(Out);
return std::hash<std::string>{}(Out.str());
return std::hash<std::string>{}(ast_to_string(*s));
}
};

Expand Down Expand Up @@ -293,18 +289,13 @@ struct HashSymbol {

struct EqualSymbolPtr {
bool operator()(KORESymbol *const &first, KORESymbol *const &second) const {
std::ostringstream Out1, Out2;
first->print(Out1);
second->print(Out2);
return Out1.str() == Out2.str();
return ast_to_string(*first) == ast_to_string(*second);
}
};

struct HashSymbolPtr {
size_t operator()(kllvm::KORESymbol *const &s) const noexcept {
std::ostringstream Out;
s->print(Out);
return std::hash<std::string>{}(Out.str());
return std::hash<std::string>{}(ast_to_string(*s));
}
};

Expand Down
13 changes: 3 additions & 10 deletions lib/ast/AST.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -970,9 +970,7 @@ sptr<KOREPattern> KORECompositePattern::dedupeDisjuncts(void) {
flatten(this, "\\or", items);
std::set<std::string> printed;
for (sptr<KOREPattern> item : items) {
std::ostringstream Out;
item->print(Out);
if (printed.insert(Out.str()).second) {
if (printed.insert(ast_to_string(*item)).second) {
dedupedItems.push_back(item);
}
}
Expand Down Expand Up @@ -1170,10 +1168,7 @@ bool KOREVariablePattern::matches(
substitution &subst, SubsortMap const &subsorts, SymbolMap const &overloads,
sptr<KOREPattern> subject) {
if (subst[name->getName()]) {
std::ostringstream Out1, Out2;
subst[name->getName()]->print(Out1);
subject->print(Out2);
return Out1.str() == Out2.str();
return ast_to_string(*subst[name->getName()]) == ast_to_string(*subject);
} else {
subst[name->getName()] = subject;
return true;
Expand Down Expand Up @@ -1796,9 +1791,7 @@ void KOREDefinition::preprocess() {
symbol->firstTag = symbol->lastTag = instantiations.at(*symbol);
symbol->layout = layouts.at(layoutStr);
objectSymbols[symbol->firstTag] = symbol;
std::ostringstream Out;
symbol->print(Out);
allObjectSymbols[Out.str()] = symbol;
allObjectSymbols[ast_to_string(*symbol)] = symbol;
}
}
uint32_t lastTag = nextSymbol - 1;
Expand Down
26 changes: 10 additions & 16 deletions lib/codegen/CreateTerm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -652,10 +652,9 @@ llvm::Value *CreateTerm::createHook(
std::string domain = name.substr(0, name.find('.'));
if (domain == "ARRAY") {
// array is not really hooked in llvm, it's implemented in K
std::ostringstream Out;
pattern->getConstructor()->print(Out, 0, false);
return createFunctionCall(
"eval_" + Out.str(), pattern, false, true, locationStack);
auto fn_name = fmt::format(
"eval_{}", ast_to_string(*pattern->getConstructor(), 0, false));
return createFunctionCall(fn_name, pattern, false, true, locationStack);
}
std::string hookName
= "hook_" + domain + "_" + name.substr(name.find('.') + 1);
Expand Down Expand Up @@ -900,11 +899,10 @@ CreateTerm::createAllocation(KOREPattern *pattern, std::string locationStack) {

return std::make_pair(val, true);
} else {
std::ostringstream Out;
symbol->print(Out, 0, false);
auto fn_name = fmt::format("eval_{}", ast_to_string(*symbol, 0, false));
return std::make_pair(
createFunctionCall(
"eval_" + Out.str(), constructor, false, true, locationStack),
fn_name, constructor, false, true, locationStack),
true);
}
} else if (auto cat = dynamic_cast<KORECompositeSort *>(
Expand Down Expand Up @@ -1008,10 +1006,8 @@ bool makeFunction(
return false;
}
auto cat = sort->getCategory(definition);
std::ostringstream Out;
sort->print(Out);
llvm::Type *paramType = getValueType(cat, Module);
debugArgs.push_back(getDebugType(cat, Out.str()));
debugArgs.push_back(getDebugType(cat, ast_to_string(*sort)));
switch (cat.cat) {
case SortCategory::Map:
case SortCategory::RangeMap:
Expand Down Expand Up @@ -1045,11 +1041,11 @@ bool makeFunction(
if (axiom->getAttributes().count("label")) {
debugName = axiom->getStringAttribute("label") + postfix;
}
std::ostringstream Out;
termSort(pattern)->print(Out);
initDebugFunction(
debugName, debugName,
getDebugFunctionType(getDebugType(returnCat, Out.str()), debugArgs),
getDebugFunctionType(
getDebugType(returnCat, ast_to_string(*termSort(pattern))),
debugArgs),
definition, applyRule);
if (tailcc) {
applyRule->setCallingConv(llvm::CallingConv::Tail);
Expand Down Expand Up @@ -1129,10 +1125,8 @@ std::string makeApplyRuleFunction(
return "";
}
auto cat = sort->getCategory(definition);
std::ostringstream Out;
sort->print(Out);
llvm::Type *paramType = getValueType(cat, Module);
debugArgs.push_back(getDebugType(cat, Out.str()));
debugArgs.push_back(getDebugType(cat, ast_to_string(*sort)));
switch (cat.cat) {
case SortCategory::Map:
case SortCategory::RangeMap:
Expand Down
54 changes: 22 additions & 32 deletions lib/codegen/Decision.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,14 @@
#include <llvm/IR/Value.h>
#include <llvm/Support/Casting.h>

#include <fmt/format.h>

#include <iostream>
#include <limits>
#include <memory>
#include <set>
#include <type_traits>

namespace kllvm {

static std::string LAYOUTITEM_STRUCT = "layoutitem";
Expand Down Expand Up @@ -104,22 +107,19 @@ getFailPattern(DecisionCase const &_case, bool isInt) {
+ std::to_string(bitwidth) + "\")");
}
} else {
std::ostringstream symbol;
_case.getConstructor()->print(symbol);
std::ostringstream returnSort;
_case.getConstructor()->getSort()->print(returnSort);
std::string result = symbol.str() + "(";
auto result = fmt::format("{}(", ast_to_string(*_case.getConstructor()));

std::string conn = "";
for (int i = 0; i < _case.getConstructor()->getArguments().size(); i++) {
result += conn;
result += "Var'Unds'";
std::ostringstream argSort;
_case.getConstructor()->getArguments()[i]->print(argSort);
result += ":" + argSort.str();
result += fmt::format(
"{}Var'Unds':{}", conn,
ast_to_string(*_case.getConstructor()->getArguments()[i]));
conn = ",";
}
result += ")";
return std::make_pair(returnSort.str(), result);

auto return_sort = ast_to_string(*_case.getConstructor()->getSort());
return std::make_pair(return_sort, result);
}
}

Expand Down Expand Up @@ -732,18 +732,15 @@ void makeEvalOrAnywhereFunction(
auto returnSort = dynamic_cast<KORECompositeSort *>(function->getSort().get())
->getCategory(definition);
auto returnType = getParamType(returnSort, module);
std::ostringstream Out;
function->getSort()->print(Out);
auto debugReturnType = getDebugType(returnSort, Out.str());
auto debugReturnType
= getDebugType(returnSort, ast_to_string(*function->getSort()));
std::vector<llvm::Type *> args;
std::vector<llvm::Metadata *> debugArgs;
std::vector<ValueType> cats;
for (auto &sort : function->getArguments()) {
auto cat = dynamic_cast<KORECompositeSort *>(sort.get())
->getCategory(definition);
std::ostringstream Out;
sort->print(Out);
debugArgs.push_back(getDebugType(cat, Out.str()));
debugArgs.push_back(getDebugType(cat, ast_to_string(*sort)));
switch (cat.cat) {
case SortCategory::Map:
case SortCategory::RangeMap:
Expand All @@ -760,9 +757,7 @@ void makeEvalOrAnywhereFunction(
}
llvm::FunctionType *funcType
= llvm::FunctionType::get(returnType, args, false);
std::ostringstream Out2;
function->print(Out2, 0, false);
std::string name = "eval_" + Out2.str();
std::string name = fmt::format("eval_{}", ast_to_string(*function, 0, false));
llvm::Function *matchFunc = getOrInsertFunction(module, name, funcType);
KORESymbolDeclaration *symbolDecl
= definition->getSymbolDeclarations().at(function->getName());
Expand Down Expand Up @@ -791,9 +786,9 @@ void makeEvalOrAnywhereFunction(
++val, ++i) {
val->setName("_" + std::to_string(i + 1));
codegen.store(std::make_pair(val->getName().str(), val->getType()), val);
std::ostringstream Out;
function->getArguments()[i]->print(Out);
initDebugParam(matchFunc, i, val->getName().str(), cats[i], Out.str());
initDebugParam(
matchFunc, i, val->getName().str(), cats[i],
ast_to_string(*function->getArguments()[i]));
}
addStuck(stuck, module, function, codegen, definition);

Expand All @@ -804,9 +799,7 @@ void abortWhenStuck(
llvm::BasicBlock *CurrentBlock, llvm::Module *Module, KORESymbol *symbol,
Decision &codegen, KOREDefinition *d) {
auto &Ctx = Module->getContext();
std::ostringstream Out;
symbol->print(Out);
symbol = d->getAllSymbols().at(Out.str());
symbol = d->getAllSymbols().at(ast_to_string(*symbol));
auto BlockType = getBlockType(Module, d, symbol);
llvm::Value *Ptr;
auto BlockPtr = llvm::PointerType::getUnqual(
Expand Down Expand Up @@ -1276,9 +1269,7 @@ void makeStepFunction(
auto argSort
= dynamic_cast<KORECompositeSort *>(res.pattern->getSort().get());
auto cat = argSort->getCategory(definition);
std::ostringstream Out;
argSort->print(Out);
debugTypes.push_back(getDebugType(cat, Out.str()));
debugTypes.push_back(getDebugType(cat, ast_to_string(*argSort)));
switch (cat.cat) {
case SortCategory::Map:
case SortCategory::RangeMap:
Expand Down Expand Up @@ -1334,9 +1325,8 @@ void makeStepFunction(
auto cat = dynamic_cast<KORECompositeSort *>(sort.get())
->getCategory(definition);
types.push_back(cat);
std::ostringstream Out;
sort->print(Out);
initDebugParam(matchFunc, i, "_" + std::to_string(i + 1), cat, Out.str());
initDebugParam(
matchFunc, i, "_" + std::to_string(i + 1), cat, ast_to_string(*sort));
}
auto header = stepFunctionHeader(
axiom->getOrdinal(), module, definition, block, stuck, args, types);
Expand Down
Loading