Skip to content

Commit

Permalink
Merge pull request #71 from AIDotNet/feature_qa
Browse files Browse the repository at this point in the history
Feature qa
  • Loading branch information
xuzeyu91 authored Apr 15, 2024
2 parents 79326de + 64e949a commit 1cc56dd
Show file tree
Hide file tree
Showing 11 changed files with 330 additions and 27 deletions.
6 changes: 6 additions & 0 deletions src/AntSK.Domain/AntSK.Domain.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions src/AntSK.Domain/Domain/Interface/IKernelService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ namespace AntSK.Domain.Domain.Interface
public interface IKernelService
{
Kernel GetKernelByApp(Apps app);

Kernel GetKernelByAIModelID(string modelid);
void ImportFunctionsByApp(Apps app, Kernel _kernel);
Task<string> HistorySummarize(Kernel _kernel, string questions, string history);
}
Expand Down
9 changes: 9 additions & 0 deletions src/AntSK.Domain/Domain/Model/ImportKMSTaskReq.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,14 @@ public class ImportKMSTaskDTO
public string FilePath { get; set; } = "";

public string FileName { get; set; } = "";

public bool IsQA { get; set; } = false;
}


public class ImportKMSTaskReq : ImportKMSTaskDTO
{
public bool IsQA { get; set; }=false;
public KmsDetails KmsDetail { get; set; } = new KmsDetails();
}

Expand All @@ -32,4 +35,10 @@ public enum ImportType
Text = 3,
Excel=4
}

public class QAModel
{
public string ChatModelId { get; set; }
public string Context { get; set; }
}
}
154 changes: 154 additions & 0 deletions src/AntSK.Domain/Domain/Other/QAHandler.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
using AntSK.Domain.Domain.Model;
using Microsoft.Extensions.Configuration;
using Microsoft.Extensions.Logging;
using Microsoft.KernelMemory.AI.OpenAI;
using Microsoft.KernelMemory.Configuration;
using Microsoft.KernelMemory.DataFormats.Text;
using Microsoft.KernelMemory.Diagnostics;
using Microsoft.KernelMemory.Extensions;
using Microsoft.KernelMemory.Pipeline;
using Newtonsoft.Json;
using RestSharp;
using System.Security.Policy;
using System.Text;

namespace AntSK.Domain.Domain.Other
{
public class QAHandler : IPipelineStepHandler
{
private readonly TextPartitioningOptions _options;
private readonly IPipelineOrchestrator _orchestrator;
private readonly ILogger<QAHandler> _log;
private readonly TextChunker.TokenCounter _tokenCounter;
public QAHandler(
string stepName,
IPipelineOrchestrator orchestrator,
TextPartitioningOptions? options = null,
ILogger<QAHandler>? log = null
)
{
this.StepName = stepName;
this._orchestrator = orchestrator;
this._options = options ?? new TextPartitioningOptions();
this._options.Validate();

this._log = log ?? DefaultLogger<QAHandler>.Instance;
this._tokenCounter = DefaultGPTTokenizer.StaticCountTokens;
}

/// <inheritdoc />
public string StepName { get; }

/// <inheritdoc />
public async Task<(bool success, DataPipeline updatedPipeline)> InvokeAsync(
DataPipeline pipeline, CancellationToken cancellationToken = default)
{
this._log.LogDebug("Partitioning text, pipeline '{0}/{1}'", pipeline.Index, pipeline.DocumentId);

if (pipeline.Files.Count == 0)
{
this._log.LogWarning("Pipeline '{0}/{1}': there are no files to process, moving to next pipeline step.", pipeline.Index, pipeline.DocumentId);
return (true, pipeline);
}

foreach (DataPipeline.FileDetails uploadedFile in pipeline.Files)
{
// Track new files being generated (cannot edit originalFile.GeneratedFiles while looping it)
Dictionary<string, DataPipeline.GeneratedFileDetails> newFiles = new();

foreach (KeyValuePair<string, DataPipeline.GeneratedFileDetails> generatedFile in uploadedFile.GeneratedFiles)
{
var file = generatedFile.Value;
if (file.AlreadyProcessedBy(this))
{
this._log.LogTrace("File {0} already processed by this handler", file.Name);
continue;
}

// Partition only the original text
if (file.ArtifactType != DataPipeline.ArtifactTypes.ExtractedText)
{
this._log.LogTrace("Skipping file {0} (not original text)", file.Name);
continue;
}

// Use a different partitioning strategy depending on the file type
List<string> partitions;
List<string> sentences;
BinaryData partitionContent = await this._orchestrator.ReadFileAsync(pipeline, file.Name, cancellationToken).ConfigureAwait(false);

// Skip empty partitions. Also: partitionContent.ToString() throws an exception if there are no bytes.
if (partitionContent.ToArray().Length == 0) { continue; }

switch (file.MimeType)
{
case MimeTypes.PlainText:
case MimeTypes.MarkDown:
{
this._log.LogDebug("Partitioning text file {0}", file.Name);
string content = partitionContent.ToString();

using (HttpClient httpclient = new HttpClient())
{
httpclient.Timeout = TimeSpan.FromMinutes(10);
StringContent scontent = new StringContent(JsonConvert.SerializeObject(new QAModel() { ChatModelId = StepName, Context = content }), Encoding.UTF8, "application/json");
HttpResponseMessage response = await httpclient.PostAsync("http://localhost:5000/api/KMS/QA", scontent);
List<string> qaList = JsonConvert.DeserializeObject<List<string>>( await response.Content.ReadAsStringAsync());
sentences = qaList;
partitions = qaList;
}
break;
}
default:
this._log.LogWarning("File {0} cannot be partitioned, type '{1}' not supported", file.Name, file.MimeType);
// Don't partition other files
continue;
}

if (partitions.Count == 0) { continue; }

this._log.LogDebug("Saving {0} file partitions", partitions.Count);
for (int partitionNumber = 0; partitionNumber < partitions.Count; partitionNumber++)
{
// TODO: turn partitions in objects with more details, e.g. page number
string text = partitions[partitionNumber];
int sectionNumber = 0; // TODO: use this to store the page number (if any)
BinaryData textData = new(text);

int tokenCount = this._tokenCounter(text);
this._log.LogDebug("Partition size: {0} tokens", tokenCount);

var destFile = uploadedFile.GetPartitionFileName(partitionNumber);
await this._orchestrator.WriteFileAsync(pipeline, destFile, textData, cancellationToken).ConfigureAwait(false);

var destFileDetails = new DataPipeline.GeneratedFileDetails
{
Id = Guid.NewGuid().ToString("N"),
ParentId = uploadedFile.Id,
Name = destFile,
Size = text.Length,
MimeType = MimeTypes.PlainText,
ArtifactType = DataPipeline.ArtifactTypes.TextPartition,
PartitionNumber = partitionNumber,
SectionNumber = sectionNumber,
Tags = pipeline.Tags,
ContentSHA256 = textData.CalculateSHA256(),
};
newFiles.Add(destFile, destFileDetails);
destFileDetails.MarkProcessedBy(this);
}

file.MarkProcessedBy(this);
}

// Add new files to pipeline status
foreach (var file in newFiles)
{
uploadedFile.GeneratedFiles.Add(file.Key, file.Value);
}
}

return (true, pipeline);
}
}
}
61 changes: 49 additions & 12 deletions src/AntSK.Domain/Domain/Service/ImportKMSService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,40 @@ public void ImportKMSTask(ImportKMSTaskReq req)
try
{
var km = _kmss_Repositories.GetFirst(p => p.Id == req.KmsId);

var _memory = _kMService.GetMemoryByKMS(km.Id);
string fileid = req.KmsDetail.Id;
List<string> step = new List<string>();
if (req.IsQA)
{
_memory.Orchestrator.AddHandler<TextExtractionHandler>("extract_text");
_memory.Orchestrator.AddHandler<QAHandler>(km.ChatModelID);
_memory.Orchestrator.AddHandler<GenerateEmbeddingsHandler>("generate_embeddings");
_memory.Orchestrator.AddHandler<SaveRecordsHandler>("save_memory_records");
step.Add("extract_text");
step.Add(km.ChatModelID);
step.Add("generate_embeddings");
step.Add("save_memory_records");
}

switch (req.ImportType)
{
case ImportType.File:
//导入文件
{
var importResult = _memory.ImportDocumentAsync(new Document(fileid)
.AddFile(req.FilePath)
.AddTag(KmsConstantcs.KmsIdTag, req.KmsId)
, index: KmsConstantcs.KmsIndex).Result;
//导入文件
if (req.IsQA)
{
var importResult = _memory.ImportDocumentAsync(new Document(fileid)
.AddFile(req.FilePath)
.AddTag(KmsConstantcs.KmsIdTag, req.KmsId)
,index: KmsConstantcs.KmsIndex ,steps: step.ToArray()).Result;
}
else
{
var importResult = _memory.ImportDocumentAsync(new Document(fileid)
.AddFile(req.FilePath)
.AddTag(KmsConstantcs.KmsIdTag, req.KmsId)
, index: KmsConstantcs.KmsIndex).Result;
}
//查询文档数量
var docTextList = _kMService.GetDocumentByFileID(km.Id, fileid).Result;
string fileGuidName = Path.GetFileName(req.FilePath);
Expand All @@ -48,8 +70,16 @@ public void ImportKMSTask(ImportKMSTaskReq req)
case ImportType.Url:
{
//导入url
var importResult = _memory.ImportWebPageAsync(req.Url, fileid, new TagCollection() { { KmsConstantcs.KmsIdTag, req.KmsId } }
, index: KmsConstantcs.KmsIndex).Result;
if (req.IsQA)
{
var importResult = _memory.ImportWebPageAsync(req.Url, fileid, new TagCollection() { { KmsConstantcs.KmsIdTag, req.KmsId } }
, index: KmsConstantcs.KmsIndex, steps: step.ToArray()).Result;
}
else
{
var importResult = _memory.ImportWebPageAsync(req.Url, fileid, new TagCollection() { { KmsConstantcs.KmsIdTag, req.KmsId } }
, index: KmsConstantcs.KmsIndex).Result;
}
//查询文档数量
var docTextList = _kMService.GetDocumentByFileID(km.Id, fileid).Result;
req.KmsDetail.Url = req.Url;
Expand All @@ -59,8 +89,16 @@ public void ImportKMSTask(ImportKMSTaskReq req)
case ImportType.Text:
//导入文本
{
var importResult = _memory.ImportTextAsync(req.Text, fileid, new TagCollection() { { KmsConstantcs.KmsIdTag, req.KmsId } }
, index: KmsConstantcs.KmsIndex).Result;
if (req.IsQA)
{
var importResult = _memory.ImportTextAsync(req.Text, fileid, new TagCollection() { { KmsConstantcs.KmsIdTag, req.KmsId } }
, index: KmsConstantcs.KmsIndex, steps: step.ToArray()).Result;
}
else
{
var importResult = _memory.ImportTextAsync(req.Text, fileid, new TagCollection() { { KmsConstantcs.KmsIdTag, req.KmsId } }
, index: KmsConstantcs.KmsIndex).Result;
}
//查询文档数量
var docTextList = _kMService.GetDocumentByFileID(km.Id, fileid).Result;
req.KmsDetail.Url = req.Url;
Expand All @@ -71,8 +109,7 @@ public void ImportKMSTask(ImportKMSTaskReq req)
case ImportType.Excel:
using (var fs = File.OpenRead(req.FilePath))
{
var excelList= ExeclHelper.ExcelToList<KMSExcelModel>(fs);

var excelList= ExeclHelper.ExcelToList<KMSExcelModel>(fs);
_memory.Orchestrator.AddHandler<TextExtractionHandler>("extract_text");
_memory.Orchestrator.AddHandler<KMExcelHandler>("antsk_excel_split");
_memory.Orchestrator.AddHandler<GenerateEmbeddingsHandler>("generate_embeddings");
Expand Down
19 changes: 16 additions & 3 deletions src/AntSK.Domain/Domain/Service/KernelService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
using AntSK.LLM.LLamaFactory;
using System.Reflection;
using DocumentFormat.OpenXml.Drawing;
using Microsoft.KernelMemory;
using OpenCvSharp.ML;

namespace AntSK.Domain.Domain.Service
{
Expand Down Expand Up @@ -57,7 +59,7 @@ public Kernel GetKernelByApp(Apps app)
var chatHttpClient = OpenAIHttpClientHandlerUtil.GetHttpClient(chatModel.EndPoint);

var builder = Kernel.CreateBuilder();
WithTextGenerationByAIType(builder, app, chatModel, chatHttpClient);
WithTextGenerationByAIType(builder, chatModel, chatHttpClient);

_kernel = builder.Build();
RegisterPluginsWithKernel(_kernel);
Expand All @@ -69,7 +71,18 @@ public Kernel GetKernelByApp(Apps app)
//}
}

private void WithTextGenerationByAIType(IKernelBuilder builder, Apps app, AIModels chatModel, HttpClient chatHttpClient)
public Kernel GetKernelByAIModelID(string modelid)
{
var chatModel = _aIModels_Repositories.GetById(modelid);
var chatHttpClient = OpenAIHttpClientHandlerUtil.GetHttpClient(chatModel.EndPoint);
var builder = Kernel.CreateBuilder();
WithTextGenerationByAIType(builder, chatModel, chatHttpClient);
_kernel = builder.Build();
RegisterPluginsWithKernel(_kernel);
return _kernel;
}

private void WithTextGenerationByAIType(IKernelBuilder builder,AIModels chatModel, HttpClient chatHttpClient)
{
switch (chatModel.AIType)
{
Expand All @@ -96,7 +109,7 @@ private void WithTextGenerationByAIType(IKernelBuilder builder, Apps app, AIMode

case Model.Enum.AIType.SparkDesk:
var options = new SparkDeskOptions { AppId = chatModel.EndPoint, ApiSecret = chatModel.ModelKey, ApiKey = chatModel.ModelName, ModelVersion = Sdcb.SparkDesk.ModelVersion.V3_5 };
builder.Services.AddKeyedSingleton<ITextGenerationService>("spark-desk", new SparkDeskTextCompletion(options, app.Id));
builder.Services.AddKeyedSingleton<ITextGenerationService>("spark-desk", new SparkDeskTextCompletion(options, chatModel.Id));
break;

case Model.Enum.AIType.DashScope:
Expand Down
Loading

0 comments on commit 1cc56dd

Please sign in to comment.