-
Notifications
You must be signed in to change notification settings - Fork 1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add tensor parallelism for RWKV #1237
base: main
Are you sure you want to change the base?
Conversation
* add asserts and fix post training readme * precommit --------- Co-authored-by: Quentin Anthony <[email protected]>
* fix typo * fix neoxargs usage test * skip conversion test due to multiprocessing issue * precommit --------- Co-authored-by: Quentin Anthony <[email protected]>
* Add ERROR logging prefix and sort alphabetically * fix comment
configs/neox_arguments.md
Outdated
@@ -843,6 +843,29 @@ Model Arguments | |||
|
|||
|
|||
|
|||
- **dim_att**: int |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should either have unified args (across mamba, rwkv, transformers) for these, or prepend these args with whatever block type they're targeting (e.g. rwkv_dim_att
).
|
||
"num_layers": 24, | ||
"hidden_size": 1024, | ||
"num_attention_heads": 16, # head_size = dim_att / num_attention_heads. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar comment here. Calling these attention heads is highly misleading.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I kind of disagree, as rwkv code generally references time mixing as attention
, and the RWKV kernel is often called a type of "linear attention." But, I can add a bunch of configs to decouple rkwv and transformer config options, but this will just create a lot of config args that have essentially the same purpose in my opinion.
except ModuleNotFoundError: | ||
print( | ||
"Unable to import RWKV FLA kernels. Install them from our requirements/requirements-rwkv.txt, \ | ||
or directly from https://github.com/sustcsonglin/flash-linear-attention.git, or use CUDA kernels." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This last point "or use CUDA kernels"
is confusing. Can you add a "by doing xyz"
so that users know what you mean?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
reminder^
@@ -104,7 +126,7 @@ class RWKV_TimeMix(nn.Module): | |||
TODO: fix jit compiling. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this based on the parser issue we discussed? I think it's worth testing just-jit and reordered jit and heuristics like I suggested before merging with this TODO
|
||
self.ffn = RWKV_ChannelMix(neox_args, layer_number) | ||
self.ffn = ParallelRWKV_ChannelMix(neox_args, layer_number, init_method=init_method) | ||
|
||
if neox_args.attention_dropout > 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
another attention arg for rwkv. Can we decouple attn dropout from rwkv?
megatron/neox_arguments/arguments.py
Outdated
WARNING = f"{YELLOW}[WARNING]{END}" | ||
|
||
### Formatted logging prefixes ### | ||
ERROR = f"{RED}[ERROR]{END} " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we've properly merged this branch onto upstream main
, since this is tracking as a change. Please do this.
@@ -277,6 +277,11 @@ class NeoXArgsModel(NeoXArgsTemplate): | |||
} | |||
""" | |||
|
|||
rwkv_fla: bool = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
regen neox_arguments.md
, since this isn't showing up there.
megatron/training.py
Outdated
@@ -406,6 +406,9 @@ def get_batch(neox_args, data_iterator): | |||
datatype=datatype, | |||
) | |||
elif neox_args.train_impl == "kto": | |||
assert ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think these will also go away with a proper rebase onto latest main
…ype' option was removed (#1309) * fix 'intermediate_size' in Llama configuration files after the 'mlp_type' option was removed * config adjustments for llama and gated activations * pre-commit --------- Co-authored-by: jahatef <[email protected]> Co-authored-by: Quentin Anthony <[email protected]>
* Python 3.10 support In this issue Python 3.10 support was added #1122 * update wording on torch and python --------- Co-authored-by: Quentin Anthony <[email protected]>
* adds pyproject files and tests * formatting and add dev packages to dev req files * improve req testing --------- Co-authored-by: Quentin Anthony <[email protected]>
Adds tensor parallel implementation for rwkv, and support for Triton FLA implementation in GPT-NeoX