Skip to content

Commit

Permalink
Merge branch 'dev' into bloom-filter
Browse files Browse the repository at this point in the history
  • Loading branch information
sleeepyjack committed Jul 29, 2022
2 parents faa5e85 + 1a0785b commit 5578405
Show file tree
Hide file tree
Showing 7 changed files with 273 additions and 76 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ We plan to add many GPU-accelerated, concurrent data structures to `cuCollection

`cuco::static_map` is a fixed-size hash table using open addressing with linear probing.

It provides both host, bulk APIs ([example](https://github.com/NVIDIA/cuCollections/blob/dev/examples/static_map/static_map_example.cu)) as well as device APIs for individual operations ([example (TODO)]()).
It provides both host, bulk APIs ([example](https://github.com/NVIDIA/cuCollections/blob/dev/examples/static_map/host_bulk_example.cu)) as well as device APIs for individual operations ([example](https://github.com/NVIDIA/cuCollections/blob/dev/examples/static_map/device_view_example.cu)).

See the Doxygen documentation in `static_map.cuh` for more detailed information.

Expand Down
3 changes: 2 additions & 1 deletion examples/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ endfunction(ConfigureExample)
### Example sources ###############################################################################
###################################################################################################

ConfigureExample(STATIC_MAP_EXAMPLE "${CMAKE_CURRENT_SOURCE_DIR}/static_map/static_map_example.cu")
ConfigureExample(HOST_BULK_EXAMPLE "${CMAKE_CURRENT_SOURCE_DIR}/static_map/host_bulk_example.cu")
ConfigureExample(DEVICE_SIDE_EXAMPLE "${CMAKE_CURRENT_SOURCE_DIR}/static_map/device_view_example.cu")
ConfigureExample(CUSTOM_TYPE_EXAMPLE "${CMAKE_CURRENT_SOURCE_DIR}/static_map/custom_type_example.cu")
ConfigureExample(STATIC_MULTIMAP_EXAMPLE "${CMAKE_CURRENT_SOURCE_DIR}/static_multimap/static_multimap_example.cu")
ConfigureExample(BLOOM_FILTER_EXAMPLE "${CMAKE_CURRENT_SOURCE_DIR}/bloom_filter/bloom_filter_example.cu")
Expand Down
182 changes: 182 additions & 0 deletions examples/static_map/device_view_example.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
/*
* Copyright (c) 2020-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <cuco/static_map.cuh>

#include <thrust/device_vector.h>
#include <thrust/iterator/zip_iterator.h>
#include <thrust/logical.h>
#include <thrust/sequence.h>
#include <thrust/tuple.h>

#include <cmath>
#include <cstddef>
#include <iostream>
#include <limits>

/**
* @file device_side_example.cu
* @brief Demonstrates usage of the device side APIs for individual operations like insert/find.
*
* Individual operations like a single insert or find can be performed in device code via the
* static_map "device_view" types. Note that concurrent insert and find are not supported, and
* therefore there are separate view types for insert and find to help prevent undefined behavior.
*
* @note This example is for demonstration purposes only. It is not intended to show the most
* performant way to do the example algorithm.
*
*/

/**
* @brief Inserts keys that pass the specified predicated into the map.
*
* @tparam Map Type of the map returned from static_map::get_mutable_device_view
* @tparam KeyIter Input iterator whose value_type convertible to Map::key_type
* @tparam ValueIter Input iterator whose value_type is convertible to Map::mapped_type
* @tparam Predicate Unary predicate
*
* @param[in] map_view View of the map into which inserts will be performed
* @param[in] key_begin The beginning of the range of keys to insert
* @param[in] value_begin The beginning of the range of values associated with each key to insert
* @param[in] num_keys The total number of keys and values
* @param[in] pred Unary predicate applied to each key. Only keys that pass the predicated will be
* inserted.
* @param[out] num_inserted The total number of keys successfully inserted
*/
template <typename Map, typename KeyIter, typename ValueIter, typename Predicate>
__global__ void filtered_insert(Map map_view,
KeyIter key_begin,
ValueIter value_begin,
std::size_t num_keys,
Predicate pred,
int* num_inserted)
{
auto tid = threadIdx.x + blockIdx.x * blockDim.x;

std::size_t counter = 0;
while (tid < num_keys) {
// Only insert keys that pass the predicate
if (pred(key_begin[tid])) {
// mutable_device_view::insert returns `true` if it is the first time the given key was
// inserted and `false` if the key already existed
if (map_view.insert({key_begin[tid], value_begin[tid]})) {
++counter; // Count number of successfully inserted keys
}
}
tid += gridDim.x * blockDim.x;
}

// Update global count of inserted keys
atomicAdd(num_inserted, counter);
}

/**
* @brief For keys that have a match in the map, increments their corresponding value by one.
*
* @tparam Map Type of the map returned from static_map::get_device_view
* @tparam KeyIter Input iterator whose value_type convertible to Map::key_type
*
* @param map_view View of the map into which queries will be performed
* @param key_begin The beginning of the range of keys to query
* @param num_keys The total number of keys
*/
template <typename Map, typename KeyIter>
__global__ void increment_values(Map map_view, KeyIter key_begin, std::size_t num_keys)
{
auto tid = threadIdx.x + blockIdx.x * blockDim.x;
while (tid < num_keys) {
// If the key exists in the map, find returns an iterator to the specified key. Otherwise it
// returns map.end()
auto found = map_view.find(key_begin[tid]);
if (found != map_view.end()) {
// If the key exists, atomically increment the associated value
// The value type of the iterator is pair<cuda::atomic<Key>, cuda::atomic<Value>>
found->second.fetch_add(1, cuda::memory_order_relaxed);
}
tid += gridDim.x * blockDim.x;
}
}

int main(void)
{
using Key = int;
using Value = int;

// Empty slots are represented by reserved "sentinel" values. These values should be selected such
// that they never occur in your input data.
Key constexpr empty_key_sentinel = -1;
Value constexpr empty_value_sentinel = -1;

// Number of key/value pairs to be inserted
std::size_t constexpr num_keys = 50'000;

// Create a sequence of keys and values {{0,0}, {1,1}, ... {i,i}}
thrust::device_vector<Key> insert_keys(num_keys);
thrust::sequence(insert_keys.begin(), insert_keys.end(), 0);
thrust::device_vector<Value> insert_values(num_keys);
thrust::sequence(insert_values.begin(), insert_values.end(), 0);

// Compute capacity based on a 50% load factor
auto constexpr load_factor = 0.5;
std::size_t const capacity = std::ceil(num_keys / load_factor);

// Constructs a map with "capacity" slots using -1 and -1 as the empty key/value sentinels.
cuco::static_map<Key, Value> map{capacity,
cuco::sentinel::empty_key{empty_key_sentinel},
cuco::sentinel::empty_value{empty_value_sentinel}};

// Get a non-owning, mutable view of the map that allows inserts to pass by value into the kernel
auto device_insert_view = map.get_device_mutable_view();

// Predicate will only insert even keys
auto is_even = [] __device__(auto key) { return (key % 2) == 0; };

// Allocate storage for count of number of inserted keys
thrust::device_vector<int> num_inserted(1);

auto constexpr block_size = 256;
auto const grid_size = (num_keys + block_size - 1) / block_size;
filtered_insert<<<grid_size, block_size>>>(device_insert_view,
insert_keys.begin(),
insert_values.begin(),
num_keys,
is_even,
num_inserted.data().get());

std::cout << "Number of keys inserted: " << num_inserted[0] << std::endl;

// Get a non-owning view of the map that allows find operations to pass by value into the kernel
auto device_find_view = map.get_device_view();

increment_values<<<grid_size, block_size>>>(device_find_view, insert_keys.begin(), num_keys);

// Retrieve contents of all the non-empty slots in the map
thrust::device_vector<Key> contained_keys(num_inserted[0]);
thrust::device_vector<Value> contained_values(num_inserted[0]);
map.retrieve_all(contained_keys.begin(), contained_values.begin());

auto tuple_iter =
thrust::make_zip_iterator(thrust::make_tuple(contained_keys.begin(), contained_values.begin()));
// Iterate over all slot contents and verify that `slot.key + 1 == slot.value` is always true.
auto result = thrust::all_of(
thrust::device, tuple_iter, tuple_iter + num_inserted[0], [] __device__(auto const& tuple) {
return thrust::get<0>(tuple) + 1 == thrust::get<1>(tuple);
});

if (result) { std::cout << "Success! Target values are properly incremented.\n"; }

return 0;
}
86 changes: 86 additions & 0 deletions examples/static_map/host_bulk_example.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
/*
* Copyright (c) 2020-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <cuco/static_map.cuh>

#include <thrust/device_vector.h>
#include <thrust/equal.h>
#include <thrust/iterator/zip_iterator.h>
#include <thrust/sequence.h>
#include <thrust/transform.h>

#include <cmath>
#include <cstddef>
#include <iostream>
#include <limits>

/**
* @file host_bulk_example.cu
* @brief Demonstrates usage of the static_map "bulk" host APIs.
*
* The bulk APIs are only invocable from the host and are used for doing operations like insert or
* find on a set of keys.
*
*/

int main(void)
{
using Key = int;
using Value = int;

// Empty slots are represented by reserved "sentinel" values. These values should be selected such
// that they never occur in your input data.
Key constexpr empty_key_sentinel = -1;
Value constexpr empty_value_sentinel = -1;

// Number of key/value pairs to be inserted
std::size_t constexpr num_keys = 50'000;

// Compute capacity based on a 50% load factor
auto constexpr load_factor = 0.5;
std::size_t const capacity = std::ceil(num_keys / load_factor);

// Constructs a map with "capacity" slots using -1 and -1 as the empty key/value sentinels.
cuco::static_map<Key, Value> map{capacity,
cuco::sentinel::empty_key{empty_key_sentinel},
cuco::sentinel::empty_value{empty_value_sentinel}};

// Create a sequence of keys and values {{0,0}, {1,1}, ... {i,i}}
thrust::device_vector<Key> insert_keys(num_keys);
thrust::sequence(insert_keys.begin(), insert_keys.end(), 0);
thrust::device_vector<Value> insert_values(num_keys);
thrust::sequence(insert_values.begin(), insert_values.end(), 0);
auto zipped =
thrust::make_zip_iterator(thrust::make_tuple(insert_keys.begin(), insert_values.begin()));

// Inserts all pairs into the map
map.insert(zipped, zipped + insert_keys.size());

// Storage for found values
thrust::device_vector<Value> found_values(num_keys);

// Finds all keys {0, 1, 2, ...} and stores associated values into `found_values`
// If a key `keys_to_find[i]` doesn't exist, `found_values[i] == empty_value_sentinel`
map.find(insert_keys.begin(), insert_keys.end(), found_values.begin());

// Verify that all the found values match the inserted values
bool const all_values_match =
thrust::equal(found_values.begin(), found_values.end(), insert_values.begin());

if (all_values_match) { std::cout << "Success! Found all values.\n"; }

return 0;
}
72 changes: 0 additions & 72 deletions examples/static_map/static_map_example.cu

This file was deleted.

2 changes: 1 addition & 1 deletion include/cuco/detail/__config
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
#endif

#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700)
#define CUCO_HAS_LARGE_TYPE_SUPPORT
#define CUCO_HAS_INDEPENDENT_THREADS
#endif

#if defined(CUDART_VERSION) && (CUDART_VERSION >= 11500) && defined(__CUDA_ARCH__) && \
Expand Down
2 changes: 1 addition & 1 deletion include/cuco/static_map.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ class static_map {
using counter_allocator_type = typename std::allocator_traits<Allocator>::rebind_alloc<
atomic_ctr_type>; ///< Type of the allocator to (de)allocate atomic counters

#if !defined(CUCO_HAS_LARGE_TYPE_SUPPORT)
#if !defined(CUCO_HAS_INDEPENDENT_THREADS)
static_assert(atomic_key_type::is_always_lock_free,
"A key type larger than 8B is supported for only sm_70 and up.");
static_assert(atomic_mapped_type::is_always_lock_free,
Expand Down

0 comments on commit 5578405

Please sign in to comment.