-
Notifications
You must be signed in to change notification settings - Fork 221
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
KMeans Sparse Init Update #2796
Changes from all commits
2f52a21
eee6863
328407d
137dee4
bc5138b
a5891d0
d7c2333
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,6 +14,8 @@ | |
* limitations under the License. | ||
*******************************************************************************/ | ||
|
||
#include <cmath> | ||
|
||
#include <daal/src/algorithms/kmeans/kmeans_init_kernel.h> | ||
|
||
#include "oneapi/dal/algo/kmeans_init/backend/cpu/compute_kernel.hpp" | ||
|
@@ -43,9 +45,18 @@ static compute_result<Task> call_daal_kernel(const context_cpu& ctx, | |
const std::int64_t column_count = data.get_column_count(); | ||
const std::int64_t cluster_count = desc.get_cluster_count(); | ||
|
||
//number of trials to pick each centroid from, 2 + int(ln(cluster_count)) works better than vanilla kmeans++ | ||
//https://github.com/scikit-learn/scikit-learn/blob/a63b021310ba13ea39ad3555f550d8aeec3002c5/sklearn/cluster/_kmeans.py#L108 | ||
std::int64_t trial_count = desc.get_local_trials_count(); | ||
if (trial_count == -1) { | ||
const auto additional = std::log(cluster_count); | ||
trial_count = 2 + std::int64_t(additional); | ||
} | ||
|
||
Comment on lines
+48
to
+55
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What kind of improvements this give? It seems this changes original behavior. Please also reflect it in the documentation if not already done. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is actually an original behavior in daal and oneDAL distributed, somehow was missed on oneDAL CPU |
||
daal_kmeans_init::Parameter par(dal::detail::integral_cast<std::size_t>(cluster_count), | ||
0, | ||
dal::detail::integral_cast<std::size_t>(desc.get_seed())); | ||
par.nTrials = trial_count; | ||
|
||
const auto daal_data = interop::convert_to_daal_table<Float>(data); | ||
const std::size_t len_input = 1; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add the description to updateMinDistForITrials function.
I understand that it was not there before, but I hope that by adding couple of comments at a time we can make oneDAL's code more readable.