Skip to content

Commit

Permalink
Stream handling changes for OpenRNG and refactoring RNG code for MKL (#…
Browse files Browse the repository at this point in the history
…3088)

* Made the code consistent with clone.

* Removed some commented out variables.

* Refactored the code to use vslNewStreamEx.

* Additional cleanup.

* Added space at end of file.

* fixed formatting.

---------

Co-authored-by: Alexander Andreev <[email protected]>
  • Loading branch information
KateBlueSky and Alexsandruss authored Feb 24, 2025
1 parent 90ac585 commit ed04659
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 73 deletions.
43 changes: 8 additions & 35 deletions cpp/daal/src/externals/service_rng_mkl.h
Original file line number Diff line number Diff line change
Expand Up @@ -252,44 +252,26 @@ template <CpuType cpu>
class BaseRNG : public BaseRNGIface<cpu>
{
public:
BaseRNG(const unsigned int seed, const int brngId) : _stream(0), _seed(nullptr), _seedSize(0), _brngId(brngId)
BaseRNG(const unsigned int seed, const int brngId) : _stream(0)
{
services::Status s = allocSeeds(1);
if (s)
{
_seed[0] = seed;
int errcode = 0;
__DAAL_VSLFN_CALL_NR(vslNewStreamEx, (&_stream, (const MKL_INT)brngId, (const MKL_INT)1, &seed), errcode);
}
int errcode = 0;
__DAAL_VSLFN_CALL_NR(vslNewStreamEx, (&_stream, (const MKL_INT)brngId, (const MKL_INT)1, &seed), errcode);
}

BaseRNG(const size_t n, const unsigned int * seed, const int brngId = __DAAL_BRNG_MT19937)
: _stream(0), _seed(nullptr), _seedSize(0), _brngId(brngId)
BaseRNG(const size_t n, const unsigned int * seed, const int brngId = __DAAL_BRNG_MT19937) : _stream(0)
{
services::Status s = allocSeeds(n);
if (s)
{
if (seed)
{
for (size_t i = 0; i < n; i++)
{
_seed[i] = seed[i];
}
}
int errcode = 0;
__DAAL_VSLFN_CALL_NR(vslNewStreamEx, (&_stream, (const MKL_INT)brngId, (const MKL_INT)n, seed), errcode);
}
int errcode = 0;
__DAAL_VSLFN_CALL_NR(vslNewStreamEx, (&_stream, (const MKL_INT)brngId, (const MKL_INT)n, seed), errcode);
}

BaseRNG(const BaseRNG<cpu> & other) : _stream(0), _seed(nullptr), _seedSize(other._seedSize), _brngId(other._brngId)
BaseRNG(const BaseRNG<cpu> & other) : _stream(0)
{
int errcode = 0;
__DAAL_VSLFN_CALL_NR(vslCopyStream, (&_stream, other._stream), errcode);
}

~BaseRNG()
{
daal::services::daal_free((void *)_seed);
int errcode = 0;
__DAAL_VSLFN_CALL_NR(vslDeleteStream, (&_stream), errcode);
}
Expand Down Expand Up @@ -333,19 +315,10 @@ class BaseRNG : public BaseRNGIface<cpu>
void * getState() { return _stream; }

protected:
services::Status allocSeeds(const size_t n)
{
_seedSize = n;
_seed = (unsigned int *)daal::services::daal_malloc(sizeof(unsigned int) * n);
DAAL_CHECK_MALLOC(_seed);
return services::Status();
}
services::Status allocSeeds(const size_t n) { return services::Status(); }

private:
void * _stream;
unsigned int * _seed;
size_t _seedSize;
const int _brngId;
};

/*
Expand Down
47 changes: 9 additions & 38 deletions cpp/daal/src/externals/service_rng_openrng.h
Original file line number Diff line number Diff line change
Expand Up @@ -244,48 +244,22 @@ template <CpuType cpu>
class BaseRNG : public BaseRNGIface<cpu>
{
public:
BaseRNG(const unsigned int seed, const int brngId) : _stream(0), _seed(nullptr), _seedSize(0), _brngId(brngId)
BaseRNG(const unsigned int seed, const int brngId) : _stream(0)
{
services::Status s = allocSeeds(1);
if (s)
{
_seed[0] = seed;
int errcode = 0;
errcode = vslNewStreamEx(&_stream, (openrng_int_t)brngId, 1, &seed);
}
int errcode = 0;
__DAAL_VSLFN_CALL_NR(vslNewStreamEx, (&_stream, (const MKL_INT)brngId, (const MKL_INT)1, &seed), errcode);
}

BaseRNG(const size_t n, const unsigned int * seed, const int brngId = __DAAL_BRNG_MT19937)
: _stream(0), _seed(nullptr), _seedSize(0), _brngId(brngId)
BaseRNG(const size_t n, const unsigned int * seed, const int brngId = __DAAL_BRNG_MT19937) : _stream(0)
{
services::Status s = allocSeeds(n);
if (s)
{
if (seed)
{
for (size_t i = 0; i < n; i++)
{
_seed[i] = seed[i];
}
}
int errcode = 0;
errcode = vslNewStreamEx(&_stream, (openrng_int_t)brngId, (openrng_int_t)n, seed);
}
int errcode = 0;
__DAAL_VSLFN_CALL_NR(vslNewStreamEx, (&_stream, (const MKL_INT)brngId, (const MKL_INT)n, seed), errcode);
}

BaseRNG(const BaseRNG<cpu> & other) : _stream(0), _seed(nullptr), _seedSize(other._seedSize), _brngId(other._brngId)
BaseRNG(const BaseRNG<cpu> & other) : _stream(0)
{
services::Status s = allocSeeds(_seedSize);
if (s)
{
for (size_t i = 0; i < _seedSize; i++)
{
_seed[i] = other._seed[i];
}
int errcode = 0;
errcode = vslNewStreamEx(&_stream, _brngId, _seedSize, _seed);
if (!errcode) errcode = vslCopyStreamState(_stream, other._stream);
}
int errcode = 0;
__DAAL_VSLFN_CALL_NR(vslCopyStream, (&_stream, other._stream), errcode);
}

~BaseRNG()
Expand Down Expand Up @@ -344,9 +318,6 @@ class BaseRNG : public BaseRNGIface<cpu>

private:
void * _stream;
unsigned int * _seed;
size_t _seedSize;
const int _brngId;
};

/*
Expand Down

0 comments on commit ed04659

Please sign in to comment.