Skip to content

Commit

Permalink
fix for headers and add test
Browse files Browse the repository at this point in the history
  • Loading branch information
neph1 committed Nov 3, 2024
1 parent 3ef4812 commit d9c6165
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 1 deletion.
3 changes: 2 additions & 1 deletion tale/llm/llm_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ def __init__(self, config: dict = None, backend_config: dict = None):
self.headers = headers
self.io_adapter = LlamaCppAdapter(self.url, backend_config['STREAM_ENDPOINT'], config.get('USER_START', ''), config.get('USER_END', ''), config.get('SYSTEM_START', ''), config.get('PROMPT_END', ''))
else:
headers['Authorization'] = f"Bearer {backend_config['API_PASSWORD']}"
if 'API_PASSWORD' in backend_config and backend_config['API_PASSWORD']:
headers['Authorization'] = f"Bearer {backend_config['API_PASSWORD']}"
self.headers = headers
self.io_adapter = KoboldCppAdapter(self.url, backend_config['STREAM_ENDPOINT'], backend_config['DATA_ENDPOINT'], config.get('USER_START', ''), config.get('USER_END', ''), config.get('SYSTEM_START', ''), config.get('PROMPT_END', ''))

Expand Down
1 change: 1 addition & 0 deletions tests/test_json_story.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from tale.mob_spawner import MobSpawner

class TestJsonStory():
wearable.wearbles_story = []
driver = IFDriver(screen_delay=99, gui=False, web=True, wizard_override=True)
driver.game_clock = util.GameDateTime(datetime.datetime(year=2023, month=1, day=1), 1)
story = JsonStory('tests/files/world_story/', parse_utils.load_story_config(parse_utils.load_json('tests/files/world_story/story_config.json')))
Expand Down
16 changes: 16 additions & 0 deletions tests/test_llm_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,22 @@ def test_set_prompt_llama_cpp(self):
assert('<context>context</context>' in result['messages'][1]['content'])
assert(result['messages'][0]['content'] != '<context>context</context>')


def test_password_in_header(self):
config_file = self._load_config()
config_file['BACKEND'] = 'kobold_cpp'

backend_config = self._load_backend_config('kobold_cpp')
io_util = IoUtil(config=config_file, backend_config=backend_config)

assert not io_util.headers

backend_config = self._load_backend_config('kobold_cpp')
backend_config['API_PASSWORD'] = 'test_password'
io_util = IoUtil(config=config_file, backend_config=backend_config)

assert io_util.headers['Authorization'] == f"Bearer {backend_config['API_PASSWORD']}"

@responses.activate
def test_error_response(self):
config_file = self._load_config()
Expand Down

0 comments on commit d9c6165

Please sign in to comment.