Skip to content

Commit

Permalink
completing change
Browse files Browse the repository at this point in the history
  • Loading branch information
viclafargue committed Jan 30, 2025
1 parent 9abe842 commit 6ff906b
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 23 deletions.
2 changes: 1 addition & 1 deletion cpp/src/tsne/tsne_runner.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ class TSNE_runner {
{
distance_and_perplexity();

const auto NNZ = COO_Matrix.nnz;
const auto NNZ = (value_idx)COO_Matrix.nnz;
auto* VAL = COO_Matrix.vals();
const auto* COL = COO_Matrix.cols();
const auto* ROW = COO_Matrix.rows();
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/umap/init_embed/spectral_algo.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ void launcher(const raft::handle_t& handle,
coo->rows(),
coo->cols(),
coo->vals(),
coo->safe_nnz,
coo->nnz,
n,
params->n_components,
tmp_storage.data(),
Expand Down
19 changes: 7 additions & 12 deletions cpp/src/umap/runner.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -456,18 +456,13 @@ void _transform(const raft::handle_t& handle,

raft::sparse::convert::sorted_coo_to_csr(&graph_coo, row_ind.data(), stream);

rmm::device_uvector<value_t> vals_normed(graph_coo.safe_nnz, stream);
RAFT_CUDA_TRY(
cudaMemsetAsync(vals_normed.data(), 0, graph_coo.safe_nnz * sizeof(value_t), stream));
rmm::device_uvector<value_t> vals_normed(graph_coo.nnz, stream);
RAFT_CUDA_TRY(cudaMemsetAsync(vals_normed.data(), 0, graph_coo.nnz * sizeof(value_t), stream));

CUML_LOG_DEBUG("Performing L1 normalization");

raft::sparse::linalg::csr_row_normalize_l1<value_t>(row_ind.data(),
graph_coo.vals(),
graph_coo.safe_nnz,
graph_coo.n_rows,
vals_normed.data(),
stream);
raft::sparse::linalg::csr_row_normalize_l1<value_t>(
row_ind.data(), graph_coo.vals(), graph_coo.nnz, graph_coo.n_rows, vals_normed.data(), stream);

init_transform<TPB_X, value_t><<<grid_n, blk, 0, stream>>>(graph_coo.cols(),
vals_normed.data(),
Expand Down Expand Up @@ -502,7 +497,7 @@ void _transform(const raft::handle_t& handle,
raft::linalg::unaryOp<value_t>(
graph_coo.vals(),
graph_coo.vals(),
graph_coo.safe_nnz,
graph_coo.nnz,
[=] __device__(value_t input) {
if (input < (max / float(n_epochs)))
return 0.0f;
Expand All @@ -525,7 +520,7 @@ void _transform(const raft::handle_t& handle,
rmm::device_uvector<value_t> epochs_per_sample(nnz, stream);

SimplSetEmbedImpl::make_epochs_per_sample(
comp_coo.vals(), comp_coo.safe_nnz, n_epochs, epochs_per_sample.data(), stream);
comp_coo.vals(), comp_coo.nnz, n_epochs, epochs_per_sample.data(), stream);

CUML_LOG_DEBUG("Performing optimization");

Expand All @@ -542,7 +537,7 @@ void _transform(const raft::handle_t& handle,
embedding_n,
comp_coo.rows(),
comp_coo.cols(),
comp_coo.safe_nnz,
comp_coo.nnz,
epochs_per_sample.data(),
params->repulsion_strength,
params,
Expand Down
12 changes: 6 additions & 6 deletions cpp/src/umap/simpl_set_embed/algo.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ template <int TPB_X, typename T>
void launcher(
int m, int n, raft::sparse::COO<T>* in, UMAPParams* params, T* embedding, cudaStream_t stream)
{
uint64_t nnz = in->safe_nnz;
uint64_t nnz = in->nnz;

/**
* Find vals.max()
Expand Down Expand Up @@ -334,14 +334,14 @@ void launcher(
raft::sparse::COO<T> out(stream);
raft::sparse::op::coo_remove_zeros<T>(in, &out, stream);

rmm::device_uvector<T> epochs_per_sample(out.safe_nnz, stream);
RAFT_CUDA_TRY(cudaMemsetAsync(epochs_per_sample.data(), 0, out.safe_nnz * sizeof(T), stream));
rmm::device_uvector<T> epochs_per_sample(out.nnz, stream);
RAFT_CUDA_TRY(cudaMemsetAsync(epochs_per_sample.data(), 0, out.nnz * sizeof(T), stream));

make_epochs_per_sample(out.vals(), out.safe_nnz, n_epochs, epochs_per_sample.data(), stream);
make_epochs_per_sample(out.vals(), out.nnz, n_epochs, epochs_per_sample.data(), stream);

if (ML::default_logger().should_log(ML::level_enum::debug)) {
std::stringstream ss;
ss << raft::arr2Str(epochs_per_sample.data(), out.safe_nnz, "epochs_per_sample", stream);
ss << raft::arr2Str(epochs_per_sample.data(), out.nnz, "epochs_per_sample", stream);
CUML_LOG_DEBUG(ss.str().c_str());
}

Expand All @@ -351,7 +351,7 @@ void launcher(
m,
out.rows(),
out.cols(),
out.safe_nnz,
out.nnz,
epochs_per_sample.data(),
params->repulsion_strength,
params,
Expand Down
6 changes: 3 additions & 3 deletions cpp/src/umap/supervised.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ void reset_local_connectivity(raft::sparse::COO<T>* in_coo,
* and this will update the fuzzy simplicial set to respect that label
* data.
*/
template <typename value_t, int TPB_X>
template <typename value_t, uint64_t TPB_X>
void categorical_simplicial_set_intersection(raft::sparse::COO<value_t>* graph_coo,
value_t* target,
cudaStream_t stream,
Expand All @@ -121,7 +121,7 @@ void categorical_simplicial_set_intersection(raft::sparse::COO<value_t>* graph_c
far_dist);
}

template <typename value_t, int TPB_X>
template <typename value_t, uint64_t TPB_X>
CUML_KERNEL void sset_intersection_kernel(int* row_ind1,
int* cols1,
value_t* vals1,
Expand Down Expand Up @@ -179,7 +179,7 @@ CUML_KERNEL void sset_intersection_kernel(int* row_ind1,
* Computes the CSR column index pointer and values
* for the general simplicial set intersecftion.
*/
template <typename T, int TPB_X>
template <typename T, uint64_t TPB_X>
void general_simplicial_set_intersection(int* row1_ind,
raft::sparse::COO<T>* in1,
int* row2_ind,
Expand Down

0 comments on commit 6ff906b

Please sign in to comment.