-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathinteract_mistral_llamacpp.py
66 lines (54 loc) · 1.98 KB
/
interact_mistral_llamacpp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import fire
from llama_cpp import Llama
SYSTEM_PROMPT = "Ты — Сайга, русскоязычный автоматический ассистент. Ты разговариваешь с людьми и помогаешь им."
SYSTEM_TOKEN = 1587
USER_TOKEN = 2188
BOT_TOKEN = 12435
LINEBREAK_TOKEN = 13
ROLE_TOKENS = {"user": USER_TOKEN, "bot": BOT_TOKEN, "system": SYSTEM_TOKEN}
def get_message_tokens(model, role, content):
message_tokens = model.tokenize(content.encode("utf-8"))
message_tokens.insert(1, ROLE_TOKENS[role])
message_tokens.insert(2, LINEBREAK_TOKEN)
message_tokens.append(model.token_eos())
return message_tokens
def get_system_tokens(model):
system_message = {"role": "system", "content": SYSTEM_PROMPT}
return get_message_tokens(model, **system_message)
def interact(
model_path, n_ctx=2000, top_k=30, top_p=0.9, temperature=0.2, repeat_penalty=1.1
):
model = Llama(
model_path=model_path,
n_ctx=n_ctx,
n_parts=1,
)
system_tokens = get_system_tokens(model)
tokens = system_tokens
model.eval(tokens)
while True:
user_message = input("User: ")
message_tokens = get_message_tokens(
model=model, role="user", content=user_message
)
role_tokens = [model.token_bos(), BOT_TOKEN, LINEBREAK_TOKEN]
tokens += message_tokens + role_tokens
print(tokens)
full_prompt = model.detokenize(tokens)
print(model.tokenize(full_prompt))
generator = model.generate(
tokens,
top_k=top_k,
top_p=top_p,
temp=temperature,
repeat_penalty=repeat_penalty,
)
for token in generator:
token_str = model.detokenize([token]).decode("utf-8", errors="ignore")
tokens.append(token)
if token == model.token_eos():
break
print(token_str, end="", flush=True)
print()
if __name__ == "__main__":
fire.Fire(interact)