Skip to content

Commit

Permalink
[cker] Introduce CategoricalCrossEntropyWithLogits (#13938)
Browse files Browse the repository at this point in the history
This commit introduces CategoricalCrossEntropyWithLogits that computes cce loss value and gradient with nomalization(softmax).

ONE-DCO-1.0-Signed-off-by: ragmani <[email protected]>
  • Loading branch information
ragmani authored Sep 10, 2024
1 parent 6823a1b commit dc1745f
Show file tree
Hide file tree
Showing 4 changed files with 307 additions and 0 deletions.
151 changes: 151 additions & 0 deletions compute/cker/include/cker/eigen/xent_op.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
/*
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef __NNFW_CKER_EIGEN_XENT_OPS_H__
#define __NNFW_CKER_EIGEN_XENT_OPS_H__

// From tensorflow/core/kernels/xent_op.cc
#define EIGEN_USE_THREADS

#include "unsupported/Eigen/CXX11/Tensor"
#include "cker/operation/Helper/Tensor.h"

// From tensorflow/core/kernels/xent_op.h
namespace nnfw
{
namespace cker
{
namespace xent_ops
{
namespace functor
{

// Functor used by XentOp to do the computations.
template <typename Device, typename T> struct XentFunctor
{
// Computes Cross Entropy loss and backprop.
//
// logits: batch_size, num_classes.
// labels: batch_size, num_classes.
// scratch: temporary tensor, dims: batch_size, 1
// loss: output tensor for the loss, dims: batch_size.
// backprop: output tensor for the backprop, dims: batch_size, num_classes.
void operator()(const Device &d, const Eigen::DSizes<Eigen::DenseIndex, 2> &shape,
const Eigen::array<Eigen::DenseIndex, 2> &logits_bcast,
const Eigen::array<Eigen::DenseIndex, 2> &labels_bcast,
typename TTypes<T>::ConstMatrix logits, typename TTypes<T>::ConstMatrix labels,
typename TTypes<T>::Matrix scratch, typename TTypes<T>::Vec loss,
typename TTypes<T>::Matrix backprop);
};

} // namespace functor
} // namespace xent_ops
} // namespace cker
} // namespace nnfw

// From tensorflow/core/kernels/xent_op.cc
namespace nnfw
{
namespace cker
{
namespace xent_ops
{

// Enable CPUDevice only for xent_ops
using CPUDevice = Eigen::ThreadPoolDevice;
using Index = Eigen::Index;

// Partial specialization for a CPUDevice, that uses the Eigen implementation
// from XentEigenImpl.
namespace functor
{
template <typename Device, typename T> struct XentFunctorBase
{
void operator()(const Device &d, const Eigen::DSizes<Eigen::DenseIndex, 2> &shape,
const Eigen::array<Eigen::DenseIndex, 2> &logits_bcast,
const Eigen::array<Eigen::DenseIndex, 2> &labels_bcast,
typename TTypes<T>::ConstMatrix logits, typename TTypes<T>::ConstMatrix labels,
typename TTypes<T>::Matrix scratch, typename TTypes<T>::Vec loss,
typename TTypes<T>::Matrix backprop)
{
T *scratch_ptr = scratch.data();
T *backprop_ptr = backprop.data();

T *loss_ptr = loss.data();

int row_size = shape[1];

if (shape[0] > 0)
{
backprop.device(d) = logits.broadcast(logits_bcast);
scratch.device(d) = labels.broadcast(labels_bcast);
auto reductionWorker = [&](int64_t begin, int64_t end) -> void {
for (int i = begin; i < end; i++)
{
T *this_backprop = backprop_ptr + (i * row_size);
T *this_logits = backprop_ptr + (i * row_size);
T *this_labels = scratch_ptr + (i * row_size);
T max_logits = this_logits[0];

// calculating max_logits
for (int j = 1; j < row_size; j++)
{
max_logits = std::max(max_logits, this_logits[j]);
}

T sum = T(0);
T loss_sum = T(0);

for (int j = 0; j < row_size; j++)
{
// Note that if input is reused than this_logits and this_backprop
// is same buffer, so after this calculation this_logits should no
// longer be trusted
this_backprop[j] = this_logits[j] - max_logits;
sum = sum + exp(this_backprop[j]);
}

// loss calculation
T log_sum = log(sum);
for (int j = 0; j < row_size; j++)
{
loss_sum += this_labels[j] * (log_sum - this_backprop[j]);
this_backprop[j] = (exp(this_backprop[j]) / sum) - this_labels[j];
}
loss_ptr[i] = loss_sum;
}
};
const int64_t compute_cycles = 50 * row_size;
const int64_t input_bytes = sizeof(T) * row_size;
const int64_t output_bytes = sizeof(T) * row_size;
const Eigen::TensorOpCost cost(input_bytes, output_bytes, compute_cycles);

d.parallelFor(shape[0], cost, reductionWorker);
}
}
};

template <typename T> struct XentFunctor<CPUDevice, T> : XentFunctorBase<CPUDevice, T>
{
};

} // namespace functor
} // namespace xent_ops
} // namespace cker
} // namespace nnfw

#endif // __NNFW_CKER_EIGEN_XENT_OPS_H__
4 changes: 4 additions & 0 deletions compute/cker/include/cker/operation/Helper/Tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,10 @@ struct Tensor
{
return typename TTypes<T>::ConstScalar(base<T>());
}

template <typename T> typename TTypes<T>::Vec vec() { return shaped<T, 1>(); }

template <typename T> typename TTypes<T>::Matrix matrix() { return shaped<T, 2>(); }
}; // Tensor

template <typename DSizes> Eigen::DSizes<Index32, DSizes::count> To32BitDims(const DSizes &in)
Expand Down
68 changes: 68 additions & 0 deletions compute/cker/include/cker/train/operation/Loss.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
/*
* Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -20,7 +21,10 @@
#include <numeric>

#include "cker/Shape.h"
#include "cker/eigen/EigenSupport.h"
#include "cker/eigen/Utils.h"
#include "cker/eigen/xent_op.h"
#include "cker/operation/Helper/BCast.h"
#include "cker/train/Types.h"

namespace nnfw
Expand Down Expand Up @@ -135,6 +139,70 @@ inline void CategoricalCrossEntropyGrad(const Shape &y_pred_shape, const T *y_pr
grad = -(y_true.array() / y_pred.array().cwiseMax(log_threshold<T>()));
}

template <typename T>
void CategoricalCrossEntropyWithLogits(const Shape &logits_shape, const T *logits_data,
const Shape &y_true_shape, const T *y_true_data,
const Shape &loss_out_shape, T *loss_out_data,
const Shape &grad_shape, T *grad_data)
{
// TODO Enable sparse shapes
if (loss_out_shape.DimensionsCount() != 1)
throw std::runtime_error(
"cker::CategoricalCrossEntropyWithLogits: loss output dimension count should be 1");
if (logits_shape != y_true_shape)
throw std::runtime_error(
"cker::CategoricalCrossEntropyWithLogits: logits and y_true do not have the same shape");
if (loss_out_shape.Dims(0) != logits_shape.Dims(0))
throw std::runtime_error(
"cker::CategoricalCrossEntropyWithLogits: loss_out and logits do not have the same batch");
if (logits_shape != grad_shape)
throw std::runtime_error(
"cker::CategoricalCrossEntropyWithLogits: logits and grad do not have the same shape");

auto shape_in = logits_shape;

BCast bcast(BCast::FromShape(shape_in), BCast::FromShape(y_true_shape),
/*fewer_dims_optimization=*/false);

// loss is 1-D (one per example), and size is batch_size.

Tensor logits_in;
Tensor labels_in;
Tensor scratch;
Tensor loss_out;
Tensor back_out;

logits_in.shape.ReplaceWith(shape_in.DimensionsCount(), shape_in.DimsData());
logits_in.buffer = const_cast<T *>(logits_data);

labels_in.shape.ReplaceWith(y_true_shape.DimensionsCount(), y_true_shape.DimsData());
labels_in.buffer = const_cast<T *>(y_true_data);

scratch.shape.ReplaceWith(shape_in.DimensionsCount(), shape_in.DimsData());
std::vector<T> scratch_vec(shape_in.Dims(0) * shape_in.Dims(1), static_cast<T>(0));
scratch.buffer = scratch_vec.data();

Shape shape_loss_out{shape_in.Dims(0)};
loss_out.shape.ReplaceWith(shape_loss_out.DimensionsCount(), shape_loss_out.DimsData());
loss_out.buffer = loss_out_data;

back_out.shape.ReplaceWith(shape_in.DimensionsCount(), shape_in.DimsData());
back_out.buffer = grad_data;

if (shape_in.Dims(0) > 0)
{
const xent_ops::CPUDevice &device = *eigen_support::GetThreadPoolDevice();
xent_ops::functor::XentFunctor<xent_ops::CPUDevice, T> functor;
const Eigen::DSizes<Eigen::DenseIndex, 2> shape{shape_in.Dims(0), shape_in.Dims(1)};

functor(device, shape, BCast::ToIndexArray<2>(bcast.x_bcast()),
BCast::ToIndexArray<2>(bcast.y_bcast()),
logits_in.template shaped<const T, 2>(bcast.x_reshape()),
labels_in.template shaped<const T, 2>(bcast.y_reshape()), scratch.matrix<T>(),
loss_out.vec<T>(), back_out.matrix<T>());
}
}

} // namespace train
} // namespace cker
} // namespace nnfw
Expand Down
84 changes: 84 additions & 0 deletions compute/cker/src/train/Loss.test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,31 @@ template <typename T> class LossCCEVerifier
}
}

void verifyBackwardWithLogits(const std::vector<T> &logits, const std::vector<T> &y_true,
const std::vector<T> &expected_loss_out,
const std::vector<T> &expected_grad)
{
assert(logits.size() == y_true.size());
assert(logits.size() == expected_grad.size());

std::vector<T> loss_out(_out_shape.FlatSize());
std::vector<T> grad(_in_shape.FlatSize());

nnfw::cker::train::CategoricalCrossEntropyWithLogits(_in_shape, logits.data(), _in_shape,
y_true.data(), _out_shape, loss_out.data(),
_in_shape, grad.data());

for (int i = 0; i < loss_out.size(); ++i)
{
EXPECT_NEAR(loss_out[i], expected_loss_out[i], 1e-3f);
}

for (int i = 0; i < grad.size(); ++i)
{
EXPECT_NEAR(grad[i], expected_grad[i], 1e-3f);
}
}

void throwBackward(const std::vector<T> &y_pred, const std::vector<T> &y_true,
const std::vector<T> &expected)
{
Expand All @@ -99,6 +124,21 @@ template <typename T> class LossCCEVerifier
_in_shape, y_pred.data(), _in_shape, y_true.data(), _out_shape, output.data()));
}

void throwBackwardWithLogits(const std::vector<T> &logits, const std::vector<T> &y_true,
const std::vector<T> &expected_loss_out,
const std::vector<T> &expected_grad)
{
assert(logits.size() == y_true.size());
assert(logits.size() == expected_grad.size());

std::vector<T> loss_out(_out_shape.FlatSize());
std::vector<T> grad(_in_shape.FlatSize());

EXPECT_ANY_THROW(nnfw::cker::train::CategoricalCrossEntropyWithLogits(
_in_shape, logits.data(), _in_shape, y_true.data(), _out_shape, loss_out.data(), _in_shape,
grad.data()));
}

private:
const Shape _in_shape;
const Shape _out_shape;
Expand Down Expand Up @@ -391,6 +431,33 @@ TEST(CKer_Operation, LossCategoricalCrossEntropyGrad)
LossCCEVerifier<float> verifier(in_shape, grad_shape);
verifier.verifyBackward(y_pred, y_true, expected);
}

{
nnfw::cker::Shape in_shape{1, 10};
nnfw::cker::Shape out_shape{1};

std::vector<float> logits = {1, 3, 5, 35, 4, 5, 28, 9, 4, 6};
std::vector<float> y_true = {0, 0, 0, 0, 0, 0, 0, 0, 0, 1};
std::vector<float> expected_loss_out = {29.0009};
std::vector<float> expected_grad = {0, 0, 0, 0.9991, 0, 0, 0.0009, 0, 0, -1};

LossCCEVerifier<float> verifier(in_shape, out_shape);
verifier.verifyBackwardWithLogits(logits, y_true, expected_loss_out, expected_grad);
}

{
nnfw::cker::Shape in_shape{2, 10};
nnfw::cker::Shape out_shape{2};

std::vector<float> logits = {1, 3, 5, 35, 4, 5, 28, 9, 4, 6, 89, 3, 4, 5, 23, 1, 4, 5, 1, 101};
std::vector<float> y_true = {0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0};
std::vector<float> expected_loss_out = {29.0009, 12};
std::vector<float> expected_grad = {0, 0, 0, 0.9991, 0, 0, 0.0009, 0, 0, -1,
-1, 0, 0, 0, 0, 0, 0, 0, 0, 1};

LossCCEVerifier<float> verifier(in_shape, out_shape);
verifier.verifyBackwardWithLogits(logits, y_true, expected_loss_out, expected_grad);
}
}

TEST(CKer_Operation, neg_LossCategoricalCrossEntropyGrad)
Expand All @@ -408,3 +475,20 @@ TEST(CKer_Operation, neg_LossCategoricalCrossEntropyGrad)
verifier.throwBackward(y_pred, y_true, expected);
}
}

TEST(CKer_Operation, neg_LossCategoricalCrossEntropyWithLogits)
{
// Invalid out shape
{
nnfw::cker::Shape in_shape{1, 10};
nnfw::cker::Shape out_shape{1, 1};

std::vector<float> logits = {1, 3, 5, 35, 4, 5, 28, 9, 4, 6};
std::vector<float> y_true = {0, 0, 0, 0, 0, 0, 0, 0, 0, 1};
std::vector<float> expected_loss_out = {29.0009};
std::vector<float> expected_grad = {0, 0, 0, 0.9991, 0, 0, 0.0009, 0, 0, -1};

LossCCEVerifier<float> verifier(in_shape, out_shape);
verifier.throwBackwardWithLogits(logits, y_true, expected_loss_out, expected_grad);
}
}

0 comments on commit dc1745f

Please sign in to comment.