From 650b8e38b7971cd4e0776863f99d13816e840773 Mon Sep 17 00:00:00 2001 From: Magic_yuan <317617749@qq.com> Date: Sat, 28 Dec 2024 00:11:25 +0800 Subject: [PATCH 1/2] =?UTF-8?q?feat(lightrag):=20Add=20document=20status?= =?UTF-8?q?=20tracking=20and=20checkpoint=20support=20=E5=8A=9F=E8=83=BD(l?= =?UTF-8?q?ightrag):=20=E6=B7=BB=E5=8A=A0=E6=96=87=E6=A1=A3=E7=8A=B6?= =?UTF-8?q?=E6=80=81=E8=B7=9F=E8=B8=AA=E5=92=8C=E6=96=AD=E7=82=B9=E7=BB=AD?= =?UTF-8?q?=E4=BC=A0=E6=94=AF=E6=8C=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add DocStatus enum and DocProcessingStatus class for document processing state management - 添加 DocStatus 枚举和 DocProcessingStatus 类用于文档处理状态管理 - Implement JsonDocStatusStorage for persistent status storage - 实现 JsonDocStatusStorage 用于持久化状态存储 - Add document-level deduplication in batch processing - 在批处理中添加文档级别的去重功能 - Add checkpoint support in ainsert method for resumable document processing - 在 ainsert 方法中添加断点续传支持,实现可恢复的文档处理 - Add status query methods for monitoring processing progress - 添加状态查询方法用于监控处理进度 - Update LightRAG initialization to support document status tracking - 更新 LightRAG 初始化以支持文档状态跟踪 --- lightrag/base.py | 42 +++++++- lightrag/kg/age_impl.py | 5 +- lightrag/lightrag.py | 229 +++++++++++++++++++++++++++++----------- lightrag/storage.py | 49 ++++++++- 4 files changed, 260 insertions(+), 65 deletions(-) diff --git a/lightrag/base.py b/lightrag/base.py index f5a6e0c0..c3ba3e09 100644 --- a/lightrag/base.py +++ b/lightrag/base.py @@ -1,5 +1,6 @@ from dataclasses import dataclass, field -from typing import TypedDict, Union, Literal, Generic, TypeVar +from typing import TypedDict, Union, Literal, Generic, TypeVar, Optional, Dict, Any +from enum import Enum import numpy as np @@ -129,3 +130,42 @@ async def delete_node(self, node_id: str): async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]: raise NotImplementedError("Node embedding is not used in lightrag.") + + +class DocStatus(str, Enum): + """Document processing status enum""" + + PENDING = "pending" + PROCESSING = "processing" + PROCESSED = "processed" + FAILED = "failed" + + +@dataclass +class DocProcessingStatus: + """Document processing status data structure""" + + content_summary: str # First 100 chars of document content + content_length: int # Total length of document + status: DocStatus # Current processing status + created_at: str # ISO format timestamp + updated_at: str # ISO format timestamp + chunks_count: Optional[int] = None # Number of chunks after splitting + error: Optional[str] = None # Error message if failed + metadata: Dict[str, Any] = field(default_factory=dict) # Additional metadata + + +class DocStatusStorage(BaseKVStorage): + """Base class for document status storage""" + + async def get_status_counts(self) -> Dict[str, int]: + """Get counts of documents in each status""" + raise NotImplementedError + + async def get_failed_docs(self) -> Dict[str, DocProcessingStatus]: + """Get all failed documents""" + raise NotImplementedError + + async def get_pending_docs(self) -> Dict[str, DocProcessingStatus]: + """Get all pending documents""" + raise NotImplementedError diff --git a/lightrag/kg/age_impl.py b/lightrag/kg/age_impl.py index 2a97bc37..275f5775 100644 --- a/lightrag/kg/age_impl.py +++ b/lightrag/kg/age_impl.py @@ -1,7 +1,8 @@ import asyncio import inspect import json -import os, sys +import os +import sys from contextlib import asynccontextmanager from dataclasses import dataclass from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Union @@ -22,8 +23,10 @@ if sys.platform.startswith("win"): import asyncio.windows_events + asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) + class AGEQueryException(Exception): """Exception for the AGE queries.""" diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 51ac204d..992c43a4 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -4,7 +4,7 @@ from dataclasses import asdict, dataclass, field from datetime import datetime from functools import partial -from typing import Type, cast +from typing import Type, cast, Dict from .llm import ( gpt_4o_mini_complete, @@ -32,12 +32,14 @@ BaseVectorStorage, StorageNameSpace, QueryParam, + DocStatus, ) from .storage import ( JsonKVStorage, NanoVectorDBStorage, NetworkXStorage, + JsonDocStatusStorage, ) # future KG integrations @@ -172,6 +174,9 @@ class LightRAG: addon_params: dict = field(default_factory=dict) convert_response_to_json_func: callable = convert_response_to_json + # Add new field for document status storage type + doc_status_storage: str = field(default="JsonDocStatusStorage") + def __post_init__(self): log_file = os.path.join("lightrag.log") set_logger(log_file) @@ -263,7 +268,15 @@ def __post_init__(self): ) ) - def _get_storage_class(self) -> Type[BaseGraphStorage]: + # Initialize document status storage + self.doc_status_storage_cls = self._get_storage_class()[self.doc_status_storage] + self.doc_status = self.doc_status_storage_cls( + namespace="doc_status", + global_config=asdict(self), + embedding_func=None, + ) + + def _get_storage_class(self) -> dict: return { # kv storage "JsonKVStorage": JsonKVStorage, @@ -284,6 +297,7 @@ def _get_storage_class(self) -> Type[BaseGraphStorage]: "TiDBGraphStorage": TiDBGraphStorage, "GremlinStorage": GremlinStorage, # "ArangoDBStorage": ArangoDBStorage + "JsonDocStatusStorage": JsonDocStatusStorage, } def insert(self, string_or_strings): @@ -291,71 +305,139 @@ def insert(self, string_or_strings): return loop.run_until_complete(self.ainsert(string_or_strings)) async def ainsert(self, string_or_strings): - update_storage = False - try: - if isinstance(string_or_strings, str): - string_or_strings = [string_or_strings] - - new_docs = { - compute_mdhash_id(c.strip(), prefix="doc-"): {"content": c.strip()} - for c in string_or_strings + """Insert documents with checkpoint support + + Args: + string_or_strings: Single document string or list of document strings + """ + if isinstance(string_or_strings, str): + string_or_strings = [string_or_strings] + + # 1. Remove duplicate contents from the list + unique_contents = list(set(doc.strip() for doc in string_or_strings)) + + # 2. Generate document IDs and initial status + new_docs = { + compute_mdhash_id(content, prefix="doc-"): { + "content": content, + "content_summary": self._get_content_summary(content), + "content_length": len(content), + "status": DocStatus.PENDING, + "created_at": datetime.now().isoformat(), + "updated_at": datetime.now().isoformat(), } - _add_doc_keys = await self.full_docs.filter_keys(list(new_docs.keys())) - new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys} - if not len(new_docs): - logger.warning("All docs are already in the storage") - return - update_storage = True - logger.info(f"[New Docs] inserting {len(new_docs)} docs") - - inserting_chunks = {} - for doc_key, doc in tqdm_async( - new_docs.items(), desc="Chunking documents", unit="doc" + for content in unique_contents + } + + # 3. Filter out already processed documents + _add_doc_keys = await self.doc_status.filter_keys(list(new_docs.keys())) + new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys} + + if not new_docs: + logger.info("All documents have been processed or are duplicates") + return + + logger.info(f"Processing {len(new_docs)} new unique documents") + + # Process documents in batches + batch_size = self.addon_params.get("insert_batch_size", 10) + for i in range(0, len(new_docs), batch_size): + batch_docs = dict(list(new_docs.items())[i : i + batch_size]) + + for doc_id, doc in tqdm_async( + batch_docs.items(), desc=f"Processing batch {i//batch_size + 1}" ): - chunks = { - compute_mdhash_id(dp["content"], prefix="chunk-"): { - **dp, - "full_doc_id": doc_key, + try: + # Update status to processing + doc_status = { + "content_summary": doc["content_summary"], + "content_length": doc["content_length"], + "status": DocStatus.PROCESSING, + "created_at": doc["created_at"], + "updated_at": datetime.now().isoformat(), } - for dp in chunking_by_token_size( - doc["content"], - overlap_token_size=self.chunk_overlap_token_size, - max_token_size=self.chunk_token_size, - tiktoken_model=self.tiktoken_model_name, + await self.doc_status.upsert({doc_id: doc_status}) + + # Generate chunks from document + chunks = { + compute_mdhash_id(dp["content"], prefix="chunk-"): { + **dp, + "full_doc_id": doc_id, + } + for dp in chunking_by_token_size( + doc["content"], + overlap_token_size=self.chunk_overlap_token_size, + max_token_size=self.chunk_token_size, + tiktoken_model=self.tiktoken_model_name, + ) + } + + # Update status with chunks information + doc_status.update( + { + "chunks_count": len(chunks), + "updated_at": datetime.now().isoformat(), + } ) - } - inserting_chunks.update(chunks) - _add_chunk_keys = await self.text_chunks.filter_keys( - list(inserting_chunks.keys()) - ) - inserting_chunks = { - k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys - } - if not len(inserting_chunks): - logger.warning("All chunks are already in the storage") - return - logger.info(f"[New Chunks] inserting {len(inserting_chunks)} chunks") - - await self.chunks_vdb.upsert(inserting_chunks) - - logger.info("[Entity Extraction]...") - maybe_new_kg = await extract_entities( - inserting_chunks, - knowledge_graph_inst=self.chunk_entity_relation_graph, - entity_vdb=self.entities_vdb, - relationships_vdb=self.relationships_vdb, - global_config=asdict(self), - ) - if maybe_new_kg is None: - logger.warning("No new entities and relationships found") - return - self.chunk_entity_relation_graph = maybe_new_kg + await self.doc_status.upsert({doc_id: doc_status}) + + try: + # Store chunks in vector database + await self.chunks_vdb.upsert(chunks) + + # Extract and store entities and relationships + maybe_new_kg = await extract_entities( + chunks, + knowledge_graph_inst=self.chunk_entity_relation_graph, + entity_vdb=self.entities_vdb, + relationships_vdb=self.relationships_vdb, + global_config=asdict(self), + ) - await self.full_docs.upsert(new_docs) - await self.text_chunks.upsert(inserting_chunks) - finally: - if update_storage: - await self._insert_done() + if maybe_new_kg is None: + raise Exception( + "Failed to extract entities and relationships" + ) + + self.chunk_entity_relation_graph = maybe_new_kg + + # Store original document and chunks + await self.full_docs.upsert( + {doc_id: {"content": doc["content"]}} + ) + await self.text_chunks.upsert(chunks) + + # Update status to processed + doc_status.update( + { + "status": DocStatus.PROCESSED, + "updated_at": datetime.now().isoformat(), + } + ) + await self.doc_status.upsert({doc_id: doc_status}) + + except Exception as e: + # Mark as failed if any step fails + doc_status.update( + { + "status": DocStatus.FAILED, + "error": str(e), + "updated_at": datetime.now().isoformat(), + } + ) + await self.doc_status.upsert({doc_id: doc_status}) + raise e + + except Exception as e: + import traceback + + error_msg = f"Failed to process document {doc_id}: {str(e)}\n{traceback.format_exc()}" + logger.error(error_msg) + continue + + finally: + # Ensure all indexes are updated after each document + await self._insert_done() async def _insert_done(self): tasks = [] @@ -591,3 +673,26 @@ async def _delete_by_entity_done(self): continue tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback()) await asyncio.gather(*tasks) + + def _get_content_summary(self, content: str, max_length: int = 100) -> str: + """Get summary of document content + + Args: + content: Original document content + max_length: Maximum length of summary + + Returns: + Truncated content with ellipsis if needed + """ + content = content.strip() + if len(content) <= max_length: + return content + return content[:max_length] + "..." + + async def get_processing_status(self) -> Dict[str, int]: + """Get current document processing status counts + + Returns: + Dict with counts for each status + """ + return await self.doc_status.get_status_counts() diff --git a/lightrag/storage.py b/lightrag/storage.py index 0c880bb7..0f65d09c 100644 --- a/lightrag/storage.py +++ b/lightrag/storage.py @@ -3,7 +3,7 @@ import os from tqdm.asyncio import tqdm as tqdm_async from dataclasses import dataclass -from typing import Any, Union, cast +from typing import Any, Union, cast, Dict import networkx as nx import numpy as np from nano_vectordb import NanoVectorDB @@ -19,6 +19,9 @@ BaseGraphStorage, BaseKVStorage, BaseVectorStorage, + DocStatus, + DocProcessingStatus, + DocStatusStorage, ) @@ -315,3 +318,47 @@ async def _node2vec_embed(self): nodes_ids = [self._graph.nodes[node_id]["id"] for node_id in nodes] return embeddings, nodes_ids + + +@dataclass +class JsonDocStatusStorage(DocStatusStorage): + """JSON implementation of document status storage""" + + def __post_init__(self): + working_dir = self.global_config["working_dir"] + self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json") + self._data = load_json(self._file_name) or {} + logger.info(f"Loaded document status storage with {len(self._data)} records") + + async def filter_keys(self, data: list[str]) -> set[str]: + """Return keys that don't exist in storage""" + return set([k for k in data if k not in self._data]) + + async def get_status_counts(self) -> Dict[str, int]: + """Get counts of documents in each status""" + counts = {status: 0 for status in DocStatus} + for doc in self._data.values(): + counts[doc["status"]] += 1 + return counts + + async def get_failed_docs(self) -> Dict[str, DocProcessingStatus]: + """Get all failed documents""" + return {k: v for k, v in self._data.items() if v["status"] == DocStatus.FAILED} + + async def get_pending_docs(self) -> Dict[str, DocProcessingStatus]: + """Get all pending documents""" + return {k: v for k, v in self._data.items() if v["status"] == DocStatus.PENDING} + + async def index_done_callback(self): + """Save data to file after indexing""" + write_json(self._data, self._file_name) + + async def upsert(self, data: dict[str, dict]): + """Update or insert document status + + Args: + data: Dictionary of document IDs and their status data + """ + self._data.update(data) + await self.index_done_callback() + return data From e6b2f68e7c79405468f08e4abde34825a981af4f Mon Sep 17 00:00:00 2001 From: Magic_yuan <317617749@qq.com> Date: Sat, 28 Dec 2024 00:16:53 +0800 Subject: [PATCH 2/2] =?UTF-8?q?docs(readme):=20Add=20batch=20size=20config?= =?UTF-8?q?uration=20documentation=20=E6=96=87=E6=A1=A3(readme):=20?= =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E6=89=B9=E5=A4=84=E7=90=86=E5=A4=A7=E5=B0=8F?= =?UTF-8?q?=E9=85=8D=E7=BD=AE=E8=AF=B4=E6=98=8E?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add documentation for insert_batch_size parameter in addon_params - 在 addon_params 中添加 insert_batch_size 参数的文档说明 - Explain default batch size value and its usage - 说明默认批处理大小值及其用途 - Add example configuration for batch processing - 添加批处理配置的示例 --- README.md | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 410049fe..0d7016d2 100644 --- a/README.md +++ b/README.md @@ -278,10 +278,25 @@ class QueryParam: ### Batch Insert ```python -# Batch Insert: Insert multiple texts at once +# Basic Batch Insert: Insert multiple texts at once rag.insert(["TEXT1", "TEXT2",...]) + +# Batch Insert with custom batch size configuration +rag = LightRAG( + working_dir=WORKING_DIR, + addon_params={ + "insert_batch_size": 20 # Process 20 documents per batch + } +) +rag.insert(["TEXT1", "TEXT2", "TEXT3", ...]) # Documents will be processed in batches of 20 ``` +The `insert_batch_size` parameter in `addon_params` controls how many documents are processed in each batch during insertion. This is useful for: +- Managing memory usage with large document collections +- Optimizing processing speed +- Providing better progress tracking +- Default value is 10 if not specified + ### Incremental Insert ```python @@ -594,7 +609,7 @@ if __name__ == "__main__": | **llm\_model\_kwargs** | `dict` | Additional parameters for LLM generation | | | **vector\_db\_storage\_cls\_kwargs** | `dict` | Additional parameters for vector database (currently not used) | | | **enable\_llm\_cache** | `bool` | If `TRUE`, stores LLM results in cache; repeated prompts return cached responses | `TRUE` | -| **addon\_params** | `dict` | Additional parameters, e.g., `{"example_number": 1, "language": "Simplified Chinese", "entity_types": ["organization", "person", "geo", "event"]}`: sets example limit and output language | `example_number: all examples, language: English` | +| **addon\_params** | `dict` | Additional parameters, e.g., `{"example_number": 1, "language": "Simplified Chinese", "entity_types": ["organization", "person", "geo", "event"], "insert_batch_size": 10}`: sets example limit, output language, and batch size for document processing | `example_number: all examples, language: English, insert_batch_size: 10` | | **convert\_response\_to\_json\_func** | `callable` | Not used | `convert_response_to_json` | | **embedding\_cache\_config** | `dict` | Configuration for question-answer caching. Contains three parameters:
- `enabled`: Boolean value to enable/disable cache lookup functionality. When enabled, the system will check cached responses before generating new answers.
- `similarity_threshold`: Float value (0-1), similarity threshold. When a new question's similarity with a cached question exceeds this threshold, the cached answer will be returned directly without calling the LLM.
- `use_llm_check`: Boolean value to enable/disable LLM similarity verification. When enabled, LLM will be used as a secondary check to verify the similarity between questions before returning cached answers. | Default: `{"enabled": False, "similarity_threshold": 0.95, "use_llm_check": False}` |