Skip to content

Commit

Permalink
Merge pull request #521 from magicyuan876/main
Browse files Browse the repository at this point in the history
功能(lightrag): 添加文档状态跟踪和断点续传支持
  • Loading branch information
LarFii authored Dec 27, 2024
2 parents c022db4 + e6b2f68 commit 8a0b59c
Show file tree
Hide file tree
Showing 5 changed files with 277 additions and 67 deletions.
19 changes: 17 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -278,10 +278,25 @@ class QueryParam:
### Batch Insert

```python
# Batch Insert: Insert multiple texts at once
# Basic Batch Insert: Insert multiple texts at once
rag.insert(["TEXT1", "TEXT2",...])

# Batch Insert with custom batch size configuration
rag = LightRAG(
working_dir=WORKING_DIR,
addon_params={
"insert_batch_size": 20 # Process 20 documents per batch
}
)
rag.insert(["TEXT1", "TEXT2", "TEXT3", ...]) # Documents will be processed in batches of 20
```

The `insert_batch_size` parameter in `addon_params` controls how many documents are processed in each batch during insertion. This is useful for:
- Managing memory usage with large document collections
- Optimizing processing speed
- Providing better progress tracking
- Default value is 10 if not specified

### Incremental Insert

```python
Expand Down Expand Up @@ -594,7 +609,7 @@ if __name__ == "__main__":
| **llm\_model\_kwargs** | `dict` | Additional parameters for LLM generation | |
| **vector\_db\_storage\_cls\_kwargs** | `dict` | Additional parameters for vector database (currently not used) | |
| **enable\_llm\_cache** | `bool` | If `TRUE`, stores LLM results in cache; repeated prompts return cached responses | `TRUE` |
| **addon\_params** | `dict` | Additional parameters, e.g., `{"example_number": 1, "language": "Simplified Chinese", "entity_types": ["organization", "person", "geo", "event"]}`: sets example limit and output language | `example_number: all examples, language: English` |
| **addon\_params** | `dict` | Additional parameters, e.g., `{"example_number": 1, "language": "Simplified Chinese", "entity_types": ["organization", "person", "geo", "event"], "insert_batch_size": 10}`: sets example limit, output language, and batch size for document processing | `example_number: all examples, language: English, insert_batch_size: 10` |
| **convert\_response\_to\_json\_func** | `callable` | Not used | `convert_response_to_json` |
| **embedding\_cache\_config** | `dict` | Configuration for question-answer caching. Contains three parameters:<br>- `enabled`: Boolean value to enable/disable cache lookup functionality. When enabled, the system will check cached responses before generating new answers.<br>- `similarity_threshold`: Float value (0-1), similarity threshold. When a new question's similarity with a cached question exceeds this threshold, the cached answer will be returned directly without calling the LLM.<br>- `use_llm_check`: Boolean value to enable/disable LLM similarity verification. When enabled, LLM will be used as a secondary check to verify the similarity between questions before returning cached answers. | Default: `{"enabled": False, "similarity_threshold": 0.95, "use_llm_check": False}` |

Expand Down
42 changes: 41 additions & 1 deletion lightrag/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from dataclasses import dataclass, field
from typing import TypedDict, Union, Literal, Generic, TypeVar
from typing import TypedDict, Union, Literal, Generic, TypeVar, Optional, Dict, Any
from enum import Enum

import numpy as np

Expand Down Expand Up @@ -129,3 +130,42 @@ async def delete_node(self, node_id: str):

async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
raise NotImplementedError("Node embedding is not used in lightrag.")


class DocStatus(str, Enum):
"""Document processing status enum"""

PENDING = "pending"
PROCESSING = "processing"
PROCESSED = "processed"
FAILED = "failed"


@dataclass
class DocProcessingStatus:
"""Document processing status data structure"""

content_summary: str # First 100 chars of document content
content_length: int # Total length of document
status: DocStatus # Current processing status
created_at: str # ISO format timestamp
updated_at: str # ISO format timestamp
chunks_count: Optional[int] = None # Number of chunks after splitting
error: Optional[str] = None # Error message if failed
metadata: Dict[str, Any] = field(default_factory=dict) # Additional metadata


class DocStatusStorage(BaseKVStorage):
"""Base class for document status storage"""

async def get_status_counts(self) -> Dict[str, int]:
"""Get counts of documents in each status"""
raise NotImplementedError

async def get_failed_docs(self) -> Dict[str, DocProcessingStatus]:
"""Get all failed documents"""
raise NotImplementedError

async def get_pending_docs(self) -> Dict[str, DocProcessingStatus]:
"""Get all pending documents"""
raise NotImplementedError
5 changes: 4 additions & 1 deletion lightrag/kg/age_impl.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import asyncio
import inspect
import json
import os, sys
import os
import sys
from contextlib import asynccontextmanager
from dataclasses import dataclass
from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Union
Expand All @@ -22,8 +23,10 @@

if sys.platform.startswith("win"):
import asyncio.windows_events

asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())


class AGEQueryException(Exception):
"""Exception for the AGE queries."""

Expand Down
229 changes: 167 additions & 62 deletions lightrag/lightrag.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from dataclasses import asdict, dataclass, field
from datetime import datetime
from functools import partial
from typing import Type, cast
from typing import Type, cast, Dict

from .llm import (
gpt_4o_mini_complete,
Expand Down Expand Up @@ -32,12 +32,14 @@
BaseVectorStorage,
StorageNameSpace,
QueryParam,
DocStatus,
)

from .storage import (
JsonKVStorage,
NanoVectorDBStorage,
NetworkXStorage,
JsonDocStatusStorage,
)

# future KG integrations
Expand Down Expand Up @@ -172,6 +174,9 @@ class LightRAG:
addon_params: dict = field(default_factory=dict)
convert_response_to_json_func: callable = convert_response_to_json

# Add new field for document status storage type
doc_status_storage: str = field(default="JsonDocStatusStorage")

def __post_init__(self):
log_file = os.path.join("lightrag.log")
set_logger(log_file)
Expand Down Expand Up @@ -263,7 +268,15 @@ def __post_init__(self):
)
)

def _get_storage_class(self) -> Type[BaseGraphStorage]:
# Initialize document status storage
self.doc_status_storage_cls = self._get_storage_class()[self.doc_status_storage]
self.doc_status = self.doc_status_storage_cls(
namespace="doc_status",
global_config=asdict(self),
embedding_func=None,
)

def _get_storage_class(self) -> dict:
return {
# kv storage
"JsonKVStorage": JsonKVStorage,
Expand All @@ -284,78 +297,147 @@ def _get_storage_class(self) -> Type[BaseGraphStorage]:
"TiDBGraphStorage": TiDBGraphStorage,
"GremlinStorage": GremlinStorage,
# "ArangoDBStorage": ArangoDBStorage
"JsonDocStatusStorage": JsonDocStatusStorage,
}

def insert(self, string_or_strings):
loop = always_get_an_event_loop()
return loop.run_until_complete(self.ainsert(string_or_strings))

async def ainsert(self, string_or_strings):
update_storage = False
try:
if isinstance(string_or_strings, str):
string_or_strings = [string_or_strings]

new_docs = {
compute_mdhash_id(c.strip(), prefix="doc-"): {"content": c.strip()}
for c in string_or_strings
"""Insert documents with checkpoint support
Args:
string_or_strings: Single document string or list of document strings
"""
if isinstance(string_or_strings, str):
string_or_strings = [string_or_strings]

# 1. Remove duplicate contents from the list
unique_contents = list(set(doc.strip() for doc in string_or_strings))

# 2. Generate document IDs and initial status
new_docs = {
compute_mdhash_id(content, prefix="doc-"): {
"content": content,
"content_summary": self._get_content_summary(content),
"content_length": len(content),
"status": DocStatus.PENDING,
"created_at": datetime.now().isoformat(),
"updated_at": datetime.now().isoformat(),
}
_add_doc_keys = await self.full_docs.filter_keys(list(new_docs.keys()))
new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys}
if not len(new_docs):
logger.warning("All docs are already in the storage")
return
update_storage = True
logger.info(f"[New Docs] inserting {len(new_docs)} docs")

inserting_chunks = {}
for doc_key, doc in tqdm_async(
new_docs.items(), desc="Chunking documents", unit="doc"
for content in unique_contents
}

# 3. Filter out already processed documents
_add_doc_keys = await self.doc_status.filter_keys(list(new_docs.keys()))
new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys}

if not new_docs:
logger.info("All documents have been processed or are duplicates")
return

logger.info(f"Processing {len(new_docs)} new unique documents")

# Process documents in batches
batch_size = self.addon_params.get("insert_batch_size", 10)
for i in range(0, len(new_docs), batch_size):
batch_docs = dict(list(new_docs.items())[i : i + batch_size])

for doc_id, doc in tqdm_async(
batch_docs.items(), desc=f"Processing batch {i//batch_size + 1}"
):
chunks = {
compute_mdhash_id(dp["content"], prefix="chunk-"): {
**dp,
"full_doc_id": doc_key,
try:
# Update status to processing
doc_status = {
"content_summary": doc["content_summary"],
"content_length": doc["content_length"],
"status": DocStatus.PROCESSING,
"created_at": doc["created_at"],
"updated_at": datetime.now().isoformat(),
}
for dp in chunking_by_token_size(
doc["content"],
overlap_token_size=self.chunk_overlap_token_size,
max_token_size=self.chunk_token_size,
tiktoken_model=self.tiktoken_model_name,
await self.doc_status.upsert({doc_id: doc_status})

# Generate chunks from document
chunks = {
compute_mdhash_id(dp["content"], prefix="chunk-"): {
**dp,
"full_doc_id": doc_id,
}
for dp in chunking_by_token_size(
doc["content"],
overlap_token_size=self.chunk_overlap_token_size,
max_token_size=self.chunk_token_size,
tiktoken_model=self.tiktoken_model_name,
)
}

# Update status with chunks information
doc_status.update(
{
"chunks_count": len(chunks),
"updated_at": datetime.now().isoformat(),
}
)
}
inserting_chunks.update(chunks)
_add_chunk_keys = await self.text_chunks.filter_keys(
list(inserting_chunks.keys())
)
inserting_chunks = {
k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys
}
if not len(inserting_chunks):
logger.warning("All chunks are already in the storage")
return
logger.info(f"[New Chunks] inserting {len(inserting_chunks)} chunks")

await self.chunks_vdb.upsert(inserting_chunks)

logger.info("[Entity Extraction]...")
maybe_new_kg = await extract_entities(
inserting_chunks,
knowledge_graph_inst=self.chunk_entity_relation_graph,
entity_vdb=self.entities_vdb,
relationships_vdb=self.relationships_vdb,
global_config=asdict(self),
)
if maybe_new_kg is None:
logger.warning("No new entities and relationships found")
return
self.chunk_entity_relation_graph = maybe_new_kg
await self.doc_status.upsert({doc_id: doc_status})

try:
# Store chunks in vector database
await self.chunks_vdb.upsert(chunks)

# Extract and store entities and relationships
maybe_new_kg = await extract_entities(
chunks,
knowledge_graph_inst=self.chunk_entity_relation_graph,
entity_vdb=self.entities_vdb,
relationships_vdb=self.relationships_vdb,
global_config=asdict(self),
)

await self.full_docs.upsert(new_docs)
await self.text_chunks.upsert(inserting_chunks)
finally:
if update_storage:
await self._insert_done()
if maybe_new_kg is None:
raise Exception(
"Failed to extract entities and relationships"
)

self.chunk_entity_relation_graph = maybe_new_kg

# Store original document and chunks
await self.full_docs.upsert(
{doc_id: {"content": doc["content"]}}
)
await self.text_chunks.upsert(chunks)

# Update status to processed
doc_status.update(
{
"status": DocStatus.PROCESSED,
"updated_at": datetime.now().isoformat(),
}
)
await self.doc_status.upsert({doc_id: doc_status})

except Exception as e:
# Mark as failed if any step fails
doc_status.update(
{
"status": DocStatus.FAILED,
"error": str(e),
"updated_at": datetime.now().isoformat(),
}
)
await self.doc_status.upsert({doc_id: doc_status})
raise e

except Exception as e:
import traceback

error_msg = f"Failed to process document {doc_id}: {str(e)}\n{traceback.format_exc()}"
logger.error(error_msg)
continue

finally:
# Ensure all indexes are updated after each document
await self._insert_done()

async def _insert_done(self):
tasks = []
Expand Down Expand Up @@ -591,3 +673,26 @@ async def _delete_by_entity_done(self):
continue
tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
await asyncio.gather(*tasks)

def _get_content_summary(self, content: str, max_length: int = 100) -> str:
"""Get summary of document content
Args:
content: Original document content
max_length: Maximum length of summary
Returns:
Truncated content with ellipsis if needed
"""
content = content.strip()
if len(content) <= max_length:
return content
return content[:max_length] + "..."

async def get_processing_status(self) -> Dict[str, int]:
"""Get current document processing status counts
Returns:
Dict with counts for each status
"""
return await self.doc_status.get_status_counts()
Loading

0 comments on commit 8a0b59c

Please sign in to comment.