Skip to content

Commit

Permalink
Improve the Indexer
Browse files Browse the repository at this point in the history
* Index with batches to prevent rate limiting
* Clean-up table after test run
* Add more logs
  • Loading branch information
muralov committed Oct 30, 2024
1 parent f853909 commit 6408e80
Show file tree
Hide file tree
Showing 6 changed files with 83 additions and 47 deletions.
39 changes: 28 additions & 11 deletions doc_indexer/src/indexing/indexer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import time
from typing import Protocol

from hdbcli import dbapi
Expand All @@ -10,6 +11,8 @@
from langchain_core.embeddings import Embeddings
from langchain_text_splitters import MarkdownHeaderTextSplitter

from utils.settings import CHUNKS_BATCH_SIZE


def create_chunks(
documents: list[Document], headers_to_split_on: list[tuple[str, str]]
Expand Down Expand Up @@ -56,9 +59,6 @@ def __init__(
if not table_name:
table_name = docs_path.split("/")[-1]

# if not headers_to_split_on:
# headers_to_split_on = [HEADER1]

self.docs_path = docs_path
self.table_name = table_name
self.embedding = embedding
Expand Down Expand Up @@ -96,15 +96,32 @@ def index(self) -> None:
f"Indexing {len(all_chunks)} markdown files chunks for {self.table_name}..."
)

# deletion is necessary to avoid duplicates
logging.info("Deleting existing documents in HanaDB...")
try:
self.db.delete(filter={})
except Exception:
logging.error("Error while deleting existing documents in HanaDB.")
raise

try:
self.db.add_documents(all_chunks)
except Exception:
logging.error("Error while storing documents chunks in HanaDB.")
logging.exception("Error while deleting existing documents in HanaDB.")
raise
logging.info("Successfully deleted existing documents in HanaDB.")

logging.info("Adding documents to HanaDB...")
for i in range(0, len(all_chunks), CHUNKS_BATCH_SIZE):
batch = all_chunks[i : i + CHUNKS_BATCH_SIZE]
try:
# Add current batch of documents
self.db.add_documents(batch)
logging.info(
f"Indexed batch {i//CHUNKS_BATCH_SIZE + 1} of {len(all_chunks)//CHUNKS_BATCH_SIZE + 1}"
)

# Wait 3 seconds before processing next batch
if i + CHUNKS_BATCH_SIZE < len(all_chunks):
time.sleep(3)

except Exception as e:
logging.error(
f"Error while storing documents batch {i//CHUNKS_BATCH_SIZE + 1} in HanaDB: {str(e)}"
)
raise

logging.info(f"Successfully indexed {len(all_chunks)} markdown files chunks.")
3 changes: 2 additions & 1 deletion doc_indexer/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
DATABASE_URL,
DATABASE_USER,
DOCS_PATH,
DOCS_TABLE_NAME,
EMBEDDING_MODEL_DEPLOYMENT_ID,
)

Expand All @@ -25,7 +26,7 @@ def main() -> None:
DATABASE_URL, DATABASE_PORT, DATABASE_USER, DATABASE_PASSWORD
)

indexer = MarkdownIndexer(DOCS_PATH, embeddings_model, hana_conn)
indexer = MarkdownIndexer(DOCS_PATH, embeddings_model, hana_conn, DOCS_TABLE_NAME)
indexer.index()


Expand Down
2 changes: 2 additions & 0 deletions doc_indexer/src/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from gen_ai_hub.proxy.langchain.openai import OpenAIEmbeddings
from langchain_core.embeddings import Embeddings

# TODO: re-use the model factory parent project


def create_embedding_factory(
embedding_creator: Callable[[str, Any], Embeddings]
Expand Down
6 changes: 5 additions & 1 deletion doc_indexer/src/utils/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from decouple import Config, RepositoryEnv, config
from dotenv import find_dotenv, load_dotenv

# TODO: re-use the settings parent project


def is_running_pytest() -> bool:
"""Check if the code is running with pytest.
Expand All @@ -30,7 +32,9 @@ def is_running_pytest() -> bool:
EMBEDDING_MODEL_DEPLOYMENT_ID = config("EMBEDDING_MODEL_DEPLOYMENT_ID")
EMBEDDING_MODEL_NAME = config("EMBEDDING_MODEL_NAME")

DOCS_PATH = config("DOCS_PATH", default=None)
DOCS_PATH = config("DOCS_PATH", default="data/output")
DOCS_TABLE_NAME = config("DOCS_TABLE_NAME", default="kyma_docs")
CHUNKS_BATCH_SIZE = config("CHUNKS_BATCH_SIZE", cast=int, default=100)
DATABASE_URL = config("DATABASE_URL")
DATABASE_PORT = config("DATABASE_PORT", cast=int)
DATABASE_USER = config("DATABASE_USER")
Expand Down
9 changes: 7 additions & 2 deletions doc_indexer/tests/integration/indexing/test_indexer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
import uuid
from unittest.mock import patch

Expand Down Expand Up @@ -42,9 +43,13 @@ def indexer(embedding_model, hana_conn, table_name):
indexer = MarkdownIndexer("", embedding_model, hana_conn, table_name=table_name)
yield indexer
try:
indexer.db.drop_table()
logging.info(f"Dropping table {table_name}")
cursor = hana_conn.cursor()
# Add double quotes around both schema and table names
cursor.execute(f'DROP TABLE "{DATABASE_USER}"."{table_name}"')
cursor.close()
except Exception as e:
print(f"Error while dropping table: {e}")
logging.error(f"Error while dropping table: {e}")


@pytest.fixture
Expand Down
71 changes: 39 additions & 32 deletions doc_indexer/tests/unit/indexing/test_indexer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import Any
from unittest.mock import Mock, patch
from unittest.mock import Mock, call, patch

import pytest
from indexing.indexer import MarkdownIndexer, create_chunks
Expand Down Expand Up @@ -185,56 +185,39 @@ def test_load_documents(
"test_case,headers_to_split_on,loaded_docs,expected_chunks,delete_error,add_error,expected_exception",
[
(
"Default header",
"Single batch",
None,
[
Document(
page_content="# My Header 1\nContent",
)
],
[
Document(
page_content="# My second Header 1 \nContent",
)
],
[Document(page_content="# My Header 1\nContent")],
[Document(page_content="# My Header 1\nContent")],
None,
None,
None,
),
(
"Custom headers",
[("##", "Header2")],
[
Document(
page_content="# H1\n## H2\nContent",
)
],
"Multiple batches",
None,
[
Document(
page_content="# H1\n", metadata={"source": "/test/docs/file1.md"}
),
Document(
page_content="## H2\nContent",
),
],
Document(page_content=f"# Header {i}\nContent") for i in range(6)
], # Assuming CHUNKS_BATCH_SIZE is 5
[Document(page_content=f"# Header {i}\nContent") for i in range(6)],
None,
None,
None,
),
(
"Delete error",
None,
[],
[],
[Document(page_content="# My Header 1\nContent")],
[Document(page_content="# My Header 1\nContent")],
Exception("Delete error"),
None,
Exception,
),
(
"Add documents error",
None,
[],
[],
[Document(page_content="# My Header 1\nContent")],
[Document(page_content="# My Header 1\nContent")],
None,
Exception("Add documents error"),
Exception,
Expand All @@ -257,17 +240,41 @@ def test_index(

with patch.object(indexer, "_load_documents", return_value=loaded_docs), patch(
"indexing.indexer.create_chunks", return_value=expected_chunks
) as mock_create_chunks:
) as mock_create_chunks, patch("indexing.indexer.CHUNKS_BATCH_SIZE", 5), patch(
"time.sleep"
) as mock_sleep: # Add mock for sleep

indexer.headers_to_split_on = headers_to_split_on

if expected_exception:
with pytest.raises(expected_exception):
indexer.index()
else:
indexer.index()

# Verify create_chunks was called
mock_create_chunks.assert_called_once_with(loaded_docs, headers_to_split_on)

# Verify delete was called
indexer.db.delete.assert_called_once_with(filter={})
indexer.db.add_documents.assert_called_once_with(expected_chunks)

# Calculate expected number of batches
num_chunks = len(expected_chunks)
batch_size = 5 # From the mocked CHUNKS_BATCH_SIZE
expected_batches = [
expected_chunks[i : i + batch_size]
for i in range(0, num_chunks, batch_size)
]

# Verify add_documents was called for each batch
assert indexer.db.add_documents.call_count == len(expected_batches)
for i, batch in enumerate(expected_batches):
assert indexer.db.add_documents.call_args_list[i] == call(batch)

# Verify sleep was called between batches
if len(expected_batches) > 1:
assert mock_sleep.call_count == len(expected_batches) - 1
mock_sleep.assert_has_calls([call(3)] * (len(expected_batches) - 1))

if delete_error:
assert str(delete_error) in str(indexer.db.delete.side_effect)
Expand Down

0 comments on commit 6408e80

Please sign in to comment.