Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TriePagedAttentionCache #632

Merged
merged 23 commits into from
Dec 4, 2024
Merged

TriePagedAttentionCache #632

merged 23 commits into from
Dec 4, 2024

Conversation

renxida
Copy link
Contributor

@renxida renxida commented Dec 2, 2024

feat: Add TriePagedAttentionCache with initial implementation

Added TriePagedAttentionCache as an optional prefix sharing algorithm, selectable via:
config["paged_kv_cache"]["prefix_sharing_algorithm"] = "trie"

Current Status:

  • Basic implementation and unit tests complete
  • Integration test cases for both Base and Trie implementations, with trie implementation xfailed due to pending cache allocation improvements
  • BasePagedAttentionCache remains the default

Next Steps:
To achieve full functionality, we need to support cache re-allocations to extend the associated tokens & pages.

@renxida renxida marked this pull request as ready for review December 2, 2024 22:24
@renxida renxida requested a review from stbaione December 2, 2024 22:24
@renxida
Copy link
Contributor Author

renxida commented Dec 2, 2024

image

: D

Copy link
Contributor

@stbaione stbaione left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this could benefit from another set of eyes, but the overall interface and operations make sense to me

@renxida renxida enabled auto-merge (squash) December 4, 2024 04:56
@renxida renxida merged commit de4d2fe into nod-ai:main Dec 4, 2024
16 of 20 checks passed
@@ -0,0 +1,432 @@
import pytest
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Source files should have copyright + license header comments

# Try to allocate new sequence - should evict least recently used unpublished sequence
new_tokens = list(range(1000, 1000 + TEST_PAGE_SIZE))
print(f"\nAttempting to allocate new sequence: {new_tokens}")
new_alloc = trie_cache.acquire_pages_for_tokens(new_tokens, extra_token_slots=0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test is failing on Windows: https://github.com/nod-ai/shark-ai/actions/runs/12164613667/job/33926704492?pr=635#step:11:3315

(The Windows shortfin build has been broken until #635, these test failures just slipped in right before I got that passing)

================================== FAILURES ===================================
____________________________ test_lru_eviction[1] _____________________________

trie_cache = <shortfin_apps.llm.components.kvcache.trie_attention_cache.TriePagedAttentionCache object at 0x000002464792FBF0>
access_count = 1

    @pytest.mark.parametrize(
        "access_count", [1, TEST_POOL_CAPACITY // 2, TEST_POOL_CAPACITY - 1]
    )
    def test_lru_eviction(trie_cache, access_count):
        """Test LRU eviction with different access patterns"""
        print(f"\nStarting test_lru_eviction with access_count={access_count}")
    
        # Create mix of published and unpublished sequences
        keep_published = 3  # Number of sequences to keep published
        sequences = []
    
        # First add some sequences we'll keep published
        print("\nPublishing sequences to keep active:")
        for i in range(keep_published):
            tokens = list(range(i * 100, i * 100 + TEST_PAGE_SIZE))
            alloc = trie_cache.acquire_pages_for_tokens(tokens, extra_token_slots=0)
            alloc.publish_pages_for_tokens(alloc.tokens[:TEST_PAGE_SIZE])
            sequences.append(tokens)
            print(f"Published sequence {i} (keeping active)")
            print_tree_state(trie_cache, "  ")
    
        # Then add sequences we'll publish but release (evictable)
        print("\nAdding releasable sequences:")
        for i in range(keep_published, TEST_POOL_CAPACITY):
            tokens = list(range(i * 100, i * 100 + TEST_PAGE_SIZE))
            alloc = trie_cache.acquire_pages_for_tokens(tokens, extra_token_slots=0)
            alloc.publish_pages_for_tokens(alloc.tokens[:TEST_PAGE_SIZE])
            alloc.release_pages()  # These can be evicted
            sequences.append(tokens)
            print(f"Added releasable sequence {i}")
            print_tree_state(trie_cache, "  ")
    
        print("\nCache state before accessing sequences:")
        print_tree_state(trie_cache, "  ")
    
        # Access some sequences to update their LRU status
        print(f"\nAccessing {access_count} sequences to update LRU order:")
        for i in range(access_count):
            print(f"\nAccessing sequence {i}:")
            alloc = trie_cache.acquire_pages_for_tokens(sequences[i], extra_token_slots=0)
            print_tree_state(trie_cache, "  ")
            alloc.release_pages()
            print(f"After releasing allocation {i}:")
            print_tree_state(trie_cache, "  ")
    
        print("\nCache state before attempting new allocation:")
        print_tree_state(trie_cache, "  ")
        print("\nAvailable pages in pool:", len(trie_cache.page_pool.available_pages))
    
        # Try to allocate new sequence - should evict least recently used unpublished sequence
        new_tokens = list(range(1000, 1000 + TEST_PAGE_SIZE))
        print(f"\nAttempting to allocate new sequence: {new_tokens}")
>       new_alloc = trie_cache.acquire_pages_for_tokens(new_tokens, extra_token_slots=0)

tests\apps\llm\components\kvcache\trie_attention_cache_test.py:303: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
python\shortfin_apps\llm\components\kvcache\trie_attention_cache.py:371: in acquire_pages_for_tokens
    self._evict_pages(n_empty_pages - len(self.page_pool.available_pages))
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

self = <shortfin_apps.llm.components.kvcache.trie_attention_cache.TriePagedAttentionCache object at 0x000002464792FBF0>
max_pages = 1

    def _evict_pages(self, max_pages: int) -> int:
        """Evict up to max_pages pages using LRU strategy.
    
        Evicts from unreferenced leaf nodes first, working up the trie
        as nodes become childless.
    
        Args:
            max_pages: Maximum number of pages to evict
    
        Returns:
            Number of pages actually evicted
        """
        pages_to_evict = []
    
        # Initialize heap with unreferenced leaves
        unused_leaf_heap = [
            (leaf.access_time, leaf)
            for leaf in self.leaves
            if leaf.ref_count.is_empty()
        ]
>       heapq.heapify(unused_leaf_heap)
E       TypeError: '<' not supported between instances of 'TrieNode' and 'TrieNode'

python\shortfin_apps\llm\components\kvcache\trie_attention_cache.py:407: TypeError

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ScottTodd added a commit that referenced this pull request Dec 6, 2024
monorimet pushed a commit that referenced this pull request Dec 13, 2024
feat: Add TriePagedAttentionCache with initial implementation

Added TriePagedAttentionCache as an optional prefix sharing algorithm,
selectable via:
`config["paged_kv_cache"]["prefix_sharing_algorithm"] = "trie"`

Current Status:
- Basic implementation and unit tests complete
- Integration test cases for both Base and Trie implementations, with
trie implementation xfailed due to pending cache allocation improvements
- BasePagedAttentionCache remains the default

Next Steps:
To achieve full functionality, we need to support cache re-allocations
to extend the associated tokens & pages.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants