Skip to content

Commit

Permalink
add 增加搜索节点接口
Browse files Browse the repository at this point in the history
  • Loading branch information
xuzeyu91 committed Aug 28, 2024
1 parent 6ee5d58 commit 722bcd7
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 49 deletions.
2 changes: 1 addition & 1 deletion Directory.Build.props
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
<Project>
<!-- See https://aka.ms/dotnet/msbuild/customize for more details on customizing your build -->
<PropertyGroup>
<Version>0.1.27</Version>
<Version>0.1.28</Version>
<SKVersion>1.17.1</SKVersion>
</PropertyGroup>
</Project>
16 changes: 16 additions & 0 deletions src/GraphRag.Net/Domain/Interface/IGraphService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,15 @@ public interface IGraphService
/// <param name="input"></param>
/// <returns></returns>
Task InsertGraphDataAsync(string index, string input);

/// <summary>
/// 搜索递归获取节点相关的所有边和节点
/// </summary>
/// <param name="index"></param>
/// <param name="input"></param>
/// <returns></returns>
Task<GraphModel> SearchGraphModel(string index, string input);

/// <summary>
/// 搜索递归获取节点相关的所有边和节点进行图谱对话
/// </summary>
Expand All @@ -38,6 +47,13 @@ public interface IGraphService
/// <returns></returns>
Task<string> SearchGraphAsync(string index, string input);
/// <summary>
/// 通过社区算法匹配相关节点信息
/// </summary>
/// <param name="index"></param>
/// <param name="input"></param>
/// <returns></returns>
Task<GraphModel> SearchGraphCommunityModel(string index, string input);
/// <summary>
/// 搜索递归获取节点相关的所有边和节点进行图谱对话,流式返回
/// </summary>
/// <param name="index"></param>
Expand Down
104 changes: 56 additions & 48 deletions src/GraphRag.Net/Domain/Service/GraphService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -265,48 +265,85 @@ public async Task InsertGraphDataAsync(string index, string input)
}

/// <summary>
/// 搜索递归获取节点相关的所有边和节点进行图谱对话
/// 检索相关节点
/// </summary>
/// <param name="index"></param>
/// <param name="input"></param>
/// <returns></returns>
public async Task<string> SearchGraphAsync(string index, string input)
{
/// <exception cref="ArgumentException"></exception>
public async Task<GraphModel> SearchGraphModel(string index, string input) {
if (string.IsNullOrWhiteSpace(index) || string.IsNullOrWhiteSpace(input))
{
throw new ArgumentException("Values required for index and input cannot be null.");
}
string answer = "";

var textMemModelList = await RetrieveTextMemModelList(index, input);

if (textMemModelList.Any())
{
var nodes = _nodes_Repositories.GetList(p => p.Index == index && textMemModelList.Select(c => c.Id).Contains(p.Id));
var graphModel = GetGraphAllRecursion(index, nodes); ;

answer = await _semanticService.GetGraphAnswerAsync(JsonConvert.SerializeObject(graphModel), input);
return graphModel;
}
else
{
return new GraphModel();
}
return answer;
}

/// <summary>
/// 搜索递归获取节点相关的所有边和节点进行图谱对话,流式返回
/// 通过社区算法匹配相关节点信息
/// </summary>
/// <param name="index"></param>
/// <param name="input"></param>
/// <returns></returns>
public async IAsyncEnumerable<StreamingKernelContent> SearchGraphStreamAsync(string index, string input)
/// <exception cref="ArgumentException"></exception>
public async Task<GraphModel> SearchGraphCommunityModel(string index, string input)
{
if (string.IsNullOrWhiteSpace(index) || string.IsNullOrWhiteSpace(input))
{
throw new ArgumentException("Values required for index and input cannot be null.");
}
string answer = "";
var textMemModelList = await RetrieveTextMemModelList(index, input);

if (textMemModelList.Any())
if (textMemModelList.Count() > 0)
{
var nodes = _nodes_Repositories.GetList(p => p.Index == index && textMemModelList.Select(c => c.Id).Contains(p.Id));
var answerStream = GetFilteredGraphModelStream(index, input, nodes);
//匹配到节点信息
var graphModel = GetGraphAllCommunitiesRecursion(index, nodes);
return graphModel;
}
else
{
return new GraphModel();
}
}

/// <summary>
/// 搜索递归获取节点相关的所有边和节点进行图谱对话
/// </summary>
/// <param name="index"></param>
/// <param name="input"></param>
/// <returns></returns>
public async Task<string> SearchGraphAsync(string index, string input)
{
var graphModel = await SearchGraphModel(index, input);
string answer = await _semanticService.GetGraphAnswerAsync(JsonConvert.SerializeObject(graphModel), input);
return answer;
}

/// <summary>
/// 搜索递归获取节点相关的所有边和节点进行图谱对话,流式返回
/// </summary>
/// <param name="index"></param>
/// <param name="input"></param>
/// <returns></returns>
public async IAsyncEnumerable<StreamingKernelContent> SearchGraphStreamAsync(string index, string input)
{
var graphModel = await SearchGraphModel(index, input);
if (graphModel.Nodes.Count() > 0)
{
var answerStream = _semanticService.GetGraphAnswerStreamAsync(JsonConvert.SerializeObject(graphModel), input);
await foreach (var content in answerStream)
{
yield return content;
Expand All @@ -322,19 +359,11 @@ public async IAsyncEnumerable<StreamingKernelContent> SearchGraphStreamAsync(str
/// <returns></returns>
public async Task<string> SearchGraphCommunityAsync(string index, string input)
{
if (string.IsNullOrWhiteSpace(index) || string.IsNullOrWhiteSpace(input))
{
throw new ArgumentException("Values required for index and input cannot be null.");
}
string answer = "";
var textMemModelList = await RetrieveTextMemModelList(index, input);
var graphModel = await SearchGraphCommunityModel(index, input);
var global = _globals_Repositories.GetFirst(p => p.Index == index)?.Summaries;
if (textMemModelList.Count() > 0)
{
var nodes = _nodes_Repositories.GetList(p => p.Index == index && textMemModelList.Select(c => c.Id).Contains(p.Id));
//匹配到节点信息
var graphModel = GetGraphAllCommunitiesRecursion(index, nodes);

if (graphModel.Nodes.Count()>0)
{
var community = string.Join(Environment.NewLine, _communities_Repositories.GetDB().Queryable<Communities>().Where(p => p.Index == index).Select(p => p.Summaries).ToList());

//这里数据有点多,要通过语义进行一次过滤
Expand All @@ -356,19 +385,14 @@ public async Task<string> SearchGraphCommunityAsync(string index, string input)
/// <returns></returns>
public async IAsyncEnumerable<StreamingKernelContent> SearchGraphCommunityStreamAsync(string index, string input)
{
if (string.IsNullOrWhiteSpace(index) || string.IsNullOrWhiteSpace(input))
{
throw new ArgumentException("Values required for index and input cannot be null.");
}
var textMemModelList = await RetrieveTextMemModelList(index, input);

var global = _globals_Repositories.GetFirst(p => p.Index == index)?.Summaries;
IAsyncEnumerable<StreamingKernelContent> answer;
if (textMemModelList.Count() > 0)

//匹配到节点信息
var graphModel = await SearchGraphCommunityModel(index, input);
if (graphModel.Nodes.Count() > 0)
{
var nodes = _nodes_Repositories.GetList(p => p.Index == index && textMemModelList.Select(c => c.Id).Contains(p.Id));
//匹配到节点信息
var graphModel = GetGraphAllCommunitiesRecursion(index, nodes);
var community = string.Join(Environment.NewLine, _communities_Repositories.GetDB().Queryable<Communities>().Where(p => p.Index == index).Select(p => p.Summaries).ToList());
//这里数据有点多,要通过语义进行一次过滤
answer = _semanticService.GetGraphCommunityAnswerStreamAsync(JsonConvert.SerializeObject(graphModel), community, global, input);
Expand Down Expand Up @@ -508,20 +532,6 @@ private async Task<List<TextMemModel>> RetrieveTextMemModelList(string index, st
return textMemModelList;
}

/// <summary>
/// 使用基于输入条件的语义过滤来过滤图模型。流式返回
/// </summary>
/// <param name="index"></param>
/// <param name="input"></param>
/// <param name="nodes"></param>
/// <returns></returns>
private IAsyncEnumerable<StreamingKernelContent> GetFilteredGraphModelStream(string index, string input, List<Nodes> nodes)
{
var graphModel = GetGraphAllRecursion(index, nodes);
var answerStream = _semanticService.GetGraphAnswerStreamAsync(JsonConvert.SerializeObject(graphModel), input);
return answerStream;
}

/// <summary>
/// 递归获取节点相关的所有边和节点
/// </summary>
Expand Down Expand Up @@ -654,8 +664,6 @@ private List<Nodes> GetNodes(string index, List<Edges> edges)
return nodes;
}



#endregion
}
}

0 comments on commit 722bcd7

Please sign in to comment.