From aeace96c48f2e108b8440d15b87badf0cb873ab1 Mon Sep 17 00:00:00 2001 From: "Tuan Anh Nguyen Dang (Tadashi_Cin)" Date: Thu, 7 Nov 2024 17:22:41 +0700 Subject: [PATCH] feat: add lightrag support (#474) bump:patch * feat: add lightrag support * docs: update README --- README.md | 12 + flowsettings.py | 50 +-- libs/ktem/ktem/index/file/graph/__init__.py | 3 +- .../index/file/graph/light_graph_index.py | 26 ++ .../index/file/graph/lightrag_pipelines.py | 386 ++++++++++++++++++ .../ktem/index/file/graph/nano_pipelines.py | 34 +- 6 files changed, 477 insertions(+), 34 deletions(-) create mode 100644 libs/ktem/ktem/index/file/graph/light_graph_index.py create mode 100644 libs/ktem/ktem/index/file/graph/lightrag_pipelines.py diff --git a/README.md b/README.md index f18643c00..0d1e4ab8e 100644 --- a/README.md +++ b/README.md @@ -187,6 +187,18 @@ documents and developers who want to build their own RAG pipeline.
+Setup LIGHTRAG + +- Install LightRAG: `pip install git+https://github.com/HKUDS/LightRAG.git` +- `LightRAG` install might introduce version conflicts, see [this issue](https://github.com/Cinnamon/kotaemon/issues/440) + - To quickly fix: `pip uninstall hnswlib chroma-hnswlib && pip install chroma-hnswlib` +- Launch Kotaemon with `USE_LIGHTRAG=true` environment variable. +- Set your default LLM & Embedding models in Resources setting and it will be recognized automatically from LightRAG. + +
+ +
+ Setup MS GRAPHRAG - **Non-Docker Installation**: If you are not using Docker, install GraphRAG with the following command: diff --git a/flowsettings.py b/flowsettings.py index 5513f582c..a2f9e4c41 100644 --- a/flowsettings.py +++ b/flowsettings.py @@ -287,41 +287,31 @@ } USE_NANO_GRAPHRAG = config("USE_NANO_GRAPHRAG", default=False, cast=bool) -GRAPHRAG_INDEX_TYPE = ( - "ktem.index.file.graph.GraphRAGIndex" - if not USE_NANO_GRAPHRAG - else "ktem.index.file.graph.NanoGraphRAGIndex" -) +USE_LIGHTRAG = config("USE_LIGHTRAG", default=False, cast=bool) + +if USE_NANO_GRAPHRAG: + GRAPHRAG_INDEX_TYPE = "ktem.index.file.graph.NanoGraphRAGIndex" +elif USE_LIGHTRAG: + GRAPHRAG_INDEX_TYPE = "ktem.index.file.graph.LightRAGIndex" +else: + GRAPHRAG_INDEX_TYPE = "ktem.index.file.graph.GraphRAGIndex" + KH_INDEX_TYPES = [ "ktem.index.file.FileIndex", GRAPHRAG_INDEX_TYPE, ] -GRAPHRAG_INDEX = ( - { - "name": "GraphRAG", - "config": { - "supported_file_types": ( - ".png, .jpeg, .jpg, .tiff, .tif, .pdf, .xls, .xlsx, .doc, .docx, " - ".pptx, .csv, .html, .mhtml, .txt, .md, .zip" - ), - "private": False, - }, - "index_type": "ktem.index.file.graph.GraphRAGIndex", - } - if not USE_NANO_GRAPHRAG - else { - "name": "NanoGraphRAG", - "config": { - "supported_file_types": ( - ".png, .jpeg, .jpg, .tiff, .tif, .pdf, .xls, .xlsx, .doc, .docx, " - ".pptx, .csv, .html, .mhtml, .txt, .md, .zip" - ), - "private": False, - }, - "index_type": "ktem.index.file.graph.NanoGraphRAGIndex", - } -) +GRAPHRAG_INDEX = { + "name": GRAPHRAG_INDEX_TYPE.split(".")[-1].replace("Index", ""), # get last name + "config": { + "supported_file_types": ( + ".png, .jpeg, .jpg, .tiff, .tif, .pdf, .xls, .xlsx, .doc, .docx, " + ".pptx, .csv, .html, .mhtml, .txt, .md, .zip" + ), + "private": False, + }, + "index_type": GRAPHRAG_INDEX_TYPE, +} KH_INDICES = [ { diff --git a/libs/ktem/ktem/index/file/graph/__init__.py b/libs/ktem/ktem/index/file/graph/__init__.py index e2836f704..afe1db443 100644 --- a/libs/ktem/ktem/index/file/graph/__init__.py +++ b/libs/ktem/ktem/index/file/graph/__init__.py @@ -1,4 +1,5 @@ from .graph_index import GraphRAGIndex +from .light_graph_index import LightRAGIndex from .nano_graph_index import NanoGraphRAGIndex -__all__ = ["GraphRAGIndex", "NanoGraphRAGIndex"] +__all__ = ["GraphRAGIndex", "NanoGraphRAGIndex", "LightRAGIndex"] diff --git a/libs/ktem/ktem/index/file/graph/light_graph_index.py b/libs/ktem/ktem/index/file/graph/light_graph_index.py new file mode 100644 index 000000000..583945eeb --- /dev/null +++ b/libs/ktem/ktem/index/file/graph/light_graph_index.py @@ -0,0 +1,26 @@ +from typing import Any + +from ..base import BaseFileIndexRetriever +from .graph_index import GraphRAGIndex +from .lightrag_pipelines import LightRAGIndexingPipeline, LightRAGRetrieverPipeline + + +class LightRAGIndex(GraphRAGIndex): + def _setup_indexing_cls(self): + self._indexing_pipeline_cls = LightRAGIndexingPipeline + + def _setup_retriever_cls(self): + self._retriever_pipeline_cls = [LightRAGRetrieverPipeline] + + def get_retriever_pipelines( + self, settings: dict, user_id: int, selected: Any = None + ) -> list["BaseFileIndexRetriever"]: + _, file_ids, _ = selected + retrievers = [ + LightRAGRetrieverPipeline( + file_ids=file_ids, + Index=self._resources["Index"], + ) + ] + + return retrievers diff --git a/libs/ktem/ktem/index/file/graph/lightrag_pipelines.py b/libs/ktem/ktem/index/file/graph/lightrag_pipelines.py new file mode 100644 index 000000000..5144828d8 --- /dev/null +++ b/libs/ktem/ktem/index/file/graph/lightrag_pipelines.py @@ -0,0 +1,386 @@ +import asyncio +import glob +import logging +import os +import re +from pathlib import Path +from typing import Generator + +import numpy as np +import pandas as pd +from ktem.db.models import engine +from ktem.embeddings.manager import embedding_models_manager as embeddings +from ktem.llms.manager import llms +from sqlalchemy.orm import Session +from theflow.settings import settings + +from kotaemon.base import Document, Param, RetrievedDocument +from kotaemon.base.schema import AIMessage, HumanMessage, SystemMessage + +from ..pipelines import BaseFileIndexRetriever +from .pipelines import GraphRAGIndexingPipeline +from .visualize import create_knowledge_graph, visualize_graph + +try: + from lightrag import LightRAG, QueryParam + from lightrag.operate import ( + _find_most_related_edges_from_entities, + _find_most_related_text_unit_from_entities, + ) + from lightrag.utils import EmbeddingFunc, compute_args_hash + +except ImportError: + print( + ( + "LightRAG dependencies not installed. " + "Try `pip install git+https://github.com/HKUDS/LightRAG.git` to install. " + "LighthRAG retriever pipeline will not work properly." + ) + ) + + +logging.getLogger("lightrag").setLevel(logging.INFO) + + +filestorage_path = Path(settings.KH_FILESTORAGE_PATH) / "lightrag" +filestorage_path.mkdir(parents=True, exist_ok=True) + +INDEX_BATCHSIZE = 2 + + +def get_llm_func(model): + async def llm_func( + prompt, system_prompt=None, history_messages=[], **kwargs + ) -> str: + input_messages = [SystemMessage(text=system_prompt)] if system_prompt else [] + + hashing_kv = kwargs.pop("hashing_kv", None) + if history_messages: + for msg in history_messages: + if msg.get("role") == "user": + input_messages.append(HumanMessage(text=msg["content"])) + else: + input_messages.append(AIMessage(text=msg["content"])) + + input_messages.append(HumanMessage(text=prompt)) + + if hashing_kv is not None: + args_hash = compute_args_hash("model", input_messages) + if_cache_return = await hashing_kv.get_by_id(args_hash) + if if_cache_return is not None: + return if_cache_return["return"] + + output = model(input_messages).text + + print("-" * 50) + print(output, "\n", "-" * 50) + + if hashing_kv is not None: + await hashing_kv.upsert({args_hash: {"return": output, "model": "model"}}) + + return output + + return llm_func + + +def get_embedding_func(model): + async def embedding_func(texts: list[str]) -> np.ndarray: + outputs = model(texts) + embedding_outputs = np.array([doc.embedding for doc in outputs]) + + return embedding_outputs + + return embedding_func + + +def get_default_models_wrapper(): + # setup model functions + default_embedding = embeddings.get_default() + default_embedding_dim = len(default_embedding(["Hi"])[0].embedding) + embedding_func = EmbeddingFunc( + embedding_dim=default_embedding_dim, + max_token_size=8192, + func=get_embedding_func(default_embedding), + ) + print("GraphRAG embedding dim", default_embedding_dim) + + default_llm = llms.get_default() + llm_func = get_llm_func(default_llm) + + return llm_func, embedding_func, default_llm, default_embedding + + +def prepare_graph_index_path(graph_id: str): + root_path = Path(filestorage_path) / graph_id + input_path = root_path / "input" + + return root_path, input_path + + +def list_of_list_to_df(data: list[list]) -> pd.DataFrame: + df = pd.DataFrame(data[1:], columns=data[0]) + return df + + +def clean_quote(input: str) -> str: + return re.sub(r"[\"']", "", input) + + +async def lightrag_build_local_query_context( + graph_func, + query, + query_param, +): + knowledge_graph_inst = graph_func.chunk_entity_relation_graph + entities_vdb = graph_func.entities_vdb + text_chunks_db = graph_func.text_chunks + + results = await entities_vdb.query(query, top_k=query_param.top_k) + if not len(results): + raise ValueError("No results found") + + node_datas = await asyncio.gather( + *[knowledge_graph_inst.get_node(r["entity_name"]) for r in results] + ) + node_degrees = await asyncio.gather( + *[knowledge_graph_inst.node_degree(r["entity_name"]) for r in results] + ) + node_datas = [ + {**n, "entity_name": k["entity_name"], "rank": d} + for k, n, d in zip(results, node_datas, node_degrees) + if n is not None + ] + use_text_units = await _find_most_related_text_unit_from_entities( + node_datas, query_param, text_chunks_db, knowledge_graph_inst + ) + use_relations = await _find_most_related_edges_from_entities( + node_datas, query_param, knowledge_graph_inst + ) + logging.info( + f"Local query uses {len(node_datas)} entities, " + f"{len(use_relations)} relations, {len(use_text_units)} text units" + ) + entites_section_list = [["id", "entity", "type", "description", "rank"]] + for i, n in enumerate(node_datas): + entites_section_list.append( + [ + str(i), + clean_quote(n["entity_name"]), + n.get("entity_type", "UNKNOWN"), + clean_quote(n.get("description", "UNKNOWN")), + n["rank"], + ] + ) + entities_df = list_of_list_to_df(entites_section_list) + + relations_section_list = [ + ["id", "source", "target", "description", "keywords", "weight", "rank"] + ] + for i, e in enumerate(use_relations): + relations_section_list.append( + [ + str(i), + clean_quote(e["src_tgt"][0]), + clean_quote(e["src_tgt"][1]), + clean_quote(e["description"]), + e["keywords"], + e["weight"], + e["rank"], + ] + ) + relations_df = list_of_list_to_df(relations_section_list) + + text_units_section_list = [["id", "content"]] + for i, t in enumerate(use_text_units): + text_units_section_list.append([str(i), t["content"]]) + sources_df = list_of_list_to_df(text_units_section_list) + + return entities_df, relations_df, sources_df + + +def build_graphrag(working_dir, llm_func, embedding_func): + graphrag_func = LightRAG( + working_dir=working_dir, + llm_model_func=llm_func, + embedding_func=embedding_func, + ) + return graphrag_func + + +class LightRAGIndexingPipeline(GraphRAGIndexingPipeline): + """GraphRAG specific indexing pipeline""" + + def call_graphrag_index(self, graph_id: str, docs: list[Document]): + _, input_path = prepare_graph_index_path(graph_id) + input_path.mkdir(parents=True, exist_ok=True) + + ( + llm_func, + embedding_func, + default_llm, + default_embedding, + ) = get_default_models_wrapper() + print( + f"Indexing GraphRAG with LLM {default_llm} " + f"and Embedding {default_embedding}..." + ) + + all_docs = [ + doc.text for doc in docs if doc.metadata.get("type", "text") == "text" + ] + + yield Document( + channel="debug", + text="[GraphRAG] Creating index... This can take a long time.", + ) + + # remove all .json files in the input_path directory (previous cache) + json_files = glob.glob(f"{input_path}/*.json") + for json_file in json_files: + os.remove(json_file) + + # indexing + graphrag_func = build_graphrag( + input_path, + llm_func=llm_func, + embedding_func=embedding_func, + ) + # output must be contain: Loaded graph from + # ..input/graph_chunk_entity_relation.graphml with xxx nodes, xxx edges + total_docs = len(all_docs) + process_doc_count = 0 + yield Document( + channel="debug", + text=f"[GraphRAG] Indexed {process_doc_count} / {total_docs} documents.", + ) + + for doc_id in range(0, len(all_docs), INDEX_BATCHSIZE): + cur_docs = all_docs[doc_id : doc_id + INDEX_BATCHSIZE] + graphrag_func.insert(cur_docs) + process_doc_count += len(cur_docs) + yield Document( + channel="debug", + text=( + f"[GraphRAG] Indexed {process_doc_count} " + f"/ {total_docs} documents." + ), + ) + + yield Document( + channel="debug", + text="[GraphRAG] Indexing finished.", + ) + + def stream( + self, file_paths: str | Path | list[str | Path], reindex: bool = False, **kwargs + ) -> Generator[ + Document, None, tuple[list[str | None], list[str | None], list[Document]] + ]: + file_ids, errors, all_docs = yield from super().stream( + file_paths, reindex=reindex, **kwargs + ) + + return file_ids, errors, all_docs + + +class LightRAGRetrieverPipeline(BaseFileIndexRetriever): + """GraphRAG specific retriever pipeline""" + + Index = Param(help="The SQLAlchemy Index table") + file_ids: list[str] = [] + + def _build_graph_search(self): + file_id = self.file_ids[0] + + # retrieve the graph_id from the index + with Session(engine) as session: + graph_id = ( + session.query(self.Index.target_id) + .filter(self.Index.source_id == file_id) + .filter(self.Index.relation_type == "graph") + .first() + ) + graph_id = graph_id[0] if graph_id else None + assert graph_id, f"GraphRAG index not found for file_id: {file_id}" + + _, input_path = prepare_graph_index_path(graph_id) + input_path.mkdir(parents=True, exist_ok=True) + + llm_func, embedding_func, _, _ = get_default_models_wrapper() + graphrag_func = build_graphrag( + input_path, + llm_func=llm_func, + embedding_func=embedding_func, + ) + query_params = QueryParam(mode="local", only_need_context=True) + + return graphrag_func, query_params + + def _to_document(self, header: str, context_text: str) -> RetrievedDocument: + return RetrievedDocument( + text=context_text, + metadata={ + "file_name": header, + "type": "table", + "llm_trulens_score": 1.0, + }, + score=1.0, + ) + + def format_context_records( + self, entities, relationships, sources + ) -> list[RetrievedDocument]: + docs = [] + context: str = "" + + # entities current parsing error + header = "Entities\n" + context = entities[["entity", "description"]].to_markdown(index=False) + docs.append(self._to_document(header, context)) + + header = "\nRelationships\n" + context = relationships[["source", "target", "description"]].to_markdown( + index=False + ) + docs.append(self._to_document(header, context)) + + header = "\nSources\n" + context = "" + for _, row in sources.iterrows(): + title, content = row["id"], row["content"] + context += f"\n\n
Source #{title}
\n" + context += content + docs.append(self._to_document(header, context)) + + return docs + + def plot_graph(self, relationships): + G = create_knowledge_graph(relationships) + plot = visualize_graph(G) + return plot + + def run( + self, + text: str, + ) -> list[RetrievedDocument]: + if not self.file_ids: + return [] + + graphrag_func, query_params = self._build_graph_search() + entities, relationships, sources = asyncio.run( + lightrag_build_local_query_context(graphrag_func, text, query_params) + ) + + documents = self.format_context_records(entities, relationships, sources) + plot = self.plot_graph(relationships) + + return documents + [ + RetrievedDocument( + text="", + metadata={ + "file_name": "GraphRAG", + "type": "plot", + "data": plot, + }, + ), + ] diff --git a/libs/ktem/ktem/index/file/graph/nano_pipelines.py b/libs/ktem/ktem/index/file/graph/nano_pipelines.py index b3f6b9594..332edcd1c 100644 --- a/libs/ktem/ktem/index/file/graph/nano_pipelines.py +++ b/libs/ktem/ktem/index/file/graph/nano_pipelines.py @@ -28,7 +28,7 @@ _find_most_related_edges_from_entities, _find_most_related_text_unit_from_entities, ) - from nano_graphrag._utils import EmbeddingFunc + from nano_graphrag._utils import EmbeddingFunc, compute_args_hash except ImportError: print( @@ -46,6 +46,8 @@ filestorage_path = Path(settings.KH_FILESTORAGE_PATH) / "nano_graphrag" filestorage_path.mkdir(parents=True, exist_ok=True) +INDEX_BATCHSIZE = 4 + def get_llm_func(model): async def llm_func( @@ -53,6 +55,7 @@ async def llm_func( ) -> str: input_messages = [SystemMessage(text=system_prompt)] if system_prompt else [] + hashing_kv = kwargs.pop("hashing_kv", None) if history_messages: for msg in history_messages: if msg.get("role") == "user": @@ -61,11 +64,21 @@ async def llm_func( input_messages.append(AIMessage(text=msg["content"])) input_messages.append(HumanMessage(text=prompt)) + + if hashing_kv is not None: + args_hash = compute_args_hash("model", input_messages) + if_cache_return = await hashing_kv.get_by_id(args_hash) + if if_cache_return is not None: + return if_cache_return["return"] + output = model(input_messages).text print("-" * 50) print(output, "\n", "-" * 50) + if hashing_kv is not None: + await hashing_kv.upsert({args_hash: {"return": output, "model": "model"}}) + return output return llm_func @@ -196,7 +209,6 @@ def build_graphrag(working_dir, llm_func, embedding_func): best_model_func=llm_func, cheap_model_func=llm_func, embedding_func=embedding_func, - embedding_func_max_async=4, ) return graphrag_func @@ -241,7 +253,23 @@ def call_graphrag_index(self, graph_id: str, docs: list[Document]): ) # output must be contain: Loaded graph from # ..input/graph_chunk_entity_relation.graphml with xxx nodes, xxx edges - graphrag_func.insert(all_docs) + total_docs = len(all_docs) + process_doc_count = 0 + yield Document( + channel="debug", + text=f"[GraphRAG] Indexed {process_doc_count} / {total_docs} documents.", + ) + for doc_id in range(0, len(all_docs), INDEX_BATCHSIZE): + cur_docs = all_docs[doc_id : doc_id + INDEX_BATCHSIZE] + graphrag_func.insert(cur_docs) + process_doc_count += len(cur_docs) + yield Document( + channel="debug", + text=( + f"[GraphRAG] Indexed {process_doc_count} " + f"/ {total_docs} documents." + ), + ) yield Document( channel="debug",