diff --git a/tools/util/include/cutlass/util/reference/device/tensor_foreach.h b/tools/util/include/cutlass/util/reference/device/tensor_foreach.h index 728c0a02f..37e238e86 100644 --- a/tools/util/include/cutlass/util/reference/device/tensor_foreach.h +++ b/tools/util/include/cutlass/util/reference/device/tensor_foreach.h @@ -54,7 +54,9 @@ struct TensorForEach { #if defined (CUTLASS_ENABLE_SYCL) // TODO: query the queue for block size block_size = 128; - grid_size = (size(size) + block_size - 1) / block_size; + grid_size = (size.product() + block_size - 1) / block_size; + int sm_count = KernelHardwareInfo::query_device_multiprocessor_count(); + grid_size = grid_size > sm_count / 2 ? sm_count / 2 : grid_size; #else // if grid_size or block_size are zero, query occupancy using the CUDA Occupancy API cudaError_t result = cudaOccupancyMaxPotentialBlockSize( @@ -75,7 +77,7 @@ struct TensorForEach { #if defined(CUTLASS_ENABLE_SYCL) const auto sycl_block = syclcompat::dim3(block_size, 1, 1); const auto sycl_grid = syclcompat::dim3(grid_size, 1, 1); - syclcompat::launch>(sycl_grid, sycl_block, 0, size, params); + syclcompat::launch>(sycl_grid, sycl_block, size, params); #else dim3 grid(grid_size, 1, 1); dim3 block(block_size, 1, 1); @@ -103,7 +105,7 @@ struct TensorDiagonalForEach { #if defined(CUTLASS_ENABLE_SYCL) const auto sycl_block = syclcompat::dim3(block_size, 1, 1); const auto sycl_grid = syclcompat::dim3((end - start + block_size - 1) / block_size, 1, 1); - syclcompat::launch>(sycl_grid, sycl_block, 0, size, params, start, end); + syclcompat::launch>(sycl_grid, sycl_block, size, params, start, end); #else dim3 block(block_size, 1, 1); dim3 grid((end - start + block_size - 1) / block_size, 1, 1); @@ -153,7 +155,7 @@ struct BlockForEach { #if defined(CUTLASS_ENABLE_SYCL) const auto sycl_block = syclcompat::dim3(block_size, 1, 1); const auto sycl_grid = syclcompat::dim3(grid_size, 1, 1); - syclcompat::launch>(sycl_grid, sycl_block, 0, ptr, capacity, params); + syclcompat::launch>(sycl_grid, sycl_block, ptr, capacity, params); #else dim3 grid(grid_size, 1, 1); dim3 block(block_size, 1, 1);