diff --git a/nntrainer/graph/network_graph.cpp b/nntrainer/graph/network_graph.cpp index 68f5dc6c72..2d4cfdc769 100644 --- a/nntrainer/graph/network_graph.cpp +++ b/nntrainer/graph/network_graph.cpp @@ -869,7 +869,7 @@ NetworkGraph::finalizeContext(const std::shared_ptr &lnode, const auto &w_specs = init_context.getWeightsSpec(); for (auto i = 0u; i < w_specs.size(); ++i) { - shared_weight_names.emplace_back(std::get<7>(w_specs.at(i))); + shared_weight_names.emplace_back(std::get<8>(w_specs.at(i))); } } @@ -1018,7 +1018,7 @@ NetworkGraph::refinalizeContext(const std::shared_ptr &lnode, const auto &w_specs = init_context.getWeightsSpec(); for (auto i = 0u; i < w_specs.size(); ++i) { - shared_weight_names.emplace_back(std::get<7>(w_specs.at(i))); + shared_weight_names.emplace_back(std::get<8>(w_specs.at(i))); } } diff --git a/nntrainer/layers/common_properties.h b/nntrainer/layers/common_properties.h index b7afcac81b..b1f12b3455 100644 --- a/nntrainer/layers/common_properties.h +++ b/nntrainer/layers/common_properties.h @@ -1367,6 +1367,16 @@ class ClipGradByGlobalNorm : public Property { using prop_tag = float_prop_tag; /**< property type */ }; +/** + * @brief properties for getting the loss scale value to mixed precision + * + */ +class LossScaleForMixed : public Property { +public: + static constexpr const char *key = "loss_scale"; /**< unique key to access */ + using prop_tag = float_prop_tag; /**< property type */ +}; + /** * @brief Learning Rate props * diff --git a/nntrainer/layers/layer_context.cpp b/nntrainer/layers/layer_context.cpp index d51083ec67..fff2eb15ec 100644 --- a/nntrainer/layers/layer_context.cpp +++ b/nntrainer/layers/layer_context.cpp @@ -43,7 +43,8 @@ InitLayerContext::InitLayerContext(const std::vector &dim, bool in_place_, const std::string &n, const std::string &prefix_, const float max_norm, - std::array tensor_type_) : + std::array tensor_type_, + const float loss_scale_) : input_dim(dim), in_place(in_place_), clip_by_global_norm(max_norm), @@ -51,7 +52,8 @@ InitLayerContext::InitLayerContext(const std::vector &dim, req_out_is_connected(req_out_connected), name(n), prefix(prefix_), - tensor_type(tensor_type_) { + tensor_type(tensor_type_), + loss_scale(loss_scale_) { NNTR_THROW_IF(!validate(), std::invalid_argument) << "Invalid init context name: " << name << " num inputs: " << getNumInputs(); diff --git a/nntrainer/layers/layer_context.h b/nntrainer/layers/layer_context.h index ea38ecc5f7..e5c6759638 100644 --- a/nntrainer/layers/layer_context.h +++ b/nntrainer/layers/layer_context.h @@ -62,7 +62,8 @@ class InitLayerContext { const std::string &n = "", const std::string &prefix_ = "", const float max_norm = 0.0, std::array tensor_type_ = {"NCHW", "FP32", - "FP32"}); + "FP32"}, + const float loss_scale = 0.0); /** * @brief get Tensor Format of Layer * @@ -171,7 +172,7 @@ class InitLayerContext { /** * @brief Request a new weight for the layer * - * @param dim dimension of the weight + * @param dim dimension of Variable of the weight * @param init initializer for the weight * @param reg regularizer for the weight * @param reg_const regularization constant for the weight @@ -187,9 +188,16 @@ class InitLayerContext { const WeightRegularizer reg, const float reg_const, const float decay, const std::string &name, bool trainable = true, unsigned int out_axis = 3) { - weights_spec.emplace_back(dim, init, reg, reg_const, decay, + + /** @note : We assumes the gradient type is same with Activation data + * type.*/ + TensorDim dim_g(dim); + + dim_g.setDataType(getActivationDataType()); + + weights_spec.emplace_back(dim, dim_g, init, reg, reg_const, decay, clip_by_global_norm, trainable, - prefix + ":" + name, out_axis); + prefix + ":" + name, out_axis, loss_scale); return weights_spec.size() - 1; } @@ -356,6 +364,7 @@ class InitLayerContext { std::string name; /**< name of the layer */ std::string prefix; /**< prefix of the layer */ std::array tensor_type; + float loss_scale; /**< loss_scale value */ }; /** diff --git a/nntrainer/layers/layer_node.cpp b/nntrainer/layers/layer_node.cpp index a7c5f049e4..8b18d80762 100644 --- a/nntrainer/layers/layer_node.cpp +++ b/nntrainer/layers/layer_node.cpp @@ -182,9 +182,10 @@ LayerNode::LayerNode(std::unique_ptr &&l) : needs_calc_gradient(false), output_connections(), run_context(nullptr), - layer_node_props(new PropsType( - props::Name(), props::Distribute(), props::Trainable(), {}, {}, - props::SharedFrom(), props::ClipGradByGlobalNorm(), props::Packed())), + layer_node_props( + new PropsType(props::Name(), props::Distribute(), props::Trainable(), {}, + {}, props::SharedFrom(), props::ClipGradByGlobalNorm(), + props::Packed(), props::LossScaleForMixed())), layer_node_props_realization( new RealizationPropsType(props::Flatten(), props::Activation())), loss(new props::Loss()), @@ -598,9 +599,13 @@ InitLayerContext LayerNode::finalize(const std::vector &input_dims, const auto &scope = getSharedFrom().empty() ? getName() : getSharedFrom(); float max_norm = 0.0; + float loss_scale = 0.0; if (!std::get(*layer_node_props).empty()) max_norm = std::get(*layer_node_props).get(); + if (!std::get(*layer_node_props).empty()) + loss_scale = std::get(*layer_node_props).get(); + if (!std::get(*layer_node_props).empty()) { bool isPacked = std::get(*layer_node_props); if (!isPacked) { @@ -622,7 +627,7 @@ InitLayerContext LayerNode::finalize(const std::vector &input_dims, auto context = InitLayerContext(actual_input_dims, out_info, executeInPlace() != InPlace::NONE, getName(), - scope, max_norm, tensor_type); + scope, max_norm, tensor_type, loss_scale); layer->finalize(context); diff --git a/nntrainer/layers/layer_node.h b/nntrainer/layers/layer_node.h index c1068b0f56..93e7ac7069 100644 --- a/nntrainer/layers/layer_node.h +++ b/nntrainer/layers/layer_node.h @@ -52,6 +52,7 @@ class SharedFrom; class InputConnection; class ClipGradByGlobalNorm; class Packed; +class LossScaleForMixed; } // namespace props /** @@ -939,11 +940,11 @@ will also contain the properties of the layer. The properties will be copied upon final creation. Editing properties of the layer after init will not the properties in the context/graph unless intended. */ - using PropsType = - std::tuple, - std::vector, props::SharedFrom, - props::ClipGradByGlobalNorm, props::Packed>; + using PropsType = std::tuple, + std::vector, + props::SharedFrom, props::ClipGradByGlobalNorm, + props::Packed, props::LossScaleForMixed>; using RealizationPropsType = std::tuple; /** these realization properties results in addition of new layers, hence diff --git a/nntrainer/models/model_common_properties.cpp b/nntrainer/models/model_common_properties.cpp index a1f560c49a..984cad662a 100644 --- a/nntrainer/models/model_common_properties.cpp +++ b/nntrainer/models/model_common_properties.cpp @@ -39,4 +39,6 @@ MemorySwapLookahead::MemorySwapLookahead(const unsigned int &value) { ModelTensorDataType::ModelTensorDataType(ModelTensorDataTypeInfo::Enum value) { set(value); } +LossScale::LossScale(float value) { set(value); } + } // namespace nntrainer::props diff --git a/nntrainer/models/model_common_properties.h b/nntrainer/models/model_common_properties.h index 791f9ed5d3..3776afefca 100644 --- a/nntrainer/models/model_common_properties.h +++ b/nntrainer/models/model_common_properties.h @@ -211,6 +211,17 @@ class ModelTensorDataType final : public EnumProperty { ModelTensorDataTypeInfo::Enum::W32A32); }; +/** + * @brief LossScale property, loss is scaled by this value + * + */ +class LossScale : public Property { +public: + LossScale(float value = 0.0f); + static constexpr const char *key = "loss_scale"; /**< unique key to access */ + using prop_tag = float_prop_tag; /**< property type */ +}; + } // namespace nntrainer::props #endif diff --git a/nntrainer/models/neuralnet.cpp b/nntrainer/models/neuralnet.cpp index af719237da..d0e542825f 100644 --- a/nntrainer/models/neuralnet.cpp +++ b/nntrainer/models/neuralnet.cpp @@ -65,7 +65,8 @@ namespace nntrainer { NeuralNetwork::NeuralNetwork() : - model_props(props::LossType(), {}, {}, props::ClipGradByGlobalNorm()), + model_props(props::LossType(), {}, {}, props::ClipGradByGlobalNorm(), + props::LossScale()), model_flex_props( props::Epochs(), props::TrainingBatchSize(), props::SavePath(), props::ContinueTrain(), props::SaveBestPath(), props::MemoryOptimization(), @@ -83,7 +84,8 @@ NeuralNetwork::NeuralNetwork() : } NeuralNetwork::NeuralNetwork(AppContext app_context_) : - model_props(props::LossType(), {}, {}, props::ClipGradByGlobalNorm()), + model_props(props::LossType(), {}, {}, props::ClipGradByGlobalNorm(), + props::LossScale()), model_flex_props( props::Epochs(), props::TrainingBatchSize(), props::SavePath(), props::ContinueTrain(), props::SaveBestPath(), props::MemoryOptimization(), @@ -189,6 +191,9 @@ int NeuralNetwork::compile() { !prop.empty()) { node->setProperty({"clip_grad_by_norm=" + to_string(prop)}); } + if (auto &prop = std::get(model_props); !prop.empty()) { + node->setProperty({"loss_scale=" + to_string(prop)}); + } model_graph.addLayer(node); } @@ -1018,6 +1023,7 @@ int NeuralNetwork::train_run( auto train_for_iteration = [this, stop_cb, stop_user_data](RunStats &stat, DataBuffer &buffer) { + ml_loge("train for iteration"); forwarding(true, stop_cb, stop_user_data); backwarding(iter++, stop_cb, stop_user_data); diff --git a/nntrainer/models/neuralnet.h b/nntrainer/models/neuralnet.h index 457b7d1e97..da1571a328 100644 --- a/nntrainer/models/neuralnet.h +++ b/nntrainer/models/neuralnet.h @@ -221,10 +221,11 @@ class NeuralNetwork : public ml::train::Model { /** * @brief Forward Propagation of the neural network */ - sharedConstTensors forwarding(bool training = true, - std::function stop_cb = - [](void *user_data) { return false; }, - void *user_data = nullptr); + sharedConstTensors forwarding( + bool training = true, + std::function stop_cb = + [](void *user_data) { return false; }, + void *user_data = nullptr); /** * @brief Forward Propagation of the neural network @@ -239,12 +240,11 @@ class NeuralNetwork : public ml::train::Model { /** * @brief Incremental forward Propagation of the neural network */ - sharedConstTensors - incremental_forwarding(unsigned int from, unsigned int to, - bool training = true, - std::function stop_cb = - [](void *user_data) { return false; }, - void *user_data = nullptr); + sharedConstTensors incremental_forwarding( + unsigned int from, unsigned int to, bool training = true, + std::function stop_cb = + [](void *user_data) { return false; }, + void *user_data = nullptr); /** * @brief Incremental forward Propagation of the neural network @@ -261,10 +261,11 @@ class NeuralNetwork : public ml::train::Model { * @brief Backward Propagation of the neural network * @param[in] iteration Iteration Number for the optimizer */ - void backwarding(int iteration, - std::function stop_cb = - [](void *user_data) { return false; }, - void *user_data = nullptr); + void backwarding( + int iteration, + std::function stop_cb = + [](void *user_data) { return false; }, + void *user_data = nullptr); /** * @copydoc Model::save(const std::string &file_path, ml::train::ModelFormat @@ -329,13 +330,14 @@ class NeuralNetwork : public ml::train::Model { * @retval #ML_ERROR_NONE Successful. * @retval #ML_ERROR_INVALID_PARAMETER invalid parameter. */ - int train(const std::vector &values = {}, - std::function stop_cb = - [](void *stop_user_data) { return false; }, - void *stop_user_data = nullptr, - std::function epoch_complete_cb = - [](void *epoch_user_data) { return false; }, - void *epoch_user_data = nullptr) override; + int train( + const std::vector &values = {}, + std::function stop_cb = + [](void *stop_user_data) { return false; }, + void *stop_user_data = nullptr, + std::function epoch_complete_cb = + [](void *epoch_user_data) { return false; }, + void *epoch_user_data = nullptr) override; /** * @brief Run NeuralNetwork inference @@ -630,7 +632,8 @@ s * @retval shared_ptr props::TensorFormat, props::ModelTensorDataType>; using RigidPropTypes = std::tuple, - std::vector, props::ClipGradByGlobalNorm>; + std::vector, props::ClipGradByGlobalNorm, + props::LossScale>; RigidPropTypes model_props; /**< model props */ FlexiblePropTypes model_flex_props; /**< model train props */ @@ -709,12 +712,12 @@ s * @retval shared_ptr * @retval #ML_ERROR_NONE Successful. * @retval #ML_ERROR_INVALID_PARAMETER invalid parameter. */ - int train_run(std::function stop_cb = - [](void *) { return false; }, - void *user_data = nullptr, - std::function epoch_complete_cb = - [](void *) { return false; }, - void *data = nullptr); + int train_run( + std::function stop_cb = [](void *) { return false; }, + void *user_data = nullptr, + std::function epoch_complete_cb = + [](void *) { return false; }, + void *data = nullptr); /** * @brief Swap function for the class diff --git a/nntrainer/tensor/manager.cpp b/nntrainer/tensor/manager.cpp index 4178330ebd..9a0d235ba9 100644 --- a/nntrainer/tensor/manager.cpp +++ b/nntrainer/tensor/manager.cpp @@ -52,10 +52,7 @@ namespace nntrainer { MMapedMemory::MMapedMemory(size_t size, bool allocate_fd_) : - fd(-1), - buf(nullptr), - buf_size(0), - allocate_fd(allocate_fd_) { + fd(-1), buf(nullptr), buf_size(0), allocate_fd(allocate_fd_) { #ifndef __ANDROID__ if (allocate_fd) { @@ -386,8 +383,9 @@ std::vector Manager::requestWeights( size_t current_size = weights_v2.size(); for (unsigned int i = 0; i < weights_spec.size(); ++i) { - auto &[dim, t_initializer, w_reg, w_reg_const, decay, clip_by_global_norm, - need_gradient, name, axis] = weights_spec.at(i); + auto &[dim_v, dim_g, t_initializer, w_reg, w_reg_const, decay, + clip_by_global_norm, need_gradient, name, axis, loss_scale] = + weights_spec.at(i); std::vector var_exec_order; for (auto order : default_var_exec_order) { @@ -422,7 +420,7 @@ std::vector Manager::requestWeights( /// shared_name is used and the orignal name is discarded const auto &shared_name = shared_names.at(i); /** case when shared names are given */ - var = weight_pool.requestOrExtend(shared_name, dim, var_exec_order, + var = weight_pool.requestOrExtend(shared_name, dim_v, var_exec_order, var_ls, t_initializer); if (trainable && need_gradient) { @@ -431,13 +429,13 @@ std::vector Manager::requestWeights( * for each layer anymore and it is hard to overwritten. */ grad = tensor_pool.requestOrExtend(shared_name + Var_Grad::grad_suffix, - dim, grad_exec_order, grad_ls, + dim_g, grad_exec_order, grad_ls, Tensor::Initializer::ZEROS); } } else { /** case requesting fresh weights */ var = - weight_pool.request(name, dim, var_exec_order, var_ls, t_initializer); + weight_pool.request(name, dim_v, var_exec_order, var_ls, t_initializer); if (trainable && need_gradient) { /** is_wgrad is the index which is true when it is the gradient tensor @@ -447,7 +445,7 @@ std::vector Manager::requestWeights( bool is_wgrad = true; if (Weight::isGradientClipByGlobalNorm(clip_by_global_norm)) is_wgrad = false; - grad = tensor_pool.request(name + Var_Grad::grad_suffix, dim, + grad = tensor_pool.request(name + Var_Grad::grad_suffix, dim_g, grad_exec_order, grad_ls, Tensor::Initializer::ZEROS, is_wgrad); } diff --git a/nntrainer/tensor/tensor_wrap_specs.h b/nntrainer/tensor/tensor_wrap_specs.h index 732d377ab5..6a5195fef5 100644 --- a/nntrainer/tensor/tensor_wrap_specs.h +++ b/nntrainer/tensor/tensor_wrap_specs.h @@ -73,10 +73,11 @@ enum class TensorLifespan { * * @details The tuple values are dimension, initializer, regularizer, * regularizer_constant, decay, clip gradient constant, need_gradient property, - * name and output axis of the tensor object. + * name, output axis of the tensor object and loss Scale Factor. */ -typedef std::tuple +typedef std::tuple WeightSpec; /** diff --git a/nntrainer/tensor/var_grad.cpp b/nntrainer/tensor/var_grad.cpp index 5fc5d8930d..09dbf6267e 100644 --- a/nntrainer/tensor/var_grad.cpp +++ b/nntrainer/tensor/var_grad.cpp @@ -38,6 +38,27 @@ Var_Grad::Var_Grad(const TensorDim &dim, const Tensor::Initializer init, grad = std::make_shared(grad_name); } +Var_Grad::Var_Grad(const TensorDim &dim_v, const TensorDim &dim_g, + const Tensor::Initializer init, bool need_gradient, + bool alloc_now, const std::string &name) : + is_dependent(false), + is_first_access_gradient(false), + is_last_access_gradient(false) { + var = std::make_shared(dim_v, alloc_now, init, name); + + std::string grad_name = name + grad_suffix; + if (need_gradient) + /** + * @todo gradient initializer should be none, and then they should be set + * zero right before using by the user itself. + */ + + grad = std::make_shared(dim_g, alloc_now, + Tensor::Initializer::ZEROS, grad_name); + else + grad = std::make_shared(grad_name); +} + void Var_Grad::initializeVariable(const Tensor &preallocated) { /** * Making a new tensor is intentional here as this tensor is not shared diff --git a/nntrainer/tensor/var_grad.h b/nntrainer/tensor/var_grad.h index dfe1b9a0b3..52cabbc055 100644 --- a/nntrainer/tensor/var_grad.h +++ b/nntrainer/tensor/var_grad.h @@ -59,6 +59,20 @@ class Var_Grad { bool ng = true, bool alloc_now = false, const std::string &name = ""); + /** + * @brief Construct a new Var_Grad object + * + * @param dim_v Variable tensor dimension + * @param dim_g Gradient tensor dimension + * @param ng If the variable is need_gradient + * @param alloc_now The memory for the var_grad tensors be allocated upon init + * @param name Name for this Var_Grad + */ + explicit Var_Grad(const TensorDim &dim_v, const TensorDim &dim_g, + const Tensor::Initializer init = Tensor::Initializer::NONE, + bool ng = true, bool alloc_now = false, + const std::string &name = ""); + /** * @brief Construct a new Var_Grad object * diff --git a/nntrainer/tensor/weight.cpp b/nntrainer/tensor/weight.cpp index 44f1f015b1..f98c8c8356 100644 --- a/nntrainer/tensor/weight.cpp +++ b/nntrainer/tensor/weight.cpp @@ -21,13 +21,33 @@ namespace nntrainer { Weight::Weight(const TensorDim &dim, const Tensor::Initializer init, const WeightRegularizer reg, const float reg_const, const float decay_const, const float max_norm, bool train, - bool alloc_now_, std::string name, unsigned int axis) : + bool alloc_now_, std::string name, unsigned int axis, + float loss_scale_) : Var_Grad(dim, init, train, alloc_now_, name), regularizer(reg), regularizer_constant(reg_const), decay(decay_const), clip_by_global_norm(max_norm), - output_axis(axis) { + output_axis(axis), + loss_scale(loss_scale_) { + if (init == Tensor::Initializer::NONE) + throw std::invalid_argument("Weight initializer cannot be none"); + if (regularizer == WeightRegularizer::UNKNOWN) + throw std::invalid_argument("Weight regularizer unknown"); +} + +Weight::Weight(const TensorDim &dim_v, const TensorDim &dim_g, + const Tensor::Initializer init, const WeightRegularizer reg, + const float reg_const, const float decay_const, + const float max_norm, bool train, bool alloc_now_, + std::string name, unsigned int axis, float loss_scale_) : + Var_Grad(dim_v, dim_g, init, train, alloc_now_, name), + regularizer(reg), + regularizer_constant(reg_const), + decay(decay_const), + clip_by_global_norm(max_norm), + output_axis(axis), + loss_scale(loss_scale_) { if (init == Tensor::Initializer::NONE) throw std::invalid_argument("Weight initializer cannot be none"); if (regularizer == WeightRegularizer::UNKNOWN) diff --git a/nntrainer/tensor/weight.h b/nntrainer/tensor/weight.h index bd1651bd15..552f6d5739 100644 --- a/nntrainer/tensor/weight.h +++ b/nntrainer/tensor/weight.h @@ -45,7 +45,8 @@ class Weight : public Var_Grad { regularizer_constant(1.0f), decay(0.0f), clip_by_global_norm(0.0f), - output_axis(3) {} + output_axis(3), + loss_scale(0.0) {} /** * @brief Construct a new Weight object @@ -64,7 +65,29 @@ class Weight : public Var_Grad { const WeightRegularizer reg = WeightRegularizer::NONE, const float reg_const = 1.0f, const float decay = 0.0f, const float clip_by_global_norm = 0.0f, bool ng = true, - bool alloc_now = false, std::string name = "", unsigned int axis = 3); + bool alloc_now = false, std::string name = "", unsigned int axis = 3, + float loss_scale_ = 0.0); + + /** + * @brief Construct a new Weight object + * + * @param dim_v Variable and gradient tensor dimension + * @param dim_g Gradient tensor dimension + * @param init Initializer for the weight + * @param reg Regularizer for the weight + * @param reg_const Constant multiplier for regularizer + * @param ng If the variable needs gradient + * @param alloc_now The memory for the weight tensors be allocated upon init + * @param name Name for this weight + */ + explicit Weight( + const TensorDim &dim_v, const TensorDim &dim_g, + const Tensor::Initializer init = Tensor::Initializer::XAVIER_UNIFORM, + const WeightRegularizer reg = WeightRegularizer::NONE, + const float reg_const = 1.0f, const float decay = 0.0f, + const float clip_by_global_norm = 0.0f, bool ng = true, + bool alloc_now = false, std::string name = "", unsigned int axis = 3, + float loss_scale_ = 0.0); /** * @brief Construct a new Weight object @@ -72,16 +95,18 @@ class Weight : public Var_Grad { * @param spec Weight specification */ explicit Weight(const Spec &spec, bool alloc_now = false) : - Weight(std::get<0>(spec), // TensorDim - std::get<1>(spec), // Tensor::Initializer - std::get<2>(spec), // WeightRegularizer - std::get<3>(spec), // WeightRegularizerConstant - std::get<4>(spec), // weight decay constant - std::get<5>(spec), // MaxNorm for clipping - std::get<6>(spec), // need_gradient + Weight(std::get<0>(spec), // TensorDim for Variable + std::get<1>(spec), // TensorDim for Gradient + std::get<2>(spec), // Tensor::Initializer + std::get<3>(spec), // WeightRegularizer + std::get<4>(spec), // WeightRegularizerConstant + std::get<5>(spec), // weight decay constant + std::get<6>(spec), // MaxNorm for clipping + std::get<7>(spec), // need_gradient alloc_now, - std::get<7>(spec), // Name - std::get<8>(spec) // out axis + std::get<8>(spec), // Name + std::get<9>(spec), // out axis + std::get<10>(spec) // loss scale ) {} /** @@ -105,7 +130,8 @@ class Weight : public Var_Grad { regularizer_constant(1.0f), decay(0.0f), clip_by_global_norm(0.0f), - output_axis(output_axis_) {} + output_axis(output_axis_), + loss_scale(0.0) {} /** * @brief Construct a new Weight object @@ -118,13 +144,14 @@ class Weight : public Var_Grad { explicit Weight(Tensor *v, Tensor *g, const WeightRegularizer reg, const float reg_const, const float decay, bool is_dependent = false, const float max_norm = 0.0f, - unsigned int output_axis_ = 3) : + unsigned int output_axis_ = 3, float loss_scale_ = 0.0f) : Var_Grad(v, g, is_dependent), regularizer(reg), regularizer_constant(reg_const), decay(decay), clip_by_global_norm(max_norm), - output_axis(output_axis_) {} + output_axis(output_axis_), + loss_scale(loss_scale_) {} /** * @brief Swap for weight @@ -142,6 +169,7 @@ class Weight : public Var_Grad { swap(lhs.clip_by_global_norm, rhs.clip_by_global_norm); swap(lhs.output_axis, rhs.output_axis); swap(lhs.opt_vars, rhs.opt_vars); + swap(lhs.loss_scale, rhs.loss_scale); } /** @@ -308,7 +336,9 @@ class Weight : public Var_Grad { float decay; /**< constant factor for the weight decay */ float clip_by_global_norm; /**< constant factor to clip gradient by L2 norm */ unsigned int output_axis; + float loss_scale; std::vector opt_vars; /**< optimizer variables */ + std::shared_ptr var32; /** * @brief Apply the weight decay to the weight diff --git a/nntrainer/utils/node_exporter.cpp b/nntrainer/utils/node_exporter.cpp index eabf9234f1..031d2c2fbf 100644 --- a/nntrainer/utils/node_exporter.cpp +++ b/nntrainer/utils/node_exporter.cpp @@ -91,7 +91,8 @@ void Exporter::saveTflResult( const std::tuple, std::vector, props::SharedFrom, - props::ClipGradByGlobalNorm, props::Packed> &props, + props::ClipGradByGlobalNorm, props::Packed, + props::LossScaleForMixed> &props, const LayerNode *self) { createIfNull(tf_node); tf_node->setLayerNode(*self); diff --git a/nntrainer/utils/node_exporter.h b/nntrainer/utils/node_exporter.h index 84c38894f1..de29cf77d9 100644 --- a/nntrainer/utils/node_exporter.h +++ b/nntrainer/utils/node_exporter.h @@ -234,6 +234,7 @@ class DisableBias; class Activation; class BatchNormalization; class Packed; +class LossScaleForMixed; } // namespace props class LayerNode; @@ -243,11 +244,11 @@ class LayerNode; */ template <> void Exporter::saveTflResult( - const std::tuple, std::vector, props::SharedFrom, - props::ClipGradByGlobalNorm, props::Packed> &props, + props::ClipGradByGlobalNorm, props::Packed, + props::LossScaleForMixed> &props, const LayerNode *self); class BatchNormalizationLayer;