forked from bzzzzzu/nnhack_rzd
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmake_embeddings.py
69 lines (58 loc) · 3.14 KB
/
make_embeddings.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import numpy as np
from transformers import AutoTokenizer, AutoModel
import torch
from torch import Tensor
import os
def average_pool(last_hidden_states: Tensor,
attention_mask: Tensor) -> Tensor:
last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
def create_embeddings(string_dict, device="cuda"):
embeddings_raw_name = 'embeddings.npy'
embeddings_text_name = 'embeddings.txt'
embeddings_answer_name = 'embeddings_answer.txt'
if not os.path.exists(embeddings_raw_name):
tokenizer = AutoTokenizer.from_pretrained('intfloat/multilingual-e5-base')
model = AutoModel.from_pretrained('intfloat/multilingual-e5-base').to(device)
embeddings = []
# пока складываем основную причину с подпричиной
string_list = []
answer_list = []
for key in string_dict.keys():
for subkey in string_dict[key]:
temp_str = f'{key}. {subkey}.'
string_list.append(str.replace(temp_str, '..', '.'))
answer_list.append(string_dict[key][subkey])
# медленно (лучше батчами), но просто и исполняется один раз
for line in string_list:
batch_dict = tokenizer(line, max_length=512, padding=True, truncation=True, return_tensors='pt').to(device)
outputs = model(**batch_dict)
embedding = average_pool(outputs.last_hidden_state, batch_dict['attention_mask'])
embeddings.append(embedding[0])
embeddings = torch.stack(embeddings).cpu().detach().numpy()
np.save(embeddings_raw_name, embeddings)
with open(embeddings_text_name, 'w', encoding='utf-8') as f:
for line in string_list:
f.write(line + '\n')
with open(embeddings_answer_name, 'w', encoding='utf-8') as f:
for line in answer_list:
f.write(line + '\n')
embeddings_raw = embeddings
embeddings_text = string_list
embeddings_answer = answer_list
else:
embeddings_raw = np.load(embeddings_raw_name)
with open(embeddings_text_name, 'r', encoding='utf-8') as f:
embeddings_text = f.readlines()
with open(embeddings_answer_name, 'r', encoding='utf-8') as f:
embeddings_answer = f.readlines()
return embeddings_raw, embeddings_text, embeddings_answer
# Достаточно быстро работает и на процессоре для одиночных вводов текста
# Желательно держать модель в памяти, ну пока сойдет
def get_embedding(text):
tokenizer = AutoTokenizer.from_pretrained('intfloat/multilingual-e5-base')
model = AutoModel.from_pretrained('intfloat/multilingual-e5-base')
batch_dict = tokenizer(text, max_length=512, padding=True, truncation=True, return_tensors='pt')
outputs = model(**batch_dict)
embedding = average_pool(outputs.last_hidden_state, batch_dict['attention_mask']).cpu().detach().numpy()
return embedding