diff --git a/include/analyze/comp_unique_bounds_pb.h b/include/analyze/comp_unique_bounds_pb.h index 870bc6d9d..f404237de 100644 --- a/include/analyze/comp_unique_bounds_pb.h +++ b/include/analyze/comp_unique_bounds_pb.h @@ -21,7 +21,6 @@ class CompUniqueBoundsPB : public CompUniqueBounds { public: class Bound : public CompUniqueBounds::Bound { public: // Visible to CompUniqueBoundsPB's subclasses - Ref ctx_; // isl var -> ft expr, the demangling map yielded from GenPBExpr // shared from CompUniqueBoundsPB::cachedFreeVars_ Ref> demangleMap_; @@ -30,11 +29,9 @@ class CompUniqueBoundsPB : public CompUniqueBounds { PBSet bound_; public: - Bound(Ref ctx, - Ref> demangleMap, + Bound(Ref> demangleMap, PBSet bound) - : ctx_(std::move(ctx)), demangleMap_(std::move(demangleMap)), - bound_(std::move(bound)) {} + : demangleMap_(std::move(demangleMap)), bound_(std::move(bound)) {} BoundType type() const override { return BoundType::Presburger; } diff --git a/include/analyze/deps.h b/include/analyze/deps.h index 42211b69a..927142cd1 100644 --- a/include/analyze/deps.h +++ b/include/analyze/deps.h @@ -256,7 +256,6 @@ struct Dependence { // not only counting the nearest, but all PBMap later2EarlierIterAllPossible_; PBMap extConstraint_; - PBCtx &presburger_; AnalyzeDeps &self_; // Helper functions @@ -363,39 +362,40 @@ class AnalyzeDeps { static std::string makeCond(GenPBExpr &genPBExpr, GenPBExpr::VarMap &externals, bool eraseOutsideVarDef, const AccessPoint &ap); - static PBMap makeAccMapStatic(PBCtx &presburger, const AccessPoint &p, - int iterDim, int accDim, + static PBMap makeAccMapStatic(const Ref &presburger, + const AccessPoint &p, int iterDim, int accDim, const std::string &extSuffix, GenPBExpr::VarMap &externals, const ASTHashSet &noNeedToBeVars, bool eraseOutsideVarDef); private: - PBMap makeAccMap(PBCtx &presburger, const AccessPoint &p, int iterDim, - int accDim, const std::string &extSuffix, + PBMap makeAccMap(const Ref &presburger, const AccessPoint &p, + int iterDim, int accDim, const std::string &extSuffix, GenPBExpr::VarMap &externals, const ASTHashSet &noNeedToBeVars) { return makeAccMapStatic(presburger, p, iterDim, accDim, extSuffix, externals, noNeedToBeVars, eraseOutsideVarDef_); } - PBMap makeEqForBothOps(PBCtx &presburger, + PBMap makeEqForBothOps(const Ref &presburger, const std::vector> &coord, int iterDim) const; - PBMap makeIneqBetweenOps(PBCtx &presburger, DepDirection mode, int iterId, - int iterDim) const; + PBMap makeIneqBetweenOps(const Ref &presburger, DepDirection mode, + int iterId, int iterDim) const; - PBMap makeSerialToAll(PBCtx &presburger, int iterDim, + PBMap makeSerialToAll(const Ref &presburger, int iterDim, const std::vector &point) const; static int countSerial(const std::vector &point); - PBMap makeExternalEq(PBCtx &presburger, int iterDim, + PBMap makeExternalEq(const Ref &presburger, int iterDim, const std::string &ext1, const std::string &ext2); - PBMap makeConstraintOfSingleLoop(PBCtx &presburger, const ID &loop, - DepDirection mode, int iterDim); + PBMap makeConstraintOfSingleLoop(const Ref &presburger, + const ID &loop, DepDirection mode, + int iterDim); - PBMap makeConstraintOfParallelScope(PBCtx &presburger, + PBMap makeConstraintOfParallelScope(const Ref &presburger, const ParallelScope ¶llel, DepDirection mode, int iterDim, const AccessPoint &later, @@ -415,14 +415,14 @@ class AnalyzeDeps { * * There will be no dependences of a[0] across i */ - PBMap makeEraseVarDefConstraint(PBCtx &presburger, + PBMap makeEraseVarDefConstraint(const Ref &presburger, const Ref &point, int iterDim); /** * Constraint for loops that explicitly marked as no_deps by users */ - PBMap makeNoDepsConstraint(PBCtx &presburger, const std::string &var, - int iterDim); + PBMap makeNoDepsConstraint(const Ref &presburger, + const std::string &var, int iterDim); /** * Constraint for external variables inside loop @@ -438,7 +438,7 @@ class AnalyzeDeps { * idx[i] + j must be different for the same i but different j, but * idx[i] + j may be the same for different i */ - PBMap makeExternalVarConstraint(PBCtx &presburger, + PBMap makeExternalVarConstraint(const Ref &presburger, const Ref &later, const Ref &earlier, const GenPBExpr::VarMap &laterExternals, @@ -460,13 +460,15 @@ class AnalyzeDeps { * FindDepsMode::Dep mode, we do not care about the result. Therefore, we * project out these dimensions */ - PBMap projectOutPrivateAxis(PBCtx &presburger, int iterDim, int since); - void projectOutPrivateAxis(PBCtx &presburger, const Ref &point, + PBMap projectOutPrivateAxis(const Ref &presburger, int iterDim, + int since); + void projectOutPrivateAxis(const Ref &presburger, + const Ref &point, const std::vector> &otherList, std::vector &otherMapList, int iterDim); int numCommonDims(const Ref &p1, const Ref &p2); - void checkAgainstCond(PBCtx &presburger, const Ref &later, + void checkAgainstCond(const Ref &later, const Ref &earlier, const PBMap &depAll, const PBMap &nearest, const PBMap &laterMap, const PBMap &earlierMap, const PBMap &extConstraint, @@ -485,7 +487,8 @@ class AnalyzeDeps { checkDepLatestEarlier(const Ref &later, const std::vector> &earlierList); void - checkDepLatestEarlierImpl(PBCtx &presburger, const Ref &later, + checkDepLatestEarlierImpl(const Ref &presburger, + const Ref &later, const std::vector> &earlierList); /** @@ -498,7 +501,7 @@ class AnalyzeDeps { void checkDepEarliestLater(const std::vector> &laterList, const Ref &earlier); void - checkDepEarliestLaterImpl(PBCtx &presburger, + checkDepEarliestLaterImpl(const Ref &presburger, const std::vector> &laterList, const Ref &earlier); }; diff --git a/include/math/parse_pb_expr.h b/include/math/parse_pb_expr.h index 2189b3359..20f5cb37a 100644 --- a/include/math/parse_pb_expr.h +++ b/include/math/parse_pb_expr.h @@ -26,8 +26,12 @@ typedef std::vector PBFuncAST; /** * Parse a PBFunc to be ASTs + * + * @{ */ -PBFuncAST parsePBFunc(const std::string &str); +PBFuncAST parsePBFunc(const PBFunc::Serialized &f); +PBFuncAST parsePBFunc(const PBSingleFunc::Serialized &f); +/** @} */ /** * Construct AST from PBSet while preserving min and max with a special hack to @@ -35,8 +39,8 @@ PBFuncAST parsePBFunc(const std::string &str); * * @{ */ -PBFuncAST parsePBFuncReconstructMinMax(const PBCtx &ctx, const PBSet &set); -PBFuncAST parsePBFuncReconstructMinMax(const PBCtx &ctx, const PBMap &map); +PBFuncAST parsePBFuncReconstructMinMax(const PBSet &set); +PBFuncAST parsePBFuncReconstructMinMax(const PBMap &map); /** @} */ /** @@ -44,16 +48,15 @@ PBFuncAST parsePBFuncReconstructMinMax(const PBCtx &ctx, const PBMap &map); * * @{ */ -inline SimplePBFuncAST parseSimplePBFunc(const std::string &str) { - auto ret = parsePBFunc(str); +inline SimplePBFuncAST parseSimplePBFunc(const auto &f) { + auto ret = parsePBFunc(f); if (ret.size() != 1) { - throw ParserError(str + " is not a simple PBFunc"); + throw ParserError(FT_MSG << f << " is not a simple PBFunc"); } return ret.front(); } -inline SimplePBFuncAST parseSimplePBFuncReconstructMinMax(const PBCtx &ctx, - const auto &f) { - auto ret = parsePBFuncReconstructMinMax(ctx, f); +inline SimplePBFuncAST parseSimplePBFuncReconstructMinMax(const auto &f) { + auto ret = parsePBFuncReconstructMinMax(f); if (ret.size() != 1) { throw ParserError(FT_MSG << f << " is not a simple PBFunc"); } diff --git a/include/math/presburger.h b/include/math/presburger.h index 060b70e9a..93914f7e7 100644 --- a/include/math/presburger.h +++ b/include/math/presburger.h @@ -22,6 +22,7 @@ #include #include #include +#include #include #include @@ -47,38 +48,39 @@ template T *MOVE_ISL_PTR(T *&ptr) { return ret; } +/** + * Context for presburger operation + * + * - All operands of a presburger operation should be on the same context. + * - Operations in the same context is NOT thread-safe. Explitly transfer to a + * different context if you want to use in multiple threads. + */ class PBCtx { isl_ctx *ctx_ = nullptr; - bool dontFree_ = false; // Tolerate memory leak public: PBCtx() : ctx_(isl_ctx_alloc()) { isl_options_set_on_error(ctx_, ISL_ON_ERROR_ABORT); } - ~PBCtx() { - if (!dontFree_) { - isl_ctx_free(ctx_); - } - } + ~PBCtx() { isl_ctx_free(ctx_); } PBCtx(const PBCtx &other) = delete; PBCtx &operator=(const PBCtx &other) = delete; PBCtx(PBCtx &&other) = delete; PBCtx &operator=(PBCtx &&other) = delete; - void setDontFree(bool flag = true) { dontFree_ = flag; } - isl_ctx *get() const { return GET_ISL_PTR(ctx_); } }; class PBMap { + Ref ctx_; isl_map *map_ = nullptr; public: PBMap() {} - PBMap(isl_map *map) : map_(map) {} - PBMap(const PBCtx &ctx, const std::string &str) - : map_(isl_map_read_from_str(ctx.get(), str.c_str())) { + PBMap(const Ref &ctx, isl_map *map) : ctx_(ctx), map_(map) {} + PBMap(const Ref &ctx, const std::string &str) + : ctx_(ctx), map_(isl_map_read_from_str(ctx->get(), str.c_str())) { if (map_ == nullptr) { ERROR("Unable to construct an PBMap from " + str); } @@ -89,8 +91,9 @@ class PBMap { } } - PBMap(const PBMap &other) : map_(other.copy()) {} + PBMap(const PBMap &other) : ctx_(other.ctx_), map_(other.copy()) {} PBMap &operator=(const PBMap &other) { + ctx_ = other.ctx_; if (map_ != nullptr) { isl_map_free(map_); } @@ -98,8 +101,9 @@ class PBMap { return *this; } - PBMap(PBMap &&other) : map_(other.move()) {} + PBMap(PBMap &&other) : ctx_(std::move(other.ctx_)), map_(other.move()) {} PBMap &operator=(PBMap &&other) { + ctx_ = std::move(other.ctx_); if (map_ != nullptr) { isl_map_free(map_); } @@ -109,10 +113,29 @@ class PBMap { bool isValid() const { return map_ != nullptr; } + const auto &ctx() const { return ctx_; } + auto &ctx() { return ctx_; } + isl_map *get() const { return GET_ISL_PTR(map_); } isl_map *copy() const { return COPY_ISL_PTR(map_, map); } isl_map *move() { return MOVE_ISL_PTR(map_); } + class Serialized { + std::string data_; + + public: + Serialized() {} + Serialized(const std::string &data) : data_(data) {} + PBMap to(const Ref &ctx) const { return {ctx, data_}; } + bool isValid() const { return !data_.empty(); } + const auto &data() const { return data_; } + friend std::ostream &operator<<(std::ostream &os, const Serialized &s) { + return os << s.data_; + } + }; + Serialized toSerialized() const { return {isl_map_to_str(get())}; } + PBMap to(const Ref &ctx) const { return toSerialized().to(ctx); } + bool empty() const { DEBUG_PROFILE("empty"); return isl_map_is_empty(get()); @@ -142,19 +165,21 @@ class PBMap { }; class PBVal { + Ref ctx_; isl_val *val_ = nullptr; public: PBVal() {} - PBVal(isl_val *val) : val_(val) {} + PBVal(const Ref &ctx, isl_val *val) : ctx_(ctx), val_(val) {} ~PBVal() { if (val_ != nullptr) { isl_val_free(val_); } } - PBVal(const PBVal &other) : val_(other.copy()) {} + PBVal(const PBVal &other) : ctx_(other.ctx_), val_(other.copy()) {} PBVal &operator=(const PBVal &other) { + ctx_ = other.ctx_; if (val_ != nullptr) { isl_val_free(val_); } @@ -162,8 +187,9 @@ class PBVal { return *this; } - PBVal(PBVal &&other) : val_(other.move()) {} + PBVal(PBVal &&other) : ctx_(std::move(other.ctx_)), val_(other.move()) {} PBVal &operator=(PBVal &&other) { + ctx_ = std::move(other.ctx_); if (val_ != nullptr) { isl_val_free(val_); } @@ -173,6 +199,9 @@ class PBVal { bool isValid() const { return val_ != nullptr; } + const auto &ctx() const { return ctx_; } + auto &ctx() { return ctx_; } + isl_val *get() const { return GET_ISL_PTR(val_); } isl_val *copy() const { return COPY_ISL_PTR(val_, val); } isl_val *move() { return MOVE_ISL_PTR(val_); } @@ -192,13 +221,14 @@ class PBVal { }; class PBSet { + Ref ctx_; isl_set *set_ = nullptr; public: PBSet() {} - PBSet(isl_set *set) : set_(set) {} - PBSet(const PBCtx &ctx, const std::string &str) - : set_(isl_set_read_from_str(ctx.get(), str.c_str())) { + PBSet(const Ref &ctx, isl_set *set) : ctx_(ctx), set_(set) {} + PBSet(const Ref &ctx, const std::string &str) + : ctx_(ctx), set_(isl_set_read_from_str(ctx->get(), str.c_str())) { if (set_ == nullptr) { ERROR("Unable to construct an PBSet from " + str); } @@ -209,8 +239,9 @@ class PBSet { } } - PBSet(const PBSet &other) : set_(other.copy()) {} + PBSet(const PBSet &other) : ctx_(other.ctx_), set_(other.copy()) {} PBSet &operator=(const PBSet &other) { + ctx_ = other.ctx_; if (set_ != nullptr) { isl_set_free(set_); } @@ -218,8 +249,9 @@ class PBSet { return *this; } - PBSet(PBSet &&other) : set_(other.move()) {} + PBSet(PBSet &&other) : ctx_(std::move(other.ctx_)), set_(other.move()) {} PBSet &operator=(PBSet &&other) { + ctx_ = std::move(other.ctx_); if (set_ != nullptr) { isl_set_free(set_); } @@ -229,10 +261,29 @@ class PBSet { bool isValid() const { return set_ != nullptr; } + const auto &ctx() const { return ctx_; } + auto &ctx() { return ctx_; } + isl_set *get() const { return GET_ISL_PTR(set_); } isl_set *copy() const { return COPY_ISL_PTR(set_, set); } isl_set *move() { return MOVE_ISL_PTR(set_); } + class Serialized { + std::string data_; + + public: + Serialized() {} + Serialized(const std::string &data) : data_(data) {} + PBSet to(const Ref &ctx) const { return {ctx, data_}; } + bool isValid() const { return !data_.empty(); } + const auto &data() const { return data_; } + friend std::ostream &operator<<(std::ostream &os, const Serialized &s) { + return os << s.data_; + } + }; + Serialized toSerialized() const { return {isl_set_to_str(get())}; } + PBSet to(const Ref &ctx) const { return toSerialized().to(ctx); } + bool empty() const { DEBUG_PROFILE("empty"); return isl_set_is_empty(get()); @@ -265,21 +316,26 @@ class PBSet { }; class PBSpace { + Ref ctx_; isl_space *space_ = nullptr; public: PBSpace() {} - PBSpace(isl_space *space) : space_(space) {} - PBSpace(const PBSet &set) : space_(isl_set_get_space(set.get())) {} - PBSpace(const PBMap &map) : space_(isl_map_get_space(map.get())) {} + PBSpace(const Ref &ctx, isl_space *space) + : ctx_(ctx), space_(space) {} + PBSpace(const PBSet &set) + : ctx_(set.ctx()), space_(isl_set_get_space(set.get())) {} + PBSpace(const PBMap &map) + : ctx_(map.ctx()), space_(isl_map_get_space(map.get())) {} ~PBSpace() { if (space_ != nullptr) { isl_space_free(space_); } } - PBSpace(const PBSpace &other) : space_(other.copy()) {} + PBSpace(const PBSpace &other) : ctx_(other.ctx_), space_(other.copy()) {} PBSpace &operator=(const PBSpace &other) { + ctx_ = other.ctx_; if (space_ != nullptr) { isl_space_free(space_); } @@ -287,8 +343,10 @@ class PBSpace { return *this; } - PBSpace(PBSpace &&other) : space_(other.move()) {} + PBSpace(PBSpace &&other) + : ctx_(std::move(other.ctx_)), space_(other.move()) {} PBSpace &operator=(PBSpace &&other) { + ctx_ = std::move(other.ctx_); if (space_ != nullptr) { isl_space_free(space_); } @@ -298,6 +356,9 @@ class PBSpace { bool isValid() const { return space_ != nullptr; } + const auto &ctx() const { return ctx_; } + auto &ctx() { return ctx_; } + isl_space *get() const { return GET_ISL_PTR(space_); } isl_space *copy() const { return COPY_ISL_PTR(space_, space); } isl_space *move() { return MOVE_ISL_PTR(space_); } @@ -314,12 +375,21 @@ class PBSpace { }; class PBSingleFunc { + Ref ctx_; isl_pw_aff *func_ = nullptr; public: PBSingleFunc() {} - PBSingleFunc(isl_pw_aff *func) : func_(func) {} - explicit PBSingleFunc(isl_aff *func) : func_(isl_pw_aff_from_aff(func)) {} + PBSingleFunc(const Ref &ctx, isl_pw_aff *func) + : ctx_(ctx), func_(func) {} + PBSingleFunc(const Ref &ctx, const std::string &str) + : ctx_(ctx), func_(isl_pw_aff_read_from_str(ctx->get(), str.c_str())) { + if (func_ == nullptr) { + ERROR("Unable to construct an PBSingleFunc from " + str); + } + } + explicit PBSingleFunc(const Ref &ctx, isl_aff *func) + : ctx_(ctx), func_(isl_pw_aff_from_aff(func)) {} ~PBSingleFunc() { if (func_ != nullptr) { @@ -327,8 +397,10 @@ class PBSingleFunc { } } - PBSingleFunc(const PBSingleFunc &other) : func_(other.copy()) {} + PBSingleFunc(const PBSingleFunc &other) + : ctx_(other.ctx_), func_(other.copy()) {} PBSingleFunc &operator=(const PBSingleFunc &other) { + ctx_ = other.ctx_; if (func_ != nullptr) { isl_pw_aff_free(func_); } @@ -336,8 +408,10 @@ class PBSingleFunc { return *this; } - PBSingleFunc(PBSingleFunc &&other) : func_(other.move()) {} + PBSingleFunc(PBSingleFunc &&other) + : ctx_(std::move(other.ctx_)), func_(other.move()) {} PBSingleFunc &operator=(PBSingleFunc &&other) { + ctx_ = std::move(other.ctx_); if (func_ != nullptr) { isl_pw_aff_free(func_); } @@ -347,24 +421,48 @@ class PBSingleFunc { bool isValid() const { return func_ != nullptr; } + const auto &ctx() const { return ctx_; } + auto &ctx() { return ctx_; } + isl_pw_aff *get() const { return GET_ISL_PTR(func_); } isl_pw_aff *copy() const { return COPY_ISL_PTR(func_, pw_aff); } isl_pw_aff *move() { return MOVE_ISL_PTR(func_); } + class Serialized { + std::string data_; + + public: + Serialized() {} + Serialized(const std::string &data) : data_(data) {} + PBSingleFunc to(const Ref &ctx) const { return {ctx, data_}; } + bool isValid() const { return !data_.empty(); } + const auto &data() const { return data_; } + friend std::ostream &operator<<(std::ostream &os, const Serialized &s) { + return os << s.data_; + } + }; + Serialized toSerialized() const { return {isl_pw_aff_to_str(get())}; } + PBSingleFunc to(const Ref &ctx) const { + return toSerialized().to(ctx); + } + isl_size nInDims() const { return isl_pw_aff_dim(get(), isl_dim_in); } std::vector> pieces() const { - std::vector> result; + typedef std::vector> Result; + std::pair> userData; isl_pw_aff_foreach_piece( get(), - [](isl_set *set, isl_aff *piece, void *user) { - ((std::vector> *)user) - ->emplace_back(PBSet(set), - PBSingleFunc(isl_pw_aff_from_aff(piece))); + [](isl_set *set, isl_aff *piece, void *userDataRaw) { + auto &[result, ctx] = + *(std::pair> *)userDataRaw; + result.emplace_back( + PBSet(ctx, set), + PBSingleFunc(ctx, isl_pw_aff_from_aff(piece))); return isl_stat_ok; }, - &result); - return result; + &userData); + return userData.first; } friend std::ostream &operator<<(std::ostream &os, @@ -374,22 +472,39 @@ class PBSingleFunc { }; class PBFunc { + Ref ctx_; isl_pw_multi_aff *func_ = nullptr; public: PBFunc() {} - PBFunc(isl_pw_multi_aff *func) : func_(func) {} + PBFunc(const Ref &ctx, isl_pw_multi_aff *func) + : ctx_(ctx), func_(func) {} + PBFunc(const Ref &ctx, const std::string &str) + : ctx_(ctx), + func_(isl_pw_multi_aff_read_from_str(ctx->get(), str.c_str())) { + if (func_ == nullptr) { + ERROR("Unable to construct an PBFunc from " + str); + } + } PBFunc(const PBSingleFunc &singleFunc) - : func_(isl_pw_multi_aff_from_pw_aff(singleFunc.copy())) {} + : ctx_(singleFunc.ctx()), + func_(isl_pw_multi_aff_from_pw_aff(singleFunc.copy())) {} PBFunc(PBSingleFunc &&singleFunc) - : func_(isl_pw_multi_aff_from_pw_aff(singleFunc.move())) {} + : ctx_(std::move(singleFunc.ctx())), + func_(isl_pw_multi_aff_from_pw_aff(singleFunc.move())) {} - PBFunc(const PBMap &map) : func_(isl_pw_multi_aff_from_map(map.copy())) {} - PBFunc(PBMap &&map) : func_(isl_pw_multi_aff_from_map(map.move())) {} + PBFunc(const PBMap &map) + : ctx_(map.ctx()), func_(isl_pw_multi_aff_from_map(map.copy())) {} + PBFunc(PBMap &&map) + : ctx_(std::move(map.ctx())), + func_(isl_pw_multi_aff_from_map(map.move())) {} - PBFunc(const PBSet &set) : func_(isl_pw_multi_aff_from_set(set.copy())) {} - PBFunc(PBSet &&set) : func_(isl_pw_multi_aff_from_set(set.move())) {} + PBFunc(const PBSet &set) + : ctx_(set.ctx()), func_(isl_pw_multi_aff_from_set(set.copy())) {} + PBFunc(PBSet &&set) + : ctx_(std::move(set.ctx())), + func_(isl_pw_multi_aff_from_set(set.move())) {} ~PBFunc() { if (func_ != nullptr) { @@ -397,8 +512,9 @@ class PBFunc { } } - PBFunc(const PBFunc &other) : func_(other.copy()) {} + PBFunc(const PBFunc &other) : ctx_(other.ctx_), func_(other.copy()) {} PBFunc &operator=(const PBFunc &other) { + ctx_ = other.ctx_; if (func_ != nullptr) { isl_pw_multi_aff_free(func_); } @@ -406,8 +522,9 @@ class PBFunc { return *this; } - PBFunc(PBFunc &&other) : func_(other.move()) {} + PBFunc(PBFunc &&other) : ctx_(std::move(other.ctx_)), func_(other.move()) {} PBFunc &operator=(PBFunc &&other) { + ctx_ = std::move(other.ctx_); if (func_ != nullptr) { isl_pw_multi_aff_free(func_); } @@ -417,32 +534,53 @@ class PBFunc { bool isValid() const { return func_ != nullptr; } + const auto &ctx() const { return ctx_; } + auto &ctx() { return ctx_; } + isl_pw_multi_aff *get() const { return GET_ISL_PTR(func_); } isl_pw_multi_aff *copy() const { return COPY_ISL_PTR(func_, pw_multi_aff); } isl_pw_multi_aff *move() { return MOVE_ISL_PTR(func_); } + class Serialized { + std::string data_; + + public: + Serialized() {} + Serialized(const std::string &data) : data_(data) {} + PBFunc to(const Ref &ctx) const { return {ctx, data_}; } + bool isValid() const { return !data_.empty(); } + const auto &data() const { return data_; } + friend std::ostream &operator<<(std::ostream &os, const Serialized &s) { + return os << s.data_; + } + }; + Serialized toSerialized() const { return {isl_pw_multi_aff_to_str(get())}; } + PBFunc to(const Ref &ctx) const { return toSerialized().to(ctx); } + isl_size nInDims() const { return isl_pw_multi_aff_dim(get(), isl_dim_in); } isl_size nOutDims() const { return isl_pw_multi_aff_dim(get(), isl_dim_out); } PBSingleFunc operator[](isl_size i) const { - return isl_pw_multi_aff_get_pw_aff(get(), 0); + return {ctx_, isl_pw_multi_aff_get_pw_aff(get(), 0)}; } std::vector> pieces() const { - std::vector> result; + typedef std::vector> Result; + std::pair> userData; isl_pw_multi_aff_foreach_piece( get(), - [](isl_set *set, isl_multi_aff *piece, void *user) { - ((std::vector> *)user) - ->emplace_back( - PBSet(set), - PBFunc(isl_pw_multi_aff_from_multi_aff(piece))); + [](isl_set *set, isl_multi_aff *piece, void *userDataRaw) { + auto &[result, ctx] = + *(std::pair> *)userDataRaw; + result.emplace_back( + PBSet(ctx, set), + PBFunc(ctx, isl_pw_multi_aff_from_multi_aff(piece))); return isl_stat_ok; }, - &result); - return result; + &userData); + return userData.first; } friend std::ostream &operator<<(std::ostream &os, const PBFunc &func) { @@ -451,11 +589,13 @@ class PBFunc { }; class PBPoint { + Ref ctx_; isl_point *point_ = nullptr; public: PBPoint() {} - PBPoint(isl_point *point) : point_(point) {} + PBPoint(const Ref &ctx, isl_point *point) + : ctx_(ctx), point_(point) {} ~PBPoint() { if (point_ != nullptr) { @@ -463,8 +603,9 @@ class PBPoint { } } - PBPoint(const PBPoint &other) : point_(other.copy()) {} + PBPoint(const PBPoint &other) : ctx_(other.ctx_), point_(other.copy()) {} PBPoint &operator=(const PBPoint &other) { + ctx_ = other.ctx_; if (point_ != nullptr) { isl_point_free(point_); } @@ -472,8 +613,10 @@ class PBPoint { return *this; } - PBPoint(PBPoint &&other) : point_(other.move()) {} + PBPoint(PBPoint &&other) + : ctx_(std::move(other.ctx_)), point_(other.move()) {} PBPoint &operator=(PBPoint &&other) { + ctx_ = std::move(other.ctx_); if (point_ != nullptr) { isl_point_free(point_); } @@ -489,14 +632,17 @@ class PBPoint { bool isVoid() const { return isl_point_is_void(point_); } + const auto &ctx() const { return ctx_; } + auto &ctx() { return ctx_; } + std::vector coordinates() const { ASSERT(!isVoid()); std::vector result; isl_size nCoord = isl_space_dim( - PBSpace(isl_point_get_space(point_)).get(), isl_dim_set); + PBSpace(ctx_, isl_point_get_space(point_)).get(), isl_dim_set); for (isl_size i = 0; i < nCoord; ++i) result.emplace_back( - isl_point_get_coordinate_val(point_, isl_dim_set, i)); + ctx_, isl_point_get_coordinate_val(point_, isl_dim_set, i)); return result; } }; @@ -522,107 +668,131 @@ template auto PBRefTake(std::remove_reference_t &&t) { return t.move(); } +Ref commonCtx(const auto &lhs, const auto &rhs) { + if (lhs.ctx()->get() != rhs.ctx()->get()) { + ERROR( + "Operands of a Presburger operation should be on the same context"); + } + return lhs.ctx(); +} + template PBSet projectOutAllParams(T &&set) { - return isl_set_project_out_all_params(PBRefTake(set)); + return {set.ctx(), isl_set_project_out_all_params(PBRefTake(set))}; } template PBMap projectOutAllParams(T &&map) { - return isl_map_project_out_all_params(PBRefTake(map)); + return {map.ctx(), isl_map_project_out_all_params(PBRefTake(map))}; } template PBSet projectOutParamById(T &&set, const std::string &name) { - isl_ctx *ctx = isl_set_get_ctx(set.get()); - return isl_set_project_out_param_id( - PBRefTake(set), isl_id_alloc(ctx, name.c_str(), nullptr)); + return {set.ctx(), + isl_set_project_out_param_id( + PBRefTake(set), + isl_id_alloc(set.ctx()->get(), name.c_str(), nullptr))}; } template PBSet projectOutParamDims(T &&set, unsigned first, unsigned n) { - return isl_set_project_out(PBRefTake(set), isl_dim_param, first, n); + return {set.ctx(), + isl_set_project_out(PBRefTake(set), isl_dim_param, first, n)}; } template PBSet projectOutDims(T &&set, unsigned first, unsigned n) { - return isl_set_project_out(PBRefTake(set), isl_dim_set, first, n); + return {set.ctx(), + isl_set_project_out(PBRefTake(set), isl_dim_set, first, n)}; } template PBMap projectOutParamById(T &&map, const std::string &name) { - isl_ctx *ctx = isl_map_get_ctx(map.get()); - return isl_map_project_out_param_id( - PBRefTake(map), isl_id_alloc(ctx, name.c_str(), nullptr)); + return {map.ctx(), + isl_map_project_out_param_id( + PBRefTake(map), + isl_id_alloc(map.ctx()->get(), name.c_str(), nullptr))}; } template PBMap projectOutParamDims(T &&map, unsigned first, unsigned n) { - return isl_map_project_out(PBRefTake(map), isl_dim_param, first, n); + return {map.ctx(), + isl_map_project_out(PBRefTake(map), isl_dim_param, first, n)}; } template PBMap projectOutInputDims(T &&map, unsigned first, unsigned n) { - return isl_map_project_out(PBRefTake(map), isl_dim_in, first, n); + return {map.ctx(), + isl_map_project_out(PBRefTake(map), isl_dim_in, first, n)}; } template PBMap projectOutOutputDims(T &&map, unsigned first, unsigned n) { - return isl_map_project_out(PBRefTake(map), isl_dim_out, first, n); + return {map.ctx(), + isl_map_project_out(PBRefTake(map), isl_dim_out, first, n)}; } template PBSet insertDims(T &&set, unsigned first, unsigned n) { - return isl_set_insert_dims(PBRefTake(set), isl_dim_set, first, n); + return {set.ctx(), + isl_set_insert_dims(PBRefTake(set), isl_dim_set, first, n)}; } template PBMap insertInputDims(T &&map, unsigned first, unsigned n) { - return isl_map_insert_dims(PBRefTake(map), isl_dim_in, first, n); + return {map.ctx(), + isl_map_insert_dims(PBRefTake(map), isl_dim_in, first, n)}; } template PBMap insertOutputDims(T &&map, unsigned first, unsigned n) { - return isl_map_insert_dims(PBRefTake(map), isl_dim_out, first, n); + return {map.ctx(), + isl_map_insert_dims(PBRefTake(map), isl_dim_out, first, n)}; } template PBSet fixDim(T &&set, unsigned pos, int x) { - return isl_set_fix_si(PBRefTake(set), isl_dim_set, pos, x); + return {set.ctx(), isl_set_fix_si(PBRefTake(set), isl_dim_set, pos, x)}; } template PBMap fixInputDim(T &&map, unsigned pos, int x) { - return isl_map_fix_si(PBRefTake(map), isl_dim_in, pos, x); + return {map.set(), isl_map_fix_si(PBRefTake(map), isl_dim_in, pos, x)}; } template PBMap fixOutputDim(T &&map, unsigned pos, int x) { - return isl_map_fix_si(PBRefTake(map), isl_dim_out, pos, x); + return {map.ctx(), isl_map_fix_si(PBRefTake(map), isl_dim_out, pos, x)}; } template PBSet lowerBoundDim(T &&set, unsigned pos, int x) { - return isl_set_lower_bound_si(PBRefTake(set), isl_dim_set, pos, x); + return {set.ctx(), + isl_set_lower_bound_si(PBRefTake(set), isl_dim_set, pos, x)}; } template PBMap lowerBoundInputDim(T &&map, unsigned pos, int x) { - return isl_map_lower_bound_si(PBRefTake(map), isl_dim_in, pos, x); + return {map.ctx(), + isl_map_lower_bound_si(PBRefTake(map), isl_dim_in, pos, x)}; } template PBMap lowerBoundOutputDim(T &&map, unsigned pos, int x) { - return isl_map_lower_bound_si(PBRefTake(map), isl_dim_out, pos, x); + return {map.ctx(), + isl_map_lower_bound_si(PBRefTake(map), isl_dim_out, pos, x)}; } template PBSet upperBoundDim(T &&set, unsigned pos, int x) { - return isl_set_upper_bound_si(PBRefTake(set), isl_dim_set, pos, x); + return {set.ctx(), + isl_set_upper_bound_si(PBRefTake(set), isl_dim_set, pos, x)}; } template PBMap upperBoundInputDim(T &&map, unsigned pos, int x) { - return isl_map_upper_bound_si(PBRefTake(map), isl_dim_in, pos, x); + return {map.ctx(), + isl_map_upper_bound_si(PBRefTake(map), isl_dim_in, pos, x)}; } template PBMap upperBoundOutputDim(T &&map, unsigned pos, int x) { - return isl_map_upper_bound_si(PBRefTake(map), isl_dim_out, pos, x); + return {map.ctx(), + isl_map_upper_bound_si(PBRefTake(map), isl_dim_out, pos, x)}; } template PBMap newDomainOnlyMap(T &&set) { - return isl_map_from_domain(PBRefTake(set)); + return {set.ctx(), isl_map_from_domain(PBRefTake(set))}; } template PBMap newRangeOnlyMap(T &&set) { - return isl_map_from_range(PBRefTake(set)); + return {set.ctx(), isl_map_from_range(PBRefTake(set))}; } template PBMap moveDimsInputToOutput(T &&map, unsigned first, unsigned n, unsigned target) { - return isl_map_move_dims(PBRefTake(map), isl_dim_out, target, isl_dim_in, - first, n); + return {map.ctx(), isl_map_move_dims(PBRefTake(map), isl_dim_out, target, + isl_dim_in, first, n)}; } template PBMap moveDimsOutputToInput(T &&map, unsigned first, unsigned n, unsigned target) { - return isl_map_move_dims(PBRefTake(map), isl_dim_in, target, isl_dim_out, - first, n); + return {map.ctx(), isl_map_move_dims(PBRefTake(map), isl_dim_in, target, + isl_dim_out, first, n)}; } /** * Move other dimensions to be parameters. @@ -636,38 +806,38 @@ PBMap moveDimsOutputToInput(T &&map, unsigned first, unsigned n, template PBMap moveDimsInputToParam(T &&map, unsigned first, unsigned n, unsigned target) { - return isl_map_move_dims(PBRefTake(map), isl_dim_param, target, - isl_dim_in, first, n); + return {map.ctx(), isl_map_move_dims(PBRefTake(map), isl_dim_param, + target, isl_dim_in, first, n)}; } template PBMap moveDimsOutputToParam(T &&map, unsigned first, unsigned n, unsigned target) { - return isl_map_move_dims(PBRefTake(map), isl_dim_param, target, - isl_dim_out, first, n); + return {map.ctx(), isl_map_move_dims(PBRefTake(map), isl_dim_param, + target, isl_dim_out, first, n)}; } /** @} */ template PBMap moveDimsParamToInput(T &&map, unsigned first, unsigned n, unsigned target) { - return isl_map_move_dims(PBRefTake(map), isl_dim_in, target, - isl_dim_param, first, n); + return {map.ctx(), isl_map_move_dims(PBRefTake(map), isl_dim_in, target, + isl_dim_param, first, n)}; } template PBMap moveDimsParamToOutput(T &&map, unsigned first, unsigned n, unsigned target) { - return isl_map_move_dims(PBRefTake(map), isl_dim_out, target, - isl_dim_param, first, n); + return {map.ctx(), isl_map_move_dims(PBRefTake(map), isl_dim_out, target, + isl_dim_param, first, n)}; } template PBSet moveDimsSetToParam(T &&set, unsigned first, unsigned n, unsigned target) { - return isl_set_move_dims(PBRefTake(set), isl_dim_param, target, - isl_dim_set, first, n); + return {set.ctx(), isl_set_move_dims(PBRefTake(set), isl_dim_param, + target, isl_dim_set, first, n)}; } template PBSet moveDimsParamToSet(T &&set, unsigned first, unsigned n, unsigned target) { - return isl_set_move_dims(PBRefTake(set), isl_dim_set, target, - isl_dim_param, first, n); + return {set.ctx(), isl_set_move_dims(PBRefTake(set), isl_dim_set, target, + isl_dim_param, first, n)}; } template @@ -680,254 +850,273 @@ std::pair padToSameDims(T &&lhs, U &&rhs) { template PBSet complement(T &&set) { DEBUG_PROFILE("complement"); - return isl_set_complement(PBRefTake(set)); + return {set.ctx(), isl_set_complement(PBRefTake(set))}; } template PBMap complement(T &&map) { DEBUG_PROFILE("complement"); - return isl_map_complement(PBRefTake(map)); + return {map.ctx(), isl_map_complement(PBRefTake(map))}; } template PBMap reverse(T &&map) { DEBUG_PROFILE("reverse"); - return isl_map_reverse(PBRefTake(map)); + return {map.ctx(), isl_map_reverse(PBRefTake(map))}; } template PBMap subtract(T &&lhs, U &&rhs) { DEBUG_PROFILE_VERBOSE("subtract", "nBasic=" + std::to_string(lhs.nBasic()) + "," + std::to_string(rhs.nBasic())); - return isl_map_subtract(PBRefTake(lhs), PBRefTake(rhs)); + return {commonCtx(lhs, rhs), + isl_map_subtract(PBRefTake(lhs), PBRefTake(rhs))}; } template PBSet subtract(T &&lhs, U &&rhs) { DEBUG_PROFILE_VERBOSE("subtract", "nBasic=" + std::to_string(lhs.nBasic()) + "," + std::to_string(rhs.nBasic())); - return isl_set_subtract(PBRefTake(lhs), PBRefTake(rhs)); + return {commonCtx(lhs, rhs), + isl_set_subtract(PBRefTake(lhs), PBRefTake(rhs))}; } template PBMap intersect(T &&lhs, U &&rhs) { DEBUG_PROFILE_VERBOSE("intersect", "nBasic=" + std::to_string(lhs.nBasic()) + "," + std::to_string(rhs.nBasic())); - return isl_map_intersect(PBRefTake(lhs), PBRefTake(rhs)); + return {commonCtx(lhs, rhs), + isl_map_intersect(PBRefTake(lhs), PBRefTake(rhs))}; } template PBSet intersect(T &&lhs, U &&rhs) { DEBUG_PROFILE_VERBOSE("intersect", "nBasic=" + std::to_string(lhs.nBasic()) + "," + std::to_string(rhs.nBasic())); - return isl_set_intersect(PBRefTake(lhs), PBRefTake(rhs)); + return {commonCtx(lhs, rhs), + isl_set_intersect(PBRefTake(lhs), PBRefTake(rhs))}; } template PBMap intersectDomain(T &&lhs, U &&rhs) { DEBUG_PROFILE_VERBOSE("intersectDomain", "nBasic=" + std::to_string(lhs.nBasic()) + "," + std::to_string(rhs.nBasic())); - return isl_map_intersect_domain(PBRefTake(lhs), PBRefTake(rhs)); + return {commonCtx(lhs, rhs), + isl_map_intersect_domain(PBRefTake(lhs), PBRefTake(rhs))}; } template PBMap intersectRange(T &&lhs, U &&rhs) { DEBUG_PROFILE_VERBOSE("intersectRange", "nBasic=" + std::to_string(lhs.nBasic()) + "," + std::to_string(rhs.nBasic())); - return isl_map_intersect_range(PBRefTake(lhs), PBRefTake(rhs)); + return {commonCtx(lhs, rhs), + isl_map_intersect_range(PBRefTake(lhs), PBRefTake(rhs))}; } template PBSingleFunc intersectDomain(T &&lhs, U &&rhs) { - return isl_pw_aff_intersect_domain(PBRefTake(lhs), PBRefTake(rhs)); + return {commonCtx(lhs, rhs), + isl_pw_aff_intersect_domain(PBRefTake(lhs), PBRefTake(rhs))}; } template PBFunc intersectDomain(T &&lhs, U &&rhs) { - return isl_multi_pw_aff_intersect_domain(PBRefTake(lhs), - PBRefTake(rhs)); + return {commonCtx(lhs, rhs), isl_multi_pw_aff_intersect_domain( + PBRefTake(lhs), PBRefTake(rhs))}; } template PBSet intersectParams(T &&lhs, U &&rhs) { - return isl_set_intersect_params(PBRefTake(lhs), PBRefTake(rhs)); + return {commonCtx(lhs, rhs), + isl_set_intersect_params(PBRefTake(lhs), PBRefTake(rhs))}; } template PBMap intersectParams(T &&lhs, U &&rhs) { - return isl_map_intersect_params(PBRefTake(lhs), PBRefTake(rhs)); + return {commonCtx(lhs, rhs), + isl_map_intersect_params(PBRefTake(lhs), PBRefTake(rhs))}; } template PBMap uni(T &&lhs, U &&rhs) { DEBUG_PROFILE_VERBOSE("uni", "nBasic=" + std::to_string(lhs.nBasic()) + "," + std::to_string(rhs.nBasic())); - return isl_map_union(PBRefTake(lhs), PBRefTake(rhs)); + return {commonCtx(lhs, rhs), + isl_map_union(PBRefTake(lhs), PBRefTake(rhs))}; } template PBSet uni(T &&lhs, U &&rhs) { DEBUG_PROFILE_VERBOSE("uni", "nBasic=" + std::to_string(lhs.nBasic()) + "," + std::to_string(rhs.nBasic())); - return isl_set_union(PBRefTake(lhs), PBRefTake(rhs)); + return {commonCtx(lhs, rhs), + isl_set_union(PBRefTake(lhs), PBRefTake(rhs))}; } template PBSet apply(T &&lhs, U &&rhs) { DEBUG_PROFILE("apply"); - return isl_set_apply(PBRefTake(lhs), PBRefTake(rhs)); + return {commonCtx(lhs, rhs), + isl_set_apply(PBRefTake(lhs), PBRefTake(rhs))}; } template PBMap applyDomain(T &&lhs, U &&rhs) { DEBUG_PROFILE("applyDomain"); - return isl_map_apply_domain(PBRefTake(lhs), PBRefTake(rhs)); + return {commonCtx(lhs, rhs), + isl_map_apply_domain(PBRefTake(lhs), PBRefTake(rhs))}; } template PBMap applyRange(T &&lhs, U &&rhs) { DEBUG_PROFILE("applyRange"); - return isl_map_apply_range(PBRefTake(lhs), PBRefTake(rhs)); + return {commonCtx(lhs, rhs), + isl_map_apply_range(PBRefTake(lhs), PBRefTake(rhs))}; } template PBMap sum(T &&lhs, U &&rhs) { DEBUG_PROFILE("sum"); - return isl_map_sum(PBRefTake(lhs), PBRefTake(rhs)); + return {commonCtx(lhs, rhs), + isl_map_sum(PBRefTake(lhs), PBRefTake(rhs))}; } template PBSet sum(T &&lhs, U &&rhs) { DEBUG_PROFILE("sum"); - return isl_set_sum(PBRefTake(lhs), PBRefTake(rhs)); + return {commonCtx(lhs, rhs), + isl_set_sum(PBRefTake(lhs), PBRefTake(rhs))}; } -template PBMap neg(T &&lhs) { +template PBMap neg(T &&map) { DEBUG_PROFILE("neg"); - return isl_map_neg(PBRefTake(lhs)); + return {map.ctx(), isl_map_neg(PBRefTake(map))}; } -template PBSet neg(T &&lhs) { +template PBSet neg(T &&set) { DEBUG_PROFILE("neg"); - return isl_set_neg(PBRefTake(lhs)); + return {set.ctx(), isl_set_neg(PBRefTake(set))}; } template PBMap lexmax(T &&map) { DEBUG_PROFILE_VERBOSE("lexmax", "nBasic=" + std::to_string(map.nBasic())); - return isl_map_lexmax(PBRefTake(map)); + return {map.ctx(), isl_map_lexmax(PBRefTake(map))}; } template PBMap lexmin(T &&map) { DEBUG_PROFILE_VERBOSE("lexmin", "nBasic=" + std::to_string(map.nBasic())); - return isl_map_lexmin(PBRefTake(map)); + return {map.ctx(), isl_map_lexmin(PBRefTake(map))}; } template PBSet lexmax(T &&set) { DEBUG_PROFILE_VERBOSE("lexmax", "nBasic=" + std::to_string(set.nBasic())); - return isl_set_lexmax(PBRefTake(set)); + return {set.ctx(), isl_set_lexmax(PBRefTake(set))}; } template PBSet lexmin(T &&set) { DEBUG_PROFILE_VERBOSE("lexmin", "nBasic=" + std::to_string(set.nBasic())); - return isl_set_lexmin(PBRefTake(set)); + return {set.ctx(), isl_set_lexmin(PBRefTake(set))}; } template PBMap identity(T &&space) { DEBUG_PROFILE("identity"); - return isl_map_identity(PBRefTake(space)); + return {space.ctx(), isl_map_identity(PBRefTake(space))}; } template PBMap lexGE(T &&space) { DEBUG_PROFILE("lexGE"); - return isl_map_lex_ge(PBRefTake(space)); + return {space.ctx(), isl_map_lex_ge(PBRefTake(space))}; } template PBMap lexGT(T &&space) { DEBUG_PROFILE("lexGT"); - return isl_map_lex_gt(PBRefTake(space)); + return {space.ctx(), isl_map_lex_gt(PBRefTake(space))}; } template PBMap lexLE(T &&space) { DEBUG_PROFILE("lexLE"); - return isl_map_lex_le(PBRefTake(space)); + return {space.ctx(), isl_map_lex_le(PBRefTake(space))}; } template PBMap lexLT(T &&space) { DEBUG_PROFILE("lexLT"); - return isl_map_lex_lt(PBRefTake(space)); + return {space.ctx(), isl_map_lex_lt(PBRefTake(space))}; } -inline PBSpace spaceAlloc(const PBCtx &ctx, unsigned nparam, unsigned nIn, +inline PBSpace spaceAlloc(const Ref &ctx, unsigned nparam, unsigned nIn, unsigned nOut) { - return isl_space_alloc(ctx.get(), nparam, nIn, nOut); + return {ctx, isl_space_alloc(ctx->get(), nparam, nIn, nOut)}; } -inline PBSpace spaceSetAlloc(const PBCtx &ctx, unsigned nparam, unsigned dim) { - return isl_space_set_alloc(ctx.get(), nparam, dim); +inline PBSpace spaceSetAlloc(const Ref &ctx, unsigned nparam, + unsigned dim) { + return {ctx, isl_space_set_alloc(ctx->get(), nparam, dim)}; } template PBSet emptySet(T &&space) { - return isl_set_empty(PBRefTake(space)); + return {space.ctx(), isl_set_empty(PBRefTake(space))}; } template PBMap emptyMap(T &&space) { - return isl_map_empty(PBRefTake(space)); + return {space.ctx(), isl_map_empty(PBRefTake(space))}; } template PBSet universeSet(T &&space) { - return isl_set_universe(PBRefTake(space)); + return {space.ctx(), isl_set_universe(PBRefTake(space))}; } template PBMap universeMap(T &&space) { - return isl_map_universe(PBRefTake(space)); + return {space.ctx(), isl_map_universe(PBRefTake(space))}; } template PBSet domain(T &&map) { - return isl_map_domain(PBRefTake(map)); + return {map.ctx(), isl_map_domain(PBRefTake(map))}; } template PBSet range(T &&map) { - return isl_map_range(PBRefTake(map)); + return {map.ctx(), isl_map_range(PBRefTake(map))}; } template PBSet domain(T &&func) { - return isl_pw_aff_domain(PBRefTake(func)); + return {func.ctx(), isl_pw_aff_domain(PBRefTake(func))}; } template PBSet domain(T &&func) { - return isl_multi_pw_aff_domain(PBRefTake(func)); + return {func.ctx(), isl_multi_pw_aff_domain(PBRefTake(func))}; } template PBSet params(T &&set) { - return isl_set_params(PBRefTake(set)); + return {set.ctx(), isl_set_params(PBRefTake(set))}; } template PBSet params(T &&map) { - return isl_map_params(PBRefTake(map)); + return {map.ctx(), isl_map_params(PBRefTake(map))}; } template PBSet coalesce(T &&set) { DEBUG_PROFILE("coalesce"); - return isl_set_coalesce(PBRefTake(set)); + return {set.ctx(), isl_set_coalesce(PBRefTake(set))}; } template PBMap coalesce(T &&map) { DEBUG_PROFILE("coalesce"); - return isl_map_coalesce(PBRefTake(map)); + return {map.ctx(), isl_map_coalesce(PBRefTake(map))}; } template PBSet cartesianProduct(T &&lhs, U &&rhs) { - return isl_set_flat_product(PBRefTake(lhs), PBRefTake(rhs)); + return {commonCtx(lhs, rhs), + isl_set_flat_product(PBRefTake(lhs), PBRefTake(rhs))}; } template PBVal dimMaxVal(T &&set, int pos) { - return isl_set_dim_max_val(PBRefTake(set), pos); + return {set.ctx(), isl_set_dim_max_val(PBRefTake(set), pos)}; } template PBVal dimMinVal(T &&set, int pos) { - return isl_set_dim_min_val(PBRefTake(set), pos); + return {set.ctx(), isl_set_dim_min_val(PBRefTake(set), pos)}; } inline PBVal dimFixVal(const PBSet &set, int pos) { - return isl_set_plain_get_val_if_fixed(set.get(), isl_dim_set, pos); + return {set.ctx(), + isl_set_plain_get_val_if_fixed(set.get(), isl_dim_set, pos)}; } template PBSpace spaceMapFromSet(T &&space) { - return isl_space_map_from_set(PBRefTake(space)); + return {space.ctx(), isl_space_map_from_set(PBRefTake(space))}; } template PBSet wrap(T &&map) { - return isl_map_wrap(PBRefTake(map)); + return {map.ctx(), isl_map_wrap(PBRefTake(map))}; } template PBMap unwrap(T &&set) { - return isl_set_unwrap(PBRefTake(set)); + return {set.ctx(), isl_set_unwrap(PBRefTake(set))}; } template PBSet flatten(T &&set) { - return isl_set_flatten(PBRefTake(set)); + return {set.ctx(), isl_set_flatten(PBRefTake(set))}; } template PBMap flattenDomain(T &&map) { - return isl_map_flatten_domain(PBRefTake(map)); + return {map.ctx(), isl_map_flatten_domain(PBRefTake(map))}; } template PBMap flattenRange(T &&map) { - return isl_map_flatten_range(PBRefTake(map)); + return {map.ctx(), isl_map_flatten_range(PBRefTake(map))}; } template PBSet flattenMapToSet(T &&map) { @@ -935,17 +1124,19 @@ template PBSet flattenMapToSet(T &&map) { } template PBPoint sample(T &&set) { - return isl_set_sample_point(PBRefTake(set)); + return {set.ctx(), isl_set_sample_point(PBRefTake(set))}; } template PBSingleFunc min(T &&lhs, U &&rhs) { - return isl_pw_aff_min(PBRefTake(lhs), PBRefTake(rhs)); + return {commonCtx(lhs, rhs), + isl_pw_aff_min(PBRefTake(lhs), PBRefTake(rhs))}; } template PBSingleFunc max(T &&lhs, U &&rhs) { - return isl_pw_aff_max(PBRefTake(lhs), PBRefTake(rhs)); + return {commonCtx(lhs, rhs), + isl_pw_aff_max(PBRefTake(lhs), PBRefTake(rhs))}; } /** @@ -974,7 +1165,8 @@ template PBSet coefficients(T &&set, int64_t c = 0) { auto cPoint = isl_point_zero(paramsSpace); isl_point_set_coordinate_val(cPoint, isl_dim_set, nParams - 1, isl_val_int_from_si(ctx, -c)); - return apply(PBSet(isl_set_from_point(cPoint)), PBMap(coefficientsMap)); + return apply(PBSet(set.ctx(), isl_set_from_point(cPoint)), + PBMap(set.ctx(), coefficientsMap)); } inline bool isSubset(const PBSet &small, const PBSet &big) { @@ -1166,7 +1358,7 @@ class PBMapBuilder : public PBBuilder { const std::vector &outputs() const { return outputs_; } void clearOutputs() { outputs_.clear(); } - PBMap build(const PBCtx &ctx) const; + PBMap build(const Ref &ctx) const; }; class PBSetBuilder : public PBBuilder { @@ -1189,17 +1381,15 @@ class PBSetBuilder : public PBBuilder { const std::vector &vars() const { return vars_; } void clearVars() { vars_.clear(); } - PBSet build(const PBCtx &ctx) const; + PBSet build(const Ref &ctx) const; }; -auto pbFuncWithTimeout(PBCtx &ctx, const auto &func, int seconds, - const auto &...args) +auto pbFuncWithTimeout(const auto &func, int seconds, const auto &...args) -> std::optional { decltype(func(args...)) ret; if (timeout([&]() { ret = func(args...); }, seconds)) { return ret; } else { - ctx.setDontFree(); return std::nullopt; } } diff --git a/src/analyze/check_not_modified.cc b/src/analyze/check_not_modified.cc index e7783778d..8178a3645 100644 --- a/src/analyze/check_not_modified.cc +++ b/src/analyze/check_not_modified.cc @@ -176,16 +176,13 @@ bool checkNotModified(const Stmt &op, const Expr &s0Expr, const Expr &s1Expr, return ret; }; - // write -> serialized PBSet - std::unordered_map writesWAR; + std::unordered_map writesWAR; std::mutex m; auto foundWAR = [&](const Dependence &dep) { - // Serialize WAR map because it is from a random PBCtx - auto strWAR = - toString(apply(domain(dep.later2EarlierIter_), dep.laterIter2Idx_)); + auto strWAR = apply(domain(dep.later2EarlierIter_), dep.laterIter2Idx_); // only lock for writing the map std::lock_guard l(m); - writesWAR[dep.later_.stmt_] = strWAR; + writesWAR[dep.later_.stmt_] = strWAR.toSerialized(); }; FindDeps() .direction({dir}) @@ -200,8 +197,8 @@ bool checkNotModified(const Stmt &op, const Expr &s0Expr, const Expr &s1Expr, .noProjectOutPrivateAxis(true)(tmpOp, unsyncFunc(foundWAR)); auto foundRAW = [&](const Dependence &dep) { - // re-construct WAR map from stored string in current PBCtx - auto w0 = PBSet(dep.presburger_, writesWAR[dep.earlier_.stmt_]); + auto w0 = + writesWAR[dep.earlier_.stmt_].to(dep.later2EarlierIter_.ctx()); auto w1 = apply(range(dep.later2EarlierIter_), dep.earlierIter2Idx_); if (!intersect(std::move(w0), std::move(w1)).empty()) throw ModifiedException{}; diff --git a/src/analyze/comp_unique_bounds_pb.cc b/src/analyze/comp_unique_bounds_pb.cc index 424c50a9e..522741d0b 100644 --- a/src/analyze/comp_unique_bounds_pb.cc +++ b/src/analyze/comp_unique_bounds_pb.cc @@ -40,7 +40,7 @@ std::optional CompUniqueBoundsPB::Bound::getInt() const { namespace { Expr translateBoundFunc( - PBCtx &ctx, const PBSet &boundSet, + const PBSet &boundSet, const std::unordered_map &demangleMap) { if (boundSet.empty()) { @@ -49,7 +49,7 @@ Expr translateBoundFunc( // TODO: clear out those not related params PBSet compactedBoundSet = coalesce(boundSet); - auto parsed = parsePBFuncReconstructMinMax(ctx, compactedBoundSet); + auto parsed = parsePBFuncReconstructMinMax(compactedBoundSet); Expr result; ReplaceIter demangler(demangleMap); @@ -73,26 +73,26 @@ Expr translateBoundFunc( Expr CompUniqueBoundsPB::Bound::lowerExpr() const { return bound_.hasLowerBound(0) - ? translateBoundFunc(*ctx_, lexmin(bound_), *demangleMap_) + ? translateBoundFunc(lexmin(bound_), *demangleMap_) : nullptr; } Expr CompUniqueBoundsPB::Bound::upperExpr() const { return bound_.hasUpperBound(0) - ? translateBoundFunc(*ctx_, lexmax(bound_), *demangleMap_) + ? translateBoundFunc(lexmax(bound_), *demangleMap_) : nullptr; } std::tuple CompUniqueBoundsPB::Bound::lowerUpperDiffExpr() const { PBSet l = bound_.hasLowerBound(0) ? lexmin(bound_) : PBSet(); PBSet u = bound_.hasUpperBound(0) ? lexmax(bound_) : PBSet(); - PBSet diff = l.isValid() && u.isValid() - ? coalesce(apply(cartesianProduct(u, l), - PBMap(*ctx_, "{[u, l] -> [u - l]}"))) - : PBSet(); - return {l.isValid() ? translateBoundFunc(*ctx_, l, *demangleMap_) : nullptr, - u.isValid() ? translateBoundFunc(*ctx_, u, *demangleMap_) : nullptr, - diff.isValid() ? translateBoundFunc(*ctx_, diff, *demangleMap_) - : nullptr}; + PBSet diff = + l.isValid() && u.isValid() + ? coalesce(apply(cartesianProduct(u, l), + PBMap(bound_.ctx(), "{[u, l] -> [u - l]}"))) + : PBSet(); + return {l.isValid() ? translateBoundFunc(l, *demangleMap_) : nullptr, + u.isValid() ? translateBoundFunc(u, *demangleMap_) : nullptr, + diff.isValid() ? translateBoundFunc(diff, *demangleMap_) : nullptr}; } Ref CompUniqueBoundsPB::Bound::restrictScope( @@ -109,7 +109,7 @@ Ref CompUniqueBoundsPB::Bound::restrictScope( auto newBound = bound_; for (auto axes : views::reverse(axesToProject)) newBound = projectOutParamDims(newBound, axes, 1); - return Ref::make(ctx_, demangleMap_, newBound); + return Ref::make(demangleMap_, newBound); } Expr CompUniqueBoundsPB::Bound::simplestExpr( @@ -142,7 +142,7 @@ Expr CompUniqueBoundsPB::Bound::simplestExpr( restrictedBound = std::move(newRestrictedBound); minScopeLevel = scopeLevel; } - auto resultExpr = translateBoundFunc(*ctx_, restrictedBound, *demangleMap_); + auto resultExpr = translateBoundFunc(restrictedBound, *demangleMap_); if (!resultExpr.isValid()) { return nullptr; } @@ -187,8 +187,8 @@ CompUniqueBoundsPB::CompUniqueBoundsPB( } str += str.empty() ? subStr : "; " + subStr; } - cachedConds_ = PBSet(*ctx_, "[" + (varMap | views::values | join(", ")) + - "] -> {" + str + "}"); + cachedConds_ = PBSet(ctx_, "[" + (varMap | views::values | join(", ")) + + "] -> {" + str + "}"); // initialize known demangle map cachedFreeVars_ = decltype(cachedFreeVars_)::make(); @@ -233,8 +233,8 @@ Ref CompUniqueBoundsPB::getBound(const Expr &op) { str += str.empty() ? subStr : "; " + subStr; } auto bound = - (intersect(PBSet(*ctx_, "[" + (varMap | views::values | join(", ")) + - "] -> {" + str + "}"), + (intersect(PBSet(ctx_, "[" + (varMap | views::values | join(", ")) + + "] -> {" + str + "}"), cachedConds_)); // update free variables for (auto &&[expr, pbVar] : varMap) { @@ -244,7 +244,7 @@ Ref CompUniqueBoundsPB::getBound(const Expr &op) { else (*cachedFreeVars_)[pbVar] = expr; } - return cachedValues_[op] = Ref::make(ctx_, cachedFreeVars_, bound); + return cachedValues_[op] = Ref::make(cachedFreeVars_, bound); } bool CompUniqueBoundsPB::alwaysLE(const Expr &lhs, const Expr &rhs) { @@ -252,7 +252,7 @@ bool CompUniqueBoundsPB::alwaysLE(const Expr &lhs, const Expr &rhs) { r = insertDims(getBound(rhs).as()->bound_, 0, 1); // we check for the emptiness of l > r; if empty, it means we never have l > // r, or equivalently always have l <= r - auto combined = intersect(intersect(l, r), PBSet(*ctx_, "{[l, r]: l > r}")); + auto combined = intersect(intersect(l, r), PBSet(ctx_, "{[l, r]: l > r}")); return combined.empty(); } @@ -260,8 +260,7 @@ bool CompUniqueBoundsPB::alwaysLT(const Expr &lhs, const Expr &rhs) { auto l = insertDims(getBound(lhs).as()->bound_, 1, 1), r = insertDims(getBound(rhs).as()->bound_, 0, 1); // similar to alwaysLE, but !LT = GE - auto combined = - intersect(intersect(l, r), PBSet(*ctx_, "{[l, r]: l >= r}")); + auto combined = intersect(intersect(l, r), PBSet(ctx_, "{[l, r]: l >= r}")); return combined.empty(); } @@ -275,8 +274,8 @@ Ref CompUniqueBoundsPB::unionBoundsAsBound( _bounds | views::transform([&](auto &&_bound) { ASSERT(_bound->type() == BoundType::Presburger); auto &&bound = _bound.template as(); - return Ref::make(ctx_, bound->demangleMap_, - PBSet(*ctx_, toString(bound->bound_))); + return Ref::make(bound->demangleMap_, + bound->bound_.to(ctx_)); })); // union the bounds @@ -305,7 +304,7 @@ Ref CompUniqueBoundsPB::unionBoundsAsBound( (*demangleMap)[dimName] = demangled; } - return Ref::make(ctx_, demangleMap, bound); + return Ref::make(demangleMap, bound); } std::pair CompUniqueBoundsPB::unionBounds( diff --git a/src/analyze/deps.cc b/src/analyze/deps.cc index c40475e75..9158505ee 100644 --- a/src/analyze/deps.cc +++ b/src/analyze/deps.cc @@ -424,9 +424,9 @@ std::string AnalyzeDeps::makeCond(GenPBExpr &genPBExpr, return ret; } -PBMap AnalyzeDeps::makeAccMapStatic(PBCtx &presburger, const AccessPoint &p, - int iterDim, int accDim, - const std::string &extSuffix, +PBMap AnalyzeDeps::makeAccMapStatic(const Ref &presburger, + const AccessPoint &p, int iterDim, + int accDim, const std::string &extSuffix, GenPBExpr::VarMap &externals, const ASTHashSet &noNeedToBeVars, bool eraseOutsideVarDef) { @@ -474,45 +474,57 @@ std::string AnalyzeDeps::makeNdList(const std::string &name, int n) { } PBMap AnalyzeDeps::makeEqForBothOps( - PBCtx &presburger, const std::vector> &coord, + const Ref &presburger, const std::vector> &coord, int iterDim) const { auto map = universeMap(spaceAlloc(presburger, 0, iterDim, iterDim)).move(); for (auto &&[dim, val] : coord) map = isl_map_fix_si(isl_map_fix_si(map, isl_dim_out, dim, val), isl_dim_in, dim, val); - return PBMap(map); + return PBMap(presburger, map); } -PBMap AnalyzeDeps::makeIneqBetweenOps(PBCtx &presburger, DepDirection mode, - int iterId, int iterDim) const { +PBMap AnalyzeDeps::makeIneqBetweenOps(const Ref &presburger, + DepDirection mode, int iterId, + int iterDim) const { switch (mode) { case DepDirection::Inv: - return PBMap(isl_map_order_gt( - universeMap(spaceAlloc(presburger, 0, iterDim, iterDim)).move(), - isl_dim_out, iterId, isl_dim_in, iterId)); + return PBMap( + presburger, + isl_map_order_gt( + universeMap(spaceAlloc(presburger, 0, iterDim, iterDim)).move(), + isl_dim_out, iterId, isl_dim_in, iterId)); case DepDirection::Normal: - return PBMap(isl_map_order_lt( - universeMap(spaceAlloc(presburger, 0, iterDim, iterDim)).move(), - isl_dim_out, iterId, isl_dim_in, iterId)); + return PBMap( + presburger, + isl_map_order_lt( + universeMap(spaceAlloc(presburger, 0, iterDim, iterDim)).move(), + isl_dim_out, iterId, isl_dim_in, iterId)); case DepDirection::Same: - return PBMap(isl_map_equate( - universeMap(spaceAlloc(presburger, 0, iterDim, iterDim)).move(), - isl_dim_out, iterId, isl_dim_in, iterId)); + return PBMap( + presburger, + isl_map_equate( + universeMap(spaceAlloc(presburger, 0, iterDim, iterDim)).move(), + isl_dim_out, iterId, isl_dim_in, iterId)); case DepDirection::Different: return uni( - PBMap(isl_map_order_lt( - universeMap(spaceAlloc(presburger, 0, iterDim, iterDim)).move(), - isl_dim_out, iterId, isl_dim_in, iterId)), - PBMap(isl_map_order_gt( - universeMap(spaceAlloc(presburger, 0, iterDim, iterDim)).move(), - isl_dim_out, iterId, isl_dim_in, iterId))); + PBMap(presburger, + isl_map_order_lt( + universeMap(spaceAlloc(presburger, 0, iterDim, iterDim)) + .move(), + isl_dim_out, iterId, isl_dim_in, iterId)), + PBMap(presburger, + isl_map_order_gt( + universeMap(spaceAlloc(presburger, 0, iterDim, iterDim)) + .move(), + isl_dim_out, iterId, isl_dim_in, iterId))); default: ASSERT(false); } } -PBMap AnalyzeDeps::makeConstraintOfSingleLoop(PBCtx &presburger, const ID &loop, - DepDirection mode, int iterDim) { +PBMap AnalyzeDeps::makeConstraintOfSingleLoop(const Ref &presburger, + const ID &loop, DepDirection mode, + int iterDim) { if (!scope2coord_.count(loop)) { // If we don't have the scope in `scope2coord_`, it means the scope is // trivial, which has only one instance inside. So it must be `Same` @@ -548,7 +560,7 @@ PBMap AnalyzeDeps::makeConstraintOfSingleLoop(PBCtx &presburger, const ID &loop, makeIneqBetweenOps(presburger, mode, iterId, iterDim)); } -PBMap AnalyzeDeps::makeConstraintOfParallelScope(PBCtx &presburger, +PBMap AnalyzeDeps::makeConstraintOfParallelScope(const Ref &presburger, const ParallelScope ¶llel, DepDirection mode, int iterDim, const AccessPoint &later, @@ -598,7 +610,7 @@ PBMap AnalyzeDeps::makeConstraintOfParallelScope(PBCtx &presburger, " d" + std::to_string(laterDim) + "}"); } -PBMap AnalyzeDeps::makeExternalEq(PBCtx &presburger, int iterDim, +PBMap AnalyzeDeps::makeExternalEq(const Ref &presburger, int iterDim, const std::string &ext1, const std::string &ext2) { PBMap universe = universeMap(spaceAlloc(presburger, 0, iterDim, iterDim)); @@ -621,7 +633,7 @@ const std::string &AnalyzeDeps::getVar(const AST &op) { } } -PBMap AnalyzeDeps::makeSerialToAll(PBCtx &presburger, int iterDim, +PBMap AnalyzeDeps::makeSerialToAll(const Ref &presburger, int iterDim, const std::vector &point) const { std::string to = makeNdList("d", iterDim), from; for (int i = 0; i < iterDim; i++) { @@ -635,7 +647,7 @@ PBMap AnalyzeDeps::makeSerialToAll(PBCtx &presburger, int iterDim, return PBMap(presburger, "{" + from + " -> " + to + "}"); } -PBMap AnalyzeDeps::makeEraseVarDefConstraint(PBCtx &presburger, +PBMap AnalyzeDeps::makeEraseVarDefConstraint(const Ref &presburger, const Ref &point, int iterDim) { PBMap ret = universeMap(spaceAlloc(presburger, 0, iterDim, iterDim)); @@ -649,7 +661,7 @@ PBMap AnalyzeDeps::makeEraseVarDefConstraint(PBCtx &presburger, return ret; } -PBMap AnalyzeDeps::makeNoDepsConstraint(PBCtx &presburger, +PBMap AnalyzeDeps::makeNoDepsConstraint(const Ref &presburger, const std::string &var, int iterDim) { PBMap ret = universeMap(spaceAlloc(presburger, 0, iterDim, iterDim)); if (noDepsLists_.count(var)) { @@ -663,7 +675,7 @@ PBMap AnalyzeDeps::makeNoDepsConstraint(PBCtx &presburger, } PBMap AnalyzeDeps::makeExternalVarConstraint( - PBCtx &presburger, const Ref &later, + const Ref &presburger, const Ref &later, const Ref &earlier, const GenPBExpr::VarMap &laterExternals, const GenPBExpr::VarMap &earlierExternals, int iterDim) { PBMap ret = universeMap(spaceAlloc(presburger, 0, iterDim, iterDim)); @@ -708,8 +720,8 @@ PBMap AnalyzeDeps::makeExternalVarConstraint( return ret; } -PBMap AnalyzeDeps::projectOutPrivateAxis(PBCtx &presburger, int iterDim, - int since) { +PBMap AnalyzeDeps::projectOutPrivateAxis(const Ref &presburger, + int iterDim, int since) { std::string from = makeNdList("d", iterDim); std::string to; for (int i = 0; i < iterDim; i++) { @@ -720,7 +732,7 @@ PBMap AnalyzeDeps::projectOutPrivateAxis(PBCtx &presburger, int iterDim, } void AnalyzeDeps::projectOutPrivateAxis( - PBCtx &presburger, const Ref &point, + const Ref &presburger, const Ref &point, const std::vector> &otherList, std::vector &otherMapList, int iterDim) { if (!noProjectOutPrivateAxis_) { @@ -771,8 +783,7 @@ int AnalyzeDeps::numCommonDims(const Ref &p1, return n; } -void AnalyzeDeps::checkAgainstCond(PBCtx &presburger, - const Ref &later, +void AnalyzeDeps::checkAgainstCond(const Ref &later, const Ref &earlier, const PBMap &depAll, const PBMap &nearest, const PBMap &laterMap, @@ -822,7 +833,6 @@ void AnalyzeDeps::checkAgainstCond(PBCtx &presburger, if (mode_ == FindDepsMode::KillEarlier || mode_ == FindDepsMode::KillBoth) { if (auto flag = pbFuncWithTimeout( - presburger, static_cast( isSubset), 10, realEarlierIter, range(depAll)); @@ -833,7 +843,6 @@ void AnalyzeDeps::checkAgainstCond(PBCtx &presburger, if (mode_ == FindDepsMode::KillLater || mode_ == FindDepsMode::KillBoth) { if (auto flag = pbFuncWithTimeout( - presburger, static_cast( isSubset), 10, realLaterIter, domain(depAll)); @@ -846,7 +855,6 @@ void AnalyzeDeps::checkAgainstCond(PBCtx &presburger, if (mode_ == FindDepsMode::KillEarlier || mode_ == FindDepsMode::KillBoth) { if (auto flag = pbFuncWithTimeout( - presburger, static_cast( isSubset), 10, realEarlierIter, range(nearest)); @@ -857,7 +865,6 @@ void AnalyzeDeps::checkAgainstCond(PBCtx &presburger, if (mode_ == FindDepsMode::KillLater || mode_ == FindDepsMode::KillBoth) { if (auto flag = pbFuncWithTimeout( - presburger, static_cast( isSubset), 10, realLaterIter, domain(nearest)); @@ -872,11 +879,11 @@ void AnalyzeDeps::checkAgainstCond(PBCtx &presburger, for (auto &&[nodeOrParallel, dir] : item) { if (nodeOrParallel.isNode_) { _requires.emplace_back(makeConstraintOfSingleLoop( - presburger, nodeOrParallel.id_, dir, iterDim)); + depAll.ctx(), nodeOrParallel.id_, dir, iterDim)); } else { _requires.emplace_back(makeConstraintOfParallelScope( - presburger, nodeOrParallel.parallel_, dir, iterDim, *later, - *earlier)); + depAll.ctx(), nodeOrParallel.parallel_, dir, iterDim, + *later, *earlier)); } } @@ -900,13 +907,13 @@ void AnalyzeDeps::checkAgainstCond(PBCtx &presburger, if (noProjectOutPrivateAxis_) { found_(Dependence{item, getVar(later->op_), *later, *earlier, iterDim, res, laterMap, earlierMap, possible, - extConstraint, presburger, *this}); + extConstraint, *this}); } else { // It will be misleading if we pass Presburger maps to users in // this case found_(Dependence{item, getVar(later->op_), *later, *earlier, iterDim, PBMap(), PBMap(), PBMap(), PBMap(), - PBMap(), presburger, *this}); + PBMap(), *this}); } fail:; } @@ -999,8 +1006,7 @@ void AnalyzeDeps::checkDepLatestEarlier( return; } tasks_.emplace_back([later, earlierList = std::move(earlierList), this]() { - PBCtx presburger; - checkDepLatestEarlierImpl(presburger, later, earlierList); + checkDepLatestEarlierImpl(Ref::make(), later, earlierList); }); } @@ -1022,13 +1028,12 @@ void AnalyzeDeps::checkDepEarliestLater( return; } tasks_.emplace_back([laterList = std::move(laterList), earlier, this]() { - PBCtx presburger; - checkDepEarliestLaterImpl(presburger, laterList, earlier); + checkDepEarliestLaterImpl(Ref::make(), laterList, earlier); }); } void AnalyzeDeps::checkDepLatestEarlierImpl( - PBCtx &presburger, const Ref &later, + const Ref &presburger, const Ref &later, const std::vector> &earlierList) { int accDim = later->access_.size(); int iterDim = later->iter_.size(); @@ -1132,7 +1137,7 @@ void AnalyzeDeps::checkDepLatestEarlierImpl( views::zip(earlierList, es2aList, earlierMapList, depAllList)) { if (depAll.isValid()) { checkAgainstCond( - presburger, later, earlier, depAll, + later, earlier, depAll, intersect(applyRange(psNearest, std::move(es2a)), depAll), laterMap, earlierMap, extConstraint, iterDim); } @@ -1140,7 +1145,8 @@ void AnalyzeDeps::checkDepLatestEarlierImpl( } void AnalyzeDeps::checkDepEarliestLaterImpl( - PBCtx &presburger, const std::vector> &laterList, + const Ref &presburger, + const std::vector> &laterList, const Ref &earlier) { int accDim = earlier->access_.size(); int iterDim = earlier->iter_.size(); @@ -1243,7 +1249,7 @@ void AnalyzeDeps::checkDepEarliestLaterImpl( views::zip(laterList, ls2aList, laterMapList, depAllList)) { if (depAll.isValid()) { checkAgainstCond( - presburger, later, earlier, depAll, + later, earlier, depAll, intersect(applyDomain(spNearest, std::move(ls2a)), depAll), laterMap, earlierMap, extConstraint, iterDim); } @@ -1299,11 +1305,11 @@ PBMap Dependence::extraCheck(PBMap dep, PBMap require; if (nodeOrParallel.isNode_) { require = self_.makeConstraintOfSingleLoop( - presburger_, nodeOrParallel.id_, dir, iterDim_); + later2EarlierIter_.ctx(), nodeOrParallel.id_, dir, iterDim_); } else { require = self_.makeConstraintOfParallelScope( - presburger_, nodeOrParallel.parallel_, dir, iterDim_, later_, - earlier_); + later2EarlierIter_.ctx(), nodeOrParallel.parallel_, dir, iterDim_, + later_, earlier_); } dep = intersect(std::move(dep), std::move(require)); return dep; diff --git a/src/autograd/invert_stmts.cc b/src/autograd/invert_stmts.cc index 764f43189..3ed07ddbb 100644 --- a/src/autograd/invert_stmts.cc +++ b/src/autograd/invert_stmts.cc @@ -58,7 +58,7 @@ struct CondInfo { Expr cond_; }; -PBMap anythingTo1(PBCtx &presburger, int nDims) { +PBMap anythingTo1(const Ref &presburger, int nDims) { std::ostringstream os; os << "{["; for (int i = 0; i < nDims; i++) { @@ -68,7 +68,7 @@ PBMap anythingTo1(PBCtx &presburger, int nDims) { return PBMap(presburger, os.str()); } -void genCondExpr(PBCtx &presburger, CondInfo *info) { +void genCondExpr(const Ref &presburger, CondInfo *info) { // Use less basic sets to express info->when_. It also makes our final // expression simpler info->when_ = coalesce(info->when_); @@ -82,7 +82,7 @@ void genCondExpr(PBCtx &presburger, CondInfo *info) { PBMap indicator = intersectDomain( anythingTo1(presburger, info->when_.nDims()), info->when_); for (auto &&[args, _, factorRange] : - parsePBFunc(toString(PBFunc(indicator)))) { + parsePBFunc(PBFunc(indicator).toSerialized())) { if (!allReads(factorRange).empty()) { throw ParserError("External variable in recomputing condition " "is not yet supported"); @@ -230,7 +230,7 @@ invertStmts(const Stmt &op, // invertible statement, and when X reaches the end of its lifetime. To // detect the latter case, we insert a fake self-assigning statement just // before the lifetime's end before dependence analysis. - PBCtx presburger; + auto presburger = Ref::make(); std::unordered_map unrecoverableInfo, toInvertInfo; std::unordered_map allPossibleIters; // TODO: We can apply an additional filter to only invert Y if it already to @@ -247,9 +247,8 @@ invertStmts(const Stmt &op, } return false; })(InsertLifetimeEndChecker{*idsNeeded}(op), [&](const Dependence &d) { - // Serialize and deserialize to change PBCtx - auto earlierIterSet = - PBSet(presburger, toString(range(d.later2EarlierIter_))); + auto later2EarlierIter = d.later2EarlierIter_.to(presburger); + auto earlierIterSet = range(later2EarlierIter); // Trim paddings earlierIterSet = projectOutDims( @@ -259,9 +258,7 @@ invertStmts(const Stmt &op, auto toInvert = d.later_.stmt_->id(); auto toRecover = d.earlier_.stmt_->id(); if (invertibles.count(d.later_.stmt_->id())) { // Can invert - // Serialize and deserialize to change PBCtx - auto laterIterSet = - PBSet(presburger, toString(domain(d.later2EarlierIter_))); + auto laterIterSet = domain(later2EarlierIter); // Trim paddings laterIterSet = projectOutDims( diff --git a/src/math/parse_pb_expr.cc b/src/math/parse_pb_expr.cc index 9fba14c91..109e982d0 100644 --- a/src/math/parse_pb_expr.cc +++ b/src/math/parse_pb_expr.cc @@ -99,9 +99,7 @@ class RecoverBoolVars : public Mutator { } }; -} // Anonymous namespace - -PBFuncAST parsePBFunc(const std::string &str) { +PBFuncAST parsePBFuncImpl(const std::string &str) { try { antlr4::ANTLRInputStream charStream(str); pb_lexer lexer(&charStream); @@ -127,6 +125,15 @@ PBFuncAST parsePBFunc(const std::string &str) { } } +} // Anonymous namespace + +PBFuncAST parsePBFunc(const PBFunc::Serialized &f) { + return parsePBFuncImpl(f.data()); +} +PBFuncAST parsePBFunc(const PBSingleFunc::Serialized &f) { + return parsePBFuncImpl(f.data()); +} + namespace { Expr isl2Expr(__isl_take isl_ast_expr *e) { @@ -304,7 +311,7 @@ isl2Func(__isl_take isl_ast_node *node) { } // Anonymous namespace -PBFuncAST parsePBFuncReconstructMinMax(const PBCtx &ctx, const PBSet &set) { +PBFuncAST parsePBFuncReconstructMinMax(const PBSet &set) { // This is a hack to isl's schedule. Treat the set as an iteration domain. // For a single-valued set, the domain will be zero or one statement, // implemented by a statement in multiple branches. We can recover Expr from @@ -324,10 +331,10 @@ PBFuncAST parsePBFuncReconstructMinMax(const PBCtx &ctx, const PBSet &set) { }) | ranges::to(); - isl_options_set_ast_build_detect_min_max(ctx.get(), 1); + isl_options_set_ast_build_detect_min_max(set.ctx()->get(), 1); PBFuncAST ret; - isl_ast_build *build = isl_ast_build_alloc(ctx.get()); + isl_ast_build *build = isl_ast_build_alloc(set.ctx()->get()); try { isl_schedule *s = isl_schedule_from_domain(isl_union_set_from_set(set.copy())); @@ -347,7 +354,7 @@ PBFuncAST parsePBFuncReconstructMinMax(const PBCtx &ctx, const PBSet &set) { namespace { -template PBMap moveAllInputDimsToParam(const PBCtx &ctx, T &&map) { +template PBMap moveAllInputDimsToParam(T &&map) { // A name is required for the parameter, so we can't simply use // isl_map_move_dims. We constuct a map to apply on the set to move the // dimension. Example map: [i1, i2] -> {[i1, i2] -> []}. The parameters are @@ -366,15 +373,14 @@ template PBMap moveAllInputDimsToParam(const PBCtx &ctx, T &&map) { }) | join(",")) << "] -> []}"; - PBMap moving(ctx, os.str()); + PBMap moving(map.ctx(), os.str()); return applyDomain(std::forward(map), std::move(moving)); } } // Anonymous namespace -PBFuncAST parsePBFuncReconstructMinMax(const PBCtx &ctx, const PBMap &map) { - return parsePBFuncReconstructMinMax( - ctx, range(moveAllInputDimsToParam(ctx, map))); +PBFuncAST parsePBFuncReconstructMinMax(const PBMap &map) { + return parsePBFuncReconstructMinMax(range(moveAllInputDimsToParam(map))); } } // namespace freetensor diff --git a/src/math/presburger.cc b/src/math/presburger.cc index 096374d1a..4a47361c9 100644 --- a/src/math/presburger.cc +++ b/src/math/presburger.cc @@ -70,7 +70,7 @@ std::vector PBMapBuilder::newOutputs(int n, return ret; } -PBMap PBMapBuilder::build(const PBCtx &ctx) const { +PBMap PBMapBuilder::build(const Ref &ctx) const { return {ctx, "{ [" + join(inputs_, ", ") + "] -> [" + join(outputs_, ", ") + "]: " + getConstraintsStr() + " }"}; } @@ -89,7 +89,7 @@ std::vector PBSetBuilder::newVars(int n, return ret; } -PBSet PBSetBuilder::build(const PBCtx &ctx) const { +PBSet PBSetBuilder::build(const Ref &ctx) const { return {ctx, "{ [" + join(vars_, ", ") + "]: " + getConstraintsStr() + " }"}; } diff --git a/src/pass/prop_one_time_use.cc b/src/pass/prop_one_time_use.cc index 7753f5add..e18412354 100644 --- a/src/pass/prop_one_time_use.cc +++ b/src/pass/prop_one_time_use.cc @@ -19,7 +19,7 @@ namespace { struct ReplaceInfo { std::vector earlierIters_, laterIters_; - std::string funcStr_; + PBFunc::Serialized func_; }; std::vector>> @@ -69,8 +69,8 @@ Stmt propOneTimeUse(const Stmt &_op, const ID &subAST) { r2wCandidates; std::unordered_map> r2wMay; std::unordered_set wCandidates; - std::unordered_map>> + std::unordered_map< + Stmt, std::vector>> w2rMay; std::unordered_map stmts; std::mutex lock; @@ -115,7 +115,6 @@ Stmt propOneTimeUse(const Stmt &_op, const ID &subAST) { // we not only need `singleValued`, but also `bijective`, to // ensure it is really used "one time" if (auto f = pbFuncWithTimeout( - d.presburger_, [](const PBMap &map) { return PBFunc(map); }, 10, d.later2EarlierIter_); f.has_value()) { @@ -125,7 +124,7 @@ Stmt propOneTimeUse(const Stmt &_op, const ID &subAST) { // which may propagate d.earlier().as(), ReplaceInfo{d.earlier_.iter_, d.later_.iter_, - toString(*f)}); + f->toSerialized()}); wCandidates.emplace(d.earlier().as()); stmts[d.later()] = d.later_.stmt_; } @@ -154,7 +153,7 @@ Stmt propOneTimeUse(const Stmt &_op, const ID &subAST) { writeIter.nDims() - d.earlier_.iter_.size()); r2wMay[d.later()].emplace_back(d.earlier().as()); w2rMay[d.earlier().as()].emplace_back( - d.later(), toString(writeIter)); + d.later(), writeIter.toSerialized()); }); // Filter single-valued and one-time-used @@ -174,10 +173,10 @@ Stmt propOneTimeUse(const Stmt &_op, const ID &subAST) { // Check one-time use: All read statements should read different items // written by the write statement ASSERT(w2rMay.count(write.first)); - PBCtx ctx; + auto ctx = Ref::make(); PBSet writeIterUnion; - for (auto &&[read, writeIterStr] : w2rMay.at(write.first)) { - PBSet writeIter = PBSet(ctx, writeIterStr); + for (auto &&[read, _writeIter] : w2rMay.at(write.first)) { + PBSet writeIter = _writeIter.to(ctx); if (writeIterUnion.isValid()) { if (!intersect(writeIterUnion, writeIter).empty()) { goto failure; // Not one-time-used @@ -213,7 +212,7 @@ Stmt propOneTimeUse(const Stmt &_op, const ID &subAST) { if (!allIters(toProp).empty()) { try { auto &&[args, values, cond] = - parseSimplePBFunc(repInfo.funcStr_); // later -> earlier + parseSimplePBFunc(repInfo.func_); // later -> earlier ASSERT(repInfo.earlierIters_.size() <= values.size()); // maybe padded ASSERT(repInfo.laterIters_.size() <= args.size()); diff --git a/src/pass/remove_writes.cc b/src/pass/remove_writes.cc index 09f7171a9..2d9fb78bc 100644 --- a/src/pass/remove_writes.cc +++ b/src/pass/remove_writes.cc @@ -18,7 +18,7 @@ namespace { struct ReplaceInfo { std::vector earlierIters_, laterIters_; - std::string funcStr_; + PBFunc::Serialized func_; }; } // Anonymous namespace @@ -219,7 +219,7 @@ Stmt removeWrites(const Stmt &_op, const ID &singleDefId) { }) .ignoreReductionWAW(false)(op, foundSelfDependent); - PBCtx presburger; + auto presburger = Ref::make(); // {(later, earlier, toKill, replaceInfo)} std::vector> overwrites; std::unordered_map> usesRAW; // W -> R @@ -229,19 +229,16 @@ Stmt removeWrites(const Stmt &_op, const ID &singleDefId) { auto earlier = d.earlier().as(); auto later = d.later().as(); if (!kill.count(earlier)) { - kill[earlier] = - PBSet(presburger, toString(domain(d.earlierIter2Idx_))); + kill[earlier] = domain(d.earlierIter2Idx_.to(presburger)); } - auto extConstraint = - PBSet(presburger, toString(range(d.extConstraint_))); + auto extConstraint = range(d.extConstraint_.to(presburger)); std::tie(kill[earlier], extConstraint) = padToSameDims(std::move(kill[earlier]), std::move(extConstraint)); kill[earlier] = intersect(std::move(kill[earlier]), std::move(extConstraint)); - overwrites.emplace_back( - later, earlier, - PBSet(presburger, toString(range(d.later2EarlierIter_))), - ReplaceInfo{}); + overwrites.emplace_back(later, earlier, + range(d.later2EarlierIter_.to(presburger)), + ReplaceInfo{}); suspect.insert(d.def()); }; std::mutex lock; @@ -251,7 +248,6 @@ Stmt removeWrites(const Stmt &_op, const ID &singleDefId) { sameParent(d.later_.stmt_, d.earlier_.stmt_))) { if (d.later2EarlierIter_.isSingleValued()) { if (auto f = pbFuncWithTimeout( - d.presburger_, [](const PBMap &map) { return PBFunc(map); }, 10, d.later2EarlierIter_); f.has_value()) { @@ -259,15 +255,14 @@ Stmt removeWrites(const Stmt &_op, const ID &singleDefId) { auto earlier = d.earlier().as(); auto later = d.later().as(); if (!kill.count(earlier)) { - kill[earlier] = PBSet( - presburger, toString(domain(d.earlierIter2Idx_))); + kill[earlier] = + domain(d.earlierIter2Idx_.to(presburger)); } overwrites.emplace_back( later, earlier, - PBSet(presburger, - toString(range(d.later2EarlierIter_))), + range(d.later2EarlierIter_.to(presburger)), ReplaceInfo{d.earlier_.iter_, d.later_.iter_, - toString(*f)}); + f->toSerialized()}); suspect.insert(d.def()); } } @@ -431,7 +426,7 @@ Stmt removeWrites(const Stmt &_op, const ID &singleDefId) { if (!allIters(expr).empty()) { try { auto &&[args, values, cond] = - parseSimplePBFunc(repInfo.funcStr_); // later -> earlier + parseSimplePBFunc(repInfo.func_); // later -> earlier ASSERT(repInfo.earlierIters_.size() <= values.size()); // maybe padded ASSERT(repInfo.laterIters_.size() <= args.size()); diff --git a/src/pass/shrink_for.cc b/src/pass/shrink_for.cc index dfc0b5cfd..f09da3a30 100644 --- a/src/pass/shrink_for.cc +++ b/src/pass/shrink_for.cc @@ -20,8 +20,7 @@ namespace freetensor { namespace { template -PBSet moveDimToNamedParam(const PBCtx &ctx, T &&set, int dim, - const std::string ¶m) { +PBSet moveDimToNamedParam(T &&set, int dim, const std::string ¶m) { // A name is required for the parameter, so we can't simply use // isl_set_move_dims. We constuct a map to apply on the set to move the // dimension. Example map: [p] -> {[_1, _2, p] -> [_1, _2]} @@ -39,7 +38,7 @@ PBSet moveDimToNamedParam(const PBCtx &ctx, T &&set, int dim, views::transform([](int i) { return "_" + std::to_string(i); }) | join(",")) << "]}"; - PBMap map(ctx, os.str()); + PBMap map(set.ctx(), os.str()); return apply(std::forward(set), std::move(map)); } @@ -48,13 +47,15 @@ class CompUniqueBoundsPBWithStride : public CompUniqueBoundsPB { std::pair getStride(const Ref &bound, bool requireConst) { isl_stride_info *info = isl_set_get_stride_info(bound->bound_.get(), 0); - auto stride = PBVal(isl_stride_info_get_stride(info)); - auto offset = PBSingleFunc(isl_stride_info_get_offset(info)); + auto stride = + PBVal(bound->bound_.ctx(), isl_stride_info_get_stride(info)); + auto offset = + PBSingleFunc(bound->bound_.ctx(), isl_stride_info_get_offset(info)); isl_stride_info_free(info); ASSERT(stride.denSi() == 1); auto strideInt = stride.numSi(); ReplaceIter demangler(*bound->demangleMap_); - auto offsetSimpleFunc = parseSimplePBFunc(toString(offset)); + auto offsetSimpleFunc = parseSimplePBFunc(offset.toSerialized()); // offsetSimpleFunc.args_ should be a dummy variable equals to `bound`'s // value. Leave it. ASSERT(offsetSimpleFunc.values_.size() == 1); @@ -117,7 +118,7 @@ class CompUniqueBoundsPBWithStride : public CompUniqueBoundsPB { PBSet set = bound->bound_; // Reveal local dimensions - set = isl_set_lift(set.move()); + set = PBSet{set.ctx(), isl_set_lift(set.move())}; // Put local dimension at front, so we can represent the target // dimension by local dimensions, instead of representing local @@ -141,7 +142,6 @@ class CompUniqueBoundsPBWithStride : public CompUniqueBoundsPB { } auto thisLoopBound = Ref::make( - bound->ctx_, Ref>::make(demangleMap), thisLoopSet); Expr l, u, diff; @@ -166,8 +166,7 @@ class CompUniqueBoundsPBWithStride : public CompUniqueBoundsPB { // represented by outer loops. The parameter name used here is // temporary, and will be replaced later. auto paramName = "ft_shrink_for_tmp_" + std::to_string(i++); - set = moveDimToNamedParam(*bound->ctx_, std::move(set), 0, - paramName); + set = moveDimToNamedParam(std::move(set), 0, paramName); demangleMap[paramName] = makeVar(paramName); } } diff --git a/src/pass/tensor_prop_const.cc b/src/pass/tensor_prop_const.cc index 7611f56e6..be48d6568 100644 --- a/src/pass/tensor_prop_const.cc +++ b/src/pass/tensor_prop_const.cc @@ -17,7 +17,7 @@ namespace { struct ReplaceInfo { std::vector earlierIters_, laterIters_; - std::string funcStr_; + PBFunc::Serialized func_; }; } // namespace @@ -111,7 +111,6 @@ Stmt tensorPropConst(const Stmt &_op, const ID &bothInSubAST, .isSingleValued()) { // Check before converting into // PBFunc if (auto f = pbFuncWithTimeout( - d.presburger_, [](const PBMap &map) { return PBFunc(map); }, 10, d.later2EarlierIter_); f.has_value()) { @@ -119,7 +118,7 @@ Stmt tensorPropConst(const Stmt &_op, const ID &bothInSubAST, r2w[d.later()].emplace_back( d.earlier().as(), ReplaceInfo{d.earlier_.iter_, d.later_.iter_, - toString(*f)}); + f->toSerialized()}); } } })); @@ -156,7 +155,7 @@ Stmt tensorPropConst(const Stmt &_op, const ID &bothInSubAST, if (!allIters(store->expr_).empty()) { try { auto &&[args, values, cond] = - parseSimplePBFunc(repInfo.funcStr_); // later -> earlier + parseSimplePBFunc(repInfo.func_); // later -> earlier ASSERT(repInfo.earlierIters_.size() <= values.size()); // maybe padded ASSERT(repInfo.laterIters_.size() <= args.size()); diff --git a/src/schedule/inlining.cc b/src/schedule/inlining.cc index e141e93b3..379381718 100644 --- a/src/schedule/inlining.cc +++ b/src/schedule/inlining.cc @@ -91,7 +91,7 @@ Stmt inlining(const Stmt &_ast, const ID &def) { throw ParserError("ISL map is not single-valued"); } auto &&[args, values, cond] = parseSimplePBFunc( - toString(PBFunc(dep.later2EarlierIter_))); + PBFunc(dep.later2EarlierIter_).toSerialized()); ASSERT(dep.earlier_.iter_.size() <= values.size()); // maybe padded ASSERT(dep.later_.iter_.size() <= args.size()); diff --git a/src/schedule/parallelize_as.cc b/src/schedule/parallelize_as.cc index bfe3b72df..2437b3acf 100644 --- a/src/schedule/parallelize_as.cc +++ b/src/schedule/parallelize_as.cc @@ -39,7 +39,6 @@ class AddParScopes : public TrackStmt> { typedef TrackStmt> BaseClass; ID nest_, defId_; - const PBCtx &presburger_; const std::vector &orderedScopes_; const std::unordered_map &scope2Idx2Iter_; @@ -55,11 +54,11 @@ class AddParScopes : public TrackStmt> { std::unordered_map>> threadGuard_; public: - AddParScopes(const ID &nest, const ID &defId, const PBCtx &presburger, + AddParScopes(const ID &nest, const ID &defId, const std::vector &orderedScopes, const std::unordered_map &scope2Idx2Iter) - : nest_(nest), defId_(defId), presburger_(presburger), - orderedScopes_(orderedScopes), scope2Idx2Iter_(scope2Idx2Iter) {} + : nest_(nest), defId_(defId), orderedScopes_(orderedScopes), + scope2Idx2Iter_(scope2Idx2Iter) {} const auto &newScopeIds() const { return newScopeIds_; } const auto &newNestId() const { return newNestId_; } @@ -74,8 +73,7 @@ class AddParScopes : public TrackStmt> { auto &&idx2iter = coalesce(scope2Idx2Iter_.at(scope->id())); SimplePBFuncAST f; try { - f = parseSimplePBFuncReconstructMinMax(presburger_, - idx2iter); + f = parseSimplePBFuncReconstructMinMax(idx2iter); } catch (const ParserError &e) { throw InvalidSchedule( FT_MSG << "Thread mapping is not a simple function: " @@ -187,7 +185,7 @@ Stmt parallelizeAs(const Stmt &_ast, const ID &nest, const ID &reference, })}; finder.doFind(ast); - PBCtx presburger; + auto presburger = Ref::make(); std::unordered_map scope2Idx2Iter; for (const Ref &acc : views::concat(finder.reads(), finder.writes())) { @@ -261,7 +259,7 @@ Stmt parallelizeAs(const Stmt &_ast, const ID &nest, const ID &reference, } } - AddParScopes adder{nest, defId, presburger, orderedScopes, scope2Idx2Iter}; + AddParScopes adder{nest, defId, orderedScopes, scope2Idx2Iter}; ast = adder(ast); // Shrink original loops in `nest` according to the gaurds with just add. If diff --git a/src/schedule/permute.cc b/src/schedule/permute.cc index ec5e60035..a30c81aed 100644 --- a/src/schedule/permute.cc +++ b/src/schedule/permute.cc @@ -175,7 +175,7 @@ std::pair> permute( permuteMapStr = oss.str(); } // check if the map is bijective - PBCtx pbCtx; + auto pbCtx = Ref::make(); PBMap permuteMap(pbCtx, permuteMapStr); { //! FIXME: if step != 1 or -1, this check will be more strict than @@ -216,7 +216,7 @@ std::pair> permute( }; size_t numBackDims; - std::string iter2permuted; + PBMap::Serialized iter2permuted; FindDeps() .direction({dir.begin(), dir.end()}) @@ -228,7 +228,7 @@ std::pair> permute( [&](const ID &defId, const std::unordered_map> &scope2coord) { - if (iter2permuted.empty() && scope2coord.count(loopsId[0])) { + if (!iter2permuted.isValid() && scope2coord.count(loopsId[0])) { // compute number of backing dimensions, outer than our // outermost loop numBackDims = scope2coord.at(loopsId[0]).size() - 1; @@ -239,21 +239,22 @@ std::pair> permute( // prepare iter -> permuted map string, for later use in // `found' - iter2permuted = toString(permuteMap); + iter2permuted = permuteMap.toSerialized(); } }) .filterSubAST(loops.front()->id())(ast, [&](const Dependence &d) { // Construct map for iter -> permuted - auto iter2permutedMap = PBMap(d.presburger_, iter2permuted); + auto iter2permutedMap = + iter2permuted.to(d.later2EarlierIter_.ctx()); // laterIter -> permuted auto &&permuteMapLater = applyRange( - PBMap(d.presburger_, + PBMap(d.later2EarlierIter_.ctx(), getExtractDimMap(numBackDims, numBackDims + loops.size(), d.later_.iter_.size())), iter2permutedMap); // earlierIter -> permuted auto &&permuteMapEarlier = applyRange( - PBMap(d.presburger_, + PBMap(d.later2EarlierIter_.ctx(), getExtractDimMap(numBackDims, numBackDims + loops.size(), d.earlier_.iter_.size())), iter2permutedMap); @@ -275,7 +276,7 @@ std::pair> permute( // compute the function for reverse permute; we iterate against permuted // space, thus we need to compute the original iterators from the new ones auto reversePermute = - parseSimplePBFunc(toString(PBFunc(reverse(permuteMap)))); + parseSimplePBFunc(PBFunc(reverse(permuteMap)).toSerialized()); // perform transformation Permute permuter(loopsId, reversePermute); diff --git a/src/schedule/pluto.cc b/src/schedule/pluto.cc index 94e618844..2e8c75230 100644 --- a/src/schedule/pluto.cc +++ b/src/schedule/pluto.cc @@ -89,7 +89,7 @@ orthogonalMatrix(const std::vector> &vectors) { return ortho == 0; }; - PBCtx ctx; + auto ctx = Ref::make(); builder.addConstraint(nonZeroConstraint(vars, delta)); builder.addConstraints(views::zip_with( [](auto &&abs, auto &&x) { return abs >= x && abs >= -x; }, abss, @@ -155,7 +155,7 @@ class InjectFakeAccess : public Mutator { } }; -PBMap combineExternal(PBMap l2e, PBCtx &ctx, bool isFakeAccess) { +PBMap combineExternal(PBMap l2e, const Ref &ctx, bool isFakeAccess) { auto nParams = l2e.nParamDims(); // combined <-> first met original std::map orig2comb; @@ -179,10 +179,11 @@ PBMap combineExternal(PBMap l2e, PBCtx &ctx, bool isFakeAccess) { builder.addConstraint(comb2origVar[comb] == origVar); } auto eqConstraints = builder.build(ctx); - eqConstraints = isl_set_move_dims(eqConstraints.move(), isl_dim_param, 0, - isl_dim_set, 0, nParams); + eqConstraints = + PBSet{ctx, isl_set_move_dims(eqConstraints.move(), isl_dim_param, 0, + isl_dim_set, 0, nParams)}; PBMap constrainedL2e( - isl_map_intersect_params(l2e.copy(), eqConstraints.copy())); + ctx, isl_map_intersect_params(l2e.copy(), eqConstraints.copy())); // Fake access uses AllPossible which doesn't have the corresponding // constraints. Thus we don't check fake accesses and always use the // constrained one. @@ -194,10 +195,11 @@ PBMap combineExternal(PBMap l2e, PBCtx &ctx, bool isFakeAccess) { for (int i = nParams - 1; i >= 0; --i) { std::string orig = isl_map_get_dim_name(l2e.get(), isl_dim_param, i); if (orig2comb.contains(orig)) - l2e = isl_map_set_dim_name(l2e.move(), isl_dim_param, i, - orig2comb[orig].c_str()); + l2e = PBMap{ctx, isl_map_set_dim_name(l2e.move(), isl_dim_param, i, + orig2comb[orig].c_str())}; else - l2e = isl_map_remove_dims(l2e.move(), isl_dim_param, i, 1); + l2e = PBMap{ctx, + isl_map_remove_dims(l2e.move(), isl_dim_param, i, 1)}; } return l2e; @@ -252,7 +254,7 @@ struct PermuteInfo { const std::vector> &cIterValue, const std::vector ¶mExprs, const std::vector &oldLoopAxes, - const std::vector &loopVars, const PBCtx &ctx, + const std::vector &loopVars, const Ref &ctx, const PBSet &loopSet) { size_t nParams = paramExprs.size(), nestLevel = loopVars.size(); ASSERT(oldLoopAxes.size() == loopVars.size()); @@ -275,7 +277,7 @@ struct PermuteInfo { newToOld = intersectRange(std::move(newToOld), loopSet); newToOld = moveDimsOutputToInput(std::move(newToOld), 0, nParams, 0); - auto func = parseSimplePBFunc(toString(PBFunc(newToOld))); + auto func = parseSimplePBFunc(PBFunc(newToOld).toSerialized()); ASSERT(func.args_.size() == unsigned(nParams + nestLevel)); ASSERT(func.values_.size() == unsigned(nestLevel)); @@ -416,8 +418,8 @@ plutoFuseImpl(Stmt ast, const ID &loop0Id, const ID &loop1Id, int _nestLevel0, n++; if (n < expectedN) throw InvalidSchedule( - "PlutoFuse: not enough loop nests found for " + - toString(loop)); + FT_MSG << "PlutoFuse: not enough loop nests found for " + << loop); } return std::pair{n, inner->parentStmt()->id()}; }; @@ -456,8 +458,8 @@ plutoFuseImpl(Stmt ast, const ID &loop0Id, const ID &loop1Id, int _nestLevel0, for (auto &&f : outers) outersSame[0].emplace_back(f->id(), DepDirection::Same); - PBCtx ctx; - std::string loop0SetStr, loop1SetStr; + auto ctx = Ref::make(); + PBSet::Serialized loop0SetStr, loop1SetStr; std::vector outerAxes, loop0Axes, loop1Axes; auto getDeps = [&](const For &l0, int n0, const For &l1, int n1, @@ -496,7 +498,8 @@ plutoFuseImpl(Stmt ast, const ID &loop0Id, const ID &loop1Id, int _nestLevel0, // combine external params from earlier and later, since we // don't expect them to change during the two loops in Pluto - hMap = combineExternal(std::move(hMap), d.presburger_, + hMap = combineExternal(std::move(hMap), + d.earlierIter2Idx_.ctx(), d.var_ == FAKE_ACCESS_VAR); auto nRealParams = hMap.nParamDims(); @@ -519,9 +522,11 @@ plutoFuseImpl(Stmt ast, const ID &loop0Id, const ID &loop1Id, int _nestLevel0, hMap = moveDimsOutputToParam( std::move(hMap), 0, outerDims0.size(), nRealParams); for (size_t i = 0; i < outerDims0.size(); ++i) - hMap = isl_map_set_dim_name( - hMap.move(), isl_dim_param, nRealParams + i, - ("out_" + std::to_string(i)).c_str()); + hMap = PBMap{hMap.ctx(), + isl_map_set_dim_name( + hMap.move(), isl_dim_param, + nRealParams + i, + ("out_" + std::to_string(i)).c_str())}; // remove inner dims for later auto [pos1, outerDims1] = @@ -537,8 +542,8 @@ plutoFuseImpl(Stmt ast, const ID &loop0Id, const ID &loop1Id, int _nestLevel0, // if fake access, set the loop sets instead of store deps if (d.var_ == FAKE_ACCESS_VAR) { // hMap is later -> earlier, hence loop1 -> loop0. - loop1SetStr = toString(std::move(domain(hMap))); - loop0SetStr = toString(std::move(range(hMap))); + loop1SetStr = domain(hMap).toSerialized(); + loop0SetStr = range(hMap).toSerialized(); return; } @@ -548,7 +553,8 @@ plutoFuseImpl(Stmt ast, const ID &loop0Id, const ID &loop1Id, int _nestLevel0, auto hSet = flattenMapToSet(std::move(hMap)); // overapproximate to allow coefficients on strided // dependences - hSet = isl_set_remove_unknown_divs(hSet.move()); + hSet = PBSet{hSet.ctx(), + isl_set_remove_unknown_divs(hSet.move())}; auto strSet = toString(std::move(hSet)); std::lock_guard l(m); @@ -567,7 +573,7 @@ plutoFuseImpl(Stmt ast, const ID &loop0Id, const ID &loop1Id, int _nestLevel0, auto dep1to0 = getDeps(loop0, nestLevel0, loop1, nestLevel1, true); // construct loop sets - PBSet loop0Set(ctx, loop0SetStr), loop1Set(ctx, loop1SetStr); + PBSet loop0Set = loop0SetStr.to(ctx), loop1Set = loop1SetStr.to(ctx); // align external params and move to set dimensions const auto [nParams, paramExprs] = [&] { @@ -575,17 +581,20 @@ plutoFuseImpl(Stmt ast, const ID &loop0Id, const ID &loop1Id, int _nestLevel0, // that no dependence exists to bring them into the space auto paramsSpace = spaceSetAlloc(ctx, outerAxes.size(), 0); for (size_t i = 0; i < outerAxes.size(); ++i) - paramsSpace = + paramsSpace = PBSpace{ + paramsSpace.ctx(), isl_space_set_dim_name(paramsSpace.move(), isl_dim_param, i, - ("out_" + std::to_string(i)).c_str()); + ("out_" + std::to_string(i)).c_str())}; DisjointSet paramsConnect; // setup a common root, which represents combination of set dimensions paramsConnect.find(""); for (auto &d : {&dep0, &dep1, &dep1to0}) for (const auto &dd : *d) { - paramsSpace = isl_space_align_params(paramsSpace.move(), - PBSpace(dd).move()); + paramsSpace = + PBSpace{paramsSpace.ctx(), + isl_space_align_params(paramsSpace.move(), + PBSpace(dd).move())}; isl_set_foreach_basic_set( dd.get(), [](isl_basic_set *bset, void *user) { @@ -638,14 +647,17 @@ plutoFuseImpl(Stmt ast, const ID &loop0Id, const ID &loop1Id, int _nestLevel0, nAllParams - ranges::accumulate(isRedundants, 0); auto align = [&](PBSet &s) { - s = isl_set_align_params(s.move(), paramsSpace.copy()); + s = PBSet{s.ctx(), + isl_set_align_params(s.move(), paramsSpace.copy())}; int j = 0; for (int i = 0; i < nAllParams; ++i) if (isRedundants[i]) - s = isl_set_project_out(s.move(), isl_dim_param, 0, 1); + s = PBSet{s.ctx(), isl_set_project_out( + s.move(), isl_dim_param, 0, 1)}; else - s = isl_set_move_dims(s.move(), isl_dim_set, j++, - isl_dim_param, 0, 1); + s = PBSet{s.ctx(), + isl_set_move_dims(s.move(), isl_dim_set, j++, + isl_dim_param, 0, 1)}; }; for (auto &d : {&dep0, &dep1, &dep1to0}) @@ -1014,7 +1026,7 @@ plutoFuseImpl(Stmt ast, const ID &loop0Id, const ID &loop1Id, int _nestLevel0, builder.addOutputs(p); builder.addOutput(result); auto projectedLoopRange = apply(loopSet, builder.build(ctx)); - return PBSet(projectedLoopRange.move()); + return PBSet(ctx, projectedLoopRange.move()); }; auto loop0Range = loopSetToRange(loop0Set, nParams + 1, nestLevel0); auto loop1Range = loopSetToRange( @@ -1046,10 +1058,14 @@ plutoFuseImpl(Stmt ast, const ID &loop0Id, const ID &loop1Id, int _nestLevel0, bool hugeNonOverlap = false; auto check = [&, nParams = nParams](const PBSet &_aRange, const PBSet &_bRange) { - PBSet aRange = isl_set_move_dims(_aRange.copy(), isl_dim_param, - 0, isl_dim_set, 0, nParams); - PBSet bRange = isl_set_move_dims(_bRange.copy(), isl_dim_param, - 0, isl_dim_set, 0, nParams); + PBSet aRange = + PBSet{_aRange.ctx(), + isl_set_move_dims(_aRange.copy(), isl_dim_param, 0, + isl_dim_set, 0, nParams)}; + PBSet bRange = + PBSet{_bRange.ctx(), + isl_set_move_dims(_bRange.copy(), isl_dim_param, 0, + isl_dim_set, 0, nParams)}; auto tol = fixDim(universeSet(PBSpace(aRange)), 0, fusableNonOverlapTolerance);