Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Proposition integration LlamaIndex #380

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
353 changes: 353 additions & 0 deletions examples/data/paul_graham_essay/paul_graham_essay.txt

Large diffs are not rendered by default.

27 changes: 27 additions & 0 deletions examples/llama_index_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
"""
This example intends to show the use of a simple case of using llama_index with outlines
It relies on one of the examples proposed by llama_index: https://github.com/run-llama/llama_index/tree/main/examples/paul_graham_essay
"""
import outlines.text.generate as generate
import outlines.models as models
from llama_index import VectorStoreIndex, SimpleDirectoryReader, ServiceContext
from llama_index.response_synthesizers import (
ResponseMode,
get_response_synthesizer,
)
from outlines.tools.llama_index import LlamaIndexOutlinesLLM

# llama_index setup
documents = SimpleDirectoryReader("data/paul_graham_essay").load_data()
index = VectorStoreIndex.from_documents(documents=documents)
service_context = ServiceContext.from_defaults(llm=LlamaIndexOutlinesLLM())
response_synthesizer = get_response_synthesizer(
response_mode=ResponseMode.SIMPLE_SUMMARIZE,
service_context=service_context
)
query_engine = index.as_query_engine(response_synthesizer=response_synthesizer)

model = models.transformers("gpt2", llama_index_engine=query_engine)
prompt = "What did the author do after he left YC? Choose one among the following choices: Painting, Running"
answer = generate.choice(model, ["Painting", "Running"])(prompt)
print(answer)
21 changes: 21 additions & 0 deletions outlines/models/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from typing import TYPE_CHECKING, Callable

from llama_index.callbacks import CallbackManager
from llama_index.indices.prompt_helper import PromptHelper
from llama_index.llms.base import LLM
from outlines.tools.llama_index import set_llama_index_model_function

if TYPE_CHECKING:
from llama_index.core import BaseQueryEngine


class BaseModel:

def __init__(self, llama_index_engine: "BaseQueryEngine" = None, *args, **kwargs):
self.llama_index_engine = llama_index_engine

def run_with_llama_index(self, prompt: str, func: Callable) -> str:
"""Run through the llama_index engine the outlines function with the user prompt"""
set_llama_index_model_function(func)
response = self.llama_index_engine.query(prompt)
return response
5 changes: 4 additions & 1 deletion outlines/models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,24 @@

import outlines
from outlines.caching import cache
from outlines.models.base import BaseModel

__all__ = ["OpenAIAPI", "openai"]

if TYPE_CHECKING:
from openai import AsyncOpenAI


class OpenAIAPI:
class OpenAIAPI(BaseModel):
def __init__(
self,
model_name: str,
api_key: Optional[str] = os.getenv("OPENAI_API_KEY"),
temperature: float = 1.0,
max_retries: int = 6,
**kwargs,
):
super().__init__(**kwargs)
try:
import openai
except ImportError:
Expand Down
8 changes: 6 additions & 2 deletions outlines/models/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch

from outlines.models.tokenizer import Tokenizer
from outlines.models.base import BaseModel

if TYPE_CHECKING:
from transformers import PreTrainedModel, PreTrainedTokenizer
Expand Down Expand Up @@ -55,14 +56,16 @@ class CodeLlamaTokenizerFast: # type: ignore
)


class Transformers:
class Transformers(BaseModel):
"""Represents a `transformers` model."""

def __init__(
self,
model: "PreTrainedModel",
tokenizer: "PreTrainedTokenizer",
**kwargs,
):
super().__init__(**kwargs)
self.device = model.device
self.model = model
self.tokenizer = tokenizer
Expand Down Expand Up @@ -182,6 +185,7 @@ def transformers(
device: Optional[str] = None,
model_kwargs: dict = {},
tokenizer_kwargs: dict = {},
**kwargs,
):
"""Instantiate a model from the `transformers` library and its tokenizer.

Expand Down Expand Up @@ -217,4 +221,4 @@ def transformers(
model = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs)
tokenizer = TransformersTokenizer(model_name, **tokenizer_kwargs)

return Transformers(model, tokenizer)
return Transformers(model, tokenizer, **kwargs)
31 changes: 29 additions & 2 deletions outlines/text/generate/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from typing import TYPE_CHECKING, List, Optional, Tuple, Union

import torch

from outlines.models import OpenAIAPI

if TYPE_CHECKING:
Expand Down Expand Up @@ -140,7 +139,6 @@ def expand_attention_mask(
)
return attention_mask

@torch.inference_mode()
def __call__(
self,
prompt: Union[str, List[str]],
Expand All @@ -162,6 +160,35 @@ def __call__(

"""

if self.model.llama_index_engine:
result = self.model.run_with_llama_index(prompt, self.run)
else:
result = self.run(prompt, samples, rng)

return result

@torch.inference_mode()
def run(
self,
prompt: Union[str, List[str]],
samples: int = 1,
rng: Optional[torch.Generator] = None,
) -> Union[str, List[str]]:
"""Generate a new sequence given a prompt.

Parameters
----------
prompt
The input prompt.
samples
The number of samples to generate for each prompt.

Returns
-------
The full sequence that contains the prompts and the generated string.

"""

token_ids, attention_mask = self.model.tokenizer.encode(prompt)

token_ids = token_ids.squeeze(0)
Expand Down
94 changes: 94 additions & 0 deletions outlines/tools/llama_index.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
from typing import Any

from llama_index.callbacks import CallbackManager
from llama_index.llms.base import CompletionResponse
from llama_index.llms.base import LLM
from llama_index.llms.base import LLMMetadata


llama_index_model_function = None

MODELS_DEFAULT_PARAMS = {
"default": {
"context_window": 1024,
"num_output": 64
},
"gpt2": {
"context_window": 1024,
"num_output": 64
}
}


def set_llama_index_model_function(func):
global llama_index_model_function
llama_index_model_function = func


def get_llama_index_model_function():
return llama_index_model_function


class LlamaIndexOutlinesLLM():

def __init__(self, model_name: str = "default", context_window: int = None, num_output: int = None):
self.callback_manager = CallbackManager()
try:
metadata = MODELS_DEFAULT_PARAMS[model_name]
except KeyError:
raise Exception("Invalid model_name")
metadata.update({"context_window": context_window} if context_window else {})
metadata.update({"num_output": num_output} if num_output else {})
self._metadata = LLMMetadata(**metadata)

@property
def metadata(self):
"""Values used by llama_index to compute the size of the context text chunks to use in the queries"""
return self._metadata

def complete(self, prompt: str, **kwargs) -> CompletionResponse:
"""Function called by llama_index to run the query through the LLM, call the outlines function in our case"""
func = get_llama_index_model_function()
if func:
response = func(prompt)
return CompletionResponse(text=response)
else:
raise Exception("The outlines function has not been set")

### present because they are abstract_methods of the parent class

def chat(self, messages, **kwargs: Any):
"""Chat endpoint for LLM."""
pass

def stream_chat(
self, messages, **kwargs: Any
):
"""Streaming chat endpoint for LLM."""
pass

def stream_complete(self, prompt: str, **kwargs: Any):
"""Streaming completion endpoint for LLM."""
pass

async def achat(
self, messages, **kwargs: Any
):
"""Async chat endpoint for LLM."""
pass

async def acomplete(self, prompt: str, **kwargs: Any):
"""Async completion endpoint for LLM."""
pass

async def astream_chat(
self, messages, **kwargs: Any
):
"""Async streaming chat endpoint for LLM."""
pass

async def astream_complete(
self, prompt: str, **kwargs: Any
) :
"""Async streaming completion endpoint for LLM."""
pass
Loading