Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add initial framework for TKW #18

Merged
merged 1 commit into from
Jun 19, 2024
Merged

Add initial framework for TKW #18

merged 1 commit into from
Jun 19, 2024

Conversation

martin-luecke
Copy link
Contributor

TODO: expand PR description, this is currently for enabling commenting

Detailed description of this project in shark_turbine/kernel/functional/README.md

@stellaraccident
Copy link
Collaborator

I saw this branch in my latest fetch and was intrigued. Looks interesting

Copy link
Contributor

@harsh-nod harsh-nod left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pushing this. Just some minor comments but looks good overall.

shark_turbine/kernel/_support/nodes.py Outdated Show resolved Hide resolved
shark_turbine/kernel/_support/nodes.py Outdated Show resolved Hide resolved
shark_turbine/kernel/_support/nodes.py Outdated Show resolved Hide resolved
shark_turbine/kernel/_support/nodes.py Outdated Show resolved Hide resolved
shark_turbine/kernel/_support/nodes.py Outdated Show resolved Hide resolved
shark_turbine/kernel/_support/tracing.py Outdated Show resolved Hide resolved
shark_turbine/kernel/functional/README.md Outdated Show resolved Hide resolved
shark_turbine/kernel/functional/README.md Outdated Show resolved Hide resolved
shark_turbine/kernel/functional/README.md Outdated Show resolved Hide resolved
shark_turbine/kernel/functional/README.md Outdated Show resolved Hide resolved
@stellaraccident
Copy link
Collaborator

One naming nit: aside from some functional inspiration, the big idea here is a wave level programming model. If we're allocating the third letter of the prefix to what something is, probably tkw vs tkf. I doubt this will be the only thing that takes more of a functional view of things. And in general, that is a how, not a what.

@martin-luecke martin-luecke force-pushed the wave_kernel_language branch from 6178ac8 to 095511f Compare May 15, 2024 09:46
@martin-luecke martin-luecke changed the title Add initial framework for TKF Wave programming language Add initial framework for TKW May 15, 2024
@martin-luecke martin-luecke force-pushed the wave_kernel_language branch 2 times, most recently from e3b0abc to 40f090f Compare May 15, 2024 10:06
Copy link
Contributor Author

@martin-luecke martin-luecke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the comments, they should now all be addressed.
I think the renaming @stellaraccident suggested makes sense so I renamed the project to tkw for now.

shark_turbine/kernel/_support/nodes.py Outdated Show resolved Hide resolved
shark_turbine/kernel/_support/nodes.py Outdated Show resolved Hide resolved
shark_turbine/kernel/functional/README.md Outdated Show resolved Hide resolved
shark_turbine/kernel/functional/README.md Outdated Show resolved Hide resolved
@martin-luecke martin-luecke force-pushed the wave_kernel_language branch from a2fc46c to fa8819b Compare May 15, 2024 10:18
@martin-luecke
Copy link
Contributor Author

I updated the IREE version here but the CI still seems to pick up a version which does not expose the amdgpu dialect. Do you know where to adjust this? @harsh-nod @stellaraccident

@stellaraccident
Copy link
Collaborator

Look to the failing CI run?

@martin-luecke martin-luecke force-pushed the wave_kernel_language branch 2 times, most recently from 45f1de8 to bb4f0ba Compare May 15, 2024 15:24
@martin-luecke
Copy link
Contributor Author

Sorry, I rushed a bit to get to some other work.

I only now understood that the CI pulls nod-ai/SHARK and uses those requirements.
The CI here contains an sed command to adjust the SHARK requirements from iree-turbine/main to the currently tested commit.
The sed command is here:

sed -i 's/iree-turbine#/iree-turbine.git@${{github.sha}}#/g' requirements.txt

and seems to be supposed to match here

However, it does not. I am not sure whether that is intended to test everything with the requirements of iree-turbine/main or an oversight.
I modified it to do the replacement, but that currently messes up torch dependencies. Also it still seems to not pull the iree version indicated in iree-requirements.txt. I assumed this is the place to modify the requirement, similar to how it is done in this patch
But I also see that in earlier CI runs unrelated to this PR a different iree version from the iree-requirements.txt is used for testing. e.g. here which pulls version 20240410.859 instead of version 20240514.893 indicated in the iree-requirements.

I will have to revisit this later.

@martin-luecke martin-luecke force-pushed the wave_kernel_language branch 2 times, most recently from 3424b26 to 7dfabaa Compare June 10, 2024 12:07
@martin-luecke martin-luecke marked this pull request as ready for review June 10, 2024 12:08
tests/kernel/wave_gemm_test.py Outdated Show resolved Hide resolved
shark_turbine/kernel/_support/nodes.py Outdated Show resolved Hide resolved
shark_turbine/kernel/_support/tracing.py Outdated Show resolved Hide resolved
@martin-luecke martin-luecke force-pushed the wave_kernel_language branch from 7dfabaa to 1b114fe Compare June 11, 2024 14:47
Copy link
Contributor Author

@martin-luecke martin-luecke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the comments @harsh-nod
I will explore tomorrow autogenerating more of the operation boilerplate and in particular see whether the registration of the custom ops to the context can be done differently.

tests/kernel/wave_gemm_test.py Outdated Show resolved Hide resolved
shark_turbine/kernel/_support/nodes.py Outdated Show resolved Hide resolved
shark_turbine/kernel/wave/ops.py Outdated Show resolved Hide resolved
shark_turbine/kernel/_support/tracing.py Outdated Show resolved Hide resolved
@martin-luecke martin-luecke force-pushed the wave_kernel_language branch from 1b114fe to 0167aaf Compare June 11, 2024 14:58
@martin-luecke martin-luecke marked this pull request as draft June 12, 2024 15:02
@martin-luecke martin-luecke marked this pull request as draft June 12, 2024 15:02
@martin-luecke martin-luecke force-pushed the wave_kernel_language branch from 0167aaf to ec5b437 Compare June 13, 2024 16:05
@martin-luecke martin-luecke marked this pull request as ready for review June 13, 2024 16:11
@martin-luecke martin-luecke force-pushed the wave_kernel_language branch from ec5b437 to c4db712 Compare June 14, 2024 06:27
@martin-luecke
Copy link
Contributor Author

Oops, seems my VSCode plugin glitched and I accidentally closed. Now reopened.
I pushed the lit tests to their own top-level directory as suggested to avoid test tools fighting each other.

shark_turbine/kernel/ops/wave_ops.py Show resolved Hide resolved
tests/lit.cfg.py Outdated Show resolved Hide resolved
shark_turbine/kernel/ops/wave_ops.py Outdated Show resolved Hide resolved
shark_turbine/kernel/ops/wave_ops.py Outdated Show resolved Hide resolved
shark_turbine/kernel/wave/codegen.py Outdated Show resolved Hide resolved
tests/kernel/wave/tracing.py Outdated Show resolved Hide resolved
@harsh-nod
Copy link
Contributor

This is great Martin and apart from a few minor changes, seems good to land. Thanks for the push in getting this PR to this point. This will give us a great foundation to build on.

@martin-luecke martin-luecke force-pushed the wave_kernel_language branch 2 times, most recently from da97c1c to 0565259 Compare June 17, 2024 14:14
@martin-luecke
Copy link
Contributor Author

Thank you @harsh-nod

Inspired by your comment I added the add_to_graph to our custom ops and took the time to simplify the interface for adding them to the graph after tracing.
In particular I removed the fx_op field from CustomOp which was required for initialization before. It is more a technical detail and remnant from the previous implementation. Instead, there is now a _tracing_function field on the class, which is set automatically by the define_op decorater.

With that the API for creating nodes after the tracing looks like this:

[...]
read = Read(a).add_to_graph(graph)
write = Write(read, a, 4).add_to_graph(graph)

@harsh-nod
Copy link
Contributor

Thank you @harsh-nod

Inspired by your comment I added the add_to_graph to our custom ops and took the time to simplify the interface for adding them to the graph after tracing. In particular I removed the fx_op field from CustomOp which was required for initialization before. It is more a technical detail and remnant from the previous implementation. Instead, there is now a _tracing_function field on the class, which is set automatically by the define_op decorater.

With that the API for creating nodes after the tracing looks like this:

[...]
read = Read(a).add_to_graph(graph)
write = Write(read, a, 4).add_to_graph(graph)

Thanks Martin, that's awesome. Thank you for the push and this is a great first PR to build on top of.

@harsh-nod harsh-nod self-requested a review June 17, 2024 16:34
Copy link
Contributor

@harsh-nod harsh-nod left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm! thanks!

@martin-luecke martin-luecke force-pushed the wave_kernel_language branch 5 times, most recently from 0c2dc53 to 58fafc4 Compare June 18, 2024 12:23
- Readme describing TKW approach
- initial op set: read, write, mma, tiled_loop, register
- wave decorator for block-based programming model
- types for yet to be specified Memory and for Registers
- unit tests for new types
- initial gemm test that traces successfully
- unit tests for tracing the initial set of ops
- add lit and filecheck dependencies for unit tests
- interface on top of fx.Node to rely less on string matching when navigating the fx.Graph.
- small changes to tracing contexts to use op handlers from node interfaces

Signed-off-by: Martin Lücke <[email protected]>
@martin-luecke martin-luecke force-pushed the wave_kernel_language branch from 58fafc4 to 5ddb524 Compare June 19, 2024 05:40
@martin-luecke martin-luecke merged commit e76dfea into main Jun 19, 2024
6 of 7 checks passed
@martin-luecke martin-luecke deleted the wave_kernel_language branch June 19, 2024 05:50
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants