-
Notifications
You must be signed in to change notification settings - Fork 159
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[luci/service] Support BMM dynamic shape inferece #13759
Conversation
9b84b16
to
856c510
Compare
I made some test cases with Please note when you make test cases. Python Test Codeimport tensorflow as tf
# Example 1: Simple 2D Matrices Batch
A = tf.constant([[[1, 2, 3], [4, 5, 6]],
[[7, 8, 9], [10, 11, 12]]], dtype=tf.float32) # Shape: (2, 2, 3)
B = tf.constant([[[1, 4], [2, 5], [3, 6]],
[[7, 10], [8, 11], [9, 12]]], dtype=tf.float32) # Shape: (2, 3, 2)
# Perform batch matrix multiplication using tf.matmul
result = tf.matmul(A, B)
print("Example 1 Result (Shape: {}):\n{}".format(result.shape, result.numpy()))
# Example 2: Broadcasting in Batch Dimensions
A = tf.constant([[[1, 2, 3], [4, 5, 6]],
[[7, 8, 9], [10, 11, 12]],
[[13, 14, 15], [16, 17, 18]]], dtype=tf.float32) # Shape: (3, 2, 3)
B = tf.constant([[[1, 4], [2, 5], [3, 6]]], dtype=tf.float32) # Shape: (1, 3, 2)
# Perform batch matrix multiplication with broadcasting using tf.matmul
result = tf.matmul(A, B)
print("Example 2 Result (Shape: {}):\n{}".format(result.shape, result.numpy()))
# Example 3: Edge case with zero-dimension batch
A = tf.random.normal([0, 2, 3]) # Shape: (0, 2, 3)
B = tf.random.normal([0, 3, 2]) # Shape: (0, 3, 2)
# Perform batch matrix multiplication using tf.matmul
result = tf.matmul(A, B)
print("Example 3 Result (Shape: {}):\n{}".format(result.shape, result.numpy()))
# Example 3-2: Edge case with zero-dimension batch
A = tf.random.normal([3, 2, 3]) # Shape: (3, 2, 3)
B = tf.random.normal([3, 2]) # Shape: (3, 2)
# Perform batch matrix multiplication using tf.matmul
result = tf.matmul(A, B)
print("Example 3-2 Result (Shape: {}):\n{}".format(result.shape, result.numpy()))
# Example 4: Edge case with zero-dimension batch
A = tf.random.normal([2, 2, 3]) # Shape: (2, 2, 3)
B = tf.random.normal([0, 3, 2]) # Shape: (0, 3, 2)
# Perform batch matrix multiplication using tf.matmul
#result = tf.matmul(A, B)
print("Example 4 Result EXCEPTION")
# Example 5: Edge case with rank-2 batch
A = tf.random.normal([2, 3]) # Shape: (2, 3)
B = tf.random.normal([3, 2]) # Shape: (3, 2)
# Perform batch matrix multiplication using tf.matmul
result = tf.matmul(A, B)
print("Example 5 Result (Shape: {}):\n{}".format(result.shape, result.numpy()))
# Example 6: Edge case with rank-2 and rank-1 atch
A = tf.random.normal([2, 3]) # Shape: (2, 3)
B = tf.random.normal([3]) # Shape: (3)
# Perform batch matrix multiplication using tf.matmul
#result = tf.matmul(A, B)
print("Example 6 Result EXCEPTION") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I left some comments, PTAL
=)
This PR supports dynamic shpae inference ofr BatchMatMul Op. ONE-DCO-1.0-Signed-off-by: SeungHui Youn <[email protected]>
LGTM with minor comment https://github.com/Samsung/ONE/pull/13759/files#r1733773189 |
262d0fe
to
72a02d3
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
throw_unless((not contain_zero(x_shape)), "x_shape should NOT have 0"); | ||
throw_unless((not contain_zero(y_shape)), "y_shape should NOT have 0"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(note) to resolve https://github.com/Samsung/ONE/pull/13759/files#r1732922382
TEST(ShapeRuleTest, bmm_empty_input_NEG) | ||
{ | ||
luci::CircleInput input_x; | ||
luci::CircleInput input_y; | ||
luci::CircleBatchMatMul bmm; | ||
|
||
input_x.shape({0, 2}); | ||
input_x.shape_status(luci::ShapeStatus::VALID); | ||
|
||
input_y.shape({2, 4}); | ||
input_y.shape_status(luci::ShapeStatus::VALID); | ||
|
||
bmm.x(&input_x); | ||
bmm.y(&input_y); | ||
|
||
loco::TensorShape shape; | ||
luci::sinf::Rule shape_inf_rule; | ||
|
||
// (0, 2) x (2, 4) | ||
// ^ | ||
// => error, x should not be empty | ||
ASSERT_ANY_THROW(shape_inf_rule.infer(&bmm, shape)); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(newly added) to handle empty input
TEST(ShapeRuleTest, bmm_broadcast_known_dim_2) | ||
{ | ||
luci::CircleInput input_x; | ||
luci::CircleInput input_y; | ||
luci::CircleBatchMatMul bmm; | ||
|
||
input_x.shape({5, 4, 3}); | ||
input_x.shape_status(luci::ShapeStatus::VALID); | ||
|
||
input_y.shape({3, 8}); | ||
input_y.shape_status(luci::ShapeStatus::VALID); | ||
|
||
bmm.x(&input_x); | ||
bmm.y(&input_y); | ||
|
||
loco::TensorShape shape; | ||
luci::sinf::Rule shape_inf_rule; | ||
|
||
ASSERT_TRUE(shape_inf_rule.infer(&bmm, shape)); | ||
|
||
// (5, 4, 3) x (3, 8) -> (5, 3, 8) | ||
// output shape should be (5, 4, 8) | ||
ASSERT_EQ(3, shape.rank()); | ||
ASSERT_TRUE(shape.dim(0).known()); | ||
ASSERT_TRUE(shape.dim(1).known()); | ||
ASSERT_TRUE(shape.dim(2).known()); | ||
ASSERT_EQ(5, shape.dim(0).value()); | ||
ASSERT_EQ(4, shape.dim(1).value()); | ||
ASSERT_EQ(8, shape.dim(2).value()); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ys44kim I added the TC for the case you asked about. :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you, LGTM!
=)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Ping @ys44kim |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
@@ -30,4 +76,72 @@ luci::CircleNode *CloneNodeLet<CN::ABC>::visit(const luci::CircleBatchMatMul *no | |||
return cloned; | |||
} | |||
|
|||
// BatchMatMulV2 supports broadcasting in the batch dimensions(BatchMatMul doesn't) | |||
// TODO Distinguish BatchMatMul and BatchMatMulV2 | |||
loco::TensorShape sinf::Algorithm::visit(const luci::CircleBatchMatMul *node) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@zetwhite , please check #13780 (review)
can you apply here too? or as getting approvals takes time and you've go all ACKs,
apply in another PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if go with another PR, you can apply changes to all Ops (I think there aren't too many)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd like to merge this PR as is. (not for waiting approvals again)
If fine, I'll make another PR to fix it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you can apply changes to all Ops (I think there aren't too many)
Oh, I see. I apply it at once.
ah... commit
|
This PR supports dynamic shpae inference ofr BatchMatMul Op.
ONE-DCO-1.0-Signed-off-by: SeungHui Youn [email protected]
for : #13697