Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tensor parallelism for RWKV #1237

Open
wants to merge 30 commits into
base: main
Choose a base branch
from
Open

Add tensor parallelism for RWKV #1237

wants to merge 30 commits into from

Conversation

jahatef
Copy link
Collaborator

@jahatef jahatef commented Jun 19, 2024

Adds tensor parallel implementation for rwkv, and support for Triton FLA implementation in GPT-NeoX

jahatef and others added 17 commits June 4, 2024 11:22
* add asserts and fix post training readme

* precommit

---------

Co-authored-by: Quentin Anthony <[email protected]>
* fix typo

* fix neoxargs usage test

* skip conversion test due to multiprocessing issue

* precommit

---------

Co-authored-by: Quentin Anthony <[email protected]>
* Add ERROR logging prefix and sort alphabetically

* fix comment
@jahatef jahatef marked this pull request as ready for review November 7, 2024 05:28
@@ -843,6 +843,29 @@ Model Arguments



- **dim_att**: int
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should either have unified args (across mamba, rwkv, transformers) for these, or prepend these args with whatever block type they're targeting (e.g. rwkv_dim_att).


"num_layers": 24,
"hidden_size": 1024,
"num_attention_heads": 16, # head_size = dim_att / num_attention_heads.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar comment here. Calling these attention heads is highly misleading.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I kind of disagree, as rwkv code generally references time mixing as attention, and the RWKV kernel is often called a type of "linear attention." But, I can add a bunch of configs to decouple rkwv and transformer config options, but this will just create a lot of config args that have essentially the same purpose in my opinion.

except ModuleNotFoundError:
print(
"Unable to import RWKV FLA kernels. Install them from our requirements/requirements-rwkv.txt, \
or directly from https://github.com/sustcsonglin/flash-linear-attention.git, or use CUDA kernels."
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This last point "or use CUDA kernels" is confusing. Can you add a "by doing xyz" so that users know what you mean?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reminder^

@@ -104,7 +126,7 @@ class RWKV_TimeMix(nn.Module):
TODO: fix jit compiling.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this based on the parser issue we discussed? I think it's worth testing just-jit and reordered jit and heuristics like I suggested before merging with this TODO


self.ffn = RWKV_ChannelMix(neox_args, layer_number)
self.ffn = ParallelRWKV_ChannelMix(neox_args, layer_number, init_method=init_method)

if neox_args.attention_dropout > 0:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

another attention arg for rwkv. Can we decouple attn dropout from rwkv?

WARNING = f"{YELLOW}[WARNING]{END}"

### Formatted logging prefixes ###
ERROR = f"{RED}[ERROR]{END} "
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we've properly merged this branch onto upstream main, since this is tracking as a change. Please do this.

@@ -277,6 +277,11 @@ class NeoXArgsModel(NeoXArgsTemplate):
}
"""

rwkv_fla: bool = False
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

regen neox_arguments.md, since this isn't showing up there.

@@ -406,6 +406,9 @@ def get_batch(neox_args, data_iterator):
datatype=datatype,
)
elif neox_args.train_impl == "kto":
assert (
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think these will also go away with a proper rebase onto latest main

- do not create a fake head dim and split the 'mixed_x_layer' into QKV layers directly.
tiandeyu-cs and others added 10 commits November 29, 2024 15:44
…ype' option was removed (#1309)

* fix 'intermediate_size' in Llama configuration files after the 'mlp_type' option was removed

* config adjustments for llama and gated activations

* pre-commit

---------

Co-authored-by: jahatef <[email protected]>
Co-authored-by: Quentin Anthony <[email protected]>
* Python 3.10 support

In this issue Python 3.10 support was added #1122

* update wording on torch and python

---------

Co-authored-by: Quentin Anthony <[email protected]>
* adds pyproject files and tests

* formatting and add dev packages to dev req files

* improve req testing

---------

Co-authored-by: Quentin Anthony <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

9 participants