From de60f7a8e49937d26138fef680676c96dd56ec47 Mon Sep 17 00:00:00 2001 From: Andrey Prokopenko Date: Mon, 6 Jan 2025 21:29:03 -0500 Subject: [PATCH] bugfix: make sure we don't exceed scratch size capacity in BF::query --- src/spatial/detail/ArborX_BruteForceImpl.hpp | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/src/spatial/detail/ArborX_BruteForceImpl.hpp b/src/spatial/detail/ArborX_BruteForceImpl.hpp index d849ea5d3..a294ba186 100644 --- a/src/spatial/detail/ArborX_BruteForceImpl.hpp +++ b/src/spatial/detail/ArborX_BruteForceImpl.hpp @@ -62,12 +62,24 @@ struct BruteForceImpl int const n_indexables = values.size(); int const n_predicates = predicates.size(); - int max_scratch_size = TeamPolicy::scratch_size_max(0); + int const max_scratch_size = TeamPolicy::scratch_size_max(0); + // FIXME: adjust max_scratch_size to compensate for potential alignment + // additions to make sure we don't accidentally exceed capacity + // 8 is a magical valie of scratch_memory_space::ALIGN in Kokkos + int indexable_alignment = + std::max({sizeof(IndexableType), alignof(IndexableType), + static_cast(8)}); + int predicate_alignment = + std::max({sizeof(PredicateType), alignof(PredicateType), + static_cast(8)}); + int available_scratch_size = Kokkos::max( + 0, max_scratch_size - indexable_alignment - predicate_alignment); + // half of the scratch memory used by predicates and half for indexables int const predicates_per_team = - max_scratch_size / 2 / sizeof(PredicateType); + available_scratch_size / 2 / sizeof(PredicateType); int const indexables_per_team = - max_scratch_size / 2 / sizeof(IndexableType); + available_scratch_size / 2 / sizeof(IndexableType); ARBORX_ASSERT(predicates_per_team > 0); ARBORX_ASSERT(indexables_per_team > 0); @@ -87,6 +99,7 @@ struct BruteForceImpl Kokkos::MemoryTraits>; int scratch_size = ScratchPredicateType::shmem_size(predicates_per_team) + ScratchIndexableType::shmem_size(indexables_per_team); + ARBORX_ASSERT(scratch_size <= max_scratch_size); Kokkos::parallel_for( "ArborX::BruteForce::query::spatial::"