Skip to content

Commit

Permalink
feat - add the support to return & save the merged graph [useful to i…
Browse files Browse the repository at this point in the history
…nvestigate `all descriptions` of nodes & the ones between nodes & edges]
  • Loading branch information
ksachdeva committed Sep 19, 2024
1 parent 3125e0a commit 59ab206
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 20 deletions.
29 changes: 20 additions & 9 deletions examples/simple-app/app/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,9 +193,13 @@ def save_artifacts(artifacts: IndexerArtifacts, path: Path):
artifacts.text_units.to_parquet(f"{path}/text_units.parquet")
artifacts.communities_reports.to_parquet(f"{path}/communities_reports.parquet")

if artifacts.graph is not None:
with path.joinpath("graph.pickle").open("wb") as fp:
pickle.dump(artifacts.graph, fp)
if artifacts.merged_graph is not None:
with path.joinpath("merged-graph.pickle").open("wb") as fp:
pickle.dump(artifacts.merged_graph, fp)

if artifacts.summarized_graph is not None:
with path.joinpath("summarized-graph.pickle").open("wb") as fp:
pickle.dump(artifacts.summarized_graph, fp)

if artifacts.communities is not None:
with path.joinpath("community_info.pickle").open("wb") as fp:
Expand All @@ -208,13 +212,19 @@ def load_artifacts(path: Path) -> IndexerArtifacts:
text_units = pd.read_parquet(f"{path}/text_units.parquet")
communities_reports = pd.read_parquet(f"{path}/communities_reports.parquet")

graph = None
merged_graph = None
summarized_graph = None
communities = None

graph_pickled = path.joinpath("graph.pickle")
if graph_pickled.exists():
with graph_pickled.open("rb") as fp:
graph = pickle.load(fp) # noqa: S301
merged_graph_pickled = path.joinpath("merged-graph.pickle")
if merged_graph_pickled.exists():
with merged_graph_pickled.open("rb") as fp:
merged_graph = pickle.load(fp) # noqa: S301

summarized_graph_pickled = path.joinpath("summarized-graph.pickle")
if summarized_graph_pickled.exists():
with summarized_graph_pickled.open("rb") as fp:
summarized_graph = pickle.load(fp) # noqa: S301

community_info_pickled = path.joinpath("community_info.pickle")
if community_info_pickled.exists():
Expand All @@ -226,7 +236,8 @@ def load_artifacts(path: Path) -> IndexerArtifacts:
relationships,
text_units,
communities_reports,
graph=graph,
merged_graph=merged_graph,
summarized_graph=summarized_graph,
communities=communities,
)

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "langchain-graphrag"
version = "0.0.4"
version = "0.0.5"
description = "Implementation of GraphRAG (https://arxiv.org/pdf/2404.16130)"
authors = [{ name = "Kapil Sachdeva", email = "[email protected]" }]
dependencies = [
Expand Down
3 changes: 2 additions & 1 deletion src/langchain_graphrag/indexing/artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ class IndexerArtifacts(NamedTuple):
relationships: pd.DataFrame
text_units: pd.DataFrame
communities_reports: pd.DataFrame
graph: nx.Graph | None = None
merged_graph: nx.Graph | None = None
summarized_graph: nx.Graph | None = None
communities: CommunityDetectionResult | None = None

def _entity_info(self, top_k: int) -> None:
Expand Down
5 changes: 3 additions & 2 deletions src/langchain_graphrag/indexing/graph_generation/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ def __init__(
self._graphs_merger = graphs_merger
self._er_description_summarizer = er_description_summarizer

def run(self, text_units: pd.DataFrame) -> nx.Graph:
def run(self, text_units: pd.DataFrame) -> tuple[nx.Graph, nx.Graph]:
er_graphs = self._er_extractor.invoke(text_units)
er_merged_graph = self._graphs_merger(er_graphs)
return self._er_description_summarizer.invoke(er_merged_graph)
er_summarized_graph = self._er_description_summarizer.invoke(er_merged_graph)
return er_merged_graph, er_summarized_graph
15 changes: 8 additions & 7 deletions src/langchain_graphrag/indexing/simple_indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,26 +45,26 @@ def run(self, documents: list[Document]) -> IndexerArtifacts:
# Step 1 - Text Unit extraction
df_base_text_units = self._text_unit_extractor.run(documents)

# Step 2 - Generate graph
graph = self._graph_generator.run(df_base_text_units)
# Step 2 - Generate graphs
merged_graph, summarized_graph = self._graph_generator.run(df_base_text_units)

# Step 3 - Detect communities in Graph
community_detection_result = self._community_detector.run(graph)
community_detection_result = self._community_detector.run(summarized_graph)

# Step 4 - Reports for detected Communities (depends on Step 2 & Step 3)
df_communities_reports = self._communities_report_artifacts_generator.run(
community_detection_result,
graph,
summarized_graph,
)

# Step 5 - Entities generation (depends on Step 2 & Step 3)
df_entities = self._entities_artifacts_generator.run(
community_detection_result,
graph,
summarized_graph,
)

# Step 6 - Relationships generation (depends on Step 2)
df_relationships = self._relationships_artifacts_generator.run(graph)
df_relationships = self._relationships_artifacts_generator.run(summarized_graph)

# Step 7 - Text Units generation (depends on Steps 1, 5, 6)
df_text_units = self._text_units_artifacts_generator.run(
Expand All @@ -78,6 +78,7 @@ def run(self, documents: list[Document]) -> IndexerArtifacts:
relationships=df_relationships,
text_units=df_text_units,
communities_reports=df_communities_reports,
graph=graph,
summarized_graph=summarized_graph,
merged_graph=merged_graph,
communities=community_detection_result,
)

0 comments on commit 59ab206

Please sign in to comment.