From a4f9cebef8f141168d0778a486b6c06104cc023c Mon Sep 17 00:00:00 2001 From: Florian Maas Date: Thu, 4 Jul 2024 09:27:49 +0200 Subject: [PATCH 1/5] refactor a bit --- sentence_transformers/SentenceTransformer.py | 115 +++++++++++-------- 1 file changed, 66 insertions(+), 49 deletions(-) diff --git a/sentence_transformers/SentenceTransformer.py b/sentence_transformers/SentenceTransformer.py index eea804139..37aa0b9ea 100644 --- a/sentence_transformers/SentenceTransformer.py +++ b/sentence_transformers/SentenceTransformer.py @@ -538,45 +538,23 @@ def encode( self.to(device) - all_embeddings = [] length_sorted_idx = np.argsort([-self._text_length(sen) for sen in sentences]) sentences_sorted = [sentences[idx] for idx in length_sorted_idx] for start_index in trange(0, len(sentences), batch_size, desc="Batches", disable=not show_progress_bar): sentences_batch = sentences_sorted[start_index : start_index + batch_size] features = self.tokenize(sentences_batch) + if self.device.type == "hpu": if "input_ids" in features: - curr_tokenize_len = features["input_ids"].shape - additional_pad_len = 2 ** math.ceil(math.log2(curr_tokenize_len[1])) - curr_tokenize_len[1] - features["input_ids"] = torch.cat( - ( - features["input_ids"], - torch.ones((curr_tokenize_len[0], additional_pad_len), dtype=torch.int8), - ), - -1, - ) - features["attention_mask"] = torch.cat( - ( - features["attention_mask"], - torch.zeros((curr_tokenize_len[0], additional_pad_len), dtype=torch.int8), - ), - -1, - ) - if "token_type_ids" in features: - features["token_type_ids"] = torch.cat( - ( - features["token_type_ids"], - torch.zeros((curr_tokenize_len[0], additional_pad_len), dtype=torch.int8), - ), - -1, - ) + self._pad_features(features) features = batch_to_device(features, device) features.update(extra_features) with torch.no_grad(): out_features = self.forward(features) + if self.device.type == "hpu": out_features = copy.deepcopy(out_features) @@ -584,30 +562,7 @@ def encode( out_features["sentence_embedding"], self.truncate_dim ) - if output_value == "token_embeddings": - embeddings = [] - for token_emb, attention in zip(out_features[output_value], out_features["attention_mask"]): - last_mask_id = len(attention) - 1 - while last_mask_id > 0 and attention[last_mask_id].item() == 0: - last_mask_id -= 1 - - embeddings.append(token_emb[0 : last_mask_id + 1]) - elif output_value is None: # Return all outputs - embeddings = [] - for sent_idx in range(len(out_features["sentence_embedding"])): - row = {name: out_features[name][sent_idx] for name in out_features} - embeddings.append(row) - else: # Sentence embeddings - embeddings = out_features[output_value] - embeddings = embeddings.detach() - if normalize_embeddings: - embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1) - - # fixes for #522 and #487 to avoid oom problems on gpu with large datasets - if convert_to_numpy: - embeddings = embeddings.cpu() - - all_embeddings.extend(embeddings) + all_embeddings = self._process_embeddings(out_features, output_value) all_embeddings = [all_embeddings[idx] for idx in np.argsort(length_sorted_idx)] @@ -636,6 +591,68 @@ def encode( return all_embeddings + @staticmethod + def _pad_features(features): + curr_tokenize_len = features["input_ids"].shape + additional_pad_len = 2 ** math.ceil(math.log2(curr_tokenize_len[1])) - curr_tokenize_len[1] + features["input_ids"] = torch.cat( + ( + features["input_ids"], + torch.ones((curr_tokenize_len[0], additional_pad_len), dtype=torch.int8), + ), + -1, + ) + features["attention_mask"] = torch.cat( + ( + features["attention_mask"], + torch.zeros((curr_tokenize_len[0], additional_pad_len), dtype=torch.int8), + ), + -1, + ) + if "token_type_ids" in features: + features["token_type_ids"] = torch.cat( + ( + features["token_type_ids"], + torch.zeros((curr_tokenize_len[0], additional_pad_len), dtype=torch.int8), + ), + -1, + ) + + @staticmethod + def _process_token_embeddings(out_features): + embeddings = [] + for token_emb, attention in zip(out_features["token_embeddings"], out_features["attention_mask"]): + last_mask_id = len(attention) - 1 + while last_mask_id > 0 and attention[last_mask_id].item() == 0: + last_mask_id -= 1 + embeddings.append(token_emb[0 : last_mask_id + 1]) + return embeddings + + @staticmethod + def _process_all_outputs(out_features): + embeddings = [] + for sent_idx in range(len(out_features["sentence_embedding"])): + row = {name: out_features[name][sent_idx] for name in out_features} + embeddings.append(row) + return embeddings + + def _process_sentence_embeddings(self, out_features): + embeddings = out_features["sentence_embedding"] + embeddings = embeddings.detach() + if self.normalize_embeddings: + embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1) + if self.convert_to_numpy: + embeddings = embeddings.cpu() + return embeddings + + def _process_embeddings(self, out_features, output_value): + if output_value == "token_embeddings": + return self._process_token_embeddings(out_features) + elif output_value is None: + return self._process_all_outputs(out_features) + else: + return self._process_sentence_embeddings(out_features) + @property def similarity_fn_name(self) -> Optional[str]: """Return the name of the similarity function used by :meth:`SentenceTransformer.similarity` and :meth:`SentenceTransformer.similarity_pairwise`. From 1c5d6a26f624994a74927c73de474d7a20bef297 Mon Sep 17 00:00:00 2001 From: Florian Maas Date: Thu, 4 Jul 2024 09:39:07 +0200 Subject: [PATCH 2/5] improve --- sentence_transformers/SentenceTransformer.py | 24 ++++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/sentence_transformers/SentenceTransformer.py b/sentence_transformers/SentenceTransformer.py index 37aa0b9ea..94dbb1f5f 100644 --- a/sentence_transformers/SentenceTransformer.py +++ b/sentence_transformers/SentenceTransformer.py @@ -562,7 +562,15 @@ def encode( out_features["sentence_embedding"], self.truncate_dim ) - all_embeddings = self._process_embeddings(out_features, output_value) + all_embeddings: list = [] + if output_value == "token_embeddings": + all_embeddings.extend(self._process_token_embeddings(out_features)) + elif output_value == "sentence_embeddings": + all_embeddings.extend( + self._process_sentence_embeddings(out_features, normalize_embeddings, convert_to_numpy) + ) + elif not output_value: + all_embeddings.extend(self._process_all_outputs(out_features)) all_embeddings = [all_embeddings[idx] for idx in np.argsort(length_sorted_idx)] @@ -636,23 +644,15 @@ def _process_all_outputs(out_features): embeddings.append(row) return embeddings - def _process_sentence_embeddings(self, out_features): + def _process_sentence_embeddings(self, out_features, normalize_embeddings, convert_to_numpy): embeddings = out_features["sentence_embedding"] embeddings = embeddings.detach() - if self.normalize_embeddings: + if normalize_embeddings: embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1) - if self.convert_to_numpy: + if convert_to_numpy: embeddings = embeddings.cpu() return embeddings - def _process_embeddings(self, out_features, output_value): - if output_value == "token_embeddings": - return self._process_token_embeddings(out_features) - elif output_value is None: - return self._process_all_outputs(out_features) - else: - return self._process_sentence_embeddings(out_features) - @property def similarity_fn_name(self) -> Optional[str]: """Return the name of the similarity function used by :meth:`SentenceTransformer.similarity` and :meth:`SentenceTransformer.similarity_pairwise`. From c7b783b5d2429acaf27d15dd85b41d5fd66ee9a0 Mon Sep 17 00:00:00 2001 From: Florian Maas Date: Thu, 4 Jul 2024 13:27:31 +0200 Subject: [PATCH 3/5] small fix --- sentence_transformers/SentenceTransformer.py | 23 +++++++++++--------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/sentence_transformers/SentenceTransformer.py b/sentence_transformers/SentenceTransformer.py index 94dbb1f5f..54538aac4 100644 --- a/sentence_transformers/SentenceTransformer.py +++ b/sentence_transformers/SentenceTransformer.py @@ -537,6 +537,7 @@ def encode( device = self.device self.to(device) + all_embeddings = [] length_sorted_idx = np.argsort([-self._text_length(sen) for sen in sentences]) sentences_sorted = [sentences[idx] for idx in length_sorted_idx] @@ -562,15 +563,16 @@ def encode( out_features["sentence_embedding"], self.truncate_dim ) - all_embeddings: list = [] - if output_value == "token_embeddings": - all_embeddings.extend(self._process_token_embeddings(out_features)) - elif output_value == "sentence_embeddings": - all_embeddings.extend( - self._process_sentence_embeddings(out_features, normalize_embeddings, convert_to_numpy) - ) - elif not output_value: - all_embeddings.extend(self._process_all_outputs(out_features)) + if output_value == "token_embeddings": + all_embeddings.extend(self._process_token_embeddings(out_features)) + elif output_value == "sentence_embedding": + all_embeddings.extend( + self._process_sentence_embeddings(out_features, normalize_embeddings, convert_to_numpy) + ) + elif not output_value: + all_embeddings.extend(self._process_all_outputs(out_features)) + else: + raise ValueError(f"Got unexpected value for 'output_value' : {output_value}") all_embeddings = [all_embeddings[idx] for idx in np.argsort(length_sorted_idx)] @@ -644,7 +646,8 @@ def _process_all_outputs(out_features): embeddings.append(row) return embeddings - def _process_sentence_embeddings(self, out_features, normalize_embeddings, convert_to_numpy): + @staticmethod + def _process_sentence_embeddings(out_features, normalize_embeddings, convert_to_numpy): embeddings = out_features["sentence_embedding"] embeddings = embeddings.detach() if normalize_embeddings: From ff20b4f15fdcecd29c39db6f2eb20ed6d7b0eacb Mon Sep 17 00:00:00 2001 From: Florian Maas Date: Fri, 5 Jul 2024 13:58:52 +0200 Subject: [PATCH 4/5] improve error --- sentence_transformers/SentenceTransformer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sentence_transformers/SentenceTransformer.py b/sentence_transformers/SentenceTransformer.py index 54538aac4..412dd253f 100644 --- a/sentence_transformers/SentenceTransformer.py +++ b/sentence_transformers/SentenceTransformer.py @@ -572,7 +572,9 @@ def encode( elif not output_value: all_embeddings.extend(self._process_all_outputs(out_features)) else: - raise ValueError(f"Got unexpected value for 'output_value' : {output_value}") + raise ValueError( + f"Got unexpected value for 'output_value' : {output_value}. Valid values are 'token_embeddings', 'sentence_embedding' or None." + ) all_embeddings = [all_embeddings[idx] for idx in np.argsort(length_sorted_idx)] From 10dff7a32ee68d2f8a206f21a382c299a5212227 Mon Sep 17 00:00:00 2001 From: Florian Maas Date: Tue, 9 Jul 2024 14:45:36 +0200 Subject: [PATCH 5/5] add typehints --- sentence_transformers/SentenceTransformer.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/sentence_transformers/SentenceTransformer.py b/sentence_transformers/SentenceTransformer.py index 412dd253f..cda70ce8a 100644 --- a/sentence_transformers/SentenceTransformer.py +++ b/sentence_transformers/SentenceTransformer.py @@ -604,7 +604,10 @@ def encode( return all_embeddings @staticmethod - def _pad_features(features): + def _pad_features(features: Dict[str, torch.Tensor]) -> None: + """ + Pads the input features to the next power of 2 for compatibility with certain hardware accelerators. + """ curr_tokenize_len = features["input_ids"].shape additional_pad_len = 2 ** math.ceil(math.log2(curr_tokenize_len[1])) - curr_tokenize_len[1] features["input_ids"] = torch.cat( @@ -631,7 +634,7 @@ def _pad_features(features): ) @staticmethod - def _process_token_embeddings(out_features): + def _process_token_embeddings(out_features: Dict[str, torch.Tensor]) -> List[torch.Tensor]: embeddings = [] for token_emb, attention in zip(out_features["token_embeddings"], out_features["attention_mask"]): last_mask_id = len(attention) - 1 @@ -641,7 +644,7 @@ def _process_token_embeddings(out_features): return embeddings @staticmethod - def _process_all_outputs(out_features): + def _process_all_outputs(out_features: Dict[str, torch.Tensor]) -> List[Dict[str, torch.Tensor]]: embeddings = [] for sent_idx in range(len(out_features["sentence_embedding"])): row = {name: out_features[name][sent_idx] for name in out_features} @@ -649,7 +652,9 @@ def _process_all_outputs(out_features): return embeddings @staticmethod - def _process_sentence_embeddings(out_features, normalize_embeddings, convert_to_numpy): + def _process_sentence_embeddings( + out_features: Dict[str, torch.Tensor], normalize_embeddings: bool, convert_to_numpy: bool + ) -> Union[List[np.ndarray], List[torch.Tensor]]: embeddings = out_features["sentence_embedding"] embeddings = embeddings.detach() if normalize_embeddings: