-
Notifications
You must be signed in to change notification settings - Fork 1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Distributed GEMM #1907
base: main
Are you sure you want to change the base?
Distributed GEMM #1907
Conversation
Adds experimental support for running tensor parallel GEMMs natively through CUTLASS. Distributed GEMM (DistGEMM) implements communication-fused GEMMs using point-to-point communication, which allows for better pipelining, and theoretically can hide all communication behind computation. It also makes very few assumptions about the underlying kernel, and only adds a few barriers to the beginning of each GEMM kernel, and attempts to either use the epilogue source as the communication buffer, or a memcopy branch in the cuda graph, leaving SMs free for GEMMs, and the copy engine free for communication. When benchmarked with Llama 70B and 405B training shapes, DistGEMM can reach 70-80% of peak performance. A more detailed blog post on DistGEMM will be released soon.
include/cutlass/experimental/distributed/schedules/dist_gemm_base_schedule.hpp
Outdated
Show resolved
Hide resolved
/* ProcessorMappingA_ = */ cute::Layout<cute::Shape<TP_, _1>, cute::Stride<_1, _0>>, // (identity) = device_idx | ||
/* ProcessorMappingB_ = */ cute::Layout<cute::Shape<TP_, _1>, cute::Stride<_1, _0>>, // (identity) = device_idx | ||
/* ProcessorMappingC_ = */ cute::Layout<cute::Shape<TP_, _1>, cute::Stride<_1, _0>>, // (identity) = device_idx | ||
/* ProcessorMappingD_ = */ cute::Layout<cute::Shape<TP_, _1>, cute::Stride<_1, _0>>, // (identity) = device_idx |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The "ProcessorMapping" is the same in all of these schedules and are used in-code very trivially.
(1) What do these parameterize and what other versions of this parameter do you expect to support?
(2) Why do they need to be CuTe Layouts? The sizes, shapes, ranks, strides are never used.
(3) The comments next to these parameters do not explain the domain of this function or the codomain.
(4) The size of the "bias mode" is always 1
and the stride is always 0
, yet these are indexed with a coordinate 1
in that mode always. This seems like a poor parameterization and Layout
s are not what you actually want.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're right, ProcessorMapping is definitely unnecessary at this point. It was originally there in case the first iteration has a remote buffer, in which case they would map device index to the peer device's index.
I'll get rid of this; but just to clarify this is only in reference to ProcessorMapping and not IterationMapping?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm still picking apart "IterationMapping"... it seems redundant with "PeerDeviceMapping", but that's still unclear to me. It also appears that the relationships between MNKL are not respected because each of these "Layout"s refers to A|B|C|D instead... I suspect there are a lot of implied invariants in this representation and those should be eliminated.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PeerDeviceMapping can technically be represented with a boolean with the schedules implemented right now, but I didn't want to assume that it would necessarily stay that way with other schedules that we may add.
And correct, it's a little difficult to solely rely on the relationships of MNKL because operands are sharded and rotated differently in different schedules. Because DistGEMM is also supposed to have a separate buffer space for remote tensors and switch between them, it wasn't possible to maintain that behavior anymore.
I'll try and think of a better way to do this, but at some point, references to ABCD will be inevitable just because schedules shard them differently.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And my point is that what is important is how the processors are moving through MNK space, which can then be translated to actions on ABC.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, thanks for explaining. I think we could just replace IterationMapping{A,B,C,D}
with mappings to MNKL tile coordinates. It would only make sense given that the tile shape (IterationTiler
) is set up in size-4 tuples corresponding to MNKL anyway. Would moving to IterationMapping{M,N,K,L}
(or if I can figure out the mapping for it just one IterationMappingMNKL
layout) be a step in the right direction?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we briefly discussed this before your internship ended -- this is just the same idea as tile schedulers we have at the kernel layer, but rather at the scale of the NVLink system. I think calling this a DistributedTileScheduler makes sense with the same abstract class hierarchy as the tile schedulers we already have. The job of this scheduler is to map a given tile coordinate (in this case, a tile of the global problem layout) to a given physical processor (in this case, the GPU ID). This will let you generalize this to 2D and 3D TP in the future. Distributed schedules (patterns as you called them) are then just different TV layouts where the T mode can be 1,2, or 3D depending on the TP strategy
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Of course; that's the end goal, but at the same time it's kind of non-trivial to come up with an API that closely resembles kernel layer tile schedulers.
Right now it is handling cross-device and on-device tiling, and map device index and pipeline stage / iteration to tile coordinates. Are you saying that it should be broken up into two different components, one handling the tiling, and one handling the mapping to tile coordinates?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Partitioning is the job of the TV layout that maps the GPU rank to the values it extracts from the global coordinate. Scheduling adds the temporal component on top, which tells each GPU which coordinate to work on at which step of the computation. Separating them out into two makes the most sense. @ccecka agree?
Remove ProcessorMapping{A,B,C,D} and IterationMapping{A,B,C,D}, and use IterationMapping{M,N,K,L} instead.
This PR has been labeled |
@alihassanijr : Thanks for writing the detailed writeup on this. https://blog.shi-labs.com/distributed-gemm-88be6a481e2b I was wondering if it can work with 12.4 toolkit? Is 12.6 toolkit absolutely required? |
@dskhudia thank you. Re: CTK version, I'm actually not sure; I know 12.3 and earlier don't work because there's a few cuda graph APIs that DistGEMM relies on. I haven't been able to check exactly which CTK release is the earliest that supports those, I'll try and get that this week. |
@alihassanijr : Thanks for the suggestion. I tried it with CTK 12.4 and it worked fine with the following results which roughly matches with your results from the blog. So I think you relax the constraint in this example.
However, TP=4 segfaults (I did change |
I think CTK 12.4 should be fine then if there are no compile errors. But the TP=4 issue is weird; we've tested different TP values extensively. I'll try it on my end and get back to you. |
@dskhudia TP=4 runs fine for me. I'll try and see if I can set up CTK 12.4 somewhere and try. In the meantime, could you also confirm your CUDA driver version?
Update: I just remembered that at least one CUDA graphs API can result in a segfault if earlier CTK |
The driver version I have is Yes. Those (tp=4 and removal of version check restriction) are the only changes. |
I tried on another machine that has |
Yes I was just about to post this: I think your driver could also be the problem. The CTK version corresponding to 535 is 12.2. I just tried it with CTK 12.5 (still trying to get it working with 12.4 as well) on a different machine, and this one luckily has driver version 550, which corresponds to CTK 12.4, and it still works fine on my end. what's very strange to me though is that TP=8 worked with 535... |
@alihassanijr : Will this work for fp8 as well ? Anything extra needed (other than wrapper code)? |
@dskhudia yes it does support FP8, and in theory should support any CUTLASS 3.X dense GEMM. Although I'd add that the kernel configuration in that example was specifically chosen as a result of profiling with those specific problem shapes in the blog post and FP16, so I wouldn't necessarily expect the same kernel config to be the best choice for FP8 or for different TP values. |
Thanks :) Will try with fp8. |
At a superficial level, there appears to be overlap with framework ( e.g. PyTorch ) Distributed Tensor. Any comment on how each can benefit from the other ? |
@whatdhack I don't have too much information about Distributed Tensor, but based on my understanding, it would be a framework-level primitive for model parallelism techniques like Tensor Parallelism (TP). Distributed GEMM is more about how the Tensor Parallel GEMMs are implemented, so it would more serve as a backend, with things like Distributed Tensor being IR-level. |
Can this work on a 8xV100 machine assuming I use CUDA 12.4? @alihassanijr |
@christindbose sorry no, it's only implemented for CUTLASS 3.X kernels right now, which primarily limits it to Hopper. |
@alihassanijr : edit: nvm. It's in the kernel wrapper. |
Adds experimental support for running tensor parallel GEMMs natively through CUTLASS.
Distributed GEMM (DistGEMM) implements communication-fused GEMMs using point-to-point communication, which allows for better pipelining, and theoretically can hide all communication behind computation. It also makes very few assumptions about the underlying kernel, and only adds a few barriers to the beginning of each GEMM kernel, and attempts to either use the epilogue source as the communication buffer, or a memcopy branch in the cuda graph, leaving SMs free for GEMMs, and the copy engine free for communication.
When benchmarked with Llama 70B and 405B training shapes, DistGEMM can reach 70-80% of peak performance.
A more detailed blog post on DistGEMM will be released soon.