Skip to content

Commit

Permalink
fix for headers and add test
Browse files Browse the repository at this point in the history
  • Loading branch information
neph1 committed Nov 3, 2024
1 parent 3ef4812 commit 5ae6a17
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 1 deletion.
3 changes: 2 additions & 1 deletion tale/llm/llm_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ def __init__(self, config: dict = None, backend_config: dict = None):
self.headers = headers
self.io_adapter = LlamaCppAdapter(self.url, backend_config['STREAM_ENDPOINT'], config.get('USER_START', ''), config.get('USER_END', ''), config.get('SYSTEM_START', ''), config.get('PROMPT_END', ''))
else:
headers['Authorization'] = f"Bearer {backend_config['API_PASSWORD']}"
if 'API_PASSWORD' in backend_config and backend_config['API_PASSWORD']:
headers['Authorization'] = f"Bearer {backend_config['API_PASSWORD']}"
self.headers = headers
self.io_adapter = KoboldCppAdapter(self.url, backend_config['STREAM_ENDPOINT'], backend_config['DATA_ENDPOINT'], config.get('USER_START', ''), config.get('USER_END', ''), config.get('SYSTEM_START', ''), config.get('PROMPT_END', ''))

Expand Down
16 changes: 16 additions & 0 deletions tests/test_llm_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,22 @@ def test_set_prompt_llama_cpp(self):
assert('<context>context</context>' in result['messages'][1]['content'])
assert(result['messages'][0]['content'] != '<context>context</context>')


def test_password_in_header(self):
config_file = self._load_config()
config_file['BACKEND'] = 'kobold_cpp'

backend_config = self._load_backend_config('kobold_cpp')
io_util = IoUtil(config=config_file, backend_config=backend_config)

assert not io_util.headers

backend_config = self._load_backend_config('kobold_cpp')
backend_config['API_PASSWORD'] = 'test_password'
io_util = IoUtil(config=config_file, backend_config=backend_config)

assert io_util.headers['Authorization'] == f"Bearer {backend_config['API_PASSWORD']}"

@responses.activate
def test_error_response(self):
config_file = self._load_config()
Expand Down

0 comments on commit 5ae6a17

Please sign in to comment.