diff --git a/Applications/KNN/jni/meson.build b/Applications/KNN/jni/meson.build index bc50dc0214..58ca099d75 100644 --- a/Applications/KNN/jni/meson.build +++ b/Applications/KNN/jni/meson.build @@ -15,4 +15,4 @@ e = executable('knn_sample', install_dir: application_install_dir ) -test('app_knn', e, args: [nntr_app_resdir / 'KNN']) +test('app_knn', e, args: [nntr_app_resdir / 'KNN/']) diff --git a/nntrainer/graph/network_graph.cpp b/nntrainer/graph/network_graph.cpp index 2d4cfdc769..b7f4d1cffd 100644 --- a/nntrainer/graph/network_graph.cpp +++ b/nntrainer/graph/network_graph.cpp @@ -768,9 +768,10 @@ NetworkGraph::finalizeContext(const std::shared_ptr &lnode, * node is going to be used with in-place optimizations. */ auto out_specs = init_context.getOutSpecs(); + /// @note try move inplace control to finalize bool shared_var = false, shared_grad = false; - if (lnode->executeInPlace() != InPlace::NONE) { + if (lnode->executeInPlace() != InPlace::NONE && lnode->supportInPlace()) { setInplaceSharedMemoryConfigByLayer(lnode, shared_var, shared_grad); for (unsigned int i = 0; i < out_specs.size(); ++i) { auto &s = out_specs.at(i); @@ -1556,8 +1557,9 @@ void NetworkGraph::requestOptimizerVariable( const TensorDim &dim = w->getDim(); std::vector dims = cb(dim); w->setOptimizerVariables(tensor_manager->requestWeightOptimizerVariables( - dims, w->getName(), TensorLifespan::MAX_LIFESPAN, - w->isGradientClipByGlobalNorm(), Tensor::Initializer::ZEROS)); + dims, w->getName(), ":opt", TensorLifespan::MAX_LIFESPAN, + w->isGradientClipByGlobalNorm(), w->isMixedPrecision(), + Tensor::Initializer::ZEROS)); } } } diff --git a/nntrainer/layers/input_layer.cpp b/nntrainer/layers/input_layer.cpp index eabd40b297..d9f058d8ce 100644 --- a/nntrainer/layers/input_layer.cpp +++ b/nntrainer/layers/input_layer.cpp @@ -33,8 +33,7 @@ namespace nntrainer { static constexpr size_t SINGLE_INOUT_IDX = 0; InputLayer::InputLayer() : - Layer(), - input_props(props::Normalization(), props::Standardization()) {} + Layer(), input_props(props::Normalization(), props::Standardization()) {} void InputLayer::setProperty(const std::vector &values) { auto remain_props = loadProperties(values, input_props); @@ -47,7 +46,7 @@ void InputLayer::forwarding(RunLayerContext &context, bool training) { Tensor &hidden_ = context.getOutput(SINGLE_INOUT_IDX); if (!context.executeInPlace()) { Tensor &input_ = context.getInput(SINGLE_INOUT_IDX); - hidden_.copy(input_); + hidden_.copyData(input_); } if (std::get(input_props)) @@ -70,7 +69,21 @@ void InputLayer::finalize(InitLayerContext &context) { std::vector output_dims = context.getInputDimensions(); + for (auto &d : output_dims) { + d.setDataType(context.getActivationDataType()); + } + context.setOutputDimensions(output_dims); + + is_inplace = true; + + /** + * @note Input Layer assuems that the FP32 IN Tensor always. Therefore, if the + * activation data type is not fp32, then it does not support in-place + * operation. + */ + if (context.getActivationDataType() != ml::train::TensorDim::DataType::FP32) + is_inplace = false; } } /* namespace nntrainer */ diff --git a/nntrainer/layers/input_layer.h b/nntrainer/layers/input_layer.h index f6728d676b..e9183e23d1 100644 --- a/nntrainer/layers/input_layer.h +++ b/nntrainer/layers/input_layer.h @@ -82,7 +82,7 @@ class InputLayer : public Layer { /** * @copydoc Layer::supportInPlace() */ - bool supportInPlace() const override { return true; } + bool supportInPlace() const override { return is_inplace; } /** * @copydoc Layer::exportTo(Exporter &exporter, ml::train::ExportMethods @@ -105,6 +105,7 @@ class InputLayer : public Layer { private: std::tuple input_props; + bool is_inplace; }; } // namespace nntrainer diff --git a/nntrainer/layers/layer_context.cpp b/nntrainer/layers/layer_context.cpp index fff2eb15ec..f0856c1dbb 100644 --- a/nntrainer/layers/layer_context.cpp +++ b/nntrainer/layers/layer_context.cpp @@ -169,6 +169,19 @@ Tensor &RunLayerContext::getWeightGrad(unsigned int idx) const { return weights[idx]->getGradientRef(); } +/** + * @brief Get the Weight Gradient tensor object + * + * @param idx Identifier of the weight + * @return Tensor& Reference to the weight grad tensor + */ +Tensor &RunLayerContext::getWeightFP32(unsigned int idx) const { + if (!weights[idx]->hasGradient()) + throw std::invalid_argument( + "Requesting gradient for a non-trainable weight."); + return weights[idx]->getVariableFP32Ref(); +} + /** * @brief Get the Weight Optimizer Variable tensor object * diff --git a/nntrainer/layers/layer_context.h b/nntrainer/layers/layer_context.h index e5c6759638..e2f428aa2c 100644 --- a/nntrainer/layers/layer_context.h +++ b/nntrainer/layers/layer_context.h @@ -463,6 +463,15 @@ class RunLayerContext { Tensor &getWeightGrad(unsigned int idx) const; /** + * @brief Get the Weight Gradient tensor object + * + * @param idx Identifier of the weight + * @return Tensor& Reference to the weight grad tensor + */ + Tensor &getWeightFP32(unsigned int idx) const; + + /** + * @brief Get the Weight Optimizer Variable tensor object * * @param idx Identifier of the weight diff --git a/nntrainer/layers/layer_node.h b/nntrainer/layers/layer_node.h index 93e7ac7069..7dfb1bd1a0 100644 --- a/nntrainer/layers/layer_node.h +++ b/nntrainer/layers/layer_node.h @@ -487,6 +487,7 @@ class LayerNode final : public ml::train::Layer, public GraphNode { const std::vector getOutputDimensions() const; /** * @brief Get the Weight object + * currently, only unittest uses this func. * * @param idx Identifier of the weight * @return Weight& Reference to the weight @@ -495,11 +496,11 @@ class LayerNode final : public ml::train::Layer, public GraphNode { NNTR_THROW_IF(!run_context, std::runtime_error) << __func__ << " layer needs to be finalized first!"; if (run_context->weightHasGradient(idx)) { - return Weight(run_context->getWeight(idx), - run_context->getWeightGrad(idx), - run_context->getWeightName(idx)); + return Weight( + run_context->getWeight(idx), run_context->getWeightGrad(idx), + run_context->getWeightFP32(idx), run_context->getWeightName(idx)); } else { - return Weight(run_context->getWeight(idx), Tensor(), + return Weight(run_context->getWeight(idx), Tensor(), Tensor(), run_context->getWeightName(idx)); } } diff --git a/nntrainer/layers/loss/mse_loss_layer.cpp b/nntrainer/layers/loss/mse_loss_layer.cpp index 7f7bd1626f..ec9bc9b844 100644 --- a/nntrainer/layers/loss/mse_loss_layer.cpp +++ b/nntrainer/layers/loss/mse_loss_layer.cpp @@ -20,7 +20,16 @@ static constexpr size_t SINGLE_INOUT_IDX = 0; void MSELossLayer::forwarding(RunLayerContext &context, bool training) { Tensor &hidden_ = context.getOutput(SINGLE_INOUT_IDX); - Tensor &y = context.getInput(SINGLE_INOUT_IDX); + + Tensor empty_tensor; + Tensor &y = context.getInput(SINGLE_INOUT_IDX).getDataType() == + ml::train::TensorDim::DataType::FP32 + ? context.getInput(SINGLE_INOUT_IDX) + : empty_tensor; + + if (y.empty()) + y = context.getInput(SINGLE_INOUT_IDX) + .clone(ml::train::TensorDim::DataType::FP32); // hidden_ <- y2 - y; if (context.isLabelAvailable(SINGLE_INOUT_IDX)) { diff --git a/nntrainer/optimizers/adam.cpp b/nntrainer/optimizers/adam.cpp index 18c0a0fcc1..530e7fdf31 100644 --- a/nntrainer/optimizers/adam.cpp +++ b/nntrainer/optimizers/adam.cpp @@ -36,7 +36,15 @@ Adam::~Adam() {} enum AdamParams { wm, wv }; std::vector Adam::getOptimizerVariableDim(const TensorDim &dim) { - return {dim, dim}; + /** + * @note We assume the optimizer parameters should be full precsion to + * maintain the accuracy even in mixed precision training. + */ + TensorDim wm_dim(dim); + TensorDim wv_dim(dim); + wm_dim.setDataType(ml::train::TensorDim::DataType::FP32); + wv_dim.setDataType(ml::train::TensorDim::DataType::FP32); + return {wm_dim, wv_dim}; } void Adam::exportTo(Exporter &exporter, @@ -64,7 +72,15 @@ double Adam::getUpdatedLearningRate(unsigned int iteration, double ll) const { } void Adam::applyGradient(RunOptimizerContext &context) { - Tensor &x_grad = context.getGradient(); + Tensor empty_tensor; + + Tensor &x_grad = + context.getGradient().getDataType() == ml::train::TensorDim::DataType::FP32 + ? context.getGradient() + : empty_tensor; + + if (x_grad.empty()) + x_grad = context.getGradient().clone(ml::train::TensorDim::DataType::FP32); auto &beta1 = std::get(adam_props).get(); auto &beta2 = std::get(adam_props).get(); @@ -91,7 +107,7 @@ void Adam::applyGradient(RunOptimizerContext &context) { denom.add_i(epsilon); wm.divide(denom, x_grad); - context.applyGradient(context.getLearningRate() / biasCorrection1); + context.applyGradient(context.getLearningRate() / biasCorrection1, x_grad); } else { std::function sqrtEps = [epsilon](double f) { @@ -100,8 +116,9 @@ void Adam::applyGradient(RunOptimizerContext &context) { x_grad = wv.apply(sqrtEps, x_grad); x_grad.multiply_i(wm); - context.applyGradient(getUpdatedLearningRate(context.getIteration(), - context.getLearningRate())); + context.applyGradient( + getUpdatedLearningRate(context.getIteration(), context.getLearningRate()), + x_grad); } } diff --git a/nntrainer/optimizers/optimizer_context.cpp b/nntrainer/optimizers/optimizer_context.cpp index da4cd1f7e9..f70ab773a9 100644 --- a/nntrainer/optimizers/optimizer_context.cpp +++ b/nntrainer/optimizers/optimizer_context.cpp @@ -42,4 +42,11 @@ Tensor &RunOptimizerContext::getOptimizerVariable(unsigned int idx) const { void RunOptimizerContext::applyGradient(double lr) const { weight->applyGradient(lr); } + +/** + * @brief Apply the gradient with the given learning rate and gradient + */ +void RunOptimizerContext::applyGradient(double lr, Tensor &updated_grad) const { + weight->applyGradient(lr, updated_grad); +} } // namespace nntrainer diff --git a/nntrainer/optimizers/optimizer_context.h b/nntrainer/optimizers/optimizer_context.h index 62f9e0945d..6b4b983e35 100644 --- a/nntrainer/optimizers/optimizer_context.h +++ b/nntrainer/optimizers/optimizer_context.h @@ -35,9 +35,7 @@ class RunOptimizerContext { * */ RunOptimizerContext(Weight *w = nullptr, size_t iter = 0, double lr = 0.0) : - weight(w), - iteration(iter), - learning_rate(lr) {} + weight(w), iteration(iter), learning_rate(lr) {} /** * @brief Get the Weight tensor object @@ -75,6 +73,16 @@ class RunOptimizerContext { */ void applyGradient(double lr) const; + /** + * @brief Apply the gradient with the given learning rate and updated + * gradient + * + * @param lr learning rate + * @param updated_grad gradient tensor which is updated. (usually it could be + * fp32) + */ + void applyGradient(double lr, Tensor &updated_grad) const; + /** * @brief Get the current iteration value * diff --git a/nntrainer/tensor/manager.cpp b/nntrainer/tensor/manager.cpp index 9a0d235ba9..14d710b3c0 100644 --- a/nntrainer/tensor/manager.cpp +++ b/nntrainer/tensor/manager.cpp @@ -414,7 +414,7 @@ std::vector Manager::requestWeights( // var_exec_order.push_back(TensorPool::PERSIST_END_ORDER); } - Tensor *var = nullptr, *grad = nullptr; + Tensor *var = nullptr, *grad = nullptr, *var32 = nullptr; bool is_dependent = !shared_names.empty(); if (is_dependent) { /// shared_name is used and the orignal name is discarded @@ -431,6 +431,17 @@ std::vector Manager::requestWeights( grad = tensor_pool.requestOrExtend(shared_name + Var_Grad::grad_suffix, dim_g, grad_exec_order, grad_ls, Tensor::Initializer::ZEROS); + + if (var->getDataType() != ml::train::TensorDim::DataType::FP32) { + TensorDim var32_dim(dim_v); + var32_dim.setDataType(ml::train::TensorDim::DataType::FP32); + std::vector var32_exec_order; + var32_exec_order.push_back(TensorPool::PERSIST_END_ORDER); + + var32 = weight_pool.requestOrExtend(shared_name + ":var32", var32_dim, + var32_exec_order, var_ls, + Tensor::Initializer::ZEROS); + } } } else { /** case requesting fresh weights */ @@ -448,11 +459,21 @@ std::vector Manager::requestWeights( grad = tensor_pool.request(name + Var_Grad::grad_suffix, dim_g, grad_exec_order, grad_ls, Tensor::Initializer::ZEROS, is_wgrad); + if (var->getDataType() != ml::train::TensorDim::DataType::FP32) { + TensorDim var32_dim(dim_v); + var32_dim.setDataType(ml::train::TensorDim::DataType::FP32); + std::vector var32_exec_order; + var32_exec_order.push_back(TensorPool::PERSIST_END_ORDER); + var32 = + weight_pool.request(name + ":var32", var32_dim, var32_exec_order, + var_ls, Tensor::Initializer::ZEROS); + } } } weights_v2.emplace_back(std::make_unique( - var, grad, w_reg, w_reg_const, decay, is_dependent, clip_by_global_norm)); + var, grad, var32, w_reg, w_reg_const, decay, is_dependent, + clip_by_global_norm, axis, loss_scale)); } std::transform(weights_v2.begin() + current_size, weights_v2.end(), @@ -668,15 +689,15 @@ bool Manager::isSecondLastAccess(const std::string &name, */ std::vector Manager::requestWeightOptimizerVariables( const std::vector &dims, const std::string &name, - const TensorLifespan &lifespan, bool is_grad_clip, - Tensor::Initializer initializer) { + const std::string &suffix, const TensorLifespan &lifespan, bool is_grad_clip, + bool is_mixed_precision, Tensor::Initializer initializer) { std::vector ret; ret.reserve(dims.size()); std::vector exec; exec.reserve(1); - if (is_grad_clip) { + if (is_grad_clip || is_mixed_precision) { exec.emplace_back(TensorPool::PERSIST_END_ORDER); } else { exec.emplace_back(getMinMaxTensorExecutionOrder(name, true).second); @@ -685,7 +706,7 @@ std::vector Manager::requestWeightOptimizerVariables( /// @note this is assuming weight optimizer variables is treated as weight, if /// not, there is room to optimize below behavior for (unsigned int idx = 0; idx < dims.size(); idx++) - ret.push_back(weight_pool.request(name + ":opt" + std::to_string(idx), + ret.push_back(weight_pool.request(name + suffix + std::to_string(idx), dims[idx], exec, lifespan, initializer)); return ret; diff --git a/nntrainer/tensor/manager.h b/nntrainer/tensor/manager.h index ab1c018153..80ffb9d21d 100644 --- a/nntrainer/tensor/manager.h +++ b/nntrainer/tensor/manager.h @@ -224,7 +224,8 @@ class Manager { */ std::vector requestWeightOptimizerVariables( const std::vector &dims, const std::string &name, - const TensorLifespan &lifespan, bool is_grad_clip, + const std::string &suffix, const TensorLifespan &lifespan, + bool is_grad_clip, bool is_mixed_type, Tensor::Initializer initializer = Tensor::Initializer::NONE); /** diff --git a/nntrainer/tensor/tensor.cpp b/nntrainer/tensor/tensor.cpp index 4f1e8e0721..b14bbd7ae4 100644 --- a/nntrainer/tensor/tensor.cpp +++ b/nntrainer/tensor/tensor.cpp @@ -3065,6 +3065,18 @@ Tensor Tensor::clone() const { return t; } +Tensor Tensor::clone(ml::train::TensorDim::DataType type) const { + if (getDataType() == type) + return clone(); + + TensorDim dim = getDim(); + dim.setDataType(type); + Tensor t(dim, true); + t.copyData(*this); + t.name = name; + return t; +} + void Tensor::reshape(const TensorDim &d) { NNTR_THROW_IF(!contiguous, std::invalid_argument) diff --git a/nntrainer/tensor/tensor.h b/nntrainer/tensor/tensor.h index 211334da40..2ea0393e66 100644 --- a/nntrainer/tensor/tensor.h +++ b/nntrainer/tensor/tensor.h @@ -1680,6 +1680,13 @@ class Tensor { */ Tensor clone() const; + /** + * @brief Convient wrapper for inplace copy of @a this. + * @param[in] type output tensor data type + * @retval Copied version of this + */ + Tensor clone(ml::train::TensorDim::DataType type) const; + /** * @brief Save the Tensor into file * @param[in] file output file stream diff --git a/nntrainer/tensor/weight.cpp b/nntrainer/tensor/weight.cpp index f98c8c8356..d8db5ba094 100644 --- a/nntrainer/tensor/weight.cpp +++ b/nntrainer/tensor/weight.cpp @@ -34,6 +34,28 @@ Weight::Weight(const TensorDim &dim, const Tensor::Initializer init, throw std::invalid_argument("Weight initializer cannot be none"); if (regularizer == WeightRegularizer::UNKNOWN) throw std::invalid_argument("Weight regularizer unknown"); + + std::string var32_suffix = ":fp32"; + std::string var32_name = name + var32_suffix; + + /** + * @note We assume if the Weight Data Type is not FP32, then FP32 Weight is + * necessary to maintain the accuracy. + * We could think it can be other data type and if there is the case to + * support other data type, then the code below needs to be udpated. + * + * Also, the loss_scale is not used in Weight but leave as it is for later + * usage. + */ + + if (train && dim.getDataType() != ml::train::TensorDim::DataType::FP32) { + TensorDim var32_dim(dim); + var32_dim.setDataType(ml::train::TensorDim::DataType::FP32); + + var32 = std::make_shared(var32_dim, alloc_now_, init, var32_name); + } else { + var32 = std::make_shared(var32_name); + } } Weight::Weight(const TensorDim &dim_v, const TensorDim &dim_g, @@ -52,6 +74,94 @@ Weight::Weight(const TensorDim &dim_v, const TensorDim &dim_g, throw std::invalid_argument("Weight initializer cannot be none"); if (regularizer == WeightRegularizer::UNKNOWN) throw std::invalid_argument("Weight regularizer unknown"); + + std::string var32_suffix = ":fp32"; + std::string var32_name = name + var32_suffix; + + if (train && dim_v.getDataType() != ml::train::TensorDim::DataType::FP32) { + TensorDim var32_dim(dim_v); + var32_dim.setDataType(ml::train::TensorDim::DataType::FP32); + std::string var32_suffix = ":fp32"; + std::string var32_name = name + var32_suffix; + + var32 = std::make_shared(var32_dim, alloc_now_, init, var32_name); + } else { + var32 = std::make_shared(var32_name); + } +} + +Weight::Weight(const Tensor &v, const Tensor &g, const Tensor &v32, + const std::string &n, bool is_dependent, + unsigned int output_axis_) : + Var_Grad(v, g, n, is_dependent), + regularizer(WeightRegularizer::NONE), + regularizer_constant(1.0f), + decay(0.0f), + clip_by_global_norm(0.0f), + output_axis(output_axis_), + loss_scale(0.0), + var32(std::make_shared(n + ":fp32")) { + + if (!g.empty() && isMixedPrecision()) { + TensorDim var32_dim(v.getDim()); + var32_dim.setDataType(ml::train::TensorDim::DataType::FP32); + if (!v32.empty()) + var32 = std::make_shared( + v32.getSharedDataTensor(var32_dim, 0, false, n + ":fp32")); + } +} + +Weight::Weight(Tensor *v, Tensor *g, Tensor *v32, const WeightRegularizer reg, + const float reg_const, const float decay, bool is_dependent, + const float max_norm, unsigned int output_axis_, + float loss_scale_) : + Var_Grad(v, g, is_dependent), + regularizer(reg), + regularizer_constant(reg_const), + decay(decay), + clip_by_global_norm(max_norm), + output_axis(output_axis_), + loss_scale(loss_scale_), + var32(std::shared_ptr(v32, [](void *) {})) { + if (!v32) + var32 = std::make_shared(); +} + +void Weight::applyGradient(double lr, Tensor &updated_grad) { + if (isMixedPrecision() && + updated_grad.getDataType() == ml::train::TensorDim::DataType::FP32) { + updated_grad.divide(loss_scale); + var32->add_i(updated_grad, -lr); + quantizeWeight(); + return; + } + + return applyGradient(lr); +} + +void Weight::quantizeWeight() { + if (!isMixedPrecision()) + return; + + Tensor &var = getVariableRef(); + ml::train::TensorDim::DataType type = var.getDataType(); + switch (type) { + case ml::train::TensorDim::DataType::QINT4: + // NYI + break; + case ml::train::TensorDim::DataType::QINT8: + // NYI + break; + case ml::train::TensorDim::DataType::FP16: + getVariableRef().copy(getVariableFP32Ref()); + break; + case ml::train::TensorDim::DataType::FP32: + break; + default: + break; + } + + return; } } // namespace nntrainer diff --git a/nntrainer/tensor/weight.h b/nntrainer/tensor/weight.h index 552f6d5739..5382c686e1 100644 --- a/nntrainer/tensor/weight.h +++ b/nntrainer/tensor/weight.h @@ -46,7 +46,7 @@ class Weight : public Var_Grad { decay(0.0f), clip_by_global_norm(0.0f), output_axis(3), - loss_scale(0.0) {} + loss_scale(1.0) {} /** * @brief Construct a new Weight object @@ -66,7 +66,7 @@ class Weight : public Var_Grad { const float reg_const = 1.0f, const float decay = 0.0f, const float clip_by_global_norm = 0.0f, bool ng = true, bool alloc_now = false, std::string name = "", unsigned int axis = 3, - float loss_scale_ = 0.0); + float loss_scale_ = 1.0); /** * @brief Construct a new Weight object @@ -87,7 +87,7 @@ class Weight : public Var_Grad { const float reg_const = 1.0f, const float decay = 0.0f, const float clip_by_global_norm = 0.0f, bool ng = true, bool alloc_now = false, std::string name = "", unsigned int axis = 3, - float loss_scale_ = 0.0); + float loss_scale_ = 1.0); /** * @brief Construct a new Weight object @@ -114,6 +114,7 @@ class Weight : public Var_Grad { * * @param v Already created variable object * @param g Already created gradient object + * @param v32 Already created gradient object * @param n Name for this Weight * * @note This is primarily used to created wrapper of variable extracted from @@ -123,35 +124,24 @@ class Weight : public Var_Grad { * uses only, as Weight does not own the tensors v and g, and can go invalid * if the owner of these tensors free the tensors. */ - explicit Weight(const Tensor &v, const Tensor &g, const std::string &n = "", - bool is_dependent = false, unsigned int output_axis_ = 3) : - Var_Grad(v, g, n, is_dependent), - regularizer(WeightRegularizer::NONE), - regularizer_constant(1.0f), - decay(0.0f), - clip_by_global_norm(0.0f), - output_axis(output_axis_), - loss_scale(0.0) {} + explicit Weight(const Tensor &v, const Tensor &g, const Tensor &v32, + const std::string &n = "", bool is_dependent = false, + unsigned int output_axis_ = 3); /** * @brief Construct a new Weight object * * @param v ptr to already created variable tensor * @param g ptr to already created gradient tensor + * @param v32 ptr to already created variable32 tensor * @param reg Regularizer for the weight * @param reg_const Constant multiplier for regularizer */ - explicit Weight(Tensor *v, Tensor *g, const WeightRegularizer reg, - const float reg_const, const float decay, - bool is_dependent = false, const float max_norm = 0.0f, - unsigned int output_axis_ = 3, float loss_scale_ = 0.0f) : - Var_Grad(v, g, is_dependent), - regularizer(reg), - regularizer_constant(reg_const), - decay(decay), - clip_by_global_norm(max_norm), - output_axis(output_axis_), - loss_scale(loss_scale_) {} + explicit Weight(Tensor *v, Tensor *g, Tensor *v32, + const WeightRegularizer reg, const float reg_const, + const float decay, bool is_dependent = false, + const float max_norm = 0.0f, unsigned int output_axis_ = 3, + float loss_scale_ = 1.0f); /** * @brief Swap for weight @@ -170,6 +160,7 @@ class Weight : public Var_Grad { swap(lhs.output_axis, rhs.output_axis); swap(lhs.opt_vars, rhs.opt_vars); swap(lhs.loss_scale, rhs.loss_scale); + swap(lhs.var32, rhs.var32); } /** @@ -213,6 +204,8 @@ class Weight : public Var_Grad { w.var = std::make_shared(this->var->clone()); if (!this->grad->empty()) w.grad = std::make_shared(this->grad->clone()); + if (!this->var32->empty()) + w.var32 = std::make_shared(this->var32->clone()); return w; } @@ -294,6 +287,13 @@ class Weight : public Var_Grad { */ void applyGradient(double lr) { var->add_i(*grad.get(), -lr); } + /** + * @brief Apply the gradient to the weight with updated gradient + * @param[in] updated_grad gradient tensor which is updated in optimizer + * it might be different data type with gradient in weight. .eg : FP32 + */ + void applyGradient(double lr, Tensor &updated_grad); + /** * @brief Check if the gradient is supposed to be clipped by global norm with * the given max_norm value @@ -316,6 +316,16 @@ class Weight : public Var_Grad { return clip_by_global_norm > epsilon; } + /** + * @brief Check if the variable type is not full precision + * + * @return true if it is not full precsion + * @return false otherwise + */ + bool isMixedPrecision() const { + return var->getDataType() != ml::train::TensorDim::DataType::FP32; + } + /** * @brief clip the gradient value based on the given global norm * @@ -326,6 +336,19 @@ class Weight : public Var_Grad { grad->multiply_i(clip_by_global_norm / (global_norm + epsilon)); } + /** + * @brief Get the variable FP32 tensor (by reference) + * + * @return Tensor Variable FP32 tensor + */ + Tensor &getVariableFP32Ref() { return *var32.get(); } + + /** + * @brief Quantize var32 to var + * + */ + void quantizeWeight(); + private: static constexpr float epsilon = 1e-6; /**< epsilon for zero comparison */ static constexpr float epsilon_decay = @@ -337,7 +360,8 @@ class Weight : public Var_Grad { float clip_by_global_norm; /**< constant factor to clip gradient by L2 norm */ unsigned int output_axis; float loss_scale; - std::vector opt_vars; /**< optimizer variables */ + std::vector + opt_vars; /**< optimizer variables : We assume it is always full-precsion*/ std::shared_ptr var32; /** diff --git a/test/unittest/models/meson.build b/test/unittest/models/meson.build index 7166fc41ff..4a6e81e65d 100644 --- a/test/unittest/models/meson.build +++ b/test/unittest/models/meson.build @@ -11,6 +11,10 @@ models_targets = [ # disable temperally ] +if get_option('enable-fp16') + models_targets += 'unittest_models_mixed_precision.cpp' +endif + test_target += models_targets exe = executable( test_name, diff --git a/test/unittest/models/unittest_models_mixed_precision.cpp b/test/unittest/models/unittest_models_mixed_precision.cpp new file mode 100644 index 0000000000..becf11ff44 --- /dev/null +++ b/test/unittest/models/unittest_models_mixed_precision.cpp @@ -0,0 +1,54 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Copyright (C) 2024 Jijoong Moon + * + * @file unittest_models_mixed_precision.cpp + * @date 3 May 2024 + * @brief unittest models to cover mixed precision + * @see https://github.com/nnstreamer/nntrainer + * @author Jijoong Moon + * @bug No known bugs except for NYI items + */ + +#include + +#include + +#include +#include +#include + +#include + +using namespace nntrainer; + +static std::unique_ptr fc_mixed_training() { + std::unique_ptr nn(new NeuralNetwork()); + nn->setProperty( + {"batch_size=2", "model_tensor_type=FP16-FP16", "loss_scale=128"}); + + auto graph = makeGraph({ + {"input", {"name=in", "input_shape=1:1:3"}}, + {"Fully_connected", {"name=fc", "input_layers=in", "unit=10"}}, + {"mse", {"name=loss", "input_layers=fc"}}, + }); + for (auto &node : graph) { + nn->addLayer(node); + } + + nn->setOptimizer(ml::train::createOptimizer("adam", {"learning_rate = 0.1"})); + + return nn; +} + +GTEST_PARAMETER_TEST( + MixedPrecision, nntrainerModelTest, + ::testing::ValuesIn({ + mkModelTc_V2(fc_mixed_training, "fc_mixed_training", + ModelTestOption::NO_THROW_RUN_V2), + /** ModelTestOption::ALL_V2), + * Disabled for now to check + */ + }), + [](const testing::TestParamInfo &info) + -> const auto & { return std::get<1>(info.param); });