diff --git a/nntrainer/tensor/float_tensor.cpp b/nntrainer/tensor/float_tensor.cpp index 9c31c40f2..a33cfbe20 100644 --- a/nntrainer/tensor/float_tensor.cpp +++ b/nntrainer/tensor/float_tensor.cpp @@ -440,21 +440,6 @@ Tensor &FloatTensor::add_strided(Tensor const &input, Tensor &output, return output; } -int FloatTensor::add_i(Tensor const &m, Tensor &output, float const alpha) { - auto f = [&](const BroadcastInfo &e, const float *buf, const float *m_buf, - float *out_buf) { - saxpy(e.buffer_size, alpha, m_buf, e.strides[3], out_buf, strides[3]); - }; - - try { - apply_broadcast(m, f, output); - } catch (std::exception &err) { - ml_loge("%s %s", typeid(err).name(), err.what()); - return ML_ERROR_INVALID_PARAMETER; - } - return ML_ERROR_NONE; -} - int FloatTensor::add_i_partial(unsigned int len, unsigned int addr_idx, Tensor &m, unsigned int incX, unsigned int incY, const Tensor alphas, unsigned int alpha_idx) { diff --git a/nntrainer/tensor/float_tensor.h b/nntrainer/tensor/float_tensor.h index 23681fc33..17d797181 100644 --- a/nntrainer/tensor/float_tensor.h +++ b/nntrainer/tensor/float_tensor.h @@ -308,11 +308,6 @@ class FloatTensor : public TensorBase { Tensor &add_strided(Tensor const &input, Tensor &output, const float beta) const override; - /** - * @copydoc Tensor::add_i(Tensor const &m, float const alpha) - */ - int add_i(Tensor const &m, Tensor &output, float const alpha) override; - /** * @copydoc Tensor::add_i_partial() */ diff --git a/nntrainer/tensor/half_tensor.cpp b/nntrainer/tensor/half_tensor.cpp index bdc509041..da4bf76a0 100644 --- a/nntrainer/tensor/half_tensor.cpp +++ b/nntrainer/tensor/half_tensor.cpp @@ -422,22 +422,6 @@ Tensor &HalfTensor::add_strided(Tensor const &input, Tensor &output, return output; } -int HalfTensor::add_i(Tensor const &m, Tensor &output, float const alpha) { - auto f = [&](const BroadcastInfo &e, const _FP16 *buf, const _FP16 *m_buf, - _FP16 *out_buf) { - saxpy(e.buffer_size, alpha, m_buf, e.strides[3], out_buf, strides[3]); - /// @todo: saxpy is not valid for _FP16 - }; - - try { - apply_broadcast(m, f, output); - } catch (std::exception &err) { - ml_loge("%s %s", typeid(err).name(), err.what()); - return ML_ERROR_INVALID_PARAMETER; - } - return ML_ERROR_NONE; -} - int HalfTensor::add_i_partial(unsigned int len, unsigned int addr_idx, Tensor &m, unsigned int incX, unsigned int incY, const Tensor alphas, unsigned int alpha_idx) { diff --git a/nntrainer/tensor/half_tensor.h b/nntrainer/tensor/half_tensor.h index 206a8482d..3540c2703 100644 --- a/nntrainer/tensor/half_tensor.h +++ b/nntrainer/tensor/half_tensor.h @@ -308,11 +308,6 @@ class HalfTensor : public TensorBase { Tensor &add_strided(Tensor const &input, Tensor &output, const float beta) const override; - /** - * @copydoc Tensor::add_i(Tensor const &m, float const alpha) - */ - int add_i(Tensor const &m, Tensor &output, float const alpha) override; - /** * @copydoc Tensor::add_i_partial() */ diff --git a/nntrainer/tensor/tensor.cpp b/nntrainer/tensor/tensor.cpp index b0cbae110..3887a58f4 100644 --- a/nntrainer/tensor/tensor.cpp +++ b/nntrainer/tensor/tensor.cpp @@ -419,10 +419,6 @@ Tensor &Tensor::multiply(Tensor const &m, Tensor &output, std::invalid_argument) << getName() << " is not contiguous, cannot multiply"; - NNTR_THROW_IF(!getContiguous() || !m.getContiguous() || - !output.getContiguous(), - std::invalid_argument) - << getName() << " is not contiguous, cannot multiply"; itensor->multiply(m, output, beta); return output; } @@ -521,7 +517,13 @@ Tensor &Tensor::add(float const &value, Tensor &output) const { } int Tensor::add_i(Tensor const &m, float const alpha) { - return itensor->add_i(m, *this, alpha); + try { + itensor->add(m, *this, alpha); + } catch (std::exception &err) { + ml_loge("%s %s", typeid(err).name(), err.what()); + return ML_ERROR_INVALID_PARAMETER; + } + return ML_ERROR_NONE; } int Tensor::add_i_partial(unsigned int len, unsigned int addr_idx, Tensor &m, @@ -537,6 +539,11 @@ Tensor Tensor::add(Tensor const &m, float const alpha) const { } Tensor &Tensor::add(Tensor const &m, Tensor &output, float const alpha) const { + NNTR_THROW_IF(m.getFormat() != this->getFormat(), std::invalid_argument) + << "Tensor Format of " << getName() << ":" + << ((bool)(this->getFormat()) ? "NHWC" : "NCHW") << " is not match. (" + << ((bool)(m.getFormat()) ? "NHWC" : "NCHW") << ")"; + NNTR_THROW_IF(!itensor->getContiguous() || !m.getContiguous() || !output.getContiguous(), std::invalid_argument) diff --git a/nntrainer/tensor/tensor_base.cpp b/nntrainer/tensor/tensor_base.cpp index 0711504f8..16a68a3d0 100644 --- a/nntrainer/tensor/tensor_base.cpp +++ b/nntrainer/tensor/tensor_base.cpp @@ -410,12 +410,6 @@ Tensor &TensorBase::add_strided(Tensor const &input, Tensor &output, getStringDataType()); } -int TensorBase::add_i(Tensor const &m, Tensor &output, float const alpha) { - throw std::invalid_argument( - "Tensor::add_i() is currently not supported in tensor data type " + - getStringDataType()); -} - int TensorBase::add_i_partial(unsigned int len, unsigned int addr_idx, Tensor &m, unsigned int incX, unsigned int incY, const Tensor alphas, unsigned int alpha_idx) { diff --git a/nntrainer/tensor/tensor_base.h b/nntrainer/tensor/tensor_base.h index cc6ad0c2b..bab74bd04 100644 --- a/nntrainer/tensor/tensor_base.h +++ b/nntrainer/tensor/tensor_base.h @@ -306,11 +306,6 @@ class TensorBase { virtual Tensor &add_strided(Tensor const &input, Tensor &output, const float beta) const; - /** - * @copydoc Tensor::add_i(Tensor const &m, float const alpha) - */ - virtual int add_i(Tensor const &m, Tensor &output, float const alpha); - /** * @copydoc Tensor::add_i_partial() */