Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

softmax with EVT #177

Open
wants to merge 3 commits into
base: sycl-develop
Choose a base branch
from

Conversation

jiyang1011
Copy link
Collaborator

No description provided.

examples/sycl/pvc/pvc_gemm_with_epilogue_softmax.cpp Outdated Show resolved Hide resolved
examples/sycl/pvc/pvc_gemm_with_epilogue_softmax.cpp Outdated Show resolved Hide resolved
examples/sycl/pvc/pvc_gemm_with_epilogue_softmax.cpp Outdated Show resolved Hide resolved
include/cutlass/epilogue/fusion/xe_vistor_softmax.hpp Outdated Show resolved Hide resolved
include/cutlass/epilogue/fusion/xe_vistor_softmax.hpp Outdated Show resolved Hide resolved
include/cutlass/epilogue/fusion/xe_vistor_softmax.hpp Outdated Show resolved Hide resolved
include/cutlass/epilogue/fusion/xe_vistor_softmax.hpp Outdated Show resolved Hide resolved
include/cutlass/epilogue/fusion/xe_vistor_softmax.hpp Outdated Show resolved Hide resolved
include/cutlass/epilogue/fusion/xe_vistor_softmax.hpp Outdated Show resolved Hide resolved
include/cutlass/epilogue/fusion/xe_vistor_softmax.hpp Outdated Show resolved Hide resolved
@jiyang1011 jiyang1011 force-pushed the jiyang/softmax branch 3 times, most recently from 0651b94 to 7fdf576 Compare January 21, 2025 08:33
@jiyang1011 jiyang1011 changed the title softmax with EVT (draft) softmax with EVT Jan 26, 2025
@@ -312,7 +313,7 @@ class CollectiveEpilogue<
bool is_C_load_needed = is_source_supported && fusion_callbacks.is_C_load_needed();

Tensor trC = make_tensor<typename TiledMma::ValTypeC>(Shape<Int<FragmentSize>>{});
Tensor trD = make_tensor<typename TiledMma::ValTypeD>(Shape<Int<FragmentSize>>{});
Tensor trD = make_tensor<typename TiledMma::ValTypeD>(Shape<Int<FragmentSize>, Int<FragsM>, Int<FragsN>>{});
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If a callback needs to store all the values for Int and Int, that should be done inside the callback instead of in the generic path.

}
copy(params.xe_store_d, trD, rw_coord(_, epi_m, epi_n));
cst_callbacks.reduce(nullptr, synchronize, epi_m, epi_n, (epi_m == FragsM - 1 && epi_n == FragsN - 1), trD(_, epi_m, epi_n));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!!!

Comment on lines +387 to +393
CUTLASS_PRAGMA_UNROLL
for (int epi_n = 0; epi_n < FragsN; epi_n++) {
CUTLASS_PRAGMA_UNROLL
for (int epi_m = 0; epi_m < FragsM; epi_m++) {
copy(params.xe_store_d, trD(_, epi_m, epi_n), rw_coord(_, epi_m, epi_n));
}
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change will group all the store operations to run at the end.

In the SM90 implementation, there’s a condition around the copy operation to check if the D matrix is available. We can apply the same logic here.

If D isn't available, the softmax output should be stored in a different matrix, which will be added directly to the callback. This way, the generic path won’t know about that matrix. Since the reduce method is the only one that knows when the data is ready to be copied, I suggest we handle the copy operations within that method.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added the if condition to skip the copy in PR#199. Now, you should be able to revert the changes in the generic path for the epilogue and implement all the softmax functionalities inside the callback.

float cute_time = timer.seconds() / options.iterations;
double tflops = (2.0 * options.m * options.n * options.k * options.l) * 1e-12;
std::cout << "Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl;
printf("Cutlass GEMM Performance: [%4.3f]GB/s (%6.4f)ms\n", io / cute_time, cute_time*1000);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
printf("Cutlass GEMM Performance: [%4.3f]GB/s (%6.4f)ms\n", io / cute_time, cute_time*1000);
printf("Cutlass GEMM Performance: %4.3f GB/s , %4.3f TF/s , %6.4f ms\n", io / cute_time, tflops/cute_time, cute_time*1000);

Comment on lines +69 to +70
SYCL_DEVICE_OCL(float sub_group_reduce_add(float i));
SYCL_DEVICE_OCL(float sub_group_reduce_max(float i));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the reduce operation is not part of mma

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants