Skip to content

Commit

Permalink
Improved FromBuffer() to accept Memory<byte> for better memory manage…
Browse files Browse the repository at this point in the history
…ment (#316)

* Transformed FromBuffer() to FromMemory() for better memory management

* Kept FromBuffer naming + fixed a test
  • Loading branch information
sandrohanea authored Jan 4, 2025
1 parent 99dc7b6 commit 38a817c
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 21 deletions.
1 change: 0 additions & 1 deletion Whisper.net.Demo/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ await Parser.Default.ParseArguments<Options>(args)
.WithParsedAsync(Demo);

async Task Demo(Options opt)

{
if (!File.Exists(opt.ModelName))
{
Expand Down
Original file line number Diff line number Diff line change
@@ -1,41 +1,41 @@
// Licensed under the MIT license: https://opensource.org/licenses/MIT

using System.Buffers;
using System.Runtime.InteropServices;
using Whisper.net.Internals.Native;
using Whisper.net.Native;

namespace Whisper.net.Internals.ModelLoader;

internal class WhisperProcessorModelBufferLoader : IWhisperProcessorModelLoader
internal class WhisperProcessorModelMemoryLoader : IWhisperProcessorModelLoader
{
private readonly GCHandle pinnedBuffer;
private readonly MemoryHandle pinnedMemory;
private readonly WhisperAheads aHeads;
private readonly GCHandle? aheadsHandle;
private readonly UIntPtr bufferLength;

private readonly WhisperFactoryOptions options;

public WhisperProcessorModelBufferLoader(byte[] buffer, WhisperFactoryOptions options)
public WhisperProcessorModelMemoryLoader(Memory<byte> buffer, WhisperFactoryOptions options)
{
this.options = options;

pinnedBuffer = GCHandle.Alloc(buffer, GCHandleType.Pinned);
pinnedMemory = buffer.Pin();
aHeads = ModelLoaderUtils.GetWhisperAlignmentHeads(options.CustomAlignmentHeads, out aheadsHandle);
bufferLength = new UIntPtr((uint)buffer.Length);
}

public void Dispose()
{
pinnedBuffer.Free();
pinnedMemory.Dispose();
if (aheadsHandle.HasValue)
{
aheadsHandle.Value.Free();
}
}

public IntPtr LoadNativeContext(INativeWhisper nativeWhisper)
public unsafe IntPtr LoadNativeContext(INativeWhisper nativeWhisper)
{
return nativeWhisper.Whisper_Init_From_Buffer_With_Params_No_State(pinnedBuffer.AddrOfPinnedObject(), bufferLength,
return nativeWhisper.Whisper_Init_From_Buffer_With_Params_No_State((IntPtr)pinnedMemory.Pointer, bufferLength,
new WhisperContextParams()
{
UseGpu = options.UseGpu.AsByte(),
Expand Down
16 changes: 8 additions & 8 deletions Whisper.net/WhisperFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -116,30 +116,30 @@ public static WhisperFactory FromPath(string path, WhisperFactoryOptions options
}

/// <summary>
/// Creates a factory that uses the ggml model from a buffer in order to create <seealso cref="WhisperProcessorBuilder"/>.
/// Creates a factory that uses the ggml model from a buffer in memory in order to create <seealso cref="WhisperProcessorBuilder"/>.
/// </summary>
/// <param name="buffer">The buffer with the model.</param>
/// <param name="memory">The memory buffer with the model.</param>
/// <returns>An instance to the same builder.</returns>
/// <remarks>
/// If you don't know where to find a ggml model, you can use <seealso cref="Ggml.WhisperGgmlDownloader"/> which is downloading a model from huggingface.co.
/// </remarks>
public static WhisperFactory FromBuffer(byte[] buffer)
public static WhisperFactory FromBuffer(Memory<byte> memory)
{
return FromBuffer(buffer, WhisperFactoryOptions.Default);
return FromBuffer(memory, WhisperFactoryOptions.Default);
}

/// <summary>
/// Creates a factory that uses the ggml model from a buffer in order to create <seealso cref="WhisperProcessorBuilder"/>.
/// Creates a factory that uses the ggml model from a buffer in memory in order to create <seealso cref="WhisperProcessorBuilder"/>.
/// </summary>
/// <param name="buffer">The buffer with the model.</param>
/// <param name="memory">The memory buffer with the model.</param>
/// <param name="options">Thhe options for the factory and the loading of the model.</param>
/// <returns>An instance to the same builder.</returns>
/// <remarks>
/// If you don't know where to find a ggml model, you can use <seealso cref="Ggml.WhisperGgmlDownloader"/> which is downloading a model from huggingface.co.
/// </remarks>
public static WhisperFactory FromBuffer(byte[] buffer, WhisperFactoryOptions options)
public static WhisperFactory FromBuffer(Memory<byte> memory, WhisperFactoryOptions options)
{
return new WhisperFactory(new WhisperProcessorModelBufferLoader(buffer, options), options.DelayInitialization);
return new WhisperFactory(new WhisperProcessorModelMemoryLoader(memory, options), options.DelayInitialization);
}

/// <summary>
Expand Down
2 changes: 1 addition & 1 deletion tests/Whisper.net.Maui.Tests/MainPage.xaml.cs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ private async void ContentPage_Loaded(object sender, EventArgs e)
await mauiStream.CopyToAsync(audioFileStream);
audioFileStream.Seek(0, SeekOrigin.Begin);

using var whisperFactory = WhisperFactory.FromBuffer(memoryStream.ToArray());
using var whisperFactory = WhisperFactory.FromBuffer(memoryStream.GetBuffer().AsMemory(0, (int)memoryStream.Length));
var whisperBuilder = whisperFactory.CreateBuilder();
using var whisperProcessor = whisperBuilder.Build();
LblResult.Text = string.Empty;
Expand Down
6 changes: 3 additions & 3 deletions tests/Whisper.net.Tests/FactoryTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,10 @@ public void CreateBuilder_WithFileModel_ShouldReturnBuilder()
}

[Fact]
public void CreateBuilder_WithBufferedModel_ShouldReturnBuilder()
public void CreateBuilder_WithMemoryModel_ShouldReturnBuilder()
{
var buffer = File.ReadAllBytes(model.ModelFile);
using var factory = WhisperFactory.FromBuffer(buffer);
var memoryBuffer = File.ReadAllBytes(model.ModelFile);
using var factory = WhisperFactory.FromBuffer(memoryBuffer);
var builder = factory.CreateBuilder();
builder.Should().NotBeNull();
}
Expand Down

0 comments on commit 38a817c

Please sign in to comment.