diff --git a/runtime/onert/backend/train/ops/MeanLayer.cc b/runtime/onert/backend/train/ops/MeanLayer.cc index fc2b6c65221..2be83e82ce0 100644 --- a/runtime/onert/backend/train/ops/MeanLayer.cc +++ b/runtime/onert/backend/train/ops/MeanLayer.cc @@ -48,26 +48,11 @@ void MeanLayer::forward(bool) { cpu::ops::MeanLayer::run(); } void MeanLayer::backward() { - nnfw::cker::Shape keep_dim_shape; - if (_keep_dims == false) - { - keep_dim_shape.ReplaceWith(getShape(_input)); - auto axes_vec = cpu::ops::getReducerAxes(_axes); - for (const auto &axis : axes_vec) - { - keep_dim_shape.SetDim(axis, 1); - } - } - else - { - keep_dim_shape.ReplaceWith(getShape(_back_prop_input)); - } - switch (_back_prop_output->data_type()) { case OperandType::FLOAT32: { - nnfw::cker::train::MeanGrad(keep_dim_shape, getBuffer(_back_prop_output), + nnfw::cker::train::MeanGrad(getShape(_back_prop_output), getBuffer(_back_prop_output), getShape(_back_prop_input), getBuffer(_back_prop_input)); break; }