From d9c6165afe3d9e9160cfeb27b1679272cafbe2f2 Mon Sep 17 00:00:00 2001 From: rickard Date: Sun, 3 Nov 2024 20:23:24 +0100 Subject: [PATCH] fix for headers and add test --- tale/llm/llm_io.py | 3 ++- tests/test_json_story.py | 1 + tests/test_llm_io.py | 16 ++++++++++++++++ 3 files changed, 19 insertions(+), 1 deletion(-) diff --git a/tale/llm/llm_io.py b/tale/llm/llm_io.py index aa6bbdc5..593f54e6 100644 --- a/tale/llm/llm_io.py +++ b/tale/llm/llm_io.py @@ -21,7 +21,8 @@ def __init__(self, config: dict = None, backend_config: dict = None): self.headers = headers self.io_adapter = LlamaCppAdapter(self.url, backend_config['STREAM_ENDPOINT'], config.get('USER_START', ''), config.get('USER_END', ''), config.get('SYSTEM_START', ''), config.get('PROMPT_END', '')) else: - headers['Authorization'] = f"Bearer {backend_config['API_PASSWORD']}" + if 'API_PASSWORD' in backend_config and backend_config['API_PASSWORD']: + headers['Authorization'] = f"Bearer {backend_config['API_PASSWORD']}" self.headers = headers self.io_adapter = KoboldCppAdapter(self.url, backend_config['STREAM_ENDPOINT'], backend_config['DATA_ENDPOINT'], config.get('USER_START', ''), config.get('USER_END', ''), config.get('SYSTEM_START', ''), config.get('PROMPT_END', '')) diff --git a/tests/test_json_story.py b/tests/test_json_story.py index c3e6616e..0cdb7fab 100644 --- a/tests/test_json_story.py +++ b/tests/test_json_story.py @@ -12,6 +12,7 @@ from tale.mob_spawner import MobSpawner class TestJsonStory(): + wearable.wearbles_story = [] driver = IFDriver(screen_delay=99, gui=False, web=True, wizard_override=True) driver.game_clock = util.GameDateTime(datetime.datetime(year=2023, month=1, day=1), 1) story = JsonStory('tests/files/world_story/', parse_utils.load_story_config(parse_utils.load_json('tests/files/world_story/story_config.json'))) diff --git a/tests/test_llm_io.py b/tests/test_llm_io.py index 2227c3af..11b41f46 100644 --- a/tests/test_llm_io.py +++ b/tests/test_llm_io.py @@ -76,6 +76,22 @@ def test_set_prompt_llama_cpp(self): assert('context' in result['messages'][1]['content']) assert(result['messages'][0]['content'] != 'context') + + def test_password_in_header(self): + config_file = self._load_config() + config_file['BACKEND'] = 'kobold_cpp' + + backend_config = self._load_backend_config('kobold_cpp') + io_util = IoUtil(config=config_file, backend_config=backend_config) + + assert not io_util.headers + + backend_config = self._load_backend_config('kobold_cpp') + backend_config['API_PASSWORD'] = 'test_password' + io_util = IoUtil(config=config_file, backend_config=backend_config) + + assert io_util.headers['Authorization'] == f"Bearer {backend_config['API_PASSWORD']}" + @responses.activate def test_error_response(self): config_file = self._load_config()