Skip to content

Commit

Permalink
changed stream handling
Browse files Browse the repository at this point in the history
  • Loading branch information
Alexandr-Solovev committed Feb 24, 2025
1 parent db5c08e commit 0cd591f
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 84 deletions.
56 changes: 10 additions & 46 deletions cpp/daal/src/externals/service_rng_mkl.h
Original file line number Diff line number Diff line change
Expand Up @@ -252,53 +252,26 @@ template <CpuType cpu>
class BaseRNG : public BaseRNGIface<cpu>
{
public:
BaseRNG(const unsigned int seed, const int brngId) : _stream(0), _seed(nullptr), _seedSize(0), _brngId(brngId)
BaseRNG(const unsigned int seed, const int brngId) : _stream(0)
{
services::Status s = allocSeeds(1);
if (s)
{
_seed[0] = seed;
int errcode = 0;
__DAAL_VSLFN_CALL_NR(vslNewStreamEx, (&_stream, (const MKL_INT)brngId, (const MKL_INT)1, &seed), errcode);
}
int errcode = 0;
__DAAL_VSLFN_CALL_NR(vslNewStreamEx, (&_stream, (const MKL_INT)brngId, (const MKL_INT)1, &seed), errcode);
}

BaseRNG(const size_t n, const unsigned int * seed, const int brngId = __DAAL_BRNG_MT19937)
: _stream(0), _seed(nullptr), _seedSize(0), _brngId(brngId)
BaseRNG(const size_t n, const unsigned int * seed, const int brngId = __DAAL_BRNG_MT19937) : _stream(0)
{
services::Status s = allocSeeds(n);
if (s)
{
if (seed)
{
for (size_t i = 0; i < n; i++)
{
_seed[i] = seed[i];
}
}
int errcode = 0;
__DAAL_VSLFN_CALL_NR(vslNewStreamEx, (&_stream, (const MKL_INT)brngId, (const MKL_INT)n, seed), errcode);
}
int errcode = 0;
__DAAL_VSLFN_CALL_NR(vslNewStreamEx, (&_stream, (const MKL_INT)brngId, (const MKL_INT)n, seed), errcode);
}

BaseRNG(const BaseRNG<cpu> & other) : _stream(0), _seed(nullptr), _seedSize(other._seedSize), _brngId(other._brngId)
BaseRNG(const BaseRNG<cpu> & other) : _stream(0)
{
services::Status s = allocSeeds(_seedSize);
if (s)
{
for (size_t i = 0; i < _seedSize; i++)
{
_seed[i] = other._seed[i];
}
int errcode = 0;
__DAAL_VSLFN_CALL_NR(vslNewStreamEx, (&_stream, (const MKL_INT)_brngId, (const MKL_INT)_seedSize, _seed), errcode);
if (!errcode) __DAAL_VSLFN_CALL_NR(vslCopyStreamState, (_stream, other._stream), errcode);
}
int errcode = 0;
__DAAL_VSLFN_CALL_NR(vslCopyStream, (&_stream, other._stream), errcode);
}

~BaseRNG()
{
daal::services::daal_free((void *)_seed);
int errcode = 0;
__DAAL_VSLFN_CALL_NR(vslDeleteStream, (&_stream), errcode);
}
Expand Down Expand Up @@ -342,19 +315,10 @@ class BaseRNG : public BaseRNGIface<cpu>
void * getState() { return _stream; }

protected:
services::Status allocSeeds(const size_t n)
{
_seedSize = n;
_seed = (unsigned int *)daal::services::daal_malloc(sizeof(unsigned int) * n);
DAAL_CHECK_MALLOC(_seed);
return services::Status();
}
services::Status allocSeeds(const size_t n) { return services::Status(); }

private:
void * _stream;
unsigned int * _seed;
size_t _seedSize;
const int _brngId;
};

/*
Expand Down
47 changes: 9 additions & 38 deletions cpp/daal/src/externals/service_rng_openrng.h
Original file line number Diff line number Diff line change
Expand Up @@ -244,48 +244,22 @@ template <CpuType cpu>
class BaseRNG : public BaseRNGIface<cpu>
{
public:
BaseRNG(const unsigned int seed, const int brngId) : _stream(0), _seed(nullptr), _seedSize(0), _brngId(brngId)
BaseRNG(const unsigned int seed, const int brngId) : _stream(0)
{
services::Status s = allocSeeds(1);
if (s)
{
_seed[0] = seed;
int errcode = 0;
errcode = vslNewStreamEx(&_stream, (openrng_int_t)brngId, 1, &seed);
}
int errcode = 0;
__DAAL_VSLFN_CALL_NR(vslNewStreamEx, (&_stream, (const MKL_INT)brngId, (const MKL_INT)1, &seed), errcode);
}

BaseRNG(const size_t n, const unsigned int * seed, const int brngId = __DAAL_BRNG_MT19937)
: _stream(0), _seed(nullptr), _seedSize(0), _brngId(brngId)
BaseRNG(const size_t n, const unsigned int * seed, const int brngId = __DAAL_BRNG_MT19937) : _stream(0)
{
services::Status s = allocSeeds(n);
if (s)
{
if (seed)
{
for (size_t i = 0; i < n; i++)
{
_seed[i] = seed[i];
}
}
int errcode = 0;
errcode = vslNewStreamEx(&_stream, (openrng_int_t)brngId, (openrng_int_t)n, seed);
}
int errcode = 0;
__DAAL_VSLFN_CALL_NR(vslNewStreamEx, (&_stream, (const MKL_INT)brngId, (const MKL_INT)n, seed), errcode);
}

BaseRNG(const BaseRNG<cpu> & other) : _stream(0), _seed(nullptr), _seedSize(other._seedSize), _brngId(other._brngId)
BaseRNG(const BaseRNG<cpu> & other) : _stream(0)
{
services::Status s = allocSeeds(_seedSize);
if (s)
{
for (size_t i = 0; i < _seedSize; i++)
{
_seed[i] = other._seed[i];
}
int errcode = 0;
errcode = vslNewStreamEx(&_stream, _brngId, _seedSize, _seed);
if (!errcode) errcode = vslCopyStreamState(_stream, other._stream);
}
int errcode = 0;
__DAAL_VSLFN_CALL_NR(vslCopyStream, (&_stream, other._stream), errcode);
}

~BaseRNG()
Expand Down Expand Up @@ -344,9 +318,6 @@ class BaseRNG : public BaseRNGIface<cpu>

private:
void * _stream;
unsigned int * _seed;
size_t _seedSize;
const int _brngId;
};

/*
Expand Down

0 comments on commit 0cd591f

Please sign in to comment.