-
Notifications
You must be signed in to change notification settings - Fork 159
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[luci/pass] Introduce FuseMulToFullyConnectedWeightsPass #13439
[luci/pass] Introduce FuseMulToFullyConnectedWeightsPass #13439
Conversation
This will introduce FuseMulToFullyConnectedWeightsPass which will fuse Mul to following FullyConnected weights if possible. ONE-DCO-1.0-Signed-off-by: SaeHie Park <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I left some questions, PTAL
=)
_mul_s->size<DT>(3); | ||
for (uint32_t i = 0; i < 3; ++i) | ||
{ | ||
_mul_s->at<DT>(0) = 1.0f; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i
index is not used in for
statement..!?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did you intend below code?
_mul_s->at<DT>(0) = 1.0f; | |
_mul_s->at<DT>(i) = 1.0f; |
_fc_w->size<DT>(4 * 6); | ||
for (uint32_t i = 0; i < 4 * 6; ++i) | ||
{ | ||
_fc_w->at<DT>(0) = 1.0f; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto..?
_fc_b->size<DT>(6); | ||
for (uint32_t i = 0; i < 6; ++i) | ||
{ | ||
_fc_b->at<DT>(0) = 1.0f; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
diditto..?
_fc_w->dim(0) = 3; | ||
_fc_w->dim(1) = 4; | ||
_fc_w->dtype(DT); | ||
_fc_w->size<DT>(4 * 6); | ||
for (uint32_t i = 0; i < 4 * 6; ++i) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The shape of _fc_w
is <3x4>
Does this size have to be the same to the shape..?
_fc_w->dim(0) = 3; | |
_fc_w->dim(1) = 4; | |
_fc_w->dtype(DT); | |
_fc_w->size<DT>(4 * 6); | |
for (uint32_t i = 0; i < 4 * 6; ++i) | |
_fc_w->dim(0) = 3; | |
_fc_w->dim(1) = 4; | |
_fc_w->dtype(DT); | |
_fc_w->size<DT>(3 * 4); | |
for (uint32_t i = 0; i < 3 * 4; ++i) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
or
_fc_w->dim(0) = 3; | |
_fc_w->dim(1) = 4; | |
_fc_w->dtype(DT); | |
_fc_w->size<DT>(4 * 6); | |
for (uint32_t i = 0; i < 4 * 6; ++i) | |
_fc_w->dim(0) = 4; | |
_fc_w->dim(1) = 6; | |
_fc_w->dtype(DT); | |
_fc_w->size<DT>(4 * 6); | |
for (uint32_t i = 0; i < 4 * 6; ++i) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please check about FullyConnected Op.
with torch,
import torch
tensor1 = torch.randn(3, 4)
FC = torch.nn.Linear(4, 6)
output = FC(tensor1)
print(output)
print(output.shape)
gives something like
tensor([[-0.8929, 0.2467, -0.1407, -0.9917, -0.3787, 0.6504],
[-0.2442, -0.0493, -0.1175, -1.0433, 0.0100, -0.3347],
[ 0.7240, 0.0195, -0.3735, -1.5789, -0.0348, -0.9252]],
grad_fn=<AddmmBackward0>)
torch.Size([3, 6])
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
more simple case is with vector,
tensor1 = torch.randn(4)
FC = torch.nn.Linear(4, 6)
output = FC(tensor1)
print(output)
print(output.shape)
gives
tensor([ 0.7355, -0.1311, 0.4749, -0.1821, -0.1245, -1.0594],
grad_fn=<ViewBackward0>)
torch.Size([6])
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah, should be
_fc_w->dim(0) = 6;
_fc_w->dim(1) = 4;
_mul->x(input()); | ||
_mul->y(_mul_s); | ||
_fc->input(_mul); | ||
_fc->weights(_fc_b); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should it be _fc_w
instead of _fc_b
..?
TEST_F(FuseMulToFullyConnectedWeightsPassS32Test, dtype_s32_NEG) | ||
{ | ||
_graph.init(); | ||
|
||
EXPECT_FALSE(_pass.run(_graph.g())); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test seems same with the fuse_mul_to_fc_weights
test, but it expects false
.
How could it be..!?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FuseMulToFullyConnectedWeightsPassS32Test
is made with S32, which datatype we can't support.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, it was S32
.
Thank you!
@seanshpark, |
This will introduce FuseMulToFullyConnectedWeightsPass which will
fuse Mul to following FullyConnected weights if possible.