From eb884696fbe8bc0109e1c6f8dce98b9d6b3506c1 Mon Sep 17 00:00:00 2001 From: laurent Date: Tue, 31 Dec 2024 08:59:05 +0100 Subject: [PATCH] Fix a cuda warning. --- candle-core/src/sort.rs | 83 ++++++++++++++++++++++------------------- 1 file changed, 44 insertions(+), 39 deletions(-) diff --git a/candle-core/src/sort.rs b/candle-core/src/sort.rs index 614a37fe65..0ebb18357d 100644 --- a/candle-core/src/sort.rs +++ b/candle-core/src/sort.rs @@ -52,6 +52,49 @@ impl ArgSort { } } +#[cfg(feature = "cuda")] +mod cuda { + use super::*; + use crate::cuda_backend::cudarc::driver::{ + CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig, ValidAsZeroBits, + }; + use crate::cuda_backend::{kernel_name, kernels, CudaStorageSlice as S, WrapErr}; + use crate::{CudaDevice, WithDType}; + + impl crate::cuda_backend::Map1Any for ArgSort { + fn f) -> S>( + &self, + src: &CudaSlice, + dev: &CudaDevice, + layout: &crate::Layout, + _wrap: W, + ) -> Result { + let slice = match layout.contiguous_offsets() { + None => crate::bail!("input has to be contiguous"), + Some((o1, o2)) => src.slice(o1..o2), + }; + let elem_count = layout.shape().elem_count(); + let dst = unsafe { dev.alloc::(elem_count) }.w()?; + let func = if self.asc { + dev.get_or_load_func(&kernel_name::("asort_asc"), kernels::SORT)? + } else { + dev.get_or_load_func(&kernel_name::("asort_desc"), kernels::SORT)? + }; + let ncols = self.last_dim; + let nrows = elem_count / ncols; + let ncols_pad = next_power_of_2(ncols); + let params = (&slice, &dst, ncols as i32, ncols_pad as i32); + let cfg = LaunchConfig { + grid_dim: (1, nrows as u32, 1), + block_dim: (ncols_pad as u32, 1, 1), + shared_mem_bytes: (ncols_pad * std::mem::size_of::()) as u32, + }; + unsafe { func.launch(cfg, params) }.w()?; + Ok(S::U32(dst)) + } + } +} + impl crate::CustomOp1 for ArgSort { fn name(&self) -> &'static str { "argsort" @@ -81,46 +124,8 @@ impl crate::CustomOp1 for ArgSort { storage: &crate::CudaStorage, layout: &crate::Layout, ) -> Result<(crate::CudaStorage, crate::Shape)> { - use crate::cuda_backend::cudarc::driver::{ - CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig, ValidAsZeroBits, - }; - use crate::cuda_backend::{kernel_name, kernels, CudaStorageSlice as S, Map1Any, WrapErr}; - use crate::{CudaDevice, WithDType}; - - impl Map1Any for ArgSort { - fn f) -> S>( - &self, - src: &CudaSlice, - dev: &CudaDevice, - layout: &crate::Layout, - _wrap: W, - ) -> Result { - let slice = match layout.contiguous_offsets() { - None => crate::bail!("input has to be contiguous"), - Some((o1, o2)) => src.slice(o1..o2), - }; - let elem_count = layout.shape().elem_count(); - let dst = unsafe { dev.alloc::(elem_count) }.w()?; - let func = if self.asc { - dev.get_or_load_func(&kernel_name::("asort_asc"), kernels::SORT)? - } else { - dev.get_or_load_func(&kernel_name::("asort_desc"), kernels::SORT)? - }; - let ncols = self.last_dim; - let nrows = elem_count / ncols; - let ncols_pad = next_power_of_2(ncols); - let params = (&slice, &dst, ncols as i32, ncols_pad as i32); - let cfg = LaunchConfig { - grid_dim: (1, nrows as u32, 1), - block_dim: (ncols_pad as u32, 1, 1), - shared_mem_bytes: (ncols_pad * std::mem::size_of::()) as u32, - }; - unsafe { func.launch(cfg, params) }.w()?; - Ok(S::U32(dst)) - } - } - use crate::backend::BackendStorage; + use crate::cuda_backend::Map1Any; let dev = storage.device(); let slice = self.map(&storage.slice, dev, layout)?; let dst = crate::cuda_backend::CudaStorage {