diff --git a/lightrag/kg/tidb_impl.py b/lightrag/kg/tidb_impl.py index 384dfed6..66b49fe3 100644 --- a/lightrag/kg/tidb_impl.py +++ b/lightrag/kg/tidb_impl.py @@ -19,8 +19,10 @@ def __init__(self, config, **kwargs): self.password = config.get("password", None) self.database = config.get("database", None) self.workspace = config.get("workspace", None) - connection_string = (f"mysql+pymysql://{self.user}:{self.password}@{self.host}:{self.port}/{self.database}" - f"?ssl_verify_cert=true&ssl_verify_identity=true") + connection_string = ( + f"mysql+pymysql://{self.user}:{self.password}@{self.host}:{self.port}/{self.database}" + f"?ssl_verify_cert=true&ssl_verify_identity=true" + ) try: self.engine = create_engine(connection_string) @@ -49,7 +51,7 @@ async def query( self, sql: str, params: dict = None, multirows: bool = False ) -> Union[dict, None]: if params is None: - params = { "workspace": self.workspace } + params = {"workspace": self.workspace} else: params.update({"workspace": self.workspace}) with self.engine.connect() as conn, conn.begin(): @@ -130,8 +132,8 @@ async def filter_keys(self, keys: list[str]) -> set[str]: """过滤掉重复内容""" SQL = SQL_TEMPLATES["filter_keys"].format( table_name=N_T[self.namespace], - id_field= N_ID[self.namespace], - ids=",".join([f"'{id}'" for id in keys]) + id_field=N_ID[self.namespace], + ids=",".join([f"'{id}'" for id in keys]), ) try: await self.db.query(SQL) @@ -161,7 +163,7 @@ async def upsert(self, data: dict[str, dict]): ] contents = [v["content"] for v in data.values()] batches = [ - contents[i: i + self._max_batch_size] + contents[i : i + self._max_batch_size] for i in range(0, len(contents), self._max_batch_size) ] embeddings_list = await asyncio.gather( @@ -174,26 +176,30 @@ async def upsert(self, data: dict[str, dict]): merge_sql = SQL_TEMPLATES["upsert_chunk"] data = [] for item in list_data: - data.append({ - "id": item["__id__"], - "content": item["content"], - "tokens": item["tokens"], - "chunk_order_index": item["chunk_order_index"], - "full_doc_id": item["full_doc_id"], - "content_vector": f"{item["__vector__"].tolist()}", - "workspace": self.db.workspace, - }) + data.append( + { + "id": item["__id__"], + "content": item["content"], + "tokens": item["tokens"], + "chunk_order_index": item["chunk_order_index"], + "full_doc_id": item["full_doc_id"], + "content_vector": f"{item["__vector__"].tolist()}", + "workspace": self.db.workspace, + } + ) await self.db.execute(merge_sql, data) if self.namespace == "full_docs": merge_sql = SQL_TEMPLATES["upsert_doc_full"] data = [] for k, v in self._data.items(): - data.append({ - "id": k, - "content": v["content"], - "workspace": self.db.workspace, - }) + data.append( + { + "id": k, + "content": v["content"], + "workspace": self.db.workspace, + } + ) await self.db.execute(merge_sql, data) return left_data @@ -201,6 +207,7 @@ async def index_done_callback(self): if self.namespace in ["full_docs", "text_chunks"]: logger.info("full doc and chunk data had been saved into TiDB db!") + @dataclass class TiDBVectorDBStorage(BaseVectorStorage): cosine_better_than_threshold: float = 0.2 @@ -215,7 +222,7 @@ def __post_init__(self): ) async def query(self, query: str, top_k: int) -> list[dict]: - """ search from tidb vector""" + """search from tidb vector""" embeddings = await self.embedding_func([query]) embedding = embeddings[0] @@ -228,8 +235,10 @@ async def query(self, query: str, top_k: int) -> list[dict]: "better_than_threshold": self.cosine_better_than_threshold, } - results = await self.db.query(SQL_TEMPLATES[self.namespace], params=params, multirows=True) - print("vector search result:",results) + results = await self.db.query( + SQL_TEMPLATES[self.namespace], params=params, multirows=True + ) + print("vector search result:", results) if not results: return [] return results @@ -253,16 +262,16 @@ async def upsert(self, data: dict[str, dict]): ] contents = [v["content"] for v in data.values()] batches = [ - contents[i: i + self._max_batch_size] + contents[i : i + self._max_batch_size] for i in range(0, len(contents), self._max_batch_size) ] embedding_tasks = [self.embedding_func(batch) for batch in batches] embeddings_list = [] for f in tqdm( - asyncio.as_completed(embedding_tasks), - total=len(embedding_tasks), - desc="Generating embeddings", - unit="batch", + asyncio.as_completed(embedding_tasks), + total=len(embedding_tasks), + desc="Generating embeddings", + unit="batch", ): embeddings = await f embeddings_list.append(embeddings) @@ -274,27 +283,31 @@ async def upsert(self, data: dict[str, dict]): data = [] for item in list_data: merge_sql = SQL_TEMPLATES["upsert_entity"] - data.append({ - "id": item["id"], - "name": item["entity_name"], - "content": item["content"], - "content_vector": f"{item["content_vector"].tolist()}", - "workspace": self.db.workspace, - }) + data.append( + { + "id": item["id"], + "name": item["entity_name"], + "content": item["content"], + "content_vector": f"{item["content_vector"].tolist()}", + "workspace": self.db.workspace, + } + ) await self.db.execute(merge_sql, data) elif self.namespace == "relationships": data = [] for item in list_data: merge_sql = SQL_TEMPLATES["upsert_relationship"] - data.append({ - "id": item["id"], - "source_name": item["src_id"], - "target_name": item["tgt_id"], - "content": item["content"], - "content_vector": f"{item["content_vector"].tolist()}", - "workspace": self.db.workspace, - }) + data.append( + { + "id": item["id"], + "source_name": item["src_id"], + "target_name": item["tgt_id"], + "content": item["content"], + "content_vector": f"{item["content_vector"].tolist()}", + "workspace": self.db.workspace, + } + ) await self.db.execute(merge_sql, data) @@ -346,8 +359,7 @@ async def upsert(self, data: dict[str, dict]): """ }, "LIGHTRAG_GRAPH_NODES": { - "ddl": - """ + "ddl": """ CREATE TABLE LIGHTRAG_GRAPH_NODES ( `id` BIGINT PRIMARY KEY AUTO_RANDOM, `entity_id` VARCHAR(256) NOT NULL, @@ -362,8 +374,7 @@ async def upsert(self, data: dict[str, dict]): """ }, "LIGHTRAG_GRAPH_EDGES": { - "ddl": - """ + "ddl": """ CREATE TABLE LIGHTRAG_GRAPH_EDGES ( `id` BIGINT PRIMARY KEY AUTO_RANDOM, `relation_id` VARCHAR(256) NOT NULL, @@ -400,7 +411,6 @@ async def upsert(self, data: dict[str, dict]): "get_by_ids_full_docs": "SELECT doc_id as id, IFNULL(content, '') AS content FROM LIGHTRAG_DOC_FULL WHERE doc_id IN ({ids}) AND workspace = :workspace", "get_by_ids_text_chunks": "SELECT chunk_id as id, tokens, IFNULL(content, '') AS content, chunk_order_index, full_doc_id FROM LIGHTRAG_DOC_CHUNKS WHERE chunk_id IN ({ids}) AND workspace = :workspace", "filter_keys": "SELECT {id_field} AS id FROM {table_name} WHERE {id_field} IN ({ids}) AND workspace = :workspace", - # SQL for Merge operations (TiDB version with INSERT ... ON DUPLICATE KEY UPDATE) "upsert_doc_full": """ INSERT INTO LIGHTRAG_DOC_FULL (doc_id, content, workspace) @@ -408,13 +418,12 @@ async def upsert(self, data: dict[str, dict]): ON DUPLICATE KEY UPDATE content = VALUES(content), workspace = VALUES(workspace), updatetime = CURRENT_TIMESTAMP """, "upsert_chunk": """ - INSERT INTO LIGHTRAG_DOC_CHUNKS(chunk_id, content, tokens, chunk_order_index, full_doc_id, content_vector, workspace) + INSERT INTO LIGHTRAG_DOC_CHUNKS(chunk_id, content, tokens, chunk_order_index, full_doc_id, content_vector, workspace) VALUES (:id, :content, :tokens, :chunk_order_index, :full_doc_id, :content_vector, :workspace) - ON DUPLICATE KEY UPDATE - content = VALUES(content), tokens = VALUES(tokens), chunk_order_index = VALUES(chunk_order_index), + ON DUPLICATE KEY UPDATE + content = VALUES(content), tokens = VALUES(tokens), chunk_order_index = VALUES(chunk_order_index), full_doc_id = VALUES(full_doc_id), content_vector = VALUES(content_vector), workspace = VALUES(workspace), updatetime = CURRENT_TIMESTAMP """, - # SQL for VectorStorage "entities": """SELECT n.name as entity_name FROM (SELECT entity_id as id, name, VEC_COSINE_DISTANCE(content_vector,:embedding_string) as distance @@ -428,19 +437,18 @@ async def upsert(self, data: dict[str, dict]): (SELECT chunk_id as id,VEC_COSINE_DISTANCE(content_vector, :embedding_string) as distance FROM LIGHTRAG_DOC_CHUNKS WHERE workspace = :workspace) c WHERE c.distance>:better_than_threshold ORDER BY c.distance DESC LIMIT :top_k""", - "upsert_entity": """ - INSERT INTO LIGHTRAG_GRAPH_NODES(entity_id, name, content, content_vector, workspace) - VALUES(:id, :name, :content, :content_vector, :workspace) - ON DUPLICATE KEY UPDATE - name = VALUES(name), content = VALUES(content), content_vector = VALUES(content_vector), + INSERT INTO LIGHTRAG_GRAPH_NODES(entity_id, name, content, content_vector, workspace) + VALUES(:id, :name, :content, :content_vector, :workspace) + ON DUPLICATE KEY UPDATE + name = VALUES(name), content = VALUES(content), content_vector = VALUES(content_vector), workspace = VALUES(workspace), updatetime = CURRENT_TIMESTAMP """, "upsert_relationship": """ - INSERT INTO LIGHTRAG_GRAPH_EDGES(relation_id, source_name, target_name, content, content_vector, workspace) + INSERT INTO LIGHTRAG_GRAPH_EDGES(relation_id, source_name, target_name, content, content_vector, workspace) VALUES(:id, :source_name, :target_name, :content, :content_vector, :workspace) - ON DUPLICATE KEY UPDATE - source_name = VALUES(source_name), target_name = VALUES(target_name), content = VALUES(content), + ON DUPLICATE KEY UPDATE + source_name = VALUES(source_name), target_name = VALUES(target_name), content = VALUES(content), content_vector = VALUES(content_vector), workspace = VALUES(workspace), updatetime = CURRENT_TIMESTAMP - """ -} \ No newline at end of file + """, +} diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 9e9beb1e..5a337a08 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -80,6 +80,7 @@ def import_class(*args, **kwargs): TiDBKVStorage = lazy_external_import(".kg.tidb_impl", "TiDBKVStorage") TiDBVectorDBStorage = lazy_external_import(".kg.tidb_impl", "TiDBVectorDBStorage") + def always_get_an_event_loop() -> asyncio.AbstractEventLoop: """ Ensure that there is always an event loop available. diff --git a/requirements.txt b/requirements.txt index 32a92ab5..3cc48028 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,11 +13,11 @@ openai oracledb pymilvus pymongo +pymysql pyvis -tenacity # lmdeploy[all] sqlalchemy -pymysql +tenacity # LLM packages