-
Notifications
You must be signed in to change notification settings - Fork 37
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix bug in llama decode and add tests for direct/paged KVCache #143
Conversation
04a4f31
to
9bbb389
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yay! Tests!
self.rope_dimension_count = 128 | ||
self.max_seq_len = 4096 | ||
self.start_positions = torch.tensor([8]) | ||
self.bs = 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We still have a TODO to implement batch size 1 in the llama model I think: #40
I wonder if it would help to run this test parameterized (https://stackoverflow.com/a/34094) over bs == 1
, bs == 2
, and bs == 4
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, I'm not sure how that would work yet. I can take a look at that later, just want to get the fix and a couple of basic tests in for now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think its ok to get in as is for now, but file an issue for yourself so you can circle back to it later.
dec7e9e
to
f89fa99
Compare
Signed-off-by: aviator19941 <[email protected]>
Signed-off-by: aviator19941 <[email protected]>
…efill Signed-off-by: aviator19941 <[email protected]>
Signed-off-by: aviator19941 <[email protected]>
Signed-off-by: aviator19941 <[email protected]>
Signed-off-by: aviator19941 <[email protected]>
Signed-off-by: aviator19941 <[email protected]>
Signed-off-by: aviator19941 <[email protected]>
Signed-off-by: aviator19941 <[email protected]>
Signed-off-by: aviator19941 <[email protected]>
665ee32
to
b2f2b60
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Might want someone more comfortable with Python/modeling to also review though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
minor nits but great work.
The windows failure I don't see a way to debug outside of trying it yourself. Do you have access to a windows machine?
self.rope_dimension_count = 128 | ||
self.max_seq_len = 4096 | ||
self.start_positions = torch.tensor([8]) | ||
self.bs = 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think its ok to get in as is for now, but file an issue for yourself so you can circle back to it later.
Good catch on the Windows failure: https://github.com/nod-ai/sharktank/actions/runs/10563521551/job/29263925200?pr=143#step:6:118
I develop primarily on Windows so I could help as needed... but I can't spare the time for that right now. Would be nice to at least mark as XFAIL or skip on the bot. |
I can mark it as XFAIL for now and circle back to it later. |
TODO: Fix decode test for Windows Signed-off-by: aviator19941 <[email protected]>
Signed-off-by: aviator19941 <[email protected]>
Fixes the decode bug for paged kv cache and adds a couple of tests to compare direct vs. paged kv cache results.
TODO: Fix the skipped test for decode.