Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Distributed GEMM #1907

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open

Conversation

alihassanijr
Copy link
Contributor

Adds experimental support for running tensor parallel GEMMs natively through CUTLASS.

Distributed GEMM (DistGEMM) implements communication-fused GEMMs using point-to-point communication, which allows for better pipelining, and theoretically can hide all communication behind computation. It also makes very few assumptions about the underlying kernel, and only adds a few barriers to the beginning of each GEMM kernel, and attempts to either use the epilogue source as the communication buffer, or a memcopy branch in the cuda graph, leaving SMs free for GEMMs, and the copy engine free for communication.

When benchmarked with Llama 70B and 405B training shapes, DistGEMM can reach 70-80% of peak performance.

A more detailed blog post on DistGEMM will be released soon.

Adds experimental support for running tensor parallel GEMMs natively
through CUTLASS.

Distributed GEMM (DistGEMM) implements communication-fused GEMMs using
point-to-point communication, which allows for better pipelining, and
theoretically can hide all communication behind computation.
It also makes very few assumptions about the underlying kernel, and only
adds a few barriers to the beginning of each GEMM kernel, and attempts
to either use the epilogue source as the communication buffer, or a
memcopy branch in the cuda graph, leaving SMs free for GEMMs, and the
copy engine free for communication.

When benchmarked with Llama 70B and 405B training shapes, DistGEMM can
reach 70-80% of peak performance.

A more detailed blog post on DistGEMM will be released soon.
@alihassanijr
Copy link
Contributor Author

@thakkarV @hwu36

/* ProcessorMappingA_ = */ cute::Layout<cute::Shape<TP_, _1>, cute::Stride<_1, _0>>, // (identity) = device_idx
/* ProcessorMappingB_ = */ cute::Layout<cute::Shape<TP_, _1>, cute::Stride<_1, _0>>, // (identity) = device_idx
/* ProcessorMappingC_ = */ cute::Layout<cute::Shape<TP_, _1>, cute::Stride<_1, _0>>, // (identity) = device_idx
/* ProcessorMappingD_ = */ cute::Layout<cute::Shape<TP_, _1>, cute::Stride<_1, _0>>, // (identity) = device_idx
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The "ProcessorMapping" is the same in all of these schedules and are used in-code very trivially.

(1) What do these parameterize and what other versions of this parameter do you expect to support?
(2) Why do they need to be CuTe Layouts? The sizes, shapes, ranks, strides are never used.
(3) The comments next to these parameters do not explain the domain of this function or the codomain.
(4) The size of the "bias mode" is always 1 and the stride is always 0, yet these are indexed with a coordinate 1 in that mode always. This seems like a poor parameterization and Layouts are not what you actually want.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right, ProcessorMapping is definitely unnecessary at this point. It was originally there in case the first iteration has a remote buffer, in which case they would map device index to the peer device's index.

I'll get rid of this; but just to clarify this is only in reference to ProcessorMapping and not IterationMapping?

Copy link

@ccecka ccecka Oct 31, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm still picking apart "IterationMapping"... it seems redundant with "PeerDeviceMapping", but that's still unclear to me. It also appears that the relationships between MNKL are not respected because each of these "Layout"s refers to A|B|C|D instead... I suspect there are a lot of implied invariants in this representation and those should be eliminated.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PeerDeviceMapping can technically be represented with a boolean with the schedules implemented right now, but I didn't want to assume that it would necessarily stay that way with other schedules that we may add.

And correct, it's a little difficult to solely rely on the relationships of MNKL because operands are sharded and rotated differently in different schedules. Because DistGEMM is also supposed to have a separate buffer space for remote tensors and switch between them, it wasn't possible to maintain that behavior anymore.
I'll try and think of a better way to do this, but at some point, references to ABCD will be inevitable just because schedules shard them differently.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And my point is that what is important is how the processors are moving through MNK space, which can then be translated to actions on ABC.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, thanks for explaining. I think we could just replace IterationMapping{A,B,C,D} with mappings to MNKL tile coordinates. It would only make sense given that the tile shape (IterationTiler) is set up in size-4 tuples corresponding to MNKL anyway. Would moving to IterationMapping{M,N,K,L} (or if I can figure out the mapping for it just one IterationMappingMNKL layout) be a step in the right direction?

Copy link
Collaborator

@thakkarV thakkarV Oct 31, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we briefly discussed this before your internship ended -- this is just the same idea as tile schedulers we have at the kernel layer, but rather at the scale of the NVLink system. I think calling this a DistributedTileScheduler makes sense with the same abstract class hierarchy as the tile schedulers we already have. The job of this scheduler is to map a given tile coordinate (in this case, a tile of the global problem layout) to a given physical processor (in this case, the GPU ID). This will let you generalize this to 2D and 3D TP in the future. Distributed schedules (patterns as you called them) are then just different TV layouts where the T mode can be 1,2, or 3D depending on the TP strategy

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Of course; that's the end goal, but at the same time it's kind of non-trivial to come up with an API that closely resembles kernel layer tile schedulers.
Right now it is handling cross-device and on-device tiling, and map device index and pipeline stage / iteration to tile coordinates. Are you saying that it should be broken up into two different components, one handling the tiling, and one handling the mapping to tile coordinates?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Partitioning is the job of the TV layout that maps the GPU rank to the values it extracts from the global coordinate. Scheduling adds the temporal component on top, which tells each GPU which coordinate to work on at which step of the computation. Separating them out into two makes the most sense. @ccecka agree?

Remove ProcessorMapping{A,B,C,D} and IterationMapping{A,B,C,D},
and use IterationMapping{M,N,K,L} instead.
Copy link

This PR has been labeled inactive-30d due to no recent activity in the past 30 days. Please close this PR if it is no longer required. Otherwise, please respond with a comment indicating any updates. This PR will be labeled inactive-90d if there is no activity in the next 60 days.

@dskhudia
Copy link

dskhudia commented Dec 2, 2024

@alihassanijr : Thanks for writing the detailed writeup on this. https://blog.shi-labs.com/distributed-gemm-88be6a481e2b

I was wondering if it can work with 12.4 toolkit? Is 12.6 toolkit absolutely required?

@alihassanijr
Copy link
Contributor Author

@dskhudia thank you.

Re: CTK version, I'm actually not sure; I know 12.3 and earlier don't work because there's a few cuda graph APIs that DistGEMM relies on. I haven't been able to check exactly which CTK release is the earliest that supports those, I'll try and get that this week.
But it's simple enough to test if you've set it up; if it's incompatible, you'd get compilation errors from those missing cuda graph APIs.

@dskhudia
Copy link

dskhudia commented Dec 5, 2024

@alihassanijr : Thanks for the suggestion. I tried it with CTK 12.4 and it worked fine with the following results which roughly matches with your results from the blog. So I think you relax the constraint in this example.

   // Some necessary cuda graph APIs were only introduced in CUDA 12.6.
-  if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 6)) {
+  if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 4)) {
     std::cerr << "This example requires CUDA 12 or newer." << std::endl;
  running DistGEMM...
  running DistGEMM finished without runtime errors

  Disposition (eps: 0): Passed
  Warming up for 10 iterations.
  Profiling for 100 iterations.

  TP: 8
  Problem Size: 16384 x 106496 x 16384 x 1
  Local GEMM Problem Size: 2048 x 13312 x 16384 x 1
  Avg runtime: 9.16037 ms
  TFLOPS: 780.19

However, TP=4 segfaults (I did change using TP = _8 to using TP = _4) so maybe somewhere TP=8 is assumed.

@alihassanijr
Copy link
Contributor Author

I think CTK 12.4 should be fine then if there are no compile errors.

But the TP=4 issue is weird; we've tested different TP values extensively. I'll try it on my end and get back to you.

@alihassanijr
Copy link
Contributor Author

alihassanijr commented Dec 5, 2024

@dskhudia TP=4 runs fine for me. I'll try and see if I can set up CTK 12.4 somewhere and try.

In the meantime, could you also confirm your CUDA driver version?
Also, can you confirm you're using the default kernel and schedule (are the only changes TP=4 and the CTK version check?)

$ ./64_distributed_gemm --m=16384 --n=106496 --k=16384

  running DistGEMM...
  running DistGEMM finished without runtime errors

  Disposition (eps: 0): Passed
  Warming up for 10 iterations.
  Profiling for 100 iterations.

  TP: 4
  Problem Size: 16384 x 106496 x 16384 x 1
  Local GEMM Problem Size: 4096 x 26624 x 16384 x 1
  Avg runtime: 16.8623 ms
  TFLOPS: 847.67

Update: I just remembered that at least one CUDA graphs API can result in a segfault if earlier CTK

@dskhudia
Copy link

dskhudia commented Dec 5, 2024

The driver version I have is Driver Version: 535.183.01

Yes. Those (tp=4 and removal of version check restriction) are the only changes.

@dskhudia
Copy link

dskhudia commented Dec 5, 2024

I tried on another machine that has Driver Version: 550.54.14 and it worked fine :-)

@alihassanijr
Copy link
Contributor Author

alihassanijr commented Dec 5, 2024

Yes I was just about to post this:

I think your driver could also be the problem. The CTK version corresponding to 535 is 12.2.

I just tried it with CTK 12.5 (still trying to get it working with 12.4 as well) on a different machine, and this one luckily has driver version 550, which corresponds to CTK 12.4, and it still works fine on my end.


what's very strange to me though is that TP=8 worked with 535...

@dskhudia
Copy link

dskhudia commented Dec 9, 2024

@alihassanijr : Will this work for fp8 as well ? Anything extra needed (other than wrapper code)?

@alihassanijr
Copy link
Contributor Author

alihassanijr commented Dec 9, 2024

@dskhudia yes it does support FP8, and in theory should support any CUTLASS 3.X dense GEMM.

Although I'd add that the kernel configuration in that example was specifically chosen as a result of profiling with those specific problem shapes in the blog post and FP16, so I wouldn't necessarily expect the same kernel config to be the best choice for FP8 or for different TP values.

@dskhudia
Copy link

dskhudia commented Dec 9, 2024

Thanks :) Will try with fp8.

@whatdhack
Copy link

At a superficial level, there appears to be overlap with framework ( e.g. PyTorch ) Distributed Tensor. Any comment on how each can benefit from the other ?

@alihassanijr
Copy link
Contributor Author

@whatdhack I don't have too much information about Distributed Tensor, but based on my understanding, it would be a framework-level primitive for model parallelism techniques like Tensor Parallelism (TP). Distributed GEMM is more about how the Tensor Parallel GEMMs are implemented, so it would more serve as a backend, with things like Distributed Tensor being IR-level.

@christindbose
Copy link

Can this work on a 8xV100 machine assuming I use CUDA 12.4? @alihassanijr

@alihassanijr
Copy link
Contributor Author

alihassanijr commented Dec 16, 2024

@christindbose sorry no, it's only implemented for CUTLASS 3.X kernels right now, which primarily limits it to Hopper.
Even if extended to CUTLASS 2.X, Ampere/Volta machines with an all-to-all NVL network topology would be required.

@dskhudia
Copy link

dskhudia commented Dec 17, 2024

@alihassanijr : I was going through the code and where is self_flag_ptr or self_flag_ptr_ used? Is this needed or it's captured by graph edge dependency that you have?

edit: nvm. It's in the kernel wrapper.

dskhudia added a commit to dskhudia/cutlass that referenced this pull request Jan 7, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants