diff --git a/docker-compose.yml b/docker-compose.yml index 99f2f3e..d30cb7a 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,4 +1,3 @@ -version: "3.9" # optional since v1.27.0 services: builder: @@ -293,20 +292,64 @@ services: - NEURAL_FUNCTIONALITIES_URL=llm_functionalities:8000 - FUNCTIONALITIES_URL=functionalities:8000 - EXTERNAL_FUNCTIONALITIES_URL=external_functionalities:8000 + # the location of the TGI endpoint (note that it's using the internal + # Docker network, so it's port 80 rather than 8081). this could also be + # set to a remote location not managed by the current Compose instance, + # in which case the URL should reference port 8081 instead. + - INFERENCE_ENDPOINT_URL=tgi:80 networks: - internal - external depends_on: builder: condition: service_started - oat_common: + tgi: condition: service_started + + tgi: + platform: linux/x86_64 + container_name: tgi + build: + context: ./ + dockerfile: tgi/Dockerfile + ports: + # this isn't needed for llm_functionalities (it uses the internal network), + # but might be useful to be able to submit requests to the instance from + # external scripts for testing/debugging, or if the TGI service is running + # in a different location than the rest of the system. + - "8081:80" + volumes: + # if the model is downloaded from huggingface, the files will be stored + # in this location + - ./shared/file_system/downloads/tgi/huggingface:/data + # this volume is used to allow the use of the existing Alpaca LLM + # that OAT is currently set up to use. setting MODEL_ID below to a + # path instead of a huggingface model ID will cause TGI to load model + # files from that location + - ./shared/file_system/downloads/tgi/local:/models + environment: + # setting the value of MODEL_ID is equivalent to passing "--model-id" parameter + # to the TGI launcher. the default is to use the Alpaca LLM model that OAT + # has previously used + - MODEL_ID=${MODEL_ID:-/models/alpaca_llm/} + # any other TGI launcher parameters can be set in this env var, e.g.: + # TGI_PARAMS="--param1 param1_value --param2 param2_value" docker compose up + - TGI_PARAMS=${TGI_PARAMS:-} + # this is required for Mistral or other gated models. you can find your + # tokens by browsing to https://huggingface.co/settings/tokens + - HUGGING_FACE_HUB_TOKEN=${HUGGING_FACE_HUB_TOKEN:-} + networks: + - internal + - external + # larger sharded models will need this increased from the Docker default of 64MB + shm_size: ${SHM_SIZE:-1gb} deploy: resources: reservations: devices: - driver: nvidia - count: 1 + # set the number of GPUs available to the container, default to 1 + count: ${GPU_COUNT:-1} capabilities: [ gpu ] networks: diff --git a/external_functionalities/database/abstract_db.py b/external_functionalities/database/abstract_db.py index 5512c23..0b918f0 100644 --- a/external_functionalities/database/abstract_db.py +++ b/external_functionalities/database/abstract_db.py @@ -12,6 +12,10 @@ def load_session(self, session_id: str) -> Session: def save_session(self, session_id: str, session: Session) -> None: pass + @abstractmethod + def delete_session(self, session_id: str) -> None: + pass + @abstractmethod def load_taskmap(self, taskmap_id: str) -> TaskMap: pass @@ -19,3 +23,7 @@ def load_taskmap(self, taskmap_id: str) -> TaskMap: @abstractmethod def save_taskmap(self, session_id: str, session: Session) -> None: pass + + @abstractmethod + def delete_taskmap(self, taskmap_id: str) -> None: + pass diff --git a/external_functionalities/database/dynamo_db.py b/external_functionalities/database/dynamo_db.py index 0f4ec65..4d489b4 100644 --- a/external_functionalities/database/dynamo_db.py +++ b/external_functionalities/database/dynamo_db.py @@ -53,19 +53,40 @@ def save_session(self, session_id: str, session: Session) -> None: def load_session(self, session_id: str) -> Session: session = self.session_db.get(session_id) + if self.taskmap_enhancer.trigger(session): + # this is triggered if the taskmap has a wikihow URL and it hasn't already been enhanced + # (i.e. session.task.state.enhanced is False). + # Calling .enhance here with questions_only defaulting to False will peform two types + # of enhancement: + # - step summarization + # - proactive question enhancement + logger.info("Taskmap Enhancement triggered: Summarization and Proactive QA") return self.taskmap_enhancer.enhance(session) elif session.task.state.enhanced is False and session.task.taskmap.taskmap_id != "": + # in this case (not a wikihow URL, not already enhanced, non-empty taskmap ID), + # calling enhance with questions_only set to True will *only* run the + # proactive_question_enhancement part + logger.info("Taskmap Enhancement triggered: Proactive QA only") return self.taskmap_enhancer.enhance(session, questions_only=True) else: return session + def delete_session(self, session_id:str) -> None: + self.session_db.delete(session_id) + def save_taskmap(self, session_id: str, session: Session) -> None: self.taskmap_db.put(session) def load_taskmap(self, taskmap_id: str) -> TaskMap: return self.taskmap_db.get(taskmap_id) + def delete_taskmap(self, taskmap_id: str) -> None: + self.taskmap_db.delete(taskmap_id) + + def delete_stagedoutput(self, taskmap_id: str) -> None: + self.taskmap_enhancer.db.delete(taskmap_id) + def save_search_log(self, search_log: SearchLog) -> None: self.search_logs_db.put(search_log) diff --git a/external_functionalities/database/llm_taskmap_enhancer.py b/external_functionalities/database/llm_taskmap_enhancer.py index b1670f2..1b06a44 100644 --- a/external_functionalities/database/llm_taskmap_enhancer.py +++ b/external_functionalities/database/llm_taskmap_enhancer.py @@ -51,7 +51,6 @@ def __get_indices_with_criteria(length): return list(indices) def proactive_question_enhancement(self, taskmap, staged_db): - logger.info("GENERATE QUESTIONS") request: ProactiveQuestionGenerationRequest = ProactiveQuestionGenerationRequest() indices = self.__get_indices_with_criteria(len(taskmap.steps)) diff --git a/external_functionalities/database/servicer.py b/external_functionalities/database/servicer.py index 28519f9..68125d9 100644 --- a/external_functionalities/database/servicer.py +++ b/external_functionalities/database/servicer.py @@ -17,6 +17,10 @@ def save_session(self, request, context) -> None: self.instance.save_session(request.id, request.session) return Void() + def delete_session(self, request, context) -> None: + self.instance.delete_session(request.id) + return Void() + def load_taskmap(self, request, context) -> TaskMap: return self.instance.get_taskmap(request.id) @@ -24,6 +28,14 @@ def save_taskmap(self, request, context) -> None: self.instance.save_session(request.id, request.taskmap) return Void() + def delete_taskmap(self, request, context) -> None: + self.instance.delete_taskmap(request.id) + return Void() + + def delete_stagedoutput(self, request, context) -> None: + self.instance.delete_stagedoutput(request.id) + return Void() + def save_search_logs(self, request, context) -> None: self.instance.save_search_log(request) return Void() diff --git a/external_functionalities/database/staged_enhance.py b/external_functionalities/database/staged_enhance.py index 478e15c..60c75c8 100644 --- a/external_functionalities/database/staged_enhance.py +++ b/external_functionalities/database/staged_enhance.py @@ -73,6 +73,8 @@ def _enhance_desc(self, session_db, taskmap_db, session: Session) -> None: taskmaps_to_enhance = [] request: LLMMultipleDescriptionGenerationRequest = LLMMultipleDescriptionGenerationRequest() + logger.info("Enhancing descriptions") + for i, candidate in enumerate(session.task_selection.candidates_union): # avoid categories if candidate.HasField('category'): @@ -99,6 +101,7 @@ def _enhance_desc(self, session_db, taskmap_db, session: Session) -> None: return # Generate Descriptions and update session + logger.info(f"Enhancing descriptions: {len(request.task_title)} candidates found") descriptions = self.llm_desc_generator.generate_descriptions(request) for taskmap, description in zip(taskmaps_to_enhance, descriptions.description): diff --git a/functionalities/execution_search_manager/execution_search_manager.py b/functionalities/execution_search_manager/execution_search_manager.py index c341355..5c5269c 100644 --- a/functionalities/execution_search_manager/execution_search_manager.py +++ b/functionalities/execution_search_manager/execution_search_manager.py @@ -1,15 +1,11 @@ import grpc import os import json -import time from compiled_protobufs.llm_pb2 import ExecutionSearchResponse, ExecutionSearchRequest, ModelResponse, ModelRequest from compiled_protobufs.llm_pb2_grpc import LLMRunnerStub from compiled_protobufs.taskmap_pb2 import Session -from grpc._channel import _InactiveRpcError -from concurrent.futures import TimeoutError, ThreadPoolExecutor - from utils import logger, SEARCH_AGAIN_QUESTION @@ -17,12 +13,10 @@ class ExecutionSearchManager: def __init__(self): llm_channel = grpc.insecure_channel(os.environ["LLM_FUNCTIONALITIES_URL"]) self.llm = LLMRunnerStub(llm_channel) - self.valid_intents = [ - 'recommend_new_search', 'continue_current_task', 'ask_clarifying_question' - ] + self.valid_intents = ["recommend_new_search", "continue_current_task", "ask_clarifying_question"] def extract_classification_response(self, generated_answer) -> dict: - logger.info(f'Raw LLM response for classification: {generated_answer}') + logger.info(f"Raw LLM response for classification: {generated_answer}") valid_dict = {} start_token = "{" end_token = "}" @@ -33,51 +27,59 @@ def extract_classification_response(self, generated_answer) -> dict: generated_answer = "{" + generated_answer.split(start_token)[-1] try: - generated_answer = generated_answer.replace(' \n ', '').replace('\n', '') + generated_answer = generated_answer.replace(" \n ", "").replace("\n", "") valid_dict = json.loads(generated_answer) - logger.info(f'Managed to parse classification: {valid_dict}') + logger.info(f"Managed to parse classification: {valid_dict}") if "intent" in list(valid_dict.keys()) and valid_dict["intent"] in self.valid_intents: return valid_dict else: - logger.info(f'Dictionary contents not valid: {valid_dict}') + logger.info(f"Dictionary contents not valid: {valid_dict}") except Exception as e: - logger.info(f'Could not parse response >{generated_answer}<: {e}') + logger.info(f"Could not parse response >{generated_answer}<: {e}") return valid_dict def build_prompt(self, request: ExecutionSearchRequest) -> ModelRequest: model_request: ModelRequest = ModelRequest() - task_type = f"making {request.taskmap.title}" if request.domain == Session.Domain.COOKING \ + task_type = ( + f"making {request.taskmap.title}" + if request.domain == Session.Domain.COOKING else f"with {request.taskmap.title}" + ) # instruction - prompt = f"### Instruction: Imagine you are an AI assistant currently helping a user with {request.taskmap.title}. " \ - f"You are able to switch to a new task, so based on the last user request you need to decide " \ - f"whether to continue with {request.taskmap.title}, recommend that you can start new search based " \ - f"on the last user request, or asking a clarifying question. " \ - f"Choose one of the intent options and follow this format for your response: {{\"intent\": \"\"}}\n" + prompt = ( + f"### Instruction: Imagine you are an AI assistant currently helping a user with {request.taskmap.title}. " + f"You are able to switch to a new task, so based on the last user request you need to decide " + f"whether to continue with {request.taskmap.title}, recommend that you can start new search based " + f"on the last user request, or asking a clarifying question. " + f'Choose one of the intent options and follow this format for your response: {{"intent": ""}}\n' + ) # input prompt += f"### Input:\nCurrent task: {request.taskmap.title}\nConversation history:\n" - if any([request.last_agent_response == prompt for prompt in - SEARCH_AGAIN_QUESTION]) and "step" in request.last_last_agent_response.lower(): + if ( + any([request.last_agent_response == prompt for prompt in SEARCH_AGAIN_QUESTION]) + and "step" in request.last_last_agent_response.lower() + ): prompt += f"You: {request.last_last_agent_response}\n" else: prompt += f"You: {request.last_agent_response}\n" prompt += f"Last user request: {request.user_question}\n" prompt += f"Intent options:{str(self.valid_intents)}\n\n" # response - prompt += "### Response: Your response:{\"intent\":" + prompt += '### Response: Your response:{"intent":' model_request.formatted_prompt = prompt model_request.max_tokens = 20 return model_request - def llm_generate_search_decision(self, request: ExecutionSearchRequest) -> ExecutionSearchResponse: + def generate_decision(self, request: ExecutionSearchRequest, default_timeout: int = 1000) -> ExecutionSearchResponse: + default_timeout = default_timeout if request.timeout == 0 else request.timeout model_request = self.build_prompt(request) llm_response: ModelResponse = self.llm.call_model(model_request) - llm_classification: ExecutionSearchRequest = ExecutionSearchResponse() + llm_classification: ExecutionSearchResponse = ExecutionSearchResponse() parsed_result = self.extract_classification_response(llm_response.text) if parsed_result == {}: return llm_classification @@ -85,32 +87,3 @@ def llm_generate_search_decision(self, request: ExecutionSearchRequest) -> Execu llm_classification.intent_classification = parsed_result.get("intent", "") llm_classification.ai_response = parsed_result.get("ai_response", "") return llm_classification - - def generate_decision(self, request: ExecutionSearchRequest, default_timeout=1000) -> ExecutionSearchResponse: - default_timeout = default_timeout if request.timeout == 0 else request.timeout - - with ThreadPoolExecutor(max_workers=1) as executor: - future = executor.submit(self.llm_generate_search_decision, request) - timeout: float = default_timeout / 1000 + time.monotonic() - try: - if future.done() or timeout - time.monotonic() > 0: - response = future.result(timeout=timeout - time.monotonic()) - return response - - else: - future.cancel() - logger.warning(f"Timeout for Execution Search Decision Generation") - response = ExecutionSearchResponse() - return response - - except TimeoutError as e: - future.cancel() - logger.warning("TimeoutError while running Execution Search Decision Generation", exc_info=e) - response = ExecutionSearchResponse() - return response - - except _InactiveRpcError as e: - future.cancel() - logger.warning("Execution Search Decision Generation Channel is down") - response = ExecutionSearchResponse() - return response diff --git a/functionalities/llm_chit_chat/llm_chit_chat.py b/functionalities/llm_chit_chat/llm_chit_chat.py index 0eecb98..737f7ee 100644 --- a/functionalities/llm_chit_chat/llm_chit_chat.py +++ b/functionalities/llm_chit_chat/llm_chit_chat.py @@ -1,18 +1,15 @@ import grpc import os import re -import time from compiled_protobufs.chitchat_classifier_pb2 import ChitChatResponse from compiled_protobufs.llm_pb2 import ModelRequest, ModelResponse, LLMChitChatRequest from compiled_protobufs.llm_pb2_grpc import LLMRunnerStub -from grpc._channel import _InactiveRpcError -from concurrent.futures import TimeoutError, ThreadPoolExecutor from utils import logger -def extract_qa_answer(generated_answer): +def extract_qa_answer(generated_answer: str) -> str: start_token = "Your response:" end_token_1 = "Human" end_token_2 = "#" @@ -59,35 +56,34 @@ def build_prompt(request: LLMChitChatRequest) -> ModelRequest: if request.task_title != "": prompt += f'You are a friendly AI assistant who is assisting a human making {request.task_title}. ' else: - prompt += f'You are a friendly AI assistant who is assisting a human. ' + prompt += 'You are a friendly AI assistant who is assisting a human. ' prompt += 'Respond to the human or ask a question back. Try to not repeat what you said previously. ' - prompt += f'You specialize in cooking, arts & crafts, and DIY. You do not reveal your name in the spirit of ' \ - f'fair competition. You cannot play music, games or quizzes. You are not able to read the news, ' \ - f'turn on a light, or give recommendations for things outside cooking and DIY domains \n\n' + prompt += 'You specialize in cooking, arts & crafts, and DIY. You do not reveal your name in the spirit of ' \ + 'fair competition. You cannot play music, games or quizzes. You are not able to read the news, ' \ + 'turn on a light, or give recommendations for things outside cooking and DIY domains \n\n' if request.last_intent == "ChitChatIntent": - prompt += f"### Input: \n" + prompt += "### Input: \n" prompt += f'Human: {request.user_question} \n' else: if request.last_intent == "QuestionIntent": prompt += f"You just said: {request.last_agent_response}. Answer the given user question. \n" - prompt += f"### Input: \n" + prompt += "### Input: \n" prompt += f'Human: {request.user_question} \n' else: - prompt += f"### Input: \n" + prompt += "### Input: \n" prompt += f'You: {request.last_agent_response} \n' \ f'Human: {request.user_question} \n' - prompt += f'\n### Response: Your response:' + prompt += '\n### Response: Your response:' model_request.formatted_prompt = prompt model_request.max_tokens = 30 return model_request - def call_chit_chat_model(self, request: LLMChitChatRequest) -> ChitChatResponse: + def generate_chit_chat(self, request: LLMChitChatRequest) -> ChitChatResponse: model_request = self.build_prompt(request) llm_response: ModelResponse = self.llm.call_model(model_request) - logger.info(llm_response.text) agent_response: ChitChatResponse = ChitChatResponse() agent_response.text = process_response_text(extract_qa_answer(llm_response.text)) @@ -95,43 +91,3 @@ def call_chit_chat_model(self, request: LLMChitChatRequest) -> ChitChatResponse: logger.info(f'EXTRACTED LLM RESPONSE: {agent_response.text}') return agent_response - - def generate_chit_chat(self, request: LLMChitChatRequest, default_timeout=2000) -> ChitChatResponse: - # CALL CHIT CHAT LLM SERVICE FROM HERE - # default timeout is 10000, aka 1000 milliseconds aka 1 second - - default_timeout = default_timeout if request.timeout == 0 else request.timeout - logger.info(default_timeout) - - with ThreadPoolExecutor(max_workers=1) as executor: - future = executor.submit(self.call_chit_chat_model, request) - timeout: float = default_timeout / 1000 + time.monotonic() - try: - if future.done() or timeout - time.monotonic() > 0: - response = future.result(timeout=timeout - time.monotonic()) - return response - - else: - future.cancel() - logger.warning(f"Timeout for LLM Chit Chat") - response = ChitChatResponse() - response.text = "" - return response - - except TimeoutError as e: - future.cancel() - logger.warning("TimeoutError while running LLM Chit Chat", exc_info=e) - response = ChitChatResponse() - response.text = "" - return response - - except _InactiveRpcError as e: - future.cancel() - logger.warning("LLM Channel is down") - response = ChitChatResponse() - response.text = "" - return response - - except Exception as e: - future.cancel() - logger.info(type(e)) diff --git a/functionalities/llm_description_generation/llm_description_generation.py b/functionalities/llm_description_generation/llm_description_generation.py index 5bba723..ce4c758 100644 --- a/functionalities/llm_description_generation/llm_description_generation.py +++ b/functionalities/llm_description_generation/llm_description_generation.py @@ -7,7 +7,6 @@ ) from compiled_protobufs.llm_pb2_grpc import LLMRunnerStub -from grpc._channel import _InactiveRpcError from utils import logger @@ -47,11 +46,11 @@ def generate_descriptions(self, request: LLMMultipleDescriptionGenerationRequest model_batch_request: ModelBatchRequest = ModelBatchRequest() model_batch_request.max_tokens = 64 - for title, ingredients, domain in zip(request.task_title, request.ingredients, request.domains): + for title, _, domain in zip(request.task_title, request.ingredients, request.domains): if domain == "wikihow": domain = "" else: - domain = f"Domain: Cooking\n" + domain = "Domain: Cooking\n" prompt = f"### Instruction:\n" \ f"Generate a 2 sentence description. It should be fun entertaining, sell the task and make the " \ f"user wanna start it. Imagine it being the intro that just sells the task.\n\n" \ @@ -60,12 +59,7 @@ def generate_descriptions(self, request: LLMMultipleDescriptionGenerationRequest model_batch_request.formatted_prompts.append(str(prompt)) - try: - llm_responses: ModelBatchResponse = self.llm.batch_call_model(model_batch_request) - except _InactiveRpcError as e: - logger.info('LLM Channel is down during description generation') - llm_responses = ModelBatchResponse() - + llm_responses: ModelBatchResponse = self.llm.batch_call_model(model_batch_request) llm_descriptions: MultipleDescriptionGenerationResponse = MultipleDescriptionGenerationResponse() for text in llm_responses.text: llm_descriptions.description.append(extract_description(text)) diff --git a/functionalities/llm_ingredient_substitution/llm_ingredient_step_text_rewriter.py b/functionalities/llm_ingredient_substitution/llm_ingredient_step_text_rewriter.py index 02c609c..3ebd1b1 100644 --- a/functionalities/llm_ingredient_substitution/llm_ingredient_step_text_rewriter.py +++ b/functionalities/llm_ingredient_substitution/llm_ingredient_step_text_rewriter.py @@ -10,7 +10,6 @@ ) from compiled_protobufs.taskmap_pb2 import Session, Task, ReplacedIngredient -from grpc._channel import _InactiveRpcError from utils import logger from spacy.matcher import Matcher from pyserini.analysis import Analyzer, get_lucene_analyzer @@ -41,51 +40,9 @@ def extract_ingredient_name(self, question_text: str, requirements_list) -> str: return str_item return "" - def extract_response(self, generated_answer) -> str: - valid_dict = {} - start_token = "{" - end_token = "}" - if end_token not in generated_answer: - return valid_dict - - if start_token in generated_answer: - generated_answer = "{" + generated_answer.split(start_token)[-1] - - try: - generated_answer = generated_answer.replace(' \n ', '').replace('\n', '') - valid_dict = json.loads(generated_answer) - logger.info(f'Managed to parse rewritten step text: {valid_dict}') - - if "step_text" in list(valid_dict.keys()): - return valid_dict - else: - logger.info(f'Dictionary contents not valid: {valid_dict}') - except Exception as e: - logger.info(f'Could not parse response >{generated_answer}<: {e}') - return valid_dict - - def process_response_text(self, response_text: str) -> str: - # remove whitespace - text = ' '.join(response_text.split("\n")) - if not re.search(r'[.!?]', text[-1]): - # If not, add a period (.) to the end of the text - text += '.' - - # split at punctuation - sentences = re.split(r'(?<=[.!?])\s+', text) - - complete_sentences = [] - for sentence in sentences: - if sentence.endswith(('.', '!', '?')): - # remove numbered lists - sentence = re.sub(r'\d+\.', '', sentence) - complete_sentences.append(sentence.strip()) - return ' '.join(complete_sentences) - def build_prompt(self, task_title: str, step_text: str, ingredient: ReplacedIngredient) -> str: model_input = f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. - Follow this format: {{\"step_text\": \"\"}} ### Instruction: Adjust the step text to replace the original ingredient with the replacement ingredient. @@ -95,7 +52,7 @@ def build_prompt(self, task_title: str, step_text: str, ingredient: ReplacedIngr Replacement ingredient:{ingredient.replacement.amount} {ingredient.replacement.name} Step: {step_text} - ### Response: {{\"step_text\": \"""" + ### Response: """ return model_input def adjust_step_texts(self, request: Session) -> Session: @@ -132,6 +89,7 @@ def adjust_step_texts(self, request: Session) -> Session: if len(adjusting_request.step) == 0: logger.info('Nothing rewritten because no matches in steps') else: + logger.info("Rewriting steps via LLM") adjusted_step_response: AdjustedStepGenerationResponse = self.rewrite_steps(adjusting_request) for step, id in zip(adjusted_step_response.step_text, adjusted_step_response.ids): rewritten_steps[id] = step @@ -189,18 +147,14 @@ def rewrite_steps(self, request: AdjustedStepGenerationRequest) -> AdjustedStepG llm_step_texts: AdjustedStepGenerationResponse = AdjustedStepGenerationResponse() - try: - for step, ingredient in zip(request.step, request.ingredient): - model_input = self.build_prompt(request.task_title, step.response.speech_text, ingredient) - ids.append(step.unique_id) - model_batch_request.formatted_prompts.append(model_input) - - llm_responses = self.llm.batch_call_model(model_batch_request) - - for idx, text in enumerate(llm_responses.text): - llm_step_texts.step_text.append(self.extract_response(text).get('step_text', '')) - llm_step_texts.ids.append(ids[idx]) - return llm_step_texts - except _InactiveRpcError as e: - logger.warning("Step Text Rewriter Channel is down") - return llm_step_texts + for step, ingredient in zip(request.step, request.ingredient): + model_input = self.build_prompt(request.task_title, step.response.speech_text, ingredient) + ids.append(step.unique_id) + model_batch_request.formatted_prompts.append(model_input) + + llm_responses = self.llm.batch_call_model(model_batch_request) + + for idx, text in enumerate(llm_responses.text): + llm_step_texts.step_text.append(text) + llm_step_texts.ids.append(ids[idx]) + return llm_step_texts diff --git a/functionalities/llm_ingredient_substitution/llm_ingredient_substitution_generation.py b/functionalities/llm_ingredient_substitution/llm_ingredient_substitution_generation.py index a28ffeb..7deda0f 100644 --- a/functionalities/llm_ingredient_substitution/llm_ingredient_substitution_generation.py +++ b/functionalities/llm_ingredient_substitution/llm_ingredient_substitution_generation.py @@ -1,39 +1,38 @@ import grpc import os import json -import time from compiled_protobufs.llm_pb2_grpc import LLMRunnerStub from compiled_protobufs.llm_pb2 import ( - IngredientReplacementRequest, IngredientReplacementResponse, ModelRequest, ModelResponse + IngredientReplacementRequest, + IngredientReplacementResponse, + ModelRequest, + ModelResponse, ) from compiled_protobufs.taskmap_pb2 import Ingredient -from grpc._channel import _InactiveRpcError -from concurrent.futures import TimeoutError, ThreadPoolExecutor - from utils import logger -def extract_replacement_response(generated_answer, original_ing: Ingredient) -> dict: +def extract_replacement_response(generated_answer: str, original_ing: Ingredient) -> dict: valid_dict = {} - start_token = "{" end_token = "}" + # the part of the response format included in the prompt + prefix = "{\"name\"" if end_token not in generated_answer: return valid_dict - if start_token in generated_answer: - generated_answer = "{" + generated_answer.split(start_token)[-1] + generated_answer = f"{prefix}{generated_answer}" try: - generated_answer = generated_answer.replace(' \n ', '').replace('\n', '') + generated_answer = generated_answer.replace(" \n ", "").replace("\n", "") valid_dict = json.loads(generated_answer) - logger.info(f'Managed to parse ingredient: {valid_dict}') + logger.info(f"Managed to parse ingredient: {valid_dict}") if "name" in list(valid_dict.keys()) and "amount" in list(valid_dict.keys()): return valid_dict else: - logger.debug(f'Dictionary contents not valid: {valid_dict}') + logger.debug(f"Dictionary contents not valid: {valid_dict}") except Exception as e: if '"amount"' in generated_answer and "name" in generated_answer: @@ -43,20 +42,20 @@ def extract_replacement_response(generated_answer, original_ing: Ingredient) -> if "}" in generated_answer: second_part = generated_answer.split('"amount"')[1].replace(":", "").replace("}", "").strip() valid_dict["amount"] = second_part - logger.info(f'Managed to parse ingredient (2nd try): {valid_dict}') + logger.info(f"Managed to parse ingredient (2nd try): {valid_dict}") return valid_dict else: second_part = generated_answer.split('"amount"')[1] if second_part.replace(":", "").strip() in original_ing.name: valid_dict["amount"] = original_ing.name - logger.info(f'Managed to parse ingredient (3rd try): {valid_dict}') + logger.info(f"Managed to parse ingredient (3rd try): {valid_dict}") return valid_dict if second_part.replace(":", "").strip() in original_ing.amount: valid_dict["amount"] = original_ing.amount - logger.info(f'Managed to parse ingredient (3rd try): {valid_dict}') + logger.info(f"Managed to parse ingredient (3rd try): {valid_dict}") return valid_dict except Exception as e: - logger.debug(f'Second parsing did not work.') + logger.debug(f"Second parsing did not work.") return valid_dict @@ -75,22 +74,24 @@ def build_prompt(request: IngredientReplacementRequest) -> ModelRequest: original_ing = "" orginal_ing_amount = "" if request.original_ingredient.amount == "" else request.original_ingredient.amount - prompt = f"### Instruction: You are a friendly AI assistant who is assisting a human making " \ - f"{request.task_title}. You are helping the user to replace an ingredient in the recipe. " \ - f"If the amount is not specified in your earlier suggestion, extract the amount from the " \ - f"original ingredient. If your earlier suggestion responds that this substitution is not possible, " \ - f"please respond with an empty ingredient.\n " \ - f"Respond in this format: {{\"name\": \"\", \"amount\": \"{orginal_ing_amount}\"}}\n\n " \ - f"### Input:\n{original_ing}" \ - f"User: {request.user_question}\n" \ - f"Your suggestion: {request.agent_response}\n\n" \ - f"### Response:{{\"name\"" + prompt = ( + f"### Instruction: You are a friendly AI assistant who is assisting a human making " + f"{request.task_title}. You are helping the user to replace an ingredient in the recipe. " + f"If the amount is not specified in your earlier suggestion, extract the amount from the " + f"original ingredient. If your earlier suggestion responds that this substitution is not possible, " + f"please respond with an empty ingredient.\n " + f'Respond in this format: {{"name": "", "amount": "{orginal_ing_amount}"}}\n\n ' + f"### Input:\n{original_ing}" + f"User: {request.user_question}\n" + f"Your suggestion: {request.agent_response}\n\n" + f'### Response:{{"name"' + ) model_request.formatted_prompt = prompt model_request.max_tokens = 20 return model_request - def llm_generate_search_decision(self, request: IngredientReplacementRequest) -> IngredientReplacementResponse: + def generate_replacement(self, request: IngredientReplacementRequest) -> IngredientReplacementResponse: model_request = self.build_prompt(request) llm_response: ModelResponse = self.llm.call_model(model_request) @@ -106,31 +107,3 @@ def llm_generate_search_decision(self, request: IngredientReplacementRequest) -> llm_replacement.new_ingredient.MergeFrom(ingredient) return llm_replacement - - def generate_replacement(self, request: IngredientReplacementRequest, default_timeout=1000) -> \ - IngredientReplacementResponse: - with ThreadPoolExecutor(max_workers=1) as executor: - future = executor.submit(self.llm_generate_search_decision, request) - timeout: float = default_timeout / 1000 + time.monotonic() - try: - if future.done() or timeout - time.monotonic() > 0: - response = future.result(timeout=timeout - time.monotonic()) - return response - - else: - future.cancel() - logger.warning(f"Timeout for Ingredient Replacement Generation") - response = IngredientReplacementResponse() - return response - - except TimeoutError as e: - future.cancel() - logger.warning("TimeoutError while running Ingredient Replacement Generation", exc_info=e) - response = IngredientReplacementResponse() - return response - - except _InactiveRpcError as e: - future.cancel() - logger.warning("Ingredient Replacement Generation Channel is down") - response = IngredientReplacementResponse() - return response diff --git a/functionalities/llm_proactive_question_generation/llm_proactive_question_generation.py b/functionalities/llm_proactive_question_generation/llm_proactive_question_generation.py index b261850..341f8de 100644 --- a/functionalities/llm_proactive_question_generation/llm_proactive_question_generation.py +++ b/functionalities/llm_proactive_question_generation/llm_proactive_question_generation.py @@ -10,7 +10,6 @@ from taskmap_pb2 import ExtraInfo from utils import logger -from grpc._channel import _InactiveRpcError def extract_question(generated_question): @@ -27,6 +26,10 @@ def extract_question(generated_question): def process_response_text(response_text: str) -> str: + if len(response_text) == 0: + logger.warn("Empty response from LLMProactiveQuestionGeneration") + return "" + # remove whitespace text = ' '.join(response_text.split("\n")) if not re.search(r'[.!?]', text[-1]): @@ -72,11 +75,8 @@ def generate_proactive_questions(self, f"or how the results turned out.\n\n" \ f"### Response: " model_batch_request.formatted_prompts.append(model_input) - try: - llm_responses: ModelBatchResponse = self.llm.batch_call_model(model_batch_request) - except _InactiveRpcError as e: - logger.info('LLM Channel is down during proactive question generation') - llm_responses = ModelBatchResponse() + + llm_responses: ModelBatchResponse = self.llm.batch_call_model(model_batch_request) llm_questions: ProactiveQuestionGenerationResponse = ProactiveQuestionGenerationResponse() for text in llm_responses.text: diff --git a/functionalities/llm_summary_generation/llm_summary_generation.py b/functionalities/llm_summary_generation/llm_summary_generation.py index 51e8b21..c366315 100644 --- a/functionalities/llm_summary_generation/llm_summary_generation.py +++ b/functionalities/llm_summary_generation/llm_summary_generation.py @@ -2,30 +2,42 @@ import os import re -from compiled_protobufs.llm_pb2 import ModelRequest, ModelResponse, SummaryGenerationRequest, SummaryGenerationResponse, \ - ModelBatchRequest, SummaryGenerationRequest, MultipleSummaryGenerationRequest, MultipleSummaryGenerationResponse +from compiled_protobufs.llm_pb2 import ( + SummaryGenerationRequest, + SummaryGenerationResponse, + MultipleSummaryGenerationRequest, + MultipleSummaryGenerationResponse, + ModelRequest, + ModelResponse, + ModelBatchRequest, + ModelBatchResponse, +) from compiled_protobufs.llm_pb2_grpc import LLMRunnerStub from utils import logger def process_response_text(response_text: str) -> str: + if len(response_text) == 0: + logger.warn("Empty response from LLMSummaryGeneration") + return "" + # remove whitespace - text = ' '.join(response_text.split("\n")) - if not re.search(r'[.!?]', text[-1]): + text = " ".join(response_text.split("\n")) + if not re.search(r"[.!?]", text[-1]): # If not, add a period (.) to the end of the text - text += '.' + text += "." # split at punctuation - sentences = re.split(r'(?<=[.!?])\s+', text) + sentences = re.split(r"(?<=[.!?])\s+", text) complete_sentences = [] for sentence in sentences: - if sentence.endswith(('.', '!', '?')): + if sentence.endswith((".", "!", "?")): # remove numbered lists - sentence = re.sub(r'\d+\.', '', sentence) + sentence = re.sub(r"\d+\.", "", sentence) complete_sentences.append(sentence.strip()) - return ' '.join(complete_sentences) + return " ".join(complete_sentences) def truncate_details(details: str) -> str: @@ -34,16 +46,16 @@ def truncate_details(details: str) -> str: return " ".join(details_list[:n]) -def extract_qa_answer(generated_answer): +def extract_qa_answer(generated_answer: str) -> str: start_token = "### Response:" end_token = "#" if start_token in generated_answer: generated_answer = generated_answer.split(start_token)[1] if end_token in generated_answer: generated_answer = generated_answer.split(end_token)[0] - generated_answer = generated_answer.replace('\n', " ") + generated_answer = generated_answer.replace("\n", " ") if ". " in generated_answer: - generated_answer = ". ".join(generated_answer.split('.')[:-1]) + generated_answer = ". ".join(generated_answer.split(".")[:-1]) return generated_answer.strip() @@ -53,14 +65,16 @@ def __init__(self): self.llm = LLMRunnerStub(llm_channel) def generate_summary(self, request: SummaryGenerationRequest) -> SummaryGenerationResponse: - model_request: ModelRequest = ModelRequest() + model_request = ModelRequest() step_text = request.step_text title = request.task_title details = request.more_details details = truncate_details(details) - summary_generation_prompt = "Create a brief 2-sentence text by condensing key points from both the Details " \ - "and the Step, omitting insignificant details." + summary_generation_prompt = ( + "Create a brief 2-sentence text by condensing key points from both the Details " + "and the Step, omitting insignificant details." + ) model_input = f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. @@ -81,18 +95,20 @@ def generate_summary(self, request: SummaryGenerationRequest) -> SummaryGenerati model_request.formatted_prompt = model_input model_request.max_tokens = 100 - llm_response = self.llm.call_model(model_request) + llm_response: ModelResponse = self.llm.call_model(model_request) - summary_response: SummaryGenerationResponse = SummaryGenerationResponse() + summary_response = SummaryGenerationResponse() summary_response.summary = process_response_text(extract_qa_answer(llm_response.text)) return summary_response def generate_summaries(self, request: MultipleSummaryGenerationRequest) -> MultipleSummaryGenerationResponse: - model_batch_request: ModelBatchRequest = ModelBatchRequest() - model_batch_request.max_tokens = 100 + model_request = ModelBatchRequest() + model_request.max_tokens = 100 - summary_generation_prompt = "Create a brief 2-sentence text by condensing key points from both the Details " \ - "and the Step, omitting insignificant details." + summary_generation_prompt = ( + "Create a brief 2-sentence text by condensing key points from both the Details " + "and the Step, omitting insignificant details." + ) for title, step, details in zip(request.task_title, request.step_text, request.more_details): model_input = f"""Below is an instruction that describes a task, paired with an input that provides @@ -110,12 +126,13 @@ def generate_summaries(self, request: MultipleSummaryGenerationRequest) -> Multi ### Response: """ - model_batch_request.formatted_prompts.append(model_input) + model_request.formatted_prompts.append(model_input) - llm_responses = self.llm.batch_call_model(model_batch_request) + llm_responses: ModelBatchResponse = self.llm.batch_call_model(model_request) - llm_summaries: MultipleSummaryGenerationResponse = MultipleSummaryGenerationResponse() + llm_summaries = MultipleSummaryGenerationResponse() for text in llm_responses.text: + logger.info(f"Response: {text}") llm_summaries.summary.append(process_response_text(extract_qa_answer(text))) return llm_summaries diff --git a/functionalities/qa/llm_qa.py b/functionalities/qa/llm_qa.py index 70e8c3e..9cf5781 100644 --- a/functionalities/qa/llm_qa.py +++ b/functionalities/qa/llm_qa.py @@ -1,13 +1,10 @@ import os -import grpc -import time import random import re -import spacy +from typing import Optional -from grpc._channel import _InactiveRpcError -from concurrent.futures import TimeoutError, ThreadPoolExecutor -from typing import List +import spacy +import grpc from taskmap_pb2 import Session, Task, ExecutionStep from qa_pb2 import QAQuery, QARequest, QAResponse, DocumentList @@ -15,7 +12,6 @@ from llm_pb2_grpc import LLMRunnerStub from .abstract_qa import AbstractQA -from chitchat_classifier_pb2 import ChitChatResponse from task_graph import TaskGraph from utils import ( @@ -44,7 +40,7 @@ def __init__(self, environ_var: str): logger.info('LLM QA SpaCy initialized') @staticmethod - def __get_helpful_prompt(user_query: str) -> str: + def __get_helpful_prompt(user_query: str) -> Optional[str]: for word in user_query.split(" "): for keyword, answer in HELPFUL_PROMPT_PAIRS: if word == keyword: @@ -54,6 +50,8 @@ def __get_helpful_prompt(user_query: str) -> str: if keyword in user_query: return answer + return None + def process_last_sentence(self, last_sentence: str) -> str: if last_sentence.endswith(('.', '!', '?')): @@ -72,11 +70,11 @@ def process_last_sentence(self, last_sentence: str) -> str: return "" if last_sentence.endswith((',', ":", ";")): - last_sentence[-1] = "." + last_sentence = last_sentence[:-1] + "." return f" {last_sentence}".strip() - def extract_qa_answer(self, generated_answer): + def extract_qa_answer(self, generated_answer: str) -> str: start_token = "Your response: " end_token_1 = "Human" end_token_2 = "#" @@ -102,11 +100,10 @@ def domain_retrieve(self, query: QAQuery) -> DocumentList: pass @staticmethod - def strip_newlines(text): + def strip_newlines(text: str) -> str: return text.replace('\n', " ").replace(" ", " ") def __build_context_task_selected(self, task_graph: TaskGraph, query: str, question_type: str) -> str: - context = [] if task_graph.author != "": @@ -208,7 +205,6 @@ def __build_context_general(self, request) -> str: return " ".join(context) def build_prompt(self, request: QARequest) -> ModelRequest: - task_graph: TaskGraph = TaskGraph(request.query.taskmap) user_question: str = request.query.text model_request: ModelRequest = ModelRequest() @@ -283,7 +279,7 @@ def build_prompt(self, request: QARequest) -> ModelRequest: return model_request - def call_llm_qa(self, request: QAQuery) -> ChitChatResponse: + def generate_qa_response(self, request: QARequest, default_timeout: int = 1500) -> QAResponse: model_request = self.build_prompt(request) llm_response: ModelResponse = self.llm.call_model(model_request) logger.info(llm_response.text) @@ -296,45 +292,6 @@ def call_llm_qa(self, request: QAQuery) -> ChitChatResponse: return agent_response - def generate_qa_response(self, request: QARequest, default_timeout=1500) -> QAResponse: - # CALL LLM SERVICE FROM HERE - - with ThreadPoolExecutor(max_workers=1) as executor: - future = executor.submit(self.call_llm_qa, request) - timeout: float = default_timeout / 1000 + time.monotonic() - try: - if future.done() or timeout - time.monotonic() > 0: - response = future.result(timeout=timeout - time.monotonic()) - return response - - else: - future.cancel() - logger.warning(f"Timeout for LLM QA") - response = QAResponse() - response.text = "" - return response - - except TimeoutError as e: - future.cancel() - logger.warning("TimeoutError while running LLM QA", exc_info=e) - response = QAResponse() - response.text = "" - return response - - except _InactiveRpcError as e: - future.cancel() - logger.warning("LLM Channel is down") - response = QAResponse() - response.text = "" - return response - - except Exception as e: - future.cancel() - logger.info(e) - response = QAResponse() - response.text = "" - return response - def synth_response(self, request: QARequest) -> QAResponse: qa_response = self.generate_qa_response(request) user_question = request.query.text diff --git a/functionalities/qa/substitution_helper.py b/functionalities/qa/substitution_helper.py index 6371347..077a218 100644 --- a/functionalities/qa/substitution_helper.py +++ b/functionalities/qa/substitution_helper.py @@ -1,20 +1,19 @@ import grpc import os import random -import spacy +from typing import Tuple -from utils import logger, REPLACE_SUGGESTION, jaccard_sim, NOT_POSSIBLE, indri_stop_words +import spacy +from spacy.matcher import Matcher +from pyserini.analysis import Analyzer, get_lucene_analyzer from qa_pb2 import QARequest, QAResponse from llm_pb2 import IngredientReplacementRequest, IngredientReplacementResponse from llm_pb2_grpc import LLMReplacementGenerationStub from taskmap_pb2 import ReplacedIngredient, Ingredient -from grpc._channel import _InactiveRpcError -from spacy.matcher import Matcher -from pyserini.analysis import Analyzer, get_lucene_analyzer -from typing import Tuple +from utils import logger, REPLACE_SUGGESTION, jaccard_sim, NOT_POSSIBLE, indri_stop_words class SubstitutionHelper: def __init__(self): @@ -101,16 +100,12 @@ def generate_replacement(self, request: QARequest, response: QAResponse) -> Tupl original_req = self.get_original_ingredient(request) if not self.includes_amount(response): replacement_request.original_ingredient.MergeFrom(original_req) - try: - replacement: IngredientReplacementResponse = self.replacement_generator.generate_replacement( - replacement_request) - return replacement.new_ingredient, original_req - except _InactiveRpcError as e: - logger.info(e) - logger.warning("Replacement LLM Channel is down") - return Ingredient(), Ingredient() - - def create_substitution_idea(self, request, qa_response): + + replacement: IngredientReplacementResponse = self.replacement_generator.generate_replacement( + replacement_request) + return replacement.new_ingredient, original_req + + def create_substitution_idea(self, request: QARequest, qa_response: QAResponse) -> QAResponse: replacement_ingredient, original_ingredient = self.generate_replacement(request, qa_response) if replacement_ingredient.name != "": if original_ingredient.name.lower() != replacement_ingredient.name.lower(): diff --git a/llm_functionalities/Dockerfile b/llm_functionalities/Dockerfile index 2bd14ec..a2b0ae2 100644 --- a/llm_functionalities/Dockerfile +++ b/llm_functionalities/Dockerfile @@ -1,45 +1,5 @@ # syntax=docker/dockerfile:1.3 -FROM nvidia/cuda:12.1.0-runtime-ubuntu22.04 - -COPY oat_common/requirements.txt /requirements.txt - -ENV TZ="Europe/London" -RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime \ - && echo $TZ > /etc/timezone \ - && apt-get update \ - && apt-get install -y --no-install-recommends \ - software-properties-common \ - gnupg \ - ca-certificates \ - build-essential \ - git \ - wget \ - locales \ - unzip \ - curl \ - && add-apt-repository -y ppa:deadsnakes/ppa \ - && apt-get install -y --no-install-recommends \ - python3.9 \ - python3-pip \ - python3-setuptools \ - python3.9-distutils \ - python3.9-dev \ - && update-alternatives --install /usr/bin/python3 python /usr/bin/python3.9 5 \ - && locale-gen en_US.UTF-8 \ - && pip3 install --upgrade pip \ - && pip3 install -r /requirements.txt \ - # Install Rust for M1 Compatibility - && curl https://sh.rustup.rs -sSf | bash -s -- -y \ - # removes about 500MB of docs - && rm -rf /root/.rustup/toolchains/stable-x86_64-unknown-linux-gnu/share/doc - -ENV PATH="/root/.cargo/bin:${PATH}" -ENV LANG en_US.UTF-8 -ENV LANGUAGE en_US:en -# Use a common location in the volume for downloaded models from the -# huggingface.co transformers module -# https://huggingface.co/transformers/v4.0.1/installation.html?highlight=transformers_cache#caching-models -ENV TRANSFORMERS_CACHE /shared/file_system/cache/huggingface +FROM oat_common:latest COPY llm_functionalities/requirements.txt /source/requirements.txt diff --git a/llm_functionalities/downloads.toml b/llm_functionalities/downloads.toml index 4a84c6e..41e64e4 100644 --- a/llm_functionalities/downloads.toml +++ b/llm_functionalities/downloads.toml @@ -1,6 +1,6 @@ name = "llm_functionalities" -base_path = "/shared/file_system/downloads/llm_functionalities" +base_path = "/shared/file_system/downloads/tgi/local/" # from model_requirements.txt [[sources]] diff --git a/llm_functionalities/llm_runner/__init__.py b/llm_functionalities/llm_runner/__init__.py index 073bb27..ca86582 100644 --- a/llm_functionalities/llm_runner/__init__.py +++ b/llm_functionalities/llm_runner/__init__.py @@ -1,6 +1,6 @@ from .llm_runner import LLMRunner as DefaultLLMRunner +from compiled_protobufs.llm_pb2_grpc import add_LLMRunnerServicer_to_server as add_to_server from .llm_runner_servicer import ( Servicer, - add_LLMRunnerServicer_to_server as add_to_server -) \ No newline at end of file +) diff --git a/llm_functionalities/llm_runner/llm_runner.py b/llm_functionalities/llm_runner/llm_runner.py index c33edbd..5908795 100644 --- a/llm_functionalities/llm_runner/llm_runner.py +++ b/llm_functionalities/llm_runner/llm_runner.py @@ -1,73 +1,161 @@ +import os +import sys +import time +from concurrent.futures import TimeoutError, ThreadPoolExecutor +from typing import Optional, Callable, Dict, Any -import torch - -from transformers import AutoModelForCausalLM, AutoTokenizer -from torch.cuda import OutOfMemoryError +import dns.resolver +from huggingface_hub import InferenceClient from utils import logger, Downloader -from compiled_protobufs.llm_pb2 import ModelRequest, ModelResponse, ModelBatchRequest, ModelBatchResponse +from compiled_protobufs.llm_pb2 import ( + ModelRequest, + ModelResponse, + ModelBatchRequest, + ModelBatchResponse, +) + +# default timeout for requests that don't set a timeout explicitly +# TODO: allow override by env var? +DEFAULT_TGI_TIMEOUT = 5 +DEFAULT_TGI_TIMEOUT_MS = DEFAULT_TGI_TIMEOUT * 1000 class LLMRunner: def __init__(self): - if torch.cuda.is_available(): - artefact_id = "alpaca_llm" - downloader = Downloader() - downloader.download([artefact_id]) - model_name = downloader.get_artefact_path(artefact_id) - self.model = AutoModelForCausalLM.from_pretrained( - model_name, - torch_dtype=torch.float16, - device_map="auto", - max_memory={i: '24000MB' for i in range(torch.cuda.device_count())}, - ) - self.tokenizer = AutoTokenizer.from_pretrained(model_name) - - self.batch_tokenizer = AutoTokenizer.from_pretrained(model_name) - self.batch_tokenizer.padding_side = "left" - self.batch_tokenizer.pad_token = self.batch_tokenizer.eos_token - - logger.info("Finished loading Alpaca model") - else: - logger.info('No GPU available, not loading LLM...') - exit(1) + downloader = Downloader() + logger.info(f"Found {len(downloader.sources)} enabled download sources, retrieving them...") + # just download any defined sources + downloader.download() + + endpoint_url = os.environ.get("INFERENCE_ENDPOINT_URL", None) + if endpoint_url is None: + logger.error("No INFERENCE_ENDPOINT_URL defined, container will exit") + sys.exit(-1) + + # InferenceClient seems to require an http:// prefix for the URL + if not endpoint_url.startswith("http://"): + endpoint_url = f"http://{endpoint_url}" + + self.endpoint_url = endpoint_url + self.client = None + logger.info("llm_functionalities initialized") + + def _connect_to_endpoint(self) -> bool: + """Attempt to make a connection to the configured TGI endpoint. + + Simply creating an InferenceClient object with the endpoint URL doesn't trigger + a connection, so to force that to happen this just submits a small text_generation + query. If the client object already exists it's taken to mean the connection + has already been successfully established. + + Returns: + True if a connection exists/was created, False otherwise. + """ + + # a client has been instantiated, assume everything is OK + if self.client is not None: + return True + + try: + # First, need to check if we can actually resolve the TGI endpoint + # hostname. If this isn't available, it will cause urllib3 and the Python + # socket library to spend >10s waiting for a result. It seems to be hard + # to do a time-limited host lookup using the stdlib, + # but dnspython makes it easy (lifetime is the timeout parameter here): + dns.resolver.resolve(self.endpoint_url[7:].split(":")[0], lifetime=1) + + # If the hostname is resolved, we then setup a longer timeout which + # TGI will use internally when making HTTP requests via request.post() etc + client = InferenceClient(model=self.endpoint_url, timeout=DEFAULT_TGI_TIMEOUT) + + # Finally, creating the object doesn't appear to actually make a connection, + # so try something that requires a valid connection to succeed + response = client.text_generation(prompt="hello?", max_new_tokens=10) + except Exception as e: + logger.warn(f"Failed to connect to TGI at {self.endpoint_url} ({e})") + return False + + if len(response) == 0: + logger.warn(f"Failed to connect to TGI at {self.endpoint_url} (timed out)") + return False + + self.client = client + return True + + def _call_with_timeout(self, timeout_ms: int, callable: Callable, params: Dict[str, Any]) -> str: + """Call a TGI endpoint with a timeout applied. + + Since we want to avoid potentially lengthy LLM computations delaying the system's responses, + we need to enforce a timeout on each TGI call. This method handles that using a + ThreadPoolExecutor and a future object. + + Returns: + The text response from the LLM, which will be empty if a timeout or other error + occurs (str). + """ + logger.info(f"call_with_timeout, timeout={timeout_ms}, params={params}") + with ThreadPoolExecutor(max_workers=1) as executor: + future = executor.submit(callable, **params) + try: + response = future.result(timeout=timeout_ms / 1000) + logger.info("TGI endpoint returned a response within timeout period") + return response + except TimeoutError as _: + future.cancel() + executor.shutdown(wait=False, cancel_futures=True) + logger.warning(f"A call to the TGI endpoint has timed out after {timeout_ms} ms") + return "" + except Exception as e: + future.cancel() + executor.shutdown(wait=False, cancel_futures=True) + logger.warning(f"A call to the TGI endpoint failed with an unexpected exception: {str(e)}") + # assume this means a problem with the endpoint and connect again + # on a subsequent request + self.client = None + return "" def call_model(self, model_request: ModelRequest) -> ModelResponse: + """Submit a single call to the LLM endpoint""" model_response: ModelResponse = ModelResponse() - try: - formatted_prompt = model_request.formatted_prompt - max_tokens = model_request.max_tokens - inputs = self.tokenizer(formatted_prompt, return_tensors="pt").to("cuda:0") - outputs = self.model.generate(inputs=inputs.input_ids, max_new_tokens=max_tokens) - response_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True) + if not self._connect_to_endpoint(): + return model_response - model_response.text = str(response_text) + timeout_ms = DEFAULT_TGI_TIMEOUT_MS if model_request.timeout == 0 else model_request.timeout - except Exception as e: - logger.info(f'Running LLM failed: {e}') + args = { + "prompt": model_request.formatted_prompt, + "max_new_tokens": model_request.max_tokens, + } + response = self._call_with_timeout(timeout_ms, self.client.text_generation, args) + logger.info(f"LLM response text: {response}") + model_response.text = response return model_response def batch_call_model(self, model_request: ModelBatchRequest) -> ModelBatchResponse: + """Submit a batch of calls to the LLM endpoint""" model_responses: ModelBatchResponse = ModelBatchResponse() - try: - formatted_prompts = list(model_request.formatted_prompts) - max_tokens = model_request.max_tokens + if not self._connect_to_endpoint(): + return model_responses + + timeout_ms = DEFAULT_TGI_TIMEOUT_MS if model_request.timeout == 0 else model_request.timeout - encodings = self.batch_tokenizer(formatted_prompts, padding=True, return_tensors='pt').to("cuda:0") + formatted_prompts = list(model_request.formatted_prompts) + max_tokens = model_request.max_tokens + params = [{"prompt": p, "max_new_tokens": max_tokens} for p in formatted_prompts] - with torch.no_grad(): - generated_ids = self.model.generate(**encodings, max_new_tokens=max_tokens, do_sample=False) - generated_texts = self.batch_tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + logger.info(f"Submitting a batch of {len(params)} calls to TGI") + # since TGI handles batching automatically in its backend, the recommended approach + # seems to be simply submitting multiple requests in parallel as in the test fixtures: + # https://github.com/huggingface/text-generation-inference/blob/main/integration-tests/conftest.py#L434 + with ThreadPoolExecutor(max_workers=len(params)) as pool: + results = pool.map(lambda p: self._call_with_timeout(timeout_ms, self.client.text_generation, p), params) - for text in generated_texts: - model_responses.text.append(text) + for response in results: + logger.info(f"LLM response text: {response}") + model_responses.text.append(response) - except OutOfMemoryError as e: - logger.info(f'We ran out of GPU memory: {e}') - torch.cuda.empty_cache() - exit(1) - return model_responses diff --git a/llm_functionalities/llm_runner/llm_runner_servicer.py b/llm_functionalities/llm_runner/llm_runner_servicer.py index f84776f..a9a7498 100644 --- a/llm_functionalities/llm_runner/llm_runner_servicer.py +++ b/llm_functionalities/llm_runner/llm_runner_servicer.py @@ -1,15 +1,21 @@ -from compiled_protobufs.llm_pb2 import ModelRequest, ModelResponse, ModelBatchRequest, ModelBatchResponse -from compiled_protobufs.llm_pb2_grpc import LLMRunnerServicer, add_LLMRunnerServicer_to_server +from compiled_protobufs.llm_pb2 import ( + ModelRequest, + ModelResponse, + ModelBatchRequest, + ModelBatchResponse, +) +from compiled_protobufs.llm_pb2_grpc import ( + LLMRunnerServicer, +) from . import DefaultLLMRunner class Servicer(LLMRunnerServicer): - def __init__(self): self.model = DefaultLLMRunner() def call_model(self, query: ModelRequest, context) -> ModelResponse: return self.model.call_model(query) - + def batch_call_model(self, query: ModelBatchRequest, context) -> ModelBatchResponse: return self.model.batch_call_model(query) diff --git a/llm_functionalities/main.py b/llm_functionalities/main.py index 60d439b..8b080ac 100644 --- a/llm_functionalities/main.py +++ b/llm_functionalities/main.py @@ -18,8 +18,6 @@ def serve(): add_llm_runner_to_server(LLM_Runner_Servicer(), server) - logger.info('Finished loading all LLM functionalities') - server.add_insecure_port("[::]:8000") server.start() server.wait_for_termination() diff --git a/llm_functionalities/requirements.txt b/llm_functionalities/requirements.txt index e9e4342..d007041 100644 --- a/llm_functionalities/requirements.txt +++ b/llm_functionalities/requirements.txt @@ -1,4 +1,3 @@ -peft==0.4.0 -accelerate==0.21.0 -transformers==4.31.0 -scipy \ No newline at end of file +# summarization seems to be broken in the recent released versions +git+https://github.com/huggingface/huggingface_hub.git +dnspython==2.6.1 diff --git a/neural_functionalities/Dockerfile b/neural_functionalities/Dockerfile index a0a281c..c69fb42 100644 --- a/neural_functionalities/Dockerfile +++ b/neural_functionalities/Dockerfile @@ -1,5 +1,4 @@ -# syntax=docker/dockerfile:1.3 -FROM nvidia/cuda:12.1.0-runtime-ubuntu22.04 +FROM nvidia/cuda:12.4.1-runtime-ubuntu22.04 COPY oat_common/requirements.txt /requirements.txt diff --git a/orchestrator/policy/chitchat_policy/chitchat_policy.py b/orchestrator/policy/chitchat_policy/chitchat_policy.py index 39758ad..c1d23a5 100644 --- a/orchestrator/policy/chitchat_policy/chitchat_policy.py +++ b/orchestrator/policy/chitchat_policy/chitchat_policy.py @@ -138,7 +138,7 @@ def step(self, session: Session) -> Tuple[Session, OutputInteraction]: output.speech_text = f'{helpful_prompt} ' else: output.speech_text = f'{chitchat_response.text}{helpful_prompt} ' - + set_source(output) elif got_llm_response: logger.info(f'CHIT CHAT LLM RESPONSE GIVEN: {chitchat_response.text}') keyword_helpful_prompt = self.__get_helpful_prompt(chitchat_request.text) @@ -149,9 +149,10 @@ def step(self, session: Session) -> Tuple[Session, OutputInteraction]: output.speech_text = f'{chitchat_response.text} {random.choice(transition_options)} {keyword_helpful_prompt.lower()}' else: output.speech_text = f'{chitchat_response.text} {random.choice(transition_options)} {helpful_prompt.lower()}' + set_source(output, msg="chitchat from LLM") elif chitchat_response.text != "": output.speech_text = chitchat_response.text - + set_source(output) if output.speech_text == "": # chitchat responses are bad, so fallback to confused session.turn[-1].user_request.interaction.intents.append("ConfusedIntent") @@ -160,6 +161,6 @@ def step(self, session: Session) -> Tuple[Session, OutputInteraction]: raise PhaseChangeException() else: output = repeat_screen_response(session, output) + set_source(output) - set_source(output) return session, output diff --git a/shared/protobufs/database.proto b/shared/protobufs/database.proto index d1ce06d..a5efe10 100644 --- a/shared/protobufs/database.proto +++ b/shared/protobufs/database.proto @@ -32,9 +32,13 @@ message QueryList{ service Database{ rpc load_taskmap(TaskMapRequest) returns (TaskMap) {} rpc save_taskmap(TaskMapRequest) returns (Void) {} + rpc delete_taskmap(TaskMapRequest) returns (Void) {} rpc load_session(SessionRequest) returns (Session) {} rpc save_session(SessionRequest) returns (Void) {} + rpc delete_session(SessionRequest) returns (Void) {} + + rpc delete_stagedoutput(TaskMapRequest) returns (Void) {} rpc save_search_logs(SearchLog) returns(Void) {} rpc save_asr_logs(ASRLog) returns(Void) {} @@ -44,4 +48,4 @@ service Database{ rpc get_theme(ThemeMapping) returns (ThemeMapping) {} rpc get_queries(Void) returns (QueryList) {} -} \ No newline at end of file +} diff --git a/shared/protobufs/llm.proto b/shared/protobufs/llm.proto index 63585c2..8c352ab 100644 --- a/shared/protobufs/llm.proto +++ b/shared/protobufs/llm.proto @@ -5,6 +5,7 @@ import "taskmap.proto"; message ModelRequest { string formatted_prompt = 1; int32 max_tokens = 2; + int32 timeout = 3; } message ModelResponse { @@ -14,6 +15,7 @@ message ModelResponse { message ModelBatchRequest { repeated string formatted_prompts = 1; int32 max_tokens = 2; + int32 timeout = 3; } message ModelBatchResponse { @@ -145,4 +147,4 @@ service LLMReplacementGeneration { service LLMRunner { rpc call_model(ModelRequest) returns (ModelResponse) {} rpc batch_call_model(ModelBatchRequest) returns (ModelBatchResponse) {} -} \ No newline at end of file +} diff --git a/shared/test_data/sample_taskmap_seriouseats.json b/shared/test_data/sample_taskmap_seriouseats.json new file mode 100644 index 0000000..53641ff --- /dev/null +++ b/shared/test_data/sample_taskmap_seriouseats.json @@ -0,0 +1 @@ +{"taskmapId": "fe1cfe172799b69d125e6fd3af16241e", "title": "Raw Fig\u00a0Bars", "date": "2014-12-14", "sourceUrl": "https://food52.com/recipes/32540-roasted-strawberry-brioche-pudding", "description": "Raw dried fig bars are not often found and these are topped with dark chocolate drizzle and a pinch of salt which make them hard to resist. \u2014Sylvie Taylor @ Roamingtaste", "thumbnailUrl": "https://images.food52.com/oOnHcjzuF1_CkEJa4Go7Vca7vaE=/1000x1000/52f804f9-0298-4b41-9b53-cb5d9a7d66a4--raw_fig_bars.jpg", "ratingOut100": 100, "tags": ["Pudding", "Candy", "Australian/New Zealander", "Milk/Cream", "Strawberry", "Chocolate", "Fruit", "Vegan", "Dessert", "Snack"], "requirementList": [{"uniqueId": "4829157f-7a04-47d7-b430-7e8f8e4b7f85", "name": " dried figs, ends removed and sliced into quarters", "amount": "3 cups"}, {"uniqueId": "40a0ab87-69c5-48ba-8837-2062e06ee4a7", "name": " walnuts", "amount": "1/2 cup"}, {"uniqueId": "69ef2860-7354-4125-b828-5fca4ed449f4", "name": "cup brazil nuts", "amount": "1/3"}, {"uniqueId": "ed3e1a75-c2e8-43aa-8224-fa19ab39f2ec", "name": " chia seeds", "amount": "1 tablespoon"}, {"uniqueId": "beba33df-f12b-4709-b898-16b7ce481f8f", "name": " linseeds", "amount": "1 tablespoon"}, {"uniqueId": "7f55ef57-d350-41c0-8769-9a4f5434a52f", "name": " dark chocolate", "amount": "30 grams"}, {"uniqueId": "ba9e7441-c3ac-42cc-be66-eea42c043568", "name": " coconut oil", "amount": "1/2 teaspoon"}, {"uniqueId": "5428c96f-97f9-4587-b4d5-c7333cc9e577", "name": "Pinch course sea salt", "amount": " "}], "serves": "\nServes\n 8\n ", "steps": [{"uniqueId": "65d7ca29-c84c-4303-9da6-d35be303e206", "response": {"speechText": "Place the quartered figs into a bowl and place in a double boiler, cover and add two tablespoons boiling water until the figs are warmed through, approximately 5 minutes.", "screen": {"format": "TEXT_IMAGE", "paragraphs": ["Place the quartered figs into a bowl and place in a double boiler, cover and add two tablespoons boiling water until the figs are warmed through, approximately 5 minutes."], "imageList": [{"path": "https://images.food52.com/bvKhL9d3y6F8WKxDcMCmvjIKLYg=/1000x1000/a8a61e51-8734-40d6-92ef-98c5e86a0bba--Toffee_1.JPG"}], "requirements": [" dried figs, ends removed and sliced into quarters"]}}, "activeDurationMinutes": 1, "totalDurationMinutes": 1}, {"uniqueId": "19d73de6-f991-469a-ad77-a6cc98f187dd", "response": {"speechText": "Meanwhile, place the nuts and seeds into a blender and pulse until roughly chopped. Set aside.", "screen": {"format": "TEXT_IMAGE", "paragraphs": ["Meanwhile, place the nuts and seeds into a blender and pulse until roughly chopped. Set aside."], "imageList": [{"path": "https://images.food52.com/ZL62D1IGpUJJIWaujlJVplYLfQQ=/1000x1000/e701cb7c-0002-4607-8de7-792ec99f9814--Quince_cheese_004.JPG"}], "requirements": [" chia seeds"], "extraInformation": [{"type": "JOKE", "text": "Much like a baseball game, a blender includes a base and a pitcher. (And sometimes, a batter!)", "keyword": "blender", "imageUrl": "https://alexa-oat-images.s3.amazonaws.com/multimodal_images/jokes/joke_blender_2.png"}]}}, "activeDurationMinutes": 1, "totalDurationMinutes": 1}, {"uniqueId": "d09ba3ac-46a9-4050-8021-0d7ded4cea0f", "response": {"speechText": "Place the warmed figs in a food processor and pulse on and off until roughly chopped and sticky (your food processor may become rather warm and you will have to clear out the blades every 30 seconds to a minute of pulsing, so give it time to cool down if necessary).", "screen": {"format": "TEXT_IMAGE", "paragraphs": ["Place the warmed figs in a food processor and pulse on and off until roughly chopped and sticky (your food processor may become rather warm and you will have to clear out the blades every 30 seconds to a minute of pulsing, so give it time to cool down if necessary)."], "imageList": [{"path": "https://images.food52.com/tq8Tlgmiy2cY2Xpa5CyJZHtq4yw=/1000x1000/fb01978e-59b5-4296-90b7-da1a40be3ac7--power_bar.jpg"}], "requirements": [" dried figs, ends removed and sliced into quarters"]}}, "activeDurationMinutes": 1, "totalDurationMinutes": 1}, {"uniqueId": "60b13db3-d932-446f-8345-53c63dc35146", "response": {"speechText": "Place the figs in the bowl with nuts and seeds and stir until well combined.", "screen": {"format": "TEXT_IMAGE", "paragraphs": ["Place the figs in the bowl with nuts and seeds and stir until well combined."], "imageList": [{"path": "https://images.food52.com/ZL62D1IGpUJJIWaujlJVplYLfQQ=/1000x1000/e701cb7c-0002-4607-8de7-792ec99f9814--Quince_cheese_004.JPG"}], "requirements": [" dried figs, ends removed and sliced into quarters", " chia seeds"], "extraInformation": [{"type": "FUNFACT", "text": "If you take a bowl and put it upside down, it becomes a poml.", "keyword": "bowl", "imageUrl": "https://alexa-oat-images.s3.amazonaws.com/multimodal_images/facts/fact_bowl.png"}]}}, "activeDurationMinutes": 1, "totalDurationMinutes": 1}, {"uniqueId": "241279a8-170f-4d73-9c64-d9796e0a3468", "response": {"speechText": "Line a loaf dish with greasproof paper and spoon the mixture in, pressing down evenly on all sides. Refrigerate for a minimum of an hour.", "screen": {"format": "TEXT_IMAGE", "paragraphs": ["Line a loaf dish with greasproof paper and spoon the mixture in, pressing down evenly on all sides. Refrigerate for a minimum of an hour."], "imageList": [{"path": "https://images.food52.com/bvKhL9d3y6F8WKxDcMCmvjIKLYg=/1000x1000/a8a61e51-8734-40d6-92ef-98c5e86a0bba--Toffee_1.JPG"}]}}, "activeDurationMinutes": 1, "totalDurationMinutes": 1}, {"uniqueId": "f1bbbb36-89e7-4f8a-81d0-227bd9201cb5", "response": {"speechText": "Melt the chocolate and coconut oil, stirring until smooth.", "screen": {"format": "TEXT_IMAGE", "paragraphs": ["Melt the chocolate and coconut oil, stirring until smooth."], "imageList": [{"path": "https://images.food52.com/bvKhL9d3y6F8WKxDcMCmvjIKLYg=/1000x1000/a8a61e51-8734-40d6-92ef-98c5e86a0bba--Toffee_1.JPG"}], "requirements": [" coconut oil", " dark chocolate"], "extraInformation": [{"type": "JOKE", "text": "Did you know that it took over eight years to develop the recipe for milk chocolate? Be glad you weren't one of those testers.", "keyword": "chocolate", "imageUrl": "https://alexa-oat-images.s3.amazonaws.com/multimodal_images/jokes/joke_chocolate.png"}, {"type": "FUNFACT", "text": "Chocolate is a one of the most popular forms of apology.", "keyword": "chocolate", "imageUrl": "https://alexa-oat-images.s3.amazonaws.com/multimodal_images/facts/fact_chocolate.png"}]}}, "activeDurationMinutes": 1, "totalDurationMinutes": 1}, {"uniqueId": "4b07f2bd-767c-4b0b-b3a9-8890d7fe00ce", "response": {"speechText": "Remove the chilled bars from the fridge, slice and drizzle the chocolate over the top.", "screen": {"format": "TEXT_IMAGE", "paragraphs": ["Remove the chilled bars from the fridge, slice and drizzle the chocolate over the top."], "imageList": [{"path": "https://images.food52.com/tq8Tlgmiy2cY2Xpa5CyJZHtq4yw=/1000x1000/fb01978e-59b5-4296-90b7-da1a40be3ac7--power_bar.jpg"}], "requirements": [" dark chocolate"], "extraInformation": [{"type": "JOKE", "text": "Did you know that it took over eight years to develop the recipe for milk chocolate? Be glad you weren't one of those testers.", "keyword": "chocolate", "imageUrl": "https://alexa-oat-images.s3.amazonaws.com/multimodal_images/jokes/joke_chocolate.png"}, {"type": "FUNFACT", "text": "Chocolate is a one of the most popular forms of apology.", "keyword": "chocolate", "imageUrl": "https://alexa-oat-images.s3.amazonaws.com/multimodal_images/facts/fact_chocolate.png"}]}}, "activeDurationMinutes": 1, "totalDurationMinutes": 1}, {"uniqueId": "f9f47158-d828-43a7-90e4-7b06647ec4ed", "response": {"speechText": "Sprinkle the pinch of course sea salt over the top.", "screen": {"format": "TEXT_IMAGE", "paragraphs": ["Sprinkle the pinch of course sea salt over the top."], "imageList": [{"path": "https://images.food52.com/bvKhL9d3y6F8WKxDcMCmvjIKLYg=/1000x1000/a8a61e51-8734-40d6-92ef-98c5e86a0bba--Toffee_1.JPG"}], "requirements": ["Pinch course sea salt"], "extraInformation": [{"type": "JOKE", "text": "Careful with this one. Add too much salt, and it's going to taste all salty.", "keyword": "salt", "imageUrl": "https://alexa-oat-images.s3.amazonaws.com/multimodal_images/jokes/joke_salt_2.png"}]}}, "activeDurationMinutes": 1, "totalDurationMinutes": 1}, {"uniqueId": "4848afc8-de32-4321-8633-afc9e2ce4ce5", "response": {"speechText": "Keep refrigerated until serving.", "screen": {"format": "TEXT_IMAGE", "paragraphs": ["Keep refrigerated until serving."], "imageList": [{"path": "https://images.food52.com/bvKhL9d3y6F8WKxDcMCmvjIKLYg=/1000x1000/a8a61e51-8734-40d6-92ef-98c5e86a0bba--Toffee_1.JPG"}]}}, "activeDurationMinutes": 1, "totalDurationMinutes": 1}], "connectionList": [{"idFrom": "19d73de6-f991-469a-ad77-a6cc98f187dd", "idTo": "d09ba3ac-46a9-4050-8021-0d7ded4cea0f"}, {"idFrom": "f9f47158-d828-43a7-90e4-7b06647ec4ed", "idTo": "4848afc8-de32-4321-8633-afc9e2ce4ce5"}, {"idFrom": "241279a8-170f-4d73-9c64-d9796e0a3468", "idTo": "f1bbbb36-89e7-4f8a-81d0-227bd9201cb5"}, {"idFrom": "60b13db3-d932-446f-8345-53c63dc35146", "idTo": "241279a8-170f-4d73-9c64-d9796e0a3468"}, {"idFrom": "4b07f2bd-767c-4b0b-b3a9-8890d7fe00ce", "idTo": "f9f47158-d828-43a7-90e4-7b06647ec4ed"}, {"idFrom": "f1bbbb36-89e7-4f8a-81d0-227bd9201cb5", "idTo": "4b07f2bd-767c-4b0b-b3a9-8890d7fe00ce"}, {"idFrom": "d09ba3ac-46a9-4050-8021-0d7ded4cea0f", "idTo": "60b13db3-d932-446f-8345-53c63dc35146"}, {"idFrom": "65d7ca29-c84c-4303-9da6-d35be303e206", "idTo": "19d73de6-f991-469a-ad77-a6cc98f187dd"}], "dataset": "common-crawl", "author": "Sylvie Taylor @ Roamingtaste", "domainName": "food52.com"} \ No newline at end of file diff --git a/shared/test_data/sample_taskmap_wikihow.json b/shared/test_data/sample_taskmap_wikihow.json new file mode 100644 index 0000000..b03fbe6 --- /dev/null +++ b/shared/test_data/sample_taskmap_wikihow.json @@ -0,0 +1 @@ +{"taskmapId": "751e1886e67bf9bd4b29f6401dd38911", "title": "How to Make Red Velvet Cake", "date": "2022-07-31", "sourceUrl": "https://www.wikihow.com/Make-Red-Velvet-Cake", "description": "A red velvet cake can be a delicious dessert for almost any occasion, and it doesn't have to be a hard cake to make either! Simple and full of flavor, this is a showstopping dessert worth sharing with friends.", "thumbnailUrl": "https://www.wikihow.com/images/thumb/b/bf/Make-Red-Velvet-Cake-Step-17-Version-2.jpg/v4-460px-Make-Red-Velvet-Cake-Step-17-Version-2.jpg", "tags": ["Chocolate Cakes"], "requirementList": [{"uniqueId": "215f0a51-494c-4b5d-afd5-5988109d9e27", "name": "\u00bd cup (102 grams) shortening (or butter, margarine, or any type of fat used for baking and making pastries)", "amount": "102 grams"}, {"uniqueId": "c77e334d-3b54-4be5-96e1-e1536679b1db", "name": "cups (300 grams) sugar", "amount": "300 grams"}, {"uniqueId": "4eaa46cb-ec54-45a0-88fa-8ec414703556", "name": "3 eggs", "amount": "3 eggs"}, {"uniqueId": "6f3bd176-ab28-4bce-b3d0-fbd256e94df0", "name": "cocoa", "amount": "2 tablespoons"}, {"uniqueId": "a65afaf3-84e5-4d3a-b1f7-9adcdb03a1a8", "name": "ounces red food coloring", "amount": "1 \u00bd"}, {"uniqueId": "62dce040-a70a-44e6-aa64-1d63e7d76cfc", "name": "salt", "amount": "1 teaspoon"}, {"uniqueId": "2580cb6a-acc8-41d0-b9e3-c07919e7db88", "name": "cups (315 grams) flour", "amount": "315 grams"}, {"uniqueId": "39c07933-3ce3-4759-a68e-5534171b4de2", "name": "vanilla", "amount": "1 teaspoon"}, {"uniqueId": "dc94ab8c-e65c-4ee2-a495-35b3d2b8df56", "name": "(240 ml) buttermilk", "amount": "240 ml"}, {"uniqueId": "cd4d223d-d0c4-4b75-ba04-7b3c83b0e5f2", "name": "baking soda", "amount": "1 teaspoon"}, {"uniqueId": "1db558e3-2751-4188-9877-b73f7cd17a49", "name": "vinegar", "amount": "1 tablespoon"}, {"uniqueId": "d5c1501e-49ec-4382-a476-ee54e64c8f6a", "name": "cream cheese, softened", "amount": "16 ounces"}, {"uniqueId": "4080d5e8-69cd-4c3c-9c39-5987229d978e", "name": "(400 grams) powdered sugar", "amount": "400 grams"}, {"uniqueId": "fd4bee6b-26fc-4dbc-8257-80f7d65c4032", "name": "vanilla", "amount": "1 teaspoon"}, {"uniqueId": "672de4cf-bc0f-4764-b12a-cf8b460e7c45", "name": "butter (approx 110 grams)", "amount": "110 grams"}, {"uniqueId": "522186b8-97d0-4555-97c0-944e663107c2", "name": "Cake pan", "amount": " "}, {"uniqueId": "313272d7-ec66-4b95-9ca2-acfebbc4d714", "name": "Mixing bowls", "amount": " "}, {"uniqueId": "e52596c7-cd1b-4309-91d0-6a41fb06cb6e", "name": "Cooking utensils", "amount": " "}], "steps": [{"uniqueId": "e422e3b4-a864-484f-ab46-0fd4333eeeb9", "response": {"speechText": "Gather and measure out all of the ingredients", "screen": {"format": "TEXT_IMAGE", "paragraphs": ["Gather and measure out all of the ingredients.", "Good bakers know that moving quickly and efficiently in the kitchen leads to better cakes and smaller messes. Measuring out ahead of time makes it possible."], "footer": "Gather and measure out all of the ingredients", "imageList": [{"path": "https://www.wikihow.com/images/thumb/e/e6/Make-Red-Velvet-Cake-Step-1-Version-4.jpg/v4-460px-Make-Red-Velvet-Cake-Step-1-Version-4.jpg"}]}, "description": "Good bakers know that moving quickly and efficiently in the kitchen leads to better cakes and smaller messes. Measuring out ahead of time makes it possible."}, "activeDurationMinutes": 1, "totalDurationMinutes": 1}, {"uniqueId": "24879201-13a5-4e99-94f4-e07be041ab62", "response": {"speechText": "Cream shortening and gradually add in sugar", "screen": {"format": "TEXT_IMAGE", "paragraphs": ["Cream shortening and gradually add in sugar.", "Use an electric mixer set to medium speed. Add the sugar along the edges and slowly work it in to avoid sugar splatter."], "footer": "Cream shortening and gradually add in sugar", "imageList": [{"path": "https://www.wikihow.com/images/thumb/3/3e/Make-Red-Velvet-Cake-Step-2-Version-4.jpg/550px-nowatermark-Make-Red-Velvet-Cake-Step-2-Version-4.jpg"}], "video": {"hostedMp4": "https://www.wikihow.com/video/3/3a/Make Red Velvet Cake Step 2 Version 2.360p.mp4"}, "requirements": ["(400 grams) powdered sugar", "cups (300 grams) sugar", "\u00bd cup (102 grams) shortening (or butter, margarine, or any type of fat used for baking and making pastries)"], "extraInformation": [{"type": "JOKE", "text": "Sugar never spoils, but if we put our minds together, I bet we can find a way to make it inedible.", "keyword": "sugar", "imageUrl": "https://alexa-oat-images.s3.amazonaws.com/multimodal_images/jokes/joke_sugar.png"}]}, "description": "Use an electric mixer set to medium speed. Add the sugar along the edges and slowly work it in to avoid sugar splatter."}, "activeDurationMinutes": 1, "totalDurationMinutes": 1}, {"uniqueId": "f89f5e5a-2dd4-4e0c-bb69-9e25d6192c6f", "response": {"speechText": "Add eggs one at a time, beating well after each egg is added", "screen": {"format": "TEXT_IMAGE", "paragraphs": ["Add eggs one at a time, beating well after each egg is added.", "Mix in well, keeping your beater moving. It is okay if you add them both at once, too."], "footer": "Add eggs one at a time, beating well after each egg is added", "imageList": [{"path": "https://www.wikihow.com/images/thumb/f/fc/Make-Red-Velvet-Cake-Step-3-Version-4.jpg/550px-nowatermark-Make-Red-Velvet-Cake-Step-3-Version-4.jpg"}], "video": {"hostedMp4": "https://www.wikihow.com/video/4/47/Make Red Velvet Cake Step 3 Version 2.360p.mp4"}, "requirements": ["3 eggs"]}, "description": "Mix in well, keeping your beater moving. It is okay if you add them both at once, too."}, "activeDurationMinutes": 1, "totalDurationMinutes": 1}, {"uniqueId": "f038335a-c335-4386-bf9c-69e5fded5740", "response": {"speechText": "Make a paste out of the cocoa and food coloring, then add to cream", "screen": {"format": "TEXT_IMAGE", "paragraphs": ["Make a paste out of the cocoa and food coloring, then add to cream.", "In a separate bowl, use a whisk to blend the food coloring into the cocoa. Fun Fact: The original red velvet cakes got their color because imported cocoa was actually red tinted. The food coloring came later."], "footer": "Make a paste out of the cocoa and food coloring, then add to cream", "imageList": [{"path": "https://www.wikihow.com/images/thumb/2/2a/Make-Red-Velvet-Cake-Step-4-Version-4.jpg/550px-nowatermark-Make-Red-Velvet-Cake-Step-4-Version-4.jpg"}], "video": {"hostedMp4": "https://www.wikihow.com/video/d/d5/Make Red Velvet Cake Step 4 Version 2.360p.mp4"}, "requirements": ["ounces red food coloring"], "extraInformation": [{"type": "JOKE", "text": "Cream is the magician of the kitchen. One moment it's liquid, the next it's whipped, and just when you thought you had it figured out, it turns into butter.", "keyword": "cream", "imageUrl": "https://alexa-oat-images.s3.amazonaws.com/multimodal_images/jokes/joke_cream_v1.png"}]}, "description": "In a separate bowl, use a whisk to blend the food coloring into the cocoa. Fun Fact: The original red velvet cakes got their color because imported cocoa was actually red tinted. The food coloring came later."}, "activeDurationMinutes": 1, "totalDurationMinutes": 1}, {"uniqueId": "37bb9be1-1f02-458f-b6d6-24373b68a47e", "response": {"speechText": "Add salt, flour, baking soda, vanilla, and buttermilk, beating well after each ingredient is added", "screen": {"format": "TEXT_IMAGE", "paragraphs": ["Add salt, flour, baking soda, vanilla, and buttermilk, beating well after each ingredient is added.", "You can also mix the flour, salt, and baking soda in a separate bowl, adding at once. You can also add it all together into the batter one at a time, using the electric mixer to get a nice consistent batter. Add the flour slowly to avoid splatter. It can help to add it with the buttermilk."], "footer": "Add salt, flour, baking soda, vanilla, and buttermilk, beating well after each ingredient is added", "imageList": [{"path": "https://www.wikihow.com/images/thumb/2/28/Make-Red-Velvet-Cake-Step-5-Version-4.jpg/550px-nowatermark-Make-Red-Velvet-Cake-Step-5-Version-4.jpg"}], "video": {"hostedMp4": "https://www.wikihow.com/video/1/12/Make Red Velvet Cake Step 5 Version 2.360p.mp4"}, "requirements": ["baking soda", "salt", "vanilla", "cups (315 grams) flour", "(240 ml) buttermilk"]}, "description": "You can also mix the flour, salt, and baking soda in a separate bowl, adding at once. You can also add it all together into the batter one at a time, using the electric mixer to get a nice consistent batter. Add the flour slowly to avoid splatter. It can help to add it with the buttermilk."}, "activeDurationMinutes": 1, "totalDurationMinutes": 1}, {"uniqueId": "0ee4da68-4b9b-491d-a281-986bd18e0c5b", "response": {"speechText": "Pour vinegar over the batter", "screen": {"format": "TEXT_IMAGE", "paragraphs": ["Pour vinegar over the batter.", "Just a splash is all you need -- it will give it a nice subtle tanginess."], "footer": "Pour vinegar over the batter", "imageList": [{"path": "https://www.wikihow.com/images/thumb/e/ec/Make-Red-Velvet-Cake-Step-6-Version-4.jpg/550px-nowatermark-Make-Red-Velvet-Cake-Step-6-Version-4.jpg"}], "video": {"hostedMp4": "https://www.wikihow.com/video/4/42/Make Red Velvet Cake Step 6 Version 2.360p.mp4"}, "requirements": ["vinegar"], "extraInformation": [{"type": "JOKE", "text": "Vinegar can also be used as a cleaning agent. Keep that in mind if you're making a mess.", "keyword": "vinegar", "imageUrl": "https://alexa-oat-images.s3.amazonaws.com/multimodal_images/jokes/joke_vinegar_v2.png"}, {"type": "FUNFACT", "text": "Vinegar can also be used as a cleaning agent: I just wouldn't recommend it for brushing your teeth.", "keyword": "vinegar", "imageUrl": "https://alexa-oat-images.s3.amazonaws.com/multimodal_images/facts/fact_vinegar.png"}]}, "description": "Just a splash is all you need -- it will give it a nice subtle tanginess."}, "activeDurationMinutes": 1, "totalDurationMinutes": 1}, {"uniqueId": "ad4a46f8-1624-4c74-99e3-6a3e47473212", "response": {"speechText": "Stir until well mixed", "screen": {"format": "TEXT_IMAGE", "paragraphs": ["Stir until well mixed.", "You want a thin, consistent batter with not chunks of flour or dry ingredients left. Want a little more red color? Add a few more drops of red food coloring."], "footer": "Stir until well mixed", "imageList": [{"path": "https://www.wikihow.com/images/thumb/0/08/Make-Red-Velvet-Cake-Step-7-Version-4.jpg/550px-nowatermark-Make-Red-Velvet-Cake-Step-7-Version-4.jpg"}], "video": {"hostedMp4": "https://www.wikihow.com/video/c/cf/Make Red Velvet Cake Step 7 Version 2.360p.mp4"}}, "description": "You want a thin, consistent batter with not chunks of flour or dry ingredients left. Want a little more red color? Add a few more drops of red food coloring."}, "activeDurationMinutes": 1, "totalDurationMinutes": 1}, {"uniqueId": "3563062f-837e-4159-bfd3-5264e03fedc6", "response": {"speechText": "Pour the cake in a large cake pan or 2 layer cake pans and bake in a 350\u00baF", "screen": {"format": "TEXT_IMAGE", "paragraphs": ["Pour the cake in a large cake pan or 2 layer cake pans and bake in a 350\u00baF.", "The cake should take roughly an hour to cook. When it is done, you'll know through the toothpick test -- stab the center of the cake with a knife or skewer-- if it comes out without still-wet batter on it, it's done."], "footer": "Pour the cake in a large cake pan or 2 layer cake pans and bake in a 350\u00baF", "imageList": [{"path": "https://www.wikihow.com/images/thumb/b/bd/Make-Red-Velvet-Cake-Step-8-Version-4.jpg/550px-nowatermark-Make-Red-Velvet-Cake-Step-8-Version-4.jpg"}], "video": {"hostedMp4": "https://www.wikihow.com/video/c/ce/Make Red Velvet Cake Step 8 Version 2.360p.mp4"}, "requirements": ["Cake pan"]}, "description": "The cake should take roughly an hour to cook. When it is done, you'll know through the toothpick test -- stab the center of the cake with a knife or skewer-- if it comes out without still-wet batter on it, it's done."}, "activeDurationMinutes": 1, "totalDurationMinutes": 1}, {"uniqueId": "823c8451-1d93-4ce1-899d-b72b27c78ed0", "response": {"speechText": "Wait about 20 minutes to cool before frosting", "screen": {"format": "TEXT_IMAGE", "paragraphs": ["Wait about 20 minutes to cool before frosting.", "After 5 minutes, remove from the pan and let cool on a wire rack. Don't frost a hot cake -- the warmth with thin out the frosting and make it difficult, if not impossible, to add smoothly."], "footer": "Wait about 20 minutes to cool before frosting", "imageList": [{"path": "https://www.wikihow.com/images/thumb/0/05/Make-Red-Velvet-Cake-Step-9-Version-3.jpg/v4-460px-Make-Red-Velvet-Cake-Step-9-Version-3.jpg"}]}, "description": "After 5 minutes, remove from the pan and let cool on a wire rack. Don't frost a hot cake -- the warmth with thin out the frosting and make it difficult, if not impossible, to add smoothly."}, "activeDurationMinutes": 1, "totalDurationMinutes": 1}, {"uniqueId": "515e8948-f64c-47e9-8878-75ba41ad1087", "response": {"speechText": "Set the butter and cream cheese on the counter to warm to room temperature", "screen": {"format": "TEXT_IMAGE", "paragraphs": ["Set the butter and cream cheese on the counter to warm to room temperature.", "You'll be whipping the butter and cream cheese up, but that only works if it is soft enough to whip! Set the two dairy products out for 15-30 minutes to get soft. In a pinch, you can gently microwave them to speed things up, but keep it very brief. You don't want liquid."], "footer": "Set the butter and cream cheese on the counter to warm to room temperature", "imageList": [{"path": "https://www.wikihow.com/images/thumb/1/19/Make-Red-Velvet-Cake-Step-10-Version-4.jpg/v4-460px-Make-Red-Velvet-Cake-Step-10-Version-4.jpg"}], "requirements": ["cream cheese, softened"], "extraInformation": [{"type": "JOKE", "text": "Cream is the magician of the kitchen. One moment it's liquid, the next it's whipped, and just when you thought you had it figured out, it turns into butter.", "keyword": "cream", "imageUrl": "https://alexa-oat-images.s3.amazonaws.com/multimodal_images/jokes/joke_cream_v1.png"}]}, "description": "You'll be whipping the butter and cream cheese up, but that only works if it is soft enough to whip! Set the two dairy products out for 15-30 minutes to get soft. In a pinch, you can gently microwave them to speed things up, but keep it very brief. You don't want liquid."}, "activeDurationMinutes": 1, "totalDurationMinutes": 1}, {"uniqueId": "288894e4-55e2-463e-a239-a983cf11c608", "response": {"speechText": "Combine butter and cream cheese", "screen": {"format": "TEXT_IMAGE", "paragraphs": ["Combine butter and cream cheese.", "An electric mixer is best, as it will make quick work of the dairies, but a wooden spoon and whisk work well too. They don't have to be perfectly combined, just well mixed."], "footer": "Combine butter and cream cheese", "imageList": [{"path": "https://www.wikihow.com/images/thumb/5/55/Make-Red-Velvet-Cake-Step-11-Version-4.jpg/550px-nowatermark-Make-Red-Velvet-Cake-Step-11-Version-4.jpg"}], "video": {"hostedMp4": "https://www.wikihow.com/video/6/6d/Make Red Velvet Cake Step 11 Version 2.360p.mp4"}, "requirements": ["butter (approx 110 grams)", "\u00bd cup (102 grams) shortening (or butter, margarine, or any type of fat used for baking and making pastries)", "cream cheese, softened"], "extraInformation": [{"type": "JOKE", "text": "Cream is the magician of the kitchen. One moment it's liquid, the next it's whipped, and just when you thought you had it figured out, it turns into butter.", "keyword": "cream", "imageUrl": "https://alexa-oat-images.s3.amazonaws.com/multimodal_images/jokes/joke_cream_v1.png"}]}, "description": "An electric mixer is best, as it will make quick work of the dairies, but a wooden spoon and whisk work well too. They don't have to be perfectly combined, just well mixed."}, "activeDurationMinutes": 1, "totalDurationMinutes": 1}, {"uniqueId": "2cadd7f5-b3a8-4448-a2e4-e95bb801bffe", "response": {"speechText": "Add powdered sugar slowly, keeping the mixing going throughout", "screen": {"format": "TEXT_IMAGE", "paragraphs": ["Add powdered sugar slowly, keeping the mixing going throughout.", "Powdered sugar will want to poof and fly out as you mix it. To avoid a mess, add it in 3-4 parts, mixing almost all of the first part in before adding the second."], "footer": "Add powdered sugar slowly, keeping the mixing going throughout", "imageList": [{"path": "https://www.wikihow.com/images/thumb/f/f1/Make-Red-Velvet-Cake-Step-12-Version-4.jpg/550px-nowatermark-Make-Red-Velvet-Cake-Step-12-Version-4.jpg"}], "video": {"hostedMp4": "https://www.wikihow.com/video/3/32/Make Red Velvet Cake Step 12 Version 2.360p.mp4"}, "requirements": ["(400 grams) powdered sugar", "cups (300 grams) sugar"], "extraInformation": [{"type": "JOKE", "text": "Sugar never spoils, but if we put our minds together, I bet we can find a way to make it inedible.", "keyword": "sugar", "imageUrl": "https://alexa-oat-images.s3.amazonaws.com/multimodal_images/jokes/joke_sugar.png"}]}, "description": "Powdered sugar will want to poof and fly out as you mix it. To avoid a mess, add it in 3-4 parts, mixing almost all of the first part in before adding the second."}, "activeDurationMinutes": 1, "totalDurationMinutes": 1}, {"uniqueId": "92c6c3b3-e299-4213-9ea0-c903873b6fad", "response": {"speechText": "Add vanilla and whip until creamy", "screen": {"format": "TEXT_IMAGE", "paragraphs": ["Add vanilla and whip until creamy.", "Keep the mixer (or your mixing hand) going until the frosting is nice smooth texture. If you want to thin it out a bit so it spreads better, add 2 tablespoons of cold milk."], "footer": "Add vanilla and whip until creamy", "imageList": [{"path": "https://www.wikihow.com/images/thumb/e/e8/Make-Red-Velvet-Cake-Step-13-Version-4.jpg/550px-nowatermark-Make-Red-Velvet-Cake-Step-13-Version-4.jpg"}], "video": {"hostedMp4": "https://www.wikihow.com/video/8/81/Make Red Velvet Cake Step 13 Version 2.360p.mp4"}, "requirements": ["vanilla"], "extraInformation": [{"type": "JOKE", "text": "Cream is the magician of the kitchen. One moment it's liquid, the next it's whipped, and just when you thought you had it figured out, it turns into butter.", "keyword": "cream", "imageUrl": "https://alexa-oat-images.s3.amazonaws.com/multimodal_images/jokes/joke_cream_v1.png"}]}, "description": "Keep the mixer (or your mixing hand) going until the frosting is nice smooth texture. If you want to thin it out a bit so it spreads better, add 2 tablespoons of cold milk."}, "activeDurationMinutes": 1, "totalDurationMinutes": 1}, {"uniqueId": "53eaeb70-24db-4a4d-b035-669d5a69c7cb", "response": {"speechText": "Cut the cake into layers and frost", "screen": {"format": "TEXT_IMAGE", "paragraphs": ["Cut the cake into layers and frost.", "Place a little pat of icing on the bottom of our plate or cake dish to keep the bottom layer from sliding around. Then frost it and stack another layer on top, frosting the top of that. Don't worry too much about the sides yet. Do not try to frost the cake while it is still hot. Let it cool completely."], "footer": "Cut the cake into layers and frost", "imageList": [{"path": "https://www.wikihow.com/images/thumb/5/58/Make-Red-Velvet-Cake-Step-14-Version-4.jpg/550px-nowatermark-Make-Red-Velvet-Cake-Step-14-Version-4.jpg"}], "video": {"hostedMp4": "https://www.wikihow.com/video/0/04/Make Red Velvet Cake Step 14 Version 2.360p.mp4"}, "requirements": ["Cake pan"]}, "description": "Place a little pat of icing on the bottom of our plate or cake dish to keep the bottom layer from sliding around. Then frost it and stack another layer on top, frosting the top of that. Don't worry too much about the sides yet. Do not try to frost the cake while it is still hot. Let it cool completely."}, "activeDurationMinutes": 1, "totalDurationMinutes": 1}, {"uniqueId": "df328633-b041-4149-b5eb-19f6425fc338", "response": {"speechText": "Assemble the layers and continue frosting", "screen": {"format": "TEXT_IMAGE", "paragraphs": ["Assemble the layers and continue frosting.", "Stack the layers up high, frosting between each layer with a 1/4\" of frosting or so, adding more to taste."], "footer": "Assemble the layers and continue frosting", "imageList": [{"path": "https://www.wikihow.com/images/thumb/9/90/Make-Red-Velvet-Cake-Step-15-Version-3.jpg/550px-nowatermark-Make-Red-Velvet-Cake-Step-15-Version-3.jpg"}], "video": {"hostedMp4": "https://www.wikihow.com/video/6/6c/Make Red Velvet Cake Step 15 Version 2.360p.mp4"}}, "description": "Stack the layers up high, frosting between each layer with a 1/4\" of frosting or so, adding more to taste."}, "activeDurationMinutes": 1, "totalDurationMinutes": 1}, {"uniqueId": "ea85bd8d-48bb-401d-a56b-03d26c516479", "response": {"speechText": "Frost the cake and enjoy!", "screen": {"format": "TEXT_IMAGE", "paragraphs": ["Frost the cake and enjoy!.", "For bakery quality frosting, keep the knife clean after every pass, using a little warm water to ensure your frosting knife applies the icing smoothly and evenly. Use big globs of icing at a time and don't try to spread it too thin. By working in small areas, not spreading too thin, and cleaning the knife regularly, you can get a quality frosted cake. If you've got time, real pros will \"double frost.\" Start with a thin layer of frosting everywhere -- it is okay if it pulls up crumbs. Then freeze the cake for 15 minutes, pull it out, and frost \"for real.\" You'll be astonished how easily it goes on!"], "footer": "Frost the cake and enjoy!", "imageList": [{"path": "https://www.wikihow.com/images/thumb/1/1a/Make-Red-Velvet-Cake-Step-16-Version-3.jpg/550px-nowatermark-Make-Red-Velvet-Cake-Step-16-Version-3.jpg"}], "video": {"hostedMp4": "https://www.wikihow.com/video/f/fb/Make Red Velvet Cake Step 16 Version 2.360p.mp4"}, "requirements": ["Cake pan"]}, "description": "For bakery quality frosting, keep the knife clean after every pass, using a little warm water to ensure your frosting knife applies the icing smoothly and evenly. Use big globs of icing at a time and don't try to spread it too thin. By working in small areas, not spreading too thin, and cleaning the knife regularly, you can get a quality frosted cake. If you've got time, real pros will \"double frost.\" Start with a thin layer of frosting everywhere -- it is okay if it pulls up crumbs. Then freeze the cake for 15 minutes, pull it out, and frost \"for real.\" You'll be astonished how easily it goes on!"}, "activeDurationMinutes": 1, "totalDurationMinutes": 1}, {"uniqueId": "ced2b2e2-657e-4df3-8bd2-a2b86a287d95", "response": {"screen": {"format": "TEXT_IMAGE", "paragraphs": [""], "footer": "Finished", "imageList": [{"path": "https://www.wikihow.com/images/thumb/b/bf/Make-Red-Velvet-Cake-Step-17-Version-2.jpg/v4-460px-Make-Red-Velvet-Cake-Step-17-Version-2.jpg"}]}}, "activeDurationMinutes": 1, "totalDurationMinutes": 1}], "connectionList": [{"idFrom": "0ee4da68-4b9b-491d-a281-986bd18e0c5b", "idTo": "ad4a46f8-1624-4c74-99e3-6a3e47473212"}, {"idFrom": "24879201-13a5-4e99-94f4-e07be041ab62", "idTo": "f89f5e5a-2dd4-4e0c-bb69-9e25d6192c6f"}, {"idFrom": "3563062f-837e-4159-bfd3-5264e03fedc6", "idTo": "823c8451-1d93-4ce1-899d-b72b27c78ed0"}, {"idFrom": "df328633-b041-4149-b5eb-19f6425fc338", "idTo": "ea85bd8d-48bb-401d-a56b-03d26c516479"}, {"idFrom": "288894e4-55e2-463e-a239-a983cf11c608", "idTo": "2cadd7f5-b3a8-4448-a2e4-e95bb801bffe"}, {"idFrom": "ea85bd8d-48bb-401d-a56b-03d26c516479", "idTo": "ced2b2e2-657e-4df3-8bd2-a2b86a287d95"}, {"idFrom": "f89f5e5a-2dd4-4e0c-bb69-9e25d6192c6f", "idTo": "f038335a-c335-4386-bf9c-69e5fded5740"}, {"idFrom": "f038335a-c335-4386-bf9c-69e5fded5740", "idTo": "37bb9be1-1f02-458f-b6d6-24373b68a47e"}, {"idFrom": "2cadd7f5-b3a8-4448-a2e4-e95bb801bffe", "idTo": "92c6c3b3-e299-4213-9ea0-c903873b6fad"}, {"idFrom": "e422e3b4-a864-484f-ab46-0fd4333eeeb9", "idTo": "24879201-13a5-4e99-94f4-e07be041ab62"}, {"idFrom": "37bb9be1-1f02-458f-b6d6-24373b68a47e", "idTo": "0ee4da68-4b9b-491d-a281-986bd18e0c5b"}, {"idFrom": "823c8451-1d93-4ce1-899d-b72b27c78ed0", "idTo": "515e8948-f64c-47e9-8878-75ba41ad1087"}, {"idFrom": "515e8948-f64c-47e9-8878-75ba41ad1087", "idTo": "288894e4-55e2-463e-a239-a983cf11c608"}, {"idFrom": "53eaeb70-24db-4a4d-b035-669d5a69c7cb", "idTo": "df328633-b041-4149-b5eb-19f6425fc338"}, {"idFrom": "92c6c3b3-e299-4213-9ea0-c903873b6fad", "idTo": "53eaeb70-24db-4a4d-b035-669d5a69c7cb"}, {"idFrom": "ad4a46f8-1624-4c74-99e3-6a3e47473212", "idTo": "3563062f-837e-4159-bfd3-5264e03fedc6"}], "dataset": "common-crawl", "author": "Mathew Rice", "domainName": "wikihow"} \ No newline at end of file diff --git a/shared/utils/aws/new_proto_db.py b/shared/utils/aws/new_proto_db.py index 1810b66..544ae25 100644 --- a/shared/utils/aws/new_proto_db.py +++ b/shared/utils/aws/new_proto_db.py @@ -140,6 +140,16 @@ def get(self, item_id: str, decode: bool = True) -> Message: else: return proto_dict + def delete(self, item_id: str) -> None: + try: + self.__table.delete_item( + Key={ + self.primary_key: item_id, + } + ) + except Exception as e: + pass + def batch_put(self, proto_obj_list: List[Message], check_for_changes: bool = True) -> List[str]: diff --git a/tester/tests/conftest.py b/tester/tests/conftest.py index 4deaf89..38553c1 100644 --- a/tester/tests/conftest.py +++ b/tester/tests/conftest.py @@ -1,4 +1,5 @@ import os +import json import uuid import pytest @@ -44,8 +45,7 @@ def pytest_collection_modifyitems(config, items): def new_session_obj(): test_id = "test_" + str(uuid.uuid4()) session = { - "text": "", - "id": test_id, + "session_id": test_id, "headless": False, } return session @@ -79,3 +79,13 @@ def invalid_downloads_path() -> str: @pytest.fixture def missing_values_downloads_path() -> str: return "/shared/test_data/missing_values_downloads.toml" + +@pytest.fixture +def sample_taskmap_json_wikihow() -> str: + # returns a JSON dump of a TaskMap extracted from the default index + return open("/shared/test_data/sample_taskmap_wikihow.json", "r").read() + +@pytest.fixture +def sample_taskmap_json_seriouseats() -> str: + # returns a JSON dump of a TaskMap extracted from the default index + return open("/shared/test_data/sample_taskmap_seriouseats.json", "r").read() diff --git a/tester/tests/integration_tests/interaction.py b/tester/tests/integration_tests/interaction.py index bc5f14d..384c871 100644 --- a/tester/tests/integration_tests/interaction.py +++ b/tester/tests/integration_tests/interaction.py @@ -135,7 +135,7 @@ def run(self, input_text: str, session: dict, intent: str = '') -> dict: Returns: a JSON dict generated from the HTTP response """ - response = self.send_request(input_text, intent, session['id'], session['headless']) + response = self.send_request(input_text, intent, session['session_id'], session['headless']) response_json = response.json() print(f'USER: {input_text}') try: diff --git a/tester/tests/integration_tests/test_llm.py b/tester/tests/integration_tests/test_llm.py new file mode 100644 index 0000000..4785ea0 --- /dev/null +++ b/tester/tests/integration_tests/test_llm.py @@ -0,0 +1,212 @@ +import os +import json +import time +import uuid +import random + +import grpc +from google.protobuf.json_format import Parse, ParseDict, MessageToDict +import pytest + +from utils import INTRO_PROMPTS, REPLACE_SUGGESTION +from taskmap_pb2 import TaskMap, Session, ConversationTurn, Task, SessionState, ReplacedIngredient +from database_pb2 import SessionRequest, TaskMapRequest +from database_pb2_grpc import DatabaseStub +from llm_pb2 import ( + MultipleSummaryGenerationRequest, + MultipleSummaryGenerationResponse, +) +from llm_pb2_grpc import ( + LLMSummaryGenerationStub, +) + +# TODO: test timeout behaviour +# TODO: test TGI being unavailable + + +def test_llm_chit_chat(new_session_obj, interaction_obj): + """ + Simple test of LLM chit chat functionality. + + To trigger this code, we just create a new Session and send + a couple of messages that will trigger the ChitChatPolicy. + + To check that the response comes from the LLM successfully + the "source" object associated with the response needs to + have a message that's set in the appropriate block of code + in chitchat_policy.py + """ + session = new_session_obj + + # start things off with a "hi" + response = interaction_obj.run("hi", session) + # check that this response comes from the domain_policy as expected, + # and ends with one of the INTRO_PROMPT values + assert response["source"]["policy"] == "domain_policy" + interaction_obj.check_prompts(response, INTRO_PROMPTS) + + # next, reply with "what's up" - this should be classed as a ChitChatIntent and + # trigger a call to the LLM chit chat service + response = interaction_obj.run("what's up", session) + + # the response should now come from a particular code block in + # chitchat_policy.py if it successfully retrieved an LLM response + assert response["source"]["policy"] == "chitchat_policy" + assert response["source"]["message"] == "chitchat from LLM" + + # TODO check output + + +@pytest.mark.slow +def test_llm_enhancement(new_session_obj, sample_taskmap_json_wikihow: str): + """ + Tests the TaskMap enhancement performed by external_functionalities. + + Specifically this is the process triggered by loading/saving a session in + external_functionalities/database/dynamo_db.py. + + This test sets up a minimal valid Session with a Wikihow taskmap, and deletes any + existing copies of the TaskMap and its StagedOutput representation from the + DynamoDB instance as a first step. This is to force the enhancement to run + every time as normally the results are cached in the database. + + Doing this should exercise multiple different LLM components in functionalities. + When a suitable session is loaded in DynamoDB.load_session (one which has a + Wikihow URL and is not marked as already enhanced), it will call the enhance() + method of the StagedEnhance class. For TaskMaps that haven't yet been enhanced, + it spawns a thread which will end up calling LLMTaskMapEnhancer.enhance_taskmap. + + This does 2 things: + 1. It makes some llm_summary_generation requests based on groups of steps + in the TaskMap + 2. It makes an llm_proactive_question_answering request + + i.e. both of these components are exercised by this single test. + """ + + # create the objects required + session = new_session_obj + taskmap = TaskMap() + channel = grpc.insecure_channel(os.environ["EXTERNAL_FUNCTIONALITIES_URL"]) + db = DatabaseStub(channel) + + # parse the TaskMap JSON into a protobuf + Parse(sample_taskmap_json_wikihow, taskmap) + + # construct a protobuf Session object from the basic JSON representation in "session" + # and copy in the sample TaskMap + session_proto = Session() + ParseDict(session, session_proto) + session_proto.task.taskmap.CopyFrom(taskmap) + + # construct a SessionRequest to make an RPC to external_functionalities + session_request: SessionRequest = SessionRequest() + session_request.id = session["session_id"] + session_request.session.CopyFrom(session_proto) + + # delete any stored copies of the TaskMap and its StagedOutput (enhanced) representation + taskmap_request = TaskMapRequest() + taskmap_request.id = taskmap.taskmap_id + db.delete_taskmap(taskmap_request) + db.delete_stagedoutput(taskmap_request) + + # save the new session and then load it to trigger the enhancement + # (saving will also trigger some enhancement, but not the full process) + db.save_session(session_request) + session = db.load_session(session_request) + + # because the database object spawns a worker thread on the remote side, + # the RPC will return immediately. wait here until reloading the session + # shows it marked as enhanced (via session.task.state.enhanced == True), + # or we timeout. Repeated calls to load_session won't trigger repeated + # enhancement requests so this is OK to do. + timeout = 60 + delay = 10 + enhanced = False + while not enhanced and timeout > 0: + time.sleep(delay) + timeout -= delay + session = db.load_session(session_request) + enhanced = session.task.state.enhanced + + assert enhanced + + +def test_llm_ingredient_substitution_step_text(interaction_obj, sample_taskmap_json_seriouseats: str): + """ + Tests the llm_ingredient_substitution step rewriting functionality. + + This can be triggered by orchestrator/policy/validation_policy/validation_policy.py + in certain conditions: + - the session.task.phase must be Task.TaskPhase.VALIDATING (to trigger the policy) + - session.task.state.requirements_displayed must be True + - session.turn[-2].user_request.interaction must contain a QuestionIntent + - session.turn[-2].agent_response.interaction.speech_text must contain a REPLACE_SUGGESTION + - session.turn[-1].user_request.interaction must contain a YesIntent + + If all these conditions are satisifed, a call is made to adjust_step_texts inside + LLMIngredientStepTextRewriter. + + (this will also trigger the proactive question answering component) + """ + + channel = grpc.insecure_channel(os.environ["EXTERNAL_FUNCTIONALITIES_URL"]) + db = DatabaseStub(channel) + taskmap = TaskMap() + + Parse(sample_taskmap_json_seriouseats, taskmap) + + # make sure our TaskMap isn't already in the database + taskmap_request = TaskMapRequest() + taskmap_request.id = taskmap.taskmap_id + db.delete_taskmap(taskmap_request) + db.delete_stagedoutput(taskmap_request) + + # construct a protobuf Session object from the basic JSON representation in session + session_proto = Session() + session_proto.session_id = f"test_{str(uuid.uuid4())}" + + # configure the session state required to trigger this LLM functionality + session_proto.state = SessionState.RUNNING + session_proto.task.phase = Task.TaskPhase.VALIDATING + session_proto.task.state.requirements_displayed = True + + # set up a replaced ingredient (this is normally done via QA policy) + ri = ReplacedIngredient() + ri.original.name = "chocolate" + ri.original.amount = "30 grams" + ri.replacement.name = "cheese" + ri.replacement.amount = "100 grams" + session_proto.task.taskmap.CopyFrom(taskmap) + session_proto.task.taskmap.replaced_ingredients.append(ri) + + # need to add a turn where the user_request.interaction contains a QuestionIntent + t1 = ConversationTurn() + t1.user_request.interaction.intents.extend(["QuestionIntent"]) + # also need to have an ingredient mentioned in the interaction text + t1.user_request.interaction.text = "can you replace the chocolate in this recipe?" + t1.id = f"test_turn_{uuid.uuid4()}" + + # and the agent must respond with a REPLACE_SUGGESTION prompt + t1.agent_response.interaction.speech_text = f"{random.choice(REPLACE_SUGGESTION)} {ri.replacement.name}" + session_proto.turn.extend([t1]) + + # the following turn should also contain a YesIntent, but that will be generated + # by setting the appropriate user input text below + + # save the new session first, so that the orchestrator will load it + # with the correct state when we submit the interaction request + session_request: SessionRequest = SessionRequest() + session_request.id = session_proto.session_id + session_request.session.CopyFrom(session_proto) + db.save_session(session_request) + + + # working with a session via interaction_obj requires a dict + s = MessageToDict(session_proto, including_default_value_fields=True) + s["session_id"] = session_proto.session_id + # use speech text that will generate a YesIntent, required to trigger the substitution + new_session_dict = interaction_obj.run("yes ok", s) + + # check that the requirements now include the replacment ingredient + assert any(ri.replacement.name in req for req in new_session_dict["screen"]["requirements"]) diff --git a/tgi/Dockerfile b/tgi/Dockerfile new file mode 100644 index 0000000..174b527 --- /dev/null +++ b/tgi/Dockerfile @@ -0,0 +1,8 @@ +FROM ghcr.io/huggingface/text-generation-inference:1.4.5 +RUN apt-get update && apt-get install -y curl + +# wrapper script for the TGI launcher which just unpacks +# parameters passed in using the TGI_PARAMS env var via +# docker-compose.yml +COPY tgi/tgi-wrapper.sh /tmp +ENTRYPOINT ["bash", "/tmp/tgi-wrapper.sh"] diff --git a/tgi/tgi-wrapper.sh b/tgi/tgi-wrapper.sh new file mode 100755 index 0000000..e761255 --- /dev/null +++ b/tgi/tgi-wrapper.sh @@ -0,0 +1,11 @@ +#!/bin/bash + +# read the value of $TGI_PARAMS into an array, splitting on spaces +# TODO: might break if there are parameter values that also contain spaces +IFS=' ' read -ra PARAM_ARRAY <<< "${TGI_PARAMS}" + +#echo "Parameters: ${PARAM_ARRAY[@]}" + +# Pass the parameters on to the launcher. +# (only default arg in the original Dockerfile is --json-output) +text-generation-launcher "${PARAM_ARRAY[@]}"