Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
qinxuye committed Oct 23, 2024
1 parent dff8ed5 commit 9f2904c
Showing 1 changed file with 14 additions and 23 deletions.
37 changes: 14 additions & 23 deletions xinference/model/rerank/mindie.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,29 +67,20 @@ def rerank(
attention_mask = encoded_input.data["attention_mask"].numpy().astype(np.int64)
n_tokens = attention_mask.sum().item()

similarity_scores = []
for i in range(len(sentence_combinations)):
output = session.infer(
feeds=[input_ids, attention_mask],
mode="dymshape",
custom_sizes=10000000,
)
scores = (
torch.from_numpy(output[0][:, 0])
.view(
-1,
)
.float()
)
exec_time = session.summary().exec_time_list[-1]
logger.info(
"%s%s inference time: %.2f ms",
i + 1,
"tsnrhtdd"[(i + 1) % 5 * ((i + 1) % 100 ^ 15 > 4 > (i + 1) % 10) :: 4],
exec_time[1] - exec_time[0],
output = session.infer(
feeds=[input_ids, attention_mask],
mode="dymshape",
custom_sizes=10000000,
)
scores = (
torch.from_numpy(output[0][:, 0])
.view(
-1,
)
logger.info("scores [positive, negative]: %s", scores)
similarity_scores.append(float(scores[0]))
.float()
)
logger.info("scores [positive, negative]: %s", scores)
similarity_scores = scores.tolist()

sim_scores_argsort = list(reversed(np.argsort(similarity_scores)))
if top_n is not None:
Expand Down Expand Up @@ -132,7 +123,7 @@ def rerank(
}

del similarity_scores
self._session.free_resource()
# self._session.free_resource()

return Rerank(id=str(uuid.uuid1()), results=docs, meta=metadata)

Expand Down

0 comments on commit 9f2904c

Please sign in to comment.