From 823ec7fc00a9fc43b95f4b7f86b376840682d072 Mon Sep 17 00:00:00 2001 From: Andrew Myers Date: Fri, 7 Feb 2025 08:27:28 -0800 Subject: [PATCH] Fix recursion-related nvlink warning in RandomGamma (#4324) --- Src/Base/AMReX_Random.H | 67 +++++++++++++++++++++++------------------ 1 file changed, 38 insertions(+), 29 deletions(-) diff --git a/Src/Base/AMReX_Random.H b/Src/Base/AMReX_Random.H index a8bda8b036d..09b2c18c696 100644 --- a/Src/Base/AMReX_Random.H +++ b/Src/Base/AMReX_Random.H @@ -119,6 +119,39 @@ namespace amrex #endif } + namespace random_util { + + AMREX_GPU_DEVICE AMREX_FORCE_INLINE + Real RandomGamma_alpha_ge_1 (Real alpha, Real beta, RandomEngine const& random_engine) + { + AMREX_ASSERT(alpha >= 1); + AMREX_ASSERT(beta > 0); + + Real x, v, u; + Real d = alpha - 1.0_rt / 3.0_rt; + Real c = (1.0_rt / 3.0_rt) / std::sqrt(d); + + while (true) { + do { + x = amrex::RandomNormal(0.0_rt, 1.0_rt, random_engine); + v = 1.0_rt + c * x; + } while (v <= 0.0_rt); + + v = v * v * v; + u = amrex::Random(random_engine); + + if (u < 1.0_rt - 0.0331_rt * x * x * x * x) { + break; + } + + if (std::log(u) < 0.5_rt * x * x + d * (1.0_rt - v + std::log(v))) { + break; + } + } + return beta * d * v; + } + } + /** * \brief Generate a psuedo-random floating point number from the Gamma distribution * @@ -142,40 +175,16 @@ namespace amrex if (alpha < 1) { Real u = amrex::Random(random_engine); - return RandomGamma(1.0_rt + alpha, beta, random_engine) * std::pow(u, 1.0_rt / alpha); - } - - { - Real x, v, u; - Real d = alpha - 1.0_rt / 3.0_rt; - Real c = (1.0_rt / 3.0_rt) / std::sqrt(d); - - while (1) { - do { - x = amrex::RandomNormal(0.0_rt, 1.0_rt, random_engine); - v = 1.0_rt + c * x; - } while (v <= 0.0_rt); - - v = v * v * v; - u = amrex::Random(random_engine); - - if (u < 1.0_rt - 0.0331_rt * x * x * x * x) { - break; - } - - if (std::log(u) < 0.5_rt * x * x + d * (1.0_rt - v + std::log(v))) { - break; - } - } - return beta * d * v; + return amrex::random_util::RandomGamma_alpha_ge_1(1.0_rt + alpha, beta, random_engine) * std::pow(u, 1.0_rt / alpha); + } else { + return amrex::random_util::RandomGamma_alpha_ge_1(alpha, beta, random_engine); } )) AMREX_IF_ON_HOST(( - amrex::ignore_unused(random_engine); - return RandomGamma(alpha, beta); + amrex::ignore_unused(random_engine); + return RandomGamma(alpha, beta); )) - } /**