Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Text generation inference integration #12

Open
wants to merge 60 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
60 commits
Select commit Hold shift + click to select a range
7f9d021
adding TGI Dockerfile
andrewramsay Feb 22, 2024
5823d56
Add tgi service to docker-compose.yml
andrewramsay Feb 22, 2024
9f5383b
Add env var for TGI to llm_functionalities
andrewramsay Feb 22, 2024
7f70244
Updating llm_functionalities
andrewramsay Feb 22, 2024
8f72440
Initial attempt to integrate TGI
andrewramsay Feb 22, 2024
d715048
Have llm_functionalities exit if connection fails
andrewramsay Feb 22, 2024
6d06e78
adding timeout parameter when creating InferenceClient
andrewramsay Feb 26, 2024
26e22d8
Some more TGI updates
andrewramsay Feb 26, 2024
c4cfbdc
Adding connection attempt parameters for TGI
andrewramsay Mar 8, 2024
f670a12
Use a smaller default model for testing
andrewramsay Mar 8, 2024
cdc624a
Add TGI-specific summarization protos/RPCs
andrewramsay Mar 8, 2024
5754857
Use current source version of huggingface_hub
andrewramsay Mar 8, 2024
d7cd49b
Update LLMRunner connection behaviour
andrewramsay Mar 8, 2024
5c63e9a
Add a _check_connectivity method
andrewramsay Mar 8, 2024
b77fa41
Set pool size based on number of requests
andrewramsay Mar 8, 2024
4e1e86a
Initial attempt at adding some new TGI endpoints
andrewramsay Mar 8, 2024
b5edea1
Merge branch 'main' into text_generation_inference
andrewramsay Mar 8, 2024
702b076
Merge branch 'main' into text_generation_inference
andrewramsay Apr 4, 2024
c519af8
Merge branch 'main' into text_generation_inference
andrewramsay Apr 26, 2024
c92a20f
pin version of TGI
andrewramsay Apr 26, 2024
5a86228
remove TGI summary protos and RPCs
andrewramsay Apr 26, 2024
6d8e9fd
Fix session ID field name for tests
andrewramsay Apr 26, 2024
c08d061
Add a couple of JSON-format TaskMaps for tests
andrewramsay Apr 26, 2024
2e91a94
Add HUGGING_FACE_HUB_TOKEN to tgi service
andrewramsay Apr 26, 2024
7361b08
Fix session ID field name
andrewramsay Apr 26, 2024
3eb171a
Add set_source calls for each code path
andrewramsay Apr 26, 2024
bc3cfbf
Bump CUDA image version to avoid warnings
andrewramsay Apr 26, 2024
03cf959
Fix warnings about f-strings with no placeholders
andrewramsay Apr 26, 2024
51a96e9
Adding new methods to AbstractDB
andrewramsay May 3, 2024
590f498
Adding support for deleting TaskMaps and Sessions
andrewramsay May 3, 2024
67e83d9
Adding a timeout field to LLM request protos
andrewramsay May 3, 2024
e72770a
Removing old generate_summary methods
andrewramsay May 3, 2024
4d98c5a
Update LLM timeout/error handling
andrewramsay May 3, 2024
5f08def
Formatting/linting
andrewramsay May 3, 2024
69548a1
Update LLM timeout/error handling
andrewramsay May 3, 2024
77157b4
Latest version of LLMRunner
andrewramsay May 3, 2024
26d0296
Small round of linting/formatting
andrewramsay May 3, 2024
cd4ff58
Small bugfix
andrewramsay May 3, 2024
b8cac13
Adding some WIP LLM tests
andrewramsay May 3, 2024
3b71947
Adding a delete_stagedoutput to make a test work
andrewramsay May 10, 2024
2d4bff4
Update set_source call to add a message
andrewramsay May 10, 2024
88e07d8
Update sample taskmap
andrewramsay May 10, 2024
b857c9d
Current version of LLM tests
andrewramsay May 10, 2024
d19c338
Updating docker-compose.yml
andrewramsay May 23, 2024
d3f823f
Add some extra logging around enhancing taskmaps
andrewramsay May 23, 2024
e17f40c
Update LLM ingredient substitution test
andrewramsay May 23, 2024
747050d
Refactoring LLMRunner connection handling
andrewramsay May 23, 2024
38f33a7
Adding dnspython as a dependency
andrewramsay May 23, 2024
62bb76f
Update LLM ingredient substitution
andrewramsay May 23, 2024
09591f6
Prevent a crash if LLM response is empty
andrewramsay May 23, 2024
90ed0fb
Fix parsing of LLM responses
andrewramsay May 23, 2024
5787c5c
move llm downloads.toml to TGI
andrewramsay May 23, 2024
2d412c9
Sync folder name with service name
andrewramsay May 23, 2024
c12cf13
Move downloads.toml back llm_functionalities
andrewramsay May 23, 2024
b76aad2
Download alpaca_llm to a TGI Docker volume
andrewramsay May 23, 2024
39e00f9
Fix path to tgi Dockerfile
andrewramsay May 23, 2024
88e21c8
Update TGI Docker volumes
andrewramsay May 23, 2024
196707d
Add auto-downloading of "local" models
andrewramsay May 23, 2024
da5e20b
Remove the connection attempt in __init__
andrewramsay May 23, 2024
0416bb5
fix path in TGI dockerfile
andrewramsay May 23, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 46 additions & 3 deletions docker-compose.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
version: "3.9" # optional since v1.27.0
services:

builder:
Expand Down Expand Up @@ -293,20 +292,64 @@ services:
- NEURAL_FUNCTIONALITIES_URL=llm_functionalities:8000
- FUNCTIONALITIES_URL=functionalities:8000
- EXTERNAL_FUNCTIONALITIES_URL=external_functionalities:8000
# the location of the TGI endpoint (note that it's using the internal
# Docker network, so it's port 80 rather than 8081). this could also be
# set to a remote location not managed by the current Compose instance,
# in which case the URL should reference port 8081 instead.
- INFERENCE_ENDPOINT_URL=tgi:80
networks:
- internal
- external
depends_on:
builder:
condition: service_started
oat_common:
tgi:
condition: service_started

tgi:
platform: linux/x86_64
container_name: tgi
build:
context: ./
dockerfile: tgi/Dockerfile
ports:
# this isn't needed for llm_functionalities (it uses the internal network),
# but might be useful to be able to submit requests to the instance from
# external scripts for testing/debugging, or if the TGI service is running
# in a different location than the rest of the system.
- "8081:80"
volumes:
# if the model is downloaded from huggingface, the files will be stored
# in this location
- ./shared/file_system/downloads/tgi/huggingface:/data
# this volume is used to allow the use of the existing Alpaca LLM
# that OAT is currently set up to use. setting MODEL_ID below to a
# path instead of a huggingface model ID will cause TGI to load model
# files from that location
- ./shared/file_system/downloads/tgi/local:/models
environment:
# setting the value of MODEL_ID is equivalent to passing "--model-id" parameter
# to the TGI launcher. the default is to use the Alpaca LLM model that OAT
# has previously used
- MODEL_ID=${MODEL_ID:-/models/alpaca_llm/}
# any other TGI launcher parameters can be set in this env var, e.g.:
# TGI_PARAMS="--param1 param1_value --param2 param2_value" docker compose up
- TGI_PARAMS=${TGI_PARAMS:-}
# this is required for Mistral or other gated models. you can find your
# tokens by browsing to https://huggingface.co/settings/tokens
- HUGGING_FACE_HUB_TOKEN=${HUGGING_FACE_HUB_TOKEN:-}
networks:
- internal
- external
# larger sharded models will need this increased from the Docker default of 64MB
shm_size: ${SHM_SIZE:-1gb}
deploy:
resources:
reservations:
devices:
- driver: nvidia
count: 1
# set the number of GPUs available to the container, default to 1
count: ${GPU_COUNT:-1}
capabilities: [ gpu ]

networks:
Expand Down
8 changes: 8 additions & 0 deletions external_functionalities/database/abstract_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,18 @@ def load_session(self, session_id: str) -> Session:
def save_session(self, session_id: str, session: Session) -> None:
pass

@abstractmethod
def delete_session(self, session_id: str) -> None:
pass

@abstractmethod
def load_taskmap(self, taskmap_id: str) -> TaskMap:
pass

@abstractmethod
def save_taskmap(self, session_id: str, session: Session) -> None:
pass

@abstractmethod
def delete_taskmap(self, taskmap_id: str) -> None:
pass
21 changes: 21 additions & 0 deletions external_functionalities/database/dynamo_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,19 +53,40 @@ def save_session(self, session_id: str, session: Session) -> None:

def load_session(self, session_id: str) -> Session:
session = self.session_db.get(session_id)

if self.taskmap_enhancer.trigger(session):
# this is triggered if the taskmap has a wikihow URL and it hasn't already been enhanced
# (i.e. session.task.state.enhanced is False).
# Calling .enhance here with questions_only defaulting to False will peform two types
# of enhancement:
# - step summarization
# - proactive question enhancement
logger.info("Taskmap Enhancement triggered: Summarization and Proactive QA")
return self.taskmap_enhancer.enhance(session)
elif session.task.state.enhanced is False and session.task.taskmap.taskmap_id != "":
# in this case (not a wikihow URL, not already enhanced, non-empty taskmap ID),
# calling enhance with questions_only set to True will *only* run the
# proactive_question_enhancement part
logger.info("Taskmap Enhancement triggered: Proactive QA only")
return self.taskmap_enhancer.enhance(session, questions_only=True)
else:
return session

def delete_session(self, session_id:str) -> None:
self.session_db.delete(session_id)

def save_taskmap(self, session_id: str, session: Session) -> None:
self.taskmap_db.put(session)

def load_taskmap(self, taskmap_id: str) -> TaskMap:
return self.taskmap_db.get(taskmap_id)

def delete_taskmap(self, taskmap_id: str) -> None:
self.taskmap_db.delete(taskmap_id)

def delete_stagedoutput(self, taskmap_id: str) -> None:
self.taskmap_enhancer.db.delete(taskmap_id)

def save_search_log(self, search_log: SearchLog) -> None:
self.search_logs_db.put(search_log)

Expand Down
1 change: 0 additions & 1 deletion external_functionalities/database/llm_taskmap_enhancer.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ def __get_indices_with_criteria(length):
return list(indices)

def proactive_question_enhancement(self, taskmap, staged_db):
logger.info("GENERATE QUESTIONS")
request: ProactiveQuestionGenerationRequest = ProactiveQuestionGenerationRequest()
indices = self.__get_indices_with_criteria(len(taskmap.steps))

Expand Down
12 changes: 12 additions & 0 deletions external_functionalities/database/servicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,25 @@ def save_session(self, request, context) -> None:
self.instance.save_session(request.id, request.session)
return Void()

def delete_session(self, request, context) -> None:
self.instance.delete_session(request.id)
return Void()

def load_taskmap(self, request, context) -> TaskMap:
return self.instance.get_taskmap(request.id)

def save_taskmap(self, request, context) -> None:
self.instance.save_session(request.id, request.taskmap)
return Void()

def delete_taskmap(self, request, context) -> None:
self.instance.delete_taskmap(request.id)
return Void()

def delete_stagedoutput(self, request, context) -> None:
self.instance.delete_stagedoutput(request.id)
return Void()

def save_search_logs(self, request, context) -> None:
self.instance.save_search_log(request)
return Void()
Expand Down
3 changes: 3 additions & 0 deletions external_functionalities/database/staged_enhance.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ def _enhance_desc(self, session_db, taskmap_db, session: Session) -> None:
taskmaps_to_enhance = []
request: LLMMultipleDescriptionGenerationRequest = LLMMultipleDescriptionGenerationRequest()

logger.info("Enhancing descriptions")

for i, candidate in enumerate(session.task_selection.candidates_union):
# avoid categories
if candidate.HasField('category'):
Expand All @@ -99,6 +101,7 @@ def _enhance_desc(self, session_db, taskmap_db, session: Session) -> None:
return

# Generate Descriptions and update session
logger.info(f"Enhancing descriptions: {len(request.task_title)} candidates found")
descriptions = self.llm_desc_generator.generate_descriptions(request)

for taskmap, description in zip(taskmaps_to_enhance, descriptions.description):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,28 +1,22 @@
import grpc
import os
import json
import time

from compiled_protobufs.llm_pb2 import ExecutionSearchResponse, ExecutionSearchRequest, ModelResponse, ModelRequest
from compiled_protobufs.llm_pb2_grpc import LLMRunnerStub
from compiled_protobufs.taskmap_pb2 import Session

from grpc._channel import _InactiveRpcError
from concurrent.futures import TimeoutError, ThreadPoolExecutor

from utils import logger, SEARCH_AGAIN_QUESTION


class ExecutionSearchManager:
def __init__(self):
llm_channel = grpc.insecure_channel(os.environ["LLM_FUNCTIONALITIES_URL"])
self.llm = LLMRunnerStub(llm_channel)
self.valid_intents = [
'recommend_new_search', 'continue_current_task', 'ask_clarifying_question'
]
self.valid_intents = ["recommend_new_search", "continue_current_task", "ask_clarifying_question"]

def extract_classification_response(self, generated_answer) -> dict:
logger.info(f'Raw LLM response for classification: {generated_answer}')
logger.info(f"Raw LLM response for classification: {generated_answer}")
valid_dict = {}
start_token = "{"
end_token = "}"
Expand All @@ -33,84 +27,63 @@ def extract_classification_response(self, generated_answer) -> dict:
generated_answer = "{" + generated_answer.split(start_token)[-1]

try:
generated_answer = generated_answer.replace(' \n ', '').replace('\n', '')
generated_answer = generated_answer.replace(" \n ", "").replace("\n", "")
valid_dict = json.loads(generated_answer)
logger.info(f'Managed to parse classification: {valid_dict}')
logger.info(f"Managed to parse classification: {valid_dict}")

if "intent" in list(valid_dict.keys()) and valid_dict["intent"] in self.valid_intents:
return valid_dict
else:
logger.info(f'Dictionary contents not valid: {valid_dict}')
logger.info(f"Dictionary contents not valid: {valid_dict}")
except Exception as e:
logger.info(f'Could not parse response >{generated_answer}<: {e}')
logger.info(f"Could not parse response >{generated_answer}<: {e}")
return valid_dict

def build_prompt(self, request: ExecutionSearchRequest) -> ModelRequest:
model_request: ModelRequest = ModelRequest()

task_type = f"making {request.taskmap.title}" if request.domain == Session.Domain.COOKING \
task_type = (
f"making {request.taskmap.title}"
if request.domain == Session.Domain.COOKING
else f"with {request.taskmap.title}"
)
# instruction
prompt = f"### Instruction: Imagine you are an AI assistant currently helping a user with {request.taskmap.title}. " \
f"You are able to switch to a new task, so based on the last user request you need to decide " \
f"whether to continue with {request.taskmap.title}, recommend that you can start new search based " \
f"on the last user request, or asking a clarifying question. " \
f"Choose one of the intent options and follow this format for your response: {{\"intent\": \"\"}}\n"
prompt = (
f"### Instruction: Imagine you are an AI assistant currently helping a user with {request.taskmap.title}. "
f"You are able to switch to a new task, so based on the last user request you need to decide "
f"whether to continue with {request.taskmap.title}, recommend that you can start new search based "
f"on the last user request, or asking a clarifying question. "
f'Choose one of the intent options and follow this format for your response: {{"intent": ""}}\n'
)
# input
prompt += f"### Input:\nCurrent task: {request.taskmap.title}\nConversation history:\n"
if any([request.last_agent_response == prompt for prompt in
SEARCH_AGAIN_QUESTION]) and "step" in request.last_last_agent_response.lower():
if (
any([request.last_agent_response == prompt for prompt in SEARCH_AGAIN_QUESTION])
and "step" in request.last_last_agent_response.lower()
):
prompt += f"You: {request.last_last_agent_response}\n"
else:
prompt += f"You: {request.last_agent_response}\n"
prompt += f"Last user request: {request.user_question}\n"
prompt += f"Intent options:{str(self.valid_intents)}\n\n"
# response
prompt += "### Response: Your response:{\"intent\":"
prompt += '### Response: Your response:{"intent":'

model_request.formatted_prompt = prompt
model_request.max_tokens = 20
return model_request

def llm_generate_search_decision(self, request: ExecutionSearchRequest) -> ExecutionSearchResponse:
def generate_decision(self, request: ExecutionSearchRequest, default_timeout: int = 1000) -> ExecutionSearchResponse:
default_timeout = default_timeout if request.timeout == 0 else request.timeout
model_request = self.build_prompt(request)

llm_response: ModelResponse = self.llm.call_model(model_request)

llm_classification: ExecutionSearchRequest = ExecutionSearchResponse()
llm_classification: ExecutionSearchResponse = ExecutionSearchResponse()
parsed_result = self.extract_classification_response(llm_response.text)
if parsed_result == {}:
return llm_classification

llm_classification.intent_classification = parsed_result.get("intent", "")
llm_classification.ai_response = parsed_result.get("ai_response", "")
return llm_classification

def generate_decision(self, request: ExecutionSearchRequest, default_timeout=1000) -> ExecutionSearchResponse:
default_timeout = default_timeout if request.timeout == 0 else request.timeout

with ThreadPoolExecutor(max_workers=1) as executor:
future = executor.submit(self.llm_generate_search_decision, request)
timeout: float = default_timeout / 1000 + time.monotonic()
try:
if future.done() or timeout - time.monotonic() > 0:
response = future.result(timeout=timeout - time.monotonic())
return response

else:
future.cancel()
logger.warning(f"Timeout for Execution Search Decision Generation")
response = ExecutionSearchResponse()
return response

except TimeoutError as e:
future.cancel()
logger.warning("TimeoutError while running Execution Search Decision Generation", exc_info=e)
response = ExecutionSearchResponse()
return response

except _InactiveRpcError as e:
future.cancel()
logger.warning("Execution Search Decision Generation Channel is down")
response = ExecutionSearchResponse()
return response
Loading