From 7f9d021c8d21a7e6cbcfdaaace3ca3b613cbd065 Mon Sep 17 00:00:00 2001 From: Andrew Ramsay Date: Thu, 22 Feb 2024 15:13:17 +0000 Subject: [PATCH 01/57] adding TGI Dockerfile --- hg_tgi/Dockerfile | 2 ++ 1 file changed, 2 insertions(+) create mode 100644 hg_tgi/Dockerfile diff --git a/hg_tgi/Dockerfile b/hg_tgi/Dockerfile new file mode 100644 index 0000000..94e36c8 --- /dev/null +++ b/hg_tgi/Dockerfile @@ -0,0 +1,2 @@ +FROM ghcr.io/huggingface/text-generation-inference:latest +RUN apt-get update && apt-get install -y curl From 5823d5662078403653a50b5d09792e0b9e0b8a2f Mon Sep 17 00:00:00 2001 From: Andrew Ramsay Date: Thu, 22 Feb 2024 15:14:53 +0000 Subject: [PATCH 02/57] Add tgi service to docker-compose.yml --- docker-compose.yml | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/docker-compose.yml b/docker-compose.yml index d043736..10e99ce 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -300,6 +300,20 @@ services: condition: service_started oat_common: condition: service_started + tgi: + platform: linux/x86_64 + container_name: tgi + build: + context: ./ + dockerfile: hg_tgi/Dockerfile + volumes: + # the container downloads weights and other files to this path + - ./shared/file_system/tgi:/data + environment: + - MODEL_ID=google/flan-t5-base + networks: + - internal + - external deploy: resources: reservations: From 9f5383b82558c8313ee9896df0640fbd01cc6ec9 Mon Sep 17 00:00:00 2001 From: Andrew Ramsay Date: Thu, 22 Feb 2024 15:15:11 +0000 Subject: [PATCH 03/57] Add env var for TGI to llm_functionalities --- docker-compose.yml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/docker-compose.yml b/docker-compose.yml index 10e99ce..71492bc 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -292,14 +292,16 @@ services: - NEURAL_FUNCTIONALITIES_URL=llm_functionalities:8000 - FUNCTIONALITIES_URL=functionalities:8000 - EXTERNAL_FUNCTIONALITIES_URL=external_functionalities:8000 + - INFERENCE_ENDPOINT_URL=tgi:80 networks: - internal - external depends_on: builder: condition: service_started - oat_common: + tgi: condition: service_started + tgi: platform: linux/x86_64 container_name: tgi From 7f70244b438e2ec95f92ce14fcdfae8b67efdb53 Mon Sep 17 00:00:00 2001 From: Andrew Ramsay Date: Thu, 22 Feb 2024 15:15:47 +0000 Subject: [PATCH 04/57] Updating llm_functionalities - Remove existing requirements and replace with huggingface_hub - Use oat_common as a base since we don't need CUDA support now --- llm_functionalities/Dockerfile | 42 +--------------------------- llm_functionalities/requirements.txt | 5 +--- 2 files changed, 2 insertions(+), 45 deletions(-) diff --git a/llm_functionalities/Dockerfile b/llm_functionalities/Dockerfile index 2bd14ec..a2b0ae2 100644 --- a/llm_functionalities/Dockerfile +++ b/llm_functionalities/Dockerfile @@ -1,45 +1,5 @@ # syntax=docker/dockerfile:1.3 -FROM nvidia/cuda:12.1.0-runtime-ubuntu22.04 - -COPY oat_common/requirements.txt /requirements.txt - -ENV TZ="Europe/London" -RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime \ - && echo $TZ > /etc/timezone \ - && apt-get update \ - && apt-get install -y --no-install-recommends \ - software-properties-common \ - gnupg \ - ca-certificates \ - build-essential \ - git \ - wget \ - locales \ - unzip \ - curl \ - && add-apt-repository -y ppa:deadsnakes/ppa \ - && apt-get install -y --no-install-recommends \ - python3.9 \ - python3-pip \ - python3-setuptools \ - python3.9-distutils \ - python3.9-dev \ - && update-alternatives --install /usr/bin/python3 python /usr/bin/python3.9 5 \ - && locale-gen en_US.UTF-8 \ - && pip3 install --upgrade pip \ - && pip3 install -r /requirements.txt \ - # Install Rust for M1 Compatibility - && curl https://sh.rustup.rs -sSf | bash -s -- -y \ - # removes about 500MB of docs - && rm -rf /root/.rustup/toolchains/stable-x86_64-unknown-linux-gnu/share/doc - -ENV PATH="/root/.cargo/bin:${PATH}" -ENV LANG en_US.UTF-8 -ENV LANGUAGE en_US:en -# Use a common location in the volume for downloaded models from the -# huggingface.co transformers module -# https://huggingface.co/transformers/v4.0.1/installation.html?highlight=transformers_cache#caching-models -ENV TRANSFORMERS_CACHE /shared/file_system/cache/huggingface +FROM oat_common:latest COPY llm_functionalities/requirements.txt /source/requirements.txt diff --git a/llm_functionalities/requirements.txt b/llm_functionalities/requirements.txt index e9e4342..ae61678 100644 --- a/llm_functionalities/requirements.txt +++ b/llm_functionalities/requirements.txt @@ -1,4 +1 @@ -peft==0.4.0 -accelerate==0.21.0 -transformers==4.31.0 -scipy \ No newline at end of file +huggingface-hub==0.20.3 From 8f7244016e16b0185291814eb4ea964a79c7f893 Mon Sep 17 00:00:00 2001 From: Andrew Ramsay Date: Thu, 22 Feb 2024 15:17:15 +0000 Subject: [PATCH 05/57] Initial attempt to integrate TGI I've replaced the existing code to load the model with creating an `InferenceClient` object using the endpoint URL defined in the docker-compose.yml file. Creating the object doesn't trigger any connection, so it currently submits a simple query to check if the TGI endpoint is actually available. This might not happen immediately (e.g. if it's still downloading or loading a model), it currently has a basic retry setup but might need some more thought put into it. The `call_model` and `batch_call_model` methods are updated to call the `.text_generation` method on the `InferenceClient`. For the batch method, it submits things in parallel using a ThreadPoolExecutor, TGI doesn't offer a batch-specific endpoint but it should automatically batch the requests internally based on the docs. --- llm_functionalities/llm_runner/llm_runner.py | 112 +++++++++++-------- llm_functionalities/main.py | 2 - 2 files changed, 66 insertions(+), 48 deletions(-) diff --git a/llm_functionalities/llm_runner/llm_runner.py b/llm_functionalities/llm_runner/llm_runner.py index c33edbd..9198a5f 100644 --- a/llm_functionalities/llm_runner/llm_runner.py +++ b/llm_functionalities/llm_runner/llm_runner.py @@ -1,73 +1,93 @@ +import os +import sys +import concurrent.futures +import time -import torch +from huggingface_hub import InferenceClient -from transformers import AutoModelForCausalLM, AutoTokenizer -from torch.cuda import OutOfMemoryError - -from utils import logger, Downloader -from compiled_protobufs.llm_pb2 import ModelRequest, ModelResponse, ModelBatchRequest, ModelBatchResponse +from utils import logger +from compiled_protobufs.llm_pb2 import ( + ModelRequest, + ModelResponse, + ModelBatchRequest, + ModelBatchResponse, +) class LLMRunner: def __init__(self): - if torch.cuda.is_available(): - artefact_id = "alpaca_llm" - downloader = Downloader() - downloader.download([artefact_id]) - model_name = downloader.get_artefact_path(artefact_id) - self.model = AutoModelForCausalLM.from_pretrained( - model_name, - torch_dtype=torch.float16, - device_map="auto", - max_memory={i: '24000MB' for i in range(torch.cuda.device_count())}, - ) - self.tokenizer = AutoTokenizer.from_pretrained(model_name) - - self.batch_tokenizer = AutoTokenizer.from_pretrained(model_name) - self.batch_tokenizer.padding_side = "left" - self.batch_tokenizer.pad_token = self.batch_tokenizer.eos_token - - logger.info("Finished loading Alpaca model") - else: - logger.info('No GPU available, not loading LLM...') - exit(1) + endpoint_url = os.environ.get("INFERENCE_ENDPOINT_URL", None) + if endpoint_url is None: + logger.error("No INFERENCE_ENDPOINT_URL defined, container will exit") + sys.exit(-1) + + if not endpoint_url.startswith("http://"): + endpoint_url = f"http://{endpoint_url}" + + self.client = None + retries = 0 + while retries < 10: + client = self._connect_to_endpoint(endpoint_url) + if client is None: + logger.info(f"LLMRunner retrying connection to {endpoint_url}") + time.sleep(5) + retries += 1 + else: + logger.info("LLMRunner connected to endpoint!") + self.client = client + break + + def _connect_to_endpoint(self, endpoint_url: str) -> InferenceClient: + client = InferenceClient(model=endpoint_url) + try: + # creating the object doesn't appear to actually make a connection, so + # try something that will fail if it can't connect + client.text_generation(prompt="hello?", max_new_tokens=10) + except Exception: + return None + return client def call_model(self, model_request: ModelRequest) -> ModelResponse: model_response: ModelResponse = ModelResponse() - try: - formatted_prompt = model_request.formatted_prompt - max_tokens = model_request.max_tokens - inputs = self.tokenizer(formatted_prompt, return_tensors="pt").to("cuda:0") - outputs = self.model.generate(inputs=inputs.input_ids, max_new_tokens=max_tokens) - response_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True) + if self.client is None: + raise Exception("llm_functionalities isn't connected to an endpoint!") - model_response.text = str(response_text) + try: + response = self.client.text_generation( + prompt=model_request.formatted_prompt, + max_new_tokens=model_request.max_tokens, + ) + logger.info(f"LLM response text: {response}") + model_response.text = response except Exception as e: - logger.info(f'Running LLM failed: {e}') + logger.warning(f"Call to inference endpoint failed: {e}") return model_response def batch_call_model(self, model_request: ModelBatchRequest) -> ModelBatchResponse: model_responses: ModelBatchResponse = ModelBatchResponse() + if self.client is None: + raise Exception("llm_functionalities isn't connected to an endpoint!") + try: formatted_prompts = list(model_request.formatted_prompts) max_tokens = model_request.max_tokens + params = [ + {"prompt": p, "max_new_tokens": max_tokens} for p in formatted_prompts + ] - encodings = self.batch_tokenizer(formatted_prompts, padding=True, return_tensors='pt').to("cuda:0") + logger.info(f"Submitting a batch of {len(params)} calls to TGI") + with concurrent.futures.ThreadPoolExecutor(max_workers=12) as pool: + results = pool.map(lambda p: self.client.text_generation(**p), params) - with torch.no_grad(): - generated_ids = self.model.generate(**encodings, max_new_tokens=max_tokens, do_sample=False) - generated_texts = self.batch_tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + for response in results: + logger.info(f"LLM response text: {response}") + model_responses.text.append(response) - for text in generated_texts: - model_responses.text.append(text) + except Exception as e: + logger.warning(f"Call to inference endpoint failed: {e}") - except OutOfMemoryError as e: - logger.info(f'We ran out of GPU memory: {e}') - torch.cuda.empty_cache() - exit(1) - return model_responses diff --git a/llm_functionalities/main.py b/llm_functionalities/main.py index 60d439b..8b080ac 100644 --- a/llm_functionalities/main.py +++ b/llm_functionalities/main.py @@ -18,8 +18,6 @@ def serve(): add_llm_runner_to_server(LLM_Runner_Servicer(), server) - logger.info('Finished loading all LLM functionalities') - server.add_insecure_port("[::]:8000") server.start() server.wait_for_termination() From d715048dc046145c998bfde6c981da47673ccfe1 Mon Sep 17 00:00:00 2001 From: Andrew Ramsay Date: Thu, 22 Feb 2024 15:28:40 +0000 Subject: [PATCH 06/57] Have llm_functionalities exit if connection fails --- llm_functionalities/llm_runner/llm_runner.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/llm_functionalities/llm_runner/llm_runner.py b/llm_functionalities/llm_runner/llm_runner.py index 9198a5f..8209dc9 100644 --- a/llm_functionalities/llm_runner/llm_runner.py +++ b/llm_functionalities/llm_runner/llm_runner.py @@ -37,6 +37,10 @@ def __init__(self): self.client = client break + if self.client is None: + logger.error(f"LLMRunner failed to connect to the endpoint at {endpoint_url}") + sys.exit(-1) + def _connect_to_endpoint(self, endpoint_url: str) -> InferenceClient: client = InferenceClient(model=endpoint_url) try: From 6d06e78da952f43c7835b9ea37fa85b2ec363e37 Mon Sep 17 00:00:00 2001 From: Andrew Ramsay Date: Mon, 26 Feb 2024 10:52:47 +0000 Subject: [PATCH 07/57] adding timeout parameter when creating InferenceClient --- llm_functionalities/llm_runner/llm_runner.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/llm_functionalities/llm_runner/llm_runner.py b/llm_functionalities/llm_runner/llm_runner.py index 8209dc9..0898c56 100644 --- a/llm_functionalities/llm_runner/llm_runner.py +++ b/llm_functionalities/llm_runner/llm_runner.py @@ -1,7 +1,7 @@ import os import sys -import concurrent.futures import time +import concurrent.futures from huggingface_hub import InferenceClient @@ -38,11 +38,13 @@ def __init__(self): break if self.client is None: - logger.error(f"LLMRunner failed to connect to the endpoint at {endpoint_url}") + logger.error( + f"LLMRunner failed to connect to the endpoint at {endpoint_url}" + ) sys.exit(-1) def _connect_to_endpoint(self, endpoint_url: str) -> InferenceClient: - client = InferenceClient(model=endpoint_url) + client = InferenceClient(model=endpoint_url, timeout=10.0) try: # creating the object doesn't appear to actually make a connection, so # try something that will fail if it can't connect From 26e22d8ec2f80473099ea6a90ebd07fb7a31919b Mon Sep 17 00:00:00 2001 From: Andrew Ramsay Date: Mon, 26 Feb 2024 10:58:41 +0000 Subject: [PATCH 08/57] Some more TGI updates - Add some env vars with default values to the tgi service definition to allow easy control of some options that might need changed depending on model/hardware - Add a wrapper script as the entrypoint for the tgi container to allow passing in extra CLI parameters using a TGI_PARAMS env var --- docker-compose.yml | 17 +++++++++++++++-- hg_tgi/Dockerfile | 6 ++++++ hg_tgi/tgi-wrapper.sh | 11 +++++++++++ 3 files changed, 32 insertions(+), 2 deletions(-) create mode 100755 hg_tgi/tgi-wrapper.sh diff --git a/docker-compose.yml b/docker-compose.yml index 71492bc..349e4fa 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -308,20 +308,33 @@ services: build: context: ./ dockerfile: hg_tgi/Dockerfile + ports: + # this isn't needed for llm_functionalities (it uses the internal network), + # but might be useful to be able to submit requests to the instance from + # external scripts for testing/debugging + - "8081:80" volumes: # the container downloads weights and other files to this path - ./shared/file_system/tgi:/data environment: - - MODEL_ID=google/flan-t5-base + # setting the value of MODEL_ID is equivalent to passing "--model-id" parameter + # to the TGI launcher + - MODEL_ID=${MODEL_ID:-google/flan-t5-xxl} + # any other TGI launcher parameters can be set in this env var, e.g.: + # TGI_PARAMS="--param1 param1_value --param2 param2_value" docker compose up + - TGI_PARAMS=${TGI_PARAMS:-} networks: - internal - external + # larger sharded models will need this increased from the default (usually 64MB) + shm_size: ${SHM_SIZE:-2gb} deploy: resources: reservations: devices: - driver: nvidia - count: 1 + # set the number of GPUs available to the container, default to 1 + count: ${GPU_COUNT:-1} capabilities: [ gpu ] networks: diff --git a/hg_tgi/Dockerfile b/hg_tgi/Dockerfile index 94e36c8..ed11e25 100644 --- a/hg_tgi/Dockerfile +++ b/hg_tgi/Dockerfile @@ -1,2 +1,8 @@ FROM ghcr.io/huggingface/text-generation-inference:latest RUN apt-get update && apt-get install -y curl + +# wrapper script for the TGI launcher which just unpacks +# parameters passed in using the TGI_PARAMS env var via +# docker-compose.yml +COPY hg_tgi/tgi-wrapper.sh /tmp +ENTRYPOINT ["bash", "/tmp/tgi-wrapper.sh"] diff --git a/hg_tgi/tgi-wrapper.sh b/hg_tgi/tgi-wrapper.sh new file mode 100755 index 0000000..e761255 --- /dev/null +++ b/hg_tgi/tgi-wrapper.sh @@ -0,0 +1,11 @@ +#!/bin/bash + +# read the value of $TGI_PARAMS into an array, splitting on spaces +# TODO: might break if there are parameter values that also contain spaces +IFS=' ' read -ra PARAM_ARRAY <<< "${TGI_PARAMS}" + +#echo "Parameters: ${PARAM_ARRAY[@]}" + +# Pass the parameters on to the launcher. +# (only default arg in the original Dockerfile is --json-output) +text-generation-launcher "${PARAM_ARRAY[@]}" From c4cfbdc871a4e1fa78084d3adbfef40ee0bf4df9 Mon Sep 17 00:00:00 2001 From: Andrew Ramsay Date: Fri, 8 Mar 2024 15:54:17 +0000 Subject: [PATCH 09/57] Adding connection attempt parameters for TGI These 2 env vars can be used to adjust the number of retries llm_functionalities will make when attempting to connect to the TGI endpoint, and the delay between successive retries. --- docker-compose.yml | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/docker-compose.yml b/docker-compose.yml index 349e4fa..62c17a8 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -292,7 +292,16 @@ services: - NEURAL_FUNCTIONALITIES_URL=llm_functionalities:8000 - FUNCTIONALITIES_URL=functionalities:8000 - EXTERNAL_FUNCTIONALITIES_URL=external_functionalities:8000 + # the location of the TGI endpoint (note that it's using the internal + # Docker network, so it's port 80 rather than 8081) - INFERENCE_ENDPOINT_URL=tgi:80 + # these values can be used to control how long this service waits + # for the TGI endpoint to become available. this might take a significant + # amount of time in some cases, e.g. if it has to download a large model + # Number of retries (default 10) + - TGI_CONNECTION_RETRY_LIMIT=${TGI_CONNECTION_RETRY_LIMIT:-10} + # delay between retries in seconds + - TGI_CONNECTION_RETRY_DELAY=${TGI_CONNECTION_RETRY_DELAY:-10} networks: - internal - external From f670a1203c947abb793d0669f8975ab4bdf2bf5d Mon Sep 17 00:00:00 2001 From: Andrew Ramsay Date: Fri, 8 Mar 2024 15:55:25 +0000 Subject: [PATCH 10/57] Use a smaller default model for testing --- docker-compose.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docker-compose.yml b/docker-compose.yml index 62c17a8..89aeae0 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -328,7 +328,7 @@ services: environment: # setting the value of MODEL_ID is equivalent to passing "--model-id" parameter # to the TGI launcher - - MODEL_ID=${MODEL_ID:-google/flan-t5-xxl} + - MODEL_ID=${MODEL_ID:-google/flan-t5-large} # any other TGI launcher parameters can be set in this env var, e.g.: # TGI_PARAMS="--param1 param1_value --param2 param2_value" docker compose up - TGI_PARAMS=${TGI_PARAMS:-} From cdc624adb70c1482b31f749fa0a5632411bfeef5 Mon Sep 17 00:00:00 2001 From: Andrew Ramsay Date: Fri, 8 Mar 2024 15:56:29 +0000 Subject: [PATCH 11/57] Add TGI-specific summarization protos/RPCs --- shared/protobufs/llm.proto | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/shared/protobufs/llm.proto b/shared/protobufs/llm.proto index 63585c2..4cf5643 100644 --- a/shared/protobufs/llm.proto +++ b/shared/protobufs/llm.proto @@ -68,6 +68,22 @@ message MultipleSummaryGenerationResponse { repeated string summary = 1; } +message TGISummaryRequest { + string input_text = 1; +} + +message TGISummaryResponse { + string summary_text = 1; +} + +message TGIMultipleSummaryRequest { + repeated string input_text = 1; +} + +message TGIMultipleSummaryResponse { + repeated string summary_text = 1; +} + message ProactiveQuestionGenerationRequest { repeated string task_title = 1; repeated string previous_steps = 2; @@ -145,4 +161,6 @@ service LLMReplacementGeneration { service LLMRunner { rpc call_model(ModelRequest) returns (ModelResponse) {} rpc batch_call_model(ModelBatchRequest) returns (ModelBatchResponse) {} -} \ No newline at end of file + rpc generate_summary(TGISummaryRequest) returns (TGISummaryResponse) {} + rpc generate_summaries(TGIMultipleSummaryRequest) returns (TGIMultipleSummaryResponse) {} +} From 5754857d5175a17f275e50e5f098eda367c6d264 Mon Sep 17 00:00:00 2001 From: Andrew Ramsay Date: Fri, 8 Mar 2024 15:58:54 +0000 Subject: [PATCH 12/57] Use current source version of huggingface_hub There seems to be a bug in the `InferenceClient.summarization` method in the recent officially released versions, it's fixed in the current development version. --- llm_functionalities/requirements.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/llm_functionalities/requirements.txt b/llm_functionalities/requirements.txt index ae61678..1b081ae 100644 --- a/llm_functionalities/requirements.txt +++ b/llm_functionalities/requirements.txt @@ -1 +1,2 @@ -huggingface-hub==0.20.3 +# summarization seems to be broken in the recent released versions +git+https://github.com/huggingface/huggingface_hub.git From d7cd49b9e8bfac343a7b9b19547a253cb996fa23 Mon Sep 17 00:00:00 2001 From: Andrew Ramsay Date: Fri, 8 Mar 2024 16:01:10 +0000 Subject: [PATCH 13/57] Update LLMRunner connection behaviour Use the new env vars from `docker-compose.yml` to control the connection attempts to the TGI endpoint --- llm_functionalities/llm_runner/llm_runner.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/llm_functionalities/llm_runner/llm_runner.py b/llm_functionalities/llm_runner/llm_runner.py index 0898c56..44c13ea 100644 --- a/llm_functionalities/llm_runner/llm_runner.py +++ b/llm_functionalities/llm_runner/llm_runner.py @@ -26,11 +26,19 @@ def __init__(self): self.client = None retries = 0 - while retries < 10: + retry_limit = int(os.environ.get("TGI_CONNECTION_RETRY_LIMIT", 10)) + retry_delay = int(os.environ.get("TGI_CONNECTION_RETRY_DELAY", 10)) + logger.info( + f"Connecting to TGI (max {retry_limit} connections, {retry_delay}s apart)" + ) + + # might have to wait for the TGI container to finish starting up, especially if it + # needs to download model files first + while retries < retry_limit: client = self._connect_to_endpoint(endpoint_url) if client is None: logger.info(f"LLMRunner retrying connection to {endpoint_url}") - time.sleep(5) + time.sleep(retry_delay) retries += 1 else: logger.info("LLMRunner connected to endpoint!") From 5c63e9a1c024cb74fff223fc2bfef449f5effada Mon Sep 17 00:00:00 2001 From: Andrew Ramsay Date: Fri, 8 Mar 2024 16:02:15 +0000 Subject: [PATCH 14/57] Add a _check_connectivity method --- llm_functionalities/llm_runner/llm_runner.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/llm_functionalities/llm_runner/llm_runner.py b/llm_functionalities/llm_runner/llm_runner.py index 44c13ea..804a91c 100644 --- a/llm_functionalities/llm_runner/llm_runner.py +++ b/llm_functionalities/llm_runner/llm_runner.py @@ -61,11 +61,14 @@ def _connect_to_endpoint(self, endpoint_url: str) -> InferenceClient: return None return client + def _check_connectivity(self) -> None: + if self.client is None: + raise Exception("llm_functionalities isn't connected to an endpoint!") + def call_model(self, model_request: ModelRequest) -> ModelResponse: model_response: ModelResponse = ModelResponse() - if self.client is None: - raise Exception("llm_functionalities isn't connected to an endpoint!") + self._check_connectivity() try: response = self.client.text_generation( @@ -83,8 +86,7 @@ def call_model(self, model_request: ModelRequest) -> ModelResponse: def batch_call_model(self, model_request: ModelBatchRequest) -> ModelBatchResponse: model_responses: ModelBatchResponse = ModelBatchResponse() - if self.client is None: - raise Exception("llm_functionalities isn't connected to an endpoint!") + self._check_connectivity() try: formatted_prompts = list(model_request.formatted_prompts) From b77fa41b8ed76d05440562e39f58747b1687d280 Mon Sep 17 00:00:00 2001 From: Andrew Ramsay Date: Fri, 8 Mar 2024 16:03:41 +0000 Subject: [PATCH 15/57] Set pool size based on number of requests --- llm_functionalities/llm_runner/llm_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llm_functionalities/llm_runner/llm_runner.py b/llm_functionalities/llm_runner/llm_runner.py index 804a91c..88c1907 100644 --- a/llm_functionalities/llm_runner/llm_runner.py +++ b/llm_functionalities/llm_runner/llm_runner.py @@ -96,7 +96,7 @@ def batch_call_model(self, model_request: ModelBatchRequest) -> ModelBatchRespon ] logger.info(f"Submitting a batch of {len(params)} calls to TGI") - with concurrent.futures.ThreadPoolExecutor(max_workers=12) as pool: + with concurrent.futures.ThreadPoolExecutor(max_workers=len(params)) as pool: results = pool.map(lambda p: self.client.text_generation(**p), params) for response in results: From 4e1e86aedc76edd8a1bedc229346a2bd9251b9c1 Mon Sep 17 00:00:00 2001 From: Andrew Ramsay Date: Fri, 8 Mar 2024 16:04:04 +0000 Subject: [PATCH 16/57] Initial attempt at adding some new TGI endpoints This just adds TGI equivalents for the `generate_summary` and `generate_summaries` methods that pass requests through to the `InferenceClient.summarization` method (https://huggingface.co/docs/huggingface_hub/main/en/package_reference/inference_client#huggingface_hub.InferenceClient.summarization) --- llm_functionalities/llm_runner/__init__.py | 4 +- llm_functionalities/llm_runner/llm_runner.py | 45 +++++++++++++++++++ .../llm_runner/llm_runner_servicer.py | 26 +++++++++-- 3 files changed, 69 insertions(+), 6 deletions(-) diff --git a/llm_functionalities/llm_runner/__init__.py b/llm_functionalities/llm_runner/__init__.py index 073bb27..ca86582 100644 --- a/llm_functionalities/llm_runner/__init__.py +++ b/llm_functionalities/llm_runner/__init__.py @@ -1,6 +1,6 @@ from .llm_runner import LLMRunner as DefaultLLMRunner +from compiled_protobufs.llm_pb2_grpc import add_LLMRunnerServicer_to_server as add_to_server from .llm_runner_servicer import ( Servicer, - add_LLMRunnerServicer_to_server as add_to_server -) \ No newline at end of file +) diff --git a/llm_functionalities/llm_runner/llm_runner.py b/llm_functionalities/llm_runner/llm_runner.py index 88c1907..0706125 100644 --- a/llm_functionalities/llm_runner/llm_runner.py +++ b/llm_functionalities/llm_runner/llm_runner.py @@ -11,6 +11,10 @@ ModelResponse, ModelBatchRequest, ModelBatchResponse, + TGISummaryRequest, + TGISummaryResponse, + TGIMultipleSummaryRequest, + TGIMultipleSummaryResponse, ) @@ -107,3 +111,44 @@ def batch_call_model(self, model_request: ModelBatchRequest) -> ModelBatchRespon logger.warning(f"Call to inference endpoint failed: {e}") return model_responses + + def generate_summary(self, request: TGISummaryRequest) -> TGISummaryResponse: + response = TGISummaryResponse() + + self._check_connectivity() + + logger.info(f"generating summary from: {request.input_text}") + + try: + summarization_output = self.client.summarization(text=str(request.input_text)) + if summarization_output.summary_text is None: + # TODO raise exception? different string response? + response.summary_text = "" + else: + response.summary_text = summarization_output.summary_text + except Exception as e: + logger.warning(f"Call to summarization failed: {e}") + + return response + + def generate_summaries( + self, request: TGIMultipleSummaryRequest + ) -> TGIMultipleSummaryResponse: + response = TGIMultipleSummaryResponse() + + self._check_connectivity() + + try: + params = list(request.input_text) + with concurrent.futures.ThreadPoolExecutor(max_workers=len(params)) as pool: + results = pool.map(lambda p: self.client.summarization(p), params) + + for result in results: + if result.summary_text is None: + response.summary_text.append("") # TODO see comment above + else: + response.summary_text.append(result.summary_text) + except Exception as e: + logger.warning(f"Call to summarization failed: {e}") + + return response diff --git a/llm_functionalities/llm_runner/llm_runner_servicer.py b/llm_functionalities/llm_runner/llm_runner_servicer.py index f84776f..45bfcd9 100644 --- a/llm_functionalities/llm_runner/llm_runner_servicer.py +++ b/llm_functionalities/llm_runner/llm_runner_servicer.py @@ -1,15 +1,33 @@ -from compiled_protobufs.llm_pb2 import ModelRequest, ModelResponse, ModelBatchRequest, ModelBatchResponse -from compiled_protobufs.llm_pb2_grpc import LLMRunnerServicer, add_LLMRunnerServicer_to_server +from compiled_protobufs.llm_pb2 import ( + ModelRequest, + ModelResponse, + ModelBatchRequest, + ModelBatchResponse, + TGISummaryRequest, + TGISummaryResponse, + TGIMultipleSummaryRequest, + TGIMultipleSummaryResponse, +) +from compiled_protobufs.llm_pb2_grpc import ( + LLMRunnerServicer, +) from . import DefaultLLMRunner class Servicer(LLMRunnerServicer): - def __init__(self): self.model = DefaultLLMRunner() def call_model(self, query: ModelRequest, context) -> ModelResponse: return self.model.call_model(query) - + def batch_call_model(self, query: ModelBatchRequest, context) -> ModelBatchResponse: return self.model.batch_call_model(query) + + def generate_summary(self, query: TGISummaryRequest, context) -> TGISummaryResponse: + return self.model.generate_summary(query) + + def generate_summaries( + self, query: TGIMultipleSummaryRequest, context + ) -> TGIMultipleSummaryResponse: + return self.model.generate_summaries(query) From c92a20ff44170e5fb9db8d9c216056ff34b3ab53 Mon Sep 17 00:00:00 2001 From: Andrew Ramsay Date: Fri, 26 Apr 2024 14:46:03 +0100 Subject: [PATCH 17/57] pin version of TGI --- hg_tgi/Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hg_tgi/Dockerfile b/hg_tgi/Dockerfile index ed11e25..493a93f 100644 --- a/hg_tgi/Dockerfile +++ b/hg_tgi/Dockerfile @@ -1,4 +1,4 @@ -FROM ghcr.io/huggingface/text-generation-inference:latest +FROM ghcr.io/huggingface/text-generation-inference:1.4.5 RUN apt-get update && apt-get install -y curl # wrapper script for the TGI launcher which just unpacks From 5a862287e3e7d02d5653e4c18c96af292154770e Mon Sep 17 00:00:00 2001 From: Andrew Ramsay Date: Fri, 26 Apr 2024 14:46:24 +0100 Subject: [PATCH 18/57] remove TGI summary protos and RPCs --- shared/protobufs/llm.proto | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/shared/protobufs/llm.proto b/shared/protobufs/llm.proto index 4cf5643..6537056 100644 --- a/shared/protobufs/llm.proto +++ b/shared/protobufs/llm.proto @@ -68,22 +68,6 @@ message MultipleSummaryGenerationResponse { repeated string summary = 1; } -message TGISummaryRequest { - string input_text = 1; -} - -message TGISummaryResponse { - string summary_text = 1; -} - -message TGIMultipleSummaryRequest { - repeated string input_text = 1; -} - -message TGIMultipleSummaryResponse { - repeated string summary_text = 1; -} - message ProactiveQuestionGenerationRequest { repeated string task_title = 1; repeated string previous_steps = 2; @@ -161,6 +145,4 @@ service LLMReplacementGeneration { service LLMRunner { rpc call_model(ModelRequest) returns (ModelResponse) {} rpc batch_call_model(ModelBatchRequest) returns (ModelBatchResponse) {} - rpc generate_summary(TGISummaryRequest) returns (TGISummaryResponse) {} - rpc generate_summaries(TGIMultipleSummaryRequest) returns (TGIMultipleSummaryResponse) {} } From 6d8e9fde48f241520b04a73bf2a632d67955a91b Mon Sep 17 00:00:00 2001 From: Andrew Ramsay Date: Fri, 26 Apr 2024 14:47:13 +0100 Subject: [PATCH 19/57] Fix session ID field name for tests --- tester/tests/conftest.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tester/tests/conftest.py b/tester/tests/conftest.py index 4deaf89..a6d66e2 100644 --- a/tester/tests/conftest.py +++ b/tester/tests/conftest.py @@ -44,8 +44,7 @@ def pytest_collection_modifyitems(config, items): def new_session_obj(): test_id = "test_" + str(uuid.uuid4()) session = { - "text": "", - "id": test_id, + "session_id": test_id, "headless": False, } return session From c08d061bb476ca58ba8ecb4419a295d4638a35f7 Mon Sep 17 00:00:00 2001 From: Andrew Ramsay Date: Fri, 26 Apr 2024 14:47:49 +0100 Subject: [PATCH 20/57] Add a couple of JSON-format TaskMaps for tests --- shared/test_data/sample_taskmap_seriouseats.json | 1 + shared/test_data/sample_taskmap_wikihow.json | 1 + tester/tests/conftest.py | 11 +++++++++++ 3 files changed, 13 insertions(+) create mode 100644 shared/test_data/sample_taskmap_seriouseats.json create mode 100644 shared/test_data/sample_taskmap_wikihow.json diff --git a/shared/test_data/sample_taskmap_seriouseats.json b/shared/test_data/sample_taskmap_seriouseats.json new file mode 100644 index 0000000..0adef02 --- /dev/null +++ b/shared/test_data/sample_taskmap_seriouseats.json @@ -0,0 +1 @@ +{"taskmapId": "075f4135afedcb5870a371e9e85eac77", "title": "Chicken Maple Sausage Pigs in a Blanket with Maple\u00a0Cream", "date": "2018-11-15", "sourceUrl": "https://food52.com/recipes/78217-chicken-maple-sausage-pigs-in-a-blanket-with-maple-cream", "description": "A more sophisticated take on an old classic. \u2014My Stir Crazy Kitchen", "thumbnailUrl": "https://images.food52.com/PxRPfj1W6qHiGd4ta0b0EM-PM3g=/1000x1000/b22b1adc-e3d1-4005-b305-cbe602cb002c--7L7A1053.JPG", "totalTimeMinutes": "10", "ratingOut100": 100, "tags": ["American", "Serves a Crowd", "Appetizer"], "requirementList": [{"uniqueId": "630cf6ff-f6f1-4644-a98d-1d6431b7655b", "name": " chicken maple sausage links", "amount": "8 pieces"}, {"uniqueId": "11f97441-0874-4e0f-ba20-8e2c535a925b", "name": "packet crescent rolls", "amount": "1"}, {"uniqueId": "15960d6b-ca07-47ac-8b29-aade24aa30f2", "name": " heavy cream", "amount": "6 tablespoons"}, {"uniqueId": "4265f16f-dfb9-4ff2-a316-85cce7f540ed", "name": " maple syrup", "amount": "3 tablespoons"}, {"uniqueId": "a6b93747-cf41-4499-b5eb-2c5adb2f3771", "name": " kosher salt", "amount": "1 teaspoon"}], "serves": "\nMakes\n 8\n ", "steps": [{"uniqueId": "0483c625-5b1e-4893-9761-0c0869200ea7", "response": {"speechText": "In a small bowl, stir to combine the cream, maple syrup and salt.", "screen": {"format": "TEXT_IMAGE", "paragraphs": ["In a small bowl, stir to combine the cream, maple syrup and salt."], "imageList": [{"path": "https://www.seriouseats.com/thmb/fsxoprpah3N9VyuUcm2S7i4RlfI=/1500x0/filters:no_upscale():max_bytes(150000):strip_icc()/__opt__aboutcom__coeus__resources__content_migration__serious_eats__seriouseats.com__recipes__images__2012__01__20120117-187614_GFTues_PigsInABlanket_610-c075c91ac6c84089934469d43382cbb0.jpg"}], "requirements": [" heavy cream", " kosher salt", " maple syrup"], "extraInformation": [{"type": "FUNFACT", "text": "If you take a bowl and put it upside down, it becomes a poml.", "keyword": "bowl", "imageUrl": "https://alexa-oat-images.s3.amazonaws.com/multimodal_images/facts/fact_bowl.png"}]}}, "activeDurationMinutes": 1, "totalDurationMinutes": 1}, {"uniqueId": "345b60f3-9ec1-496c-b578-12774723e682", "response": {"speechText": "Preheat oven to 350F.", "screen": {"format": "TEXT_IMAGE", "paragraphs": ["Preheat oven to 350F."], "imageList": [{"path": "https://www.seriouseats.com/thmb/fsxoprpah3N9VyuUcm2S7i4RlfI=/1500x0/filters:no_upscale():max_bytes(150000):strip_icc()/__opt__aboutcom__coeus__resources__content_migration__serious_eats__seriouseats.com__recipes__images__2012__01__20120117-187614_GFTues_PigsInABlanket_610-c075c91ac6c84089934469d43382cbb0.jpg"}], "extraInformation": [{"type": "JOKE", "text": "Give the oven a hand. It's the real cook, we're just it's assistants.", "keyword": "oven", "imageUrl": "https://alexa-oat-images.s3.amazonaws.com/multimodal_images/jokes/joke_oven.png"}]}}, "activeDurationMinutes": 1, "totalDurationMinutes": 1}, {"uniqueId": "391ddab7-b1b4-478d-b5ae-c32939f4d905", "response": {"speechText": "Lay out each crescent triangle and top with a sausage link.", "screen": {"format": "TEXT_IMAGE", "paragraphs": ["Lay out each crescent triangle and top with a sausage link."], "imageList": [{"path": "https://www.seriouseats.com/thmb/fsxoprpah3N9VyuUcm2S7i4RlfI=/1500x0/filters:no_upscale():max_bytes(150000):strip_icc()/__opt__aboutcom__coeus__resources__content_migration__serious_eats__seriouseats.com__recipes__images__2012__01__20120117-187614_GFTues_PigsInABlanket_610-c075c91ac6c84089934469d43382cbb0.jpg"}], "requirements": [" chicken maple sausage links"]}}, "activeDurationMinutes": 1, "totalDurationMinutes": 1}, {"uniqueId": "14ada727-31c1-418f-8c84-6e6764de95e9", "response": {"speechText": "Roll the crescent rolls around the links.", "screen": {"format": "TEXT_IMAGE", "paragraphs": ["Roll the crescent rolls around the links."], "imageList": [{"path": "https://images.food52.com/PxRPfj1W6qHiGd4ta0b0EM-PM3g=/1000x1000/b22b1adc-e3d1-4005-b305-cbe602cb002c--7L7A1053.JPG"}], "requirements": [" chicken maple sausage links"]}}, "activeDurationMinutes": 1, "totalDurationMinutes": 1}, {"uniqueId": "1e99067c-d101-41ba-8641-04b0b1384664", "response": {"speechText": "Brush with maple cream sauce.", "screen": {"format": "TEXT_IMAGE", "paragraphs": ["Brush with maple cream sauce."], "imageList": [{"path": "https://www.seriouseats.com/thmb/fsxoprpah3N9VyuUcm2S7i4RlfI=/1500x0/filters:no_upscale():max_bytes(150000):strip_icc()/__opt__aboutcom__coeus__resources__content_migration__serious_eats__seriouseats.com__recipes__images__2012__01__20120117-187614_GFTues_PigsInABlanket_610-c075c91ac6c84089934469d43382cbb0.jpg"}], "extraInformation": [{"type": "JOKE", "text": "Cream is the magician of the kitchen. One moment it's liquid, the next it's whipped, and just when you thought you had it figured out, it turns into butter.", "keyword": "cream", "imageUrl": "https://alexa-oat-images.s3.amazonaws.com/multimodal_images/jokes/joke_cream_v1.png"}]}}, "activeDurationMinutes": 1, "totalDurationMinutes": 1}, {"uniqueId": "204a681b-23b8-4138-8de1-7307f0f14083", "response": {"speechText": "Bake for 12-14 minutes until golden brown. Serve with remaining maple cream sauce.", "screen": {"format": "TEXT_IMAGE", "paragraphs": ["Bake for 12-14 minutes until golden brown. Serve with remaining maple cream sauce."], "imageList": [{"path": "https://www.seriouseats.com/thmb/fsxoprpah3N9VyuUcm2S7i4RlfI=/1500x0/filters:no_upscale():max_bytes(150000):strip_icc()/__opt__aboutcom__coeus__resources__content_migration__serious_eats__seriouseats.com__recipes__images__2012__01__20120117-187614_GFTues_PigsInABlanket_610-c075c91ac6c84089934469d43382cbb0.jpg"}], "extraInformation": [{"type": "JOKE", "text": "Cream is the magician of the kitchen. One moment it's liquid, the next it's whipped, and just when you thought you had it figured out, it turns into butter.", "keyword": "cream", "imageUrl": "https://alexa-oat-images.s3.amazonaws.com/multimodal_images/jokes/joke_cream_v1.png"}]}}, "activeDurationMinutes": 1, "totalDurationMinutes": 1}], "connectionList": [{"idFrom": "1e99067c-d101-41ba-8641-04b0b1384664", "idTo": "204a681b-23b8-4138-8de1-7307f0f14083"}, {"idFrom": "0483c625-5b1e-4893-9761-0c0869200ea7", "idTo": "345b60f3-9ec1-496c-b578-12774723e682"}, {"idFrom": "14ada727-31c1-418f-8c84-6e6764de95e9", "idTo": "1e99067c-d101-41ba-8641-04b0b1384664"}, {"idFrom": "391ddab7-b1b4-478d-b5ae-c32939f4d905", "idTo": "14ada727-31c1-418f-8c84-6e6764de95e9"}, {"idFrom": "345b60f3-9ec1-496c-b578-12774723e682", "idTo": "391ddab7-b1b4-478d-b5ae-c32939f4d905"}], "dataset": "common-crawl", "author": "My Stir Crazy Kitchen", "domainName": "food52.com"} \ No newline at end of file diff --git a/shared/test_data/sample_taskmap_wikihow.json b/shared/test_data/sample_taskmap_wikihow.json new file mode 100644 index 0000000..b03fbe6 --- /dev/null +++ b/shared/test_data/sample_taskmap_wikihow.json @@ -0,0 +1 @@ +{"taskmapId": "751e1886e67bf9bd4b29f6401dd38911", "title": "How to Make Red Velvet Cake", "date": "2022-07-31", "sourceUrl": "https://www.wikihow.com/Make-Red-Velvet-Cake", "description": "A red velvet cake can be a delicious dessert for almost any occasion, and it doesn't have to be a hard cake to make either! Simple and full of flavor, this is a showstopping dessert worth sharing with friends.", "thumbnailUrl": "https://www.wikihow.com/images/thumb/b/bf/Make-Red-Velvet-Cake-Step-17-Version-2.jpg/v4-460px-Make-Red-Velvet-Cake-Step-17-Version-2.jpg", "tags": ["Chocolate Cakes"], "requirementList": [{"uniqueId": "215f0a51-494c-4b5d-afd5-5988109d9e27", "name": "\u00bd cup (102 grams) shortening (or butter, margarine, or any type of fat used for baking and making pastries)", "amount": "102 grams"}, {"uniqueId": "c77e334d-3b54-4be5-96e1-e1536679b1db", "name": "cups (300 grams) sugar", "amount": "300 grams"}, {"uniqueId": "4eaa46cb-ec54-45a0-88fa-8ec414703556", "name": "3 eggs", "amount": "3 eggs"}, {"uniqueId": "6f3bd176-ab28-4bce-b3d0-fbd256e94df0", "name": "cocoa", "amount": "2 tablespoons"}, {"uniqueId": "a65afaf3-84e5-4d3a-b1f7-9adcdb03a1a8", "name": "ounces red food coloring", "amount": "1 \u00bd"}, {"uniqueId": "62dce040-a70a-44e6-aa64-1d63e7d76cfc", "name": "salt", "amount": "1 teaspoon"}, {"uniqueId": "2580cb6a-acc8-41d0-b9e3-c07919e7db88", "name": "cups (315 grams) flour", "amount": "315 grams"}, {"uniqueId": "39c07933-3ce3-4759-a68e-5534171b4de2", "name": "vanilla", "amount": "1 teaspoon"}, {"uniqueId": "dc94ab8c-e65c-4ee2-a495-35b3d2b8df56", "name": "(240 ml) buttermilk", "amount": "240 ml"}, {"uniqueId": "cd4d223d-d0c4-4b75-ba04-7b3c83b0e5f2", "name": "baking soda", "amount": "1 teaspoon"}, {"uniqueId": "1db558e3-2751-4188-9877-b73f7cd17a49", "name": "vinegar", "amount": "1 tablespoon"}, {"uniqueId": "d5c1501e-49ec-4382-a476-ee54e64c8f6a", "name": "cream cheese, softened", "amount": "16 ounces"}, {"uniqueId": "4080d5e8-69cd-4c3c-9c39-5987229d978e", "name": "(400 grams) powdered sugar", "amount": "400 grams"}, {"uniqueId": "fd4bee6b-26fc-4dbc-8257-80f7d65c4032", "name": "vanilla", "amount": "1 teaspoon"}, {"uniqueId": "672de4cf-bc0f-4764-b12a-cf8b460e7c45", "name": "butter (approx 110 grams)", "amount": "110 grams"}, {"uniqueId": "522186b8-97d0-4555-97c0-944e663107c2", "name": "Cake pan", "amount": " "}, {"uniqueId": "313272d7-ec66-4b95-9ca2-acfebbc4d714", "name": "Mixing bowls", "amount": " "}, {"uniqueId": "e52596c7-cd1b-4309-91d0-6a41fb06cb6e", "name": "Cooking utensils", "amount": " "}], "steps": [{"uniqueId": "e422e3b4-a864-484f-ab46-0fd4333eeeb9", "response": {"speechText": "Gather and measure out all of the ingredients", "screen": {"format": "TEXT_IMAGE", "paragraphs": ["Gather and measure out all of the ingredients.", "Good bakers know that moving quickly and efficiently in the kitchen leads to better cakes and smaller messes. Measuring out ahead of time makes it possible."], "footer": "Gather and measure out all of the ingredients", "imageList": [{"path": "https://www.wikihow.com/images/thumb/e/e6/Make-Red-Velvet-Cake-Step-1-Version-4.jpg/v4-460px-Make-Red-Velvet-Cake-Step-1-Version-4.jpg"}]}, "description": "Good bakers know that moving quickly and efficiently in the kitchen leads to better cakes and smaller messes. Measuring out ahead of time makes it possible."}, "activeDurationMinutes": 1, "totalDurationMinutes": 1}, {"uniqueId": "24879201-13a5-4e99-94f4-e07be041ab62", "response": {"speechText": "Cream shortening and gradually add in sugar", "screen": {"format": "TEXT_IMAGE", "paragraphs": ["Cream shortening and gradually add in sugar.", "Use an electric mixer set to medium speed. Add the sugar along the edges and slowly work it in to avoid sugar splatter."], "footer": "Cream shortening and gradually add in sugar", "imageList": [{"path": "https://www.wikihow.com/images/thumb/3/3e/Make-Red-Velvet-Cake-Step-2-Version-4.jpg/550px-nowatermark-Make-Red-Velvet-Cake-Step-2-Version-4.jpg"}], "video": {"hostedMp4": "https://www.wikihow.com/video/3/3a/Make Red Velvet Cake Step 2 Version 2.360p.mp4"}, "requirements": ["(400 grams) powdered sugar", "cups (300 grams) sugar", "\u00bd cup (102 grams) shortening (or butter, margarine, or any type of fat used for baking and making pastries)"], "extraInformation": [{"type": "JOKE", "text": "Sugar never spoils, but if we put our minds together, I bet we can find a way to make it inedible.", "keyword": "sugar", "imageUrl": "https://alexa-oat-images.s3.amazonaws.com/multimodal_images/jokes/joke_sugar.png"}]}, "description": "Use an electric mixer set to medium speed. Add the sugar along the edges and slowly work it in to avoid sugar splatter."}, "activeDurationMinutes": 1, "totalDurationMinutes": 1}, {"uniqueId": "f89f5e5a-2dd4-4e0c-bb69-9e25d6192c6f", "response": {"speechText": "Add eggs one at a time, beating well after each egg is added", "screen": {"format": "TEXT_IMAGE", "paragraphs": ["Add eggs one at a time, beating well after each egg is added.", "Mix in well, keeping your beater moving. It is okay if you add them both at once, too."], "footer": "Add eggs one at a time, beating well after each egg is added", "imageList": [{"path": "https://www.wikihow.com/images/thumb/f/fc/Make-Red-Velvet-Cake-Step-3-Version-4.jpg/550px-nowatermark-Make-Red-Velvet-Cake-Step-3-Version-4.jpg"}], "video": {"hostedMp4": "https://www.wikihow.com/video/4/47/Make Red Velvet Cake Step 3 Version 2.360p.mp4"}, "requirements": ["3 eggs"]}, "description": "Mix in well, keeping your beater moving. It is okay if you add them both at once, too."}, "activeDurationMinutes": 1, "totalDurationMinutes": 1}, {"uniqueId": "f038335a-c335-4386-bf9c-69e5fded5740", "response": {"speechText": "Make a paste out of the cocoa and food coloring, then add to cream", "screen": {"format": "TEXT_IMAGE", "paragraphs": ["Make a paste out of the cocoa and food coloring, then add to cream.", "In a separate bowl, use a whisk to blend the food coloring into the cocoa. Fun Fact: The original red velvet cakes got their color because imported cocoa was actually red tinted. The food coloring came later."], "footer": "Make a paste out of the cocoa and food coloring, then add to cream", "imageList": [{"path": "https://www.wikihow.com/images/thumb/2/2a/Make-Red-Velvet-Cake-Step-4-Version-4.jpg/550px-nowatermark-Make-Red-Velvet-Cake-Step-4-Version-4.jpg"}], "video": {"hostedMp4": "https://www.wikihow.com/video/d/d5/Make Red Velvet Cake Step 4 Version 2.360p.mp4"}, "requirements": ["ounces red food coloring"], "extraInformation": [{"type": "JOKE", "text": "Cream is the magician of the kitchen. One moment it's liquid, the next it's whipped, and just when you thought you had it figured out, it turns into butter.", "keyword": "cream", "imageUrl": "https://alexa-oat-images.s3.amazonaws.com/multimodal_images/jokes/joke_cream_v1.png"}]}, "description": "In a separate bowl, use a whisk to blend the food coloring into the cocoa. Fun Fact: The original red velvet cakes got their color because imported cocoa was actually red tinted. The food coloring came later."}, "activeDurationMinutes": 1, "totalDurationMinutes": 1}, {"uniqueId": "37bb9be1-1f02-458f-b6d6-24373b68a47e", "response": {"speechText": "Add salt, flour, baking soda, vanilla, and buttermilk, beating well after each ingredient is added", "screen": {"format": "TEXT_IMAGE", "paragraphs": ["Add salt, flour, baking soda, vanilla, and buttermilk, beating well after each ingredient is added.", "You can also mix the flour, salt, and baking soda in a separate bowl, adding at once. You can also add it all together into the batter one at a time, using the electric mixer to get a nice consistent batter. Add the flour slowly to avoid splatter. It can help to add it with the buttermilk."], "footer": "Add salt, flour, baking soda, vanilla, and buttermilk, beating well after each ingredient is added", "imageList": [{"path": "https://www.wikihow.com/images/thumb/2/28/Make-Red-Velvet-Cake-Step-5-Version-4.jpg/550px-nowatermark-Make-Red-Velvet-Cake-Step-5-Version-4.jpg"}], "video": {"hostedMp4": "https://www.wikihow.com/video/1/12/Make Red Velvet Cake Step 5 Version 2.360p.mp4"}, "requirements": ["baking soda", "salt", "vanilla", "cups (315 grams) flour", "(240 ml) buttermilk"]}, "description": "You can also mix the flour, salt, and baking soda in a separate bowl, adding at once. You can also add it all together into the batter one at a time, using the electric mixer to get a nice consistent batter. Add the flour slowly to avoid splatter. It can help to add it with the buttermilk."}, "activeDurationMinutes": 1, "totalDurationMinutes": 1}, {"uniqueId": "0ee4da68-4b9b-491d-a281-986bd18e0c5b", "response": {"speechText": "Pour vinegar over the batter", "screen": {"format": "TEXT_IMAGE", "paragraphs": ["Pour vinegar over the batter.", "Just a splash is all you need -- it will give it a nice subtle tanginess."], "footer": "Pour vinegar over the batter", "imageList": [{"path": "https://www.wikihow.com/images/thumb/e/ec/Make-Red-Velvet-Cake-Step-6-Version-4.jpg/550px-nowatermark-Make-Red-Velvet-Cake-Step-6-Version-4.jpg"}], "video": {"hostedMp4": "https://www.wikihow.com/video/4/42/Make Red Velvet Cake Step 6 Version 2.360p.mp4"}, "requirements": ["vinegar"], "extraInformation": [{"type": "JOKE", "text": "Vinegar can also be used as a cleaning agent. Keep that in mind if you're making a mess.", "keyword": "vinegar", "imageUrl": "https://alexa-oat-images.s3.amazonaws.com/multimodal_images/jokes/joke_vinegar_v2.png"}, {"type": "FUNFACT", "text": "Vinegar can also be used as a cleaning agent: I just wouldn't recommend it for brushing your teeth.", "keyword": "vinegar", "imageUrl": "https://alexa-oat-images.s3.amazonaws.com/multimodal_images/facts/fact_vinegar.png"}]}, "description": "Just a splash is all you need -- it will give it a nice subtle tanginess."}, "activeDurationMinutes": 1, "totalDurationMinutes": 1}, {"uniqueId": "ad4a46f8-1624-4c74-99e3-6a3e47473212", "response": {"speechText": "Stir until well mixed", "screen": {"format": "TEXT_IMAGE", "paragraphs": ["Stir until well mixed.", "You want a thin, consistent batter with not chunks of flour or dry ingredients left. Want a little more red color? Add a few more drops of red food coloring."], "footer": "Stir until well mixed", "imageList": [{"path": "https://www.wikihow.com/images/thumb/0/08/Make-Red-Velvet-Cake-Step-7-Version-4.jpg/550px-nowatermark-Make-Red-Velvet-Cake-Step-7-Version-4.jpg"}], "video": {"hostedMp4": "https://www.wikihow.com/video/c/cf/Make Red Velvet Cake Step 7 Version 2.360p.mp4"}}, "description": "You want a thin, consistent batter with not chunks of flour or dry ingredients left. Want a little more red color? Add a few more drops of red food coloring."}, "activeDurationMinutes": 1, "totalDurationMinutes": 1}, {"uniqueId": "3563062f-837e-4159-bfd3-5264e03fedc6", "response": {"speechText": "Pour the cake in a large cake pan or 2 layer cake pans and bake in a 350\u00baF", "screen": {"format": "TEXT_IMAGE", "paragraphs": ["Pour the cake in a large cake pan or 2 layer cake pans and bake in a 350\u00baF.", "The cake should take roughly an hour to cook. When it is done, you'll know through the toothpick test -- stab the center of the cake with a knife or skewer-- if it comes out without still-wet batter on it, it's done."], "footer": "Pour the cake in a large cake pan or 2 layer cake pans and bake in a 350\u00baF", "imageList": [{"path": "https://www.wikihow.com/images/thumb/b/bd/Make-Red-Velvet-Cake-Step-8-Version-4.jpg/550px-nowatermark-Make-Red-Velvet-Cake-Step-8-Version-4.jpg"}], "video": {"hostedMp4": "https://www.wikihow.com/video/c/ce/Make Red Velvet Cake Step 8 Version 2.360p.mp4"}, "requirements": ["Cake pan"]}, "description": "The cake should take roughly an hour to cook. When it is done, you'll know through the toothpick test -- stab the center of the cake with a knife or skewer-- if it comes out without still-wet batter on it, it's done."}, "activeDurationMinutes": 1, "totalDurationMinutes": 1}, {"uniqueId": "823c8451-1d93-4ce1-899d-b72b27c78ed0", "response": {"speechText": "Wait about 20 minutes to cool before frosting", "screen": {"format": "TEXT_IMAGE", "paragraphs": ["Wait about 20 minutes to cool before frosting.", "After 5 minutes, remove from the pan and let cool on a wire rack. Don't frost a hot cake -- the warmth with thin out the frosting and make it difficult, if not impossible, to add smoothly."], "footer": "Wait about 20 minutes to cool before frosting", "imageList": [{"path": "https://www.wikihow.com/images/thumb/0/05/Make-Red-Velvet-Cake-Step-9-Version-3.jpg/v4-460px-Make-Red-Velvet-Cake-Step-9-Version-3.jpg"}]}, "description": "After 5 minutes, remove from the pan and let cool on a wire rack. Don't frost a hot cake -- the warmth with thin out the frosting and make it difficult, if not impossible, to add smoothly."}, "activeDurationMinutes": 1, "totalDurationMinutes": 1}, {"uniqueId": "515e8948-f64c-47e9-8878-75ba41ad1087", "response": {"speechText": "Set the butter and cream cheese on the counter to warm to room temperature", "screen": {"format": "TEXT_IMAGE", "paragraphs": ["Set the butter and cream cheese on the counter to warm to room temperature.", "You'll be whipping the butter and cream cheese up, but that only works if it is soft enough to whip! Set the two dairy products out for 15-30 minutes to get soft. In a pinch, you can gently microwave them to speed things up, but keep it very brief. You don't want liquid."], "footer": "Set the butter and cream cheese on the counter to warm to room temperature", "imageList": [{"path": "https://www.wikihow.com/images/thumb/1/19/Make-Red-Velvet-Cake-Step-10-Version-4.jpg/v4-460px-Make-Red-Velvet-Cake-Step-10-Version-4.jpg"}], "requirements": ["cream cheese, softened"], "extraInformation": [{"type": "JOKE", "text": "Cream is the magician of the kitchen. One moment it's liquid, the next it's whipped, and just when you thought you had it figured out, it turns into butter.", "keyword": "cream", "imageUrl": "https://alexa-oat-images.s3.amazonaws.com/multimodal_images/jokes/joke_cream_v1.png"}]}, "description": "You'll be whipping the butter and cream cheese up, but that only works if it is soft enough to whip! Set the two dairy products out for 15-30 minutes to get soft. In a pinch, you can gently microwave them to speed things up, but keep it very brief. You don't want liquid."}, "activeDurationMinutes": 1, "totalDurationMinutes": 1}, {"uniqueId": "288894e4-55e2-463e-a239-a983cf11c608", "response": {"speechText": "Combine butter and cream cheese", "screen": {"format": "TEXT_IMAGE", "paragraphs": ["Combine butter and cream cheese.", "An electric mixer is best, as it will make quick work of the dairies, but a wooden spoon and whisk work well too. They don't have to be perfectly combined, just well mixed."], "footer": "Combine butter and cream cheese", "imageList": [{"path": "https://www.wikihow.com/images/thumb/5/55/Make-Red-Velvet-Cake-Step-11-Version-4.jpg/550px-nowatermark-Make-Red-Velvet-Cake-Step-11-Version-4.jpg"}], "video": {"hostedMp4": "https://www.wikihow.com/video/6/6d/Make Red Velvet Cake Step 11 Version 2.360p.mp4"}, "requirements": ["butter (approx 110 grams)", "\u00bd cup (102 grams) shortening (or butter, margarine, or any type of fat used for baking and making pastries)", "cream cheese, softened"], "extraInformation": [{"type": "JOKE", "text": "Cream is the magician of the kitchen. One moment it's liquid, the next it's whipped, and just when you thought you had it figured out, it turns into butter.", "keyword": "cream", "imageUrl": "https://alexa-oat-images.s3.amazonaws.com/multimodal_images/jokes/joke_cream_v1.png"}]}, "description": "An electric mixer is best, as it will make quick work of the dairies, but a wooden spoon and whisk work well too. They don't have to be perfectly combined, just well mixed."}, "activeDurationMinutes": 1, "totalDurationMinutes": 1}, {"uniqueId": "2cadd7f5-b3a8-4448-a2e4-e95bb801bffe", "response": {"speechText": "Add powdered sugar slowly, keeping the mixing going throughout", "screen": {"format": "TEXT_IMAGE", "paragraphs": ["Add powdered sugar slowly, keeping the mixing going throughout.", "Powdered sugar will want to poof and fly out as you mix it. To avoid a mess, add it in 3-4 parts, mixing almost all of the first part in before adding the second."], "footer": "Add powdered sugar slowly, keeping the mixing going throughout", "imageList": [{"path": "https://www.wikihow.com/images/thumb/f/f1/Make-Red-Velvet-Cake-Step-12-Version-4.jpg/550px-nowatermark-Make-Red-Velvet-Cake-Step-12-Version-4.jpg"}], "video": {"hostedMp4": "https://www.wikihow.com/video/3/32/Make Red Velvet Cake Step 12 Version 2.360p.mp4"}, "requirements": ["(400 grams) powdered sugar", "cups (300 grams) sugar"], "extraInformation": [{"type": "JOKE", "text": "Sugar never spoils, but if we put our minds together, I bet we can find a way to make it inedible.", "keyword": "sugar", "imageUrl": "https://alexa-oat-images.s3.amazonaws.com/multimodal_images/jokes/joke_sugar.png"}]}, "description": "Powdered sugar will want to poof and fly out as you mix it. To avoid a mess, add it in 3-4 parts, mixing almost all of the first part in before adding the second."}, "activeDurationMinutes": 1, "totalDurationMinutes": 1}, {"uniqueId": "92c6c3b3-e299-4213-9ea0-c903873b6fad", "response": {"speechText": "Add vanilla and whip until creamy", "screen": {"format": "TEXT_IMAGE", "paragraphs": ["Add vanilla and whip until creamy.", "Keep the mixer (or your mixing hand) going until the frosting is nice smooth texture. If you want to thin it out a bit so it spreads better, add 2 tablespoons of cold milk."], "footer": "Add vanilla and whip until creamy", "imageList": [{"path": "https://www.wikihow.com/images/thumb/e/e8/Make-Red-Velvet-Cake-Step-13-Version-4.jpg/550px-nowatermark-Make-Red-Velvet-Cake-Step-13-Version-4.jpg"}], "video": {"hostedMp4": "https://www.wikihow.com/video/8/81/Make Red Velvet Cake Step 13 Version 2.360p.mp4"}, "requirements": ["vanilla"], "extraInformation": [{"type": "JOKE", "text": "Cream is the magician of the kitchen. One moment it's liquid, the next it's whipped, and just when you thought you had it figured out, it turns into butter.", "keyword": "cream", "imageUrl": "https://alexa-oat-images.s3.amazonaws.com/multimodal_images/jokes/joke_cream_v1.png"}]}, "description": "Keep the mixer (or your mixing hand) going until the frosting is nice smooth texture. If you want to thin it out a bit so it spreads better, add 2 tablespoons of cold milk."}, "activeDurationMinutes": 1, "totalDurationMinutes": 1}, {"uniqueId": "53eaeb70-24db-4a4d-b035-669d5a69c7cb", "response": {"speechText": "Cut the cake into layers and frost", "screen": {"format": "TEXT_IMAGE", "paragraphs": ["Cut the cake into layers and frost.", "Place a little pat of icing on the bottom of our plate or cake dish to keep the bottom layer from sliding around. Then frost it and stack another layer on top, frosting the top of that. Don't worry too much about the sides yet. Do not try to frost the cake while it is still hot. Let it cool completely."], "footer": "Cut the cake into layers and frost", "imageList": [{"path": "https://www.wikihow.com/images/thumb/5/58/Make-Red-Velvet-Cake-Step-14-Version-4.jpg/550px-nowatermark-Make-Red-Velvet-Cake-Step-14-Version-4.jpg"}], "video": {"hostedMp4": "https://www.wikihow.com/video/0/04/Make Red Velvet Cake Step 14 Version 2.360p.mp4"}, "requirements": ["Cake pan"]}, "description": "Place a little pat of icing on the bottom of our plate or cake dish to keep the bottom layer from sliding around. Then frost it and stack another layer on top, frosting the top of that. Don't worry too much about the sides yet. Do not try to frost the cake while it is still hot. Let it cool completely."}, "activeDurationMinutes": 1, "totalDurationMinutes": 1}, {"uniqueId": "df328633-b041-4149-b5eb-19f6425fc338", "response": {"speechText": "Assemble the layers and continue frosting", "screen": {"format": "TEXT_IMAGE", "paragraphs": ["Assemble the layers and continue frosting.", "Stack the layers up high, frosting between each layer with a 1/4\" of frosting or so, adding more to taste."], "footer": "Assemble the layers and continue frosting", "imageList": [{"path": "https://www.wikihow.com/images/thumb/9/90/Make-Red-Velvet-Cake-Step-15-Version-3.jpg/550px-nowatermark-Make-Red-Velvet-Cake-Step-15-Version-3.jpg"}], "video": {"hostedMp4": "https://www.wikihow.com/video/6/6c/Make Red Velvet Cake Step 15 Version 2.360p.mp4"}}, "description": "Stack the layers up high, frosting between each layer with a 1/4\" of frosting or so, adding more to taste."}, "activeDurationMinutes": 1, "totalDurationMinutes": 1}, {"uniqueId": "ea85bd8d-48bb-401d-a56b-03d26c516479", "response": {"speechText": "Frost the cake and enjoy!", "screen": {"format": "TEXT_IMAGE", "paragraphs": ["Frost the cake and enjoy!.", "For bakery quality frosting, keep the knife clean after every pass, using a little warm water to ensure your frosting knife applies the icing smoothly and evenly. Use big globs of icing at a time and don't try to spread it too thin. By working in small areas, not spreading too thin, and cleaning the knife regularly, you can get a quality frosted cake. If you've got time, real pros will \"double frost.\" Start with a thin layer of frosting everywhere -- it is okay if it pulls up crumbs. Then freeze the cake for 15 minutes, pull it out, and frost \"for real.\" You'll be astonished how easily it goes on!"], "footer": "Frost the cake and enjoy!", "imageList": [{"path": "https://www.wikihow.com/images/thumb/1/1a/Make-Red-Velvet-Cake-Step-16-Version-3.jpg/550px-nowatermark-Make-Red-Velvet-Cake-Step-16-Version-3.jpg"}], "video": {"hostedMp4": "https://www.wikihow.com/video/f/fb/Make Red Velvet Cake Step 16 Version 2.360p.mp4"}, "requirements": ["Cake pan"]}, "description": "For bakery quality frosting, keep the knife clean after every pass, using a little warm water to ensure your frosting knife applies the icing smoothly and evenly. Use big globs of icing at a time and don't try to spread it too thin. By working in small areas, not spreading too thin, and cleaning the knife regularly, you can get a quality frosted cake. If you've got time, real pros will \"double frost.\" Start with a thin layer of frosting everywhere -- it is okay if it pulls up crumbs. Then freeze the cake for 15 minutes, pull it out, and frost \"for real.\" You'll be astonished how easily it goes on!"}, "activeDurationMinutes": 1, "totalDurationMinutes": 1}, {"uniqueId": "ced2b2e2-657e-4df3-8bd2-a2b86a287d95", "response": {"screen": {"format": "TEXT_IMAGE", "paragraphs": [""], "footer": "Finished", "imageList": [{"path": "https://www.wikihow.com/images/thumb/b/bf/Make-Red-Velvet-Cake-Step-17-Version-2.jpg/v4-460px-Make-Red-Velvet-Cake-Step-17-Version-2.jpg"}]}}, "activeDurationMinutes": 1, "totalDurationMinutes": 1}], "connectionList": [{"idFrom": "0ee4da68-4b9b-491d-a281-986bd18e0c5b", "idTo": "ad4a46f8-1624-4c74-99e3-6a3e47473212"}, {"idFrom": "24879201-13a5-4e99-94f4-e07be041ab62", "idTo": "f89f5e5a-2dd4-4e0c-bb69-9e25d6192c6f"}, {"idFrom": "3563062f-837e-4159-bfd3-5264e03fedc6", "idTo": "823c8451-1d93-4ce1-899d-b72b27c78ed0"}, {"idFrom": "df328633-b041-4149-b5eb-19f6425fc338", "idTo": "ea85bd8d-48bb-401d-a56b-03d26c516479"}, {"idFrom": "288894e4-55e2-463e-a239-a983cf11c608", "idTo": "2cadd7f5-b3a8-4448-a2e4-e95bb801bffe"}, {"idFrom": "ea85bd8d-48bb-401d-a56b-03d26c516479", "idTo": "ced2b2e2-657e-4df3-8bd2-a2b86a287d95"}, {"idFrom": "f89f5e5a-2dd4-4e0c-bb69-9e25d6192c6f", "idTo": "f038335a-c335-4386-bf9c-69e5fded5740"}, {"idFrom": "f038335a-c335-4386-bf9c-69e5fded5740", "idTo": "37bb9be1-1f02-458f-b6d6-24373b68a47e"}, {"idFrom": "2cadd7f5-b3a8-4448-a2e4-e95bb801bffe", "idTo": "92c6c3b3-e299-4213-9ea0-c903873b6fad"}, {"idFrom": "e422e3b4-a864-484f-ab46-0fd4333eeeb9", "idTo": "24879201-13a5-4e99-94f4-e07be041ab62"}, {"idFrom": "37bb9be1-1f02-458f-b6d6-24373b68a47e", "idTo": "0ee4da68-4b9b-491d-a281-986bd18e0c5b"}, {"idFrom": "823c8451-1d93-4ce1-899d-b72b27c78ed0", "idTo": "515e8948-f64c-47e9-8878-75ba41ad1087"}, {"idFrom": "515e8948-f64c-47e9-8878-75ba41ad1087", "idTo": "288894e4-55e2-463e-a239-a983cf11c608"}, {"idFrom": "53eaeb70-24db-4a4d-b035-669d5a69c7cb", "idTo": "df328633-b041-4149-b5eb-19f6425fc338"}, {"idFrom": "92c6c3b3-e299-4213-9ea0-c903873b6fad", "idTo": "53eaeb70-24db-4a4d-b035-669d5a69c7cb"}, {"idFrom": "ad4a46f8-1624-4c74-99e3-6a3e47473212", "idTo": "3563062f-837e-4159-bfd3-5264e03fedc6"}], "dataset": "common-crawl", "author": "Mathew Rice", "domainName": "wikihow"} \ No newline at end of file diff --git a/tester/tests/conftest.py b/tester/tests/conftest.py index a6d66e2..38553c1 100644 --- a/tester/tests/conftest.py +++ b/tester/tests/conftest.py @@ -1,4 +1,5 @@ import os +import json import uuid import pytest @@ -78,3 +79,13 @@ def invalid_downloads_path() -> str: @pytest.fixture def missing_values_downloads_path() -> str: return "/shared/test_data/missing_values_downloads.toml" + +@pytest.fixture +def sample_taskmap_json_wikihow() -> str: + # returns a JSON dump of a TaskMap extracted from the default index + return open("/shared/test_data/sample_taskmap_wikihow.json", "r").read() + +@pytest.fixture +def sample_taskmap_json_seriouseats() -> str: + # returns a JSON dump of a TaskMap extracted from the default index + return open("/shared/test_data/sample_taskmap_seriouseats.json", "r").read() From 2e91a9401e34d0515be65fe8a2658cee7071e862 Mon Sep 17 00:00:00 2001 From: Andrew Ramsay Date: Fri, 26 Apr 2024 14:49:28 +0100 Subject: [PATCH 21/57] Add HUGGING_FACE_HUB_TOKEN to tgi service --- docker-compose.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docker-compose.yml b/docker-compose.yml index 6f64773..e481c52 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -333,6 +333,8 @@ services: # any other TGI launcher parameters can be set in this env var, e.g.: # TGI_PARAMS="--param1 param1_value --param2 param2_value" docker compose up - TGI_PARAMS=${TGI_PARAMS:-} + # this is required for Mistral or other gated models + - HUGGING_FACE_HUB_TOKEN=${HUGGING_FACE_HUB_TOKEN:-} networks: - internal - external From 7361b089725b3db682b3d88d8ef73f5fe9084345 Mon Sep 17 00:00:00 2001 From: Andrew Ramsay Date: Fri, 26 Apr 2024 14:50:11 +0100 Subject: [PATCH 22/57] Fix session ID field name --- tester/tests/integration_tests/interaction.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tester/tests/integration_tests/interaction.py b/tester/tests/integration_tests/interaction.py index bc5f14d..384c871 100644 --- a/tester/tests/integration_tests/interaction.py +++ b/tester/tests/integration_tests/interaction.py @@ -135,7 +135,7 @@ def run(self, input_text: str, session: dict, intent: str = '') -> dict: Returns: a JSON dict generated from the HTTP response """ - response = self.send_request(input_text, intent, session['id'], session['headless']) + response = self.send_request(input_text, intent, session['session_id'], session['headless']) response_json = response.json() print(f'USER: {input_text}') try: From 3eb171afe3d72c869934a1930c5fec1910450bba Mon Sep 17 00:00:00 2001 From: Andrew Ramsay Date: Fri, 26 Apr 2024 14:51:12 +0100 Subject: [PATCH 23/57] Add set_source calls for each code path This is currently used by one of the LLM tests to check if a particular step in the policy was activated to produce the current response. --- orchestrator/policy/chitchat_policy/chitchat_policy.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/orchestrator/policy/chitchat_policy/chitchat_policy.py b/orchestrator/policy/chitchat_policy/chitchat_policy.py index 39758ad..7321e0e 100644 --- a/orchestrator/policy/chitchat_policy/chitchat_policy.py +++ b/orchestrator/policy/chitchat_policy/chitchat_policy.py @@ -138,7 +138,7 @@ def step(self, session: Session) -> Tuple[Session, OutputInteraction]: output.speech_text = f'{helpful_prompt} ' else: output.speech_text = f'{chitchat_response.text}{helpful_prompt} ' - + set_source(output) elif got_llm_response: logger.info(f'CHIT CHAT LLM RESPONSE GIVEN: {chitchat_response.text}') keyword_helpful_prompt = self.__get_helpful_prompt(chitchat_request.text) @@ -149,9 +149,10 @@ def step(self, session: Session) -> Tuple[Session, OutputInteraction]: output.speech_text = f'{chitchat_response.text} {random.choice(transition_options)} {keyword_helpful_prompt.lower()}' else: output.speech_text = f'{chitchat_response.text} {random.choice(transition_options)} {helpful_prompt.lower()}' + set_source(output) elif chitchat_response.text != "": output.speech_text = chitchat_response.text - + set_source(output) if output.speech_text == "": # chitchat responses are bad, so fallback to confused session.turn[-1].user_request.interaction.intents.append("ConfusedIntent") @@ -160,6 +161,6 @@ def step(self, session: Session) -> Tuple[Session, OutputInteraction]: raise PhaseChangeException() else: output = repeat_screen_response(session, output) + set_source(output) - set_source(output) return session, output From bc3cfbfc1897b61938a5a5a68c8c95a0956b90b9 Mon Sep 17 00:00:00 2001 From: Andrew Ramsay Date: Fri, 26 Apr 2024 14:52:27 +0100 Subject: [PATCH 24/57] Bump CUDA image version to avoid warnings --- neural_functionalities/Dockerfile | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/neural_functionalities/Dockerfile b/neural_functionalities/Dockerfile index a0a281c..c69fb42 100644 --- a/neural_functionalities/Dockerfile +++ b/neural_functionalities/Dockerfile @@ -1,5 +1,4 @@ -# syntax=docker/dockerfile:1.3 -FROM nvidia/cuda:12.1.0-runtime-ubuntu22.04 +FROM nvidia/cuda:12.4.1-runtime-ubuntu22.04 COPY oat_common/requirements.txt /requirements.txt From 03cf959e999c3e5ac96aa6dde1bca60df63121ab Mon Sep 17 00:00:00 2001 From: Andrew Ramsay Date: Fri, 26 Apr 2024 14:54:46 +0100 Subject: [PATCH 25/57] Fix warnings about f-strings with no placeholders --- functionalities/llm_chit_chat/llm_chit_chat.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/functionalities/llm_chit_chat/llm_chit_chat.py b/functionalities/llm_chit_chat/llm_chit_chat.py index 0eecb98..5e40492 100644 --- a/functionalities/llm_chit_chat/llm_chit_chat.py +++ b/functionalities/llm_chit_chat/llm_chit_chat.py @@ -59,25 +59,25 @@ def build_prompt(request: LLMChitChatRequest) -> ModelRequest: if request.task_title != "": prompt += f'You are a friendly AI assistant who is assisting a human making {request.task_title}. ' else: - prompt += f'You are a friendly AI assistant who is assisting a human. ' + prompt += 'You are a friendly AI assistant who is assisting a human. ' prompt += 'Respond to the human or ask a question back. Try to not repeat what you said previously. ' - prompt += f'You specialize in cooking, arts & crafts, and DIY. You do not reveal your name in the spirit of ' \ - f'fair competition. You cannot play music, games or quizzes. You are not able to read the news, ' \ - f'turn on a light, or give recommendations for things outside cooking and DIY domains \n\n' + prompt += 'You specialize in cooking, arts & crafts, and DIY. You do not reveal your name in the spirit of ' \ + 'fair competition. You cannot play music, games or quizzes. You are not able to read the news, ' \ + 'turn on a light, or give recommendations for things outside cooking and DIY domains \n\n' if request.last_intent == "ChitChatIntent": - prompt += f"### Input: \n" + prompt += "### Input: \n" prompt += f'Human: {request.user_question} \n' else: if request.last_intent == "QuestionIntent": prompt += f"You just said: {request.last_agent_response}. Answer the given user question. \n" - prompt += f"### Input: \n" + prompt += "### Input: \n" prompt += f'Human: {request.user_question} \n' else: - prompt += f"### Input: \n" + prompt += "### Input: \n" prompt += f'You: {request.last_agent_response} \n' \ f'Human: {request.user_question} \n' - prompt += f'\n### Response: Your response:' + prompt += '\n### Response: Your response:' model_request.formatted_prompt = prompt model_request.max_tokens = 30 From 51a96e9c376d027097dd165114538b2e596de6b1 Mon Sep 17 00:00:00 2001 From: Andrew Ramsay Date: Fri, 3 May 2024 17:21:37 +0100 Subject: [PATCH 26/57] Adding new methods to AbstractDB Adding `delete_taskmap` and `delete_session` methods. These aren't much use for the online system, they're used in some new tests. --- external_functionalities/database/abstract_db.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/external_functionalities/database/abstract_db.py b/external_functionalities/database/abstract_db.py index 5512c23..0b918f0 100644 --- a/external_functionalities/database/abstract_db.py +++ b/external_functionalities/database/abstract_db.py @@ -12,6 +12,10 @@ def load_session(self, session_id: str) -> Session: def save_session(self, session_id: str, session: Session) -> None: pass + @abstractmethod + def delete_session(self, session_id: str) -> None: + pass + @abstractmethod def load_taskmap(self, taskmap_id: str) -> TaskMap: pass @@ -19,3 +23,7 @@ def load_taskmap(self, taskmap_id: str) -> TaskMap: @abstractmethod def save_taskmap(self, session_id: str, session: Session) -> None: pass + + @abstractmethod + def delete_taskmap(self, taskmap_id: str) -> None: + pass From 590f498b307f0d2b46d30a0b5d87ce0c588f02fd Mon Sep 17 00:00:00 2001 From: Andrew Ramsay Date: Fri, 3 May 2024 17:26:18 +0100 Subject: [PATCH 27/57] Adding support for deleting TaskMaps and Sessions This is to make some new tests simpler to implement, not relevant for the online system. - Define new `delete_session` and `delete_taskmap` RPCs - Add `delete_session` and `delete_taskmap` methods to `DynamoDB` - Add a `delete` method to `ProtoDB` which those 2 methods call --- external_functionalities/database/dynamo_db.py | 6 ++++++ external_functionalities/database/servicer.py | 8 ++++++++ shared/protobufs/database.proto | 4 +++- shared/utils/aws/new_proto_db.py | 10 ++++++++++ 4 files changed, 27 insertions(+), 1 deletion(-) diff --git a/external_functionalities/database/dynamo_db.py b/external_functionalities/database/dynamo_db.py index 0f4ec65..0d0aa2a 100644 --- a/external_functionalities/database/dynamo_db.py +++ b/external_functionalities/database/dynamo_db.py @@ -60,12 +60,18 @@ def load_session(self, session_id: str) -> Session: else: return session + def delete_session(self, session_id:str) -> None: + self.session_db.delete(session_id) + def save_taskmap(self, session_id: str, session: Session) -> None: self.taskmap_db.put(session) def load_taskmap(self, taskmap_id: str) -> TaskMap: return self.taskmap_db.get(taskmap_id) + def delete_taskmap(self, taskmap_id: str) -> None: + self.taskmap_db.delete(taskmap_id) + def save_search_log(self, search_log: SearchLog) -> None: self.search_logs_db.put(search_log) diff --git a/external_functionalities/database/servicer.py b/external_functionalities/database/servicer.py index 28519f9..7bdec10 100644 --- a/external_functionalities/database/servicer.py +++ b/external_functionalities/database/servicer.py @@ -17,6 +17,10 @@ def save_session(self, request, context) -> None: self.instance.save_session(request.id, request.session) return Void() + def delete_session(self, request, context) -> None: + self.instance.delete_session(request.id) + return Void() + def load_taskmap(self, request, context) -> TaskMap: return self.instance.get_taskmap(request.id) @@ -24,6 +28,10 @@ def save_taskmap(self, request, context) -> None: self.instance.save_session(request.id, request.taskmap) return Void() + def delete_taskmap(self, request, context) -> None: + self.instance.delete_taskmap(request.id) + return Void() + def save_search_logs(self, request, context) -> None: self.instance.save_search_log(request) return Void() diff --git a/shared/protobufs/database.proto b/shared/protobufs/database.proto index d1ce06d..67787de 100644 --- a/shared/protobufs/database.proto +++ b/shared/protobufs/database.proto @@ -32,9 +32,11 @@ message QueryList{ service Database{ rpc load_taskmap(TaskMapRequest) returns (TaskMap) {} rpc save_taskmap(TaskMapRequest) returns (Void) {} + rpc delete_taskmap(TaskMapRequest) returns (Void) {} rpc load_session(SessionRequest) returns (Session) {} rpc save_session(SessionRequest) returns (Void) {} + rpc delete_session(SessionRequest) returns (Void) {} rpc save_search_logs(SearchLog) returns(Void) {} rpc save_asr_logs(ASRLog) returns(Void) {} @@ -44,4 +46,4 @@ service Database{ rpc get_theme(ThemeMapping) returns (ThemeMapping) {} rpc get_queries(Void) returns (QueryList) {} -} \ No newline at end of file +} diff --git a/shared/utils/aws/new_proto_db.py b/shared/utils/aws/new_proto_db.py index 1810b66..544ae25 100644 --- a/shared/utils/aws/new_proto_db.py +++ b/shared/utils/aws/new_proto_db.py @@ -140,6 +140,16 @@ def get(self, item_id: str, decode: bool = True) -> Message: else: return proto_dict + def delete(self, item_id: str) -> None: + try: + self.__table.delete_item( + Key={ + self.primary_key: item_id, + } + ) + except Exception as e: + pass + def batch_put(self, proto_obj_list: List[Message], check_for_changes: bool = True) -> List[str]: From 67e83d94a2ec030e2e42336a4ff67d2684d0ab97 Mon Sep 17 00:00:00 2001 From: Andrew Ramsay Date: Fri, 3 May 2024 17:30:17 +0100 Subject: [PATCH 28/57] Adding a timeout field to LLM request protos This is required because all the LLM calls (and so timeout handling) is now being done inside `llm_functionalities`. Adding this field provides a way for the LLM components in `functionalities` and elsewhere to define acceptable timeouts for their own specific circumstances and pass this over to `llm_functionalities` with the rest of the request params. --- shared/protobufs/llm.proto | 2 ++ 1 file changed, 2 insertions(+) diff --git a/shared/protobufs/llm.proto b/shared/protobufs/llm.proto index 6537056..8c352ab 100644 --- a/shared/protobufs/llm.proto +++ b/shared/protobufs/llm.proto @@ -5,6 +5,7 @@ import "taskmap.proto"; message ModelRequest { string formatted_prompt = 1; int32 max_tokens = 2; + int32 timeout = 3; } message ModelResponse { @@ -14,6 +15,7 @@ message ModelResponse { message ModelBatchRequest { repeated string formatted_prompts = 1; int32 max_tokens = 2; + int32 timeout = 3; } message ModelBatchResponse { From e72770a9c448c6646ac2042d399171a70be682e4 Mon Sep 17 00:00:00 2001 From: Andrew Ramsay Date: Fri, 3 May 2024 17:33:07 +0100 Subject: [PATCH 29/57] Removing old generate_summary methods Dropping this in favour of using .text_generation for all cases --- .../llm_runner/llm_runner_servicer.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/llm_functionalities/llm_runner/llm_runner_servicer.py b/llm_functionalities/llm_runner/llm_runner_servicer.py index 45bfcd9..a9a7498 100644 --- a/llm_functionalities/llm_runner/llm_runner_servicer.py +++ b/llm_functionalities/llm_runner/llm_runner_servicer.py @@ -3,10 +3,6 @@ ModelResponse, ModelBatchRequest, ModelBatchResponse, - TGISummaryRequest, - TGISummaryResponse, - TGIMultipleSummaryRequest, - TGIMultipleSummaryResponse, ) from compiled_protobufs.llm_pb2_grpc import ( LLMRunnerServicer, @@ -23,11 +19,3 @@ def call_model(self, query: ModelRequest, context) -> ModelResponse: def batch_call_model(self, query: ModelBatchRequest, context) -> ModelBatchResponse: return self.model.batch_call_model(query) - - def generate_summary(self, query: TGISummaryRequest, context) -> TGISummaryResponse: - return self.model.generate_summary(query) - - def generate_summaries( - self, query: TGIMultipleSummaryRequest, context - ) -> TGIMultipleSummaryResponse: - return self.model.generate_summaries(query) From 4d98c5a79cbeff1a028c751f726c1bbec6c76fbe Mon Sep 17 00:00:00 2001 From: Andrew Ramsay Date: Fri, 3 May 2024 17:37:39 +0100 Subject: [PATCH 30/57] Update LLM timeout/error handling This commit removes most of the timeout-handling and exception-checking for the LLM code in `functionalities`. Previously the system couldn't assume `llm_functionalities` would be available, so there had to be error handling for failed RPC calls within the various `llm_*` components in `functionalities`. There also had to be timeout handling in some cases to prevent lengthy LLM calls from delaying system responses. In the TGI-enabled version of OAT, `llm_functionalities` has no GPU requirements any more so we can probably assume it's always going to be available like the other services. That also means we can just do the timeout/error handling in `llm_functionalities` since all calls to TGI will be routed through there. This ultimately means the code around the RPCs to `llm_functionalities` can be simplfiied to remove the existing timeout and exception handling. --- .../execution_search_manager.py | 34 ++------------- .../llm_chit_chat/llm_chit_chat.py | 43 +------------------ .../llm_description_generation.py | 7 +-- .../llm_ingredient_step_text_rewriter.py | 26 +++++------ .../llm_ingredient_substitution_generation.py | 30 +------------ .../llm_proactive_question_generation.py | 7 +-- 6 files changed, 19 insertions(+), 128 deletions(-) diff --git a/functionalities/execution_search_manager/execution_search_manager.py b/functionalities/execution_search_manager/execution_search_manager.py index c341355..fcf9a1f 100644 --- a/functionalities/execution_search_manager/execution_search_manager.py +++ b/functionalities/execution_search_manager/execution_search_manager.py @@ -72,12 +72,13 @@ def build_prompt(self, request: ExecutionSearchRequest) -> ModelRequest: model_request.max_tokens = 20 return model_request - def llm_generate_search_decision(self, request: ExecutionSearchRequest) -> ExecutionSearchResponse: + def generate_decision(self, request: ExecutionSearchRequest, default_timeout: int = 1000) -> ExecutionSearchResponse: + default_timeout = default_timeout if request.timeout == 0 else request.timeout model_request = self.build_prompt(request) llm_response: ModelResponse = self.llm.call_model(model_request) - llm_classification: ExecutionSearchRequest = ExecutionSearchResponse() + llm_classification: ExecutionSearchResponse = ExecutionSearchResponse() parsed_result = self.extract_classification_response(llm_response.text) if parsed_result == {}: return llm_classification @@ -85,32 +86,3 @@ def llm_generate_search_decision(self, request: ExecutionSearchRequest) -> Execu llm_classification.intent_classification = parsed_result.get("intent", "") llm_classification.ai_response = parsed_result.get("ai_response", "") return llm_classification - - def generate_decision(self, request: ExecutionSearchRequest, default_timeout=1000) -> ExecutionSearchResponse: - default_timeout = default_timeout if request.timeout == 0 else request.timeout - - with ThreadPoolExecutor(max_workers=1) as executor: - future = executor.submit(self.llm_generate_search_decision, request) - timeout: float = default_timeout / 1000 + time.monotonic() - try: - if future.done() or timeout - time.monotonic() > 0: - response = future.result(timeout=timeout - time.monotonic()) - return response - - else: - future.cancel() - logger.warning(f"Timeout for Execution Search Decision Generation") - response = ExecutionSearchResponse() - return response - - except TimeoutError as e: - future.cancel() - logger.warning("TimeoutError while running Execution Search Decision Generation", exc_info=e) - response = ExecutionSearchResponse() - return response - - except _InactiveRpcError as e: - future.cancel() - logger.warning("Execution Search Decision Generation Channel is down") - response = ExecutionSearchResponse() - return response diff --git a/functionalities/llm_chit_chat/llm_chit_chat.py b/functionalities/llm_chit_chat/llm_chit_chat.py index 5e40492..35b8394 100644 --- a/functionalities/llm_chit_chat/llm_chit_chat.py +++ b/functionalities/llm_chit_chat/llm_chit_chat.py @@ -83,11 +83,10 @@ def build_prompt(request: LLMChitChatRequest) -> ModelRequest: model_request.max_tokens = 30 return model_request - def call_chit_chat_model(self, request: LLMChitChatRequest) -> ChitChatResponse: + def generate_chit_chat(self, request: LLMChitChatRequest) -> ChitChatResponse: model_request = self.build_prompt(request) llm_response: ModelResponse = self.llm.call_model(model_request) - logger.info(llm_response.text) agent_response: ChitChatResponse = ChitChatResponse() agent_response.text = process_response_text(extract_qa_answer(llm_response.text)) @@ -95,43 +94,3 @@ def call_chit_chat_model(self, request: LLMChitChatRequest) -> ChitChatResponse: logger.info(f'EXTRACTED LLM RESPONSE: {agent_response.text}') return agent_response - - def generate_chit_chat(self, request: LLMChitChatRequest, default_timeout=2000) -> ChitChatResponse: - # CALL CHIT CHAT LLM SERVICE FROM HERE - # default timeout is 10000, aka 1000 milliseconds aka 1 second - - default_timeout = default_timeout if request.timeout == 0 else request.timeout - logger.info(default_timeout) - - with ThreadPoolExecutor(max_workers=1) as executor: - future = executor.submit(self.call_chit_chat_model, request) - timeout: float = default_timeout / 1000 + time.monotonic() - try: - if future.done() or timeout - time.monotonic() > 0: - response = future.result(timeout=timeout - time.monotonic()) - return response - - else: - future.cancel() - logger.warning(f"Timeout for LLM Chit Chat") - response = ChitChatResponse() - response.text = "" - return response - - except TimeoutError as e: - future.cancel() - logger.warning("TimeoutError while running LLM Chit Chat", exc_info=e) - response = ChitChatResponse() - response.text = "" - return response - - except _InactiveRpcError as e: - future.cancel() - logger.warning("LLM Channel is down") - response = ChitChatResponse() - response.text = "" - return response - - except Exception as e: - future.cancel() - logger.info(type(e)) diff --git a/functionalities/llm_description_generation/llm_description_generation.py b/functionalities/llm_description_generation/llm_description_generation.py index 5bba723..2f9810f 100644 --- a/functionalities/llm_description_generation/llm_description_generation.py +++ b/functionalities/llm_description_generation/llm_description_generation.py @@ -60,12 +60,7 @@ def generate_descriptions(self, request: LLMMultipleDescriptionGenerationRequest model_batch_request.formatted_prompts.append(str(prompt)) - try: - llm_responses: ModelBatchResponse = self.llm.batch_call_model(model_batch_request) - except _InactiveRpcError as e: - logger.info('LLM Channel is down during description generation') - llm_responses = ModelBatchResponse() - + llm_responses: ModelBatchResponse = self.llm.batch_call_model(model_batch_request) llm_descriptions: MultipleDescriptionGenerationResponse = MultipleDescriptionGenerationResponse() for text in llm_responses.text: llm_descriptions.description.append(extract_description(text)) diff --git a/functionalities/llm_ingredient_substitution/llm_ingredient_step_text_rewriter.py b/functionalities/llm_ingredient_substitution/llm_ingredient_step_text_rewriter.py index 02c609c..2798e79 100644 --- a/functionalities/llm_ingredient_substitution/llm_ingredient_step_text_rewriter.py +++ b/functionalities/llm_ingredient_substitution/llm_ingredient_step_text_rewriter.py @@ -189,18 +189,14 @@ def rewrite_steps(self, request: AdjustedStepGenerationRequest) -> AdjustedStepG llm_step_texts: AdjustedStepGenerationResponse = AdjustedStepGenerationResponse() - try: - for step, ingredient in zip(request.step, request.ingredient): - model_input = self.build_prompt(request.task_title, step.response.speech_text, ingredient) - ids.append(step.unique_id) - model_batch_request.formatted_prompts.append(model_input) - - llm_responses = self.llm.batch_call_model(model_batch_request) - - for idx, text in enumerate(llm_responses.text): - llm_step_texts.step_text.append(self.extract_response(text).get('step_text', '')) - llm_step_texts.ids.append(ids[idx]) - return llm_step_texts - except _InactiveRpcError as e: - logger.warning("Step Text Rewriter Channel is down") - return llm_step_texts + for step, ingredient in zip(request.step, request.ingredient): + model_input = self.build_prompt(request.task_title, step.response.speech_text, ingredient) + ids.append(step.unique_id) + model_batch_request.formatted_prompts.append(model_input) + + llm_responses = self.llm.batch_call_model(model_batch_request) + + for idx, text in enumerate(llm_responses.text): + llm_step_texts.step_text.append(self.extract_response(text).get('step_text', '')) + llm_step_texts.ids.append(ids[idx]) + return llm_step_texts diff --git a/functionalities/llm_ingredient_substitution/llm_ingredient_substitution_generation.py b/functionalities/llm_ingredient_substitution/llm_ingredient_substitution_generation.py index a28ffeb..9c96ef1 100644 --- a/functionalities/llm_ingredient_substitution/llm_ingredient_substitution_generation.py +++ b/functionalities/llm_ingredient_substitution/llm_ingredient_substitution_generation.py @@ -90,7 +90,7 @@ def build_prompt(request: IngredientReplacementRequest) -> ModelRequest: model_request.max_tokens = 20 return model_request - def llm_generate_search_decision(self, request: IngredientReplacementRequest) -> IngredientReplacementResponse: + def generate_replacement(self, request: IngredientReplacementRequest) -> IngredientReplacementResponse: model_request = self.build_prompt(request) llm_response: ModelResponse = self.llm.call_model(model_request) @@ -106,31 +106,3 @@ def llm_generate_search_decision(self, request: IngredientReplacementRequest) -> llm_replacement.new_ingredient.MergeFrom(ingredient) return llm_replacement - - def generate_replacement(self, request: IngredientReplacementRequest, default_timeout=1000) -> \ - IngredientReplacementResponse: - with ThreadPoolExecutor(max_workers=1) as executor: - future = executor.submit(self.llm_generate_search_decision, request) - timeout: float = default_timeout / 1000 + time.monotonic() - try: - if future.done() or timeout - time.monotonic() > 0: - response = future.result(timeout=timeout - time.monotonic()) - return response - - else: - future.cancel() - logger.warning(f"Timeout for Ingredient Replacement Generation") - response = IngredientReplacementResponse() - return response - - except TimeoutError as e: - future.cancel() - logger.warning("TimeoutError while running Ingredient Replacement Generation", exc_info=e) - response = IngredientReplacementResponse() - return response - - except _InactiveRpcError as e: - future.cancel() - logger.warning("Ingredient Replacement Generation Channel is down") - response = IngredientReplacementResponse() - return response diff --git a/functionalities/llm_proactive_question_generation/llm_proactive_question_generation.py b/functionalities/llm_proactive_question_generation/llm_proactive_question_generation.py index b261850..18460d9 100644 --- a/functionalities/llm_proactive_question_generation/llm_proactive_question_generation.py +++ b/functionalities/llm_proactive_question_generation/llm_proactive_question_generation.py @@ -72,11 +72,8 @@ def generate_proactive_questions(self, f"or how the results turned out.\n\n" \ f"### Response: " model_batch_request.formatted_prompts.append(model_input) - try: - llm_responses: ModelBatchResponse = self.llm.batch_call_model(model_batch_request) - except _InactiveRpcError as e: - logger.info('LLM Channel is down during proactive question generation') - llm_responses = ModelBatchResponse() + + llm_responses: ModelBatchResponse = self.llm.batch_call_model(model_batch_request) llm_questions: ProactiveQuestionGenerationResponse = ProactiveQuestionGenerationResponse() for text in llm_responses.text: From 5f08def8a65fcbdf69782f808e0f848ea8b982b3 Mon Sep 17 00:00:00 2001 From: Andrew Ramsay Date: Fri, 3 May 2024 17:46:10 +0100 Subject: [PATCH 31/57] Formatting/linting --- .../llm_summary_generation.py | 61 +++++++++++-------- 1 file changed, 37 insertions(+), 24 deletions(-) diff --git a/functionalities/llm_summary_generation/llm_summary_generation.py b/functionalities/llm_summary_generation/llm_summary_generation.py index 51e8b21..6a97ac7 100644 --- a/functionalities/llm_summary_generation/llm_summary_generation.py +++ b/functionalities/llm_summary_generation/llm_summary_generation.py @@ -2,8 +2,16 @@ import os import re -from compiled_protobufs.llm_pb2 import ModelRequest, ModelResponse, SummaryGenerationRequest, SummaryGenerationResponse, \ - ModelBatchRequest, SummaryGenerationRequest, MultipleSummaryGenerationRequest, MultipleSummaryGenerationResponse +from compiled_protobufs.llm_pb2 import ( + SummaryGenerationRequest, + SummaryGenerationResponse, + MultipleSummaryGenerationRequest, + MultipleSummaryGenerationResponse, + ModelRequest, + ModelResponse, + ModelBatchRequest, + ModelBatchResponse, +) from compiled_protobufs.llm_pb2_grpc import LLMRunnerStub from utils import logger @@ -11,21 +19,21 @@ def process_response_text(response_text: str) -> str: # remove whitespace - text = ' '.join(response_text.split("\n")) - if not re.search(r'[.!?]', text[-1]): + text = " ".join(response_text.split("\n")) + if not re.search(r"[.!?]", text[-1]): # If not, add a period (.) to the end of the text - text += '.' + text += "." # split at punctuation - sentences = re.split(r'(?<=[.!?])\s+', text) + sentences = re.split(r"(?<=[.!?])\s+", text) complete_sentences = [] for sentence in sentences: - if sentence.endswith(('.', '!', '?')): + if sentence.endswith((".", "!", "?")): # remove numbered lists - sentence = re.sub(r'\d+\.', '', sentence) + sentence = re.sub(r"\d+\.", "", sentence) complete_sentences.append(sentence.strip()) - return ' '.join(complete_sentences) + return " ".join(complete_sentences) def truncate_details(details: str) -> str: @@ -34,16 +42,16 @@ def truncate_details(details: str) -> str: return " ".join(details_list[:n]) -def extract_qa_answer(generated_answer): +def extract_qa_answer(generated_answer: str) -> str: start_token = "### Response:" end_token = "#" if start_token in generated_answer: generated_answer = generated_answer.split(start_token)[1] if end_token in generated_answer: generated_answer = generated_answer.split(end_token)[0] - generated_answer = generated_answer.replace('\n', " ") + generated_answer = generated_answer.replace("\n", " ") if ". " in generated_answer: - generated_answer = ". ".join(generated_answer.split('.')[:-1]) + generated_answer = ". ".join(generated_answer.split(".")[:-1]) return generated_answer.strip() @@ -53,14 +61,16 @@ def __init__(self): self.llm = LLMRunnerStub(llm_channel) def generate_summary(self, request: SummaryGenerationRequest) -> SummaryGenerationResponse: - model_request: ModelRequest = ModelRequest() + model_request = ModelRequest() step_text = request.step_text title = request.task_title details = request.more_details details = truncate_details(details) - summary_generation_prompt = "Create a brief 2-sentence text by condensing key points from both the Details " \ - "and the Step, omitting insignificant details." + summary_generation_prompt = ( + "Create a brief 2-sentence text by condensing key points from both the Details " + "and the Step, omitting insignificant details." + ) model_input = f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. @@ -81,18 +91,20 @@ def generate_summary(self, request: SummaryGenerationRequest) -> SummaryGenerati model_request.formatted_prompt = model_input model_request.max_tokens = 100 - llm_response = self.llm.call_model(model_request) + llm_response: ModelResponse = self.llm.call_model(model_request) - summary_response: SummaryGenerationResponse = SummaryGenerationResponse() + summary_response = SummaryGenerationResponse() summary_response.summary = process_response_text(extract_qa_answer(llm_response.text)) return summary_response def generate_summaries(self, request: MultipleSummaryGenerationRequest) -> MultipleSummaryGenerationResponse: - model_batch_request: ModelBatchRequest = ModelBatchRequest() - model_batch_request.max_tokens = 100 + model_request = ModelBatchRequest() + model_request.max_tokens = 100 - summary_generation_prompt = "Create a brief 2-sentence text by condensing key points from both the Details " \ - "and the Step, omitting insignificant details." + summary_generation_prompt = ( + "Create a brief 2-sentence text by condensing key points from both the Details " + "and the Step, omitting insignificant details." + ) for title, step, details in zip(request.task_title, request.step_text, request.more_details): model_input = f"""Below is an instruction that describes a task, paired with an input that provides @@ -110,12 +122,13 @@ def generate_summaries(self, request: MultipleSummaryGenerationRequest) -> Multi ### Response: """ - model_batch_request.formatted_prompts.append(model_input) + model_request.formatted_prompts.append(model_input) - llm_responses = self.llm.batch_call_model(model_batch_request) + llm_responses: ModelBatchResponse = self.llm.batch_call_model(model_request) - llm_summaries: MultipleSummaryGenerationResponse = MultipleSummaryGenerationResponse() + llm_summaries = MultipleSummaryGenerationResponse() for text in llm_responses.text: + logger.info(f"Response: {text}") llm_summaries.summary.append(process_response_text(extract_qa_answer(text))) return llm_summaries From 69548a1bc823281d1a021438c46e686c46279b42 Mon Sep 17 00:00:00 2001 From: Andrew Ramsay Date: Fri, 3 May 2024 17:46:33 +0100 Subject: [PATCH 32/57] Update LLM timeout/error handling More removal of timeout and error handling around RPCs to `llm_functionalities` (now handled when `llm_functionalities` makes calls to TGI). --- functionalities/qa/llm_qa.py | 41 +---------------------- functionalities/qa/substitution_helper.py | 16 ++++----- 2 files changed, 7 insertions(+), 50 deletions(-) diff --git a/functionalities/qa/llm_qa.py b/functionalities/qa/llm_qa.py index 70e8c3e..afe80a3 100644 --- a/functionalities/qa/llm_qa.py +++ b/functionalities/qa/llm_qa.py @@ -283,7 +283,7 @@ def build_prompt(self, request: QARequest) -> ModelRequest: return model_request - def call_llm_qa(self, request: QAQuery) -> ChitChatResponse: + def generate_qa_response(self, request: QARequest, default_timeout: int = 1500) -> QAResponse: model_request = self.build_prompt(request) llm_response: ModelResponse = self.llm.call_model(model_request) logger.info(llm_response.text) @@ -296,45 +296,6 @@ def call_llm_qa(self, request: QAQuery) -> ChitChatResponse: return agent_response - def generate_qa_response(self, request: QARequest, default_timeout=1500) -> QAResponse: - # CALL LLM SERVICE FROM HERE - - with ThreadPoolExecutor(max_workers=1) as executor: - future = executor.submit(self.call_llm_qa, request) - timeout: float = default_timeout / 1000 + time.monotonic() - try: - if future.done() or timeout - time.monotonic() > 0: - response = future.result(timeout=timeout - time.monotonic()) - return response - - else: - future.cancel() - logger.warning(f"Timeout for LLM QA") - response = QAResponse() - response.text = "" - return response - - except TimeoutError as e: - future.cancel() - logger.warning("TimeoutError while running LLM QA", exc_info=e) - response = QAResponse() - response.text = "" - return response - - except _InactiveRpcError as e: - future.cancel() - logger.warning("LLM Channel is down") - response = QAResponse() - response.text = "" - return response - - except Exception as e: - future.cancel() - logger.info(e) - response = QAResponse() - response.text = "" - return response - def synth_response(self, request: QARequest) -> QAResponse: qa_response = self.generate_qa_response(request) user_question = request.query.text diff --git a/functionalities/qa/substitution_helper.py b/functionalities/qa/substitution_helper.py index 6371347..2552c1e 100644 --- a/functionalities/qa/substitution_helper.py +++ b/functionalities/qa/substitution_helper.py @@ -101,16 +101,12 @@ def generate_replacement(self, request: QARequest, response: QAResponse) -> Tupl original_req = self.get_original_ingredient(request) if not self.includes_amount(response): replacement_request.original_ingredient.MergeFrom(original_req) - try: - replacement: IngredientReplacementResponse = self.replacement_generator.generate_replacement( - replacement_request) - return replacement.new_ingredient, original_req - except _InactiveRpcError as e: - logger.info(e) - logger.warning("Replacement LLM Channel is down") - return Ingredient(), Ingredient() - - def create_substitution_idea(self, request, qa_response): + + replacement: IngredientReplacementResponse = self.replacement_generator.generate_replacement( + replacement_request) + return replacement.new_ingredient, original_req + + def create_substitution_idea(self, request: QARequest, qa_response: QAResponse) -> QAResponse: replacement_ingredient, original_ingredient = self.generate_replacement(request, qa_response) if replacement_ingredient.name != "": if original_ingredient.name.lower() != replacement_ingredient.name.lower(): From 77157b4a1035aaf4d3994e42f9f449ade518fb42 Mon Sep 17 00:00:00 2001 From: Andrew Ramsay Date: Fri, 3 May 2024 17:50:00 +0100 Subject: [PATCH 33/57] Latest version of LLMRunner Various changes: - Remove special handling of summary requests - Define a default timeout for TGI calls - Update `_check_connectivity` to return True/False rather than throw an exception, to allow for returning empty responses to the client instead - Add a `_call_with_timeout` wrapper method for the two `call_model` methods to allow them to submit TGI requests with a timeout applied - Timeouts are set from the `ModelRequest` objects or if not set there the default value is used --- llm_functionalities/llm_runner/llm_runner.py | 162 +++++++++---------- 1 file changed, 81 insertions(+), 81 deletions(-) diff --git a/llm_functionalities/llm_runner/llm_runner.py b/llm_functionalities/llm_runner/llm_runner.py index 0706125..b7144d6 100644 --- a/llm_functionalities/llm_runner/llm_runner.py +++ b/llm_functionalities/llm_runner/llm_runner.py @@ -1,7 +1,8 @@ import os import sys import time -import concurrent.futures +from concurrent.futures import TimeoutError, ThreadPoolExecutor +from typing import Optional, Callable, Dict, Any from huggingface_hub import InferenceClient @@ -11,12 +12,12 @@ ModelResponse, ModelBatchRequest, ModelBatchResponse, - TGISummaryRequest, - TGISummaryResponse, - TGIMultipleSummaryRequest, - TGIMultipleSummaryResponse, ) +# default timeout for requests that don't set a timeout explicitly +# TODO: allow override by env var? +DEFAULT_TGI_TIMEOUT_MS = 5000 + class LLMRunner: def __init__(self): @@ -25,6 +26,7 @@ def __init__(self): logger.error("No INFERENCE_ENDPOINT_URL defined, container will exit") sys.exit(-1) + # InferenceClient seems to require an http:// prefix for the URL if not endpoint_url.startswith("http://"): endpoint_url = f"http://{endpoint_url}" @@ -32,12 +34,11 @@ def __init__(self): retries = 0 retry_limit = int(os.environ.get("TGI_CONNECTION_RETRY_LIMIT", 10)) retry_delay = int(os.environ.get("TGI_CONNECTION_RETRY_DELAY", 10)) - logger.info( - f"Connecting to TGI (max {retry_limit} connections, {retry_delay}s apart)" - ) + logger.info(f"Connecting to TGI (max {retry_limit} connections, {retry_delay}s apart)") # might have to wait for the TGI container to finish starting up, especially if it - # needs to download model files first + # needs to download model files first. the two TGI_CONNECTION_RETRY_* env vars + # determine how long we'll wait for this to happen. while retries < retry_limit: client = self._connect_to_endpoint(endpoint_url) if client is None: @@ -50,12 +51,19 @@ def __init__(self): break if self.client is None: - logger.error( - f"LLMRunner failed to connect to the endpoint at {endpoint_url}" - ) + logger.error(f"LLMRunner failed to connect to the endpoint at {endpoint_url}") sys.exit(-1) - def _connect_to_endpoint(self, endpoint_url: str) -> InferenceClient: + def _connect_to_endpoint(self, endpoint_url: str) -> Optional[InferenceClient]: + """Attempt to make a connection to the configured TGI endpoint. + + Simply creating an InferenceClient object with the endpoint URL doesn't trigger + a connection, so to force that to happen this just submits a small text_generation + query. + + Returns: + None if the connection failed, otherwise an InferenceClient object + """ client = InferenceClient(model=endpoint_url, timeout=10.0) try: # creating the object doesn't appear to actually make a connection, so @@ -65,90 +73,82 @@ def _connect_to_endpoint(self, endpoint_url: str) -> InferenceClient: return None return client - def _check_connectivity(self) -> None: + def _check_connectivity(self) -> bool: + """Test if we have a connected InferenceClient""" if self.client is None: - raise Exception("llm_functionalities isn't connected to an endpoint!") + logger.error("!!! llm_functionalities isn't connected to an endpoint!") + return False + + return True + + def _call_with_timeout(self, timeout_ms: int, callable: Callable, params: Dict[str, Any]) -> str: + """Call a TGI endpoint with a timeout applied. + + Since we want to avoid potentially lengthy LLM computations delaying the system's responses, + we need to enforce a timeout on each TGI call. This method handles that through a + ThreadPoolExecutor and a future object. + + Returns: + The text response from the LLM, which will be empty if a timeout or other error + occurs (str). + """ + logger.info(f"call_with_timeout, timeout={timeout_ms}, params={params}") + with ThreadPoolExecutor(max_workers=1) as executor: + future = executor.submit(callable, **params) + try: + response = future.result(timeout=timeout_ms / 1000) + logger.info("TGI endpoint returned a response within timeout period") + return response + except TimeoutError as _: + future.cancel() + logger.warning(f"A call to the TGI endpoint has timed out after {timeout_ms} ms") + return "" + except Exception as e: + future.cancel() + logger.warning(f"A call to the TGI endpoint failed with an unexpected exception: {str(e)}") + return "" def call_model(self, model_request: ModelRequest) -> ModelResponse: + """Perform a single call to the LLM endpoint""" model_response: ModelResponse = ModelResponse() - self._check_connectivity() + if not self._check_connectivity(): + return model_response - try: - response = self.client.text_generation( - prompt=model_request.formatted_prompt, - max_new_tokens=model_request.max_tokens, - ) - logger.info(f"LLM response text: {response}") - model_response.text = response + timeout_ms = DEFAULT_TGI_TIMEOUT_MS if model_request.timeout == 0 else model_request.timeout - except Exception as e: - logger.warning(f"Call to inference endpoint failed: {e}") + args = { + "prompt": model_request.formatted_prompt, + "max_new_tokens": model_request.max_tokens, + } + response = self._call_with_timeout(timeout_ms, self.client.text_generation, args) + logger.info(f"LLM response text: {response}") + model_response.text = response return model_response def batch_call_model(self, model_request: ModelBatchRequest) -> ModelBatchResponse: + """Submit a batch of calls to the LLM endpoint""" model_responses: ModelBatchResponse = ModelBatchResponse() - self._check_connectivity() + if not self._check_connectivity(): + return model_responses - try: - formatted_prompts = list(model_request.formatted_prompts) - max_tokens = model_request.max_tokens - params = [ - {"prompt": p, "max_new_tokens": max_tokens} for p in formatted_prompts - ] + timeout_ms = DEFAULT_TGI_TIMEOUT_MS if model_request.timeout == 0 else model_request.timeout - logger.info(f"Submitting a batch of {len(params)} calls to TGI") - with concurrent.futures.ThreadPoolExecutor(max_workers=len(params)) as pool: - results = pool.map(lambda p: self.client.text_generation(**p), params) + formatted_prompts = list(model_request.formatted_prompts) + max_tokens = model_request.max_tokens + params = [{"prompt": p, "max_new_tokens": max_tokens} for p in formatted_prompts] - for response in results: - logger.info(f"LLM response text: {response}") - model_responses.text.append(response) + logger.info(f"Submitting a batch of {len(params)} calls to TGI") + # since TGI handles batching automatically in its backend, the recommended approach + # seems to be simply submitting multiple requests in parallel as in the test fixtures: + # https://github.com/huggingface/text-generation-inference/blob/main/integration-tests/conftest.py#L434 + with ThreadPoolExecutor(max_workers=len(params)) as pool: + results = pool.map(lambda p: self._call_with_timeout(timeout_ms, self.client.text_generation, p), params) - except Exception as e: - logger.warning(f"Call to inference endpoint failed: {e}") + for response in results: + logger.info(f"LLM response text: {response}") + model_responses.text.append(response) return model_responses - - def generate_summary(self, request: TGISummaryRequest) -> TGISummaryResponse: - response = TGISummaryResponse() - - self._check_connectivity() - - logger.info(f"generating summary from: {request.input_text}") - - try: - summarization_output = self.client.summarization(text=str(request.input_text)) - if summarization_output.summary_text is None: - # TODO raise exception? different string response? - response.summary_text = "" - else: - response.summary_text = summarization_output.summary_text - except Exception as e: - logger.warning(f"Call to summarization failed: {e}") - - return response - - def generate_summaries( - self, request: TGIMultipleSummaryRequest - ) -> TGIMultipleSummaryResponse: - response = TGIMultipleSummaryResponse() - - self._check_connectivity() - - try: - params = list(request.input_text) - with concurrent.futures.ThreadPoolExecutor(max_workers=len(params)) as pool: - results = pool.map(lambda p: self.client.summarization(p), params) - - for result in results: - if result.summary_text is None: - response.summary_text.append("") # TODO see comment above - else: - response.summary_text.append(result.summary_text) - except Exception as e: - logger.warning(f"Call to summarization failed: {e}") - - return response From 26d02968f9d38c79bb57d8babf98c8b685152b5b Mon Sep 17 00:00:00 2001 From: Andrew Ramsay Date: Fri, 3 May 2024 17:54:07 +0100 Subject: [PATCH 34/57] Small round of linting/formatting --- .../execution_search_manager.py | 43 +++++++++--------- .../llm_chit_chat/llm_chit_chat.py | 5 +-- .../llm_description_generation.py | 5 +-- .../llm_ingredient_step_text_rewriter.py | 3 +- .../llm_ingredient_substitution_generation.py | 45 ++++++++++--------- .../llm_proactive_question_generation.py | 1 - functionalities/qa/llm_qa.py | 20 ++++----- functionalities/qa/substitution_helper.py | 11 +++-- 8 files changed, 62 insertions(+), 71 deletions(-) diff --git a/functionalities/execution_search_manager/execution_search_manager.py b/functionalities/execution_search_manager/execution_search_manager.py index fcf9a1f..5c5269c 100644 --- a/functionalities/execution_search_manager/execution_search_manager.py +++ b/functionalities/execution_search_manager/execution_search_manager.py @@ -1,15 +1,11 @@ import grpc import os import json -import time from compiled_protobufs.llm_pb2 import ExecutionSearchResponse, ExecutionSearchRequest, ModelResponse, ModelRequest from compiled_protobufs.llm_pb2_grpc import LLMRunnerStub from compiled_protobufs.taskmap_pb2 import Session -from grpc._channel import _InactiveRpcError -from concurrent.futures import TimeoutError, ThreadPoolExecutor - from utils import logger, SEARCH_AGAIN_QUESTION @@ -17,12 +13,10 @@ class ExecutionSearchManager: def __init__(self): llm_channel = grpc.insecure_channel(os.environ["LLM_FUNCTIONALITIES_URL"]) self.llm = LLMRunnerStub(llm_channel) - self.valid_intents = [ - 'recommend_new_search', 'continue_current_task', 'ask_clarifying_question' - ] + self.valid_intents = ["recommend_new_search", "continue_current_task", "ask_clarifying_question"] def extract_classification_response(self, generated_answer) -> dict: - logger.info(f'Raw LLM response for classification: {generated_answer}') + logger.info(f"Raw LLM response for classification: {generated_answer}") valid_dict = {} start_token = "{" end_token = "}" @@ -33,40 +27,47 @@ def extract_classification_response(self, generated_answer) -> dict: generated_answer = "{" + generated_answer.split(start_token)[-1] try: - generated_answer = generated_answer.replace(' \n ', '').replace('\n', '') + generated_answer = generated_answer.replace(" \n ", "").replace("\n", "") valid_dict = json.loads(generated_answer) - logger.info(f'Managed to parse classification: {valid_dict}') + logger.info(f"Managed to parse classification: {valid_dict}") if "intent" in list(valid_dict.keys()) and valid_dict["intent"] in self.valid_intents: return valid_dict else: - logger.info(f'Dictionary contents not valid: {valid_dict}') + logger.info(f"Dictionary contents not valid: {valid_dict}") except Exception as e: - logger.info(f'Could not parse response >{generated_answer}<: {e}') + logger.info(f"Could not parse response >{generated_answer}<: {e}") return valid_dict def build_prompt(self, request: ExecutionSearchRequest) -> ModelRequest: model_request: ModelRequest = ModelRequest() - task_type = f"making {request.taskmap.title}" if request.domain == Session.Domain.COOKING \ + task_type = ( + f"making {request.taskmap.title}" + if request.domain == Session.Domain.COOKING else f"with {request.taskmap.title}" + ) # instruction - prompt = f"### Instruction: Imagine you are an AI assistant currently helping a user with {request.taskmap.title}. " \ - f"You are able to switch to a new task, so based on the last user request you need to decide " \ - f"whether to continue with {request.taskmap.title}, recommend that you can start new search based " \ - f"on the last user request, or asking a clarifying question. " \ - f"Choose one of the intent options and follow this format for your response: {{\"intent\": \"\"}}\n" + prompt = ( + f"### Instruction: Imagine you are an AI assistant currently helping a user with {request.taskmap.title}. " + f"You are able to switch to a new task, so based on the last user request you need to decide " + f"whether to continue with {request.taskmap.title}, recommend that you can start new search based " + f"on the last user request, or asking a clarifying question. " + f'Choose one of the intent options and follow this format for your response: {{"intent": ""}}\n' + ) # input prompt += f"### Input:\nCurrent task: {request.taskmap.title}\nConversation history:\n" - if any([request.last_agent_response == prompt for prompt in - SEARCH_AGAIN_QUESTION]) and "step" in request.last_last_agent_response.lower(): + if ( + any([request.last_agent_response == prompt for prompt in SEARCH_AGAIN_QUESTION]) + and "step" in request.last_last_agent_response.lower() + ): prompt += f"You: {request.last_last_agent_response}\n" else: prompt += f"You: {request.last_agent_response}\n" prompt += f"Last user request: {request.user_question}\n" prompt += f"Intent options:{str(self.valid_intents)}\n\n" # response - prompt += "### Response: Your response:{\"intent\":" + prompt += '### Response: Your response:{"intent":' model_request.formatted_prompt = prompt model_request.max_tokens = 20 diff --git a/functionalities/llm_chit_chat/llm_chit_chat.py b/functionalities/llm_chit_chat/llm_chit_chat.py index 35b8394..737f7ee 100644 --- a/functionalities/llm_chit_chat/llm_chit_chat.py +++ b/functionalities/llm_chit_chat/llm_chit_chat.py @@ -1,18 +1,15 @@ import grpc import os import re -import time from compiled_protobufs.chitchat_classifier_pb2 import ChitChatResponse from compiled_protobufs.llm_pb2 import ModelRequest, ModelResponse, LLMChitChatRequest from compiled_protobufs.llm_pb2_grpc import LLMRunnerStub -from grpc._channel import _InactiveRpcError -from concurrent.futures import TimeoutError, ThreadPoolExecutor from utils import logger -def extract_qa_answer(generated_answer): +def extract_qa_answer(generated_answer: str) -> str: start_token = "Your response:" end_token_1 = "Human" end_token_2 = "#" diff --git a/functionalities/llm_description_generation/llm_description_generation.py b/functionalities/llm_description_generation/llm_description_generation.py index 2f9810f..ce4c758 100644 --- a/functionalities/llm_description_generation/llm_description_generation.py +++ b/functionalities/llm_description_generation/llm_description_generation.py @@ -7,7 +7,6 @@ ) from compiled_protobufs.llm_pb2_grpc import LLMRunnerStub -from grpc._channel import _InactiveRpcError from utils import logger @@ -47,11 +46,11 @@ def generate_descriptions(self, request: LLMMultipleDescriptionGenerationRequest model_batch_request: ModelBatchRequest = ModelBatchRequest() model_batch_request.max_tokens = 64 - for title, ingredients, domain in zip(request.task_title, request.ingredients, request.domains): + for title, _, domain in zip(request.task_title, request.ingredients, request.domains): if domain == "wikihow": domain = "" else: - domain = f"Domain: Cooking\n" + domain = "Domain: Cooking\n" prompt = f"### Instruction:\n" \ f"Generate a 2 sentence description. It should be fun entertaining, sell the task and make the " \ f"user wanna start it. Imagine it being the intro that just sells the task.\n\n" \ diff --git a/functionalities/llm_ingredient_substitution/llm_ingredient_step_text_rewriter.py b/functionalities/llm_ingredient_substitution/llm_ingredient_step_text_rewriter.py index 2798e79..62cb70c 100644 --- a/functionalities/llm_ingredient_substitution/llm_ingredient_step_text_rewriter.py +++ b/functionalities/llm_ingredient_substitution/llm_ingredient_step_text_rewriter.py @@ -10,7 +10,6 @@ ) from compiled_protobufs.taskmap_pb2 import Session, Task, ReplacedIngredient -from grpc._channel import _InactiveRpcError from utils import logger from spacy.matcher import Matcher from pyserini.analysis import Analyzer, get_lucene_analyzer @@ -41,7 +40,7 @@ def extract_ingredient_name(self, question_text: str, requirements_list) -> str: return str_item return "" - def extract_response(self, generated_answer) -> str: + def extract_response(self, generated_answer) -> dict: valid_dict = {} start_token = "{" end_token = "}" diff --git a/functionalities/llm_ingredient_substitution/llm_ingredient_substitution_generation.py b/functionalities/llm_ingredient_substitution/llm_ingredient_substitution_generation.py index 9c96ef1..42e5782 100644 --- a/functionalities/llm_ingredient_substitution/llm_ingredient_substitution_generation.py +++ b/functionalities/llm_ingredient_substitution/llm_ingredient_substitution_generation.py @@ -1,17 +1,16 @@ import grpc import os import json -import time from compiled_protobufs.llm_pb2_grpc import LLMRunnerStub from compiled_protobufs.llm_pb2 import ( - IngredientReplacementRequest, IngredientReplacementResponse, ModelRequest, ModelResponse + IngredientReplacementRequest, + IngredientReplacementResponse, + ModelRequest, + ModelResponse, ) from compiled_protobufs.taskmap_pb2 import Ingredient -from grpc._channel import _InactiveRpcError -from concurrent.futures import TimeoutError, ThreadPoolExecutor - from utils import logger @@ -26,14 +25,14 @@ def extract_replacement_response(generated_answer, original_ing: Ingredient) -> generated_answer = "{" + generated_answer.split(start_token)[-1] try: - generated_answer = generated_answer.replace(' \n ', '').replace('\n', '') + generated_answer = generated_answer.replace(" \n ", "").replace("\n", "") valid_dict = json.loads(generated_answer) - logger.info(f'Managed to parse ingredient: {valid_dict}') + logger.info(f"Managed to parse ingredient: {valid_dict}") if "name" in list(valid_dict.keys()) and "amount" in list(valid_dict.keys()): return valid_dict else: - logger.debug(f'Dictionary contents not valid: {valid_dict}') + logger.debug(f"Dictionary contents not valid: {valid_dict}") except Exception as e: if '"amount"' in generated_answer and "name" in generated_answer: @@ -43,20 +42,20 @@ def extract_replacement_response(generated_answer, original_ing: Ingredient) -> if "}" in generated_answer: second_part = generated_answer.split('"amount"')[1].replace(":", "").replace("}", "").strip() valid_dict["amount"] = second_part - logger.info(f'Managed to parse ingredient (2nd try): {valid_dict}') + logger.info(f"Managed to parse ingredient (2nd try): {valid_dict}") return valid_dict else: second_part = generated_answer.split('"amount"')[1] if second_part.replace(":", "").strip() in original_ing.name: valid_dict["amount"] = original_ing.name - logger.info(f'Managed to parse ingredient (3rd try): {valid_dict}') + logger.info(f"Managed to parse ingredient (3rd try): {valid_dict}") return valid_dict if second_part.replace(":", "").strip() in original_ing.amount: valid_dict["amount"] = original_ing.amount - logger.info(f'Managed to parse ingredient (3rd try): {valid_dict}') + logger.info(f"Managed to parse ingredient (3rd try): {valid_dict}") return valid_dict except Exception as e: - logger.debug(f'Second parsing did not work.') + logger.debug(f"Second parsing did not work.") return valid_dict @@ -75,16 +74,18 @@ def build_prompt(request: IngredientReplacementRequest) -> ModelRequest: original_ing = "" orginal_ing_amount = "" if request.original_ingredient.amount == "" else request.original_ingredient.amount - prompt = f"### Instruction: You are a friendly AI assistant who is assisting a human making " \ - f"{request.task_title}. You are helping the user to replace an ingredient in the recipe. " \ - f"If the amount is not specified in your earlier suggestion, extract the amount from the " \ - f"original ingredient. If your earlier suggestion responds that this substitution is not possible, " \ - f"please respond with an empty ingredient.\n " \ - f"Respond in this format: {{\"name\": \"\", \"amount\": \"{orginal_ing_amount}\"}}\n\n " \ - f"### Input:\n{original_ing}" \ - f"User: {request.user_question}\n" \ - f"Your suggestion: {request.agent_response}\n\n" \ - f"### Response:{{\"name\"" + prompt = ( + f"### Instruction: You are a friendly AI assistant who is assisting a human making " + f"{request.task_title}. You are helping the user to replace an ingredient in the recipe. " + f"If the amount is not specified in your earlier suggestion, extract the amount from the " + f"original ingredient. If your earlier suggestion responds that this substitution is not possible, " + f"please respond with an empty ingredient.\n " + f'Respond in this format: {{"name": "", "amount": "{orginal_ing_amount}"}}\n\n ' + f"### Input:\n{original_ing}" + f"User: {request.user_question}\n" + f"Your suggestion: {request.agent_response}\n\n" + f'### Response:{{"name"' + ) model_request.formatted_prompt = prompt model_request.max_tokens = 20 diff --git a/functionalities/llm_proactive_question_generation/llm_proactive_question_generation.py b/functionalities/llm_proactive_question_generation/llm_proactive_question_generation.py index 18460d9..85c2baf 100644 --- a/functionalities/llm_proactive_question_generation/llm_proactive_question_generation.py +++ b/functionalities/llm_proactive_question_generation/llm_proactive_question_generation.py @@ -10,7 +10,6 @@ from taskmap_pb2 import ExtraInfo from utils import logger -from grpc._channel import _InactiveRpcError def extract_question(generated_question): diff --git a/functionalities/qa/llm_qa.py b/functionalities/qa/llm_qa.py index afe80a3..822cee8 100644 --- a/functionalities/qa/llm_qa.py +++ b/functionalities/qa/llm_qa.py @@ -1,13 +1,10 @@ import os -import grpc -import time import random import re -import spacy +from typing import Optional -from grpc._channel import _InactiveRpcError -from concurrent.futures import TimeoutError, ThreadPoolExecutor -from typing import List +import spacy +import grpc from taskmap_pb2 import Session, Task, ExecutionStep from qa_pb2 import QAQuery, QARequest, QAResponse, DocumentList @@ -15,7 +12,6 @@ from llm_pb2_grpc import LLMRunnerStub from .abstract_qa import AbstractQA -from chitchat_classifier_pb2 import ChitChatResponse from task_graph import TaskGraph from utils import ( @@ -44,7 +40,7 @@ def __init__(self, environ_var: str): logger.info('LLM QA SpaCy initialized') @staticmethod - def __get_helpful_prompt(user_query: str) -> str: + def __get_helpful_prompt(user_query: str) -> Optional[str]: for word in user_query.split(" "): for keyword, answer in HELPFUL_PROMPT_PAIRS: if word == keyword: @@ -54,6 +50,8 @@ def __get_helpful_prompt(user_query: str) -> str: if keyword in user_query: return answer + return None + def process_last_sentence(self, last_sentence: str) -> str: if last_sentence.endswith(('.', '!', '?')): @@ -76,7 +74,7 @@ def process_last_sentence(self, last_sentence: str) -> str: return f" {last_sentence}".strip() - def extract_qa_answer(self, generated_answer): + def extract_qa_answer(self, generated_answer: str) -> str: start_token = "Your response: " end_token_1 = "Human" end_token_2 = "#" @@ -102,11 +100,10 @@ def domain_retrieve(self, query: QAQuery) -> DocumentList: pass @staticmethod - def strip_newlines(text): + def strip_newlines(text: str) -> str: return text.replace('\n', " ").replace(" ", " ") def __build_context_task_selected(self, task_graph: TaskGraph, query: str, question_type: str) -> str: - context = [] if task_graph.author != "": @@ -208,7 +205,6 @@ def __build_context_general(self, request) -> str: return " ".join(context) def build_prompt(self, request: QARequest) -> ModelRequest: - task_graph: TaskGraph = TaskGraph(request.query.taskmap) user_question: str = request.query.text model_request: ModelRequest = ModelRequest() diff --git a/functionalities/qa/substitution_helper.py b/functionalities/qa/substitution_helper.py index 2552c1e..077a218 100644 --- a/functionalities/qa/substitution_helper.py +++ b/functionalities/qa/substitution_helper.py @@ -1,20 +1,19 @@ import grpc import os import random -import spacy +from typing import Tuple -from utils import logger, REPLACE_SUGGESTION, jaccard_sim, NOT_POSSIBLE, indri_stop_words +import spacy +from spacy.matcher import Matcher +from pyserini.analysis import Analyzer, get_lucene_analyzer from qa_pb2 import QARequest, QAResponse from llm_pb2 import IngredientReplacementRequest, IngredientReplacementResponse from llm_pb2_grpc import LLMReplacementGenerationStub from taskmap_pb2 import ReplacedIngredient, Ingredient -from grpc._channel import _InactiveRpcError -from spacy.matcher import Matcher -from pyserini.analysis import Analyzer, get_lucene_analyzer -from typing import Tuple +from utils import logger, REPLACE_SUGGESTION, jaccard_sim, NOT_POSSIBLE, indri_stop_words class SubstitutionHelper: def __init__(self): From cd4ff58da90d3e25e4f3a19c430f9a923f4e41dd Mon Sep 17 00:00:00 2001 From: Andrew Ramsay Date: Fri, 3 May 2024 17:55:29 +0100 Subject: [PATCH 35/57] Small bugfix Python strings don't support `foo[1] = "."` assignments, so this would throw an exception if it was executed --- functionalities/qa/llm_qa.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/functionalities/qa/llm_qa.py b/functionalities/qa/llm_qa.py index 822cee8..9cf5781 100644 --- a/functionalities/qa/llm_qa.py +++ b/functionalities/qa/llm_qa.py @@ -70,7 +70,7 @@ def process_last_sentence(self, last_sentence: str) -> str: return "" if last_sentence.endswith((',', ":", ";")): - last_sentence[-1] = "." + last_sentence = last_sentence[:-1] + "." return f" {last_sentence}".strip() From b8cac13b828bca480dafbb8fe11806f8a4c37685 Mon Sep 17 00:00:00 2001 From: Andrew Ramsay Date: Fri, 3 May 2024 18:09:17 +0100 Subject: [PATCH 36/57] Adding some WIP LLM tests --- tester/tests/integration_tests/test_llm.py | 202 +++++++++++++++++++++ 1 file changed, 202 insertions(+) create mode 100644 tester/tests/integration_tests/test_llm.py diff --git a/tester/tests/integration_tests/test_llm.py b/tester/tests/integration_tests/test_llm.py new file mode 100644 index 0000000..0693ba9 --- /dev/null +++ b/tester/tests/integration_tests/test_llm.py @@ -0,0 +1,202 @@ +import os +import json +import time +import uuid +import random + +import grpc +from google.protobuf.json_format import Parse, ParseDict, MessageToDict +import pytest + +from utils import INTRO_PROMPTS, REPLACE_SUGGESTION +from taskmap_pb2 import TaskMap, Session, ConversationTurn, Task, SessionState +from database_pb2 import SessionRequest, TaskMapRequest +from database_pb2_grpc import DatabaseStub +from llm_pb2 import ( + MultipleSummaryGenerationRequest, + MultipleSummaryGenerationResponse, +) +from llm_pb2_grpc import ( + LLMSummaryGenerationStub, +) + +# TODO: test timeout behaviour +# TODO: test TGI being unavailable + + +def test_llm_summary_generation(sample_taskmap_json_wikihow: str): + """ + Tests functionality for generating steps from a TaskMap. + """ + + # TODO: adapt this to trigger the actual code like the other tests + + # parse the raw JSON sample taskmap into a TaskMap object + taskmap = TaskMap() + Parse(sample_taskmap_json_wikihow, taskmap) + + functionalities_channel = grpc.insecure_channel(os.environ["FUNCTIONALITIES_URL"]) + llm_summary_stub = LLMSummaryGenerationStub(functionalities_channel) + + # this code is a slightly modified version of LLMTaskMapEnhancer.enhance_taskmap from + # external_functionalities. it loads an example TaskMap instance and constructs and + # sends a summarisation request to exercise the llm_summary_generation path + steps = taskmap.steps + grouped_steps = [steps[i : i + 3] for i in range(0, len(steps), 3)] + + for group in grouped_steps: + request: MultipleSummaryGenerationRequest = MultipleSummaryGenerationRequest() + for step in group: + request.task_title.append(taskmap.title) + request.step_text.append(step.response.speech_text) + request.more_details.append(step.response.description) + summaries: MultipleSummaryGenerationResponse = llm_summary_stub.generate_summaries(request) + for i, step in enumerate(group): + step.response.speech_text = summaries.summary[i] + + # TODO check output + + +def test_llm_chit_chat(new_session_obj, interaction_obj): + session = new_session_obj + + # start things off with a "hi" + response = interaction_obj.run("hi", session) + # check that this response comes from the domain_policy as expected, and ends with one of the + # INTRO_PROMPT values + assert response["source"]["policy"] == "domain_policy" + interaction_obj.check_prompts(response, INTRO_PROMPTS) + + # next, reply with "what's up" - this should be classed as a ChitChatIntent and + # trigger a call to the LLM chit chat service + response = interaction_obj.run("what's up", session) + # the response should now come from a particular line in chitchat_policy.py if + # it successfully retrieved an LLM response + assert response["source"]["policy"] == "chitchat_policy" + assert response["source"]["lineNumber"] == 152 + + # TODO check output + + +@pytest.mark.slow +def test_llm_enhancement(new_session_obj, sample_taskmap_json_wikihow: str): + """ + Tests the TaskMap enhancement performed by external_functionalities. + + This should exercise both the llm_description_generation and + llm_proactive_question_generation gRPC services in functionalities. + + Normally this is triggered by loading a session with a non-enhanced TaskMap. + After a TaskMap has been enhanced it's saved to a database so it doesn't get + enhanced again. To make this work in a repeatable way, this test loads the same + sample TaskMap but overwrites its "taskmap_id" field with a new value each time. + + An additional problem is that the enhancement is triggered by loading an existing + session. To workaround this, the test creates a minimal valid Session, saves it, + then performs a load to start the enhancement process. + """ + session = new_session_obj + taskmap = TaskMap() + Parse(sample_taskmap_json_wikihow, taskmap) + + # construct a protobuf Session object from the basic JSON representation in "session" + # and copy in the sample TaskMap + session_proto = Session() + ParseDict(session, session_proto) + session_proto.task.taskmap.CopyFrom(taskmap) + + # construct a SessionRequest for an RPC to external_functionalities + session_request: SessionRequest = SessionRequest() + session_request.id = session["session_id"] + session_request.session.CopyFrom(session_proto) + + channel = grpc.insecure_channel(os.environ["EXTERNAL_FUNCTIONALITIES_URL"]) + db = DatabaseStub(channel) + + # make sure this TaskMap isn't already in the database + taskmap_request = TaskMapRequest() + taskmap_request.id = taskmap.taskmap_id + db.delete_taskmap(taskmap_request) + + # save the new session and then load it to trigger the enhancement + db.save_session(session_request) + session = db.load_session(session_request) + + # because the database object spawns a worker thread on the remote side, + # the RPC will return immediately. wait here until reloading the session + # shows it marked as enhanced (via session.task.state.enhanced == True), + # or we timeout + timeout = 30 + delay = 5 + enhanced = False + while not enhanced and timeout > 0: + time.sleep(delay) + timeout -= delay + session = db.load_session(session_request) + enhanced = session.task.state.enhanced + + assert enhanced + + +def test_llm_ingredient_substitution_step_text(interaction_obj, sample_taskmap_json_wikihow: str): + # this functionality is called from orchestrator/policy/validation_policy/validation_policy.py + + # it looks like what we need to trigger this is: + # session.task.phase = Task.TaskPhase.VALIDATING (to trigger ValidationPolicy) + # then in ValidationPolicy, the conditions necessary to hit the rewriter are: + # 1. session.task.state.requirements_displayed = True + # 2. session.turn[-2].user_request.interaction contains a QuestionIntent + # 3. session.turn[-2].agent_response.interaction.speech_text contains any REPLACE_SUGGESTION + # 4. session.turn[-1].user_request.interaction contains a YesIntent + + taskmap = TaskMap() + Parse(sample_taskmap_json_wikihow, taskmap) + + # construct a protobuf Session object from the basic JSON representation in session + session_proto = Session() + session_proto.session_id = f"test_{str(uuid.uuid4())}" + + # configure the rest of the session state required + session_proto.state = SessionState.RUNNING + session_proto.task.phase = Task.TaskPhase.VALIDATING + session_proto.task.state.requirements_displayed = True + + # need to add a turn where the user_request.interaction contains a QuestionIntent + t1 = ConversationTurn() + t1.user_request.interaction.intents.extend(["QuestionIntent"]) + t1.id = f"test_turn_{uuid.uuid4()}" + + # and the agent must respond with a REPLACE_SUGGESTION prompt + t1.agent_response.interaction.speech_text = random.choice(REPLACE_SUGGESTION) + session_proto.turn.extend([t1]) + + # the following turn should also contain a YesIntent, but that will be generated + # by setting the appropriate user input text below + + # save the new session first, so that the orchestrator will load it + # with the correct state when we submit the interaction request + session_request: SessionRequest = SessionRequest() + session_request.id = session_proto.session_id + session_request.session.CopyFrom(session_proto) + channel = grpc.insecure_channel(os.environ["EXTERNAL_FUNCTIONALITIES_URL"]) + + db = DatabaseStub(channel) + db.save_session(session_request) + + # make sure this TaskMap isn't already in the database + taskmap_request = TaskMapRequest() + taskmap_request.id = taskmap.taskmap_id + db.delete_taskmap(taskmap_request) + + s = MessageToDict(session_proto, including_default_value_fields=True) + + # use speech text that will generate a YesIntent, required to trigger the substitution + s["session_id"] = session_proto.session_id + resp = interaction_obj.run("yes ok", s) + + # TODO: the problem with this is it finds nothing to do with the current sample taskmap, need to try some others + + +def test_llm_ingredient_substitution_2(): + # see functionalities/qa/composed_qa.py + pass From 3b719477bab155ebd39131d3d731c6ca46f38727 Mon Sep 17 00:00:00 2001 From: Andrew Ramsay Date: Fri, 10 May 2024 18:27:21 +0100 Subject: [PATCH 37/57] Adding a delete_stagedoutput to make a test work --- external_functionalities/database/dynamo_db.py | 15 +++++++++++++++ .../database/llm_taskmap_enhancer.py | 1 - external_functionalities/database/servicer.py | 4 ++++ shared/protobufs/database.proto | 2 ++ 4 files changed, 21 insertions(+), 1 deletion(-) diff --git a/external_functionalities/database/dynamo_db.py b/external_functionalities/database/dynamo_db.py index 0d0aa2a..4d489b4 100644 --- a/external_functionalities/database/dynamo_db.py +++ b/external_functionalities/database/dynamo_db.py @@ -53,9 +53,21 @@ def save_session(self, session_id: str, session: Session) -> None: def load_session(self, session_id: str) -> Session: session = self.session_db.get(session_id) + if self.taskmap_enhancer.trigger(session): + # this is triggered if the taskmap has a wikihow URL and it hasn't already been enhanced + # (i.e. session.task.state.enhanced is False). + # Calling .enhance here with questions_only defaulting to False will peform two types + # of enhancement: + # - step summarization + # - proactive question enhancement + logger.info("Taskmap Enhancement triggered: Summarization and Proactive QA") return self.taskmap_enhancer.enhance(session) elif session.task.state.enhanced is False and session.task.taskmap.taskmap_id != "": + # in this case (not a wikihow URL, not already enhanced, non-empty taskmap ID), + # calling enhance with questions_only set to True will *only* run the + # proactive_question_enhancement part + logger.info("Taskmap Enhancement triggered: Proactive QA only") return self.taskmap_enhancer.enhance(session, questions_only=True) else: return session @@ -72,6 +84,9 @@ def load_taskmap(self, taskmap_id: str) -> TaskMap: def delete_taskmap(self, taskmap_id: str) -> None: self.taskmap_db.delete(taskmap_id) + def delete_stagedoutput(self, taskmap_id: str) -> None: + self.taskmap_enhancer.db.delete(taskmap_id) + def save_search_log(self, search_log: SearchLog) -> None: self.search_logs_db.put(search_log) diff --git a/external_functionalities/database/llm_taskmap_enhancer.py b/external_functionalities/database/llm_taskmap_enhancer.py index b1670f2..1b06a44 100644 --- a/external_functionalities/database/llm_taskmap_enhancer.py +++ b/external_functionalities/database/llm_taskmap_enhancer.py @@ -51,7 +51,6 @@ def __get_indices_with_criteria(length): return list(indices) def proactive_question_enhancement(self, taskmap, staged_db): - logger.info("GENERATE QUESTIONS") request: ProactiveQuestionGenerationRequest = ProactiveQuestionGenerationRequest() indices = self.__get_indices_with_criteria(len(taskmap.steps)) diff --git a/external_functionalities/database/servicer.py b/external_functionalities/database/servicer.py index 7bdec10..68125d9 100644 --- a/external_functionalities/database/servicer.py +++ b/external_functionalities/database/servicer.py @@ -32,6 +32,10 @@ def delete_taskmap(self, request, context) -> None: self.instance.delete_taskmap(request.id) return Void() + def delete_stagedoutput(self, request, context) -> None: + self.instance.delete_stagedoutput(request.id) + return Void() + def save_search_logs(self, request, context) -> None: self.instance.save_search_log(request) return Void() diff --git a/shared/protobufs/database.proto b/shared/protobufs/database.proto index 67787de..a5efe10 100644 --- a/shared/protobufs/database.proto +++ b/shared/protobufs/database.proto @@ -38,6 +38,8 @@ service Database{ rpc save_session(SessionRequest) returns (Void) {} rpc delete_session(SessionRequest) returns (Void) {} + rpc delete_stagedoutput(TaskMapRequest) returns (Void) {} + rpc save_search_logs(SearchLog) returns(Void) {} rpc save_asr_logs(ASRLog) returns(Void) {} From 2d4bff4d7017e48d7cec542b636af5b8ab93e040 Mon Sep 17 00:00:00 2001 From: Andrew Ramsay Date: Fri, 10 May 2024 18:29:43 +0100 Subject: [PATCH 38/57] Update set_source call to add a message --- orchestrator/policy/chitchat_policy/chitchat_policy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/orchestrator/policy/chitchat_policy/chitchat_policy.py b/orchestrator/policy/chitchat_policy/chitchat_policy.py index 7321e0e..c1d23a5 100644 --- a/orchestrator/policy/chitchat_policy/chitchat_policy.py +++ b/orchestrator/policy/chitchat_policy/chitchat_policy.py @@ -149,7 +149,7 @@ def step(self, session: Session) -> Tuple[Session, OutputInteraction]: output.speech_text = f'{chitchat_response.text} {random.choice(transition_options)} {keyword_helpful_prompt.lower()}' else: output.speech_text = f'{chitchat_response.text} {random.choice(transition_options)} {helpful_prompt.lower()}' - set_source(output) + set_source(output, msg="chitchat from LLM") elif chitchat_response.text != "": output.speech_text = chitchat_response.text set_source(output) From 88e07d80cea57a68d58977c5af2f9e31cb708b7b Mon Sep 17 00:00:00 2001 From: Andrew Ramsay Date: Fri, 10 May 2024 18:29:59 +0100 Subject: [PATCH 39/57] Update sample taskmap --- shared/test_data/sample_taskmap_seriouseats.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/shared/test_data/sample_taskmap_seriouseats.json b/shared/test_data/sample_taskmap_seriouseats.json index 0adef02..53641ff 100644 --- a/shared/test_data/sample_taskmap_seriouseats.json +++ b/shared/test_data/sample_taskmap_seriouseats.json @@ -1 +1 @@ -{"taskmapId": "075f4135afedcb5870a371e9e85eac77", "title": "Chicken Maple Sausage Pigs in a Blanket with Maple\u00a0Cream", "date": "2018-11-15", "sourceUrl": "https://food52.com/recipes/78217-chicken-maple-sausage-pigs-in-a-blanket-with-maple-cream", "description": "A more sophisticated take on an old classic. \u2014My Stir Crazy Kitchen", "thumbnailUrl": "https://images.food52.com/PxRPfj1W6qHiGd4ta0b0EM-PM3g=/1000x1000/b22b1adc-e3d1-4005-b305-cbe602cb002c--7L7A1053.JPG", "totalTimeMinutes": "10", "ratingOut100": 100, "tags": ["American", "Serves a Crowd", "Appetizer"], "requirementList": [{"uniqueId": "630cf6ff-f6f1-4644-a98d-1d6431b7655b", "name": " chicken maple sausage links", "amount": "8 pieces"}, {"uniqueId": "11f97441-0874-4e0f-ba20-8e2c535a925b", "name": "packet crescent rolls", "amount": "1"}, {"uniqueId": "15960d6b-ca07-47ac-8b29-aade24aa30f2", "name": " heavy cream", "amount": "6 tablespoons"}, {"uniqueId": "4265f16f-dfb9-4ff2-a316-85cce7f540ed", "name": " maple syrup", "amount": "3 tablespoons"}, {"uniqueId": "a6b93747-cf41-4499-b5eb-2c5adb2f3771", "name": " kosher salt", "amount": "1 teaspoon"}], "serves": "\nMakes\n 8\n ", "steps": [{"uniqueId": "0483c625-5b1e-4893-9761-0c0869200ea7", "response": {"speechText": "In a small bowl, stir to combine the cream, maple syrup and salt.", "screen": {"format": "TEXT_IMAGE", "paragraphs": ["In a small bowl, stir to combine the cream, maple syrup and salt."], "imageList": [{"path": "https://www.seriouseats.com/thmb/fsxoprpah3N9VyuUcm2S7i4RlfI=/1500x0/filters:no_upscale():max_bytes(150000):strip_icc()/__opt__aboutcom__coeus__resources__content_migration__serious_eats__seriouseats.com__recipes__images__2012__01__20120117-187614_GFTues_PigsInABlanket_610-c075c91ac6c84089934469d43382cbb0.jpg"}], "requirements": [" heavy cream", " kosher salt", " maple syrup"], "extraInformation": [{"type": "FUNFACT", "text": "If you take a bowl and put it upside down, it becomes a poml.", "keyword": "bowl", "imageUrl": "https://alexa-oat-images.s3.amazonaws.com/multimodal_images/facts/fact_bowl.png"}]}}, "activeDurationMinutes": 1, "totalDurationMinutes": 1}, {"uniqueId": "345b60f3-9ec1-496c-b578-12774723e682", "response": {"speechText": "Preheat oven to 350F.", "screen": {"format": "TEXT_IMAGE", "paragraphs": ["Preheat oven to 350F."], "imageList": [{"path": "https://www.seriouseats.com/thmb/fsxoprpah3N9VyuUcm2S7i4RlfI=/1500x0/filters:no_upscale():max_bytes(150000):strip_icc()/__opt__aboutcom__coeus__resources__content_migration__serious_eats__seriouseats.com__recipes__images__2012__01__20120117-187614_GFTues_PigsInABlanket_610-c075c91ac6c84089934469d43382cbb0.jpg"}], "extraInformation": [{"type": "JOKE", "text": "Give the oven a hand. It's the real cook, we're just it's assistants.", "keyword": "oven", "imageUrl": "https://alexa-oat-images.s3.amazonaws.com/multimodal_images/jokes/joke_oven.png"}]}}, "activeDurationMinutes": 1, "totalDurationMinutes": 1}, {"uniqueId": "391ddab7-b1b4-478d-b5ae-c32939f4d905", "response": {"speechText": "Lay out each crescent triangle and top with a sausage link.", "screen": {"format": "TEXT_IMAGE", "paragraphs": ["Lay out each crescent triangle and top with a sausage link."], "imageList": [{"path": "https://www.seriouseats.com/thmb/fsxoprpah3N9VyuUcm2S7i4RlfI=/1500x0/filters:no_upscale():max_bytes(150000):strip_icc()/__opt__aboutcom__coeus__resources__content_migration__serious_eats__seriouseats.com__recipes__images__2012__01__20120117-187614_GFTues_PigsInABlanket_610-c075c91ac6c84089934469d43382cbb0.jpg"}], "requirements": [" chicken maple sausage links"]}}, "activeDurationMinutes": 1, "totalDurationMinutes": 1}, {"uniqueId": "14ada727-31c1-418f-8c84-6e6764de95e9", "response": {"speechText": "Roll the crescent rolls around the links.", "screen": {"format": "TEXT_IMAGE", "paragraphs": ["Roll the crescent rolls around the links."], "imageList": [{"path": "https://images.food52.com/PxRPfj1W6qHiGd4ta0b0EM-PM3g=/1000x1000/b22b1adc-e3d1-4005-b305-cbe602cb002c--7L7A1053.JPG"}], "requirements": [" chicken maple sausage links"]}}, "activeDurationMinutes": 1, "totalDurationMinutes": 1}, {"uniqueId": "1e99067c-d101-41ba-8641-04b0b1384664", "response": {"speechText": "Brush with maple cream sauce.", "screen": {"format": "TEXT_IMAGE", "paragraphs": ["Brush with maple cream sauce."], "imageList": [{"path": "https://www.seriouseats.com/thmb/fsxoprpah3N9VyuUcm2S7i4RlfI=/1500x0/filters:no_upscale():max_bytes(150000):strip_icc()/__opt__aboutcom__coeus__resources__content_migration__serious_eats__seriouseats.com__recipes__images__2012__01__20120117-187614_GFTues_PigsInABlanket_610-c075c91ac6c84089934469d43382cbb0.jpg"}], "extraInformation": [{"type": "JOKE", "text": "Cream is the magician of the kitchen. One moment it's liquid, the next it's whipped, and just when you thought you had it figured out, it turns into butter.", "keyword": "cream", "imageUrl": "https://alexa-oat-images.s3.amazonaws.com/multimodal_images/jokes/joke_cream_v1.png"}]}}, "activeDurationMinutes": 1, "totalDurationMinutes": 1}, {"uniqueId": "204a681b-23b8-4138-8de1-7307f0f14083", "response": {"speechText": "Bake for 12-14 minutes until golden brown. Serve with remaining maple cream sauce.", "screen": {"format": "TEXT_IMAGE", "paragraphs": ["Bake for 12-14 minutes until golden brown. Serve with remaining maple cream sauce."], "imageList": [{"path": "https://www.seriouseats.com/thmb/fsxoprpah3N9VyuUcm2S7i4RlfI=/1500x0/filters:no_upscale():max_bytes(150000):strip_icc()/__opt__aboutcom__coeus__resources__content_migration__serious_eats__seriouseats.com__recipes__images__2012__01__20120117-187614_GFTues_PigsInABlanket_610-c075c91ac6c84089934469d43382cbb0.jpg"}], "extraInformation": [{"type": "JOKE", "text": "Cream is the magician of the kitchen. One moment it's liquid, the next it's whipped, and just when you thought you had it figured out, it turns into butter.", "keyword": "cream", "imageUrl": "https://alexa-oat-images.s3.amazonaws.com/multimodal_images/jokes/joke_cream_v1.png"}]}}, "activeDurationMinutes": 1, "totalDurationMinutes": 1}], "connectionList": [{"idFrom": "1e99067c-d101-41ba-8641-04b0b1384664", "idTo": "204a681b-23b8-4138-8de1-7307f0f14083"}, {"idFrom": "0483c625-5b1e-4893-9761-0c0869200ea7", "idTo": "345b60f3-9ec1-496c-b578-12774723e682"}, {"idFrom": "14ada727-31c1-418f-8c84-6e6764de95e9", "idTo": "1e99067c-d101-41ba-8641-04b0b1384664"}, {"idFrom": "391ddab7-b1b4-478d-b5ae-c32939f4d905", "idTo": "14ada727-31c1-418f-8c84-6e6764de95e9"}, {"idFrom": "345b60f3-9ec1-496c-b578-12774723e682", "idTo": "391ddab7-b1b4-478d-b5ae-c32939f4d905"}], "dataset": "common-crawl", "author": "My Stir Crazy Kitchen", "domainName": "food52.com"} \ No newline at end of file +{"taskmapId": "fe1cfe172799b69d125e6fd3af16241e", "title": "Raw Fig\u00a0Bars", "date": "2014-12-14", "sourceUrl": "https://food52.com/recipes/32540-roasted-strawberry-brioche-pudding", "description": "Raw dried fig bars are not often found and these are topped with dark chocolate drizzle and a pinch of salt which make them hard to resist. \u2014Sylvie Taylor @ Roamingtaste", "thumbnailUrl": "https://images.food52.com/oOnHcjzuF1_CkEJa4Go7Vca7vaE=/1000x1000/52f804f9-0298-4b41-9b53-cb5d9a7d66a4--raw_fig_bars.jpg", "ratingOut100": 100, "tags": ["Pudding", "Candy", "Australian/New Zealander", "Milk/Cream", "Strawberry", "Chocolate", "Fruit", "Vegan", "Dessert", "Snack"], "requirementList": [{"uniqueId": "4829157f-7a04-47d7-b430-7e8f8e4b7f85", "name": " dried figs, ends removed and sliced into quarters", "amount": "3 cups"}, {"uniqueId": "40a0ab87-69c5-48ba-8837-2062e06ee4a7", "name": " walnuts", "amount": "1/2 cup"}, {"uniqueId": "69ef2860-7354-4125-b828-5fca4ed449f4", "name": "cup brazil nuts", "amount": "1/3"}, {"uniqueId": "ed3e1a75-c2e8-43aa-8224-fa19ab39f2ec", "name": " chia seeds", "amount": "1 tablespoon"}, {"uniqueId": "beba33df-f12b-4709-b898-16b7ce481f8f", "name": " linseeds", "amount": "1 tablespoon"}, {"uniqueId": "7f55ef57-d350-41c0-8769-9a4f5434a52f", "name": " dark chocolate", "amount": "30 grams"}, {"uniqueId": "ba9e7441-c3ac-42cc-be66-eea42c043568", "name": " coconut oil", "amount": "1/2 teaspoon"}, {"uniqueId": "5428c96f-97f9-4587-b4d5-c7333cc9e577", "name": "Pinch course sea salt", "amount": " "}], "serves": "\nServes\n 8\n ", "steps": [{"uniqueId": "65d7ca29-c84c-4303-9da6-d35be303e206", "response": {"speechText": "Place the quartered figs into a bowl and place in a double boiler, cover and add two tablespoons boiling water until the figs are warmed through, approximately 5 minutes.", "screen": {"format": "TEXT_IMAGE", "paragraphs": ["Place the quartered figs into a bowl and place in a double boiler, cover and add two tablespoons boiling water until the figs are warmed through, approximately 5 minutes."], "imageList": [{"path": "https://images.food52.com/bvKhL9d3y6F8WKxDcMCmvjIKLYg=/1000x1000/a8a61e51-8734-40d6-92ef-98c5e86a0bba--Toffee_1.JPG"}], "requirements": [" dried figs, ends removed and sliced into quarters"]}}, "activeDurationMinutes": 1, "totalDurationMinutes": 1}, {"uniqueId": "19d73de6-f991-469a-ad77-a6cc98f187dd", "response": {"speechText": "Meanwhile, place the nuts and seeds into a blender and pulse until roughly chopped. Set aside.", "screen": {"format": "TEXT_IMAGE", "paragraphs": ["Meanwhile, place the nuts and seeds into a blender and pulse until roughly chopped. Set aside."], "imageList": [{"path": "https://images.food52.com/ZL62D1IGpUJJIWaujlJVplYLfQQ=/1000x1000/e701cb7c-0002-4607-8de7-792ec99f9814--Quince_cheese_004.JPG"}], "requirements": [" chia seeds"], "extraInformation": [{"type": "JOKE", "text": "Much like a baseball game, a blender includes a base and a pitcher. (And sometimes, a batter!)", "keyword": "blender", "imageUrl": "https://alexa-oat-images.s3.amazonaws.com/multimodal_images/jokes/joke_blender_2.png"}]}}, "activeDurationMinutes": 1, "totalDurationMinutes": 1}, {"uniqueId": "d09ba3ac-46a9-4050-8021-0d7ded4cea0f", "response": {"speechText": "Place the warmed figs in a food processor and pulse on and off until roughly chopped and sticky (your food processor may become rather warm and you will have to clear out the blades every 30 seconds to a minute of pulsing, so give it time to cool down if necessary).", "screen": {"format": "TEXT_IMAGE", "paragraphs": ["Place the warmed figs in a food processor and pulse on and off until roughly chopped and sticky (your food processor may become rather warm and you will have to clear out the blades every 30 seconds to a minute of pulsing, so give it time to cool down if necessary)."], "imageList": [{"path": "https://images.food52.com/tq8Tlgmiy2cY2Xpa5CyJZHtq4yw=/1000x1000/fb01978e-59b5-4296-90b7-da1a40be3ac7--power_bar.jpg"}], "requirements": [" dried figs, ends removed and sliced into quarters"]}}, "activeDurationMinutes": 1, "totalDurationMinutes": 1}, {"uniqueId": "60b13db3-d932-446f-8345-53c63dc35146", "response": {"speechText": "Place the figs in the bowl with nuts and seeds and stir until well combined.", "screen": {"format": "TEXT_IMAGE", "paragraphs": ["Place the figs in the bowl with nuts and seeds and stir until well combined."], "imageList": [{"path": "https://images.food52.com/ZL62D1IGpUJJIWaujlJVplYLfQQ=/1000x1000/e701cb7c-0002-4607-8de7-792ec99f9814--Quince_cheese_004.JPG"}], "requirements": [" dried figs, ends removed and sliced into quarters", " chia seeds"], "extraInformation": [{"type": "FUNFACT", "text": "If you take a bowl and put it upside down, it becomes a poml.", "keyword": "bowl", "imageUrl": "https://alexa-oat-images.s3.amazonaws.com/multimodal_images/facts/fact_bowl.png"}]}}, "activeDurationMinutes": 1, "totalDurationMinutes": 1}, {"uniqueId": "241279a8-170f-4d73-9c64-d9796e0a3468", "response": {"speechText": "Line a loaf dish with greasproof paper and spoon the mixture in, pressing down evenly on all sides. Refrigerate for a minimum of an hour.", "screen": {"format": "TEXT_IMAGE", "paragraphs": ["Line a loaf dish with greasproof paper and spoon the mixture in, pressing down evenly on all sides. Refrigerate for a minimum of an hour."], "imageList": [{"path": "https://images.food52.com/bvKhL9d3y6F8WKxDcMCmvjIKLYg=/1000x1000/a8a61e51-8734-40d6-92ef-98c5e86a0bba--Toffee_1.JPG"}]}}, "activeDurationMinutes": 1, "totalDurationMinutes": 1}, {"uniqueId": "f1bbbb36-89e7-4f8a-81d0-227bd9201cb5", "response": {"speechText": "Melt the chocolate and coconut oil, stirring until smooth.", "screen": {"format": "TEXT_IMAGE", "paragraphs": ["Melt the chocolate and coconut oil, stirring until smooth."], "imageList": [{"path": "https://images.food52.com/bvKhL9d3y6F8WKxDcMCmvjIKLYg=/1000x1000/a8a61e51-8734-40d6-92ef-98c5e86a0bba--Toffee_1.JPG"}], "requirements": [" coconut oil", " dark chocolate"], "extraInformation": [{"type": "JOKE", "text": "Did you know that it took over eight years to develop the recipe for milk chocolate? Be glad you weren't one of those testers.", "keyword": "chocolate", "imageUrl": "https://alexa-oat-images.s3.amazonaws.com/multimodal_images/jokes/joke_chocolate.png"}, {"type": "FUNFACT", "text": "Chocolate is a one of the most popular forms of apology.", "keyword": "chocolate", "imageUrl": "https://alexa-oat-images.s3.amazonaws.com/multimodal_images/facts/fact_chocolate.png"}]}}, "activeDurationMinutes": 1, "totalDurationMinutes": 1}, {"uniqueId": "4b07f2bd-767c-4b0b-b3a9-8890d7fe00ce", "response": {"speechText": "Remove the chilled bars from the fridge, slice and drizzle the chocolate over the top.", "screen": {"format": "TEXT_IMAGE", "paragraphs": ["Remove the chilled bars from the fridge, slice and drizzle the chocolate over the top."], "imageList": [{"path": "https://images.food52.com/tq8Tlgmiy2cY2Xpa5CyJZHtq4yw=/1000x1000/fb01978e-59b5-4296-90b7-da1a40be3ac7--power_bar.jpg"}], "requirements": [" dark chocolate"], "extraInformation": [{"type": "JOKE", "text": "Did you know that it took over eight years to develop the recipe for milk chocolate? Be glad you weren't one of those testers.", "keyword": "chocolate", "imageUrl": "https://alexa-oat-images.s3.amazonaws.com/multimodal_images/jokes/joke_chocolate.png"}, {"type": "FUNFACT", "text": "Chocolate is a one of the most popular forms of apology.", "keyword": "chocolate", "imageUrl": "https://alexa-oat-images.s3.amazonaws.com/multimodal_images/facts/fact_chocolate.png"}]}}, "activeDurationMinutes": 1, "totalDurationMinutes": 1}, {"uniqueId": "f9f47158-d828-43a7-90e4-7b06647ec4ed", "response": {"speechText": "Sprinkle the pinch of course sea salt over the top.", "screen": {"format": "TEXT_IMAGE", "paragraphs": ["Sprinkle the pinch of course sea salt over the top."], "imageList": [{"path": "https://images.food52.com/bvKhL9d3y6F8WKxDcMCmvjIKLYg=/1000x1000/a8a61e51-8734-40d6-92ef-98c5e86a0bba--Toffee_1.JPG"}], "requirements": ["Pinch course sea salt"], "extraInformation": [{"type": "JOKE", "text": "Careful with this one. Add too much salt, and it's going to taste all salty.", "keyword": "salt", "imageUrl": "https://alexa-oat-images.s3.amazonaws.com/multimodal_images/jokes/joke_salt_2.png"}]}}, "activeDurationMinutes": 1, "totalDurationMinutes": 1}, {"uniqueId": "4848afc8-de32-4321-8633-afc9e2ce4ce5", "response": {"speechText": "Keep refrigerated until serving.", "screen": {"format": "TEXT_IMAGE", "paragraphs": ["Keep refrigerated until serving."], "imageList": [{"path": "https://images.food52.com/bvKhL9d3y6F8WKxDcMCmvjIKLYg=/1000x1000/a8a61e51-8734-40d6-92ef-98c5e86a0bba--Toffee_1.JPG"}]}}, "activeDurationMinutes": 1, "totalDurationMinutes": 1}], "connectionList": [{"idFrom": "19d73de6-f991-469a-ad77-a6cc98f187dd", "idTo": "d09ba3ac-46a9-4050-8021-0d7ded4cea0f"}, {"idFrom": "f9f47158-d828-43a7-90e4-7b06647ec4ed", "idTo": "4848afc8-de32-4321-8633-afc9e2ce4ce5"}, {"idFrom": "241279a8-170f-4d73-9c64-d9796e0a3468", "idTo": "f1bbbb36-89e7-4f8a-81d0-227bd9201cb5"}, {"idFrom": "60b13db3-d932-446f-8345-53c63dc35146", "idTo": "241279a8-170f-4d73-9c64-d9796e0a3468"}, {"idFrom": "4b07f2bd-767c-4b0b-b3a9-8890d7fe00ce", "idTo": "f9f47158-d828-43a7-90e4-7b06647ec4ed"}, {"idFrom": "f1bbbb36-89e7-4f8a-81d0-227bd9201cb5", "idTo": "4b07f2bd-767c-4b0b-b3a9-8890d7fe00ce"}, {"idFrom": "d09ba3ac-46a9-4050-8021-0d7ded4cea0f", "idTo": "60b13db3-d932-446f-8345-53c63dc35146"}, {"idFrom": "65d7ca29-c84c-4303-9da6-d35be303e206", "idTo": "19d73de6-f991-469a-ad77-a6cc98f187dd"}], "dataset": "common-crawl", "author": "Sylvie Taylor @ Roamingtaste", "domainName": "food52.com"} \ No newline at end of file From b857c9decf82716b57ae3455f0df7d322b7ffa91 Mon Sep 17 00:00:00 2001 From: Andrew Ramsay Date: Fri, 10 May 2024 18:30:06 +0100 Subject: [PATCH 40/57] Current version of LLM tests --- tester/tests/integration_tests/test_llm.py | 184 +++++++++++---------- 1 file changed, 101 insertions(+), 83 deletions(-) diff --git a/tester/tests/integration_tests/test_llm.py b/tester/tests/integration_tests/test_llm.py index 0693ba9..b8bada8 100644 --- a/tester/tests/integration_tests/test_llm.py +++ b/tester/tests/integration_tests/test_llm.py @@ -9,7 +9,7 @@ import pytest from utils import INTRO_PROMPTS, REPLACE_SUGGESTION -from taskmap_pb2 import TaskMap, Session, ConversationTurn, Task, SessionState +from taskmap_pb2 import TaskMap, Session, ConversationTurn, Task, SessionState, ReplacedIngredient from database_pb2 import SessionRequest, TaskMapRequest from database_pb2_grpc import DatabaseStub from llm_pb2 import ( @@ -24,56 +24,35 @@ # TODO: test TGI being unavailable -def test_llm_summary_generation(sample_taskmap_json_wikihow: str): - """ - Tests functionality for generating steps from a TaskMap. +def test_llm_chit_chat(new_session_obj, interaction_obj): """ + Simple test of LLM chit chat functionality. - # TODO: adapt this to trigger the actual code like the other tests - - # parse the raw JSON sample taskmap into a TaskMap object - taskmap = TaskMap() - Parse(sample_taskmap_json_wikihow, taskmap) - - functionalities_channel = grpc.insecure_channel(os.environ["FUNCTIONALITIES_URL"]) - llm_summary_stub = LLMSummaryGenerationStub(functionalities_channel) - - # this code is a slightly modified version of LLMTaskMapEnhancer.enhance_taskmap from - # external_functionalities. it loads an example TaskMap instance and constructs and - # sends a summarisation request to exercise the llm_summary_generation path - steps = taskmap.steps - grouped_steps = [steps[i : i + 3] for i in range(0, len(steps), 3)] - - for group in grouped_steps: - request: MultipleSummaryGenerationRequest = MultipleSummaryGenerationRequest() - for step in group: - request.task_title.append(taskmap.title) - request.step_text.append(step.response.speech_text) - request.more_details.append(step.response.description) - summaries: MultipleSummaryGenerationResponse = llm_summary_stub.generate_summaries(request) - for i, step in enumerate(group): - step.response.speech_text = summaries.summary[i] + To trigger this code, we just create a new Session and send + a couple of messages that will trigger the ChitChatPolicy. - # TODO check output - - -def test_llm_chit_chat(new_session_obj, interaction_obj): + To check that the response comes from the LLM successfully + the "source" object associated with the response needs to + have a message that's set in the appropriate block of code + in chitchat_policy.py + """ session = new_session_obj # start things off with a "hi" response = interaction_obj.run("hi", session) - # check that this response comes from the domain_policy as expected, and ends with one of the - # INTRO_PROMPT values + # check that this response comes from the domain_policy as expected, + # and ends with one of the INTRO_PROMPT values assert response["source"]["policy"] == "domain_policy" interaction_obj.check_prompts(response, INTRO_PROMPTS) # next, reply with "what's up" - this should be classed as a ChitChatIntent and # trigger a call to the LLM chit chat service response = interaction_obj.run("what's up", session) - # the response should now come from a particular line in chitchat_policy.py if - # it successfully retrieved an LLM response + + # the response should now come from a particular code block in + # chitchat_policy.py if it successfully retrieved an LLM response assert response["source"]["policy"] == "chitchat_policy" - assert response["source"]["lineNumber"] == 152 + assert response["source"]["message"] == "chitchat from LLM" # TODO check output @@ -83,20 +62,35 @@ def test_llm_enhancement(new_session_obj, sample_taskmap_json_wikihow: str): """ Tests the TaskMap enhancement performed by external_functionalities. - This should exercise both the llm_description_generation and - llm_proactive_question_generation gRPC services in functionalities. + Specifically this is the process triggered by loading/saving a session in + external_functionalities/database/dynamo_db.py. - Normally this is triggered by loading a session with a non-enhanced TaskMap. - After a TaskMap has been enhanced it's saved to a database so it doesn't get - enhanced again. To make this work in a repeatable way, this test loads the same - sample TaskMap but overwrites its "taskmap_id" field with a new value each time. + This test sets up a minimal valid Session with a Wikihow taskmap, and deletes any + existing copies of the TaskMap and its StagedOutput representation from the + DynamoDB instance as a first step. This is to force the enhancement to run + every time as normally the results are cached in the database. - An additional problem is that the enhancement is triggered by loading an existing - session. To workaround this, the test creates a minimal valid Session, saves it, - then performs a load to start the enhancement process. + Doing this should exercise multiple different LLM components in functionalities. + When a suitable session is loaded in DynamoDB.load_session (one which has a + Wikihow URL and is not marked as already enhanced), it will call the enhance() + method of the StagedEnhance class. For TaskMaps that haven't yet been enhanced, + it spawns a thread which will end up calling LLMTaskMapEnhancer.enhance_taskmap. + + This does 2 things: + 1. It makes some llm_summary_generation requests based on groups of steps + in the TaskMap + 2. It makes an llm_proactive_question_answering request + + i.e. both of these components are exercised by this single test. """ + + # create the objects required session = new_session_obj taskmap = TaskMap() + channel = grpc.insecure_channel(os.environ["EXTERNAL_FUNCTIONALITIES_URL"]) + db = DatabaseStub(channel) + + # parse the TaskMap JSON into a protobuf Parse(sample_taskmap_json_wikihow, taskmap) # construct a protobuf Session object from the basic JSON representation in "session" @@ -105,29 +99,29 @@ def test_llm_enhancement(new_session_obj, sample_taskmap_json_wikihow: str): ParseDict(session, session_proto) session_proto.task.taskmap.CopyFrom(taskmap) - # construct a SessionRequest for an RPC to external_functionalities + # construct a SessionRequest to make an RPC to external_functionalities session_request: SessionRequest = SessionRequest() session_request.id = session["session_id"] session_request.session.CopyFrom(session_proto) - channel = grpc.insecure_channel(os.environ["EXTERNAL_FUNCTIONALITIES_URL"]) - db = DatabaseStub(channel) - - # make sure this TaskMap isn't already in the database + # delete any stored copies of the TaskMap and its StagedOutput (enhanced) representation taskmap_request = TaskMapRequest() - taskmap_request.id = taskmap.taskmap_id + taskmap_request.id = taskmap.taskmap_id db.delete_taskmap(taskmap_request) + db.delete_stagedoutput(taskmap_request) # save the new session and then load it to trigger the enhancement + # (saving will also trigger some enhancement, but not the full process) db.save_session(session_request) session = db.load_session(session_request) # because the database object spawns a worker thread on the remote side, # the RPC will return immediately. wait here until reloading the session # shows it marked as enhanced (via session.task.state.enhanced == True), - # or we timeout - timeout = 30 - delay = 5 + # or we timeout. Repeated calls to load_session won't trigger repeated + # enhancement requests so this is OK to do. + timeout = 60 + delay = 10 enhanced = False while not enhanced and timeout > 0: time.sleep(delay) @@ -138,36 +132,62 @@ def test_llm_enhancement(new_session_obj, sample_taskmap_json_wikihow: str): assert enhanced -def test_llm_ingredient_substitution_step_text(interaction_obj, sample_taskmap_json_wikihow: str): - # this functionality is called from orchestrator/policy/validation_policy/validation_policy.py - - # it looks like what we need to trigger this is: - # session.task.phase = Task.TaskPhase.VALIDATING (to trigger ValidationPolicy) - # then in ValidationPolicy, the conditions necessary to hit the rewriter are: - # 1. session.task.state.requirements_displayed = True - # 2. session.turn[-2].user_request.interaction contains a QuestionIntent - # 3. session.turn[-2].agent_response.interaction.speech_text contains any REPLACE_SUGGESTION - # 4. session.turn[-1].user_request.interaction contains a YesIntent +def test_llm_ingredient_substitution_step_text(interaction_obj, sample_taskmap_json_seriouseats: str): + """ + Tests the llm_ingredient_substitution step rewriting functionality. + + This can be triggered by orchestrator/policy/validation_policy/validation_policy.py + in certain conditions: + - the session.task.phase must be Task.TaskPhase.VALIDATING (to trigger the policy) + - session.task.state.requirements_displayed must be True + - session.turn[-2].user_request.interaction must contain a QuestionIntent + - session.turn[-2].agent_response.interaction.speech_text must contain a REPLACE_SUGGESTION + - session.turn[-1].user_request.interaction must contain a YesIntent + + If all these conditions are satisifed, a call is made to adjust_step_texts inside + LLMIngredientStepTextRewriter. + """ + channel = grpc.insecure_channel(os.environ["EXTERNAL_FUNCTIONALITIES_URL"]) + db = DatabaseStub(channel) taskmap = TaskMap() - Parse(sample_taskmap_json_wikihow, taskmap) + + Parse(sample_taskmap_json_seriouseats, taskmap) + print("ID: ", taskmap.taskmap_id) + + # make sure our TaskMap isn't already in the database + taskmap_request = TaskMapRequest() + taskmap_request.id = taskmap.taskmap_id + db.delete_taskmap(taskmap_request) + db.delete_stagedoutput(taskmap_request) # construct a protobuf Session object from the basic JSON representation in session session_proto = Session() session_proto.session_id = f"test_{str(uuid.uuid4())}" - # configure the rest of the session state required + # configure the session state required to trigger this LLM functionality session_proto.state = SessionState.RUNNING session_proto.task.phase = Task.TaskPhase.VALIDATING session_proto.task.state.requirements_displayed = True + # set up a replaced ingredient (this is normally done via QA policy) + ri = ReplacedIngredient() + ri.original.name = "chocolate" + ri.original.amount = "30 grams" + ri.replacement.name = "cheese" + ri.replacement.amount = "100 grams" + session_proto.task.taskmap.CopyFrom(taskmap) + session_proto.task.taskmap.replaced_ingredients.append(ri) + # need to add a turn where the user_request.interaction contains a QuestionIntent t1 = ConversationTurn() t1.user_request.interaction.intents.extend(["QuestionIntent"]) + # also need to have an ingredient mentioned in the interaction text + t1.user_request.interaction.text = "can you replace the chocolate in this recipe?" t1.id = f"test_turn_{uuid.uuid4()}" # and the agent must respond with a REPLACE_SUGGESTION prompt - t1.agent_response.interaction.speech_text = random.choice(REPLACE_SUGGESTION) + t1.agent_response.interaction.speech_text = f"{random.choice(REPLACE_SUGGESTION)} {ri.replacement.name}" session_proto.turn.extend([t1]) # the following turn should also contain a YesIntent, but that will be generated @@ -178,25 +198,23 @@ def test_llm_ingredient_substitution_step_text(interaction_obj, sample_taskmap_j session_request: SessionRequest = SessionRequest() session_request.id = session_proto.session_id session_request.session.CopyFrom(session_proto) - channel = grpc.insecure_channel(os.environ["EXTERNAL_FUNCTIONALITIES_URL"]) - - db = DatabaseStub(channel) db.save_session(session_request) - # make sure this TaskMap isn't already in the database - taskmap_request = TaskMapRequest() - taskmap_request.id = taskmap.taskmap_id - db.delete_taskmap(taskmap_request) + # working with a session via interaction_obj requires a dict s = MessageToDict(session_proto, including_default_value_fields=True) - - # use speech text that will generate a YesIntent, required to trigger the substitution s["session_id"] = session_proto.session_id + # use speech text that will generate a YesIntent, required to trigger the substitution resp = interaction_obj.run("yes ok", s) - # TODO: the problem with this is it finds nothing to do with the current sample taskmap, need to try some others - - -def test_llm_ingredient_substitution_2(): - # see functionalities/qa/composed_qa.py - pass + # check that at least one of the steps now contains our replaced ingredient + db.save_session(session_request) + session = db.load_session(session_request) + found_replacement = False + for step in session.task.taskmap.steps: + if ri.replacement.name in step.response.screen.requirements or ri.replacement.name in step.response.speech_text: + found_replacement = True + break + + # TODO: currently failing, need to try with the original model + assert found_replacement From d19c3381328a39c05cca7e0d554689379dc3b27b Mon Sep 17 00:00:00 2001 From: Andrew Ramsay Date: Thu, 23 May 2024 14:09:29 +0100 Subject: [PATCH 41/57] Updating docker-compose.yml - Remove old TGI_CONNECTION_* env vars - Add a volume to the TGI container to load local models from - Update comments - Adjust default SHM size - Set the default MODEL_ID to be the local Alpaca model --- docker-compose.yml | 35 +++++++++++++++++++---------------- 1 file changed, 19 insertions(+), 16 deletions(-) diff --git a/docker-compose.yml b/docker-compose.yml index e481c52..91ad30b 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,4 +1,3 @@ -version: "3.9" # optional since v1.27.0 services: builder: @@ -294,15 +293,10 @@ services: - FUNCTIONALITIES_URL=functionalities:8000 - EXTERNAL_FUNCTIONALITIES_URL=external_functionalities:8000 # the location of the TGI endpoint (note that it's using the internal - # Docker network, so it's port 80 rather than 8081) + # Docker network, so it's port 80 rather than 8081). this could also be + # set to a remote location not managed by the current Compose instance, + # in which case the URL should reference port 8081 instead. - INFERENCE_ENDPOINT_URL=tgi:80 - # these values can be used to control how long this service waits - # for the TGI endpoint to become available. this might take a significant - # amount of time in some cases, e.g. if it has to download a large model - # Number of retries (default 10) - - TGI_CONNECTION_RETRY_LIMIT=${TGI_CONNECTION_RETRY_LIMIT:-10} - # delay between retries in seconds - - TGI_CONNECTION_RETRY_DELAY=${TGI_CONNECTION_RETRY_DELAY:-10} networks: - internal - external @@ -321,25 +315,34 @@ services: ports: # this isn't needed for llm_functionalities (it uses the internal network), # but might be useful to be able to submit requests to the instance from - # external scripts for testing/debugging + # external scripts for testing/debugging, or if the TGI service is running + # in a different location than the rest of the system. - "8081:80" volumes: - # the container downloads weights and other files to this path + # if the model is downloaded from huggingface, the files will be stored + # in this location - ./shared/file_system/tgi:/data + # this volume is used to allow the use of the existing Alpaca LLM + # that OAT is currently set up to use. setting MODEL_ID below to a + # path instead of a huggingface model ID will cause TGI to load model + # files from that location + - ./shared/file_system/downloads/llm_functionalities/:/models environment: # setting the value of MODEL_ID is equivalent to passing "--model-id" parameter - # to the TGI launcher - - MODEL_ID=${MODEL_ID:-google/flan-t5-large} + # to the TGI launcher. the default is to use the Alpaca LLM model that OAT + # has previously used + - MODEL_ID=${MODEL_ID:-/models/alpaca_llm/} # any other TGI launcher parameters can be set in this env var, e.g.: # TGI_PARAMS="--param1 param1_value --param2 param2_value" docker compose up - TGI_PARAMS=${TGI_PARAMS:-} - # this is required for Mistral or other gated models + # this is required for Mistral or other gated models. you can find your + # tokens by browsing to https://huggingface.co/settings/tokens - HUGGING_FACE_HUB_TOKEN=${HUGGING_FACE_HUB_TOKEN:-} networks: - internal - external - # larger sharded models will need this increased from the default (usually 64MB) - shm_size: ${SHM_SIZE:-2gb} + # larger sharded models will need this increased from the Docker default of 64MB + shm_size: ${SHM_SIZE:-1gb} deploy: resources: reservations: From d3f823f1eca3fda25b0bf06c72c94c78a8466929 Mon Sep 17 00:00:00 2001 From: Andrew Ramsay Date: Thu, 23 May 2024 14:10:57 +0100 Subject: [PATCH 42/57] Add some extra logging around enhancing taskmaps --- external_functionalities/database/staged_enhance.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/external_functionalities/database/staged_enhance.py b/external_functionalities/database/staged_enhance.py index 478e15c..60c75c8 100644 --- a/external_functionalities/database/staged_enhance.py +++ b/external_functionalities/database/staged_enhance.py @@ -73,6 +73,8 @@ def _enhance_desc(self, session_db, taskmap_db, session: Session) -> None: taskmaps_to_enhance = [] request: LLMMultipleDescriptionGenerationRequest = LLMMultipleDescriptionGenerationRequest() + logger.info("Enhancing descriptions") + for i, candidate in enumerate(session.task_selection.candidates_union): # avoid categories if candidate.HasField('category'): @@ -99,6 +101,7 @@ def _enhance_desc(self, session_db, taskmap_db, session: Session) -> None: return # Generate Descriptions and update session + logger.info(f"Enhancing descriptions: {len(request.task_title)} candidates found") descriptions = self.llm_desc_generator.generate_descriptions(request) for taskmap, description in zip(taskmaps_to_enhance, descriptions.description): From e17f40c3b11ab3992ee4105c6563b418af303488 Mon Sep 17 00:00:00 2001 From: Andrew Ramsay Date: Thu, 23 May 2024 14:12:20 +0100 Subject: [PATCH 43/57] Update LLM ingredient substitution test Instead of reloading the session from the database we can just check the JSON object returned by the orchestrator, since it should include some of the changed text. --- tester/tests/integration_tests/test_llm.py | 18 +++++------------- 1 file changed, 5 insertions(+), 13 deletions(-) diff --git a/tester/tests/integration_tests/test_llm.py b/tester/tests/integration_tests/test_llm.py index b8bada8..4785ea0 100644 --- a/tester/tests/integration_tests/test_llm.py +++ b/tester/tests/integration_tests/test_llm.py @@ -146,6 +146,8 @@ def test_llm_ingredient_substitution_step_text(interaction_obj, sample_taskmap_j If all these conditions are satisifed, a call is made to adjust_step_texts inside LLMIngredientStepTextRewriter. + + (this will also trigger the proactive question answering component) """ channel = grpc.insecure_channel(os.environ["EXTERNAL_FUNCTIONALITIES_URL"]) @@ -153,7 +155,6 @@ def test_llm_ingredient_substitution_step_text(interaction_obj, sample_taskmap_j taskmap = TaskMap() Parse(sample_taskmap_json_seriouseats, taskmap) - print("ID: ", taskmap.taskmap_id) # make sure our TaskMap isn't already in the database taskmap_request = TaskMapRequest() @@ -205,16 +206,7 @@ def test_llm_ingredient_substitution_step_text(interaction_obj, sample_taskmap_j s = MessageToDict(session_proto, including_default_value_fields=True) s["session_id"] = session_proto.session_id # use speech text that will generate a YesIntent, required to trigger the substitution - resp = interaction_obj.run("yes ok", s) + new_session_dict = interaction_obj.run("yes ok", s) - # check that at least one of the steps now contains our replaced ingredient - db.save_session(session_request) - session = db.load_session(session_request) - found_replacement = False - for step in session.task.taskmap.steps: - if ri.replacement.name in step.response.screen.requirements or ri.replacement.name in step.response.speech_text: - found_replacement = True - break - - # TODO: currently failing, need to try with the original model - assert found_replacement + # check that the requirements now include the replacment ingredient + assert any(ri.replacement.name in req for req in new_session_dict["screen"]["requirements"]) From 747050da2f1da1f4d2b170b6aac5c87db1c3c1d6 Mon Sep 17 00:00:00 2001 From: Andrew Ramsay Date: Thu, 23 May 2024 14:22:07 +0100 Subject: [PATCH 44/57] Refactoring LLMRunner connection handling The previous version of this class entered a loop on startup to periodically check if the TGI endpoint address was connectable. If it was able to make a connection it assumed that it would be remain valid indefinitely. This version should be a bit more flexible in that it now only attempts to connect when LLM requests are triggered. It should also handle the endpoint becoming unavailable and then available again because the first request sent after it comes back up should create a new client object for future requests. The other important change here is the use of dnspython to test if the TGI endpoint hostname is resolvable. I found that if this isn't true (e.g. if you launch OAT with a remote TGI endpoint that is still starting up), the DNS resolution process takes 10+ seconds to timeout. This seems to be difficult to handle using the Python stdlib, but the dnspython package makes it very simple to set a timeout on the resolution process. This allows the per-request connection checking to work without excessive delays. --- llm_functionalities/llm_runner/llm_runner.py | 86 ++++++++++---------- 1 file changed, 44 insertions(+), 42 deletions(-) diff --git a/llm_functionalities/llm_runner/llm_runner.py b/llm_functionalities/llm_runner/llm_runner.py index b7144d6..1dd1294 100644 --- a/llm_functionalities/llm_runner/llm_runner.py +++ b/llm_functionalities/llm_runner/llm_runner.py @@ -4,6 +4,7 @@ from concurrent.futures import TimeoutError, ThreadPoolExecutor from typing import Optional, Callable, Dict, Any +import dns.resolver from huggingface_hub import InferenceClient from utils import logger @@ -16,7 +17,8 @@ # default timeout for requests that don't set a timeout explicitly # TODO: allow override by env var? -DEFAULT_TGI_TIMEOUT_MS = 5000 +DEFAULT_TGI_TIMEOUT = 5 +DEFAULT_TGI_TIMEOUT_MS = DEFAULT_TGI_TIMEOUT * 1000 class LLMRunner: @@ -30,62 +32,57 @@ def __init__(self): if not endpoint_url.startswith("http://"): endpoint_url = f"http://{endpoint_url}" + self.endpoint_url = endpoint_url self.client = None - retries = 0 - retry_limit = int(os.environ.get("TGI_CONNECTION_RETRY_LIMIT", 10)) - retry_delay = int(os.environ.get("TGI_CONNECTION_RETRY_DELAY", 10)) - logger.info(f"Connecting to TGI (max {retry_limit} connections, {retry_delay}s apart)") - - # might have to wait for the TGI container to finish starting up, especially if it - # needs to download model files first. the two TGI_CONNECTION_RETRY_* env vars - # determine how long we'll wait for this to happen. - while retries < retry_limit: - client = self._connect_to_endpoint(endpoint_url) - if client is None: - logger.info(f"LLMRunner retrying connection to {endpoint_url}") - time.sleep(retry_delay) - retries += 1 - else: - logger.info("LLMRunner connected to endpoint!") - self.client = client - break - - if self.client is None: - logger.error(f"LLMRunner failed to connect to the endpoint at {endpoint_url}") - sys.exit(-1) + self._connect_to_endpoint() - def _connect_to_endpoint(self, endpoint_url: str) -> Optional[InferenceClient]: + def _connect_to_endpoint(self) -> bool: """Attempt to make a connection to the configured TGI endpoint. Simply creating an InferenceClient object with the endpoint URL doesn't trigger a connection, so to force that to happen this just submits a small text_generation - query. + query. If the client object already exists it's taken to mean the connection + has already been successfully established. Returns: - None if the connection failed, otherwise an InferenceClient object + True if a connection exists/was created, False otherwise. """ - client = InferenceClient(model=endpoint_url, timeout=10.0) + + # a client has been instantiated, assume everything is OK + if self.client is not None: + return True + try: - # creating the object doesn't appear to actually make a connection, so - # try something that will fail if it can't connect - client.text_generation(prompt="hello?", max_new_tokens=10) - except Exception: - return None - return client - - def _check_connectivity(self) -> bool: - """Test if we have a connected InferenceClient""" - if self.client is None: - logger.error("!!! llm_functionalities isn't connected to an endpoint!") + # First, need to check if we can actually resolve the TGI endpoint + # hostname. If this isn't available, it will cause urllib3 and the Python + # socket library to spend >10s waiting for a result. It seems to be hard + # to do a time-limited host lookup using the stdlib, + # but dnspython makes it easy (lifetime is the timeout parameter here): + dns.resolver.resolve(self.endpoint_url[7:].split(":")[0], lifetime=1) + + # If the hostname is resolved, we then setup a longer timeout which + # TGI will use internally when making HTTP requests via request.post() etc + client = InferenceClient(model=self.endpoint_url, timeout=DEFAULT_TGI_TIMEOUT) + + # Finally, creating the object doesn't appear to actually make a connection, + # so try something that requires a valid connection to succeed + response = client.text_generation(prompt="hello?", max_new_tokens=10) + except Exception as e: + logger.warn(f"Failed to connect to TGI at {self.endpoint_url} ({e})") + return False + + if len(response) == 0: + logger.warn(f"Failed to connect to TGI at {self.endpoint_url} (timed out)") return False + self.client = client return True def _call_with_timeout(self, timeout_ms: int, callable: Callable, params: Dict[str, Any]) -> str: """Call a TGI endpoint with a timeout applied. Since we want to avoid potentially lengthy LLM computations delaying the system's responses, - we need to enforce a timeout on each TGI call. This method handles that through a + we need to enforce a timeout on each TGI call. This method handles that using a ThreadPoolExecutor and a future object. Returns: @@ -101,18 +98,23 @@ def _call_with_timeout(self, timeout_ms: int, callable: Callable, params: Dict[s return response except TimeoutError as _: future.cancel() + executor.shutdown(wait=False, cancel_futures=True) logger.warning(f"A call to the TGI endpoint has timed out after {timeout_ms} ms") return "" except Exception as e: future.cancel() + executor.shutdown(wait=False, cancel_futures=True) logger.warning(f"A call to the TGI endpoint failed with an unexpected exception: {str(e)}") + # assume this means a problem with the endpoint and connect again + # on a subsequent request + self.client = None return "" def call_model(self, model_request: ModelRequest) -> ModelResponse: - """Perform a single call to the LLM endpoint""" + """Submit a single call to the LLM endpoint""" model_response: ModelResponse = ModelResponse() - if not self._check_connectivity(): + if not self._connect_to_endpoint(): return model_response timeout_ms = DEFAULT_TGI_TIMEOUT_MS if model_request.timeout == 0 else model_request.timeout @@ -131,7 +133,7 @@ def batch_call_model(self, model_request: ModelBatchRequest) -> ModelBatchRespon """Submit a batch of calls to the LLM endpoint""" model_responses: ModelBatchResponse = ModelBatchResponse() - if not self._check_connectivity(): + if not self._connect_to_endpoint(): return model_responses timeout_ms = DEFAULT_TGI_TIMEOUT_MS if model_request.timeout == 0 else model_request.timeout From 38f33a7448abd6845039f6830391c2608e3987f9 Mon Sep 17 00:00:00 2001 From: Andrew Ramsay Date: Thu, 23 May 2024 14:28:02 +0100 Subject: [PATCH 45/57] Adding dnspython as a dependency --- llm_functionalities/requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/llm_functionalities/requirements.txt b/llm_functionalities/requirements.txt index 1b081ae..d007041 100644 --- a/llm_functionalities/requirements.txt +++ b/llm_functionalities/requirements.txt @@ -1,2 +1,3 @@ # summarization seems to be broken in the recent released versions git+https://github.com/huggingface/huggingface_hub.git +dnspython==2.6.1 From 62bb76f71ed32f9205a5b007988fa33ce2985a15 Mon Sep 17 00:00:00 2001 From: Andrew Ramsay Date: Thu, 23 May 2024 14:29:14 +0100 Subject: [PATCH 46/57] Update LLM ingredient substitution While I was trying to add some LLM tests, I found that this class didn't seem to be set up to parse the model responses correctly. The original prompt asks it to produce output in the format: > {"step_text": ""} and then ends the prompt with: > {"step_text:": " The Alpaca model we've been using does successfully complete this with something like: > one two three"} but the problem is the `extract_response` method assumes that the generated text will be a complete parseable JSON string, not a partial one. This leads the method to return an empty dict as if the response had failed to generate any text, when it will normally have generated something valid. I've changed the prompt so it will output the text without any extra formatting, removed `extract_response` because it's now not required, and removed `process_response_text` because it seemed to be already unused. --- .../llm_ingredient_step_text_rewriter.py | 47 ++----------------- 1 file changed, 3 insertions(+), 44 deletions(-) diff --git a/functionalities/llm_ingredient_substitution/llm_ingredient_step_text_rewriter.py b/functionalities/llm_ingredient_substitution/llm_ingredient_step_text_rewriter.py index 62cb70c..3ebd1b1 100644 --- a/functionalities/llm_ingredient_substitution/llm_ingredient_step_text_rewriter.py +++ b/functionalities/llm_ingredient_substitution/llm_ingredient_step_text_rewriter.py @@ -40,51 +40,9 @@ def extract_ingredient_name(self, question_text: str, requirements_list) -> str: return str_item return "" - def extract_response(self, generated_answer) -> dict: - valid_dict = {} - start_token = "{" - end_token = "}" - if end_token not in generated_answer: - return valid_dict - - if start_token in generated_answer: - generated_answer = "{" + generated_answer.split(start_token)[-1] - - try: - generated_answer = generated_answer.replace(' \n ', '').replace('\n', '') - valid_dict = json.loads(generated_answer) - logger.info(f'Managed to parse rewritten step text: {valid_dict}') - - if "step_text" in list(valid_dict.keys()): - return valid_dict - else: - logger.info(f'Dictionary contents not valid: {valid_dict}') - except Exception as e: - logger.info(f'Could not parse response >{generated_answer}<: {e}') - return valid_dict - - def process_response_text(self, response_text: str) -> str: - # remove whitespace - text = ' '.join(response_text.split("\n")) - if not re.search(r'[.!?]', text[-1]): - # If not, add a period (.) to the end of the text - text += '.' - - # split at punctuation - sentences = re.split(r'(?<=[.!?])\s+', text) - - complete_sentences = [] - for sentence in sentences: - if sentence.endswith(('.', '!', '?')): - # remove numbered lists - sentence = re.sub(r'\d+\.', '', sentence) - complete_sentences.append(sentence.strip()) - return ' '.join(complete_sentences) - def build_prompt(self, task_title: str, step_text: str, ingredient: ReplacedIngredient) -> str: model_input = f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. - Follow this format: {{\"step_text\": \"\"}} ### Instruction: Adjust the step text to replace the original ingredient with the replacement ingredient. @@ -94,7 +52,7 @@ def build_prompt(self, task_title: str, step_text: str, ingredient: ReplacedIngr Replacement ingredient:{ingredient.replacement.amount} {ingredient.replacement.name} Step: {step_text} - ### Response: {{\"step_text\": \"""" + ### Response: """ return model_input def adjust_step_texts(self, request: Session) -> Session: @@ -131,6 +89,7 @@ def adjust_step_texts(self, request: Session) -> Session: if len(adjusting_request.step) == 0: logger.info('Nothing rewritten because no matches in steps') else: + logger.info("Rewriting steps via LLM") adjusted_step_response: AdjustedStepGenerationResponse = self.rewrite_steps(adjusting_request) for step, id in zip(adjusted_step_response.step_text, adjusted_step_response.ids): rewritten_steps[id] = step @@ -196,6 +155,6 @@ def rewrite_steps(self, request: AdjustedStepGenerationRequest) -> AdjustedStepG llm_responses = self.llm.batch_call_model(model_batch_request) for idx, text in enumerate(llm_responses.text): - llm_step_texts.step_text.append(self.extract_response(text).get('step_text', '')) + llm_step_texts.step_text.append(text) llm_step_texts.ids.append(ids[idx]) return llm_step_texts From 09591f6eb1e7cde1271454a8ff7ed1708e5a5d2b Mon Sep 17 00:00:00 2001 From: Andrew Ramsay Date: Thu, 23 May 2024 14:38:48 +0100 Subject: [PATCH 47/57] Prevent a crash if LLM response is empty If the input string was empty, these methods would throw an IndexError when calling `re.search` (because of trying to do `""[-1]`) --- .../llm_proactive_question_generation.py | 4 ++++ .../llm_summary_generation/llm_summary_generation.py | 4 ++++ 2 files changed, 8 insertions(+) diff --git a/functionalities/llm_proactive_question_generation/llm_proactive_question_generation.py b/functionalities/llm_proactive_question_generation/llm_proactive_question_generation.py index 85c2baf..341f8de 100644 --- a/functionalities/llm_proactive_question_generation/llm_proactive_question_generation.py +++ b/functionalities/llm_proactive_question_generation/llm_proactive_question_generation.py @@ -26,6 +26,10 @@ def extract_question(generated_question): def process_response_text(response_text: str) -> str: + if len(response_text) == 0: + logger.warn("Empty response from LLMProactiveQuestionGeneration") + return "" + # remove whitespace text = ' '.join(response_text.split("\n")) if not re.search(r'[.!?]', text[-1]): diff --git a/functionalities/llm_summary_generation/llm_summary_generation.py b/functionalities/llm_summary_generation/llm_summary_generation.py index 6a97ac7..c366315 100644 --- a/functionalities/llm_summary_generation/llm_summary_generation.py +++ b/functionalities/llm_summary_generation/llm_summary_generation.py @@ -18,6 +18,10 @@ def process_response_text(response_text: str) -> str: + if len(response_text) == 0: + logger.warn("Empty response from LLMSummaryGeneration") + return "" + # remove whitespace text = " ".join(response_text.split("\n")) if not re.search(r"[.!?]", text[-1]): From 90ed0fb78862098e6c6fd65dac065912bde18573 Mon Sep 17 00:00:00 2001 From: Andrew Ramsay Date: Thu, 23 May 2024 16:03:20 +0100 Subject: [PATCH 48/57] Fix parsing of LLM responses The prompt for this type of request expects a JSON-compatible string response with 2 fields. However the prompt already includes the first part of the expected string (`"{\"name\"`) and the parsing fails because it's run only on the response, which is an incomplete JSON string. This just adjusts the parsing method to add the section of the response format included in the prompt to the start of the actual response, allowing the parsing to function as intended. --- .../llm_ingredient_substitution_generation.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/functionalities/llm_ingredient_substitution/llm_ingredient_substitution_generation.py b/functionalities/llm_ingredient_substitution/llm_ingredient_substitution_generation.py index 42e5782..7deda0f 100644 --- a/functionalities/llm_ingredient_substitution/llm_ingredient_substitution_generation.py +++ b/functionalities/llm_ingredient_substitution/llm_ingredient_substitution_generation.py @@ -14,15 +14,15 @@ from utils import logger -def extract_replacement_response(generated_answer, original_ing: Ingredient) -> dict: +def extract_replacement_response(generated_answer: str, original_ing: Ingredient) -> dict: valid_dict = {} - start_token = "{" end_token = "}" + # the part of the response format included in the prompt + prefix = "{\"name\"" if end_token not in generated_answer: return valid_dict - if start_token in generated_answer: - generated_answer = "{" + generated_answer.split(start_token)[-1] + generated_answer = f"{prefix}{generated_answer}" try: generated_answer = generated_answer.replace(" \n ", "").replace("\n", "") From 5787c5c280aba07a64f0fc06e8d2c7f0b67e214f Mon Sep 17 00:00:00 2001 From: Andrew Ramsay Date: Thu, 23 May 2024 16:50:38 +0100 Subject: [PATCH 49/57] move llm downloads.toml to TGI --- {llm_functionalities => hg_tgi}/downloads.toml | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename {llm_functionalities => hg_tgi}/downloads.toml (100%) diff --git a/llm_functionalities/downloads.toml b/hg_tgi/downloads.toml similarity index 100% rename from llm_functionalities/downloads.toml rename to hg_tgi/downloads.toml From 2d412c91c00d94a68f2e06367c9f4af741404ca9 Mon Sep 17 00:00:00 2001 From: Andrew Ramsay Date: Thu, 23 May 2024 17:00:29 +0100 Subject: [PATCH 50/57] Sync folder name with service name --- {hg_tgi => tgi}/Dockerfile | 0 {hg_tgi => tgi}/downloads.toml | 0 {hg_tgi => tgi}/tgi-wrapper.sh | 0 3 files changed, 0 insertions(+), 0 deletions(-) rename {hg_tgi => tgi}/Dockerfile (100%) rename {hg_tgi => tgi}/downloads.toml (100%) rename {hg_tgi => tgi}/tgi-wrapper.sh (100%) diff --git a/hg_tgi/Dockerfile b/tgi/Dockerfile similarity index 100% rename from hg_tgi/Dockerfile rename to tgi/Dockerfile diff --git a/hg_tgi/downloads.toml b/tgi/downloads.toml similarity index 100% rename from hg_tgi/downloads.toml rename to tgi/downloads.toml diff --git a/hg_tgi/tgi-wrapper.sh b/tgi/tgi-wrapper.sh similarity index 100% rename from hg_tgi/tgi-wrapper.sh rename to tgi/tgi-wrapper.sh From c12cf13604336dab04e8e9e445c920fd5c343089 Mon Sep 17 00:00:00 2001 From: Andrew Ramsay Date: Thu, 23 May 2024 17:11:51 +0100 Subject: [PATCH 51/57] Move downloads.toml back llm_functionalities By keeping this in LLM functionalities it allows the `Downloader` class to continue working as normal, even though it's technically going to be downloading files for use by TGI --- {tgi => llm_functionalities}/downloads.toml | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename {tgi => llm_functionalities}/downloads.toml (100%) diff --git a/tgi/downloads.toml b/llm_functionalities/downloads.toml similarity index 100% rename from tgi/downloads.toml rename to llm_functionalities/downloads.toml From b76aad2543b893a6cda2ac9750bcaea49377663c Mon Sep 17 00:00:00 2001 From: Andrew Ramsay Date: Thu, 23 May 2024 17:14:21 +0100 Subject: [PATCH 52/57] Download alpaca_llm to a TGI Docker volume --- llm_functionalities/downloads.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llm_functionalities/downloads.toml b/llm_functionalities/downloads.toml index 4a84c6e..41e64e4 100644 --- a/llm_functionalities/downloads.toml +++ b/llm_functionalities/downloads.toml @@ -1,6 +1,6 @@ name = "llm_functionalities" -base_path = "/shared/file_system/downloads/llm_functionalities" +base_path = "/shared/file_system/downloads/tgi/local/" # from model_requirements.txt [[sources]] From 39e00f986061c1199c2bea3e39e027bcfc22e811 Mon Sep 17 00:00:00 2001 From: Andrew Ramsay Date: Thu, 23 May 2024 17:14:57 +0100 Subject: [PATCH 53/57] Fix path to tgi Dockerfile --- docker-compose.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docker-compose.yml b/docker-compose.yml index 91ad30b..12866c5 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -311,7 +311,7 @@ services: container_name: tgi build: context: ./ - dockerfile: hg_tgi/Dockerfile + dockerfile: tgi/Dockerfile ports: # this isn't needed for llm_functionalities (it uses the internal network), # but might be useful to be able to submit requests to the instance from From 88e21c834ba8287862813303276eda96c9b5b5d3 Mon Sep 17 00:00:00 2001 From: Andrew Ramsay Date: Thu, 23 May 2024 17:15:05 +0100 Subject: [PATCH 54/57] Update TGI Docker volumes This helps make clear which files are local and which have been downloaded from huggingface.co --- docker-compose.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docker-compose.yml b/docker-compose.yml index 12866c5..d30cb7a 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -321,12 +321,12 @@ services: volumes: # if the model is downloaded from huggingface, the files will be stored # in this location - - ./shared/file_system/tgi:/data + - ./shared/file_system/downloads/tgi/huggingface:/data # this volume is used to allow the use of the existing Alpaca LLM # that OAT is currently set up to use. setting MODEL_ID below to a # path instead of a huggingface model ID will cause TGI to load model # files from that location - - ./shared/file_system/downloads/llm_functionalities/:/models + - ./shared/file_system/downloads/tgi/local:/models environment: # setting the value of MODEL_ID is equivalent to passing "--model-id" parameter # to the TGI launcher. the default is to use the Alpaca LLM model that OAT From 196707dcedc38be9eec6c4b8960fd5e0c4942f3e Mon Sep 17 00:00:00 2001 From: Andrew Ramsay Date: Thu, 23 May 2024 17:25:25 +0100 Subject: [PATCH 55/57] Add auto-downloading of "local" models --- llm_functionalities/llm_runner/llm_runner.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/llm_functionalities/llm_runner/llm_runner.py b/llm_functionalities/llm_runner/llm_runner.py index 1dd1294..2cdd922 100644 --- a/llm_functionalities/llm_runner/llm_runner.py +++ b/llm_functionalities/llm_runner/llm_runner.py @@ -7,7 +7,7 @@ import dns.resolver from huggingface_hub import InferenceClient -from utils import logger +from utils import logger, Downloader from compiled_protobufs.llm_pb2 import ( ModelRequest, ModelResponse, @@ -23,6 +23,11 @@ class LLMRunner: def __init__(self): + downloader = Downloader() + logger.info(f"Found {len(downloader.sources)} enabled download sources, retrieving them...") + # just download any defined sources + downloader.download() + endpoint_url = os.environ.get("INFERENCE_ENDPOINT_URL", None) if endpoint_url is None: logger.error("No INFERENCE_ENDPOINT_URL defined, container will exit") From da5e20bb07f1d67d17da281facdf2e22635f3901 Mon Sep 17 00:00:00 2001 From: Andrew Ramsay Date: Thu, 23 May 2024 17:25:41 +0100 Subject: [PATCH 56/57] Remove the connection attempt in __init__ The connection will be attempted for the first time when the first LLM request is received instead --- llm_functionalities/llm_runner/llm_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llm_functionalities/llm_runner/llm_runner.py b/llm_functionalities/llm_runner/llm_runner.py index 2cdd922..5908795 100644 --- a/llm_functionalities/llm_runner/llm_runner.py +++ b/llm_functionalities/llm_runner/llm_runner.py @@ -39,7 +39,7 @@ def __init__(self): self.endpoint_url = endpoint_url self.client = None - self._connect_to_endpoint() + logger.info("llm_functionalities initialized") def _connect_to_endpoint(self) -> bool: """Attempt to make a connection to the configured TGI endpoint. From 0416bb53cffc1d56785407af13ce59deb4226f17 Mon Sep 17 00:00:00 2001 From: Andrew Ramsay Date: Thu, 23 May 2024 17:26:52 +0100 Subject: [PATCH 57/57] fix path in TGI dockerfile --- tgi/Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tgi/Dockerfile b/tgi/Dockerfile index 493a93f..174b527 100644 --- a/tgi/Dockerfile +++ b/tgi/Dockerfile @@ -4,5 +4,5 @@ RUN apt-get update && apt-get install -y curl # wrapper script for the TGI launcher which just unpacks # parameters passed in using the TGI_PARAMS env var via # docker-compose.yml -COPY hg_tgi/tgi-wrapper.sh /tmp +COPY tgi/tgi-wrapper.sh /tmp ENTRYPOINT ["bash", "/tmp/tgi-wrapper.sh"]