Skip to content

Commit

Permalink
some tidying up for streaming
Browse files Browse the repository at this point in the history
  • Loading branch information
neph1 committed Jan 7, 2024
1 parent c1b2c9a commit 540d250
Show file tree
Hide file tree
Showing 6 changed files with 17 additions and 19 deletions.
17 changes: 10 additions & 7 deletions tale/llm/llm_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import aiohttp
import asyncio
import json
from tale.errors import LlmResponseException
import tale.parse_utils as parse_utils
from tale.player_utils import TextBuffer

Expand Down Expand Up @@ -41,10 +42,13 @@ def synchronous_request(self, request_body: dict, prompt: str) -> str:
request_body['response_format'] = self.openai_json_format
self._set_prompt(request_body, prompt)
response = requests.post(self.url + self.endpoint, headers=self.headers, data=json.dumps(request_body))
if self.backend == 'kobold_cpp':
parsed_response = self._parse_kobold_result(response.text)
else:
parsed_response = self._parse_openai_result(response.text)
try:
if self.backend == 'kobold_cpp':
parsed_response = self._parse_kobold_result(response.text)
else:
parsed_response = self._parse_openai_result(response.text)
except LlmResponseException as exc:
return ''
return parsed_response

def asynchronous_request(self, request_body: dict, prompt: str) -> str:
Expand Down Expand Up @@ -87,7 +91,7 @@ def _do_process_result(self, url, io = None, wait: bool = False) -> str:
new_text = text[len(old_text):]
io.output_no_newline(new_text, new_paragraph=False)
old_text = text
io.output_no_newline("")
io.output_no_newline("</p>", new_paragraph=False)
return old_text

def _parse_kobold_result(self, result: str) -> str:
Expand All @@ -99,8 +103,7 @@ def _parse_openai_result(self, result: str) -> str:
try:
return json.loads(result)['choices'][0]['message']['content']
except:
print("Error parsing result from OpenAI")
print(result)
raise LlmResponseException("Error parsing result from backend")

def _set_prompt(self, request_body: dict, prompt: str) -> dict:
if self.user_start_prompt:
Expand Down
5 changes: 3 additions & 2 deletions tale/llm/llm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def __init__(self, io_util: IoUtil = None):
io_util=self.io_util,
backend=self.backend)

def evoke(self, player_io: TextBuffer, message: str, short_len : bool=False, rolling_prompt='', alt_prompt='', skip_history=True):
def evoke(self, message: str, short_len : bool=False, rolling_prompt='', alt_prompt='', skip_history=True):
"""Evoke a response from LLM. Async if stream is True, otherwise synchronous.
Update the rolling prompt with the latest message.
Will put generated text in lm_cache.look_hashes, and reuse it if same hash is generated."""
Expand Down Expand Up @@ -101,7 +101,8 @@ def evoke(self, player_io: TextBuffer, message: str, short_len : bool=False, rol
llm_cache.cache_look(text, text_hash_value)
return output_template.format(message=message, text=text), rolling_prompt

text = self.io_util.stream_request(request_body=request_body, player_io=player_io, prompt=prompt, io=self.connection)
self.connection.output(output_template.format(message=message, text='<p>'))
text = self.io_util.stream_request(request_body=request_body, prompt=prompt, io=self.connection)
llm_cache.cache_look(text, text_hash_value)
return '\n', rolling_prompt

Expand Down
3 changes: 1 addition & 2 deletions tale/player.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,7 @@ def tell(self, message: str, *, end: bool=False, format: bool=True, evoke: bool=
if evoke:
if self.title in message:
message = message.replace(self.title, 'you')
msg, rolling_prompt = mud_context.driver.llm_util.evoke(self._output,
message,
msg, rolling_prompt = mud_context.driver.llm_util.evoke(message,
short_len = short_len,
rolling_prompt = self.rolling_prompt,
alt_prompt = alt_prompt)
Expand Down
1 change: 1 addition & 0 deletions tests/supportstuff.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ class FakeIoUtil(IoUtil):
def __init__(self, response: list = []) -> None:
super().__init__()
self.response = response # type: list
self.backend = 'kobold_cpp'

def synchronous_request(self, request_body: dict, prompt: str = None) -> str:
return self.response.pop(0) if isinstance(self.response, list) > 0 and len(self.response) > 0 else self.response
Expand Down
3 changes: 2 additions & 1 deletion tests/test_llm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,9 @@ def test_read_items(self):
def test_evoke(self):
evoke_string = 'test response'
self.llm_util.io_util = FakeIoUtil(response=evoke_string)

self.llm_util.set_story(self.story)
result = self.llm_util.evoke(message='test evoke', player_io=None)
result = self.llm_util.evoke(message='test evoke')
assert(result)
assert(llm_cache.get_looks([llm_cache.generate_hash('test evoke')]) == evoke_string)

Expand Down
7 changes: 0 additions & 7 deletions tests/test_player.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,13 +582,6 @@ def test_strip(self):
output.print(" 1 ", format=False)
self.assertEqual([(" 1 \n", False)], output.get_paragraphs())

def test_no_line_break(self):
output = TextBuffer()
output.print("1", line_breaks=False)
output.print("2", line_breaks=False)
output.print("3", line_breaks=False)
self.assertEqual([("123\n", True)], output.get_paragraphs())


class TestCharacterBuilders(unittest.TestCase):
def setUp(self):
Expand Down

0 comments on commit 540d250

Please sign in to comment.