Skip to content

Commit

Permalink
apply code review
Browse files Browse the repository at this point in the history
  • Loading branch information
zetwhite committed Jan 17, 2024
1 parent 99ec326 commit b7d9614
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
12 changes: 6 additions & 6 deletions compute/cker/include/cker/train/operation/ReLU6.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,12 @@ inline void ReLU6Grad(const Shape &output_shape, const float *output_data,
const auto incoming_map = MapAsVector(incoming_data, incoming_shape);
auto grad_map = MapAsVector(grad_data, grad_shape);

if (output_shape == incoming_shape && output_shape == grad_shape)
grad_map.array() =
incoming_map.array() *
(0.0f < output_map.array() && output_map.array() < 6.0f).template cast<float>();
else
throw std::runtime_error("cker::ReLU6Grad: Unsupported shape");
if (output_shape != incoming_shape || output_shape != grad_shape)
throw std::runtime_error{"cker::ReLU6Grad: Unsupported shape"};

grad_map.array() =
incoming_map.array() *
(0.0f < output_map.array() && output_map.array() < 6.0f).template cast<float>();
}

} // namespace train
Expand Down
2 changes: 1 addition & 1 deletion compute/cker/src/train/Relu6.test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ TEST(CKer_Operation, ReLU6)

// clang-format off
// std::vector<float> input_fwd = {-2.0, -1.0, 2.0, 3.0, 6.0, 7.0};
std::vector<float> output_fwd = { 0.0, 0.0, 2.0, 3.0, 6.0, 7.0};
std::vector<float> output_fwd = { 0.0, 0.0, 2.0, 3.0, 6.0, 6.0};
std::vector<float> input_bwd = {-0.1, -0.2, 0.3, 0.4, -0.1, 0.5};
std::vector<float> expected_output_bwd = { 0.0, 0.0, 0.3, 0.4, 0.0, 0.0};
// clang-format on
Expand Down

0 comments on commit b7d9614

Please sign in to comment.