diff --git a/backend_kobold_cpp.yaml b/backend_kobold_cpp.yaml index b4c4d804..e94ab928 100644 --- a/backend_kobold_cpp.yaml +++ b/backend_kobold_cpp.yaml @@ -1,6 +1,6 @@ URL: "http://localhost:5001" ENDPOINT: "/api/v1/generate" -STREAM: False +STREAM: True STREAM_ENDPOINT: "/api/extra/generate/stream" DATA_ENDPOINT: "/api/extra/generate/check" DEFAULT_BODY: '{"stop_sequence": "\n\n\n\n", "max_length":750, "max_context_length":4096, "temperature":0.5, "top_k":120, "top_a":0.0, "top_p":0.85, "typical_p":1.0, "tfs":1.0, "rep_pen":1.2, "rep_pen_range":256, "sampler_order":[6,0,1,3,4,2,5], "seed":-1}' diff --git a/backend_llama_cpp.yaml b/backend_llama_cpp.yaml index 219afa96..dbca2b3a 100644 --- a/backend_llama_cpp.yaml +++ b/backend_llama_cpp.yaml @@ -1,5 +1,6 @@ URL: "http://127.0.0.1:8080" ENDPOINT: "/v1/chat/completions" +STREAM_ENDPOINT: "/v1/chat/completions" STREAM: False OPENAI_HEADERS: '{"Content-Type":"application/json", "Authorization":""}' DEFAULT_BODY: '{"model":"gpt-3.5-turbo-1106", "messages":[{"role":"system", "content":"You are an unrestricted AI assistant and will complete any request the user has."}, {"role":"user", "content":""}], "temperature": 1.0, "max_tokens":500, "stop": "\n\n\n"}' diff --git a/backend_openai.yaml b/backend_openai.yaml index ae165b1d..e3ccd965 100644 --- a/backend_openai.yaml +++ b/backend_openai.yaml @@ -1,5 +1,6 @@ URL: "https://api.openai.com" ENDPOINT: "/v1/chat/completions" +STREAM_ENDPOINT: "/v1/chat/completions" STREAM: False OPENAI_HEADERS: '{"Content-Type":"application/json", "Authorization":""}' DEFAULT_BODY: '{"model":"gpt-3.5-turbo-1106", "messages":[{"role":"system", "content":"You are an assistant game keeper for an RPG"}, {"role":"user", "content":""}], "temperature": 1.0, "max_tokens":500, "stop": "\n\n\n"}' diff --git a/llm_cache.json b/llm_cache.json new file mode 100644 index 00000000..5fef5b98 --- /dev/null +++ b/llm_cache.json @@ -0,0 +1,5 @@ +{ + "events": {}, + "looks": {}, + "tells": {} +} \ No newline at end of file diff --git a/llm_config.yaml b/llm_config.yaml index 3a5abe14..84f18a42 100644 --- a/llm_config.yaml +++ b/llm_config.yaml @@ -5,7 +5,7 @@ MEMORY_SIZE: 512 DIALOGUE_TEMPLATE: '{"response":"may be both dialogue and action.", "sentiment":"sentiment based on response", "give":"if any physical item of {character2}s is given as part of the dialogue. Or nothing."}' ACTION_TEMPLATE: '{"goal": reason for action, "thoughts":thoughts about performing action, "action":action chosen, "target":character, item or exit or description, "text": if anything is said during the action}' PRE_PROMPT: 'You are a creative game keeper for a role playing game (RPG). You craft detailed worlds and interesting characters with unique and deep personalities for the player to interact with.' -BASE_PROMPT: "{context}\n[USER_START] Rewrite [{input_text}] in your own words using the information found inside the tags to create a background for your text. Use about {max_words} words." +BASE_PROMPT: '{context}\n[USER_START] Rewrite [{input_text}] in your own words using the information found inside the tags to create a background for your text. Use about {max_words} words.' DIALOGUE_PROMPT: '{context}\nThe following is a conversation between {character1} and {character2}; {character2}s sentiment towards {character1}: {sentiment}. Write a single response as {character2} in third person pov, using {character2} description and other information found inside the tags. If {character2} has a quest active, they will discuss it based on its status. Respond in JSON using this template: """{dialogue_template}""". [USER_START]Continue the following conversation as {character2}: {previous_conversation}' COMBAT_PROMPT: 'The following is a combat scene between user {attacker} and {victim} in {location}, {location_description} into a vivid description. [USER_START] Rewrite the following combat result in about 150 words, using the characters weapons and their health status: 1.0 is highest, 0.0 is lowest. Combat Result: {attacker_msg}' PRE_JSON_PROMPT: 'Below is an instruction that describes a task, paired with an input that provides further context. Write a response in valid JSON format that appropriately completes the request.' diff --git a/requirements_dev.txt b/requirements_dev.txt index 476e4474..044ba760 100644 --- a/requirements_dev.txt +++ b/requirements_dev.txt @@ -9,5 +9,6 @@ pillow packaging==20.3 pillow>=8.3.2 responses==0.13.3 +aioresponses==0.7.6 diff --git a/tale/llm/LivingNpc.py b/tale/llm/LivingNpc.py index 29a9b4b6..06d929ab 100644 --- a/tale/llm/LivingNpc.py +++ b/tale/llm/LivingNpc.py @@ -260,7 +260,7 @@ def tell_action_deferred(self): actions = '\n'.join(self.deferred_actions) deferred_action = ParseResult(verb='idle-action', unparsed=actions, who_info=None) self.tell_others(actions + '\n') - #self.location._notify_action_all(deferred_action, actor=self) + self.location._notify_action_all(deferred_action, actor=self) self.deferred_actions.clear() def _clear_quest(self): diff --git a/tale/llm/character.py b/tale/llm/character.py index 2a1ae137..dce23001 100644 --- a/tale/llm/character.py +++ b/tale/llm/character.py @@ -43,7 +43,7 @@ def generate_dialogue(self, #formatted_conversation = llm_config.params['USER_START'] formatted_conversation = conversation.replace('', '\n')#llm_config.params['USER_END'] + '\n' + llm_config.params['USER_START']) prompt += self.dialogue_prompt.format( - context=context.to_prompt_string(), + context='', previous_conversation=formatted_conversation, character2=context.speaker_name, character1=context.target_name, @@ -52,10 +52,7 @@ def generate_dialogue(self, sentiment=sentiment) request_body = deepcopy(self.default_body) request_body['grammar'] = self.json_grammar - - - #if not self.stream: - response = self.io_util.synchronous_request(request_body, prompt=prompt) + response = self.io_util.synchronous_request(request_body, prompt=prompt, context=context.to_prompt_string()) try: json_result = json.loads(parse_utils.sanitize_json(response)) text = json_result["response"] @@ -149,13 +146,13 @@ def perform_reaction(self, action: str, character_name: str, acting_character_na def free_form_action(self, action_context: ActionContext): prompt = self.pre_prompt prompt += self.free_form_action_prompt.format( - context=action_context.to_prompt_string(), + context = '', character_name=action_context.character_name, action_template=self.action_template) request_body = deepcopy(self.default_body) request_body['grammar'] = self.json_grammar try : - text = self.io_util.synchronous_request(request_body, prompt=prompt) + text = self.io_util.synchronous_request(request_body, prompt=prompt, context=action_context.to_prompt_string()) if not text: return None response = json.loads(parse_utils.sanitize_json(text)) diff --git a/tale/llm/io_adapters.py b/tale/llm/io_adapters.py new file mode 100644 index 00000000..d7110d61 --- /dev/null +++ b/tale/llm/io_adapters.py @@ -0,0 +1,147 @@ + +from abc import ABC, abstractmethod +import asyncio +import json +import time + +import aiohttp +import requests + +from tale.errors import LlmResponseException + + +class AbstractIoAdapter(ABC): + + def __init__(self, url: str, stream_endpoint: str, user_start_prompt: str, user_end_prompt: str): + self.url = url + self.stream_endpoint = stream_endpoint + self.user_start_prompt = user_start_prompt + self.user_end_prompt = user_end_prompt + + @abstractmethod + def stream_request(self, request_body: dict, io = None, wait: bool = False) -> str: + pass + + @abstractmethod + async def _do_stream_request(self, url: str, request_body: dict,) -> bool: + pass + + @abstractmethod + def _parse_result(self, result: str) -> str: + pass + + @abstractmethod + def _set_prompt(self, request_body: dict, prompt: str, context: str = '') -> dict: + pass + +class KoboldCppAdapter(AbstractIoAdapter): + + def __init__(self, url: str, stream_endpoint: str, data_endpoint: str, user_start_prompt: str, user_end_prompt: str): + super().__init__(url, stream_endpoint, user_start_prompt, user_end_prompt) + self.data_endpoint = data_endpoint + + def stream_request(self, request_body: dict, io = None, wait: bool = False) -> str: + result = asyncio.run(self._do_stream_request(self.url + self.stream_endpoint, request_body)) + + try: + if result: + return self._do_process_result(self.url + self.data_endpoint, io, wait) + except LlmResponseException as exc: + print("Error parsing response from backend - ", exc) + return '' + + async def _do_stream_request(self, url: str, request_body: dict,) -> bool: + """ Send request to stream endpoint async to not block the main thread""" + async with aiohttp.ClientSession() as session: + async with session.post(url, data=json.dumps(request_body)) as response: + if response.status == 200: + return True + else: + print("Error occurred:", response.status) + + def _do_process_result(self, url, io = None, wait: bool = False) -> str: + """ Process the result from the stream endpoint """ + tries = 0 + old_text = '' + while tries < 4: + time.sleep(0.25) + data = requests.post(url) + + text = json.loads(data.text)['results'][0]['text'] + + if len(text) == len(old_text): + tries += 1 + continue + if not wait: + new_text = text[len(old_text):] + io.output_no_newline(new_text, new_paragraph=False) + old_text = text + return old_text + + def _parse_result(self, result: str) -> str: + """ Parse the result from the stream endpoint """ + return json.loads(result)['results'][0]['text'] + + def _set_prompt(self, request_body: dict, prompt: str, context: str = '') -> dict: + if self.user_start_prompt: + prompt = prompt.replace('[USER_START]', self.user_start_prompt) + if self.user_end_prompt: + prompt = prompt + self.user_end_prompt + prompt.replace('{context}', '') + request_body['prompt'] = prompt + request_body['memory'] = context + return request_body + +class LlamaCppAdapter(AbstractIoAdapter): + + def stream_request(self, request_body: dict, io = None, wait: bool = False) -> str: + return asyncio.run(self._do_stream_request(self.url + self.stream_endpoint, request_body, io = io)) + + async def _do_stream_request(self, url: str, request_body: dict, io = None) -> str: + """ Send request to stream endpoint async to not block the main thread""" + request_body['stream'] = True + text = '' + async with aiohttp.ClientSession() as session: + async with session.post(url, data=json.dumps(request_body)) as response: + if response.status != 200: + print("Error occurred:", response.status) + return False + async for chunk in response.content.iter_any(): + decoded = chunk.decode('utf-8') + lines = decoded.split('\n') + for line in lines: + # Ignore empty lines + if not line.strip(): + continue + key, value = line.split(':', 1) + key = key.strip() + value = value.strip() + if key == 'data': + data = json.loads(value) + choice = data['choices'][0]['delta'] + content = choice.get('content', None) + + if content: + io.output_no_newline(content, new_paragraph=False) + text += content + #while len(lines) == 0: + # await asyncio.sleep(0.05) + + return text + + def _parse_result(self, result: str) -> str: + """ Parse the result from the stream endpoint """ + try: + return json.loads(result)['choices'][0]['message']['content'] + except: + raise LlmResponseException("Error parsing result from backend") + + def _set_prompt(self, request_body: dict, prompt: str, context: str = '') -> dict: + if self.user_start_prompt: + prompt = prompt.replace('[USER_START]', self.user_start_prompt) + if self.user_end_prompt: + prompt = prompt + self.user_end_prompt + if context: + prompt = prompt.format(context=context) + request_body['messages'][1]['content'] = prompt + return request_body \ No newline at end of file diff --git a/tale/llm/llm_io.py b/tale/llm/llm_io.py index 81abcb3d..d1663864 100644 --- a/tale/llm/llm_io.py +++ b/tale/llm/llm_io.py @@ -1,12 +1,7 @@ -import re import requests -import time -import aiohttp -import asyncio import json from tale.errors import LlmResponseException -import tale.parse_utils as parse_utils -from tale.player_utils import TextBuffer +from tale.llm.io_adapters import KoboldCppAdapter, LlamaCppAdapter class IoUtil(): """ Handles connection and data retrieval from backend """ @@ -19,107 +14,41 @@ def __init__(self, config: dict = None, backend_config: dict = None): self.url = backend_config['URL'] self.endpoint = backend_config['ENDPOINT'] - if self.backend != 'kobold_cpp': headers = json.loads(backend_config['OPENAI_HEADERS']) headers['Authorization'] = f"Bearer {backend_config['OPENAI_API_KEY']}" self.openai_json_format = json.loads(backend_config['OPENAI_JSON_FORMAT']) self.headers = headers + self.io_adapter = LlamaCppAdapter(self.url, backend_config['STREAM_ENDPOINT'], config['USER_START'], config['USER_END']) else: + self.io_adapter = KoboldCppAdapter(self.url, backend_config['STREAM_ENDPOINT'], backend_config['DATA_ENDPOINT'], config['USER_START'], config['USER_END']) self.headers = {} + self.stream = backend_config['STREAM'] - if self.stream: - self.stream_endpoint = backend_config['STREAM_ENDPOINT'] - self.data_endpoint = backend_config['DATA_ENDPOINT'] - self.user_start_prompt = config['USER_START'] - self.user_end_prompt = config['USER_END'] - def synchronous_request(self, request_body: dict, prompt: str) -> str: + + def synchronous_request(self, request_body: dict, prompt: str, context: str = '') -> str: """ Send request to backend and return the result """ if request_body.get('grammar', None) and 'openai' in self.url: # TODO: temp fix for openai request_body.pop('grammar') request_body['response_format'] = self.openai_json_format - self._set_prompt(request_body, prompt) + request_body = self.io_adapter._set_prompt(request_body, prompt, context) + print(request_body) response = requests.post(self.url + self.endpoint, headers=self.headers, data=json.dumps(request_body)) - try: - if self.backend == 'kobold_cpp': - parsed_response = self._parse_kobold_result(response.text) - else: - parsed_response = self._parse_openai_result(response.text) - except LlmResponseException as exc: - print("Error parsing response from backend - ", exc) - return '' - return parsed_response + if response.status_code == 200: + return self.io_adapter._parse_result(response.text) + return '' - def asynchronous_request(self, request_body: dict, prompt: str) -> str: + def asynchronous_request(self, request_body: dict, prompt: str, context: str = '') -> str: if self.backend != 'kobold_cpp': - return self.synchronous_request(request_body, prompt) - return self.stream_request(request_body, wait=True, prompt=prompt) - - def stream_request(self, request_body: dict, prompt: str, io = None, wait: bool = False) -> str: - if self.backend != 'kobold_cpp': - raise NotImplementedError("Currently does not support streaming requests for OpenAI") - self._set_prompt(request_body, prompt) - result = asyncio.run(self._do_stream_request(self.url + self.stream_endpoint, request_body)) - if result: - return self._do_process_result(self.url + self.data_endpoint, io, wait) - return '' + return self.synchronous_request(request_body=request_body, prompt=prompt, context=context) + return self.stream_request(request_body, wait=True, prompt=prompt, context=context) - async def _do_stream_request(self, url: str, request_body: dict,) -> bool: - """ Send request to stream endpoint async to not block the main thread""" - async with aiohttp.ClientSession() as session: - async with session.post(url, data=json.dumps(request_body)) as response: - if response.status == 200: - return True - else: - # Handle errors - print("Error occurred:", response.status) + def stream_request(self, request_body: dict, prompt: str, context: str = '', io = None, wait: bool = False) -> str: + if self.io_adapter: + request_body = self.io_adapter._set_prompt(request_body, prompt, context) + return self.io_adapter.stream_request(request_body, io, wait) + # fall back if no io adapter + return self.synchronous_request(request_body=request_body, prompt=prompt, context=context) - def _do_process_result(self, url, io = None, wait: bool = False) -> str: - """ Process the result from the stream endpoint """ - tries = 0 - old_text = '' - while tries < 4: - time.sleep(0.5) - data = requests.post(url) - text = self._parse_kobold_result(data.text) - - if len(text) == len(old_text): - tries += 1 - continue - if not wait: - new_text = text[len(old_text):] - io.output_no_newline(new_text, new_paragraph=False) - old_text = text - return old_text - - def _parse_kobold_result(self, result: str) -> str: - """ Parse the result from the kobold endpoint """ - return json.loads(result)['results'][0]['text'] - - def _parse_openai_result(self, result: str) -> str: - """ Parse the result from the openai endpoint """ - try: - return json.loads(result)['choices'][0]['message']['content'] - except: - raise LlmResponseException("Error parsing result from backend") - - def _set_prompt(self, request_body: dict, prompt: str) -> dict: - if self.user_start_prompt: - prompt = prompt.replace('[USER_START]', self.user_start_prompt) - if self.user_end_prompt: - prompt = prompt + self.user_end_prompt - if self.backend == 'kobold_cpp': - request_body['prompt'] = prompt - else : - request_body['messages'][1]['content'] = prompt - return request_body - - def _extract_context(self, full_string): - pattern = re.escape('') + "(.*?)" + re.escape('') - match = re.search(pattern, full_string, re.DOTALL) - if match: - return '' + match.group(1) + '' - else: - return '' \ No newline at end of file diff --git a/tale/llm/llm_utils.py b/tale/llm/llm_utils.py index 807a1ef7..ffa3c179 100644 --- a/tale/llm/llm_utils.py +++ b/tale/llm/llm_utils.py @@ -87,22 +87,22 @@ def evoke(self, message: str, short_len : bool=False, rolling_prompt='', alt_pro return output_template.format(message=message, text=cached_look), rolling_prompt trimmed_message = parse_utils.remove_special_chars(str(message)) - context = EvokeContext(story_context=self.__story_context, history=rolling_prompt if not skip_history or alt_prompt else '') + story_context = EvokeContext(story_context=self.__story_context, history=rolling_prompt if not skip_history or alt_prompt else '') prompt = self.pre_prompt prompt += alt_prompt or (self.evoke_prompt.format( - context=context.to_prompt_string(), + context = '', max_words=self.word_limit if not short_len else self.short_word_limit, input_text=str(trimmed_message))) request_body = deepcopy(self.default_body) if not self.stream: - text = self.io_util.synchronous_request(request_body, prompt=prompt) + text = self.io_util.synchronous_request(request_body, prompt=prompt, context=story_context.to_prompt_string()) llm_cache.cache_look(text, text_hash_value) return output_template.format(message=message, text=text), rolling_prompt if self.connection: self.connection.output(output_template.format(message=message, text='')) - text = self.io_util.stream_request(request_body=request_body, prompt=prompt, io=self.connection) + text = self.io_util.stream_request(request_body=request_body, prompt=prompt, context=story_context.to_prompt_string(), io=self.connection) llm_cache.cache_look(text, text_hash_value) return '\n', rolling_prompt diff --git a/tests/supportstuff.py b/tests/supportstuff.py index a4019885..158636a7 100644 --- a/tests/supportstuff.py +++ b/tests/supportstuff.py @@ -10,6 +10,7 @@ from wsgiref.simple_server import WSGIServer from tale import pubsub, util, driver, base, story +from tale.llm.io_adapters import AbstractIoAdapter from tale.llm.llm_utils import LlmUtil from tale.llm.llm_io import IoUtil @@ -66,11 +67,13 @@ def __init__(self, response: list = []) -> None: super().__init__() self.response = response # type: list self.backend = 'kobold_cpp' + self.io_adapter = None + self.stream = False - def synchronous_request(self, request_body: dict, prompt: str = None) -> str: + def synchronous_request(self, request_body: dict, prompt: str = None, context: str = '') -> str: return self.response.pop(0) if isinstance(self.response, list) > 0 and len(self.response) > 0 else self.response - def asynchronous_request(self, request_body: dict, prompt: str = None): + def asynchronous_request(self, request_body: dict, prompt: str = None, context: str = ''): return self.synchronous_request(request_body, prompt) def set_response(self, response: any): @@ -89,5 +92,3 @@ def get_request(self): def clear_requests(self): self.requests = [] - - \ No newline at end of file diff --git a/tests/test_llm_ext.py b/tests/test_llm_ext.py index 2460cf75..c85a2d09 100644 --- a/tests/test_llm_ext.py +++ b/tests/test_llm_ext.py @@ -136,6 +136,7 @@ class TestLivingNpcActions(): driver = FakeDriver() driver.story = DynamicStory() llm_util = LlmUtil(IoUtil(config=dummy_config, backend_config=dummy_backend_config)) # type: LlmUtil + llm_util.backend = dummy_config['BACKEND'] driver.llm_util = llm_util story = DynamicStory() driver.story = story diff --git a/tests/test_llm_io.py b/tests/test_llm_io.py index acd527a1..2223b6f7 100644 --- a/tests/test_llm_io.py +++ b/tests/test_llm_io.py @@ -2,74 +2,131 @@ import json import os +from aioresponses import aioresponses +import responses import yaml from tale.llm.llm_io import IoUtil +from tale.player import Player, PlayerConnection +from tale.tio.iobase import IoAdapterBase class TestLlmIo(): - llm_io = IoUtil() + - def setup(self): + def _load_config(self) -> dict: with open(os.path.realpath(os.path.join(os.path.dirname(__file__), "../llm_config.yaml")), "r") as stream: try: - self.config_file = yaml.safe_load(stream) + return yaml.safe_load(stream) except yaml.YAMLError as exc: print(exc) - self.llm_io.user_start_prompt = self.config_file['USER_START'] - self.llm_io.user_end_prompt = self.config_file['USER_END'] - def _load_backend_config(self, backend): + def _load_backend_config(self, backend) -> dict: with open(os.path.realpath(os.path.join(os.path.dirname(__file__), f"../backend_{backend}.yaml")), "r") as stream: try: - self.backend_config = yaml.safe_load(stream) + return yaml.safe_load(stream) except yaml.YAMLError as exc: print(exc) def test_set_prompt_kobold_cpp(self): - self.llm_io.backend = 'kobold_cpp' - self._load_backend_config('kobold_cpp') - prompt = self.config_file['BASE_PROMPT'] + config_file = self._load_config() + backend_config = self._load_backend_config('kobold_cpp') + prompt = config_file['BASE_PROMPT'] assert('### Instruction' not in prompt) assert('### Response' not in prompt) assert('USER_START' in prompt) assert('USER_END' not in prompt) - request_body = json.loads(self.backend_config['DEFAULT_BODY']) + request_body = json.loads(backend_config['DEFAULT_BODY']) - result = self.llm_io._set_prompt(request_body, prompt) - assert(self.config_file['USER_START'] in result['prompt']) - assert(self.config_file['USER_END'] in result['prompt']) + io_util = IoUtil(config=config_file, backend_config=backend_config) + result = io_util.io_adapter._set_prompt(request_body=request_body, prompt=prompt, context='') + assert(config_file['USER_START'] in result['prompt']) + assert(config_file['USER_END'] in result['prompt']) def test_set_prompt_openai(self): - self.backend = 'openai' - self._load_backend_config('openai') - self.llm_io.backend = 'openai' - prompt = self.config_file['BASE_PROMPT'] + config_file = self._load_config() + config_file['BACKEND'] = 'openai' + backend_config = self._load_backend_config('openai') + prompt = config_file['BASE_PROMPT'] assert('### Instruction' not in prompt) assert('### Response' not in prompt) assert('USER_START' in prompt) assert('USER_END' not in prompt) - request_body = json.loads(self.backend_config['DEFAULT_BODY']) - - result = self.llm_io._set_prompt(request_body, prompt) - assert(self.config_file['USER_START'] in result['messages'][1]['content']) - assert(self.config_file['USER_END'] in result['messages'][1]['content']) + request_body = json.loads(backend_config['DEFAULT_BODY']) + io_util = IoUtil(config=config_file, backend_config=backend_config) + result = io_util.io_adapter._set_prompt(request_body=request_body, prompt=prompt, context='') + assert(config_file['USER_START'] in result['messages'][1]['content']) + assert(config_file['USER_END'] in result['messages'][1]['content']) def test_set_prompt_llama_cpp(self): - self.backend = 'llama_cpp' - self._load_backend_config('llama_cpp') - self.llm_io.backend = 'llama_cpp' - prompt = self.config_file['BASE_PROMPT'] + + config_file = self._load_config() + config_file['BACKEND'] = 'llama_cpp' + backend_config = self._load_backend_config('llama_cpp') + prompt = config_file['BASE_PROMPT'] assert('### Instruction' not in prompt) assert('### Response' not in prompt) assert('USER_START' in prompt) assert('USER_END' not in prompt) - request_body = json.loads(self.backend_config['DEFAULT_BODY']) + request_body = json.loads(backend_config['DEFAULT_BODY']) + + io_util = IoUtil(config=config_file, backend_config=backend_config) + result = io_util.io_adapter._set_prompt(request_body=request_body, prompt=prompt, context='') + assert(config_file['USER_START'] in result['messages'][1]['content']) + assert(config_file['USER_END'] in result['messages'][1]['content']) - result = self.llm_io._set_prompt(request_body, prompt) - assert(self.config_file['USER_START'] in result['messages'][1]['content']) - assert(self.config_file['USER_END'] in result['messages'][1]['content']) \ No newline at end of file + @responses.activate + def test_error_response(self): + config_file = self._load_config() + backend_config = self._load_backend_config('kobold_cpp') + responses.add(responses.POST, backend_config['URL'] + backend_config['ENDPOINT'], + json={'results':['']}, status=500) + io_util = IoUtil(config=config_file, backend_config=backend_config) + + response = io_util.synchronous_request(request_body=json.loads(backend_config['DEFAULT_BODY']), prompt='test evoke', context='') + assert(response == '') + + @responses.activate + def test_stream_kobold_cpp(self): + config = {'BACKEND':'kobold_cpp', 'USER_START':'', 'USER_END':''} + with open(os.path.realpath(os.path.join(os.path.dirname(__file__), f"../backend_kobold_cpp.yaml")), "r") as stream: + try: + backend_config = yaml.safe_load(stream) + except yaml.YAMLError as exc: + print(exc) + io_util = IoUtil(config=config, backend_config=backend_config) # type: IoUtil + io_util.stream = True + conn = PlayerConnection(Player('test', 'm')) + + responses.add(responses.POST, backend_config['URL'] + backend_config['DATA_ENDPOINT'], + json={'results':[{'text':'stream test'}]}, status=200) + with aioresponses() as mocked_responses: + # Mock the response for the specified URL + mocked_responses.post(backend_config['URL'] + backend_config['STREAM_ENDPOINT'], + status=200, + body="{'results':[{'text':'stream test'}]}") + result = io_util.stream_request(request_body=json.loads(backend_config['DEFAULT_BODY']), prompt='test evoke', context='', io = IoAdapterBase(conn)) + assert(result == 'stream test') + + def test_stream_llama_cpp(self): + config = {'BACKEND':'llama_cpp', 'USER_START':'', 'USER_END':''} + with open(os.path.realpath(os.path.join(os.path.dirname(__file__), f"../backend_llama_cpp.yaml")), "r") as stream: + try: + backend_config = yaml.safe_load(stream) + except yaml.YAMLError as exc: + print(exc) + io_util = IoUtil(config=config, backend_config=backend_config) # type: IoUtil + io_util.stream = True + conn = PlayerConnection(Player('test', 'm')) + + with aioresponses() as mocked_responses: + # Mock the response for the specified URL + mocked_responses.post(backend_config['URL'] + backend_config['STREAM_ENDPOINT'], + status=200, + body='data: {"choices":[{"delta":{"content":"stream test"}}]}') + result = io_util.stream_request(request_body=json.loads(backend_config['DEFAULT_BODY']), prompt='test evoke', context='', io = IoAdapterBase(conn)) + assert(result == 'stream test') diff --git a/tests/test_llm_utils.py b/tests/test_llm_utils.py index 65945676..514065f3 100644 --- a/tests/test_llm_utils.py +++ b/tests/test_llm_utils.py @@ -1,5 +1,8 @@ import datetime import json +import os + +import yaml from tale.image_gen.automatic1111 import Automatic1111 import tale.llm.llm_cache as llm_cache from tale import mud_context, weapon_type @@ -8,6 +11,7 @@ from tale.base import Item, Location, Weapon from tale.coord import Coord from tale.json_story import JsonStory +from tale.llm.llm_io import IoUtil from tale.llm.llm_utils import LlmUtil from tale.npc_defs import StationaryMob from tale.races import UnarmedAttack