diff --git a/cpp/src/neighbors/detail/cagra/add_nodes.cuh b/cpp/src/neighbors/detail/cagra/add_nodes.cuh index 358b7643e..453928992 100644 --- a/cpp/src/neighbors/detail/cagra/add_nodes.cuh +++ b/cpp/src/neighbors/detail/cagra/add_nodes.cuh @@ -37,7 +37,8 @@ void add_node_core( const cuvs::neighbors::cagra::index& idx, raft::mdspan, raft::layout_stride, Accessor> additional_dataset_view, - raft::host_matrix_view updated_graph) + raft::host_matrix_view updated_graph, + const cuvs::neighbors::cagra::extend_params& extend_params) { using DistanceT = float; const std::size_t degree = idx.graph_degree(); @@ -68,7 +69,19 @@ void add_node_core( new_size, raft::resource::get_cuda_stream(handle)); - const std::size_t max_chunk_size = 1024; + std::size_t data_size_per_vector = + sizeof(IdxT) * base_degree + sizeof(DistanceT) * base_degree + sizeof(T) * dim; + cudaPointerAttributes attr; + RAFT_CUDA_TRY(cudaPointerGetAttributes(&attr, additional_dataset_view.data_handle())); + if (attr.devicePointer == nullptr) { + // for batch_load_iterator + data_size_per_vector += sizeof(T) * dim; + } + + const std::size_t max_search_batch_size = + std::min(std::max(1lu, raft::resource::get_workspace_free_bytes(handle) / data_size_per_vector), + num_add); + RAFT_EXPECTS(max_search_batch_size > 0, "No enough working memory space is left."); cuvs::neighbors::cagra::search_params params; params.itopk_size = std::max(base_degree * 2lu, 256lu); @@ -77,24 +90,24 @@ void add_node_core( auto mr = raft::resource::get_workspace_resource(handle); auto neighbor_indices = raft::make_device_mdarray( - handle, mr, raft::make_extents(max_chunk_size, base_degree)); + handle, mr, raft::make_extents(max_search_batch_size, base_degree)); auto neighbor_distances = raft::make_device_mdarray( - handle, mr, raft::make_extents(max_chunk_size, base_degree)); + handle, mr, raft::make_extents(max_search_batch_size, base_degree)); auto queries = raft::make_device_mdarray( - handle, mr, raft::make_extents(max_chunk_size, dim)); + handle, mr, raft::make_extents(max_search_batch_size, dim)); auto host_neighbor_indices = - raft::make_host_matrix(max_chunk_size, base_degree); + raft::make_host_matrix(max_search_batch_size, base_degree); cuvs::spatial::knn::detail::utils::batch_load_iterator additional_dataset_batch( additional_dataset_view.data_handle(), num_add, additional_dataset_view.stride(0), - max_chunk_size, + max_search_batch_size, raft::resource::get_cuda_stream(handle), - raft::resource::get_workspace_resource(handle)); + mr); for (const auto& batch : additional_dataset_batch) { // Step 1: Obtain K (=base_degree) nearest neighbors of the new vectors by CAGRA search // Create queries @@ -254,7 +267,8 @@ void add_graph_nodes( const std::size_t degree = index.graph_degree(); const std::size_t dim = index.dim(); const std::size_t stride = input_updated_dataset_view.stride(0); - const std::size_t max_chunk_size_ = params.max_chunk_size == 0 ? 1 : params.max_chunk_size; + const std::size_t max_chunk_size_ = + params.max_chunk_size == 0 ? new_dataset_size : params.max_chunk_size; raft::copy(updated_graph_view.data_handle(), index.graph().data_handle(), @@ -298,7 +312,7 @@ void add_graph_nodes( stride); neighbors::cagra::add_node_core( - handle, internal_index, additional_dataset_view, updated_graph); + handle, internal_index, additional_dataset_view, updated_graph, params); raft::resource::sync_stream(handle); } }