From 190d26fff720922e4a819fbdd74a94a81e2e1899 Mon Sep 17 00:00:00 2001 From: Jang Jiseob Date: Mon, 7 Oct 2024 14:51:51 +0900 Subject: [PATCH] [onert] Apply softmax to CategoricalCrossEntropy automatically (#14105) This commit apply softmax automatically when using CategoricalCrossEntropy loss if models to be trained are not applied softmax. ONE-DCO-1.0-Signed-off-by: ragmani --- .../api/nnfw/include/nnfw_experimental.h | 8 ++++- .../onert/backend/train/KernelGenerator.cc | 5 +++- .../ops/LossCategoricalCrossentropyLayer.cc | 30 ++++++++++++------- .../ops/LossCategoricalCrossentropyLayer.h | 4 ++- .../core/include/ir/train/operation/Loss.h | 4 ++- .../train/TrainableOperationConverter.cc | 8 ++++- .../compiler/train/pass/LossInsertionPass.cc | 8 ++++- .../core/src/ir/train/TrainableGraph.test.cc | 8 ++++- .../core/src/ir/train/UseDefGenerator.test.cc | 8 ++++- .../onert/core/src/ir/train/operation/Loss.cc | 5 ++-- 10 files changed, 68 insertions(+), 20 deletions(-) diff --git a/runtime/onert/api/nnfw/include/nnfw_experimental.h b/runtime/onert/api/nnfw/include/nnfw_experimental.h index b8906eb4d97..27dc0c6f7dd 100644 --- a/runtime/onert/api/nnfw/include/nnfw_experimental.h +++ b/runtime/onert/api/nnfw/include/nnfw_experimental.h @@ -237,7 +237,13 @@ typedef struct nnfw_train_info float learning_rate = 0.001f; /** Batch size */ uint32_t batch_size = 1; - /** loss info */ + /** loss info + * Note that you don't need to worry about whether the model you use does not include softmax + * when you try to use NNFW_TRAIN_LOSS_CATEGORICAL_CROSSENTROPY. Using + * NNFW_TRAIN_LOSS_CATEGORICAL_CROSSENTROPY will ensure that the predicted input of loss is + * the result of performing softmax once regardless of whether the output of the model is + * the result of softmax or not. + */ nnfw_loss_info loss_info{.loss = NNFW_TRAIN_LOSS_MEAN_SQUARED_ERROR, .reduction_type = NNFW_TRAIN_LOSS_REDUCTION_SUM_OVER_BATCH_SIZE}; /** optimizer type */ diff --git a/runtime/onert/backend/train/KernelGenerator.cc b/runtime/onert/backend/train/KernelGenerator.cc index abbe17145c5..a6046224187 100644 --- a/runtime/onert/backend/train/KernelGenerator.cc +++ b/runtime/onert/backend/train/KernelGenerator.cc @@ -421,9 +421,12 @@ void KernelGenerator::visit(const ir::train::operation::Loss &node) } case ir::train::LossCode::CategoricalCrossentropy: { + const auto y_pred_op_code = node.y_pred_op_code(); + bool is_normalization_required = (y_pred_op_code != ir::OpCode::Softmax); auto fn = std::make_unique(); fn->configure(y_pred_tensor, y_true_tensor, output_tensor, back_prop_y_pred_tensor, - reduction_type, loss_param.cce.axis, loss_param.cce.label_smoothing); + reduction_type, loss_param.cce.axis, loss_param.cce.label_smoothing, + is_normalization_required); _return_fn = std::move(fn); break; } diff --git a/runtime/onert/backend/train/ops/LossCategoricalCrossentropyLayer.cc b/runtime/onert/backend/train/ops/LossCategoricalCrossentropyLayer.cc index 3ccd17dda21..a751dd163dd 100644 --- a/runtime/onert/backend/train/ops/LossCategoricalCrossentropyLayer.cc +++ b/runtime/onert/backend/train/ops/LossCategoricalCrossentropyLayer.cc @@ -28,17 +28,16 @@ namespace train namespace ops { -void LossCategoricalCrossentropyLayer::configure(const IPortableTensor *y_pred, - const IPortableTensor *y_true, - IPortableTensor *output, - IPortableTensor *back_prop_y_pred, - ir::train::LossReductionType reduction_type, - int32_t axis, float label_smoothing) +void LossCategoricalCrossentropyLayer::configure( + const IPortableTensor *y_pred, const IPortableTensor *y_true, IPortableTensor *output, + IPortableTensor *back_prop_y_pred, ir::train::LossReductionType reduction_type, int32_t axis, + float label_smoothing, bool is_normalization_required) { LossLayer::configure(y_pred, y_true, output, back_prop_y_pred, reduction_type); _axis = axis; _label_smoothing = label_smoothing; + _is_normalization_required = is_normalization_required; } void LossCategoricalCrossentropyLayer::forward(bool) @@ -59,12 +58,23 @@ void LossCategoricalCrossentropyLayer::backward() { assert(_back_prop_y_pred != nullptr); - const auto reduction_type = convertLossReductionType(_reduction_type); if (_y_pred->data_type() == OperandType::FLOAT32) { - nnfw::cker::train::CategoricalCrossEntropyGrad( - getShape(_y_pred), getBuffer(_y_pred), getShape(_y_true), getBuffer(_y_true), - getShape(_back_prop_y_pred), getBuffer(_back_prop_y_pred), reduction_type); + const auto reduction_type = convertLossReductionType(_reduction_type); + if (_is_normalization_required) + { + // TODO Eliminate duplicate calculations for output + nnfw::cker::train::CategoricalCrossEntropyWithLogits( + getShape(_y_pred), getBuffer(_y_pred), getShape(_y_true), getBuffer(_y_true), + getShape(_output), getBuffer(_output), getShape(_back_prop_y_pred), + getBuffer(_back_prop_y_pred), reduction_type); + } + else + { + nnfw::cker::train::CategoricalCrossEntropyGrad( + getShape(_y_pred), getBuffer(_y_pred), getShape(_y_true), getBuffer(_y_true), + getShape(_back_prop_y_pred), getBuffer(_back_prop_y_pred), reduction_type); + } } else { diff --git a/runtime/onert/backend/train/ops/LossCategoricalCrossentropyLayer.h b/runtime/onert/backend/train/ops/LossCategoricalCrossentropyLayer.h index 7e3419a71fb..dfd5ef93167 100644 --- a/runtime/onert/backend/train/ops/LossCategoricalCrossentropyLayer.h +++ b/runtime/onert/backend/train/ops/LossCategoricalCrossentropyLayer.h @@ -36,13 +36,15 @@ class LossCategoricalCrossentropyLayer : public LossLayer void configure(const IPortableTensor *y_pred, const IPortableTensor *y_true, IPortableTensor *output, IPortableTensor *back_prop_y_pred, - ir::train::LossReductionType reduction_type, int32_t axis, float label_smoothing); + ir::train::LossReductionType reduction_type, int32_t axis, float label_smoothing, + bool is_normalization_required); void forward(bool training) override; void backward() override; private: int32_t _axis{-1}; float _label_smoothing{0.0f}; + bool _is_normalization_required{false}; }; } // namespace ops diff --git a/runtime/onert/core/include/ir/train/operation/Loss.h b/runtime/onert/core/include/ir/train/operation/Loss.h index 35b0b55ab70..adc2cd017d9 100644 --- a/runtime/onert/core/include/ir/train/operation/Loss.h +++ b/runtime/onert/core/include/ir/train/operation/Loss.h @@ -38,7 +38,7 @@ class Loss : public ir::operation::Loss, public TrainableOperation using OperationType = ir::operation::Loss; public: - Loss(const OperationType &operation, const LossInfo &info); + Loss(const OperationType &operation, const LossInfo &info, ir::OpCode y_pred_op_code); public: std::unique_ptr clone() const override; @@ -49,9 +49,11 @@ class Loss : public ir::operation::Loss, public TrainableOperation public: const LossInfo ¶m() const { return _param; } + ir::OpCode y_pred_op_code() const { return _y_pred_op_code; } private: LossInfo _param; + ir::OpCode _y_pred_op_code; // The op code of the last node computing y_pred }; } // namespace operation diff --git a/runtime/onert/core/src/compiler/train/TrainableOperationConverter.cc b/runtime/onert/core/src/compiler/train/TrainableOperationConverter.cc index 80ed05aa524..299e0fc66e1 100644 --- a/runtime/onert/core/src/compiler/train/TrainableOperationConverter.cc +++ b/runtime/onert/core/src/compiler/train/TrainableOperationConverter.cc @@ -68,7 +68,13 @@ void TrainableOperationConverter::visit(const ir::operation::FullyConnected &nod void TrainableOperationConverter::visit(const ir::operation::Loss &node) { - _return_op = std::make_unique(node, _training_info->lossInfo()); + const auto &y_pred_index = node.getInputs().at(ir::operation::Loss::Input::Y_PRED); + const auto &y_pred = _tgraph.operands().at(y_pred_index); + const auto &y_pred_node = _tgraph.operations().at(y_pred.getDef()); + const auto y_pred_op_code = y_pred_node.opcode(); + + _return_op = + std::make_unique(node, _training_info->lossInfo(), y_pred_op_code); } void TrainableOperationConverter::visit(const ir::operation::Pad &node) diff --git a/runtime/onert/core/src/compiler/train/pass/LossInsertionPass.cc b/runtime/onert/core/src/compiler/train/pass/LossInsertionPass.cc index ea1f21e3091..baa0255d893 100644 --- a/runtime/onert/core/src/compiler/train/pass/LossInsertionPass.cc +++ b/runtime/onert/core/src/compiler/train/pass/LossInsertionPass.cc @@ -64,8 +64,14 @@ void LossInsertionPass::run() auto output_index = _trainable_graph.addOperand(output_shape, float_op); ir::OperandIndexSequence outputs{output_index}; + // The y_pred node information may be required in some loss layers (e.g., + // CategoricalCrossEntropy(SoftmaxCrossEntropy)); + const auto &y_pred_node = _trainable_graph.operations().at(y_pred.getDef()); + const auto y_pred_op_code = y_pred_node.opcode(); + auto loss_op = std::make_unique(inputs, outputs); - auto trainable_loss_op = std::make_unique(*loss_op, loss_info); + auto trainable_loss_op = + std::make_unique(*loss_op, loss_info, y_pred_op_code); trainable_loss_op->enableBackward(); _trainable_graph.addOperation(std::move(trainable_loss_op)); diff --git a/runtime/onert/core/src/ir/train/TrainableGraph.test.cc b/runtime/onert/core/src/ir/train/TrainableGraph.test.cc index 84df228907a..7b755dc9d18 100644 --- a/runtime/onert/core/src/ir/train/TrainableGraph.test.cc +++ b/runtime/onert/core/src/ir/train/TrainableGraph.test.cc @@ -62,8 +62,14 @@ OperationIndex addLossOperation(train::TrainableGraph &tgraph, const OperandInde const OperandIndexSequence outputs) { // Add "Loss" operation + const auto &y_pred_index = inputs.at(0); + const auto &y_pred = tgraph.operands().at(y_pred_index); + const auto &y_pred_node = tgraph.operations().at(y_pred.getDef()); + const auto y_pred_op_code = y_pred_node.opcode(); + auto loss_op = operation::Loss(inputs, outputs); - return tgraph.addOperation(std::make_unique(loss_op, train::LossInfo{})); + return tgraph.addOperation( + std::make_unique(loss_op, train::LossInfo{}, y_pred_op_code)); } TEST(TrainableGraph, topological_sort_linear) diff --git a/runtime/onert/core/src/ir/train/UseDefGenerator.test.cc b/runtime/onert/core/src/ir/train/UseDefGenerator.test.cc index a38ce0ac397..86bfcc3e824 100644 --- a/runtime/onert/core/src/ir/train/UseDefGenerator.test.cc +++ b/runtime/onert/core/src/ir/train/UseDefGenerator.test.cc @@ -70,8 +70,14 @@ OperationIndex addLossOperation(train::TrainableGraph &tgraph, const OperandInde const OperandIndexSequence outputs) { // Add "Loss" operation + const auto &y_pred_index = inputs.at(0); + const auto &y_pred = tgraph.operands().at(y_pred_index); + const auto &y_pred_node = tgraph.operations().at(y_pred.getDef()); + const auto y_pred_op_code = y_pred_node.opcode(); + auto loss_op = operation::Loss(inputs, outputs); - return tgraph.addOperation(std::make_unique(loss_op, train::LossInfo{})); + return tgraph.addOperation( + std::make_unique(loss_op, train::LossInfo{}, y_pred_op_code)); } train::UseDefChain createUseDefChain(const Operand &operand, diff --git a/runtime/onert/core/src/ir/train/operation/Loss.cc b/runtime/onert/core/src/ir/train/operation/Loss.cc index 3a89e0ff600..b95458dc4b5 100644 --- a/runtime/onert/core/src/ir/train/operation/Loss.cc +++ b/runtime/onert/core/src/ir/train/operation/Loss.cc @@ -36,8 +36,9 @@ void Loss::accept(OperationVisitor &v) const { v.visit(*this); } void Loss::accept(TrainableOperationVisitor &v) const { v.visit(*this); } -Loss::Loss(const OperationType &operation, const LossInfo ¶m) - : OperationType{operation.getInputs(), operation.getOutputs()}, _param{param} +Loss::Loss(const OperationType &operation, const LossInfo ¶m, ir::OpCode y_pred_op_code) + : OperationType{operation.getInputs(), operation.getOutputs()}, _param{param}, + _y_pred_op_code{y_pred_op_code} { // DO NOTHING }