Skip to content

Commit

Permalink
update tutorials
Browse files Browse the repository at this point in the history
  • Loading branch information
ZiyiXia committed Dec 27, 2024
1 parent 51da3a7 commit bae5fab
Show file tree
Hide file tree
Showing 3 changed files with 306 additions and 273 deletions.
13 changes: 1 addition & 12 deletions Tutorials/1_Embedding/1.2.5_BGE_EN_ICL.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -44,17 +44,6 @@
"## 1. BGE-EN-ICL structure"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"\n",
"os.environ[\"HF_ENDPOINT\"]=\"https://hf-mirror.com\""
]
},
{
"cell_type": "code",
"execution_count": 2,
Expand Down Expand Up @@ -140,7 +129,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"BERT like encoder only networks are considered with strong capacity for representation learning because of their bidirectional attention structure. Some previous work replace unidirectional attention with bidirectional attention during the embedding training phase. But this might creates a mismatch with the model's pre-training design, which could potentially undermine its in-context learning and generative properties.\n",
"BERT-like encoder only networks are considered with strong capacity for representation learning because of their bidirectional attention structure. Some previous work replace unidirectional attention with bidirectional attention during the embedding training phase. But this might creates a mismatch with the model's pre-training design, which could potentially undermine its in-context learning and generative properties.\n",
"\n",
"Thus BGE-EN-ICL introduces a [EOS] token's output embedding to address this issue."
]
Expand Down
267 changes: 6 additions & 261 deletions Tutorials/7_Fine-tuning/7.1.1_Data_preparation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"# Data preparation for fine-tuning"
"# Data Preparation for Fine-tuning"
]
},
{
Expand All @@ -27,18 +27,7 @@
"metadata": {},
"outputs": [],
"source": [
"# % pip install -U datasets"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"\n",
"os.environ[\"HF_ENDPOINT\"]=\"https://hf-mirror.com\""
"% pip install -U datasets"
]
},
{
Expand All @@ -59,17 +48,9 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/share/project/xzy/Envs/ft/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
},
{
"data": {
"text/plain": [
Expand All @@ -79,7 +60,7 @@
"})"
]
},
"execution_count": 3,
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
Expand Down Expand Up @@ -306,14 +287,14 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"## Test Data for Evaluation"
"## 2. Test Data for Evaluation"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The last step is to construct the testing dataset following the [format](https://github.com/FlagOpen/FlagEmbedding/tree/master/examples/evaluation#8-custom-dataset) for evaluation."
"The last step is to construct the testing dataset for evaluaton."
]
},
{
Expand Down Expand Up @@ -461,242 +442,6 @@
"corpus.to_json(\"ft_data/corpus.jsonl\")\n",
"qrels.to_json(\"ft_data/test_qrels.jsonl\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Finetune"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"from FlagEmbedding import FlagModel\n",
"\n",
"finetuned_path = \"test_encoder_only_base_bge-large-en-v1.5\"\n",
"model_name = \"BAAI/bge-large-en-v1.5\"\n",
"model = FlagModel(finetuned_path, \n",
"# model = FlagModel(model_name,\n",
" query_instruction_for_retrieval=\"Represent this sentence for searching relevant passages:\",\n",
" devices=[0,1],\n",
" use_fp16=False)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"initial target device: 100%|██████████| 2/2 [00:30<00:00, 15.31s/it]\n",
"pre tokenize: 100%|██████████| 2/2 [00:00<00:00, 116.32it/s]\n",
"You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n",
"pre tokenize: 100%|██████████| 2/2 [00:00<00:00, 123.47it/s]\n",
"You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n",
"/share/project/xzy/Envs/ft/lib/python3.11/site-packages/_distutils_hack/__init__.py:54: UserWarning: Reliance on distutils from stdlib is deprecated. Users must rely on setuptools to provide the distutils module. Avoid importing distutils or import setuptools first, and avoid setting SETUPTOOLS_USE_DISTUTILS=stdlib. Register concerns at https://github.com/pypa/setuptools/issues/new?template=distutils-deprecation.yml\n",
" warnings.warn(\n",
"/share/project/xzy/Envs/ft/lib/python3.11/site-packages/_distutils_hack/__init__.py:54: UserWarning: Reliance on distutils from stdlib is deprecated. Users must rely on setuptools to provide the distutils module. Avoid importing distutils or import setuptools first, and avoid setting SETUPTOOLS_USE_DISTUTILS=stdlib. Register concerns at https://github.com/pypa/setuptools/issues/new?template=distutils-deprecation.yml\n",
" warnings.warn(\n",
"Inference Embeddings: 100%|██████████| 2/2 [00:00<00:00, 13.06it/s]\n",
"Inference Embeddings: 100%|██████████| 2/2 [00:00<00:00, 13.14it/s]\n",
"Chunks: 100%|██████████| 2/2 [00:05<00:00, 2.56s/it]\n",
"pre tokenize: 100%|██████████| 14/14 [00:00<00:00, 55.58it/s]\n",
"pre tokenize: 100%|██████████| 14/14 [00:00<00:00, 27.82it/s]\n",
"Inference Embeddings: 100%|██████████| 14/14 [00:02<00:00, 6.24it/s]\n",
"Inference Embeddings: 100%|██████████| 14/14 [00:03<00:00, 4.07it/s]\n",
"Chunks: 100%|██████████| 2/2 [00:04<00:00, 2.05s/it]\n"
]
}
],
"source": [
"queries_text = [q[1] for q in queries.items()]\n",
"corpus_text = [corpus[str(i)][0] for i in range(len(corpus))]\n",
"\n",
"queries_embeddings = model.encode_queries(queries_text)\n",
"corpus_embeddings = model.encode_corpus(corpus_text)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"total number of vectors: 7000\n"
]
}
],
"source": [
"import faiss\n",
"import numpy as np\n",
"\n",
"# get the length of our embedding vectors, vectors by bge-base-en-v1.5 have length 768\n",
"dim = corpus_embeddings.shape[-1]\n",
"\n",
"# create the faiss index and store the corpus embeddings into the vector space\n",
"index = faiss.index_factory(dim, 'Flat', faiss.METRIC_INNER_PRODUCT)\n",
"# corpus_embeddings = corpus_embeddings.astype(np.float32)\n",
"# train and add the embeddings to the index\n",
"index.train(corpus_embeddings)\n",
"index.add(corpus_embeddings)\n",
"\n",
"print(f\"total number of vectors: {index.ntotal}\")"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Searching: 100%|██████████| 22/22 [00:00<00:00, 31.84it/s]\n"
]
}
],
"source": [
"from tqdm import tqdm\n",
"\n",
"query_size = len(queries_embeddings)\n",
"\n",
"all_scores = []\n",
"all_indices = []\n",
"\n",
"for i in tqdm(range(0, query_size, 32), desc=\"Searching\"):\n",
" j = min(i + 32, query_size)\n",
" query_embedding = queries_embeddings[i: j]\n",
" score, indice = index.search(query_embedding.astype(np.float32), k=100)\n",
" all_scores.append(score)\n",
" all_indices.append(indice)\n",
"\n",
"all_scores = np.concatenate(all_scores, axis=0)\n",
"all_indices = np.concatenate(all_indices, axis=0)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"results = {}\n",
"for idx, (scores, indices) in enumerate(zip(all_scores, all_indices)):\n",
" results[queries_ids[idx]] = {}\n",
" for score, index in zip(scores, indices):\n",
" if index != -1:\n",
" results[queries_ids[idx]][corpus_ids[index]] = float(score)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"defaultdict(<class 'list'>, {'NDCG@10': 0.84061, 'NDCG@100': 0.85484})\n",
"defaultdict(<class 'list'>, {'MAP@10': 0.81157, 'MAP@100': 0.81471})\n",
"defaultdict(<class 'list'>, {'Recall@10': 0.93, 'Recall@100': 0.99429})\n",
"defaultdict(<class 'list'>, {'P@10': 0.093, 'P@100': 0.00994})\n",
"defaultdict(<class 'list'>, {'MRR@10': 0.81157, 'MRR@100': 0.81471})\n"
]
}
],
"source": [
"from FlagEmbedding.abc.evaluation.utils import evaluate_metrics, evaluate_mrr\n",
"\n",
"k_values = [10,100]\n",
"eval_res = evaluate_metrics(qrels, results, k_values)\n",
"mrr = evaluate_mrr(qrels, results, k_values)\n",
"\n",
"for res in eval_res:\n",
" print(res)\n",
"print(mrr)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"defaultdict(<class 'list'>, {'NDCG@1': 0.58286, 'NDCG@5': 0.68588, 'NDCG@10': 0.70405})\n",
"defaultdict(<class 'list'>, {'Recall@1': 0.58286, 'Recall@5': 0.76714, 'Recall@10': 0.82286})\n"
]
}
],
"source": [
"# Original test result"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"defaultdict(<class 'list'>, {'NDCG@1': 0.75571, 'NDCG@5': 0.84706, 'NDCG@10': 0.85623})\n",
"defaultdict(<class 'list'>, {'Recall@1': 0.75571, 'Recall@5': 0.92286, 'Recall@10': 0.95143})\n"
]
}
],
"source": [
"# Fake test result"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"You're using a XLMRobertaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[6.453125]\n"
]
}
],
"source": [
"from FlagEmbedding import FlagReranker\n",
"\n",
"reranker = FlagReranker(\n",
" 'BAAI/bge-reranker-base', \n",
" query_max_length=256,\n",
" use_fp16=True,\n",
" devices=['cuda:1'],\n",
")\n",
"\n",
"score = reranker.compute_score(['I am happy to help', 'Assisting you is my pleasure'])\n",
"print(score)"
]
}
],
"metadata": {
Expand Down
Loading

0 comments on commit bae5fab

Please sign in to comment.