Skip to content

Commit

Permalink
Expose the auxiliary dimension for scalar reduction
Browse files Browse the repository at this point in the history
Refactor `sumsq() += u * u` into `sumsq(x) += u * u` having only one
pixel. It enables CUDA codegen to fuse multiple reduction stages into
one single GPU kernel.
  • Loading branch information
antonysigma committed Jan 20, 2025
1 parent 6934b4e commit de4b6d3
Showing 1 changed file with 11 additions and 10 deletions.
21 changes: 11 additions & 10 deletions proximal/halide/src/algorithm/linearized-admm.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@ namespace utils {
Func
normSquared(const Func& v, const RDom& r) {
Func sumsq{"sumsq"};
sumsq() = 0.0f;
sumsq(x) = 0.0f;

if (v.dimensions() == 4) {
sumsq() += v(r.x, r.y, r.z, r.w) * v(r.x, r.y, r.z, r.w);
sumsq(x) += v(r.x, r.y, r.z, r.w) * v(r.x, r.y, r.z, r.w);
} else { // n_dim == 3
sumsq() += v(r.x, r.y, r.z) * v(r.x, r.y, r.z);
sumsq(x) += v(r.x, r.y, r.z) * v(r.x, r.y, r.z);
}

return sumsq;
Expand All @@ -31,13 +31,13 @@ template <size_t N>
Func
normSquared(const FuncTuple<N>& v, const RDom& r) {
Func sumsq{"sumsq"};
sumsq() = 0.0f;
sumsq(x) = 0.0f;

for (const auto& _v : v) {
if (_v.dimensions() == 4) {
sumsq() += _v(r.x, r.y, r.z, r.w) * _v(r.x, r.y, r.z, r.w);
sumsq(x) += _v(r.x, r.y, r.z, r.w) * _v(r.x, r.y, r.z, r.w);
} else { // n_dim == 3
sumsq() += _v(r.x, r.y, r.z) * _v(r.x, r.y, r.z);
sumsq(x) += _v(r.x, r.y, r.z) * _v(r.x, r.y, r.z);
}
}

Expand Down Expand Up @@ -73,7 +73,7 @@ iterate(const Func& v, const FuncTuple<N>& z, const FuncTuple<N>& u, G& K, const
});

const Func v2 = K.adjoint(Kvzu);
Func v3;
Func v3{"v3"};
v3(x, y, c) = v(x, y, c) - (mu / lmb) * v2(x, y, c);

v_new = omega_fn(v3, 1.0f / mu, b);
Expand Down Expand Up @@ -153,15 +153,16 @@ computeConvergence(const Func& v, const FuncTuple<N>& z, const FuncTuple<N>& u,

const Func Kv_norm = normSquared(Kv, output_dimensions);
const Func z_norm = normSquared(z, output_dimensions);
const Expr eps_pri = eps_rel * sqrt(max(Kv_norm(), z_norm())) + std::sqrt(float(output_size)) * eps_abs;
const Expr eps_pri =
eps_rel * sqrt(max(Kv_norm(0), z_norm(0))) + std::sqrt(float(output_size)) * eps_abs;

const Func KTu_norm = normSquared(KTu, input_dimensions);
const Expr eps_dual =
sqrt(KTu_norm()) * eps_rel / (1.0f / lmb) + std::sqrt(float(input_size)) * eps_abs;
sqrt(KTu_norm(0)) * eps_rel / (1.0f / lmb) + std::sqrt(float(input_size)) * eps_abs;

const Func r_norm = normSquared(r, output_dimensions);
const Func s_norm = normSquared(s, input_dimensions);
return {sqrt(r_norm()), sqrt(s_norm()), eps_pri, eps_dual};
return {sqrt(r_norm(0)), sqrt(s_norm(0)), eps_pri, eps_dual};
}
} // namespace linearized_admm
} // namespace algorithm

0 comments on commit de4b6d3

Please sign in to comment.