Skip to content

Commit

Permalink
misc
Browse files Browse the repository at this point in the history
  • Loading branch information
dclipca committed Jan 24, 2025
1 parent 12d15a4 commit fa3e2fc
Show file tree
Hide file tree
Showing 6 changed files with 79 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ public async Task GetPerfInfo_ShouldReturnValidData()

// Assert
result.Should().NotBeNull();
result.TotalGenerations.Should().BeGreaterOrEqualTo(0);
result.TotalGenerations.Should().BeGreaterThanOrEqualTo(0);
result.Uptime.Should().BeGreaterThan(0);

// Log performance data
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,8 @@ protected IntegrationTestBase(ITestOutputHelper output)
{
HttpClient = new HttpClient
{
BaseAddress = new Uri(TestConfig.NativeApiBaseUrl)
BaseAddress = new Uri(TestConfig.BaseUrl)
},
BaseUrl = TestConfig.BaseUrl,
Logger = Logger,
});
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,21 @@
</PropertyGroup>

<ItemGroup>
<PackageReference Include="FluentAssertions" Version="7.0.0" />
<PackageReference Include="MartinCostello.Logging.XUnit" Version="0.4.0" />
<PackageReference Include="FluentAssertions" Version="8.0.1" />
<PackageReference Include="MartinCostello.Logging.XUnit" Version="0.5.1" />
<PackageReference Include="MicroElements.Testing.xUnit" Version="0.4.1" />
<PackageReference Include="Microsoft.Extensions.Logging.Console" Version="9.0.0" />
<PackageReference Include="Microsoft.Extensions.Logging.Debug" Version="9.0.0" />
<PackageReference Include="Microsoft.Extensions.Logging.Console" Version="9.0.1" />
<PackageReference Include="Microsoft.Extensions.Logging.Debug" Version="9.0.1" />
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="17.12.0" />
<PackageReference Include="Moq" Version="4.20.72" />
<PackageReference Include="System.Linq.Async" Version="6.0.1" />
<PackageReference Include="WireMock.Net" Version="1.6.9" />
<PackageReference Include="xunit" Version="2.9.2" />
<PackageReference Include="xunit.runner.visualstudio" Version="2.8.2">
<PackageReference Include="xunit" Version="2.9.3" />
<PackageReference Include="xunit.runner.visualstudio" Version="3.0.1">
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
<PrivateAssets>all</PrivateAssets>
</PackageReference>
<PackageReference Include="coverlet.collector" Version="6.0.2">
<PackageReference Include="coverlet.collector" Version="6.0.4">
<IncludeAssets>runtime; build; native; contentfiles; analyzers</IncludeAssets>
<PrivateAssets>all</PrivateAssets>
</PackageReference>
Expand Down
1 change: 0 additions & 1 deletion SpongeEngine.KoboldSharp.Tests/Unit/UnitTestBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ protected UnitTestBase(ITestOutputHelper output)
{
BaseAddress = new Uri(Server.Urls[0])
},
BaseUrl = TestConfig.BaseUrl,
Logger = LoggerFactory
.Create(builder => builder.AddXUnit(output))
.CreateLogger(GetType()),
Expand Down
64 changes: 62 additions & 2 deletions SpongeEngine.KoboldSharp/KoboldSharpClient.cs
Original file line number Diff line number Diff line change
@@ -1,14 +1,74 @@
using SpongeEngine.LLMSharp.Core;
using System.Runtime.CompilerServices;
using SpongeEngine.LLMSharp.Core;
using SpongeEngine.LLMSharp.Core.Interfaces;
using SpongeEngine.LLMSharp.Core.Models;

namespace SpongeEngine.KoboldSharp
{
public partial class KoboldSharpClient : LlmClientBase
public partial class KoboldSharpClient : LlmClientBase, ICompletionService
{
public override KoboldSharpClientOptions Options { get; }

public KoboldSharpClient(KoboldSharpClientOptions options): base(options)
{
Options = options;
}

public async Task<CompletionResult> CompleteAsync(CompletionRequest request, CancellationToken cancellationToken = new CancellationToken())
{
var koboldRequest = new KoboldSharpRequest
{
Prompt = request.Prompt,
MaxLength = request.MaxTokens ?? 80,
Temperature = request.Temperature,
TopP = request.TopP,
StopSequences = request.StopSequences.ToList(),
Stream = false
};

var startTime = DateTime.UtcNow;
var response = await GenerateAsync(koboldRequest, cancellationToken);

// Get token counts using KoboldCpp's token counting API
var promptTokens = await CountTokensAsync(new CountTokensRequest { Prompt = request.Prompt }, cancellationToken);
var responseTokens = await CountTokensAsync(new CountTokensRequest { Prompt = response.Results[0].Text }, cancellationToken);

return new CompletionResult
{
Text = response.Results[0].Text,
ModelId = request.ModelId,
GenerationTime = DateTime.UtcNow - startTime,
TokenUsage = new CompletionTokenUsage
{
PromptTokens = promptTokens.Count,
CompletionTokens = responseTokens.Count,
TotalTokens = promptTokens.Count + responseTokens.Count
}
};
}

public async IAsyncEnumerable<CompletionToken> StreamCompletionAsync(CompletionRequest request, [EnumeratorCancellation] CancellationToken cancellationToken = new CancellationToken())
{
var koboldRequest = new KoboldSharpRequest
{
Prompt = request.Prompt,
MaxLength = request.MaxTokens ?? 80,
Temperature = request.Temperature,
TopP = request.TopP,
StopSequences = request.StopSequences.ToList(),
Stream = true
};

await foreach (var token in GenerateStreamAsync(koboldRequest, cancellationToken))
{
var tokenCount = await CountTokensAsync(new CountTokensRequest { Prompt = token }, cancellationToken);

yield return new CompletionToken
{
Text = token,
TokenCount = tokenCount.Count
};
}
}
}
}
16 changes: 8 additions & 8 deletions SpongeEngine.KoboldSharp/SpongeEngine.KoboldSharp.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,16 @@

<!-- Package Dependencies -->
<ItemGroup>
<PackageReference Include="Microsoft.Extensions.Caching.Memory" Version="9.0.0" Condition="'$(TargetFramework)' == 'net6.0'" />
<PackageReference Include="Microsoft.Extensions.Caching.Memory" Version="9.0.0" Condition="'$(TargetFramework)' == 'net7.0'" />
<PackageReference Include="Microsoft.Extensions.Caching.Memory" Version="9.0.0" Condition="'$(TargetFramework)' == 'net8.0'" />
<PackageReference Include="Microsoft.Extensions.Caching.Memory" Version="9.0.1" Condition="'$(TargetFramework)' == 'net6.0'" />
<PackageReference Include="Microsoft.Extensions.Caching.Memory" Version="9.0.1" Condition="'$(TargetFramework)' == 'net7.0'" />
<PackageReference Include="Microsoft.Extensions.Caching.Memory" Version="9.0.1" Condition="'$(TargetFramework)' == 'net8.0'" />

<PackageReference Include="Microsoft.Extensions.Logging.Abstractions" Version="9.0.0" Condition="'$(TargetFramework)' == 'net6.0'" />
<PackageReference Include="Microsoft.Extensions.Logging.Abstractions" Version="9.0.0" Condition="'$(TargetFramework)' == 'net7.0'" />
<PackageReference Include="Microsoft.Extensions.Logging.Abstractions" Version="9.0.0" Condition="'$(TargetFramework)' == 'net8.0'" />
<PackageReference Include="Microsoft.Extensions.Logging.Abstractions" Version="9.0.1" Condition="'$(TargetFramework)' == 'net6.0'" />
<PackageReference Include="Microsoft.Extensions.Logging.Abstractions" Version="9.0.1" Condition="'$(TargetFramework)' == 'net7.0'" />
<PackageReference Include="Microsoft.Extensions.Logging.Abstractions" Version="9.0.1" Condition="'$(TargetFramework)' == 'net8.0'" />
<PackageReference Include="OllamaSharp" Version="4.0.11" />
<PackageReference Include="Polly" Version="8.5.0" />
<PackageReference Include="SpongeEngine.LLMSharp.Core" Version="1.1.2" />
<PackageReference Include="Polly" Version="8.5.1" />
<PackageReference Include="SpongeEngine.LLMSharp.Core" Version="2.0.1" />
<PackageReference Include="System.Linq.Async" Version="6.0.1" />
</ItemGroup>

Expand Down

0 comments on commit fa3e2fc

Please sign in to comment.