diff --git a/Whisper.net/Ggml/GgmlType.cs b/Whisper.net/Ggml/GgmlType.cs index 06549aaa..4bf1f008 100644 --- a/Whisper.net/Ggml/GgmlType.cs +++ b/Whisper.net/Ggml/GgmlType.cs @@ -35,4 +35,16 @@ public enum WhisperAlignmentHeadsPreset LargeV2, LargeV3, LargeV3Turbo -} \ No newline at end of file +} + +public class WhisperAlignmentHead +{ + public int TextLayer; + public int Head; + + public WhisperAlignmentHead(int textLayer, int head) + { + TextLayer = textLayer; + Head = head; + } +} diff --git a/Whisper.net/Internals/ModelLoader/WhisperProcessorModelBufferLoader.cs b/Whisper.net/Internals/ModelLoader/WhisperProcessorModelBufferLoader.cs index 62465014..8c4a48f9 100644 --- a/Whisper.net/Internals/ModelLoader/WhisperProcessorModelBufferLoader.cs +++ b/Whisper.net/Internals/ModelLoader/WhisperProcessorModelBufferLoader.cs @@ -10,15 +10,23 @@ namespace Whisper.net.Internals.ModelLoader; internal class WhisperProcessorModelBufferLoader(byte[] buffer) : IWhisperProcessorModelLoader { private readonly GCHandle pinnedBuffer = GCHandle.Alloc(buffer, GCHandleType.Pinned); + private GCHandle aheadsHandle; public void Dispose() { pinnedBuffer.Free(); + if (aheadsHandle.IsAllocated) + { + aheadsHandle.Free(); + } } public IntPtr LoadNativeContext(INativeWhisper nativeWhisper) { var bufferLength = new UIntPtr((uint)buffer.Length); + + var aHeads = WhisperProcessorModelFileLoader.GetWhisperAlignmentHeads(RuntimeOptions.Instance.CustomAlignmentHeads, ref aheadsHandle); + return nativeWhisper.Whisper_Init_From_Buffer_With_Params_No_State(pinnedBuffer.AddrOfPinnedObject(), bufferLength, new WhisperContextParams() { @@ -28,11 +36,7 @@ public IntPtr LoadNativeContext(INativeWhisper nativeWhisper) DtwTokenLevelTimestamp = RuntimeOptions.Instance.UseDtwTimeStamps ? (byte)1 : (byte)0, HeadsPreset = (WhisperAlignmentHeadsPreset)RuntimeOptions.Instance.HeadsPreset, DtwNTop = -1, - WhisperAheads = new WhisperAheads() - { - NHeads = 0, - Heads = IntPtr.Zero - }, + WhisperAheads = aHeads, Dtw_mem_size = 1024 * 1024 * 128, }); } diff --git a/Whisper.net/Internals/ModelLoader/WhisperProcessorModelFileLoader.cs b/Whisper.net/Internals/ModelLoader/WhisperProcessorModelFileLoader.cs index 8f35cd98..1fd258da 100644 --- a/Whisper.net/Internals/ModelLoader/WhisperProcessorModelFileLoader.cs +++ b/Whisper.net/Internals/ModelLoader/WhisperProcessorModelFileLoader.cs @@ -1,5 +1,6 @@ // Licensed under the MIT license: https://opensource.org/licenses/MIT +using System.Runtime.InteropServices; using Whisper.net.Internals.Native; using Whisper.net.LibraryLoader; using Whisper.net.Native; @@ -8,13 +9,49 @@ namespace Whisper.net.Internals.ModelLoader; internal sealed class WhisperProcessorModelFileLoader(string pathModel) : IWhisperProcessorModelLoader { + private GCHandle aheadsHandle; + public void Dispose() { + if (aheadsHandle.IsAllocated) + { + aheadsHandle.Free(); + } + } + + public static WhisperAheads GetWhisperAlignmentHeads(Ggml.WhisperAlignmentHead[]? alignmentHeads, ref GCHandle aHeadsHandle) + { + var aHeadsPtr = IntPtr.Zero; + var nHeads = alignmentHeads?.Length ?? 0; + if (nHeads > 0) + { + var aHeads = new int[nHeads * 2]; + if (aHeadsHandle.IsAllocated) + { + aHeadsHandle.Free(); + } + aHeadsHandle = GCHandle.Alloc(aHeads, GCHandleType.Pinned); + aHeadsPtr = aHeadsHandle.AddrOfPinnedObject(); + + for (var i = 0; i < nHeads; i++) + { + aHeads[i * 2] = alignmentHeads![i].TextLayer; + aHeads[i * 2 + 1] = alignmentHeads[i].Head; + } + } + + return new WhisperAheads() + { + NHeads = (nuint)nHeads, + Heads = aHeadsPtr + }; } public IntPtr LoadNativeContext(INativeWhisper nativeWhisper) { + var aHeads = GetWhisperAlignmentHeads(RuntimeOptions.Instance.CustomAlignmentHeads, ref aheadsHandle); + return nativeWhisper.Whisper_Init_From_File_With_Params_No_State(pathModel, new WhisperContextParams() { @@ -24,11 +61,7 @@ public IntPtr LoadNativeContext(INativeWhisper nativeWhisper) DtwTokenLevelTimestamp = RuntimeOptions.Instance.UseDtwTimeStamps ? (byte)1 : (byte)0, HeadsPreset = (WhisperAlignmentHeadsPreset)RuntimeOptions.Instance.HeadsPreset, DtwNTop = -1, - WhisperAheads = new WhisperAheads() - { - NHeads = 0, - Heads = IntPtr.Zero - }, + WhisperAheads = aHeads, Dtw_mem_size = 1024 * 1024 * 128, }); } diff --git a/Whisper.net/LibraryLoader/RuntimeOptions.cs b/Whisper.net/LibraryLoader/RuntimeOptions.cs index 410e6060..973d71df 100644 --- a/Whisper.net/LibraryLoader/RuntimeOptions.cs +++ b/Whisper.net/LibraryLoader/RuntimeOptions.cs @@ -13,6 +13,7 @@ public class RuntimeOptions internal bool UseFlashAttention { get; private set; } internal bool UseDtwTimeStamps { get; private set; } internal WhisperAlignmentHeadsPreset HeadsPreset { get; private set; } + internal WhisperAlignmentHead[]? CustomAlignmentHeads { get; private set; } internal int GpuDevice { get; private set; } internal List RuntimeLibraryOrder { get; private set; } internal RuntimeLibrary? LoadedLibrary { get; private set; } @@ -27,6 +28,7 @@ private RuntimeOptions() UseFlashAttention = false; UseDtwTimeStamps = false; HeadsPreset = WhisperAlignmentHeadsPreset.None; + CustomAlignmentHeads = null; RuntimeLibraryOrder = defaultRuntimeOrder; GpuDevice = 0; } @@ -127,6 +129,17 @@ public void SetHeadsPreset(WhisperAlignmentHeadsPreset headsPreset) HeadsPreset = headsPreset; } + /// + /// Sets custom attention heads array for DTW. + /// + /// + /// By default, it is null. Required when using DTW with models which don't have a matching WhisperAlignmentHeadsPreset. + /// + public void SetAlignmentHeads(WhisperAlignmentHead[]? alignmentHeads) + { + CustomAlignmentHeads = alignmentHeads; + } + /// /// Resets the runtime options to their default values. /// @@ -138,6 +151,7 @@ public void Reset() UseFlashAttention = false; UseDtwTimeStamps = false; HeadsPreset = WhisperAlignmentHeadsPreset.None; + CustomAlignmentHeads = null; RuntimeLibraryOrder = defaultRuntimeOrder; GpuDevice = 0; }