Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: propagate precision correctly to enable non-bf16 inference #165

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

Icedgarr
Copy link

@Icedgarr Icedgarr commented Jun 5, 2024

This PR fixes some incompatibilities that I encountered when instantiating TSS from fam/llm/fast_inference.py with older and less powerful GPUs (e.g. Google Colab T4 GPU).

fam/llm/fast_inference_utils.py was putting the model to the device (cuda) with dtype.bfloat16 instead of using the precision parameter that contains the selected dtype (by default float16 or bfloat16 depending on the GPU architecture).

The linear layer of the Attention class in fam/llm/fast_model.py was also missing the dtype definition using the one provided in the config.

Copy link
Member

@vatsalaggarwal vatsalaggarwal left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks a lot for this! one minor comment, ready to merge o/w!

@@ -188,8 +188,8 @@ def __init__(self, config: ModelArgs):

total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
# key, query, value projections for all heads, but in a batch
self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False)
self.wo = nn.Linear(config.dim, config.dim, bias=False)
self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False, dtype=config.dtype)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why're these required? I think with the fix here, this shouldn't be needed?

Copy link
Author

@Icedgarr Icedgarr Jun 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The following line was throwing an error due to the use of mixed types, q was float16 but k and v are bfloat16.
y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)

Error: TorchRuntimeError: Failed running call_function <built-in function scaled_dot_product_attention>(*(FakeTensor(..., device='cuda:0', size=(2, 16, s0, 128)), FakeTensor(..., device='cuda:0', size=(2, 16, 2048, 128), dtype=torch.float16), FakeTensor(..., device='cuda:0', size=(2, 16, 2048, 128), dtype=torch.float16)), **{'attn_mask': FakeTensor(..., device='cuda:0', size=(1, 1, s0, 2048), dtype=torch.bool), 'dropout_p': 0.0}): Expected query, key, and value to have the same dtype, but got query.dtype: float key.dtype: c10::Half and value.dtype: c10::Half instead.

However, I have just checked and this alone did not solve the issue, it worked after I run the code with the torch dynamo disabled as well (doing export TORCHDYNAMO_DISABLE=1). It may be some of the operations done to the key and value tensors (I suspect this one: k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1) because it is the only one performed on k and v, but not q).

If you prefer, I can remove this change from the PR and if I or someone else find the root cause and a way to solve it we can create another PR.

Copy link
Member

@vatsalaggarwal vatsalaggarwal Jun 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So queries are the right dtype but keys and values are not? That sounds like it might be related to kv-cache not being the right dtype ... but we seem to be setting it correctly here... did you already have a look there?

Copy link
Author

@Icedgarr Icedgarr Jun 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's right, I see that the kv-cache runs when I execute the code, so it is likely to be what changes the dtypes, which according to the code you reference should not happen. If you agree I'll remove this part of the PR and investigate a bit further tomorrow to try to fix this other issue.

@vatsalaggarwal vatsalaggarwal changed the title Fix b16 incompatibilities with older gpu fix: propagate precision correctly to enable non-bf16 inference Jun 5, 2024
@Icedgarr
Copy link
Author

Icedgarr commented Jun 6, 2024

I have reverted the last commit since it was not required for this fix.

@Icedgarr Icedgarr requested a review from vatsalaggarwal June 6, 2024 09:25
Copy link

@RahulVadisetty91 RahulVadisetty91 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

image

need modification to add support for other devices or more robust error handling.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants