diff --git a/src/workload/operation-space.cpp b/src/workload/operation-space.cpp index 88360620..0cdefd19 100644 --- a/src/workload/operation-space.cpp +++ b/src/workload/operation-space.cpp @@ -52,8 +52,10 @@ OperationSpace::OperationSpace(const Workload* wc, const OperationPoint& low, co // a data-space may not result in the exclusive high point in that data-space. for (unsigned space_id = 0; space_id < wc->GetShape()->NumDataSpaces; space_id++) { - auto space_low = Project(space_id, workload_, low); - auto space_high = Project(space_id, workload_, high); + Point space_low(workload_->GetShape()->DataSpaceOrder.at(space_id)); + Point space_high(workload_->GetShape()->DataSpaceOrder.at(space_id)); + + ProjectLowHigh(space_id, workload_, low, high, space_low, space_high); // Increment the high points by 1 because the AAHR constructor wants // an exclusive max point. @@ -62,6 +64,46 @@ OperationSpace::OperationSpace(const Workload* wc, const OperationPoint& low, co } } +void OperationSpace::ProjectLowHigh(Shape::DataSpaceID d, + const Workload* wc, + const OperationPoint& problem_low, + const OperationPoint& problem_high, + Point& data_space_low, + Point& data_space_high) +{ + for (unsigned data_space_dim = 0; data_space_dim < wc->GetShape()->DataSpaceOrder.at(d); data_space_dim++) + { + data_space_low[data_space_dim] = 0; + data_space_high[data_space_dim] = 0; + + for (auto& term : wc->GetShape()->Projections.at(d).at(data_space_dim)) + { + Coordinate low = problem_low[term.second]; + Coordinate high = problem_high[term.second]; + if (term.first != wc->GetShape()->NumCoefficients) + { + // If Coefficient is negative, flip high/low. + auto coeff = wc->GetCoefficient(term.first); + if (coeff < 0) + { + data_space_low[data_space_dim] += (high * coeff); + data_space_high[data_space_dim] += (low * coeff); + } + else + { + data_space_low[data_space_dim] += (low * coeff); + data_space_high[data_space_dim] += (high * coeff); + } + } + else + { + data_space_low[data_space_dim] += low; + data_space_high[data_space_dim] += high; + } + } + } +} + Point OperationSpace::Project(Shape::DataSpaceID d, const Workload* wc, const OperationPoint& problem_point) diff --git a/src/workload/operation-space.hpp b/src/workload/operation-space.hpp index afa5ea1c..2456dea6 100644 --- a/src/workload/operation-space.hpp +++ b/src/workload/operation-space.hpp @@ -61,6 +61,12 @@ class OperationSpace private: Point Project(Shape::DataSpaceID d, const Workload* wc, const OperationPoint& problem_point); + void ProjectLowHigh(Shape::DataSpaceID d, + const Workload* wc, + const OperationPoint& problem_low, + const OperationPoint& problem_high, + Point& data_space_low, + Point& data_space_high); public: OperationSpace();