Skip to content

Commit

Permalink
Merge pull request #154 from amosproj/develop
Browse files Browse the repository at this point in the history
Develop
  • Loading branch information
borges-filipe authored Jun 19, 2024
2 parents d24e4b3 + 614535a commit fb4ba9d
Show file tree
Hide file tree
Showing 26 changed files with 2,440 additions and 813 deletions.
186 changes: 186 additions & 0 deletions Project/backend/codebase/graph_analysis/graph_analysis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
import networkx as nx
import os
import json


def analyze_graph_structure(G):
"""Analyzes the structure of a knowledge graph and provides hopefully useful information.
Currently, I am not sure how to use most of the information, but we may find a way to use it
Args:
G: A networkx graph.
Returns:
A dictionary containing information about the graph's structure.
"""

# Basic Graph Statistics
num_nodes = G.number_of_nodes() # Total number of nodes
num_edges = G.number_of_edges() # Total number of edges
density = nx.density(G) # Ratio of actual edges to possible edges (0 to 1)
average_degree = 2 * num_edges / num_nodes # Average number of edges per node

# Degree Distribution
degree_distribution = dict(G.degree())
# Degree distribution can indicate the presence of hubs or important nodes

degree_centrality = nx.degree_centrality(G)

""" Centrality Measures
- Degree Centrality: Measures node connectivity
- Nodes with high degree centrality are important in the network
Examples: 3 nodes are connected in a line
node1 - node2 - node3
- Degree: node1 = 1, node2 = 2, node3 = 1
- Degree Centrality: node1 = 0.33(1/3), node2 = 0.66(2/3), node3 = 0.33(1/3)
"""

betweenness_centrality = nx.betweenness_centrality(G)

"""
- Betweenness Centrality: Measures node's control over information flow
- Nodes with high betweenness centrality are important in the network
Examples: 4 nodes are connected
1
/ \
2-----3
\ /
4
- Here, node 2 has the highest betweenness centrality because it lies on the shortest path between all other nodes
- if node in not between any other nodes, its betweenness centrality is 0
- Betweenness Centrality show the dependency of the network on a node
"""

# - Closeness Centrality: Measures average length of the shortest path from a node to all other nodes
closeness_centrality = nx.closeness_centrality(G)

"""
- Closeness Centrality: Measures average length of the shortest path from a node to all other nodes
- Nodes with high closeness centrality are important in the network
Examples: 4 nodes are connected
0
/ | \
2--1--3
- Here, node 0, 1 (1.0) has the highest closeness centrality because it is connected to all other nodes (node 2, 3 = 0.75)
- Closeness Centrality show the average distance of a node to all other nodes in the network
"""

# - Eigenvector Centrality: Measures influence of a node in a network
eigenvector_centrality = nx.eigenvector_centrality(G)

"""
- Eigenvector Centrality: Measures influence of a node in a network
- Nodes with high eigenvector centrality are important in the network
Examples: 4 nodes are connected
1
/ \
2-----3
\ /
4
- Here, node 3 has the highest eigenvector centrality because it is connected to node 2 which has high eigenvector centrality
- Eigenvector Centrality show the influence of a node in the network
- Eigenvector Centrality is similar to PageRank algorithm
- in this measure every node has some values and the values are updated based on the values of the connected nodes (node value != 0)
"""

# Community Structure
# - Louvain Algorithm (for community detection)
communities = list(nx.community.greedy_modularity_communities(G))
community_sizes = [len(community) for community in communities]
num_communities = len(communities)
# Communities can reveal modular structures in the graph
"""
- Community Detection: Identifying groups of nodes that are more connected to each other than to the rest of the network
- Communities can reveal modular structures in the graph
- Communities can be used to identify groups of nodes that are more connected to each other than to the rest of the network
Examples: 7 nodes are connected
1
/ \
2-----3
\ / 5
4-----/ \
6-----7
- Here, nodes 1, 2, 3, 4 are in one community and nodes 5, 6, 7 are in another community
"""

# Graph Connectivity
# - Check if the graph is connected
is_connected = nx.is_connected(G)
# - Calculate diameter: Longest shortest path between any two nodes
diameter = nx.diameter(G) if is_connected else float('inf')
# - Average shortest path length: Average of all shortest paths in the graph
average_shortest_path_length = nx.average_shortest_path_length(G) if is_connected else float('inf')

# Clustering Coefficient
# - Measures the degree to which nodes tend to cluster together
average_clustering_coefficient = nx.average_clustering(G)

# Assortativity
# - Measures the similarity of connections in the graph with respect to node degree
assortativity = nx.degree_assortativity_coefficient(G)

# Graph Diameter and Radius
# - Diameter: Longest shortest path in the graph
# - Radius: Minimum eccentricity of any node
radius = nx.radius(G) if is_connected else float('inf')

# Graph Transitivity
# - Measures the overall probability for the network to have adjacent nodes interconnected
transitivity = nx.transitivity(G)

# Return a dictionary containing the structural information
graph_info = {
"num_nodes": num_nodes,
"num_edges": num_edges,
"density": density,
"average_degree": average_degree,
"degree_distribution": degree_distribution,
"degree_centrality": degree_centrality,
"betweenness_centrality": betweenness_centrality,
"closeness_centrality": closeness_centrality,
"eigenvector_centrality": eigenvector_centrality,
"num_communities": num_communities,
"community_sizes": community_sizes,
"is_connected": is_connected,
"diameter": diameter,
"average_shortest_path_length": average_shortest_path_length,
"average_clustering_coefficient": average_clustering_coefficient,
"assortativity": assortativity,
"radius": radius,
"transitivity": transitivity
}

return graph_info


def print_graph_info(graph_info):
"""Prints the graph information in a formatted and readable way.
Args:
graph_info: A dictionary containing information about the graph's structure.
"""

print(json.dumps(graph_info, indent=4))


graph_directory = os.fsencode("../.media/graphs/")

with os.scandir("./Project/backend/codebase/.media/graphs/") as it:
for entry in it:
if entry.name.endswith(".gml") and entry.is_file():
print("-----------------------")
print(f"Filename: {entry.name}")
graph = nx.read_gml(entry.path)
graph_info = analyze_graph_structure(graph)
print_graph_info(graph_info)
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def create_and_store_graph(uuid, entities_and_relations, chunks):
graph_db_service = netx_graphdb.NetXGraphDB()

# read entities and relations
graph = graph_db_service.create_graph_from_df(uuid, combined)
graph = graph_db_service.create_graph_from_df(combined)

# save graph as file
graph_db_service.save_graph(uuid, graph)
25 changes: 24 additions & 1 deletion Project/backend/codebase/graph_creator/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@
from graph_creator.gemini import process_chunks
import shutil
import mimetypes
from graph_creator.schemas.graph_vis import GraphVisData
from graph_creator.schemas.graph_vis import GraphVisData, QueryInputData, GraphQueryOutput
from graph_creator.services.netx_graphdb import NetXGraphDB

import graph_creator.graph_creator_main as graph_creator_main
from graph_creator.services.query_graph import GraphQuery
from graph_creator.utils.const import GraphStatus

router = APIRouter()
Expand Down Expand Up @@ -284,3 +285,25 @@ async def get_graph_data_for_visualization(
return netx_services.graph_data_for_visualization(
g_job.id, node=node, adj_depth=adj_depth
)


@router.post("/query_graph/{graph_job_id}")
async def query_graph(
graph_job_id: uuid.UUID,
input_data: QueryInputData,
graph_job_dao: GraphJobDAO = Depends(),
netx_services: NetXGraphDB = Depends(),
graph_query_services: GraphQuery = Depends(),
) -> GraphQueryOutput:
g_job = await graph_job_dao.get_graph_job_by_id(graph_job_id)

if not g_job:
raise HTTPException(status_code=404, detail="Graph job not found")
if g_job.status != GraphStatus.GRAPH_READY:
raise HTTPException(
status_code=400,
detail=f"No graph created for this job!",
)
graph = netx_services.load_graph(graph_job_id=graph_job_id)
data = graph_query_services.query_graph(graph=graph, query=input_data.text)
return data
10 changes: 10 additions & 0 deletions Project/backend/codebase/graph_creator/schemas/graph_vis.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,13 @@ class GraphEdge(BaseModel):
class GraphVisData(BaseModel):
nodes: list[GraphNode]
edges: list[GraphEdge]


class QueryInputData(BaseModel):
text: str


class GraphQueryOutput(BaseModel):
llm_nodes: list[str]
spacy_nodes: list[str]
retrieved_info: dict[str, list[tuple[str, str]]]
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from graph_creator.schemas.graph_vis import GraphVisData, GraphNode, GraphEdge

# Scale range for min-max scaling the node sizes
scale_range = [1, 15]
scale_range = [15, 35]

class NetXGraphDB:
"""
Expand All @@ -16,9 +16,7 @@ class NetXGraphDB:
All the graphs operations will happen in memory.
"""

def create_graph_from_df(
self, graph_job_id: uuid.UUID, data: pd.DataFrame = None
) -> nx.Graph:
def create_graph_from_df(self, data: pd.DataFrame = None) -> nx.Graph:
df = pd.DataFrame(data)
graph = nx.Graph()

Expand Down
80 changes: 80 additions & 0 deletions Project/backend/codebase/graph_creator/services/query_graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import os

import networkx as nx
import spacy
from langchain.chains.graph_qa.base import GraphQAChain
from langchain_community.graphs import NetworkxEntityGraph
from langchain_groq import ChatGroq
from networkx import NetworkXError

from graph_creator.llama3 import configure_groq
from graph_creator.schemas.graph_vis import GraphQueryOutput


class GraphQuery:

def query_graph(self, graph: nx.Graph, query: str) -> GraphQueryOutput:
entities_from_llm = self.retrieve_entities_from_llm(query)
entities_from_spacy = self.retrieve_entities_from_spacy(query)
all_entities = set()
all_entities.update(entities_from_llm)
all_entities.update(entities_from_spacy)

entities_relationships = {}

for node in all_entities:
edges_info = []
try:
for neighbor in graph.neighbors(node):
edge_data = graph.get_edge_data(node, neighbor)
relationship = edge_data.get('relation', 'is connected to')
edges_info.append((relationship, neighbor))
entities_relationships[node] = edges_info
except NetworkXError:
continue

return GraphQueryOutput(
llm_nodes=entities_from_llm,
spacy_nodes=entities_from_spacy,
retrieved_info=entities_relationships,
)

@staticmethod
def retrieve_entities_from_llm(query: str):
groq_client = configure_groq()
SYS_PROMPT = """
The user has a knowledge graph and wants to query it. For that he needs entities.
Your task is to extract all entities from the below query.
Format you answer as below example:
entity1, entity2, entity3
"""
USER_PROMPT = f"user input: ```{query}``` \n\n output: "
chat_completion = groq_client.chat.completions.create(
messages=[
{"role": "system", "content": SYS_PROMPT},
{"role": "user", "content": USER_PROMPT},
],
model="llama3-8b-8192",
)
response = chat_completion.choices[0].message.content
return [entity.strip() for entity in response.split(",")]

@staticmethod
def retrieve_entities_from_spacy(query: str):
nlp = spacy.load('en_core_web_sm')
doc = nlp(query)

entities = [(ent.text, ent.label_) for ent in doc.ents]
return entities

def query_graph_via_langchain(self, query: str, graph_path: str):
graph1 = NetworkxEntityGraph.from_gml(graph_path)
chain = GraphQAChain.from_llm(
ChatGroq(
temperature=0,
model="llama3-8b-8192",
api_key=os.getenv("GROQ_API_KEY")
), graph=graph1, verbose=True
)
response = chain.invoke(query)
return response
4 changes: 3 additions & 1 deletion Project/backend/codebase/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -83,4 +83,6 @@ tomli==2.0.1
typing-inspect==0.9.0
urllib3==2.2.1
google-generativeai==0.5.4
groq==0.8.0
groq==0.8.0
spacy
langchain-groq
1 change: 1 addition & 0 deletions Project/backend/config/api/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ RUN : \
&& pip install pip --upgrade \
&& pip install setuptools --upgrade \
&& pip install -r requirements.txt \
&& python -m spacy download en_core_web_sm \
&& :

COPY codebase/ ./
Expand Down
Loading

0 comments on commit fb4ba9d

Please sign in to comment.