Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#14080: Preprocess weights for Conv2D on Device #16750

Open
wants to merge 26 commits into
base: main
Choose a base branch
from

Conversation

sankarmanoj-tt
Copy link
Contributor

@sankarmanoj-tt sankarmanoj-tt commented Jan 15, 2025

Ticket

#14080

Problem description

Currently weights preprocessing takes place on the host, on a single thread. This is slow, especially when there is a large weights matrix, and Debug mode is enabled.

What's changed

The weights are loaded to the device in the same format as PyTorch. All other processing, including permute, padding, etc are done on the Device.

Checklist

  • Post commit CI passes
  • (For models and ops writers) Full new models tests passes
  • New/Existing tests provide coverage for changes

@sankarmanoj-tt sankarmanoj-tt force-pushed the smanoj/conv_device_weights branch from fe919e2 to e00b7ec Compare January 15, 2025 10:53
@sankarmanoj-tt sankarmanoj-tt force-pushed the smanoj/conv_device_weights branch 2 times, most recently from 6351705 to 7662eba Compare January 29, 2025 13:34
@sankarmanoj-tt sankarmanoj-tt marked this pull request as ready for review January 30, 2025 12:05
@sankarmanoj-tt sankarmanoj-tt force-pushed the smanoj/conv_device_weights branch from 7f3a9c0 to c5c4540 Compare January 31, 2025 08:38
@sankarmanoj-tt
Copy link
Contributor Author

@sankarmanoj-tt TODO: Re-enable transpose cast

@sankarmanoj-tt sankarmanoj-tt force-pushed the smanoj/conv_device_weights branch from c5c4540 to 6c6da4d Compare February 3, 2025 10:12
@sankarmanoj-tt sankarmanoj-tt force-pushed the smanoj/conv_device_weights branch 2 times, most recently from a967aee to 644f04b Compare February 5, 2025 17:08
@sankarmanoj-tt sankarmanoj-tt force-pushed the smanoj/conv_device_weights branch from 667a8a4 to 687fc72 Compare February 13, 2025 04:40
Comment on lines 132 to 134
HEIGHT_SHARDED_LAYOUT,
BLOCK_SHARDED_LAYOUT,
WIDTH_SHARDED_LAYOUT,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please revert. You're not using this and I'd prefer users access enum via TensorMemoryLayout.HEIGHT_SHARDED_LAYOUT

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reverted.

@@ -58,6 +58,7 @@ def run_conv(
config_override,
dilation=1,
use_shallow_conv_variant=False,
transpose_shards=True, # TODO: Fails when set to False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why? Should this be tracked in a github issue?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a bug that needs to be fixed. I've created an issue here to track it.

template <typename T>
std::pair<ttnn::Tensor, std::optional<ttnn::Tensor>> prepare_conv_weights_biases_on_device(
const ttnn::Tensor& weight_tensor,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a lot of conv code that does templating on device, but shouldn't need to be. Are there some plans to clean it up?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the correct way to support MeshDevice and Device?

@@ -1212,7 +1220,7 @@ def test_resnet50_conv_wh_fp32(
)
@pytest.mark.parametrize(
"weights_dtype",
[ttnn.bfloat8_b],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

revert or intentional change?

@@ -1354,7 +1362,7 @@ def test_sd_conv(
)
@pytest.mark.parametrize(
"activations_dtype",
[ttnn.bfloat16, ttnn.bfloat8_b],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

revert or intentional change?

@@ -1495,7 +1503,7 @@ def test_sd_conv_wh(
)
@pytest.mark.parametrize(
"weights_dtype",
[ttnn.bfloat8_b],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

revert or intentional change?

@@ -2007,6 +2020,7 @@ def test_halo_reshard_conv(
)


@skip_for_grayskull()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this skiped now on GS?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This fails because the device ops used to prepare the weights don't support FP32 on grayskull.

@@ -2618,7 +2633,7 @@ def test_conv_for_vanilla_unet(
)
@pytest.mark.parametrize(
"weights_dtype",
[ttnn.bfloat8_b, ttnn.bfloat16],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

revert? or is this intentional change?

@@ -2855,6 +2873,7 @@ def test_shallow_conv_with_tiled_input(device):

# Tests running conv2d which maps to matmul w/o sharding the input tensor.
# Output tensor is in DRAM.
@skip_for_grayskull()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this now skiped on GS?

@pavlejosipovic
Copy link
Contributor

On a high level

  1. This change doesn't allow for user to pass in a tensor on device and get it preprocessed instead we have two paths that do this with tensor that has to be on host
  2. New codepath that performs more (but not all preprocessing on device) is not default, but all our tests are using just that? So default code path is not tested and we have two codepaths for the same thing.
  3. This change extends the runtime of our tests on post-commit?
    @mywoodstock any thoughts on the above.

@sankarmanoj-tt sankarmanoj-tt force-pushed the smanoj/conv_device_weights branch from 9f806ae to 0435cce Compare February 18, 2025 23:06
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants