Skip to content

Commit

Permalink
replace RNG with raft RNG generator, issue rapidsai#7 and rapidsai#23
Browse files Browse the repository at this point in the history
…for wholegraph 23.10
  • Loading branch information
linhu-nv committed Aug 11, 2023
1 parent f8c08bc commit b8f20d5
Show file tree
Hide file tree
Showing 6 changed files with 75 additions and 179 deletions.
2 changes: 1 addition & 1 deletion cpp/cmake/thirdparty/get_raft.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ endfunction()
# CPM_raft_SOURCE=/path/to/local/raft
find_and_configure_raft(VERSION ${WHOLEGRAPH_MIN_VERSION_raft}
FORK rapidsai
PINNED_TAG branch-${WHOLEGRAPH_BRANCH_VERSION_raft}
PINNED_TAG pull-request/1568

# When PINNED_TAG above doesn't match wholegraph,
# force local raft clone in build directory
Expand Down
28 changes: 22 additions & 6 deletions cpp/src/wholegraph_ops/raft_random_gen.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@

#include <cmath>
#include <wholememory/wholegraph_op.h>
#include <wholememory_ops/raft_random.cuh>

#include <raft/random/rng_state.hpp>
#include <raft/random/rng_device.cuh>

#include "error.hpp"
#include "logger.hpp"
Expand All @@ -37,15 +39,25 @@ wholememory_error_code_t generate_random_positive_int_cpu(int64_t random_seed,
}

auto* output_ptr = wholememory_tensor_get_data_pointer(output);
PCGenerator rng((unsigned long long)random_seed, subsequence, 0);

raft::random::RngState _rngstate(random_seed, 0, raft::random::GeneratorType::GenPC);
raft::random::detail::DeviceState <raft::random::detail::PCGenerator> rngstate(_rngstate);
raft::random::detail::PCGenerator rng(rngstate, (uint64_t)subsequence);

for (int64_t i = 0; i < output_tensor_desc.sizes[0]; i++) {
if (output_tensor_desc.dtype == WHOLEMEMORY_DT_INT) {
raft::random::detail::UniformDistParams<int32_t> params;
params.start = 0;
params.end = 1;
int32_t random_num;
rng.next(random_num);
raft::random::detail::custom_next(rng, &random_num, params, 0, 0);
static_cast<int*>(output_ptr)[i] = random_num;
} else {
raft::random::detail::UniformDistParams<int64_t> params;
params.start = 0;
params.end = 1;
int64_t random_num;
rng.next(random_num);
raft::random::detail::custom_next(rng, &random_num, params, 0, 0);
static_cast<int64_t*>(output_ptr)[i] = random_num;
}
}
Expand All @@ -65,9 +77,13 @@ wholememory_error_code_t generate_exponential_distribution_negative_float_cpu(
return WHOLEMEMORY_INVALID_INPUT;
}
auto* output_ptr = wholememory_tensor_get_data_pointer(output);
PCGenerator rng((unsigned long long)random_seed, subsequence, 0);
raft::random::RngState _rngstate(random_seed, 0, raft::random::GeneratorType::GenPC);
raft::random::detail::DeviceState <raft::random::detail::PCGenerator> rngstate(_rngstate);
raft::random::detail::PCGenerator rng(rngstate, (uint64_t)subsequence);
for (int64_t i = 0; i < output_tensor_desc.sizes[0]; i++) {
float u = -rng.next_float(1.0f, 0.5f);
float u = 0.0;
rng.next(u);
u = -(0.5 + 0.5*u);
uint64_t random_num2 = 0;
int seed_count = -1;
do {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,14 @@
#include <thrust/scan.h>

#include <raft/util/integer_utils.hpp>
#include <raft/random/rng_state.hpp>
#include <raft/random/rng_device.cuh>
#include <wholememory/device_reference.cuh>
#include <wholememory/env_func_ptrs.h>
#include <wholememory/global_reference.h>
#include <wholememory/tensor_description.h>

#include "wholememory_ops/output_memory_handle.hpp"
#include "wholememory_ops/raft_random.cuh"
#include "wholememory_ops/temp_memory_handle.hpp"
#include "wholememory_ops/thrust_allocator.hpp"

Expand Down Expand Up @@ -65,7 +66,7 @@ __global__ void large_sample_kernel(wholememory_gref_t wm_csr_row_ptr,
const IdType* input_nodes,
const int input_node_count,
const int max_sample_count,
unsigned long long random_seed,
raft::random::detail::DeviceState<raft::random::detail::PCGenerator> rngstate,
const int* sample_offset,
wholememory_array_description_t sample_offset_desc,
WMIdType* output,
Expand All @@ -75,8 +76,7 @@ __global__ void large_sample_kernel(wholememory_gref_t wm_csr_row_ptr,
int input_idx = blockIdx.x;
if (input_idx >= input_node_count) return;
int gidx = threadIdx.x + blockIdx.x * blockDim.x;
PCGenerator rng(random_seed, (uint64_t)gidx, (uint64_t)0);

raft::random::detail::PCGenerator rng(rngstate, (uint64_t)gidx);
wholememory::device_reference<WMOffsetType> csr_row_ptr_gen(wm_csr_row_ptr);
wholememory::device_reference<WMIdType> csr_col_ptr_gen(wm_csr_col_ptr);

Expand Down Expand Up @@ -104,8 +104,11 @@ __global__ void large_sample_kernel(wholememory_gref_t wm_csr_row_ptr,
}
__syncthreads();
for (int idx = max_sample_count + threadIdx.x; idx < neighbor_count; idx += blockDim.x) {
raft::random::detail::UniformDistParams<int32_t> params;
params.start = 0;
params.end = 1;
int32_t rand_num;
rng.next(rand_num);
raft::random::detail::custom_next(rng, &rand_num, params, 0, 0);
rand_num %= idx + 1;
if (rand_num < max_sample_count) { atomicMax((int*)(output + offset + rand_num), idx); }
}
Expand Down Expand Up @@ -139,15 +142,15 @@ __global__ void unweighted_sample_without_replacement_kernel(
const IdType* input_nodes,
const int input_node_count,
const int max_sample_count,
unsigned long long random_seed,
raft::random::detail::DeviceState<raft::random::detail::PCGenerator> rngstate,
const int* sample_offset,
wholememory_array_description_t sample_offset_desc,
WMIdType* output,
int* src_lid,
int64_t* output_edge_gid_ptr)
{
int gidx = threadIdx.x + blockIdx.x * blockDim.x;
PCGenerator rng(random_seed, (uint64_t)gidx, (uint64_t)0);
raft::random::detail::PCGenerator rng(rngstate, (uint64_t)gidx);
int input_idx = blockIdx.x;
if (input_idx >= input_node_count) return;

Expand Down Expand Up @@ -193,9 +196,12 @@ __global__ void unweighted_sample_without_replacement_kernel(
#pragma unroll
for (int i = 0; i < ITEMS_PER_THREAD; i++) {
int idx = i * BLOCK_DIM + threadIdx.x;
int32_t random_num;
rng.next(random_num);
int32_t r = idx < M ? (random_num % (N - idx)) : N;
raft::random::detail::UniformDistParams<int32_t> params;
params.start = 0;
params.end = 1;
int32_t rand_num;
raft::random::detail::custom_next(rng, &rand_num, params, 0, 0);
int32_t r = idx < M ? rand_num % ( N - idx ) : N;
sa_p[i] = ((uint64_t)r << 32UL) | idx;
}
__syncthreads();
Expand Down Expand Up @@ -364,6 +370,8 @@ void wholegraph_csr_unweighted_sample_without_replacement_func(
(int64_t*)gen_output_edge_gid_buffer_mh.device_malloc(count, WHOLEMEMORY_DT_INT64);
}
// sample node
raft::random::RngState _rngstate(random_seed, 0, raft::random::GeneratorType::GenPC);
raft::random::detail::DeviceState <raft::random::detail::PCGenerator> rngstate(_rngstate);
if (max_sample_count <= 0) {
sample_all_kernel<IdType, int, WMIdType, int64_t>
<<<center_node_count, 64, 0, stream>>>(wm_csr_row_ptr,
Expand Down Expand Up @@ -392,7 +400,7 @@ void wholegraph_csr_unweighted_sample_without_replacement_func(
(const IdType*)center_nodes,
center_node_count,
max_sample_count,
random_seed,
rngstate,
(const int*)output_sample_offset,
output_sample_offset_desc,
(WMIdType*)output_dest_node_ptr,
Expand All @@ -410,7 +418,7 @@ void wholegraph_csr_unweighted_sample_without_replacement_func(
const IdType* input_nodes,
const int input_node_count,
const int max_sample_count,
unsigned long long random_seed,
raft::random::detail::DeviceState<raft::random::detail::PCGenerator> rngstate,
const int* sample_offset,
wholememory_array_description_t sample_offset_desc,
WMIdType* output,
Expand Down Expand Up @@ -460,7 +468,7 @@ void wholegraph_csr_unweighted_sample_without_replacement_func(
(const IdType*)center_nodes,
center_node_count,
max_sample_count,
random_seed,
rngstate,
(const int*)output_sample_offset,
output_sample_offset_desc,
(WMIdType*)output_dest_node_ptr,
Expand Down
25 changes: 15 additions & 10 deletions cpp/src/wholegraph_ops/weighted_sample_without_replacement_func.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,15 @@
#include <random>
#include <thrust/scan.h>

#include <raft/random/rng_state.hpp>
#include <raft/random/rng_device.cuh>
#include <raft/util/integer_utils.hpp>
#include <wholememory/device_reference.cuh>
#include <wholememory/env_func_ptrs.h>
#include <wholememory/global_reference.h>
#include <wholememory/tensor_description.h>

#include "wholememory_ops/output_memory_handle.hpp"
#include "wholememory_ops/raft_random.cuh"
#include "wholememory_ops/temp_memory_handle.hpp"
#include "wholememory_ops/thrust_allocator.hpp"

Expand All @@ -37,9 +38,11 @@
namespace wholegraph_ops {

template <typename WeightType>
__device__ __forceinline__ float gen_key_from_weight(const WeightType weight, PCGenerator& rng)
__device__ __forceinline__ float gen_key_from_weight(const WeightType weight, raft::random::detail::PCGenerator& rng)
{
float u = -rng.next_float(1.0f, 0.5f);
float u = 0.0;
rng.next(u);
u = -(0.5 + 0.5*u);
uint64_t random_num2 = 0;
int seed_count = -1;
do {
Expand Down Expand Up @@ -73,7 +76,7 @@ __launch_bounds__(BLOCK_SIZE) __global__ void weighted_sample_without_replacemen
const IdType* input_nodes,
const int input_node_count,
const int max_sample_count,
unsigned long long random_seed,
raft::random::detail::DeviceState<raft::random::detail::PCGenerator> rngstate,
const int* sample_offset,
wholememory_array_description_t sample_offset_desc,
const int* target_neighbor_offset,
Expand Down Expand Up @@ -109,7 +112,7 @@ __launch_bounds__(BLOCK_SIZE) __global__ void weighted_sample_without_replacemen
return;
}

PCGenerator rng(random_seed, (uint64_t)gidx, (uint64_t)0);
raft::random::detail::PCGenerator rng(rngstate, (uint64_t)gidx);
for (int id = threadIdx.x; id < neighbor_count; id += BLOCK_SIZE) {
WeightType thread_weight = csr_weight_ptr_gen[start + id];
weight_keys_local_buff[id] =
Expand Down Expand Up @@ -240,7 +243,7 @@ __launch_bounds__(BLOCK_SIZE) __global__ void weighted_sample_without_replacemen
const IdType* input_nodes,
const int input_node_count,
const int max_sample_count,
unsigned long long random_seed,
raft::random::detail::DeviceState<raft::random::detail::PCGenerator> rngstate,
const int* sample_offset,
wholememory_array_description_t sample_offset_desc,
WMIdType* output,
Expand Down Expand Up @@ -272,7 +275,7 @@ __launch_bounds__(BLOCK_SIZE) __global__ void weighted_sample_without_replacemen
}
return;
} else {
PCGenerator rng(random_seed, (uint64_t)gidx, (uint64_t)0);
raft::random::detail::PCGenerator rng(rngstate, (uint64_t)gidx);

float weight_keys[ITEMS_PER_THREAD];
int neighbor_idxs[ITEMS_PER_THREAD];
Expand Down Expand Up @@ -443,6 +446,8 @@ void wholegraph_csr_weighted_sample_without_replacement_func(
(int64_t*)gen_output_edge_gid_buffer_mh.device_malloc(count, WHOLEMEMORY_DT_INT64);
}

raft::random::RngState _rngstate(random_seed, 0, raft::random::GeneratorType::GenPC);
raft::random::detail::DeviceState <raft::random::detail::PCGenerator> rngstate(_rngstate);
if (max_sample_count > sample_count_threshold) {
wholememory_ops::wm_thrust_allocator tmp_thrust_allocator(p_env_fns);
thrust::exclusive_scan(thrust::cuda::par(tmp_thrust_allocator).on(stream),
Expand Down Expand Up @@ -480,7 +485,7 @@ void wholegraph_csr_weighted_sample_without_replacement_func(
(const IdType*)center_nodes,
center_node_count,
max_sample_count,
random_seed,
rngstate,
(const int*)output_sample_offset,
output_sample_offset_desc,
tmp_neighbor_counts_offset,
Expand Down Expand Up @@ -522,7 +527,7 @@ void wholegraph_csr_weighted_sample_without_replacement_func(
const IdType*,
const int,
const int,
unsigned long long,
raft::random::detail::DeviceState<raft::random::detail::PCGenerator>,
const int*,
wholememory_array_description_t,
WMIdType*,
Expand Down Expand Up @@ -592,7 +597,7 @@ void wholegraph_csr_weighted_sample_without_replacement_func(
(const IdType*)center_nodes,
center_node_count,
max_sample_count,
random_seed,
rngstate,
(const int*)output_sample_offset,
output_sample_offset_desc,
(WMIdType*)output_dest_node_ptr,
Expand Down
Loading

0 comments on commit b8f20d5

Please sign in to comment.