From de4b6d3be68e67098b8af02567b768e6f0830003 Mon Sep 17 00:00:00 2001 From: Antony Chan Date: Mon, 13 Jan 2025 08:21:23 -0800 Subject: [PATCH] Expose the auxiliary dimension for scalar reduction Refactor `sumsq() += u * u` into `sumsq(x) += u * u` having only one pixel. It enables CUDA codegen to fuse multiple reduction stages into one single GPU kernel. --- .../halide/src/algorithm/linearized-admm.h | 21 ++++++++++--------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/proximal/halide/src/algorithm/linearized-admm.h b/proximal/halide/src/algorithm/linearized-admm.h index 2e07adc..c758219 100644 --- a/proximal/halide/src/algorithm/linearized-admm.h +++ b/proximal/halide/src/algorithm/linearized-admm.h @@ -16,12 +16,12 @@ namespace utils { Func normSquared(const Func& v, const RDom& r) { Func sumsq{"sumsq"}; - sumsq() = 0.0f; + sumsq(x) = 0.0f; if (v.dimensions() == 4) { - sumsq() += v(r.x, r.y, r.z, r.w) * v(r.x, r.y, r.z, r.w); + sumsq(x) += v(r.x, r.y, r.z, r.w) * v(r.x, r.y, r.z, r.w); } else { // n_dim == 3 - sumsq() += v(r.x, r.y, r.z) * v(r.x, r.y, r.z); + sumsq(x) += v(r.x, r.y, r.z) * v(r.x, r.y, r.z); } return sumsq; @@ -31,13 +31,13 @@ template Func normSquared(const FuncTuple& v, const RDom& r) { Func sumsq{"sumsq"}; - sumsq() = 0.0f; + sumsq(x) = 0.0f; for (const auto& _v : v) { if (_v.dimensions() == 4) { - sumsq() += _v(r.x, r.y, r.z, r.w) * _v(r.x, r.y, r.z, r.w); + sumsq(x) += _v(r.x, r.y, r.z, r.w) * _v(r.x, r.y, r.z, r.w); } else { // n_dim == 3 - sumsq() += _v(r.x, r.y, r.z) * _v(r.x, r.y, r.z); + sumsq(x) += _v(r.x, r.y, r.z) * _v(r.x, r.y, r.z); } } @@ -73,7 +73,7 @@ iterate(const Func& v, const FuncTuple& z, const FuncTuple& u, G& K, const }); const Func v2 = K.adjoint(Kvzu); - Func v3; + Func v3{"v3"}; v3(x, y, c) = v(x, y, c) - (mu / lmb) * v2(x, y, c); v_new = omega_fn(v3, 1.0f / mu, b); @@ -153,15 +153,16 @@ computeConvergence(const Func& v, const FuncTuple& z, const FuncTuple& u, const Func Kv_norm = normSquared(Kv, output_dimensions); const Func z_norm = normSquared(z, output_dimensions); - const Expr eps_pri = eps_rel * sqrt(max(Kv_norm(), z_norm())) + std::sqrt(float(output_size)) * eps_abs; + const Expr eps_pri = + eps_rel * sqrt(max(Kv_norm(0), z_norm(0))) + std::sqrt(float(output_size)) * eps_abs; const Func KTu_norm = normSquared(KTu, input_dimensions); const Expr eps_dual = - sqrt(KTu_norm()) * eps_rel / (1.0f / lmb) + std::sqrt(float(input_size)) * eps_abs; + sqrt(KTu_norm(0)) * eps_rel / (1.0f / lmb) + std::sqrt(float(input_size)) * eps_abs; const Func r_norm = normSquared(r, output_dimensions); const Func s_norm = normSquared(s, input_dimensions); - return {sqrt(r_norm()), sqrt(s_norm()), eps_pri, eps_dual}; + return {sqrt(r_norm(0)), sqrt(s_norm(0)), eps_pri, eps_dual}; } } // namespace linearized_admm } // namespace algorithm