Skip to content

Commit

Permalink
refactor - move saving & loading of artifacts to example app and ther…
Browse files Browse the repository at this point in the history
…eby removing the depedency on pyarrow & parquet packages
  • Loading branch information
ksachdeva committed Aug 13, 2024
1 parent 9fd467e commit 95a5f0b
Show file tree
Hide file tree
Showing 8 changed files with 36 additions and 36 deletions.
26 changes: 26 additions & 0 deletions examples/simple-app/app/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
from enum import Enum
from pathlib import Path

import pandas as pd
from langchain.embeddings.cache import CacheBackedEmbeddings
from langchain_community.cache import SQLiteCache
from langchain_community.storage import SQLStore
from langchain_core.embeddings import Embeddings
from langchain_core.language_models import BaseLLM
from langchain_graphrag.indexing.artifacts import IndexerArtifacts
from langchain_ollama import OllamaEmbeddings, OllamaLLM
from langchain_openai import (
AzureChatOpenAI,
Expand Down Expand Up @@ -116,3 +118,27 @@ def make_embedding_instance(
underlying_embeddings=underlying_embedding,
document_embedding_cache=store,
)


def save_artifacts(artifacts: IndexerArtifacts, path: Path):
artifacts.entities.to_parquet(f"{path}/entities.parquet")
artifacts.relationships.to_parquet(f"{path}/relationships.parquet")
artifacts.text_units.to_parquet(f"{path}/text_units.parquet")
artifacts.communities.to_parquet(f"{path}/communities.parquet")
artifacts.communities_reports.to_parquet(f"{path}/communities_reports.parquet")


def load_artifacts(path: Path) -> IndexerArtifacts:
entities = pd.read_parquet(f"{path}/entities.parquet")
relationships = pd.read_parquet(f"{path}/relationships.parquet")
text_units = pd.read_parquet(f"{path}/text_units.parquet")
communities = pd.read_parquet(f"{path}/communities.parquet")
communities_reports = pd.read_parquet(f"{path}/communities_reports.parquet")

return IndexerArtifacts(
entities,
relationships,
text_units,
communities,
communities_reports,
)
6 changes: 4 additions & 2 deletions examples/simple-app/app/indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@
EmbeddingModelType,
LLMModel,
LLMType,
load_artifacts,
make_embedding_instance,
make_llm_instance,
save_artifacts,
)
from langchain_chroma.vectorstores import Chroma as ChromaVectorStore
from langchain_community.document_loaders.directory import DirectoryLoader
Expand Down Expand Up @@ -196,7 +198,7 @@ def index(
# save the artifacts
artifacts_dir = output_dir / "artifacts"
artifacts_dir.mkdir(parents=True, exist_ok=True)
artifacts.save(artifacts_dir)
save_artifacts(artifacts, artifacts_dir)

artifacts.report()

Expand All @@ -205,5 +207,5 @@ def index(
def report(
artifacts_dir: Path = typer.Option(..., dir_okay=True, file_okay=False),
):
artifacts = IndexerArtifacts.load(artifacts_dir)
artifacts: IndexerArtifacts = load_artifacts(artifacts_dir)
artifacts.report()
6 changes: 3 additions & 3 deletions examples/simple-app/app/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@
EmbeddingModelType,
LLMModel,
LLMType,
load_artifacts,
make_embedding_instance,
make_llm_instance,
)
from langchain_chroma.vectorstores import Chroma as ChromaVectorStore
from langchain_core.output_parsers.string import StrOutputParser
from langchain_graphrag.indexing.artifacts import IndexerArtifacts
from langchain_graphrag.query.global_search import GlobalSearch
from langchain_graphrag.query.global_search.community_weight_calculator import (
CommunityWeightCalculator,
Expand Down Expand Up @@ -84,7 +84,7 @@ def global_search(
key_points_aggregator=points_aggregator,
)

artifacts = IndexerArtifacts.load(artifacts_dir)
artifacts = load_artifacts(artifacts_dir)
response = searcher.invoke(query, artifacts)

print(response)
Expand Down Expand Up @@ -150,7 +150,7 @@ def local_search(
context_builder=context_builder,
)

artifacts = IndexerArtifacts.load(artifacts_dir)
artifacts = load_artifacts(artifacts_dir)
response = searcher.invoke(query, artifacts)

print(response)
2 changes: 2 additions & 0 deletions examples/simple-app/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ dependencies = [
"langchain-openai",
"langchain-ollama",
"langchain-chroma",
"pyarrow>=17.0.0",
"fastparquet>=2024.5.0",
]
name = "simple-app"
version = "0.0.1"
Expand Down
4 changes: 1 addition & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
[project]
name = "langchain-graphrag"
version = "0.0.2-beta.4"
version = "0.0.2-beta.5"
description = "Implementation of GraphRAG (https://arxiv.org/pdf/2404.16130)"
authors = [{ name = "Kapil Sachdeva", email = "[email protected]" }]
dependencies = [
"pandas>=2.2.2",
"pyarrow>=17.0.0",
"fastparquet>=2024.5.0",
"networkx>=3.3",
"langchain-core>=0.2.27",
"langchain-text-splitters>=0.2.2",
Expand Down
2 changes: 0 additions & 2 deletions requirements-dev.lock
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,6 @@ fastapi==0.112.0
# via chromadb
# via langchain-chroma
fastparquet==2024.5.0
# via langchain-graphrag
filelock==3.15.4
# via huggingface-hub
filetype==1.2.0
Expand Down Expand Up @@ -425,7 +424,6 @@ ptyprocess==0.7.0
pure-eval==0.2.3
# via stack-data
pyarrow==17.0.0
# via langchain-graphrag
pyasn1==0.6.0
# via pyasn1-modules
# via rsa
Expand Down
2 changes: 0 additions & 2 deletions requirements.lock
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,6 @@ fastapi==0.112.0
# via chromadb
# via langchain-chroma
fastparquet==2024.5.0
# via langchain-graphrag
filelock==3.15.4
# via huggingface-hub
filetype==1.2.0
Expand Down Expand Up @@ -335,7 +334,6 @@ protobuf==4.25.4
psutil==6.0.0
# via unstructured
pyarrow==17.0.0
# via langchain-graphrag
pyasn1==0.6.0
# via pyasn1-modules
# via rsa
Expand Down
24 changes: 0 additions & 24 deletions src/langchain_graphrag/indexing/artifacts.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from pathlib import Path
from typing import NamedTuple

import pandas as pd
Expand All @@ -12,29 +11,6 @@ class IndexerArtifacts(NamedTuple):
communities: pd.DataFrame
communities_reports: pd.DataFrame

def save(self, path: Path):
self.entities.to_parquet(f"{path}/entities.parquet")
self.relationships.to_parquet(f"{path}/relationships.parquet")
self.text_units.to_parquet(f"{path}/text_units.parquet")
self.communities.to_parquet(f"{path}/communities.parquet")
self.communities_reports.to_parquet(f"{path}/communities_reports.parquet")

@staticmethod
def load(path: Path) -> "IndexerArtifacts":
entities = pd.read_parquet(f"{path}/entities.parquet")
relationships = pd.read_parquet(f"{path}/relationships.parquet")
text_units = pd.read_parquet(f"{path}/text_units.parquet")
communities = pd.read_parquet(f"{path}/communities.parquet")
communities_reports = pd.read_parquet(f"{path}/communities_reports.parquet")

return IndexerArtifacts(
entities,
relationships,
text_units,
communities,
communities_reports,
)

def _entity_info(self, top_k: int) -> None:
tableprint.banner("Entities")

Expand Down

0 comments on commit 95a5f0b

Please sign in to comment.