From 650b8e38b7971cd4e0776863f99d13816e840773 Mon Sep 17 00:00:00 2001
From: Magic_yuan <317617749@qq.com>
Date: Sat, 28 Dec 2024 00:11:25 +0800
Subject: [PATCH 1/2] =?UTF-8?q?feat(lightrag):=20Add=20document=20status?=
=?UTF-8?q?=20tracking=20and=20checkpoint=20support=20=E5=8A=9F=E8=83=BD(l?=
=?UTF-8?q?ightrag):=20=E6=B7=BB=E5=8A=A0=E6=96=87=E6=A1=A3=E7=8A=B6?=
=?UTF-8?q?=E6=80=81=E8=B7=9F=E8=B8=AA=E5=92=8C=E6=96=AD=E7=82=B9=E7=BB=AD?=
=?UTF-8?q?=E4=BC=A0=E6=94=AF=E6=8C=81?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
- Add DocStatus enum and DocProcessingStatus class for document processing state management
- 添加 DocStatus 枚举和 DocProcessingStatus 类用于文档处理状态管理
- Implement JsonDocStatusStorage for persistent status storage
- 实现 JsonDocStatusStorage 用于持久化状态存储
- Add document-level deduplication in batch processing
- 在批处理中添加文档级别的去重功能
- Add checkpoint support in ainsert method for resumable document processing
- 在 ainsert 方法中添加断点续传支持,实现可恢复的文档处理
- Add status query methods for monitoring processing progress
- 添加状态查询方法用于监控处理进度
- Update LightRAG initialization to support document status tracking
- 更新 LightRAG 初始化以支持文档状态跟踪
---
lightrag/base.py | 42 +++++++-
lightrag/kg/age_impl.py | 5 +-
lightrag/lightrag.py | 229 +++++++++++++++++++++++++++++-----------
lightrag/storage.py | 49 ++++++++-
4 files changed, 260 insertions(+), 65 deletions(-)
diff --git a/lightrag/base.py b/lightrag/base.py
index f5a6e0c0..c3ba3e09 100644
--- a/lightrag/base.py
+++ b/lightrag/base.py
@@ -1,5 +1,6 @@
from dataclasses import dataclass, field
-from typing import TypedDict, Union, Literal, Generic, TypeVar
+from typing import TypedDict, Union, Literal, Generic, TypeVar, Optional, Dict, Any
+from enum import Enum
import numpy as np
@@ -129,3 +130,42 @@ async def delete_node(self, node_id: str):
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
raise NotImplementedError("Node embedding is not used in lightrag.")
+
+
+class DocStatus(str, Enum):
+ """Document processing status enum"""
+
+ PENDING = "pending"
+ PROCESSING = "processing"
+ PROCESSED = "processed"
+ FAILED = "failed"
+
+
+@dataclass
+class DocProcessingStatus:
+ """Document processing status data structure"""
+
+ content_summary: str # First 100 chars of document content
+ content_length: int # Total length of document
+ status: DocStatus # Current processing status
+ created_at: str # ISO format timestamp
+ updated_at: str # ISO format timestamp
+ chunks_count: Optional[int] = None # Number of chunks after splitting
+ error: Optional[str] = None # Error message if failed
+ metadata: Dict[str, Any] = field(default_factory=dict) # Additional metadata
+
+
+class DocStatusStorage(BaseKVStorage):
+ """Base class for document status storage"""
+
+ async def get_status_counts(self) -> Dict[str, int]:
+ """Get counts of documents in each status"""
+ raise NotImplementedError
+
+ async def get_failed_docs(self) -> Dict[str, DocProcessingStatus]:
+ """Get all failed documents"""
+ raise NotImplementedError
+
+ async def get_pending_docs(self) -> Dict[str, DocProcessingStatus]:
+ """Get all pending documents"""
+ raise NotImplementedError
diff --git a/lightrag/kg/age_impl.py b/lightrag/kg/age_impl.py
index 2a97bc37..275f5775 100644
--- a/lightrag/kg/age_impl.py
+++ b/lightrag/kg/age_impl.py
@@ -1,7 +1,8 @@
import asyncio
import inspect
import json
-import os, sys
+import os
+import sys
from contextlib import asynccontextmanager
from dataclasses import dataclass
from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Union
@@ -22,8 +23,10 @@
if sys.platform.startswith("win"):
import asyncio.windows_events
+
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
+
class AGEQueryException(Exception):
"""Exception for the AGE queries."""
diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py
index 51ac204d..992c43a4 100644
--- a/lightrag/lightrag.py
+++ b/lightrag/lightrag.py
@@ -4,7 +4,7 @@
from dataclasses import asdict, dataclass, field
from datetime import datetime
from functools import partial
-from typing import Type, cast
+from typing import Type, cast, Dict
from .llm import (
gpt_4o_mini_complete,
@@ -32,12 +32,14 @@
BaseVectorStorage,
StorageNameSpace,
QueryParam,
+ DocStatus,
)
from .storage import (
JsonKVStorage,
NanoVectorDBStorage,
NetworkXStorage,
+ JsonDocStatusStorage,
)
# future KG integrations
@@ -172,6 +174,9 @@ class LightRAG:
addon_params: dict = field(default_factory=dict)
convert_response_to_json_func: callable = convert_response_to_json
+ # Add new field for document status storage type
+ doc_status_storage: str = field(default="JsonDocStatusStorage")
+
def __post_init__(self):
log_file = os.path.join("lightrag.log")
set_logger(log_file)
@@ -263,7 +268,15 @@ def __post_init__(self):
)
)
- def _get_storage_class(self) -> Type[BaseGraphStorage]:
+ # Initialize document status storage
+ self.doc_status_storage_cls = self._get_storage_class()[self.doc_status_storage]
+ self.doc_status = self.doc_status_storage_cls(
+ namespace="doc_status",
+ global_config=asdict(self),
+ embedding_func=None,
+ )
+
+ def _get_storage_class(self) -> dict:
return {
# kv storage
"JsonKVStorage": JsonKVStorage,
@@ -284,6 +297,7 @@ def _get_storage_class(self) -> Type[BaseGraphStorage]:
"TiDBGraphStorage": TiDBGraphStorage,
"GremlinStorage": GremlinStorage,
# "ArangoDBStorage": ArangoDBStorage
+ "JsonDocStatusStorage": JsonDocStatusStorage,
}
def insert(self, string_or_strings):
@@ -291,71 +305,139 @@ def insert(self, string_or_strings):
return loop.run_until_complete(self.ainsert(string_or_strings))
async def ainsert(self, string_or_strings):
- update_storage = False
- try:
- if isinstance(string_or_strings, str):
- string_or_strings = [string_or_strings]
-
- new_docs = {
- compute_mdhash_id(c.strip(), prefix="doc-"): {"content": c.strip()}
- for c in string_or_strings
+ """Insert documents with checkpoint support
+
+ Args:
+ string_or_strings: Single document string or list of document strings
+ """
+ if isinstance(string_or_strings, str):
+ string_or_strings = [string_or_strings]
+
+ # 1. Remove duplicate contents from the list
+ unique_contents = list(set(doc.strip() for doc in string_or_strings))
+
+ # 2. Generate document IDs and initial status
+ new_docs = {
+ compute_mdhash_id(content, prefix="doc-"): {
+ "content": content,
+ "content_summary": self._get_content_summary(content),
+ "content_length": len(content),
+ "status": DocStatus.PENDING,
+ "created_at": datetime.now().isoformat(),
+ "updated_at": datetime.now().isoformat(),
}
- _add_doc_keys = await self.full_docs.filter_keys(list(new_docs.keys()))
- new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys}
- if not len(new_docs):
- logger.warning("All docs are already in the storage")
- return
- update_storage = True
- logger.info(f"[New Docs] inserting {len(new_docs)} docs")
-
- inserting_chunks = {}
- for doc_key, doc in tqdm_async(
- new_docs.items(), desc="Chunking documents", unit="doc"
+ for content in unique_contents
+ }
+
+ # 3. Filter out already processed documents
+ _add_doc_keys = await self.doc_status.filter_keys(list(new_docs.keys()))
+ new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys}
+
+ if not new_docs:
+ logger.info("All documents have been processed or are duplicates")
+ return
+
+ logger.info(f"Processing {len(new_docs)} new unique documents")
+
+ # Process documents in batches
+ batch_size = self.addon_params.get("insert_batch_size", 10)
+ for i in range(0, len(new_docs), batch_size):
+ batch_docs = dict(list(new_docs.items())[i : i + batch_size])
+
+ for doc_id, doc in tqdm_async(
+ batch_docs.items(), desc=f"Processing batch {i//batch_size + 1}"
):
- chunks = {
- compute_mdhash_id(dp["content"], prefix="chunk-"): {
- **dp,
- "full_doc_id": doc_key,
+ try:
+ # Update status to processing
+ doc_status = {
+ "content_summary": doc["content_summary"],
+ "content_length": doc["content_length"],
+ "status": DocStatus.PROCESSING,
+ "created_at": doc["created_at"],
+ "updated_at": datetime.now().isoformat(),
}
- for dp in chunking_by_token_size(
- doc["content"],
- overlap_token_size=self.chunk_overlap_token_size,
- max_token_size=self.chunk_token_size,
- tiktoken_model=self.tiktoken_model_name,
+ await self.doc_status.upsert({doc_id: doc_status})
+
+ # Generate chunks from document
+ chunks = {
+ compute_mdhash_id(dp["content"], prefix="chunk-"): {
+ **dp,
+ "full_doc_id": doc_id,
+ }
+ for dp in chunking_by_token_size(
+ doc["content"],
+ overlap_token_size=self.chunk_overlap_token_size,
+ max_token_size=self.chunk_token_size,
+ tiktoken_model=self.tiktoken_model_name,
+ )
+ }
+
+ # Update status with chunks information
+ doc_status.update(
+ {
+ "chunks_count": len(chunks),
+ "updated_at": datetime.now().isoformat(),
+ }
)
- }
- inserting_chunks.update(chunks)
- _add_chunk_keys = await self.text_chunks.filter_keys(
- list(inserting_chunks.keys())
- )
- inserting_chunks = {
- k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys
- }
- if not len(inserting_chunks):
- logger.warning("All chunks are already in the storage")
- return
- logger.info(f"[New Chunks] inserting {len(inserting_chunks)} chunks")
-
- await self.chunks_vdb.upsert(inserting_chunks)
-
- logger.info("[Entity Extraction]...")
- maybe_new_kg = await extract_entities(
- inserting_chunks,
- knowledge_graph_inst=self.chunk_entity_relation_graph,
- entity_vdb=self.entities_vdb,
- relationships_vdb=self.relationships_vdb,
- global_config=asdict(self),
- )
- if maybe_new_kg is None:
- logger.warning("No new entities and relationships found")
- return
- self.chunk_entity_relation_graph = maybe_new_kg
+ await self.doc_status.upsert({doc_id: doc_status})
+
+ try:
+ # Store chunks in vector database
+ await self.chunks_vdb.upsert(chunks)
+
+ # Extract and store entities and relationships
+ maybe_new_kg = await extract_entities(
+ chunks,
+ knowledge_graph_inst=self.chunk_entity_relation_graph,
+ entity_vdb=self.entities_vdb,
+ relationships_vdb=self.relationships_vdb,
+ global_config=asdict(self),
+ )
- await self.full_docs.upsert(new_docs)
- await self.text_chunks.upsert(inserting_chunks)
- finally:
- if update_storage:
- await self._insert_done()
+ if maybe_new_kg is None:
+ raise Exception(
+ "Failed to extract entities and relationships"
+ )
+
+ self.chunk_entity_relation_graph = maybe_new_kg
+
+ # Store original document and chunks
+ await self.full_docs.upsert(
+ {doc_id: {"content": doc["content"]}}
+ )
+ await self.text_chunks.upsert(chunks)
+
+ # Update status to processed
+ doc_status.update(
+ {
+ "status": DocStatus.PROCESSED,
+ "updated_at": datetime.now().isoformat(),
+ }
+ )
+ await self.doc_status.upsert({doc_id: doc_status})
+
+ except Exception as e:
+ # Mark as failed if any step fails
+ doc_status.update(
+ {
+ "status": DocStatus.FAILED,
+ "error": str(e),
+ "updated_at": datetime.now().isoformat(),
+ }
+ )
+ await self.doc_status.upsert({doc_id: doc_status})
+ raise e
+
+ except Exception as e:
+ import traceback
+
+ error_msg = f"Failed to process document {doc_id}: {str(e)}\n{traceback.format_exc()}"
+ logger.error(error_msg)
+ continue
+
+ finally:
+ # Ensure all indexes are updated after each document
+ await self._insert_done()
async def _insert_done(self):
tasks = []
@@ -591,3 +673,26 @@ async def _delete_by_entity_done(self):
continue
tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
await asyncio.gather(*tasks)
+
+ def _get_content_summary(self, content: str, max_length: int = 100) -> str:
+ """Get summary of document content
+
+ Args:
+ content: Original document content
+ max_length: Maximum length of summary
+
+ Returns:
+ Truncated content with ellipsis if needed
+ """
+ content = content.strip()
+ if len(content) <= max_length:
+ return content
+ return content[:max_length] + "..."
+
+ async def get_processing_status(self) -> Dict[str, int]:
+ """Get current document processing status counts
+
+ Returns:
+ Dict with counts for each status
+ """
+ return await self.doc_status.get_status_counts()
diff --git a/lightrag/storage.py b/lightrag/storage.py
index 0c880bb7..0f65d09c 100644
--- a/lightrag/storage.py
+++ b/lightrag/storage.py
@@ -3,7 +3,7 @@
import os
from tqdm.asyncio import tqdm as tqdm_async
from dataclasses import dataclass
-from typing import Any, Union, cast
+from typing import Any, Union, cast, Dict
import networkx as nx
import numpy as np
from nano_vectordb import NanoVectorDB
@@ -19,6 +19,9 @@
BaseGraphStorage,
BaseKVStorage,
BaseVectorStorage,
+ DocStatus,
+ DocProcessingStatus,
+ DocStatusStorage,
)
@@ -315,3 +318,47 @@ async def _node2vec_embed(self):
nodes_ids = [self._graph.nodes[node_id]["id"] for node_id in nodes]
return embeddings, nodes_ids
+
+
+@dataclass
+class JsonDocStatusStorage(DocStatusStorage):
+ """JSON implementation of document status storage"""
+
+ def __post_init__(self):
+ working_dir = self.global_config["working_dir"]
+ self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json")
+ self._data = load_json(self._file_name) or {}
+ logger.info(f"Loaded document status storage with {len(self._data)} records")
+
+ async def filter_keys(self, data: list[str]) -> set[str]:
+ """Return keys that don't exist in storage"""
+ return set([k for k in data if k not in self._data])
+
+ async def get_status_counts(self) -> Dict[str, int]:
+ """Get counts of documents in each status"""
+ counts = {status: 0 for status in DocStatus}
+ for doc in self._data.values():
+ counts[doc["status"]] += 1
+ return counts
+
+ async def get_failed_docs(self) -> Dict[str, DocProcessingStatus]:
+ """Get all failed documents"""
+ return {k: v for k, v in self._data.items() if v["status"] == DocStatus.FAILED}
+
+ async def get_pending_docs(self) -> Dict[str, DocProcessingStatus]:
+ """Get all pending documents"""
+ return {k: v for k, v in self._data.items() if v["status"] == DocStatus.PENDING}
+
+ async def index_done_callback(self):
+ """Save data to file after indexing"""
+ write_json(self._data, self._file_name)
+
+ async def upsert(self, data: dict[str, dict]):
+ """Update or insert document status
+
+ Args:
+ data: Dictionary of document IDs and their status data
+ """
+ self._data.update(data)
+ await self.index_done_callback()
+ return data
From e6b2f68e7c79405468f08e4abde34825a981af4f Mon Sep 17 00:00:00 2001
From: Magic_yuan <317617749@qq.com>
Date: Sat, 28 Dec 2024 00:16:53 +0800
Subject: [PATCH 2/2] =?UTF-8?q?docs(readme):=20Add=20batch=20size=20config?=
=?UTF-8?q?uration=20documentation=20=E6=96=87=E6=A1=A3(readme):=20?=
=?UTF-8?q?=E6=B7=BB=E5=8A=A0=E6=89=B9=E5=A4=84=E7=90=86=E5=A4=A7=E5=B0=8F?=
=?UTF-8?q?=E9=85=8D=E7=BD=AE=E8=AF=B4=E6=98=8E?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
- Add documentation for insert_batch_size parameter in addon_params
- 在 addon_params 中添加 insert_batch_size 参数的文档说明
- Explain default batch size value and its usage
- 说明默认批处理大小值及其用途
- Add example configuration for batch processing
- 添加批处理配置的示例
---
README.md | 19 +++++++++++++++++--
1 file changed, 17 insertions(+), 2 deletions(-)
diff --git a/README.md b/README.md
index 410049fe..0d7016d2 100644
--- a/README.md
+++ b/README.md
@@ -278,10 +278,25 @@ class QueryParam:
### Batch Insert
```python
-# Batch Insert: Insert multiple texts at once
+# Basic Batch Insert: Insert multiple texts at once
rag.insert(["TEXT1", "TEXT2",...])
+
+# Batch Insert with custom batch size configuration
+rag = LightRAG(
+ working_dir=WORKING_DIR,
+ addon_params={
+ "insert_batch_size": 20 # Process 20 documents per batch
+ }
+)
+rag.insert(["TEXT1", "TEXT2", "TEXT3", ...]) # Documents will be processed in batches of 20
```
+The `insert_batch_size` parameter in `addon_params` controls how many documents are processed in each batch during insertion. This is useful for:
+- Managing memory usage with large document collections
+- Optimizing processing speed
+- Providing better progress tracking
+- Default value is 10 if not specified
+
### Incremental Insert
```python
@@ -594,7 +609,7 @@ if __name__ == "__main__":
| **llm\_model\_kwargs** | `dict` | Additional parameters for LLM generation | |
| **vector\_db\_storage\_cls\_kwargs** | `dict` | Additional parameters for vector database (currently not used) | |
| **enable\_llm\_cache** | `bool` | If `TRUE`, stores LLM results in cache; repeated prompts return cached responses | `TRUE` |
-| **addon\_params** | `dict` | Additional parameters, e.g., `{"example_number": 1, "language": "Simplified Chinese", "entity_types": ["organization", "person", "geo", "event"]}`: sets example limit and output language | `example_number: all examples, language: English` |
+| **addon\_params** | `dict` | Additional parameters, e.g., `{"example_number": 1, "language": "Simplified Chinese", "entity_types": ["organization", "person", "geo", "event"], "insert_batch_size": 10}`: sets example limit, output language, and batch size for document processing | `example_number: all examples, language: English, insert_batch_size: 10` |
| **convert\_response\_to\_json\_func** | `callable` | Not used | `convert_response_to_json` |
| **embedding\_cache\_config** | `dict` | Configuration for question-answer caching. Contains three parameters:
- `enabled`: Boolean value to enable/disable cache lookup functionality. When enabled, the system will check cached responses before generating new answers.
- `similarity_threshold`: Float value (0-1), similarity threshold. When a new question's similarity with a cached question exceeds this threshold, the cached answer will be returned directly without calling the LLM.
- `use_llm_check`: Boolean value to enable/disable LLM similarity verification. When enabled, LLM will be used as a secondary check to verify the similarity between questions before returning cached answers. | Default: `{"enabled": False, "similarity_threshold": 0.95, "use_llm_check": False}` |