Skip to content

Commit

Permalink
Support Half/BFloat16 in nonzero (#7850)
Browse files Browse the repository at this point in the history
Partial fix for #7748.
  • Loading branch information
swolchok authored Jan 23, 2025
1 parent e000b22 commit 5db40f2
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 9 deletions.
7 changes: 3 additions & 4 deletions kernels/portable/cpu/op_nonzero.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,9 @@ Tensor& nonzero_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) {

ET_KERNEL_CHECK(ctx, check_nonzero_args(in, out), InvalidArgument, out);

ET_SWITCH_REAL_TYPES_AND(
Bool, in.scalar_type(), ctx, "nonzero.out", CTYPE, [&] {
nonzero<CTYPE>(ctx, in, out);
});
ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, "nonzero.out", CTYPE, [&] {
nonzero<CTYPE>(ctx, in, out);
});

return out;
}
Expand Down
8 changes: 3 additions & 5 deletions kernels/test/op_nonzero_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,8 @@ class OpNonzeroTest : public OperatorTest {
void test_dtype() {
TensorFactory<DTYPE> tf_input;
TensorFactory<ScalarType::Long> tf_long;
// clang-format off
Tensor a = tf_input.make(/*sizes=*/{2, 2}, /*data=*/{2, 0,
2, 4});
// clang-format on
Tensor a = tf_input.make(
/*sizes=*/{2, 2}, /*data=*/{CTYPE(2), CTYPE(0), CTYPE(2), CTYPE(4)});
Tensor out = tf_long.zeros({3, 2});

op_nonzero_out(a, out);
Expand All @@ -45,7 +43,7 @@ class OpNonzeroTest : public OperatorTest {

TEST_F(OpNonzeroTest, AllDtypesSupported) {
#define TEST_ENTRY(ctype, dtype) test_dtype<ctype, ScalarType::dtype>();
ET_FORALL_REAL_TYPES(TEST_ENTRY);
ET_FORALL_REALHBBF16_TYPES(TEST_ENTRY);
#undef TEST_ENTRY
}

Expand Down

0 comments on commit 5db40f2

Please sign in to comment.