Skip to content

Commit

Permalink
feat(llm): store faiss binary and optimze the workflow
Browse files Browse the repository at this point in the history
  • Loading branch information
ramchaik committed Aug 27, 2024
1 parent f330818 commit c0c46b2
Show file tree
Hide file tree
Showing 7 changed files with 50 additions and 7 deletions.
2 changes: 2 additions & 0 deletions llm_api/app.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import traceback
from flask import Flask, request, jsonify
from main import LangChainModel

Expand All @@ -17,6 +18,7 @@ def predict():

return jsonify(result)
except Exception as e:
print(traceback.format_exc())
return jsonify({'error': str(e)}), 500

if __name__ == '__main__':
Expand Down
4 changes: 3 additions & 1 deletion llm_api/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,7 @@

URLS = [
"https://tip.golang.org/tour/concurrency.article",
"https://tip.golang.org/doc/effective_go"
"https://tip.golang.org/doc/effective_go",
# "https://socrates.acadiau.ca/courses/engl/rcunningham/resources/Shpe/Hamlet.pdf",
# "https://gosafir.com/mag/wp-content/uploads/2019/12/Tolkien-J.-The-lord-of-the-rings-HarperCollins-ebooks-2010.pdf"
]
Binary file added llm_api/data/faiss_index.bin
Binary file not shown.
Binary file added llm_api/data/index_to_id.npy
Binary file not shown.
Binary file added llm_api/data/vectors.npy
Binary file not shown.
41 changes: 38 additions & 3 deletions llm_api/embeddings.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,42 @@
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import OllamaEmbeddings
from langchain_community.docstore.in_memory import InMemoryDocstore
import faiss
import numpy as np
import os
from typing import TypedDict, List

def create_vectorstore(documents):
ollama_embeddings = OllamaEmbeddings(model="nomic-embed-text")
vectorstore = FAISS.from_documents(documents, ollama_embeddings)
def create_vectorstore(documents: List[str], force_rebuild: bool = False) -> FAISS:
index_path = "data/faiss_index.bin"
vectors_path = "data/vectors.npy"
index_to_id_path = "data/index_to_id.npy"

if not force_rebuild and os.path.exists(index_path) and os.path.exists(vectors_path) and os.path.exists(index_to_id_path):
# Load existing index and vectors
index = faiss.read_index(index_path)
docstore_dict = np.load(vectors_path, allow_pickle=True).item()
index_to_id = np.load(index_to_id_path, allow_pickle=True).item()

# Create InMemoryDocstore from the loaded dictionary
docstore = InMemoryDocstore(docstore_dict)

ollama_embeddings = OllamaEmbeddings(model="nomic-embed-text")
vectorstore = FAISS(
embedding_function=ollama_embeddings,
index=index,
docstore=docstore,
index_to_docstore_id=index_to_id
)
print("Loaded existing vectorstore from disk.")
else:
# Create new vectorstore
ollama_embeddings = OllamaEmbeddings(model="nomic-embed-text")
vectorstore = FAISS.from_documents(documents, ollama_embeddings)

# Save the index, vectors, and index_to_docstore_id separately
faiss.write_index(vectorstore.index, index_path)
np.save(vectors_path, vectorstore.docstore._dict)
np.save(index_to_id_path, vectorstore.index_to_docstore_id)
print("Created new vectorstore and saved to disk.")

return vectorstore
10 changes: 7 additions & 3 deletions llm_api/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

llm = create_llm(model="phi3:mini")
rag_chain = rag_prompt | llm | StrOutputParser()
retrieval_grader = grading_prompt | create_llm(model="saikatkumardey/tinyllama", format="json") | JsonOutputParser()
retrieval_grader = grading_prompt | create_llm(model="tinyllama", format="json") | JsonOutputParser()

web_search_tool = TavilySearchResults(api_key=TAVILY_API_KEY)

Expand All @@ -52,6 +52,10 @@ def retrieve(state: GraphState) -> GraphState:
logger.info(f"Retrieve time: {time.time() - start_time:.2f} seconds")
return state

@lru_cache(maxsize=100)
def cached_rag_chain_invoke(documents, question):
return rag_chain.invoke({"documents": documents, "question": question})

def generate(state: GraphState) -> GraphState:
start_time = time.time()
docs_content = "\n".join(doc.page_content for doc in state["documents"][:3]) # Limit to top 3 documents
Expand Down Expand Up @@ -87,7 +91,7 @@ def process_batch(batch):
if len(filtered_docs) >= 5: # Top 5 docs
break

state["documents"] = filtered_docs[:3] # Limit to top 3 relevant documents
state["documents"] = filtered_docs[:5] # Limit to top 5 relevant documents
state["search"] = "Yes" if search_needed and len(filtered_docs) < 3 else "No"
state["steps"].append("grade_document_retrieval")
logger.info(f"Grade documents time: {time.time() - start_time:.2f} seconds")
Expand All @@ -102,7 +106,7 @@ def web_search(state: GraphState) -> GraphState:
return state

def decide_to_generate(state: GraphState) -> str:
return "search" if state["search"] == "Yes" and len(state["documents"]) < 3 else "generate"
return "search" if state.get("search") == "Yes" and len(state["documents"]) < 3 else "generate"


def add_nodes(workflow: StateGraph):
Expand Down

0 comments on commit c0c46b2

Please sign in to comment.