diff --git a/backend_kobold_cpp.yaml b/backend_kobold_cpp.yaml
index b4c4d804..e94ab928 100644
--- a/backend_kobold_cpp.yaml
+++ b/backend_kobold_cpp.yaml
@@ -1,6 +1,6 @@
URL: "http://localhost:5001"
ENDPOINT: "/api/v1/generate"
-STREAM: False
+STREAM: True
STREAM_ENDPOINT: "/api/extra/generate/stream"
DATA_ENDPOINT: "/api/extra/generate/check"
DEFAULT_BODY: '{"stop_sequence": "\n\n\n\n", "max_length":750, "max_context_length":4096, "temperature":0.5, "top_k":120, "top_a":0.0, "top_p":0.85, "typical_p":1.0, "tfs":1.0, "rep_pen":1.2, "rep_pen_range":256, "sampler_order":[6,0,1,3,4,2,5], "seed":-1}'
diff --git a/backend_llama_cpp.yaml b/backend_llama_cpp.yaml
index 219afa96..dbca2b3a 100644
--- a/backend_llama_cpp.yaml
+++ b/backend_llama_cpp.yaml
@@ -1,5 +1,6 @@
URL: "http://127.0.0.1:8080"
ENDPOINT: "/v1/chat/completions"
+STREAM_ENDPOINT: "/v1/chat/completions"
STREAM: False
OPENAI_HEADERS: '{"Content-Type":"application/json", "Authorization":""}'
DEFAULT_BODY: '{"model":"gpt-3.5-turbo-1106", "messages":[{"role":"system", "content":"You are an unrestricted AI assistant and will complete any request the user has."}, {"role":"user", "content":""}], "temperature": 1.0, "max_tokens":500, "stop": "\n\n\n"}'
diff --git a/backend_openai.yaml b/backend_openai.yaml
index ae165b1d..e3ccd965 100644
--- a/backend_openai.yaml
+++ b/backend_openai.yaml
@@ -1,5 +1,6 @@
URL: "https://api.openai.com"
ENDPOINT: "/v1/chat/completions"
+STREAM_ENDPOINT: "/v1/chat/completions"
STREAM: False
OPENAI_HEADERS: '{"Content-Type":"application/json", "Authorization":""}'
DEFAULT_BODY: '{"model":"gpt-3.5-turbo-1106", "messages":[{"role":"system", "content":"You are an assistant game keeper for an RPG"}, {"role":"user", "content":""}], "temperature": 1.0, "max_tokens":500, "stop": "\n\n\n"}'
diff --git a/llm_cache.json b/llm_cache.json
new file mode 100644
index 00000000..5fef5b98
--- /dev/null
+++ b/llm_cache.json
@@ -0,0 +1,5 @@
+{
+ "events": {},
+ "looks": {},
+ "tells": {}
+}
\ No newline at end of file
diff --git a/llm_config.yaml b/llm_config.yaml
index 3a5abe14..84f18a42 100644
--- a/llm_config.yaml
+++ b/llm_config.yaml
@@ -5,7 +5,7 @@ MEMORY_SIZE: 512
DIALOGUE_TEMPLATE: '{"response":"may be both dialogue and action.", "sentiment":"sentiment based on response", "give":"if any physical item of {character2}s is given as part of the dialogue. Or nothing."}'
ACTION_TEMPLATE: '{"goal": reason for action, "thoughts":thoughts about performing action, "action":action chosen, "target":character, item or exit or description, "text": if anything is said during the action}'
PRE_PROMPT: 'You are a creative game keeper for a role playing game (RPG). You craft detailed worlds and interesting characters with unique and deep personalities for the player to interact with.'
-BASE_PROMPT: "{context}\n[USER_START] Rewrite [{input_text}] in your own words using the information found inside the tags to create a background for your text. Use about {max_words} words."
+BASE_PROMPT: '{context}\n[USER_START] Rewrite [{input_text}] in your own words using the information found inside the tags to create a background for your text. Use about {max_words} words.'
DIALOGUE_PROMPT: '{context}\nThe following is a conversation between {character1} and {character2}; {character2}s sentiment towards {character1}: {sentiment}. Write a single response as {character2} in third person pov, using {character2} description and other information found inside the tags. If {character2} has a quest active, they will discuss it based on its status. Respond in JSON using this template: """{dialogue_template}""". [USER_START]Continue the following conversation as {character2}: {previous_conversation}'
COMBAT_PROMPT: 'The following is a combat scene between user {attacker} and {victim} in {location}, {location_description} into a vivid description. [USER_START] Rewrite the following combat result in about 150 words, using the characters weapons and their health status: 1.0 is highest, 0.0 is lowest. Combat Result: {attacker_msg}'
PRE_JSON_PROMPT: 'Below is an instruction that describes a task, paired with an input that provides further context. Write a response in valid JSON format that appropriately completes the request.'
diff --git a/requirements_dev.txt b/requirements_dev.txt
index 476e4474..044ba760 100644
--- a/requirements_dev.txt
+++ b/requirements_dev.txt
@@ -9,5 +9,6 @@ pillow
packaging==20.3
pillow>=8.3.2
responses==0.13.3
+aioresponses==0.7.6
diff --git a/tale/llm/LivingNpc.py b/tale/llm/LivingNpc.py
index 29a9b4b6..06d929ab 100644
--- a/tale/llm/LivingNpc.py
+++ b/tale/llm/LivingNpc.py
@@ -260,7 +260,7 @@ def tell_action_deferred(self):
actions = '\n'.join(self.deferred_actions)
deferred_action = ParseResult(verb='idle-action', unparsed=actions, who_info=None)
self.tell_others(actions + '\n')
- #self.location._notify_action_all(deferred_action, actor=self)
+ self.location._notify_action_all(deferred_action, actor=self)
self.deferred_actions.clear()
def _clear_quest(self):
diff --git a/tale/llm/character.py b/tale/llm/character.py
index 2a1ae137..dce23001 100644
--- a/tale/llm/character.py
+++ b/tale/llm/character.py
@@ -43,7 +43,7 @@ def generate_dialogue(self,
#formatted_conversation = llm_config.params['USER_START']
formatted_conversation = conversation.replace('', '\n')#llm_config.params['USER_END'] + '\n' + llm_config.params['USER_START'])
prompt += self.dialogue_prompt.format(
- context=context.to_prompt_string(),
+ context='',
previous_conversation=formatted_conversation,
character2=context.speaker_name,
character1=context.target_name,
@@ -52,10 +52,7 @@ def generate_dialogue(self,
sentiment=sentiment)
request_body = deepcopy(self.default_body)
request_body['grammar'] = self.json_grammar
-
-
- #if not self.stream:
- response = self.io_util.synchronous_request(request_body, prompt=prompt)
+ response = self.io_util.synchronous_request(request_body, prompt=prompt, context=context.to_prompt_string())
try:
json_result = json.loads(parse_utils.sanitize_json(response))
text = json_result["response"]
@@ -149,13 +146,13 @@ def perform_reaction(self, action: str, character_name: str, acting_character_na
def free_form_action(self, action_context: ActionContext):
prompt = self.pre_prompt
prompt += self.free_form_action_prompt.format(
- context=action_context.to_prompt_string(),
+ context = '',
character_name=action_context.character_name,
action_template=self.action_template)
request_body = deepcopy(self.default_body)
request_body['grammar'] = self.json_grammar
try :
- text = self.io_util.synchronous_request(request_body, prompt=prompt)
+ text = self.io_util.synchronous_request(request_body, prompt=prompt, context=action_context.to_prompt_string())
if not text:
return None
response = json.loads(parse_utils.sanitize_json(text))
diff --git a/tale/llm/io_adapters.py b/tale/llm/io_adapters.py
new file mode 100644
index 00000000..d7110d61
--- /dev/null
+++ b/tale/llm/io_adapters.py
@@ -0,0 +1,147 @@
+
+from abc import ABC, abstractmethod
+import asyncio
+import json
+import time
+
+import aiohttp
+import requests
+
+from tale.errors import LlmResponseException
+
+
+class AbstractIoAdapter(ABC):
+
+ def __init__(self, url: str, stream_endpoint: str, user_start_prompt: str, user_end_prompt: str):
+ self.url = url
+ self.stream_endpoint = stream_endpoint
+ self.user_start_prompt = user_start_prompt
+ self.user_end_prompt = user_end_prompt
+
+ @abstractmethod
+ def stream_request(self, request_body: dict, io = None, wait: bool = False) -> str:
+ pass
+
+ @abstractmethod
+ async def _do_stream_request(self, url: str, request_body: dict,) -> bool:
+ pass
+
+ @abstractmethod
+ def _parse_result(self, result: str) -> str:
+ pass
+
+ @abstractmethod
+ def _set_prompt(self, request_body: dict, prompt: str, context: str = '') -> dict:
+ pass
+
+class KoboldCppAdapter(AbstractIoAdapter):
+
+ def __init__(self, url: str, stream_endpoint: str, data_endpoint: str, user_start_prompt: str, user_end_prompt: str):
+ super().__init__(url, stream_endpoint, user_start_prompt, user_end_prompt)
+ self.data_endpoint = data_endpoint
+
+ def stream_request(self, request_body: dict, io = None, wait: bool = False) -> str:
+ result = asyncio.run(self._do_stream_request(self.url + self.stream_endpoint, request_body))
+
+ try:
+ if result:
+ return self._do_process_result(self.url + self.data_endpoint, io, wait)
+ except LlmResponseException as exc:
+ print("Error parsing response from backend - ", exc)
+ return ''
+
+ async def _do_stream_request(self, url: str, request_body: dict,) -> bool:
+ """ Send request to stream endpoint async to not block the main thread"""
+ async with aiohttp.ClientSession() as session:
+ async with session.post(url, data=json.dumps(request_body)) as response:
+ if response.status == 200:
+ return True
+ else:
+ print("Error occurred:", response.status)
+
+ def _do_process_result(self, url, io = None, wait: bool = False) -> str:
+ """ Process the result from the stream endpoint """
+ tries = 0
+ old_text = ''
+ while tries < 4:
+ time.sleep(0.25)
+ data = requests.post(url)
+
+ text = json.loads(data.text)['results'][0]['text']
+
+ if len(text) == len(old_text):
+ tries += 1
+ continue
+ if not wait:
+ new_text = text[len(old_text):]
+ io.output_no_newline(new_text, new_paragraph=False)
+ old_text = text
+ return old_text
+
+ def _parse_result(self, result: str) -> str:
+ """ Parse the result from the stream endpoint """
+ return json.loads(result)['results'][0]['text']
+
+ def _set_prompt(self, request_body: dict, prompt: str, context: str = '') -> dict:
+ if self.user_start_prompt:
+ prompt = prompt.replace('[USER_START]', self.user_start_prompt)
+ if self.user_end_prompt:
+ prompt = prompt + self.user_end_prompt
+ prompt.replace('{context}', '')
+ request_body['prompt'] = prompt
+ request_body['memory'] = context
+ return request_body
+
+class LlamaCppAdapter(AbstractIoAdapter):
+
+ def stream_request(self, request_body: dict, io = None, wait: bool = False) -> str:
+ return asyncio.run(self._do_stream_request(self.url + self.stream_endpoint, request_body, io = io))
+
+ async def _do_stream_request(self, url: str, request_body: dict, io = None) -> str:
+ """ Send request to stream endpoint async to not block the main thread"""
+ request_body['stream'] = True
+ text = ''
+ async with aiohttp.ClientSession() as session:
+ async with session.post(url, data=json.dumps(request_body)) as response:
+ if response.status != 200:
+ print("Error occurred:", response.status)
+ return False
+ async for chunk in response.content.iter_any():
+ decoded = chunk.decode('utf-8')
+ lines = decoded.split('\n')
+ for line in lines:
+ # Ignore empty lines
+ if not line.strip():
+ continue
+ key, value = line.split(':', 1)
+ key = key.strip()
+ value = value.strip()
+ if key == 'data':
+ data = json.loads(value)
+ choice = data['choices'][0]['delta']
+ content = choice.get('content', None)
+
+ if content:
+ io.output_no_newline(content, new_paragraph=False)
+ text += content
+ #while len(lines) == 0:
+ # await asyncio.sleep(0.05)
+
+ return text
+
+ def _parse_result(self, result: str) -> str:
+ """ Parse the result from the stream endpoint """
+ try:
+ return json.loads(result)['choices'][0]['message']['content']
+ except:
+ raise LlmResponseException("Error parsing result from backend")
+
+ def _set_prompt(self, request_body: dict, prompt: str, context: str = '') -> dict:
+ if self.user_start_prompt:
+ prompt = prompt.replace('[USER_START]', self.user_start_prompt)
+ if self.user_end_prompt:
+ prompt = prompt + self.user_end_prompt
+ if context:
+ prompt = prompt.format(context=context)
+ request_body['messages'][1]['content'] = prompt
+ return request_body
\ No newline at end of file
diff --git a/tale/llm/llm_io.py b/tale/llm/llm_io.py
index 81abcb3d..d1663864 100644
--- a/tale/llm/llm_io.py
+++ b/tale/llm/llm_io.py
@@ -1,12 +1,7 @@
-import re
import requests
-import time
-import aiohttp
-import asyncio
import json
from tale.errors import LlmResponseException
-import tale.parse_utils as parse_utils
-from tale.player_utils import TextBuffer
+from tale.llm.io_adapters import KoboldCppAdapter, LlamaCppAdapter
class IoUtil():
""" Handles connection and data retrieval from backend """
@@ -19,107 +14,41 @@ def __init__(self, config: dict = None, backend_config: dict = None):
self.url = backend_config['URL']
self.endpoint = backend_config['ENDPOINT']
-
if self.backend != 'kobold_cpp':
headers = json.loads(backend_config['OPENAI_HEADERS'])
headers['Authorization'] = f"Bearer {backend_config['OPENAI_API_KEY']}"
self.openai_json_format = json.loads(backend_config['OPENAI_JSON_FORMAT'])
self.headers = headers
+ self.io_adapter = LlamaCppAdapter(self.url, backend_config['STREAM_ENDPOINT'], config['USER_START'], config['USER_END'])
else:
+ self.io_adapter = KoboldCppAdapter(self.url, backend_config['STREAM_ENDPOINT'], backend_config['DATA_ENDPOINT'], config['USER_START'], config['USER_END'])
self.headers = {}
+
self.stream = backend_config['STREAM']
- if self.stream:
- self.stream_endpoint = backend_config['STREAM_ENDPOINT']
- self.data_endpoint = backend_config['DATA_ENDPOINT']
- self.user_start_prompt = config['USER_START']
- self.user_end_prompt = config['USER_END']
- def synchronous_request(self, request_body: dict, prompt: str) -> str:
+
+ def synchronous_request(self, request_body: dict, prompt: str, context: str = '') -> str:
""" Send request to backend and return the result """
if request_body.get('grammar', None) and 'openai' in self.url:
# TODO: temp fix for openai
request_body.pop('grammar')
request_body['response_format'] = self.openai_json_format
- self._set_prompt(request_body, prompt)
+ request_body = self.io_adapter._set_prompt(request_body, prompt, context)
+ print(request_body)
response = requests.post(self.url + self.endpoint, headers=self.headers, data=json.dumps(request_body))
- try:
- if self.backend == 'kobold_cpp':
- parsed_response = self._parse_kobold_result(response.text)
- else:
- parsed_response = self._parse_openai_result(response.text)
- except LlmResponseException as exc:
- print("Error parsing response from backend - ", exc)
- return ''
- return parsed_response
+ if response.status_code == 200:
+ return self.io_adapter._parse_result(response.text)
+ return ''
- def asynchronous_request(self, request_body: dict, prompt: str) -> str:
+ def asynchronous_request(self, request_body: dict, prompt: str, context: str = '') -> str:
if self.backend != 'kobold_cpp':
- return self.synchronous_request(request_body, prompt)
- return self.stream_request(request_body, wait=True, prompt=prompt)
-
- def stream_request(self, request_body: dict, prompt: str, io = None, wait: bool = False) -> str:
- if self.backend != 'kobold_cpp':
- raise NotImplementedError("Currently does not support streaming requests for OpenAI")
- self._set_prompt(request_body, prompt)
- result = asyncio.run(self._do_stream_request(self.url + self.stream_endpoint, request_body))
- if result:
- return self._do_process_result(self.url + self.data_endpoint, io, wait)
- return ''
+ return self.synchronous_request(request_body=request_body, prompt=prompt, context=context)
+ return self.stream_request(request_body, wait=True, prompt=prompt, context=context)
- async def _do_stream_request(self, url: str, request_body: dict,) -> bool:
- """ Send request to stream endpoint async to not block the main thread"""
- async with aiohttp.ClientSession() as session:
- async with session.post(url, data=json.dumps(request_body)) as response:
- if response.status == 200:
- return True
- else:
- # Handle errors
- print("Error occurred:", response.status)
+ def stream_request(self, request_body: dict, prompt: str, context: str = '', io = None, wait: bool = False) -> str:
+ if self.io_adapter:
+ request_body = self.io_adapter._set_prompt(request_body, prompt, context)
+ return self.io_adapter.stream_request(request_body, io, wait)
+ # fall back if no io adapter
+ return self.synchronous_request(request_body=request_body, prompt=prompt, context=context)
- def _do_process_result(self, url, io = None, wait: bool = False) -> str:
- """ Process the result from the stream endpoint """
- tries = 0
- old_text = ''
- while tries < 4:
- time.sleep(0.5)
- data = requests.post(url)
- text = self._parse_kobold_result(data.text)
-
- if len(text) == len(old_text):
- tries += 1
- continue
- if not wait:
- new_text = text[len(old_text):]
- io.output_no_newline(new_text, new_paragraph=False)
- old_text = text
- return old_text
-
- def _parse_kobold_result(self, result: str) -> str:
- """ Parse the result from the kobold endpoint """
- return json.loads(result)['results'][0]['text']
-
- def _parse_openai_result(self, result: str) -> str:
- """ Parse the result from the openai endpoint """
- try:
- return json.loads(result)['choices'][0]['message']['content']
- except:
- raise LlmResponseException("Error parsing result from backend")
-
- def _set_prompt(self, request_body: dict, prompt: str) -> dict:
- if self.user_start_prompt:
- prompt = prompt.replace('[USER_START]', self.user_start_prompt)
- if self.user_end_prompt:
- prompt = prompt + self.user_end_prompt
- if self.backend == 'kobold_cpp':
- request_body['prompt'] = prompt
- else :
- request_body['messages'][1]['content'] = prompt
- return request_body
-
- def _extract_context(self, full_string):
- pattern = re.escape('') + "(.*?)" + re.escape('')
- match = re.search(pattern, full_string, re.DOTALL)
- if match:
- return '' + match.group(1) + ''
- else:
- return ''
\ No newline at end of file
diff --git a/tale/llm/llm_utils.py b/tale/llm/llm_utils.py
index 807a1ef7..ffa3c179 100644
--- a/tale/llm/llm_utils.py
+++ b/tale/llm/llm_utils.py
@@ -87,22 +87,22 @@ def evoke(self, message: str, short_len : bool=False, rolling_prompt='', alt_pro
return output_template.format(message=message, text=cached_look), rolling_prompt
trimmed_message = parse_utils.remove_special_chars(str(message))
- context = EvokeContext(story_context=self.__story_context, history=rolling_prompt if not skip_history or alt_prompt else '')
+ story_context = EvokeContext(story_context=self.__story_context, history=rolling_prompt if not skip_history or alt_prompt else '')
prompt = self.pre_prompt
prompt += alt_prompt or (self.evoke_prompt.format(
- context=context.to_prompt_string(),
+ context = '',
max_words=self.word_limit if not short_len else self.short_word_limit,
input_text=str(trimmed_message)))
request_body = deepcopy(self.default_body)
if not self.stream:
- text = self.io_util.synchronous_request(request_body, prompt=prompt)
+ text = self.io_util.synchronous_request(request_body, prompt=prompt, context=story_context.to_prompt_string())
llm_cache.cache_look(text, text_hash_value)
return output_template.format(message=message, text=text), rolling_prompt
if self.connection:
self.connection.output(output_template.format(message=message, text=''))
- text = self.io_util.stream_request(request_body=request_body, prompt=prompt, io=self.connection)
+ text = self.io_util.stream_request(request_body=request_body, prompt=prompt, context=story_context.to_prompt_string(), io=self.connection)
llm_cache.cache_look(text, text_hash_value)
return '\n', rolling_prompt
diff --git a/tests/supportstuff.py b/tests/supportstuff.py
index a4019885..158636a7 100644
--- a/tests/supportstuff.py
+++ b/tests/supportstuff.py
@@ -10,6 +10,7 @@
from wsgiref.simple_server import WSGIServer
from tale import pubsub, util, driver, base, story
+from tale.llm.io_adapters import AbstractIoAdapter
from tale.llm.llm_utils import LlmUtil
from tale.llm.llm_io import IoUtil
@@ -66,11 +67,13 @@ def __init__(self, response: list = []) -> None:
super().__init__()
self.response = response # type: list
self.backend = 'kobold_cpp'
+ self.io_adapter = None
+ self.stream = False
- def synchronous_request(self, request_body: dict, prompt: str = None) -> str:
+ def synchronous_request(self, request_body: dict, prompt: str = None, context: str = '') -> str:
return self.response.pop(0) if isinstance(self.response, list) > 0 and len(self.response) > 0 else self.response
- def asynchronous_request(self, request_body: dict, prompt: str = None):
+ def asynchronous_request(self, request_body: dict, prompt: str = None, context: str = ''):
return self.synchronous_request(request_body, prompt)
def set_response(self, response: any):
@@ -89,5 +92,3 @@ def get_request(self):
def clear_requests(self):
self.requests = []
-
-
\ No newline at end of file
diff --git a/tests/test_llm_ext.py b/tests/test_llm_ext.py
index 2460cf75..c85a2d09 100644
--- a/tests/test_llm_ext.py
+++ b/tests/test_llm_ext.py
@@ -136,6 +136,7 @@ class TestLivingNpcActions():
driver = FakeDriver()
driver.story = DynamicStory()
llm_util = LlmUtil(IoUtil(config=dummy_config, backend_config=dummy_backend_config)) # type: LlmUtil
+ llm_util.backend = dummy_config['BACKEND']
driver.llm_util = llm_util
story = DynamicStory()
driver.story = story
diff --git a/tests/test_llm_io.py b/tests/test_llm_io.py
index acd527a1..2223b6f7 100644
--- a/tests/test_llm_io.py
+++ b/tests/test_llm_io.py
@@ -2,74 +2,131 @@
import json
import os
+from aioresponses import aioresponses
+import responses
import yaml
from tale.llm.llm_io import IoUtil
+from tale.player import Player, PlayerConnection
+from tale.tio.iobase import IoAdapterBase
class TestLlmIo():
- llm_io = IoUtil()
+
- def setup(self):
+ def _load_config(self) -> dict:
with open(os.path.realpath(os.path.join(os.path.dirname(__file__), "../llm_config.yaml")), "r") as stream:
try:
- self.config_file = yaml.safe_load(stream)
+ return yaml.safe_load(stream)
except yaml.YAMLError as exc:
print(exc)
- self.llm_io.user_start_prompt = self.config_file['USER_START']
- self.llm_io.user_end_prompt = self.config_file['USER_END']
- def _load_backend_config(self, backend):
+ def _load_backend_config(self, backend) -> dict:
with open(os.path.realpath(os.path.join(os.path.dirname(__file__), f"../backend_{backend}.yaml")), "r") as stream:
try:
- self.backend_config = yaml.safe_load(stream)
+ return yaml.safe_load(stream)
except yaml.YAMLError as exc:
print(exc)
def test_set_prompt_kobold_cpp(self):
- self.llm_io.backend = 'kobold_cpp'
- self._load_backend_config('kobold_cpp')
- prompt = self.config_file['BASE_PROMPT']
+ config_file = self._load_config()
+ backend_config = self._load_backend_config('kobold_cpp')
+ prompt = config_file['BASE_PROMPT']
assert('### Instruction' not in prompt)
assert('### Response' not in prompt)
assert('USER_START' in prompt)
assert('USER_END' not in prompt)
- request_body = json.loads(self.backend_config['DEFAULT_BODY'])
+ request_body = json.loads(backend_config['DEFAULT_BODY'])
- result = self.llm_io._set_prompt(request_body, prompt)
- assert(self.config_file['USER_START'] in result['prompt'])
- assert(self.config_file['USER_END'] in result['prompt'])
+ io_util = IoUtil(config=config_file, backend_config=backend_config)
+ result = io_util.io_adapter._set_prompt(request_body=request_body, prompt=prompt, context='')
+ assert(config_file['USER_START'] in result['prompt'])
+ assert(config_file['USER_END'] in result['prompt'])
def test_set_prompt_openai(self):
- self.backend = 'openai'
- self._load_backend_config('openai')
- self.llm_io.backend = 'openai'
- prompt = self.config_file['BASE_PROMPT']
+ config_file = self._load_config()
+ config_file['BACKEND'] = 'openai'
+ backend_config = self._load_backend_config('openai')
+ prompt = config_file['BASE_PROMPT']
assert('### Instruction' not in prompt)
assert('### Response' not in prompt)
assert('USER_START' in prompt)
assert('USER_END' not in prompt)
- request_body = json.loads(self.backend_config['DEFAULT_BODY'])
-
- result = self.llm_io._set_prompt(request_body, prompt)
- assert(self.config_file['USER_START'] in result['messages'][1]['content'])
- assert(self.config_file['USER_END'] in result['messages'][1]['content'])
+ request_body = json.loads(backend_config['DEFAULT_BODY'])
+ io_util = IoUtil(config=config_file, backend_config=backend_config)
+ result = io_util.io_adapter._set_prompt(request_body=request_body, prompt=prompt, context='')
+ assert(config_file['USER_START'] in result['messages'][1]['content'])
+ assert(config_file['USER_END'] in result['messages'][1]['content'])
def test_set_prompt_llama_cpp(self):
- self.backend = 'llama_cpp'
- self._load_backend_config('llama_cpp')
- self.llm_io.backend = 'llama_cpp'
- prompt = self.config_file['BASE_PROMPT']
+
+ config_file = self._load_config()
+ config_file['BACKEND'] = 'llama_cpp'
+ backend_config = self._load_backend_config('llama_cpp')
+ prompt = config_file['BASE_PROMPT']
assert('### Instruction' not in prompt)
assert('### Response' not in prompt)
assert('USER_START' in prompt)
assert('USER_END' not in prompt)
- request_body = json.loads(self.backend_config['DEFAULT_BODY'])
+ request_body = json.loads(backend_config['DEFAULT_BODY'])
+
+ io_util = IoUtil(config=config_file, backend_config=backend_config)
+ result = io_util.io_adapter._set_prompt(request_body=request_body, prompt=prompt, context='')
+ assert(config_file['USER_START'] in result['messages'][1]['content'])
+ assert(config_file['USER_END'] in result['messages'][1]['content'])
- result = self.llm_io._set_prompt(request_body, prompt)
- assert(self.config_file['USER_START'] in result['messages'][1]['content'])
- assert(self.config_file['USER_END'] in result['messages'][1]['content'])
\ No newline at end of file
+ @responses.activate
+ def test_error_response(self):
+ config_file = self._load_config()
+ backend_config = self._load_backend_config('kobold_cpp')
+ responses.add(responses.POST, backend_config['URL'] + backend_config['ENDPOINT'],
+ json={'results':['']}, status=500)
+ io_util = IoUtil(config=config_file, backend_config=backend_config)
+
+ response = io_util.synchronous_request(request_body=json.loads(backend_config['DEFAULT_BODY']), prompt='test evoke', context='')
+ assert(response == '')
+
+ @responses.activate
+ def test_stream_kobold_cpp(self):
+ config = {'BACKEND':'kobold_cpp', 'USER_START':'', 'USER_END':''}
+ with open(os.path.realpath(os.path.join(os.path.dirname(__file__), f"../backend_kobold_cpp.yaml")), "r") as stream:
+ try:
+ backend_config = yaml.safe_load(stream)
+ except yaml.YAMLError as exc:
+ print(exc)
+ io_util = IoUtil(config=config, backend_config=backend_config) # type: IoUtil
+ io_util.stream = True
+ conn = PlayerConnection(Player('test', 'm'))
+
+ responses.add(responses.POST, backend_config['URL'] + backend_config['DATA_ENDPOINT'],
+ json={'results':[{'text':'stream test'}]}, status=200)
+ with aioresponses() as mocked_responses:
+ # Mock the response for the specified URL
+ mocked_responses.post(backend_config['URL'] + backend_config['STREAM_ENDPOINT'],
+ status=200,
+ body="{'results':[{'text':'stream test'}]}")
+ result = io_util.stream_request(request_body=json.loads(backend_config['DEFAULT_BODY']), prompt='test evoke', context='', io = IoAdapterBase(conn))
+ assert(result == 'stream test')
+
+ def test_stream_llama_cpp(self):
+ config = {'BACKEND':'llama_cpp', 'USER_START':'', 'USER_END':''}
+ with open(os.path.realpath(os.path.join(os.path.dirname(__file__), f"../backend_llama_cpp.yaml")), "r") as stream:
+ try:
+ backend_config = yaml.safe_load(stream)
+ except yaml.YAMLError as exc:
+ print(exc)
+ io_util = IoUtil(config=config, backend_config=backend_config) # type: IoUtil
+ io_util.stream = True
+ conn = PlayerConnection(Player('test', 'm'))
+
+ with aioresponses() as mocked_responses:
+ # Mock the response for the specified URL
+ mocked_responses.post(backend_config['URL'] + backend_config['STREAM_ENDPOINT'],
+ status=200,
+ body='data: {"choices":[{"delta":{"content":"stream test"}}]}')
+ result = io_util.stream_request(request_body=json.loads(backend_config['DEFAULT_BODY']), prompt='test evoke', context='', io = IoAdapterBase(conn))
+ assert(result == 'stream test')
diff --git a/tests/test_llm_utils.py b/tests/test_llm_utils.py
index 65945676..514065f3 100644
--- a/tests/test_llm_utils.py
+++ b/tests/test_llm_utils.py
@@ -1,5 +1,8 @@
import datetime
import json
+import os
+
+import yaml
from tale.image_gen.automatic1111 import Automatic1111
import tale.llm.llm_cache as llm_cache
from tale import mud_context, weapon_type
@@ -8,6 +11,7 @@
from tale.base import Item, Location, Weapon
from tale.coord import Coord
from tale.json_story import JsonStory
+from tale.llm.llm_io import IoUtil
from tale.llm.llm_utils import LlmUtil
from tale.npc_defs import StationaryMob
from tale.races import UnarmedAttack