Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ap/fix mistral template #183

Closed
wants to merge 10 commits into from
13 changes: 7 additions & 6 deletions src/instructlab/training/chat_templates/ibm_generic_tmpl.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
# SPDX-License-Identifier: Apache-2.0

# First Party
from instructlab.training.tokenizer_utils import SpecialTokens
from instructlab.training.tokenizer_utils import SpecialTokens, TokenInfo

SPECIAL_TOKENS = SpecialTokens(
system="<|system|>",
user="<|user|>",
assistant="<|assistant|>",
eos="<|endoftext|>",
pad="<|pad|>",
system=TokenInfo("<|system|>", add_to_tokenizer=True),
user=TokenInfo("<|user|>", add_to_tokenizer=True),
assistant=TokenInfo("<|assistant|>", add_to_tokenizer=True),
eos=TokenInfo("<|endoftext|>", add_to_tokenizer=True),
pad=TokenInfo("<|pad|>", add_to_tokenizer=True),
bos=TokenInfo("<|begginingoftext|>", add_to_tokenizer=True),
)

CHAT_TEMPLATE = (
Expand Down
60 changes: 45 additions & 15 deletions src/instructlab/training/chat_templates/mistral_tmpl.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,54 @@
# SPDX-License-Identifier: Apache-2.0

# First Party
from instructlab.training.tokenizer_utils import SpecialTokens
from instructlab.training.tokenizer_utils import SpecialTokens, TokenInfo

SPECIAL_TOKENS = SpecialTokens(
bos="<s>",
eos="</s>",
user="[INST]",
assistant="[/INST]",
bos=TokenInfo("<s>", add_to_tokenizer=True),
eos=TokenInfo("</s>", add_to_tokenizer=True),
user=TokenInfo("[INST]", add_to_tokenizer=False),
assistant=TokenInfo("[/INST]", add_to_tokenizer=False),
# user="[INST]",
# assistant="[/INST]",
)

# CHAT_TEMPLATE = (
# "{{ '<s>' }}"
# "{% for message in messages %}"
# "{% if message['role'] == 'pretraining' %}"
# "{{'<|pretrain|>' + message['content'] + '</s>' + '<|/pretrain|>'}}"
# "{% elif message['role'] == 'user' %}"
# "{{ '[INST] ' + message['content'] + ' [/INST]' }}"
# "{% elif message['role'] == 'assistant' %}"
# "{{ message['content'] + '</s>'}}"
# "{% endif %}"
# "{% endfor %}"
# )

CHAT_TEMPLATE = (
"{%- if messages[0]['role'] == 'system' %}"
"{%- set system_message = messages[0]['content'] %}"
"{%- set loop_messages = messages[1:] %}"
"{%- else %}"
"{%- set loop_messages = messages %}"
"{%- endif %}"
"{{ '<s>' }}"
"{% for message in messages %}"
"{% if message['role'] == 'pretraining' %}"
"{{'<|pretrain|>' + message['content'] + '</s>' + '<|/pretrain|>'}}"
"{% elif message['role'] == 'user' %}"
"{{ '[INST] ' + message['content'] + ' [/INST]' }}"
"{% elif message['role'] == 'assistant' %}"
"{{ message['content'] + '</s>'}}"
"{% endif %}"
"{% endfor %}"
)
"{%- for message in loop_messages %}"
"{%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}"
"{{- raise_exception('After the optional system message, conversation roles must alternate user/assistant/user/assistant/...') }}"
"{%- endif %}"
"{%- if message['role'] == 'user' %}"
"{%- if loop.first and system_message is defined %}"
"{{- ' [INST] ' + system_message + '\n\n' + message['content'] + ' [/INST]' }}"
"{%- else %}"
"{{- ' [INST] ' + message['content'] + ' [/INST]' }}"
"{%- endif %}"
"{%- elif message['role'] == 'pretraining' %}"
"{{- '<|pretrain|>' + message['content'] + '</s>' + '<|/pretrain|>' }}"
"{%- elif message['role'] == 'assistant' %}"
"{{- ' ' + message['content'] + '</s>'}}"
"{%- else %}"
"{{- raise_exception('Only user and assistant roles are supported, with the exception of an initial optional system message!') }}"
"{%- endif %}"
"{%- endfor %}"
)
Loading
Loading