Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

replace switch with if constexpr #1301

Open
wants to merge 9 commits into
base: develop
Choose a base branch
from
614 changes: 280 additions & 334 deletions include/color_spinor.h

Large diffs are not rendered by default.

13 changes: 6 additions & 7 deletions include/kernels/clover_outer_product.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,7 @@ namespace quda {
Spinor A = arg.inA(x_cb, 0);
Spinor C = arg.inC(x_cb, 0);

#pragma unroll
for (int dim=0; dim<4; ++dim) {
static_for<0,4>([&](auto dim){
int shift[4] = {0, 0, 0, 0};
shift[dim] = 1;
const int nbr_idx = neighborIndex(x_cb, shift, arg.partitioned, arg.parity, arg.X);
Expand All @@ -73,18 +72,18 @@ namespace quda {
Spinor B_shift = arg.inB(nbr_idx, 0);
Spinor D_shift = arg.inD(nbr_idx, 0);

B_shift = (B_shift.project(dim,1)).reconstruct(dim,1);
B_shift = (B_shift.template project<dim,1>()).template reconstruct<dim,1>();
Link result = outerProdSpinTrace(B_shift,A);

D_shift = (D_shift.project(dim,-1)).reconstruct(dim,-1);
D_shift = (D_shift.template project<dim,-1>()).template reconstruct<dim,-1>();
result += outerProdSpinTrace(D_shift,C);

Link temp = arg.force(dim, x_cb, arg.parity);
Link U = arg.U(dim, x_cb, arg.parity);
result = temp + U*result*arg.coeff;
arg.force(dim, x_cb, arg.parity) = result;
}
} // dim
}); // dim
}
};

Expand All @@ -107,11 +106,11 @@ namespace quda {
Spinor C = arg.inC(bulk_cb_idx, 0);

HalfSpinor projected_tmp = arg.inB.Ghost(Arg::dim, 1, x_cb, 0);
Spinor B_shift = projected_tmp.reconstruct(Arg::dim, 1);
Spinor B_shift = projected_tmp.template reconstruct<Arg::dim, 1>();
Link result = outerProdSpinTrace(B_shift,A);

projected_tmp = arg.inD.Ghost(Arg::dim, 1, x_cb, 0);
Spinor D_shift = projected_tmp.reconstruct(Arg::dim,-1);
Spinor D_shift = projected_tmp.template reconstruct<Arg::dim,-1>();
result += outerProdSpinTrace(D_shift,C);

Link temp = arg.force(Arg::dim, bulk_cb_idx, arg.parity);
Expand Down
2 changes: 1 addition & 1 deletion include/kernels/clover_sigma_outer_product.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ namespace quda
for (int i = 0; i < Arg::nvector; i++) {
const Spinor A = arg.inA[i](x_cb, parity);
const Spinor B = arg.inB[i](x_cb, parity);
Spinor C = A.sigma(nu, mu); // multiply by sigma_mu_nu
Spinor C = A.template sigma<nu, mu>(); // multiply by sigma_mu_nu
result += arg.coeff[i][parity] * outerProdSpinTrace(C, B);
}

Expand Down
8 changes: 4 additions & 4 deletions include/kernels/dslash_domain_wall_5d.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,9 @@ namespace quda
constexpr int proj_dir = dagger ? +1 : -1;
Vector in = arg.in(fwd_idx, their_spinor_parity);
if (s == arg.Ls - 1) {
out += (-arg.m_f * in.project(d, proj_dir)).reconstruct(d, proj_dir);
out += (-arg.m_f * in.template project<d, proj_dir>()).template reconstruct<d, proj_dir>();
} else {
out += in.project(d, proj_dir).reconstruct(d, proj_dir);
out += in.template project<d, proj_dir>().template reconstruct<d, proj_dir>();
}
}

Expand All @@ -66,9 +66,9 @@ namespace quda
constexpr int proj_dir = dagger ? -1 : +1;
Vector in = arg.in(back_idx, their_spinor_parity);
if (s == 0) {
out += (-arg.m_f * in.project(d, proj_dir)).reconstruct(d, proj_dir);
out += (-arg.m_f * in.template project<d, proj_dir>()).template reconstruct<d, proj_dir>();
} else {
out += in.project(d, proj_dir).reconstruct(d, proj_dir);
out += in.template project<d, proj_dir>().template reconstruct<d, proj_dir>();
}
}
}
Expand Down
48 changes: 24 additions & 24 deletions include/kernels/dslash_domain_wall_m5.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ namespace quda
constexpr int proj_dir = dagger ? +1 : -1;
if (shared) {
if (sync) { cache.sync(); }
cache.save(in.project(4, proj_dir));
cache.save(in.template project<4, proj_dir>());
cache.sync();
}
const int fwd_s = (s + 1) % arg.Ls;
Expand All @@ -236,20 +236,20 @@ namespace quda
half_in = cache.load(threadIdx.x, fwd_s, parity);
} else {
Vector full_in = arg.in(fwd_idx, parity);
half_in = full_in.project(4, proj_dir);
half_in = full_in.template project<4, proj_dir>();
}
if (s == arg.Ls - 1) {
out += (-arg.m_f * half_in).reconstruct(4, proj_dir);
out += (-arg.m_f * half_in).template reconstruct<4, proj_dir>();
} else {
out += half_in.reconstruct(4, proj_dir);
out += half_in.template reconstruct<4, proj_dir>();
}
}

{ // backwards direction
constexpr int proj_dir = dagger ? -1 : +1;
if (shared) {
cache.sync();
cache.save(in.project(4, proj_dir));
cache.save(in.template project<4, proj_dir>());
cache.sync();
}
const int back_s = (s + arg.Ls - 1) % arg.Ls;
Expand All @@ -259,12 +259,12 @@ namespace quda
half_in = cache.load(threadIdx.x, back_s, parity);
} else {
Vector full_in = arg.in(back_idx, parity);
half_in = full_in.project(4, proj_dir);
half_in = full_in.template project<4, proj_dir>();
}
if (s == 0) {
out += (-arg.m_f * half_in).reconstruct(4, proj_dir);
out += (-arg.m_f * half_in).template reconstruct<4, proj_dir>();
} else {
out += half_in.reconstruct(4, proj_dir);
out += half_in.template reconstruct<4, proj_dir>();
}
}

Expand All @@ -284,9 +284,9 @@ namespace quda
const Vector in = shared ? cache.load(threadIdx.x, fwd_s, parity) : arg.in(fwd_idx, parity);
constexpr int proj_dir = dagger ? +1 : -1;
if (s == arg.Ls - 1) {
out += (-arg.m_f * in.project(4, proj_dir)).reconstruct(4, proj_dir);
out += (-arg.m_f * in.template project<4, proj_dir>()).template reconstruct<4, proj_dir>();
} else {
out += in.project(4, proj_dir).reconstruct(4, proj_dir);
out += in.template project<4, proj_dir>().template reconstruct<4, proj_dir>();
}
}

Expand All @@ -296,9 +296,9 @@ namespace quda
const Vector in = shared ? cache.load(threadIdx.x, back_s, parity) : arg.in(back_idx, parity);
constexpr int proj_dir = dagger ? -1 : +1;
if (s == 0) {
out += (-arg.m_f * in.project(4, proj_dir)).reconstruct(4, proj_dir);
out += (-arg.m_f * in.template project<4, proj_dir>()).template reconstruct<4, proj_dir>();
} else {
out += in.project(4, proj_dir).reconstruct(4, proj_dir);
out += in.template project<4, proj_dir>().template reconstruct<4, proj_dir>();
}
}
} // use_half_vector
Expand Down Expand Up @@ -395,14 +395,14 @@ namespace quda
int exp = s_ < s ? arg.Ls - s + s_ : s_ - s;
real factorR = inv * fpow(k, exp) * (s_ < s ? -arg.m_f : static_cast<real>(1.0));
constexpr int proj_dir = dagger ? -1 : +1;
out += factorR * (in.project(4, proj_dir)).reconstruct(4, proj_dir);
out += factorR * (in.template project<4, proj_dir>()).template reconstruct<4, proj_dir>();
}

{
int exp = s_ > s ? arg.Ls - s_ + s : s - s_;
real factorL = inv * fpow(k, exp) * (s_ > s ? -arg.m_f : static_cast<real>(1.0));
constexpr int proj_dir = dagger ? +1 : -1;
out += factorL * (in.project(4, proj_dir)).reconstruct(4, proj_dir);
out += factorL * (in.template project<4, proj_dir>()).template reconstruct<4, proj_dir>();
}
}

Expand Down Expand Up @@ -443,7 +443,7 @@ namespace quda

if (shared) {
if (sync) { cache.sync(); }
cache.save(in.project(4, proj_dir));
cache.save(in.template project<4, proj_dir>());
cache.sync();
}

Expand All @@ -457,21 +457,21 @@ namespace quda
r += factorR * cache.load(threadIdx.x, s, parity);
} else {
Vector in = arg.in(s * arg.volume_4d_cb + x_cb, parity);
r += factorR * in.project(4, proj_dir);
r += factorR * in.template project<4, proj_dir>();
}

R *= coeff.kappa(s);
s = (s + arg.Ls - 1) % arg.Ls;
}

out += r.reconstruct(4, proj_dir);
out += r.template reconstruct<4, proj_dir>();
}

{ // second do L
constexpr int proj_dir = dagger ? +1 : -1;
if (shared) {
cache.sync(); // ensure we finish R before overwriting cache
cache.save(in.project(4, proj_dir));
cache.save(in.template project<4, proj_dir>());
cache.sync();
}

Expand All @@ -485,14 +485,14 @@ namespace quda
l += factorL * cache.load(threadIdx.x, s, parity);
} else {
Vector in = arg.in(s * arg.volume_4d_cb + x_cb, parity);
l += factorL * in.project(4, proj_dir);
l += factorL * in.template project<4, proj_dir>();
}

L *= coeff.kappa(s);
s = (s + 1) % arg.Ls;
}

out += l.reconstruct(4, proj_dir);
out += l.template reconstruct<4, proj_dir>();
}
} else { // use_half_vector
SharedMemoryCache<Vector> cache(target::block_dim());
Expand All @@ -512,13 +512,13 @@ namespace quda
auto factorR = (s_ < s ? -arg.m_f * R : R);

Vector in = shared ? cache.load(threadIdx.x, s, parity) : arg.in(s * arg.volume_4d_cb + x_cb, parity);
r += factorR * in.project(4, proj_dir);
r += factorR * in.template project<4, proj_dir>();

R *= coeff.kappa(s);
s = (s + arg.Ls - 1) % arg.Ls;
}

out += r.reconstruct(4, proj_dir);
out += r.template reconstruct<4, proj_dir>();
}

{ // second do L
Expand All @@ -531,13 +531,13 @@ namespace quda
auto factorL = (s_ > s ? -arg.m_f * L : L);

Vector in = shared ? cache.load(threadIdx.x, s, parity) : arg.in(s * arg.volume_4d_cb + x_cb, parity);
l += factorL * in.project(4, proj_dir);
l += factorL * in.template project<4, proj_dir>();

L *= coeff.kappa(s);
s = (s + 1) % arg.Ls;
}

out += l.reconstruct(4, proj_dir);
out += l.template reconstruct<4, proj_dir>();
}
} // use_half_vector

Expand Down
16 changes: 8 additions & 8 deletions include/kernels/dslash_gamma_helper.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,11 @@ namespace quda {
{
ColorSpinor<typename Arg::real, Arg::nColor, 4> in = arg.in(x_cb, parity);
switch(arg.d) {
case 0: arg.out(x_cb, parity) = in.gamma(0);
case 1: arg.out(x_cb, parity) = in.gamma(1);
case 2: arg.out(x_cb, parity) = in.gamma(2);
case 3: arg.out(x_cb, parity) = in.gamma(3);
case 4: arg.out(x_cb, parity) = in.gamma(4);
case 0: arg.out(x_cb, parity) = in.template gamma<0>();
case 1: arg.out(x_cb, parity) = in.template gamma<1>();
case 2: arg.out(x_cb, parity) = in.template gamma<2>();
case 3: arg.out(x_cb, parity) = in.template gamma<3>();
case 4: arg.out(x_cb, parity) = in.template gamma<4>();
}
}
};
Expand All @@ -101,12 +101,12 @@ namespace quda {
constexpr int d = 4;
if (!arg.doublet) {
fermion_t in = arg.in(x_cb, parity);
arg.out(x_cb, parity) = arg.a * (in + arg.b * in.igamma(d));
arg.out(x_cb, parity) = arg.a * (in + arg.b * in.template igamma<d>());
} else {
fermion_t in_1 = arg.in(x_cb+0*arg.volumeCB, parity);
fermion_t in_2 = arg.in(x_cb+1*arg.volumeCB, parity);
arg.out(x_cb + 0 * arg.volumeCB, parity) = arg.a * (in_1 + arg.b * in_1.igamma(d) + arg.c * in_2);
arg.out(x_cb + 1 * arg.volumeCB, parity) = arg.a * (in_2 - arg.b * in_2.igamma(d) + arg.c * in_1);
arg.out(x_cb + 0 * arg.volumeCB, parity) = arg.a * (in_1 + arg.b * in_1.template igamma<d>() + arg.c * in_2);
arg.out(x_cb + 1 * arg.volumeCB, parity) = arg.a * (in_2 - arg.b * in_2.template igamma<d>() + arg.c * in_1);
}
}
};
Expand Down
9 changes: 4 additions & 5 deletions include/kernels/dslash_mdw_fused.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -211,8 +211,7 @@ namespace quda {

const int index_4d_cb = index_4d_cb_from_coordinate_4d(coordinate, arg.dim);

#pragma unroll
for (int d = 0; d < 4; d++) // loop over dimension
static_for<0,4>([&](auto d) // loop over dimension
{
int x[4] = {coordinate[0], coordinate[1], coordinate[2], coordinate[3]};
x[d] = (coordinate[d] == arg.dim[d] - 1 && !arg.comm[d]) ? 0 : coordinate[d] + 1;
Expand All @@ -228,7 +227,7 @@ namespace quda {

const Link U = arg.U(d, index_4d_cb, arg.parity);
const Vector in = arg.in(fwd_idx, their_spinor_parity);
out += (U * in.project(d, proj_dir)).reconstruct(d, proj_dir);
out += (U * in.template project<d, proj_dir>()).template reconstruct<d, proj_dir>();
}
x[d] = (coordinate[d] == 0 && !arg.comm[d]) ? arg.dim[d] - 1 : coordinate[d] - 1;
if (!halo || !is_halo_4d(x, arg.dim, arg.halo_shift)) {
Expand All @@ -245,9 +244,9 @@ namespace quda {

const Link U = arg.U(d, gauge_idx, 1 - arg.parity);
const Vector in = arg.in(back_idx, their_spinor_parity);
out += (conj(U) * in.project(d, proj_dir)).reconstruct(d, proj_dir);
out += (conj(U) * in.template project<d, proj_dir>()).template reconstruct<d, proj_dir>();
}
} // nDim
}); // nDim
}

/**
Expand Down
18 changes: 9 additions & 9 deletions include/kernels/dslash_mobius_eofa.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -119,19 +119,19 @@ namespace quda
const Vector in = cache.load(threadIdx.x, (s + 1) % Ls, threadIdx.z);
constexpr int proj_dir = Arg::dagger ? +1 : -1;
if (s == Ls - 1) {
out += (-arg.m_f * in.project(4, proj_dir)).reconstruct(4, proj_dir);
out += (-arg.m_f * in.template project<4, proj_dir>()).template reconstruct<4, proj_dir>();
} else {
out += in.project(4, proj_dir).reconstruct(4, proj_dir);
out += in.template project<4, proj_dir>().template reconstruct<4, proj_dir>();
}
}

{ // backwards direction
const Vector in = cache.load(threadIdx.x, (s + Ls - 1) % Ls, threadIdx.z);
constexpr int proj_dir = Arg::dagger ? -1 : +1;
if (s == 0) {
out += (-arg.m_f * in.project(4, proj_dir)).reconstruct(4, proj_dir);
out += (-arg.m_f * in.template project<4, proj_dir>()).template reconstruct<4, proj_dir>();
} else {
out += in.project(4, proj_dir).reconstruct(4, proj_dir);
out += in.template project<4, proj_dir>().template reconstruct<4, proj_dir>();
}
}

Expand All @@ -145,12 +145,12 @@ namespace quda
if (s == (Arg::pm ? Ls - 1 : 0)) {
for (int sp = 0; sp < Ls; sp++) {
out += (static_cast<real>(0.5) * arg.coeff.u[sp])
* cache.load(threadIdx.x, sp, threadIdx.z).project(4, proj_dir).reconstruct(4, proj_dir);
* cache.load(threadIdx.x, sp, threadIdx.z).template project<4, proj_dir>().template reconstruct<4, proj_dir>();
}
}
} else {
out += (static_cast<real>(0.5) * arg.coeff.u[s])
* cache.load(threadIdx.x, Arg::pm ? Ls - 1 : 0, threadIdx.z).project(4, proj_dir).reconstruct(4, proj_dir);
* cache.load(threadIdx.x, Arg::pm ? Ls - 1 : 0, threadIdx.z).template project<4, proj_dir>().template reconstruct<4, proj_dir>();
}

if (Arg::xpay) { // really axpy
Expand Down Expand Up @@ -197,19 +197,19 @@ namespace quda
int exp = s < sp ? arg.Ls - sp + s : s - sp;
real factorR = 0.5 * arg.coeff.y[Arg::pm ? arg.Ls - exp - 1 : exp] * (s < sp ? -arg.m_f : static_cast<real>(1.0));
constexpr int proj_dir = Arg::dagger ? -1 : +1;
out += factorR * (in.project(4, proj_dir)).reconstruct(4, proj_dir);
out += factorR * (in.template project<4, proj_dir>()).template reconstruct<4, proj_dir>();
}
{
int exp = s > sp ? arg.Ls - s + sp : sp - s;
real factorL = 0.5 * arg.coeff.y[Arg::pm ? arg.Ls - exp - 1 : exp] * (s > sp ? -arg.m_f : static_cast<real>(1.0));
constexpr int proj_dir = Arg::dagger ? +1 : -1;
out += factorL * (in.project(4, proj_dir)).reconstruct(4, proj_dir);
out += factorL * (in.template project<4, proj_dir>()).template reconstruct<4, proj_dir>();
}
// The EOFA stuff
{
constexpr int proj_dir = Arg::pm ? +1 : -1;
real t = Arg::dagger ? arg.coeff.y[s] * arg.coeff.x[sp] : arg.coeff.x[s] * arg.coeff.y[sp];
out += (t * sherman_morrison) * (in.project(4, proj_dir)).reconstruct(4, proj_dir);
out += (t * sherman_morrison) * (in.template project<4, proj_dir>()).template reconstruct<4, proj_dir>();
}
}
if (Arg::xpay) { // really axpy
Expand Down
4 changes: 2 additions & 2 deletions include/kernels/dslash_ndeg_twisted_mass.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,11 @@ namespace quda

if (flavor == 0) {
out = x0 + arg.a * out;
out += arg.b * x0.igamma(4);
out += arg.b * x0.template igamma<4>();
out += arg.c * x1;
} else {
out = x1 + arg.a * out;
out += -arg.b * x1.igamma(4);
out += -arg.b * x1.template igamma<4>();
out += arg.c * x0;
}

Expand Down
Loading