From ed046590f42c258c1d8f38dc6b014c5429e9a6cc Mon Sep 17 00:00:00 2001 From: kmcgrie Date: Mon, 24 Feb 2025 09:20:25 -0800 Subject: [PATCH] Stream handling changes for OpenRNG and refactoring RNG code for MKL (#3088) * Made the code consistent with clone. * Removed some commented out variables. * Refactored the code to use vslNewStreamEx. * Additional cleanup. * Added space at end of file. * fixed formatting. --------- Co-authored-by: Alexander Andreev --- cpp/daal/src/externals/service_rng_mkl.h | 43 ++++-------------- cpp/daal/src/externals/service_rng_openrng.h | 47 ++++---------------- 2 files changed, 17 insertions(+), 73 deletions(-) diff --git a/cpp/daal/src/externals/service_rng_mkl.h b/cpp/daal/src/externals/service_rng_mkl.h index d68de75e49d..82fa1b0f35e 100644 --- a/cpp/daal/src/externals/service_rng_mkl.h +++ b/cpp/daal/src/externals/service_rng_mkl.h @@ -252,36 +252,19 @@ template class BaseRNG : public BaseRNGIface { public: - BaseRNG(const unsigned int seed, const int brngId) : _stream(0), _seed(nullptr), _seedSize(0), _brngId(brngId) + BaseRNG(const unsigned int seed, const int brngId) : _stream(0) { - services::Status s = allocSeeds(1); - if (s) - { - _seed[0] = seed; - int errcode = 0; - __DAAL_VSLFN_CALL_NR(vslNewStreamEx, (&_stream, (const MKL_INT)brngId, (const MKL_INT)1, &seed), errcode); - } + int errcode = 0; + __DAAL_VSLFN_CALL_NR(vslNewStreamEx, (&_stream, (const MKL_INT)brngId, (const MKL_INT)1, &seed), errcode); } - BaseRNG(const size_t n, const unsigned int * seed, const int brngId = __DAAL_BRNG_MT19937) - : _stream(0), _seed(nullptr), _seedSize(0), _brngId(brngId) + BaseRNG(const size_t n, const unsigned int * seed, const int brngId = __DAAL_BRNG_MT19937) : _stream(0) { - services::Status s = allocSeeds(n); - if (s) - { - if (seed) - { - for (size_t i = 0; i < n; i++) - { - _seed[i] = seed[i]; - } - } - int errcode = 0; - __DAAL_VSLFN_CALL_NR(vslNewStreamEx, (&_stream, (const MKL_INT)brngId, (const MKL_INT)n, seed), errcode); - } + int errcode = 0; + __DAAL_VSLFN_CALL_NR(vslNewStreamEx, (&_stream, (const MKL_INT)brngId, (const MKL_INT)n, seed), errcode); } - BaseRNG(const BaseRNG & other) : _stream(0), _seed(nullptr), _seedSize(other._seedSize), _brngId(other._brngId) + BaseRNG(const BaseRNG & other) : _stream(0) { int errcode = 0; __DAAL_VSLFN_CALL_NR(vslCopyStream, (&_stream, other._stream), errcode); @@ -289,7 +272,6 @@ class BaseRNG : public BaseRNGIface ~BaseRNG() { - daal::services::daal_free((void *)_seed); int errcode = 0; __DAAL_VSLFN_CALL_NR(vslDeleteStream, (&_stream), errcode); } @@ -333,19 +315,10 @@ class BaseRNG : public BaseRNGIface void * getState() { return _stream; } protected: - services::Status allocSeeds(const size_t n) - { - _seedSize = n; - _seed = (unsigned int *)daal::services::daal_malloc(sizeof(unsigned int) * n); - DAAL_CHECK_MALLOC(_seed); - return services::Status(); - } + services::Status allocSeeds(const size_t n) { return services::Status(); } private: void * _stream; - unsigned int * _seed; - size_t _seedSize; - const int _brngId; }; /* diff --git a/cpp/daal/src/externals/service_rng_openrng.h b/cpp/daal/src/externals/service_rng_openrng.h index d8109ccc366..ec364952322 100644 --- a/cpp/daal/src/externals/service_rng_openrng.h +++ b/cpp/daal/src/externals/service_rng_openrng.h @@ -244,48 +244,22 @@ template class BaseRNG : public BaseRNGIface { public: - BaseRNG(const unsigned int seed, const int brngId) : _stream(0), _seed(nullptr), _seedSize(0), _brngId(brngId) + BaseRNG(const unsigned int seed, const int brngId) : _stream(0) { - services::Status s = allocSeeds(1); - if (s) - { - _seed[0] = seed; - int errcode = 0; - errcode = vslNewStreamEx(&_stream, (openrng_int_t)brngId, 1, &seed); - } + int errcode = 0; + __DAAL_VSLFN_CALL_NR(vslNewStreamEx, (&_stream, (const MKL_INT)brngId, (const MKL_INT)1, &seed), errcode); } - BaseRNG(const size_t n, const unsigned int * seed, const int brngId = __DAAL_BRNG_MT19937) - : _stream(0), _seed(nullptr), _seedSize(0), _brngId(brngId) + BaseRNG(const size_t n, const unsigned int * seed, const int brngId = __DAAL_BRNG_MT19937) : _stream(0) { - services::Status s = allocSeeds(n); - if (s) - { - if (seed) - { - for (size_t i = 0; i < n; i++) - { - _seed[i] = seed[i]; - } - } - int errcode = 0; - errcode = vslNewStreamEx(&_stream, (openrng_int_t)brngId, (openrng_int_t)n, seed); - } + int errcode = 0; + __DAAL_VSLFN_CALL_NR(vslNewStreamEx, (&_stream, (const MKL_INT)brngId, (const MKL_INT)n, seed), errcode); } - BaseRNG(const BaseRNG & other) : _stream(0), _seed(nullptr), _seedSize(other._seedSize), _brngId(other._brngId) + BaseRNG(const BaseRNG & other) : _stream(0) { - services::Status s = allocSeeds(_seedSize); - if (s) - { - for (size_t i = 0; i < _seedSize; i++) - { - _seed[i] = other._seed[i]; - } - int errcode = 0; - errcode = vslNewStreamEx(&_stream, _brngId, _seedSize, _seed); - if (!errcode) errcode = vslCopyStreamState(_stream, other._stream); - } + int errcode = 0; + __DAAL_VSLFN_CALL_NR(vslCopyStream, (&_stream, other._stream), errcode); } ~BaseRNG() @@ -344,9 +318,6 @@ class BaseRNG : public BaseRNGIface private: void * _stream; - unsigned int * _seed; - size_t _seedSize; - const int _brngId; }; /*