Skip to content

Commit

Permalink
Rework precision specification. Generalize towards using this for oth…
Browse files Browse the repository at this point in the history
…er functions.
  • Loading branch information
mcourteaux committed Nov 11, 2024
1 parent 3ced523 commit b35f7fa
Show file tree
Hide file tree
Showing 8 changed files with 416 additions and 415 deletions.
108 changes: 108 additions & 0 deletions src/ApproximationTables.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
#include "ApproximationTables.h"

namespace Halide {
namespace Internal {

// clang-format off
// Generate this table with:
// python3 src/polynomial_optimizer.py atan --order 1 2 3 4 5 6 7 8 --loss mse mae mulpe mulpe_mae --no-gui --format table
static std::vector<Approximation> table_atan = {
{ApproximationPrecision::MSE, 9.249650e-04, 7.078984e-02, 2.411547e+06, {+8.56188008e-01}},
{ApproximationPrecision::MSE, 1.026356e-05, 9.214909e-03, 3.985505e+05, {+9.76213454e-01, -2.00030200e-01}},
{ApproximationPrecision::MSE, 1.577588e-07, 1.323851e-03, 6.724566e+04, {+9.95982073e-01, -2.92278128e-01, +8.30180680e-02}},
{ApproximationPrecision::MSE, 2.849011e-09, 1.992218e-04, 1.142204e+04, {+9.99316541e-01, -3.22286501e-01, +1.49032461e-01, -4.08635592e-02}},
{ApproximationPrecision::MSE, 5.667504e-11, 3.080100e-05, 1.945614e+03, {+9.99883373e-01, -3.30599535e-01, +1.81451316e-01, -8.71733830e-02, +2.18671936e-02}},
{ApproximationPrecision::MSE, 1.202662e-12, 4.846916e-06, 3.318677e+02, {+9.99980065e-01, -3.32694393e-01, +1.94019697e-01, -1.17694732e-01, +5.40822080e-02, -1.22995279e-02}},
{ApproximationPrecision::MSE, 2.672889e-14, 7.722732e-07, 5.664632e+01, {+9.99996589e-01, -3.33190090e-01, +1.98232868e-01, -1.32941469e-01, +8.07623712e-02, -3.46124853e-02, +7.15115276e-03}},
{ApproximationPrecision::MSE, 6.147315e-16, 1.245768e-07, 9.764224e+00, {+9.99999416e-01, -3.33302229e-01, +1.99511173e-01, -1.39332647e-01, +9.70944891e-02, -5.68823386e-02, +2.25679012e-02, -4.25772648e-03}},

{ApproximationPrecision::MAE, 1.097847e-03, 4.801638e-02, 2.793645e+06, {+8.33414544e-01}},
{ApproximationPrecision::MAE, 1.209593e-05, 4.968992e-03, 4.623251e+05, {+9.72410454e-01, -1.91981283e-01}},
{ApproximationPrecision::MAE, 1.839382e-07, 6.107084e-04, 7.766697e+04, {+9.95360080e-01, -2.88702052e-01, +7.93508437e-02}},
{ApproximationPrecision::MAE, 3.296902e-09, 8.164167e-05, 1.313615e+04, {+9.99214108e-01, -3.21178073e-01, +1.46272006e-01, -3.89915187e-02}},
{ApproximationPrecision::MAE, 6.523525e-11, 1.147459e-05, 2.229646e+03, {+9.99866373e-01, -3.30305517e-01, +1.80162434e-01, -8.51611537e-02, +2.08475020e-02}},
{ApproximationPrecision::MAE, 1.378842e-12, 1.667328e-06, 3.792091e+02, {+9.99977226e-01, -3.32622991e-01, +1.93541452e-01, -1.16429278e-01, +5.26504600e-02, -1.17203722e-02}},
{ApproximationPrecision::MAE, 3.055131e-14, 2.480947e-07, 6.457187e+01, {+9.99996113e-01, -3.33173716e-01, +1.98078484e-01, -1.32334692e-01, +7.96260166e-02, -3.36062649e-02, +6.81247117e-03}},
{ApproximationPrecision::MAE, 7.013215e-16, 3.757868e-08, 1.102324e+01, {+9.99999336e-01, -3.33298615e-01, +1.99465749e-01, -1.39086791e-01, +9.64233077e-02, -5.59142254e-02, +2.18643190e-02, -4.05495427e-03}},

{ApproximationPrecision::MULPE, 1.355602e-03, 1.067325e-01, 1.808493e+06, {+8.92130617e-01}},
{ApproximationPrecision::MULPE, 2.100588e-05, 1.075508e-02, 1.822095e+05, {+9.89111122e-01, -2.14468039e-01}},
{ApproximationPrecision::MULPE, 3.573985e-07, 1.316370e-03, 2.227347e+04, {+9.98665077e-01, -3.02990987e-01, +9.10404434e-02}},
{ApproximationPrecision::MULPE, 6.474958e-09, 1.548508e-04, 2.619892e+03, {+9.99842198e-01, -3.26272641e-01, +1.56294460e-01, -4.46207045e-02}},
{ApproximationPrecision::MULPE, 1.313474e-10, 2.533532e-05, 4.294794e+02, {+9.99974110e-01, -3.31823782e-01, +1.85886095e-01, -9.30024008e-02, +2.43894760e-02}},
{ApproximationPrecision::MULPE, 3.007880e-12, 3.530685e-06, 5.983830e+01, {+9.99996388e-01, -3.33036463e-01, +1.95959706e-01, -1.22068745e-01, +5.83403647e-02, -1.37966171e-02}},
{ApproximationPrecision::MULPE, 6.348880e-14, 4.882649e-07, 8.276351e+00, {+9.99999499e-01, -3.33273408e-01, +1.98895454e-01, -1.35153794e-01, +8.43185278e-02, -3.73434598e-02, +7.95583230e-03}},
{ApproximationPrecision::MULPE, 1.369569e-15, 7.585036e-08, 1.284979e+00, {+9.99999922e-01, -3.33320840e-01, +1.99708563e-01, -1.40257063e-01, +9.93094012e-02, -5.97138046e-02, +2.44056181e-02, -4.73371006e-03}},


{ApproximationPrecision::MULPE_MAE, 9.548909e-04, 6.131488e-02, 2.570520e+06, {+8.46713042e-01}},
{ApproximationPrecision::MULPE_MAE, 1.159917e-05, 6.746680e-03, 3.778023e+05, {+9.77449762e-01, -1.98798279e-01}},
{ApproximationPrecision::MULPE_MAE, 1.783646e-07, 8.575388e-04, 6.042236e+04, {+9.96388826e-01, -2.92591679e-01, +8.24585555e-02}},
{ApproximationPrecision::MULPE_MAE, 3.265269e-09, 1.190548e-04, 9.505190e+03, {+9.99430906e-01, -3.22774535e-01, +1.49370817e-01, -4.07480795e-02}},
{ApproximationPrecision::MULPE_MAE, 6.574962e-11, 1.684690e-05, 1.515116e+03, {+9.99909079e-01, -3.30795737e-01, +1.81810037e-01, -8.72860225e-02, +2.17776539e-02}},
{ApproximationPrecision::MULPE_MAE, 1.380489e-12, 2.497538e-06, 2.510721e+02, {+9.99984893e-01, -3.32748885e-01, +1.94193211e-01, -1.17865932e-01, +5.40633775e-02, -1.22309990e-02}},
{ApproximationPrecision::MULPE_MAE, 3.053218e-14, 3.784868e-07, 4.181995e+01, {+9.99997480e-01, -3.33205127e-01, +1.98309644e-01, -1.33094430e-01, +8.08643094e-02, -3.45859503e-02, +7.11261604e-03}},
{ApproximationPrecision::MULPE_MAE, 7.018877e-16, 5.862915e-08, 6.942196e+00, {+9.99999581e-01, -3.33306326e-01, +1.99542180e-01, -1.39433369e-01, +9.72462857e-02, -5.69734398e-02, +2.25639390e-02, -4.24074590e-03}},
};
// clang-format on

const Approximation *find_best_approximation(const std::vector<Approximation> &table, ApproximationPrecision precision) {
const Approximation *best = nullptr;
constexpr int term_cost = 20;
constexpr int extra_term_cost = 200;
double best_score = 0;
//std::printf("Looking for min_terms=%d, max_absolute_error=%f\n", precision.constraint_min_poly_terms, precision.constraint_max_absolute_error);
for (size_t i = 0; i < table.size(); ++i) {
const Approximation &e = table[i];

double penalty = 0.0;

int obj_score = e.objective == precision.optimized_for ? 100 * term_cost : 0;
if (precision.optimized_for == ApproximationPrecision::MULPE_MAE && e.objective == ApproximationPrecision::MULPE) {
obj_score = 50 * term_cost; // When MULPE_MAE is not available, prefer MULPE.
}

int num_terms = int(e.coefficients.size());
int term_count_score = (12 - num_terms) * term_cost;
if (num_terms < precision.constraint_min_poly_terms) {
penalty += (precision.constraint_min_poly_terms - num_terms) * extra_term_cost;
}

double precision_score = 0;
// If we don't care about the maximum number of terms, we maximize precision.
switch (precision.optimized_for) {
case ApproximationPrecision::MSE:
precision_score = -std::log(e.mse);
break;
case ApproximationPrecision::MAE:
precision_score = -std::log(e.mae);
break;
case ApproximationPrecision::MULPE:
precision_score = -std::log(e.mulpe);
break;
case ApproximationPrecision::MULPE_MAE:
precision_score = -0.5 * std::log(e.mulpe * e.mae);
break;
}

if (precision.constraint_max_absolute_error > 0.0 && precision.constraint_max_absolute_error < e.mae) {
penalty += 20 * extra_term_cost; // penalty for not getting the required precision.
}

double score = obj_score + term_count_score + precision_score - penalty;
//std::printf("Score for %zu (%zu terms): %f = %d + %d + %f - penalty %f\n", i, e.coefficients.size(), score, obj_score, term_count_score, precision_score, penalty);
if (score > best_score) {
best = &e;
best_score = score;
}
}
//std::printf("Best score: %f\n", best_score);
return best;
}

const Approximation *best_atan_approximation(Halide::ApproximationPrecision precision) {
return find_best_approximation(table_atan, precision);
}

} // namespace Internal
} // namespace Halide
21 changes: 21 additions & 0 deletions src/ApproximationTables.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#pragma once

#include <vector>

#include "IROperator.h"

namespace Halide {
namespace Internal {

struct Approximation {
ApproximationPrecision::OptimizationObjective objective;
double mse;
double mae;
double mulpe;
std::vector<double> coefficients;
};

const Approximation *best_atan_approximation(Halide::ApproximationPrecision precision);

} // namespace Internal
} // namespace Halide
4 changes: 2 additions & 2 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -219,8 +219,7 @@ target_sources(
WrapCalls.h
)

# The sources that go into libHalide. For the sake of IDE support, headers that
# exist in src/ but are not public should be included here.
# The sources that go into libHalide.
target_sources(
Halide
PRIVATE
Expand All @@ -232,6 +231,7 @@ target_sources(
AlignLoads.cpp
AllocationBoundsInference.cpp
ApplySplit.cpp
ApproximationTables.cpp
Argument.cpp
AssociativeOpsTable.cpp
Associativity.cpp
Expand Down
104 changes: 6 additions & 98 deletions src/IROperator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "IRPrinter.h"
#include "Interval.h"
#include "Util.h"
#include "ApproximationTables.h"
#include "Var.h"

using namespace Halide::Internal;
Expand Down Expand Up @@ -1388,7 +1389,7 @@ Expr fast_sin_cos(const Expr &x_full, bool is_sin) {
Expr sin_usecos = is_sin ? ((k_mod4 == 1) || (k_mod4 == 3)) : ((k_mod4 == 0) || (k_mod4 == 2));
Expr flip_sign = is_sin ? (k_mod4 > 1) : ((k_mod4 == 1) || (k_mod4 == 2));

// Reduce the angle modulo pi/2.
// Reduce the angle modulo pi/2: i.e., to the angle within the quadrant.
Expr x = x_full - k_real * pi_over_two;

const float sin_c2 = -0.16666667163372039794921875f;
Expand Down Expand Up @@ -1447,106 +1448,13 @@ Expr fast_atan_approximation(const Expr &x_full, ApproximationPrecision precisio

// The table is huge, so let's put clang-format off and handle the layout manually:
// clang-format off
std::vector<float> c;
switch (precision) {
// == MSE Optimized == //
case ApproximationPrecision::MSE_Poly2: // (MSE=1.0264e-05, MAE=9.2149e-03, MaxUlpE=3.9855e+05)
c = {+9.762134539879e-01f, -2.000301999499e-01f};
break;
case ApproximationPrecision::MSE_Poly3: // (MSE=1.5776e-07, MAE=1.3239e-03, MaxUlpE=6.7246e+04)
c = {+9.959820734941e-01f, -2.922781275652e-01f, +8.301806798764e-02f};
break;
case ApproximationPrecision::MSE_Poly4: // (MSE=2.8490e-09, MAE=1.9922e-04, MaxUlpE=1.1422e+04)
c = {+9.993165406918e-01f, -3.222865011143e-01f, +1.490324612527e-01f, -4.086355921512e-02f};
break;
case ApproximationPrecision::MSE_Poly5: // (MSE=5.6675e-11, MAE=3.0801e-05, MaxUlpE=1.9456e+03)
c = {+9.998833730470e-01f, -3.305995351168e-01f, +1.814513158372e-01f, -8.717338298570e-02f,
+2.186719361787e-02f};
break;
case ApproximationPrecision::MSE_Poly6: // (MSE=1.2027e-12, MAE=4.8469e-06, MaxUlpE=3.3187e+02)
c = {+9.999800646964e-01f, -3.326943930673e-01f, +1.940196968486e-01f, -1.176947321238e-01f,
+5.408220801540e-02f, -1.229952788751e-02f};
break;
case ApproximationPrecision::MSE_Poly7: // (MSE=2.6729e-14, MAE=7.7227e-07, MaxUlpE=5.6646e+01)
c = {+9.999965889517e-01f, -3.331900904961e-01f, +1.982328680483e-01f, -1.329414694644e-01f,
+8.076237117606e-02f, -3.461248530394e-02f, +7.151152759080e-03f};
break;
case ApproximationPrecision::MSE_Poly8: // (MSE=6.1506e-16, MAE=1.2419e-07, MaxUlpE=9.6914e+00)
c = {+9.999994159669e-01f, -3.333022219271e-01f, +1.995110884308e-01f, -1.393321817395e-01f,
+9.709319573480e-02f, -5.688043380309e-02f, +2.256648487698e-02f, -4.257308331872e-03f};
break;

// == MAE Optimized == //
case ApproximationPrecision::MAE_1e_2:
case ApproximationPrecision::MAE_Poly2: // (MSE=1.2096e-05, MAE=4.9690e-03, MaxUlpE=4.6233e+05)
c = {+9.724104536788e-01f, -1.919812827495e-01f};
break;
case ApproximationPrecision::MAE_1e_3:
case ApproximationPrecision::MAE_Poly3: // (MSE=1.8394e-07, MAE=6.1071e-04, MaxUlpE=7.7667e+04)
c = {+9.953600796593e-01f, -2.887020515559e-01f, +7.935084373856e-02f};
break;
case ApproximationPrecision::MAE_1e_4:
case ApproximationPrecision::MAE_Poly4: // (MSE=3.2969e-09, MAE=8.1642e-05, MaxUlpE=1.3136e+04)
c = {+9.992141075707e-01f, -3.211780734117e-01f, +1.462720063085e-01f, -3.899151874271e-02f};
break;
case ApproximationPrecision::MAE_Poly5: // (MSE=6.5235e-11, MAE=1.1475e-05, MaxUlpE=2.2296e+03)
c = {+9.998663727249e-01f, -3.303055171903e-01f, +1.801624340886e-01f, -8.516115366058e-02f,
+2.084750202717e-02f};
break;
case ApproximationPrecision::MAE_1e_5:
case ApproximationPrecision::MAE_Poly6: // (MSE=1.3788e-12, MAE=1.6673e-06, MaxUlpE=3.7921e+02)
c = {+9.999772256973e-01f, -3.326229914097e-01f, +1.935414518077e-01f, -1.164292778405e-01f,
+5.265046001895e-02f, -1.172037220425e-02f};
break;
case ApproximationPrecision::MAE_1e_6:
case ApproximationPrecision::MAE_Poly7: // (MSE=3.0551e-14, MAE=2.4809e-07, MaxUlpE=6.4572e+01)
c = {+9.999961125922e-01f, -3.331737159104e-01f, +1.980784841430e-01f, -1.323346922675e-01f,
+7.962601662878e-02f, -3.360626486524e-02f, +6.812471171209e-03f};
break;
case ApproximationPrecision::MAE_Poly8: // (MSE=7.0132e-16, MAE=3.7579e-08, MaxUlpE=1.1023e+01)
c = {+9.999993357462e-01f, -3.332986153129e-01f, +1.994657492754e-01f, -1.390867909988e-01f,
+9.642330770840e-02f, -5.591422536378e-02f, +2.186431903729e-02f, -4.054954273090e-03f};
break;


// == Max ULP Optimized == //
case ApproximationPrecision::MULPE_Poly2: // (MSE=2.1006e-05, MAE=1.0755e-02, MaxUlpE=1.8221e+05)
c = {+9.891111216318e-01f, -2.144680385336e-01f};
break;
case ApproximationPrecision::MULPE_1e_2:
case ApproximationPrecision::MULPE_Poly3: // (MSE=3.5740e-07, MAE=1.3164e-03, MaxUlpE=2.2273e+04)
c = {+9.986650768126e-01f, -3.029909865833e-01f, +9.104044335898e-02f};
break;
case ApproximationPrecision::MULPE_1e_3:
case ApproximationPrecision::MULPE_Poly4: // (MSE=6.4750e-09, MAE=1.5485e-04, MaxUlpE=2.6199e+03)
c = {+9.998421981586e-01f, -3.262726405770e-01f, +1.562944595469e-01f, -4.462070448745e-02f};
break;
case ApproximationPrecision::MULPE_1e_4:
case ApproximationPrecision::MULPE_Poly5: // (MSE=1.3135e-10, MAE=2.5335e-05, MaxUlpE=4.2948e+02)
c = {+9.999741103798e-01f, -3.318237821017e-01f, +1.858860952571e-01f, -9.300240079057e-02f,
+2.438947597681e-02f};
break;
case ApproximationPrecision::MULPE_1e_5:
case ApproximationPrecision::MULPE_Poly6: // (MSE=3.0079e-12, MAE=3.5307e-06, MaxUlpE=5.9838e+01)
c = {+9.999963876702e-01f, -3.330364633925e-01f, +1.959597060284e-01f, -1.220687452250e-01f,
+5.834036471395e-02f, -1.379661708254e-02f};
break;
case ApproximationPrecision::MULPE_1e_6:
case ApproximationPrecision::MULPE_Poly7: // (MSE=6.3489e-14, MAE=4.8826e-07, MaxUlpE=8.2764e+00)
c = {+9.999994992400e-01f, -3.332734078379e-01f, +1.988954540598e-01f, -1.351537940907e-01f,
+8.431852775558e-02f, -3.734345976535e-02f, +7.955832300869e-03f};
break;
case ApproximationPrecision::MULPE_Poly8: // (MSE=1.3696e-15, MAE=7.5850e-08, MaxUlpE=1.2850e+00)
c = {+9.999999220612e-01f, -3.333208398432e-01f, +1.997085632112e-01f, -1.402570625577e-01f,
+9.930940122930e-02f, -5.971380457112e-02f, +2.440561807586e-02f, -4.733710058459e-03f};
break;
}
// clang-format on
const Internal::Approximation *approx = Internal::best_atan_approximation(precision);
const std::vector<double> &c = approx->coefficients;

Expr x2 = x * x;
Expr result = c.back();
Expr result = float(c.back());
for (size_t i = 1; i < c.size(); ++i) {
result = x2 * result + c[c.size() - i - 1];
result = x2 * result + float(c[c.size() - i - 1]);
}
result *= x;

Expand Down
Loading

0 comments on commit b35f7fa

Please sign in to comment.