-
Notifications
You must be signed in to change notification settings - Fork 105
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
#14080: Preprocess weights for Conv2D on Device #16750
base: main
Are you sure you want to change the base?
Conversation
fe919e2
to
e00b7ec
Compare
6351705
to
7662eba
Compare
7f3a9c0
to
c5c4540
Compare
@sankarmanoj-tt TODO: Re-enable transpose cast |
c5c4540
to
6c6da4d
Compare
a967aee
to
644f04b
Compare
667a8a4
to
687fc72
Compare
ttnn/ttnn/__init__.py
Outdated
HEIGHT_SHARDED_LAYOUT, | ||
BLOCK_SHARDED_LAYOUT, | ||
WIDTH_SHARDED_LAYOUT, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please revert. You're not using this and I'd prefer users access enum via TensorMemoryLayout.HEIGHT_SHARDED_LAYOUT
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reverted.
@@ -58,6 +58,7 @@ def run_conv( | |||
config_override, | |||
dilation=1, | |||
use_shallow_conv_variant=False, | |||
transpose_shards=True, # TODO: Fails when set to False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why? Should this be tracked in a github issue?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a bug that needs to be fixed. I've created an issue here to track it.
template <typename T> | ||
std::pair<ttnn::Tensor, std::optional<ttnn::Tensor>> prepare_conv_weights_biases_on_device( | ||
const ttnn::Tensor& weight_tensor, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's a lot of conv code that does templating on device, but shouldn't need to be. Are there some plans to clean it up?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the correct way to support MeshDevice and Device?
@@ -1212,7 +1220,7 @@ def test_resnet50_conv_wh_fp32( | |||
) | |||
@pytest.mark.parametrize( | |||
"weights_dtype", | |||
[ttnn.bfloat8_b], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
revert or intentional change?
@@ -1354,7 +1362,7 @@ def test_sd_conv( | |||
) | |||
@pytest.mark.parametrize( | |||
"activations_dtype", | |||
[ttnn.bfloat16, ttnn.bfloat8_b], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
revert or intentional change?
@@ -1495,7 +1503,7 @@ def test_sd_conv_wh( | |||
) | |||
@pytest.mark.parametrize( | |||
"weights_dtype", | |||
[ttnn.bfloat8_b], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
revert or intentional change?
@@ -2007,6 +2020,7 @@ def test_halo_reshard_conv( | |||
) | |||
|
|||
|
|||
@skip_for_grayskull() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this skiped now on GS?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This fails because the device ops used to prepare the weights don't support FP32 on grayskull.
@@ -2618,7 +2633,7 @@ def test_conv_for_vanilla_unet( | |||
) | |||
@pytest.mark.parametrize( | |||
"weights_dtype", | |||
[ttnn.bfloat8_b, ttnn.bfloat16], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
revert? or is this intentional change?
@@ -2855,6 +2873,7 @@ def test_shallow_conv_with_tiled_input(device): | |||
|
|||
# Tests running conv2d which maps to matmul w/o sharding the input tensor. | |||
# Output tensor is in DRAM. | |||
@skip_for_grayskull() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this now skiped on GS?
On a high level
|
9f806ae
to
0435cce
Compare
Ticket
#14080
Problem description
Currently weights preprocessing takes place on the host, on a single thread. This is slow, especially when there is a large weights matrix, and Debug mode is enabled.
What's changed
The weights are loaded to the device in the same format as PyTorch. All other processing, including permute, padding, etc are done on the Device.
Checklist