Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Does each token requires KNN search during inference? #7

Open
noanti opened this issue Jul 11, 2023 · 3 comments
Open

Does each token requires KNN search during inference? #7

noanti opened this issue Jul 11, 2023 · 3 comments

Comments

@noanti
Copy link

noanti commented Jul 11, 2023

If i use faiss as a Memory, during the inference,calculating each token requires 3(becase there are 3 memory attention layers) knn search, right? Will the generation speed become very slow?

@noanti
Copy link
Author

noanti commented Jul 14, 2023

@CStanKonrad Is there a practical example that using external Memory?

@CStanKonrad
Copy link
Owner

Regarding the question, the suggested implementation of kNN retrieves for each query in the memory layer k most matching keys from the memory cache. In the 3B model, there are 3 memory layers, each having 32 heads, which gives 96 retrievals per token. In general, we recommend using the brute force approach (full attention - no kNN; an example of such an approach is implemented in this repository) for memories that fit on GPU. However, if you want to use Faiss you will need to tune the index manually (note that the faster Faiss indexes have a training stage and allow to balance between speed and retrieval accuracy). We currently do not provide practical examples with Faiss.

Example times obtained on 40GB A100 GPU with bfloat16 precision using code from this repository
(populating of memory cache takes around 17s in this case):
process 64k tokens, then generate 100 tokens: ~23s
process 64k tokens, then generate 200 tokens: ~29s
process 64k tokens, then generate 300 tokens ~36s
process 64k tokens, then generate 400 tokens ~43s
process 64k tokens, then generate 500 tokens ~50s
So in case of 64k memory generation of one token is <= 0.07s (note that if you generate a lot, this time will increase as memory will increase during generation)

@noanti
Copy link
Author

noanti commented Jul 19, 2023

Got it, thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants