From 1531779b06d07513bfe1524a11937e64b2e6807c Mon Sep 17 00:00:00 2001 From: Xiao-Yong Jin Date: Thu, 21 Jul 2022 20:41:56 +0000 Subject: [PATCH 1/9] add static_for in util_quda.h It accepts a functor with a templated operator(). --- include/util_quda.h | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/include/util_quda.h b/include/util_quda.h index 533df01970..3da8f84489 100644 --- a/include/util_quda.h +++ b/include/util_quda.h @@ -14,6 +14,15 @@ namespace quda constexpr bool str_slant(const char *str) { return *str == '/' ? true : (*str ? str_slant(str + 1) : false); } constexpr const char *r_slant(const char *str) { return *str == '/' ? (str + 1) : r_slant(str - 1); } constexpr const char *file_name(const char *str) { return str_slant(str) ? r_slant(str_end(str)) : str; } + + template + __attribute__((always_inline)) void static_for(F&&f) + { + if constexpr (A(); + __attribute__((musttail)) return static_for(std::forward(f)); + } + } } // namespace quda /** From 3216cb4a0731ecfe31ca922a1f0e8f50638da9a1 Mon Sep 17 00:00:00 2001 From: Xiao-Yong Jin Date: Thu, 21 Jul 2022 20:44:39 +0000 Subject: [PATCH 2/9] flip the switch in complex i_ function --- include/complex_quda.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/complex_quda.h b/include/complex_quda.h index 66b5eee609..a6a4a3e3e1 100644 --- a/include/complex_quda.h +++ b/include/complex_quda.h @@ -1207,7 +1207,7 @@ lhs.real()*rhs.imag()+lhs.imag()*rhs.real()); template __host__ __device__ inline complex i_(const complex &a) { // FIXME compiler generates worse code with "optimal" code -#if 1 +#if 0 return complex(0.0, 1.0) * a; #else return complex(-a.imag(), a.real()); From 16a632cbb49996c8bb78fe99d0ee54aef52ec33d Mon Sep 17 00:00:00 2001 From: Xiao-Yong Jin Date: Thu, 21 Jul 2022 20:45:49 +0000 Subject: [PATCH 3/9] replace switch with if constexpr in color_spinor.h --- include/color_spinor.h | 614 ++++++++---------- include/kernels/clover_outer_product.cuh | 13 +- .../kernels/clover_sigma_outer_product.cuh | 2 +- include/kernels/dslash_gamma_helper.cuh | 16 +- include/kernels/dslash_pack.cuh | 16 +- include/kernels/dslash_wilson.cuh | 13 +- 6 files changed, 309 insertions(+), 365 deletions(-) diff --git a/include/color_spinor.h b/include/color_spinor.h index 3f7e32ff74..70817e65f0 100644 --- a/include/color_spinor.h +++ b/include/color_spinor.h @@ -75,18 +75,18 @@ namespace quda { __device__ __host__ inline void operator=(const colorspinor_ghost_wrapper &s); /** - @brief 2-d accessor functor - @param[in] s Spin index - @param[in] c Color index - @return Complex number at this spin and color index + @brief 2-d accessor functor + @param[in] s Spin index + @param[in] c Color index + @return Complex number at this spin and color index */ __device__ __host__ inline complex& operator()(int s, int c) { return data[s*Nc + c]; } /** - @brief 2-d accessor functor - @param[in] s Spin index - @param[in] c Color index - @return Complex number at this spin and color index + @brief 2-d accessor functor + @param[in] s Spin index + @param[in] c Color index + @return Complex number at this spin and color index */ __device__ __host__ inline const complex& operator()(int s, int c) const { return data[s*Nc + c]; } @@ -146,120 +146,114 @@ namespace quda { } /** - Return this application of gamma_dim to this spinor - @param dim Which dimension gamma matrix we are applying - @return The new spinor + Return this application of gamma_dim to this spinor + @param dim Which dimension gamma matrix we are applying + @return The new spinor */ - __device__ __host__ inline ColorSpinor gamma(int dim) { + template + __device__ __host__ inline ColorSpinor gamma() { ColorSpinor a; const auto &t = *this; - switch (dim) { - case 0: // x dimension + static_assert(0<=dim && dim<=4, "dim must be 0-4"); + + if constexpr (dim==0) { // x dimension #pragma unroll - for (int i=0; i igamma(int dim) { + template + __device__ __host__ inline ColorSpinor igamma() { ColorSpinor a; const auto &t = *this; - switch (dim) { - case 0: // x dimension + static_assert(0<=dim && dim<=4, "dim must be 0-4"); + + if constexpr (dim==0) { // x dimension #pragma unroll - for (int i=0; i project(int dim, int sign) const + template + __device__ __host__ inline ColorSpinor project() const { ColorSpinor proj; const auto &t = *this; - switch (dim) { - case 0: // x dimension - switch (sign) { - case 1: // positive projector + + static_assert(0<=dim && dim<=4, "dim must be 0-4"); + static_assert(sign==-1 || sign==1, "sign must be -1 or 1"); + + if constexpr (dim==0) { // x dimension + if constexpr (sign==1) { // positive projector #pragma unroll - for (int i=0; i(2.0) * t(0, i); proj(1, i) = static_cast(2.0) * t(1, i); } - break; - case -1: // negative projector + } else if constexpr (sign==-1) { // negative projector #pragma unroll - for (int i=0; i(2.0) * t(2, i); proj(1, i) = static_cast(2.0) * t(3, i); } - break; } - break; - case 4: - switch (sign) { - case 1: // positive projector + } else if constexpr (dim==4) { // gamma_5 + if constexpr (sign==1) { // positive projector #pragma unroll for (int i = 0; i < Nc; i++) { proj(0, i) = t(0, i) + t(2, i); proj(1, i) = t(1, i) + t(3, i); } - break; - case -1: // negative projector + } else if constexpr (sign==-1) { // negative projector #pragma unroll for (int i = 0; i < Nc; i++) { proj(0, i) = t(0, i) - t(2, i); proj(1, i) = t(1, i) - t(3, i); } - break; } - break; } return proj; @@ -394,165 +372,149 @@ namespace quda { sigma(0,1) = i 0 0 0 0 -i 0 0 - 0 0 i 0 - 0 0 0 -i + 0 0 i 0 + 0 0 0 -i sigma(0,2) = 0 -1 0 0 1 0 0 0 - 0 0 0 -1 - 0 0 1 0 + 0 0 0 -1 + 0 0 1 0 sigma(0,3) = 0 0 0 -i 0 0 -i 0 - 0 -i i 0 - -i 0 0 0 + 0 -i i 0 + -i 0 0 0 sigma(1,2) = 0 i 0 0 i 0 0 0 - 0 0 0 i - 0 0 i 0 + 0 0 0 i + 0 0 i 0 sigma(1,3) = 0 0 0 -1 0 0 1 0 - 0 -1 0 0 - 1 0 0 0 + 0 -1 0 0 + 1 0 0 0 sigma(2,3) = 0 0 -i 0 0 0 0 i - -i 0 0 0 - 0 i 0 0 + -i 0 0 0 + 0 i 0 0 */ - __device__ __host__ inline ColorSpinor sigma(int mu, int nu) const + template + __device__ __host__ inline ColorSpinor sigma() const { ColorSpinor a; const ColorSpinor &b = *this; complex j(0.0,1.0); - switch(mu) { - case 0: - switch(nu) { - case 1: -#pragma unroll - for (int i=0; i a; #pragma unroll for (int c=0; c a; #pragma unroll for (int c=0; c reconstruct(int dim, int sign) const + template + __device__ __host__ inline ColorSpinor reconstruct() const { ColorSpinor recon; - const auto t = *this; + const auto &t = *this; + + static_assert(sign==1||sign==-1, "sign must be 1 or -1"); + static_assert(dim==0||dim==1||dim==2||dim==3||dim==4, "dim must be 0-4"); - switch (dim) { - case 0: // x dimension - switch (sign) { - case 1: // positive projector + if constexpr (dim==0) { + if constexpr (sign==1) { #pragma unroll - for (int i=0; i([&]{ int shift[4] = {0, 0, 0, 0}; shift[dim] = 1; const int nbr_idx = neighborIndex(x_cb, shift, arg.partitioned, arg.parity, arg.X); @@ -73,10 +72,10 @@ namespace quda { Spinor B_shift = arg.inB(nbr_idx, 0); Spinor D_shift = arg.inD(nbr_idx, 0); - B_shift = (B_shift.project(dim,1)).reconstruct(dim,1); + B_shift = (B_shift.template project()).template reconstruct(); Link result = outerProdSpinTrace(B_shift,A); - D_shift = (D_shift.project(dim,-1)).reconstruct(dim,-1); + D_shift = (D_shift.template project()).template reconstruct(); result += outerProdSpinTrace(D_shift,C); Link temp = arg.force(dim, x_cb, arg.parity); @@ -84,7 +83,7 @@ namespace quda { result = temp + U*result*arg.coeff; arg.force(dim, x_cb, arg.parity) = result; } - } // dim + }); // dim } }; @@ -107,11 +106,11 @@ namespace quda { Spinor C = arg.inC(bulk_cb_idx, 0); HalfSpinor projected_tmp = arg.inB.Ghost(Arg::dim, 1, x_cb, 0); - Spinor B_shift = projected_tmp.reconstruct(Arg::dim, 1); + Spinor B_shift = projected_tmp.template reconstruct(); Link result = outerProdSpinTrace(B_shift,A); projected_tmp = arg.inD.Ghost(Arg::dim, 1, x_cb, 0); - Spinor D_shift = projected_tmp.reconstruct(Arg::dim,-1); + Spinor D_shift = projected_tmp.template reconstruct(); result += outerProdSpinTrace(D_shift,C); Link temp = arg.force(Arg::dim, bulk_cb_idx, arg.parity); diff --git a/include/kernels/clover_sigma_outer_product.cuh b/include/kernels/clover_sigma_outer_product.cuh index 91d2de8ef5..88dcfb0ebc 100644 --- a/include/kernels/clover_sigma_outer_product.cuh +++ b/include/kernels/clover_sigma_outer_product.cuh @@ -51,7 +51,7 @@ namespace quda for (int i = 0; i < Arg::nvector; i++) { const Spinor A = arg.inA[i](x_cb, parity); const Spinor B = arg.inB[i](x_cb, parity); - Spinor C = A.sigma(nu, mu); // multiply by sigma_mu_nu + Spinor C = A.template sigma(); // multiply by sigma_mu_nu result += arg.coeff[i][parity] * outerProdSpinTrace(C, B); } diff --git a/include/kernels/dslash_gamma_helper.cuh b/include/kernels/dslash_gamma_helper.cuh index 3b5e27492a..cb95074d95 100644 --- a/include/kernels/dslash_gamma_helper.cuh +++ b/include/kernels/dslash_gamma_helper.cuh @@ -78,11 +78,11 @@ namespace quda { { ColorSpinor in = arg.in(x_cb, parity); switch(arg.d) { - case 0: arg.out(x_cb, parity) = in.gamma(0); - case 1: arg.out(x_cb, parity) = in.gamma(1); - case 2: arg.out(x_cb, parity) = in.gamma(2); - case 3: arg.out(x_cb, parity) = in.gamma(3); - case 4: arg.out(x_cb, parity) = in.gamma(4); + case 0: arg.out(x_cb, parity) = in.template gamma<0>(); + case 1: arg.out(x_cb, parity) = in.template gamma<1>(); + case 2: arg.out(x_cb, parity) = in.template gamma<2>(); + case 3: arg.out(x_cb, parity) = in.template gamma<3>(); + case 4: arg.out(x_cb, parity) = in.template gamma<4>(); } } }; @@ -101,12 +101,12 @@ namespace quda { constexpr int d = 4; if (!arg.doublet) { fermion_t in = arg.in(x_cb, parity); - arg.out(x_cb, parity) = arg.a * (in + arg.b * in.igamma(d)); + arg.out(x_cb, parity) = arg.a * (in + arg.b * in.template igamma()); } else { fermion_t in_1 = arg.in(x_cb+0*arg.volumeCB, parity); fermion_t in_2 = arg.in(x_cb+1*arg.volumeCB, parity); - arg.out(x_cb + 0 * arg.volumeCB, parity) = arg.a * (in_1 + arg.b * in_1.igamma(d) + arg.c * in_2); - arg.out(x_cb + 1 * arg.volumeCB, parity) = arg.a * (in_2 - arg.b * in_2.igamma(d) + arg.c * in_1); + arg.out(x_cb + 0 * arg.volumeCB, parity) = arg.a * (in_1 + arg.b * in_1.template igamma() + arg.c * in_2); + arg.out(x_cb + 1 * arg.volumeCB, parity) = arg.a * (in_2 - arg.b * in_2.template igamma() + arg.c * in_1); } } }; diff --git a/include/kernels/dslash_pack.cuh b/include/kernels/dslash_pack.cuh index b4c22a4bb9..1ebdde58e1 100644 --- a/include/kernels/dslash_pack.cuh +++ b/include/kernels/dslash_pack.cuh @@ -154,16 +154,16 @@ namespace quda constexpr int proj_dir = dagger ? +1 : -1; Vector f = arg.in_pack(idx + s * arg.dc.volume_4d_cb, spinor_parity); if (twist == 1) { - f = arg.twist_a * (f + arg.twist_b * f.igamma(4)); + f = arg.twist_a * (f + arg.twist_b * f.template igamma<4>()); } else if (twist == 2) { Vector f1 = arg.in_pack(idx + (1 - s) * arg.dc.volume_4d_cb, spinor_parity); // load other flavor if (s == 0) - f = arg.twist_a * (f + arg.twist_b * f.igamma(4) + arg.twist_c * f1); + f = arg.twist_a * (f + arg.twist_b * f.template igamma<4>() + arg.twist_c * f1); else - f = arg.twist_a * (f - arg.twist_b * f.igamma(4) + arg.twist_c * f1); + f = arg.twist_a * (f - arg.twist_b * f.template igamma<4>() + arg.twist_c * f1); } if (arg.spin_project) { - arg.in_pack.Ghost(dim, 0, ghost_idx + s * arg.dc.ghostFaceCB[dim], spinor_parity) = f.project(dim, proj_dir); + arg.in_pack.Ghost(dim, 0, ghost_idx + s * arg.dc.ghostFaceCB[dim], spinor_parity) = f.template project(); } else { arg.in_pack.Ghost(dim, 0, ghost_idx + s * arg.dc.ghostFaceCB[dim], spinor_parity) = f; } @@ -173,16 +173,16 @@ namespace quda constexpr int proj_dir = dagger ? -1 : +1; Vector f = arg.in_pack(idx + s * arg.dc.volume_4d_cb, spinor_parity); if (twist == 1) { - f = arg.twist_a * (f + arg.twist_b * f.igamma(4)); + f = arg.twist_a * (f + arg.twist_b * f.template igamma<4>()); } else if (twist == 2) { Vector f1 = arg.in_pack(idx + (1 - s) * arg.dc.volume_4d_cb, spinor_parity); // load other flavor if (s == 0) - f = arg.twist_a * (f + arg.twist_b * f.igamma(4) + arg.twist_c * f1); + f = arg.twist_a * (f + arg.twist_b * f.template igamma<4>() + arg.twist_c * f1); else - f = arg.twist_a * (f - arg.twist_b * f.igamma(4) + arg.twist_c * f1); + f = arg.twist_a * (f - arg.twist_b * f.template igamma<4>() + arg.twist_c * f1); } if (arg.spin_project) { - arg.in_pack.Ghost(dim, 1, ghost_idx + s * arg.dc.ghostFaceCB[dim], spinor_parity) = f.project(dim, proj_dir); + arg.in_pack.Ghost(dim, 1, ghost_idx + s * arg.dc.ghostFaceCB[dim], spinor_parity) = f.template project(); } else { arg.in_pack.Ghost(dim, 1, ghost_idx + s * arg.dc.ghostFaceCB[dim], spinor_parity) = f; } diff --git a/include/kernels/dslash_wilson.cuh b/include/kernels/dslash_wilson.cuh index cd7575974a..99f9e3786e 100644 --- a/include/kernels/dslash_wilson.cuh +++ b/include/kernels/dslash_wilson.cuh @@ -77,8 +77,7 @@ namespace quda // parity for gauge field - include residual parity from 5-d => 4-d checkerboarding const int gauge_parity = (Arg::nDim == 5 ? (coord.x_cb / arg.dc.volume_4d_cb + parity) % 2 : parity); -#pragma unroll - for (int d = 0; d < 4; d++) { // loop over dimension - 4 and not nDim since this is used for DWF as well + static_for<0,4>([&]{ // loop over dimension - 4 and not nDim since this is used for DWF as well { // Forward gather - compute fwd offset for vector fetch const int fwd_idx = getNeighborIndexCB(coord, d, +1, arg.dc); const int gauge_idx = (Arg::nDim == 5 ? coord.x_cb % arg.dc.volume_4d_cb : coord.x_cb); @@ -95,13 +94,13 @@ namespace quda Link U = arg.U(d, gauge_idx, gauge_parity); HalfVector in = arg.in.Ghost(d, 1, ghost_idx + coord.s * arg.dc.ghostFaceCB[d], their_spinor_parity); - out += (U * in).reconstruct(d, proj_dir); + out += (U * in).template reconstruct(); } else if (doBulk() && !ghost) { Link U = arg.U(d, gauge_idx, gauge_parity); Vector in = arg.in(fwd_idx + coord.s * arg.dc.volume_4d_cb, their_spinor_parity); - out += (U * in.project(d, proj_dir)).reconstruct(d, proj_dir); + out += (U * in.template project()).template reconstruct(); } } @@ -121,16 +120,16 @@ namespace quda Link U = arg.U.Ghost(d, gauge_ghost_idx, 1 - gauge_parity); HalfVector in = arg.in.Ghost(d, 0, ghost_idx + coord.s * arg.dc.ghostFaceCB[d], their_spinor_parity); - out += (conj(U) * in).reconstruct(d, proj_dir); + out += (conj(U) * in).template reconstruct(); } else if (doBulk() && !ghost) { Link U = arg.U(d, gauge_idx, 1 - gauge_parity); Vector in = arg.in(back_idx + coord.s * arg.dc.volume_4d_cb, their_spinor_parity); - out += (conj(U) * in.project(d, proj_dir)).reconstruct(d, proj_dir); + out += (conj(U) * in.template project()).template reconstruct(); } } - } // nDim + }); // nDim } template struct wilson : dslash_default { From 6b83e57ee51262f82a75ca7180478775e9dd2b70 Mon Sep 17 00:00:00 2001 From: Xiao-Yong Jin Date: Thu, 21 Jul 2022 23:36:55 +0000 Subject: [PATCH 4/9] static_for: guard against missing musttail if not clang --- include/util_quda.h | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/include/util_quda.h b/include/util_quda.h index 3da8f84489..2142b4890f 100644 --- a/include/util_quda.h +++ b/include/util_quda.h @@ -7,6 +7,14 @@ #include #include +#ifndef QUDA_MUSTTAIL +#ifdef __clang__ +#define QUDA_MUSTTAIL __attribute__((musttail)) +#else +#define QUDA_MUSTTAIL +#endif +#endif + namespace quda { // strip path from __FILE__ @@ -20,7 +28,7 @@ namespace quda { if constexpr (A(); - __attribute__((musttail)) return static_for(std::forward(f)); + QUDA_MUSTTAIL return static_for(std::forward(f)); } } } // namespace quda From aa67c80f1101ab8e48f862106f5c8a9a8d8920d8 Mon Sep 17 00:00:00 2001 From: Xiao-Yong Jin Date: Thu, 21 Jul 2022 23:37:35 +0000 Subject: [PATCH 5/9] constexpr: fix domain wall and twisted mass color spinor functions --- include/kernels/dslash_domain_wall_5d.cuh | 8 ++-- include/kernels/dslash_domain_wall_m5.cuh | 48 +++++++++---------- include/kernels/dslash_mdw_fused.cuh | 9 ++-- include/kernels/dslash_mobius_eofa.cuh | 18 +++---- include/kernels/dslash_ndeg_twisted_mass.cuh | 4 +- ...slash_ndeg_twisted_mass_preconditioned.cuh | 8 ++-- include/kernels/dslash_twisted_mass.cuh | 2 +- .../dslash_twisted_mass_preconditioned.cuh | 29 ++++++----- 8 files changed, 62 insertions(+), 64 deletions(-) diff --git a/include/kernels/dslash_domain_wall_5d.cuh b/include/kernels/dslash_domain_wall_5d.cuh index c9c344df3d..6ade2dd993 100644 --- a/include/kernels/dslash_domain_wall_5d.cuh +++ b/include/kernels/dslash_domain_wall_5d.cuh @@ -55,9 +55,9 @@ namespace quda constexpr int proj_dir = dagger ? +1 : -1; Vector in = arg.in(fwd_idx, their_spinor_parity); if (s == arg.Ls - 1) { - out += (-arg.m_f * in.project(d, proj_dir)).reconstruct(d, proj_dir); + out += (-arg.m_f * in.template project()).template reconstruct(); } else { - out += in.project(d, proj_dir).reconstruct(d, proj_dir); + out += in.template project().template reconstruct(); } } @@ -66,9 +66,9 @@ namespace quda constexpr int proj_dir = dagger ? -1 : +1; Vector in = arg.in(back_idx, their_spinor_parity); if (s == 0) { - out += (-arg.m_f * in.project(d, proj_dir)).reconstruct(d, proj_dir); + out += (-arg.m_f * in.template project()).template reconstruct(); } else { - out += in.project(d, proj_dir).reconstruct(d, proj_dir); + out += in.template project().template reconstruct(); } } } diff --git a/include/kernels/dslash_domain_wall_m5.cuh b/include/kernels/dslash_domain_wall_m5.cuh index 4919a5374e..006f086fcb 100644 --- a/include/kernels/dslash_domain_wall_m5.cuh +++ b/include/kernels/dslash_domain_wall_m5.cuh @@ -226,7 +226,7 @@ namespace quda constexpr int proj_dir = dagger ? +1 : -1; if (shared) { if (sync) { cache.sync(); } - cache.save(in.project(4, proj_dir)); + cache.save(in.template project<4, proj_dir>()); cache.sync(); } const int fwd_s = (s + 1) % arg.Ls; @@ -236,12 +236,12 @@ namespace quda half_in = cache.load(threadIdx.x, fwd_s, parity); } else { Vector full_in = arg.in(fwd_idx, parity); - half_in = full_in.project(4, proj_dir); + half_in = full_in.template project<4, proj_dir>(); } if (s == arg.Ls - 1) { - out += (-arg.m_f * half_in).reconstruct(4, proj_dir); + out += (-arg.m_f * half_in).template reconstruct<4, proj_dir>(); } else { - out += half_in.reconstruct(4, proj_dir); + out += half_in.template reconstruct<4, proj_dir>(); } } @@ -249,7 +249,7 @@ namespace quda constexpr int proj_dir = dagger ? -1 : +1; if (shared) { cache.sync(); - cache.save(in.project(4, proj_dir)); + cache.save(in.template project<4, proj_dir>()); cache.sync(); } const int back_s = (s + arg.Ls - 1) % arg.Ls; @@ -259,12 +259,12 @@ namespace quda half_in = cache.load(threadIdx.x, back_s, parity); } else { Vector full_in = arg.in(back_idx, parity); - half_in = full_in.project(4, proj_dir); + half_in = full_in.template project<4, proj_dir>(); } if (s == 0) { - out += (-arg.m_f * half_in).reconstruct(4, proj_dir); + out += (-arg.m_f * half_in).template reconstruct<4, proj_dir>(); } else { - out += half_in.reconstruct(4, proj_dir); + out += half_in.template reconstruct<4, proj_dir>(); } } @@ -284,9 +284,9 @@ namespace quda const Vector in = shared ? cache.load(threadIdx.x, fwd_s, parity) : arg.in(fwd_idx, parity); constexpr int proj_dir = dagger ? +1 : -1; if (s == arg.Ls - 1) { - out += (-arg.m_f * in.project(4, proj_dir)).reconstruct(4, proj_dir); + out += (-arg.m_f * in.template project<4, proj_dir>()).template reconstruct<4, proj_dir>(); } else { - out += in.project(4, proj_dir).reconstruct(4, proj_dir); + out += in.template project<4, proj_dir>().template reconstruct<4, proj_dir>(); } } @@ -296,9 +296,9 @@ namespace quda const Vector in = shared ? cache.load(threadIdx.x, back_s, parity) : arg.in(back_idx, parity); constexpr int proj_dir = dagger ? -1 : +1; if (s == 0) { - out += (-arg.m_f * in.project(4, proj_dir)).reconstruct(4, proj_dir); + out += (-arg.m_f * in.template project<4, proj_dir>()).template reconstruct<4, proj_dir>(); } else { - out += in.project(4, proj_dir).reconstruct(4, proj_dir); + out += in.template project<4, proj_dir>().template reconstruct<4, proj_dir>(); } } } // use_half_vector @@ -395,14 +395,14 @@ namespace quda int exp = s_ < s ? arg.Ls - s + s_ : s_ - s; real factorR = inv * fpow(k, exp) * (s_ < s ? -arg.m_f : static_cast(1.0)); constexpr int proj_dir = dagger ? -1 : +1; - out += factorR * (in.project(4, proj_dir)).reconstruct(4, proj_dir); + out += factorR * (in.template project<4, proj_dir>()).template reconstruct<4, proj_dir>(); } { int exp = s_ > s ? arg.Ls - s_ + s : s - s_; real factorL = inv * fpow(k, exp) * (s_ > s ? -arg.m_f : static_cast(1.0)); constexpr int proj_dir = dagger ? +1 : -1; - out += factorL * (in.project(4, proj_dir)).reconstruct(4, proj_dir); + out += factorL * (in.template project<4, proj_dir>()).template reconstruct<4, proj_dir>(); } } @@ -443,7 +443,7 @@ namespace quda if (shared) { if (sync) { cache.sync(); } - cache.save(in.project(4, proj_dir)); + cache.save(in.template project<4, proj_dir>()); cache.sync(); } @@ -457,21 +457,21 @@ namespace quda r += factorR * cache.load(threadIdx.x, s, parity); } else { Vector in = arg.in(s * arg.volume_4d_cb + x_cb, parity); - r += factorR * in.project(4, proj_dir); + r += factorR * in.template project<4, proj_dir>(); } R *= coeff.kappa(s); s = (s + arg.Ls - 1) % arg.Ls; } - out += r.reconstruct(4, proj_dir); + out += r.template reconstruct<4, proj_dir>(); } { // second do L constexpr int proj_dir = dagger ? +1 : -1; if (shared) { cache.sync(); // ensure we finish R before overwriting cache - cache.save(in.project(4, proj_dir)); + cache.save(in.template project<4, proj_dir>()); cache.sync(); } @@ -485,14 +485,14 @@ namespace quda l += factorL * cache.load(threadIdx.x, s, parity); } else { Vector in = arg.in(s * arg.volume_4d_cb + x_cb, parity); - l += factorL * in.project(4, proj_dir); + l += factorL * in.template project<4, proj_dir>(); } L *= coeff.kappa(s); s = (s + 1) % arg.Ls; } - out += l.reconstruct(4, proj_dir); + out += l.template reconstruct<4, proj_dir>(); } } else { // use_half_vector SharedMemoryCache cache(target::block_dim()); @@ -512,13 +512,13 @@ namespace quda auto factorR = (s_ < s ? -arg.m_f * R : R); Vector in = shared ? cache.load(threadIdx.x, s, parity) : arg.in(s * arg.volume_4d_cb + x_cb, parity); - r += factorR * in.project(4, proj_dir); + r += factorR * in.template project<4, proj_dir>(); R *= coeff.kappa(s); s = (s + arg.Ls - 1) % arg.Ls; } - out += r.reconstruct(4, proj_dir); + out += r.template reconstruct<4, proj_dir>(); } { // second do L @@ -531,13 +531,13 @@ namespace quda auto factorL = (s_ > s ? -arg.m_f * L : L); Vector in = shared ? cache.load(threadIdx.x, s, parity) : arg.in(s * arg.volume_4d_cb + x_cb, parity); - l += factorL * in.project(4, proj_dir); + l += factorL * in.template project<4, proj_dir>(); L *= coeff.kappa(s); s = (s + 1) % arg.Ls; } - out += l.reconstruct(4, proj_dir); + out += l.template reconstruct<4, proj_dir>(); } } // use_half_vector diff --git a/include/kernels/dslash_mdw_fused.cuh b/include/kernels/dslash_mdw_fused.cuh index 8665b8dcc3..1b48a71bee 100644 --- a/include/kernels/dslash_mdw_fused.cuh +++ b/include/kernels/dslash_mdw_fused.cuh @@ -211,8 +211,7 @@ namespace quda { const int index_4d_cb = index_4d_cb_from_coordinate_4d(coordinate, arg.dim); -#pragma unroll - for (int d = 0; d < 4; d++) // loop over dimension + static_for<0,4>([&] // loop over dimension { int x[4] = {coordinate[0], coordinate[1], coordinate[2], coordinate[3]}; x[d] = (coordinate[d] == arg.dim[d] - 1 && !arg.comm[d]) ? 0 : coordinate[d] + 1; @@ -228,7 +227,7 @@ namespace quda { const Link U = arg.U(d, index_4d_cb, arg.parity); const Vector in = arg.in(fwd_idx, their_spinor_parity); - out += (U * in.project(d, proj_dir)).reconstruct(d, proj_dir); + out += (U * in.template project()).template reconstruct(); } x[d] = (coordinate[d] == 0 && !arg.comm[d]) ? arg.dim[d] - 1 : coordinate[d] - 1; if (!halo || !is_halo_4d(x, arg.dim, arg.halo_shift)) { @@ -245,9 +244,9 @@ namespace quda { const Link U = arg.U(d, gauge_idx, 1 - arg.parity); const Vector in = arg.in(back_idx, their_spinor_parity); - out += (conj(U) * in.project(d, proj_dir)).reconstruct(d, proj_dir); + out += (conj(U) * in.template project()).template reconstruct(); } - } // nDim + }); // nDim } /** diff --git a/include/kernels/dslash_mobius_eofa.cuh b/include/kernels/dslash_mobius_eofa.cuh index 169ab63fbb..cab8e719cd 100644 --- a/include/kernels/dslash_mobius_eofa.cuh +++ b/include/kernels/dslash_mobius_eofa.cuh @@ -119,9 +119,9 @@ namespace quda const Vector in = cache.load(threadIdx.x, (s + 1) % Ls, threadIdx.z); constexpr int proj_dir = Arg::dagger ? +1 : -1; if (s == Ls - 1) { - out += (-arg.m_f * in.project(4, proj_dir)).reconstruct(4, proj_dir); + out += (-arg.m_f * in.template project<4, proj_dir>()).template reconstruct<4, proj_dir>(); } else { - out += in.project(4, proj_dir).reconstruct(4, proj_dir); + out += in.template project<4, proj_dir>().template reconstruct<4, proj_dir>(); } } @@ -129,9 +129,9 @@ namespace quda const Vector in = cache.load(threadIdx.x, (s + Ls - 1) % Ls, threadIdx.z); constexpr int proj_dir = Arg::dagger ? -1 : +1; if (s == 0) { - out += (-arg.m_f * in.project(4, proj_dir)).reconstruct(4, proj_dir); + out += (-arg.m_f * in.template project<4, proj_dir>()).template reconstruct<4, proj_dir>(); } else { - out += in.project(4, proj_dir).reconstruct(4, proj_dir); + out += in.template project<4, proj_dir>().template reconstruct<4, proj_dir>(); } } @@ -145,12 +145,12 @@ namespace quda if (s == (Arg::pm ? Ls - 1 : 0)) { for (int sp = 0; sp < Ls; sp++) { out += (static_cast(0.5) * arg.coeff.u[sp]) - * cache.load(threadIdx.x, sp, threadIdx.z).project(4, proj_dir).reconstruct(4, proj_dir); + * cache.load(threadIdx.x, sp, threadIdx.z).template project<4, proj_dir>().template reconstruct<4, proj_dir>(); } } } else { out += (static_cast(0.5) * arg.coeff.u[s]) - * cache.load(threadIdx.x, Arg::pm ? Ls - 1 : 0, threadIdx.z).project(4, proj_dir).reconstruct(4, proj_dir); + * cache.load(threadIdx.x, Arg::pm ? Ls - 1 : 0, threadIdx.z).template project<4, proj_dir>().template reconstruct<4, proj_dir>(); } if (Arg::xpay) { // really axpy @@ -197,19 +197,19 @@ namespace quda int exp = s < sp ? arg.Ls - sp + s : s - sp; real factorR = 0.5 * arg.coeff.y[Arg::pm ? arg.Ls - exp - 1 : exp] * (s < sp ? -arg.m_f : static_cast(1.0)); constexpr int proj_dir = Arg::dagger ? -1 : +1; - out += factorR * (in.project(4, proj_dir)).reconstruct(4, proj_dir); + out += factorR * (in.template project<4, proj_dir>()).template reconstruct<4, proj_dir>(); } { int exp = s > sp ? arg.Ls - s + sp : sp - s; real factorL = 0.5 * arg.coeff.y[Arg::pm ? arg.Ls - exp - 1 : exp] * (s > sp ? -arg.m_f : static_cast(1.0)); constexpr int proj_dir = Arg::dagger ? +1 : -1; - out += factorL * (in.project(4, proj_dir)).reconstruct(4, proj_dir); + out += factorL * (in.template project<4, proj_dir>()).template reconstruct<4, proj_dir>(); } // The EOFA stuff { constexpr int proj_dir = Arg::pm ? +1 : -1; real t = Arg::dagger ? arg.coeff.y[s] * arg.coeff.x[sp] : arg.coeff.x[s] * arg.coeff.y[sp]; - out += (t * sherman_morrison) * (in.project(4, proj_dir)).reconstruct(4, proj_dir); + out += (t * sherman_morrison) * (in.template project<4, proj_dir>()).template reconstruct<4, proj_dir>(); } } if (Arg::xpay) { // really axpy diff --git a/include/kernels/dslash_ndeg_twisted_mass.cuh b/include/kernels/dslash_ndeg_twisted_mass.cuh index 2f88d5ee48..21fbc3772f 100644 --- a/include/kernels/dslash_ndeg_twisted_mass.cuh +++ b/include/kernels/dslash_ndeg_twisted_mass.cuh @@ -61,11 +61,11 @@ namespace quda if (flavor == 0) { out = x0 + arg.a * out; - out += arg.b * x0.igamma(4); + out += arg.b * x0.template igamma<4>(); out += arg.c * x1; } else { out = x1 + arg.a * out; - out += -arg.b * x1.igamma(4); + out += -arg.b * x1.template igamma<4>(); out += arg.c * x0; } diff --git a/include/kernels/dslash_ndeg_twisted_mass_preconditioned.cuh b/include/kernels/dslash_ndeg_twisted_mass_preconditioned.cuh index bf0b42060e..d4161be454 100644 --- a/include/kernels/dslash_ndeg_twisted_mass_preconditioned.cuh +++ b/include/kernels/dslash_ndeg_twisted_mass_preconditioned.cuh @@ -80,9 +80,9 @@ namespace quda Vector x0 = arg.x(coord.x_cb + 0 * arg.dc.volume_4d_cb, my_spinor_parity); Vector x1 = arg.x(coord.x_cb + 1 * arg.dc.volume_4d_cb, my_spinor_parity); if (flavor == 0) - out += arg.a_inv * (x0 + arg.b_inv * x0.igamma(4) + arg.c_inv * x1); + out += arg.a_inv * (x0 + arg.b_inv * x0.template igamma<4>() + arg.c_inv * x1); else - out += arg.a_inv * (x1 - arg.b_inv * x1.igamma(4) + arg.c_inv * x0); + out += arg.a_inv * (x1 - arg.b_inv * x1.template igamma<4>() + arg.c_inv * x0); } else { Vector x = arg.x(my_flavor_idx, my_spinor_parity); out += x; // just directly add since twist already applied in the dslash @@ -104,9 +104,9 @@ namespace quda cache.sync(); // safe to sync in here since other threads will exit if (isComplete(arg, coord) && active) { if (flavor == 0) - out = arg.a * (out + arg.b * out.igamma(4) + arg.c * cache.load_y(1)); + out = arg.a * (out + arg.b * out.template igamma<4>() + arg.c * cache.load_y(1)); else - out = arg.a * (out - arg.b * out.igamma(4) + arg.c * cache.load_y(0)); + out = arg.a * (out - arg.b * out.template igamma<4>() + arg.c * cache.load_y(0)); } } diff --git a/include/kernels/dslash_twisted_mass.cuh b/include/kernels/dslash_twisted_mass.cuh index 85230c729d..94eca61d50 100644 --- a/include/kernels/dslash_twisted_mass.cuh +++ b/include/kernels/dslash_twisted_mass.cuh @@ -51,7 +51,7 @@ namespace quda if (mykernel_type == INTERIOR_KERNEL) { Vector x = arg.x(coord.x_cb, my_spinor_parity); - x += arg.b * x.igamma(4); + x += arg.b * x.template igamma<4>(); out = x + arg.a * out; } else if (active) { Vector x = arg.out(coord.x_cb, my_spinor_parity); diff --git a/include/kernels/dslash_twisted_mass_preconditioned.cuh b/include/kernels/dslash_twisted_mass_preconditioned.cuh index 496e12276d..231f56c4e6 100644 --- a/include/kernels/dslash_twisted_mass_preconditioned.cuh +++ b/include/kernels/dslash_twisted_mass_preconditioned.cuh @@ -55,8 +55,7 @@ namespace quda typedef Matrix, Arg::nColor> Link; const int their_spinor_parity = nParity == 2 ? 1 - parity : 0; -#pragma unroll - for (int d = 0; d < Arg::nDim; d++) { // loop over dimension + static_for<0,Arg::nDim>([&]{ // loop over dimension { // Forward gather - compute fwd offset for vector fetch const int fwd_idx = getNeighborIndexCB(coord, d, +1, arg.dc); constexpr int proj_dir = dagger ? +1 : -1; @@ -72,24 +71,24 @@ namespace quda Link U = arg.U(d, coord.x_cb, parity); HalfVector in = arg.in.Ghost(d, 1, ghost_idx + coord.s * arg.dc.ghostFaceCB[d], their_spinor_parity); - out += (U * in).reconstruct(d, proj_dir); + out += (U * in).template reconstruct(); } else if (doBulk() && !ghost) { Link U = arg.U(d, coord.x_cb, parity); Vector in; if (twist == 1) { in = arg.in(fwd_idx + coord.s * arg.dc.volume_4d_cb, their_spinor_parity); - in = arg.a * (in + arg.b * in.igamma(4)); // apply A^{-1} to in + in = arg.a * (in + arg.b * in.template igamma<4>()); // apply A^{-1} to in } else { // twisted doublet Vector in0 = arg.in(fwd_idx + 0 * arg.dc.volume_4d_cb, their_spinor_parity); Vector in1 = arg.in(fwd_idx + 1 * arg.dc.volume_4d_cb, their_spinor_parity); if (coord.s == 0) - in = arg.a * (in0 + arg.b * in0.igamma(4) + arg.c * in1); + in = arg.a * (in0 + arg.b * in0.template igamma<4>() + arg.c * in1); else - in = arg.a * (in1 - arg.b * in1.igamma(4) + arg.c * in0); + in = arg.a * (in1 - arg.b * in1.template igamma<4>() + arg.c * in0); } - out += (U * in.project(d, proj_dir)).reconstruct(d, proj_dir); + out += (U * in.template project()).template reconstruct(); } } @@ -108,27 +107,27 @@ namespace quda Link U = arg.U.Ghost(d, ghost_idx, 1 - parity); HalfVector in = arg.in.Ghost(d, 0, ghost_idx + coord.s * arg.dc.ghostFaceCB[d], their_spinor_parity); - out += (conj(U) * in).reconstruct(d, proj_dir); + out += (conj(U) * in).template reconstruct(); } else if (doBulk() && !ghost) { Link U = arg.U(d, gauge_idx, 1 - parity); Vector in; if (twist == 1) { in = arg.in(back_idx + coord.s * arg.dc.volume_4d_cb, their_spinor_parity); - in = arg.a * (in + arg.b * in.igamma(4)); // apply A^{-1} to in + in = arg.a * (in + arg.b * in.template igamma<4>()); // apply A^{-1} to in } else { // twisted doublet Vector in0 = arg.in(back_idx + 0 * arg.dc.volume_4d_cb, their_spinor_parity); Vector in1 = arg.in(back_idx + 1 * arg.dc.volume_4d_cb, their_spinor_parity); if (coord.s == 0) - in = arg.a * (in0 + arg.b * in0.igamma(4) + arg.c * in1); + in = arg.a * (in0 + arg.b * in0.template igamma<4>() + arg.c * in1); else - in = arg.a * (in1 - arg.b * in1.igamma(4) + arg.c * in0); + in = arg.a * (in1 - arg.b * in1.template igamma<4>() + arg.c * in0); } - out += (conj(U) * in.project(d, proj_dir)).reconstruct(d, proj_dir); + out += (conj(U) * in.template project()).template reconstruct(); } } - } // nDim + }); // nDim } template @@ -167,7 +166,7 @@ namespace quda if (xpay && mykernel_type == INTERIOR_KERNEL) { Vector x = arg.x(coord.x_cb, my_spinor_parity); if (!dagger || Arg::asymmetric) { - out += arg.a_inv * (x + arg.b_inv * x.igamma(4)); // apply inverse twist which is undone below + out += arg.a_inv * (x + arg.b_inv * x.template igamma<4>()); // apply inverse twist which is undone below } else { out += x; // just directly add since twist already applied in the dslash } @@ -178,7 +177,7 @@ namespace quda } if (isComplete(arg, coord) && active) { - if (!dagger || Arg::asymmetric) out = arg.a * (out + arg.b * out.igamma(4)); // apply A^{-1} to D*in + if (!dagger || Arg::asymmetric) out = arg.a * (out + arg.b * out.template igamma<4>()); // apply A^{-1} to D*in } if (mykernel_type != EXTERIOR_KERNEL_ALL || active) arg.out(coord.x_cb, my_spinor_parity) = out; From fd79ac52acbae9733b480b079e76392ba79cdec9 Mon Sep 17 00:00:00 2001 From: Xiao-Yong Jin Date: Fri, 22 Jul 2022 02:45:34 +0000 Subject: [PATCH 6/9] make static_for work for c++17 --- include/kernels/clover_outer_product.cuh | 2 +- include/kernels/dslash_mdw_fused.cuh | 2 +- include/kernels/dslash_twisted_mass_preconditioned.cuh | 2 +- include/kernels/dslash_wilson.cuh | 2 +- include/util_quda.h | 10 ++++++++++ 5 files changed, 14 insertions(+), 4 deletions(-) diff --git a/include/kernels/clover_outer_product.cuh b/include/kernels/clover_outer_product.cuh index a80990457b..1b4d96afb1 100644 --- a/include/kernels/clover_outer_product.cuh +++ b/include/kernels/clover_outer_product.cuh @@ -63,7 +63,7 @@ namespace quda { Spinor A = arg.inA(x_cb, 0); Spinor C = arg.inC(x_cb, 0); - static_for<0,4>([&]{ + static_for<0,4>(static_for_var(dim){ int shift[4] = {0, 0, 0, 0}; shift[dim] = 1; const int nbr_idx = neighborIndex(x_cb, shift, arg.partitioned, arg.parity, arg.X); diff --git a/include/kernels/dslash_mdw_fused.cuh b/include/kernels/dslash_mdw_fused.cuh index 1b48a71bee..4c472d86a1 100644 --- a/include/kernels/dslash_mdw_fused.cuh +++ b/include/kernels/dslash_mdw_fused.cuh @@ -211,7 +211,7 @@ namespace quda { const int index_4d_cb = index_4d_cb_from_coordinate_4d(coordinate, arg.dim); - static_for<0,4>([&] // loop over dimension + static_for<0,4>(static_for_var(d) // loop over dimension { int x[4] = {coordinate[0], coordinate[1], coordinate[2], coordinate[3]}; x[d] = (coordinate[d] == arg.dim[d] - 1 && !arg.comm[d]) ? 0 : coordinate[d] + 1; diff --git a/include/kernels/dslash_twisted_mass_preconditioned.cuh b/include/kernels/dslash_twisted_mass_preconditioned.cuh index 231f56c4e6..83e77e01a7 100644 --- a/include/kernels/dslash_twisted_mass_preconditioned.cuh +++ b/include/kernels/dslash_twisted_mass_preconditioned.cuh @@ -55,7 +55,7 @@ namespace quda typedef Matrix, Arg::nColor> Link; const int their_spinor_parity = nParity == 2 ? 1 - parity : 0; - static_for<0,Arg::nDim>([&]{ // loop over dimension + static_for<0,Arg::nDim>(static_for_var(d){ // loop over dimension { // Forward gather - compute fwd offset for vector fetch const int fwd_idx = getNeighborIndexCB(coord, d, +1, arg.dc); constexpr int proj_dir = dagger ? +1 : -1; diff --git a/include/kernels/dslash_wilson.cuh b/include/kernels/dslash_wilson.cuh index 99f9e3786e..714550950b 100644 --- a/include/kernels/dslash_wilson.cuh +++ b/include/kernels/dslash_wilson.cuh @@ -77,7 +77,7 @@ namespace quda // parity for gauge field - include residual parity from 5-d => 4-d checkerboarding const int gauge_parity = (Arg::nDim == 5 ? (coord.x_cb / arg.dc.volume_4d_cb + parity) % 2 : parity); - static_for<0,4>([&]{ // loop over dimension - 4 and not nDim since this is used for DWF as well + static_for<0,4>(static_for_var(d){ // loop over dimension - 4 and not nDim since this is used for DWF as well { // Forward gather - compute fwd offset for vector fetch const int fwd_idx = getNeighborIndexCB(coord, d, +1, arg.dc); const int gauge_idx = (Arg::nDim == 5 ? coord.x_cb % arg.dc.volume_4d_cb : coord.x_cb); diff --git a/include/util_quda.h b/include/util_quda.h index 2142b4890f..359b47c99e 100644 --- a/include/util_quda.h +++ b/include/util_quda.h @@ -27,12 +27,22 @@ namespace quda __attribute__((always_inline)) void static_for(F&&f) { if constexpr (A()); +#else f.template operator()(); +#endif QUDA_MUSTTAIL return static_for(std::forward(f)); } } } // namespace quda +#if __cplusplus < 202002L +#define static_for_var(i) [&](auto i) +#else +#define static_for_var(i) [&] +#endif + /** @brief Query whether autotuning is enabled or not. Default is enabled but can be overridden by setting QUDA_ENABLE_TUNING=0. @return If autotuning is enabled From 3ce2a7b0d0a9d19187a3c99d2b0131cb06201b61 Mon Sep 17 00:00:00 2001 From: Xiao-Yong Jin Date: Fri, 22 Jul 2022 02:59:13 +0000 Subject: [PATCH 7/9] make static_for __device__ --- include/util_quda.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/util_quda.h b/include/util_quda.h index 359b47c99e..206ac73624 100644 --- a/include/util_quda.h +++ b/include/util_quda.h @@ -24,7 +24,7 @@ namespace quda constexpr const char *file_name(const char *str) { return str_slant(str) ? r_slant(str_end(str)) : str; } template - __attribute__((always_inline)) void static_for(F&&f) + __attribute__((always_inline)) __host__ __device__ inline void static_for(F&&f) { if constexpr (A Date: Thu, 28 Jul 2022 19:59:11 +0000 Subject: [PATCH 8/9] static_for: use only the c++17 compat version and remove macros --- include/kernels/clover_outer_product.cuh | 2 +- include/kernels/dslash_mdw_fused.cuh | 2 +- .../dslash_twisted_mass_preconditioned.cuh | 2 +- include/kernels/dslash_wilson.cuh | 2 +- include/util_quda.h | 20 +------------------ 5 files changed, 5 insertions(+), 23 deletions(-) diff --git a/include/kernels/clover_outer_product.cuh b/include/kernels/clover_outer_product.cuh index 1b4d96afb1..06b49d03b4 100644 --- a/include/kernels/clover_outer_product.cuh +++ b/include/kernels/clover_outer_product.cuh @@ -63,7 +63,7 @@ namespace quda { Spinor A = arg.inA(x_cb, 0); Spinor C = arg.inC(x_cb, 0); - static_for<0,4>(static_for_var(dim){ + static_for<0,4>([&](auto dim){ int shift[4] = {0, 0, 0, 0}; shift[dim] = 1; const int nbr_idx = neighborIndex(x_cb, shift, arg.partitioned, arg.parity, arg.X); diff --git a/include/kernels/dslash_mdw_fused.cuh b/include/kernels/dslash_mdw_fused.cuh index 4c472d86a1..65c7210911 100644 --- a/include/kernels/dslash_mdw_fused.cuh +++ b/include/kernels/dslash_mdw_fused.cuh @@ -211,7 +211,7 @@ namespace quda { const int index_4d_cb = index_4d_cb_from_coordinate_4d(coordinate, arg.dim); - static_for<0,4>(static_for_var(d) // loop over dimension + static_for<0,4>([&](auto d) // loop over dimension { int x[4] = {coordinate[0], coordinate[1], coordinate[2], coordinate[3]}; x[d] = (coordinate[d] == arg.dim[d] - 1 && !arg.comm[d]) ? 0 : coordinate[d] + 1; diff --git a/include/kernels/dslash_twisted_mass_preconditioned.cuh b/include/kernels/dslash_twisted_mass_preconditioned.cuh index 83e77e01a7..723cfb3f45 100644 --- a/include/kernels/dslash_twisted_mass_preconditioned.cuh +++ b/include/kernels/dslash_twisted_mass_preconditioned.cuh @@ -55,7 +55,7 @@ namespace quda typedef Matrix, Arg::nColor> Link; const int their_spinor_parity = nParity == 2 ? 1 - parity : 0; - static_for<0,Arg::nDim>(static_for_var(d){ // loop over dimension + static_for<0,Arg::nDim>([&](auto d){ // loop over dimension { // Forward gather - compute fwd offset for vector fetch const int fwd_idx = getNeighborIndexCB(coord, d, +1, arg.dc); constexpr int proj_dir = dagger ? +1 : -1; diff --git a/include/kernels/dslash_wilson.cuh b/include/kernels/dslash_wilson.cuh index 714550950b..0297c44a5e 100644 --- a/include/kernels/dslash_wilson.cuh +++ b/include/kernels/dslash_wilson.cuh @@ -77,7 +77,7 @@ namespace quda // parity for gauge field - include residual parity from 5-d => 4-d checkerboarding const int gauge_parity = (Arg::nDim == 5 ? (coord.x_cb / arg.dc.volume_4d_cb + parity) % 2 : parity); - static_for<0,4>(static_for_var(d){ // loop over dimension - 4 and not nDim since this is used for DWF as well + static_for<0,4>([&](auto d){ // loop over dimension - 4 and not nDim since this is used for DWF as well { // Forward gather - compute fwd offset for vector fetch const int fwd_idx = getNeighborIndexCB(coord, d, +1, arg.dc); const int gauge_idx = (Arg::nDim == 5 ? coord.x_cb % arg.dc.volume_4d_cb : coord.x_cb); diff --git a/include/util_quda.h b/include/util_quda.h index 206ac73624..dbe2402b3a 100644 --- a/include/util_quda.h +++ b/include/util_quda.h @@ -7,14 +7,6 @@ #include #include -#ifndef QUDA_MUSTTAIL -#ifdef __clang__ -#define QUDA_MUSTTAIL __attribute__((musttail)) -#else -#define QUDA_MUSTTAIL -#endif -#endif - namespace quda { // strip path from __FILE__ @@ -27,22 +19,12 @@ namespace quda __attribute__((always_inline)) __host__ __device__ inline void static_for(F&&f) { if constexpr (A()); -#else - f.template operator()(); -#endif - QUDA_MUSTTAIL return static_for(std::forward(f)); + static_for(std::forward(f)); } } } // namespace quda -#if __cplusplus < 202002L -#define static_for_var(i) [&](auto i) -#else -#define static_for_var(i) [&] -#endif - /** @brief Query whether autotuning is enabled or not. Default is enabled but can be overridden by setting QUDA_ENABLE_TUNING=0. @return If autotuning is enabled From 87d76e22d6ea8260131fa0e6624f3b0a248f6e54 Mon Sep 17 00:00:00 2001 From: Xiao-Yong Jin Date: Thu, 28 Jul 2022 20:07:00 +0000 Subject: [PATCH 9/9] Revert "flip the switch in complex i_ function" This reverts commit 3216cb4a0731ecfe31ca922a1f0e8f50638da9a1. --- include/complex_quda.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/complex_quda.h b/include/complex_quda.h index a6a4a3e3e1..66b5eee609 100644 --- a/include/complex_quda.h +++ b/include/complex_quda.h @@ -1207,7 +1207,7 @@ lhs.real()*rhs.imag()+lhs.imag()*rhs.real()); template __host__ __device__ inline complex i_(const complex &a) { // FIXME compiler generates worse code with "optimal" code -#if 0 +#if 1 return complex(0.0, 1.0) * a; #else return complex(-a.imag(), a.real());