From b5c3c620a840b7f0654660bcfbd8decb5d7e6cbb Mon Sep 17 00:00:00 2001 From: Andrew Kane Date: Fri, 17 May 2024 09:29:22 -0400 Subject: [PATCH] Added support for halfvec type --- src/Pgvector.Dapper/CHANGELOG.md | 5 + src/Pgvector.Dapper/HalfvecTypeHandler.cs | 28 ++++++ src/Pgvector.Dapper/Pgvector.Dapper.csproj | 2 +- src/Pgvector.EntityFrameworkCore/CHANGELOG.md | 1 + .../HalfvecTypeMapping.cs | 19 ++++ .../HalfvecTypeMappingSourcePlugin.cs | 11 +++ .../VectorDbContextOptionsExtension.cs | 1 + .../VectorDbFunctionsExtensions.cs | 8 +- .../VectorDbFunctionsTranslatorPlugin.cs | 16 ++-- src/Pgvector/CHANGELOG.md | 5 + src/Pgvector/HalfVector.cs | 44 +++++++++ src/Pgvector/Npgsql/HalfvecConverter.cs | 91 +++++++++++++++++++ .../Npgsql/HalfvecTypeInfoResolverFactory.cs | 42 +++++++++ src/Pgvector/Npgsql/VectorExtensions.cs | 1 + src/Pgvector/Pgvector.csproj | 2 +- tests/Pgvector.CSharp.Tests/DapperTests.cs | 10 +- .../EntityFrameworkCoreTests.cs | 14 ++- .../Pgvector.CSharp.Tests/HalfVectorTests.cs | 41 +++++++++ tests/Pgvector.CSharp.Tests/NpgsqlTests.cs | 15 ++- 19 files changed, 335 insertions(+), 21 deletions(-) create mode 100644 src/Pgvector.Dapper/HalfvecTypeHandler.cs create mode 100644 src/Pgvector.EntityFrameworkCore/HalfvecTypeMapping.cs create mode 100644 src/Pgvector.EntityFrameworkCore/HalfvecTypeMappingSourcePlugin.cs create mode 100644 src/Pgvector/HalfVector.cs create mode 100644 src/Pgvector/Npgsql/HalfvecConverter.cs create mode 100644 src/Pgvector/Npgsql/HalfvecTypeInfoResolverFactory.cs create mode 100644 tests/Pgvector.CSharp.Tests/HalfVectorTests.cs diff --git a/src/Pgvector.Dapper/CHANGELOG.md b/src/Pgvector.Dapper/CHANGELOG.md index 8f204e0..6bb01a2 100644 --- a/src/Pgvector.Dapper/CHANGELOG.md +++ b/src/Pgvector.Dapper/CHANGELOG.md @@ -1,3 +1,8 @@ +## 0.3.0 (unreleased) + +- Added support for `halfvec` type +- Dropped support for .NET Standard + ## 0.2.0 (2024-04-17) - Added support for Npgsql 8 diff --git a/src/Pgvector.Dapper/HalfvecTypeHandler.cs b/src/Pgvector.Dapper/HalfvecTypeHandler.cs new file mode 100644 index 0000000..f9a7f0f --- /dev/null +++ b/src/Pgvector.Dapper/HalfvecTypeHandler.cs @@ -0,0 +1,28 @@ +using Dapper; +using Pgvector; +using System; +using System.Data; +using System.Data.SqlClient; + +namespace Pgvector.Dapper; + +public class HalfvecTypeHandler : SqlMapper.TypeHandler +{ + public override HalfVector? Parse(object value) + => value switch + { + null or DBNull => null, + HalfVector vec => vec, + _ => value.ToString() is string s ? new HalfVector(s) : null + }; + + public override void SetValue(IDbDataParameter parameter, HalfVector? value) + { + parameter.Value = value is null ? DBNull.Value : value; + + if (parameter is SqlParameter sqlParameter) + { + sqlParameter.UdtTypeName = "halfvec"; + } + } +} diff --git a/src/Pgvector.Dapper/Pgvector.Dapper.csproj b/src/Pgvector.Dapper/Pgvector.Dapper.csproj index 181f9dd..c6d8a15 100644 --- a/src/Pgvector.Dapper/Pgvector.Dapper.csproj +++ b/src/Pgvector.Dapper/Pgvector.Dapper.csproj @@ -9,7 +9,7 @@ https://github.com/pgvector/pgvector-dotnet README.md - netstandard2.0;net6.0 + net6.0 enable latest diff --git a/src/Pgvector.EntityFrameworkCore/CHANGELOG.md b/src/Pgvector.EntityFrameworkCore/CHANGELOG.md index a6a8f70..9b81ac1 100644 --- a/src/Pgvector.EntityFrameworkCore/CHANGELOG.md +++ b/src/Pgvector.EntityFrameworkCore/CHANGELOG.md @@ -1,5 +1,6 @@ ## 0.2.1 (unreleased) +- Added support for `halfvec` type - Added support for compiled models - Added `L1Distance` function diff --git a/src/Pgvector.EntityFrameworkCore/HalfvecTypeMapping.cs b/src/Pgvector.EntityFrameworkCore/HalfvecTypeMapping.cs new file mode 100644 index 0000000..d1d4bfa --- /dev/null +++ b/src/Pgvector.EntityFrameworkCore/HalfvecTypeMapping.cs @@ -0,0 +1,19 @@ +using Microsoft.EntityFrameworkCore.Storage; +using Npgsql.EntityFrameworkCore.PostgreSQL.Storage.Internal.Mapping; +using NpgsqlTypes; + +namespace Pgvector.EntityFrameworkCore; + +public class HalfvecTypeMapping : RelationalTypeMapping +{ + public static HalfvecTypeMapping Default { get; } = new(); + + public HalfvecTypeMapping() : base("halfvec", typeof(HalfVector)) { } + + public HalfvecTypeMapping(string storeType) : base(storeType, typeof(HalfVector)) { } + + protected HalfvecTypeMapping(RelationalTypeMappingParameters parameters) : base(parameters) { } + + protected override RelationalTypeMapping Clone(RelationalTypeMappingParameters parameters) + => new HalfvecTypeMapping(parameters); +} diff --git a/src/Pgvector.EntityFrameworkCore/HalfvecTypeMappingSourcePlugin.cs b/src/Pgvector.EntityFrameworkCore/HalfvecTypeMappingSourcePlugin.cs new file mode 100644 index 0000000..663bf71 --- /dev/null +++ b/src/Pgvector.EntityFrameworkCore/HalfvecTypeMappingSourcePlugin.cs @@ -0,0 +1,11 @@ +using Microsoft.EntityFrameworkCore.Storage; + +namespace Pgvector.EntityFrameworkCore; + +public class HalfvecTypeMappingSourcePlugin : IRelationalTypeMappingSourcePlugin +{ + public RelationalTypeMapping? FindMapping(in RelationalTypeMappingInfo mappingInfo) + => mappingInfo.ClrType == typeof(HalfVector) + ? new HalfvecTypeMapping(mappingInfo.StoreTypeName ?? "halfvec") + : null; +} diff --git a/src/Pgvector.EntityFrameworkCore/VectorDbContextOptionsExtension.cs b/src/Pgvector.EntityFrameworkCore/VectorDbContextOptionsExtension.cs index 570e4dc..96c4225 100644 --- a/src/Pgvector.EntityFrameworkCore/VectorDbContextOptionsExtension.cs +++ b/src/Pgvector.EntityFrameworkCore/VectorDbContextOptionsExtension.cs @@ -17,6 +17,7 @@ public void ApplyServices(IServiceCollection services) .TryAdd(); services.AddSingleton(); + services.AddSingleton(); } public void Validate(IDbContextOptions options) { } diff --git a/src/Pgvector.EntityFrameworkCore/VectorDbFunctionsExtensions.cs b/src/Pgvector.EntityFrameworkCore/VectorDbFunctionsExtensions.cs index 131df01..1077c16 100644 --- a/src/Pgvector.EntityFrameworkCore/VectorDbFunctionsExtensions.cs +++ b/src/Pgvector.EntityFrameworkCore/VectorDbFunctionsExtensions.cs @@ -4,15 +4,15 @@ namespace Pgvector.EntityFrameworkCore; public static class VectorDbFunctionsExtensions { - public static double L2Distance(this Vector a, Vector b) + public static double L2Distance(this object a, object b) => throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(L2Distance))); - public static double MaxInnerProduct(this Vector a, Vector b) + public static double MaxInnerProduct(this object a, object b) => throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(MaxInnerProduct))); - public static double CosineDistance(this Vector a, Vector b) + public static double CosineDistance(this object a, object b) => throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(CosineDistance))); - public static double L1Distance(this Vector a, Vector b) + public static double L1Distance(this object a, object b) => throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(L1Distance))); } diff --git a/src/Pgvector.EntityFrameworkCore/VectorDbFunctionsTranslatorPlugin.cs b/src/Pgvector.EntityFrameworkCore/VectorDbFunctionsTranslatorPlugin.cs index b19d00a..498dc12 100644 --- a/src/Pgvector.EntityFrameworkCore/VectorDbFunctionsTranslatorPlugin.cs +++ b/src/Pgvector.EntityFrameworkCore/VectorDbFunctionsTranslatorPlugin.cs @@ -31,29 +31,29 @@ private class VectorDbFunctionsTranslator : IMethodCallTranslator private static readonly MethodInfo _methodL2Distance = typeof(VectorDbFunctionsExtensions) .GetRuntimeMethod(nameof(VectorDbFunctionsExtensions.L2Distance), new[] { - typeof(Vector), - typeof(Vector), + typeof(object), + typeof(object), })!; private static readonly MethodInfo _methodMaxInnerProduct = typeof(VectorDbFunctionsExtensions) .GetRuntimeMethod(nameof(VectorDbFunctionsExtensions.MaxInnerProduct), new[] { - typeof(Vector), - typeof(Vector), + typeof(object), + typeof(object), })!; private static readonly MethodInfo _methodCosineDistance = typeof(VectorDbFunctionsExtensions) .GetRuntimeMethod(nameof(VectorDbFunctionsExtensions.CosineDistance), new[] { - typeof(Vector), - typeof(Vector), + typeof(object), + typeof(object), })!; private static readonly MethodInfo _methodL1Distance = typeof(VectorDbFunctionsExtensions) .GetRuntimeMethod(nameof(VectorDbFunctionsExtensions.L1Distance), new[] { - typeof(Vector), - typeof(Vector), + typeof(object), + typeof(object), })!; public VectorDbFunctionsTranslator( diff --git a/src/Pgvector/CHANGELOG.md b/src/Pgvector/CHANGELOG.md index c01f02e..20d17ac 100644 --- a/src/Pgvector/CHANGELOG.md +++ b/src/Pgvector/CHANGELOG.md @@ -1,3 +1,8 @@ +## 0.3.0 (unreleased) + +- Added support for `halfvec` type +- Dropped support for .NET Standard + ## 0.2.0 (2023-11-24) - Added support for Npgsql 8 diff --git a/src/Pgvector/HalfVector.cs b/src/Pgvector/HalfVector.cs new file mode 100644 index 0000000..5f2ca81 --- /dev/null +++ b/src/Pgvector/HalfVector.cs @@ -0,0 +1,44 @@ +using System; +using System.Globalization; +using System.Linq; + +namespace Pgvector; + +public class HalfVector : IEquatable +{ + public ReadOnlyMemory Memory { get; } + + public HalfVector(ReadOnlyMemory v) + => Memory = v; + + public HalfVector(string s) + => Memory = Array.ConvertAll(s.Substring(1, s.Length - 2).Split(','), v => Half.Parse(v, CultureInfo.InvariantCulture)); + + public override string ToString() + => string.Concat("[", string.Join(",", Memory.ToArray().Select(v => v.ToString(CultureInfo.InvariantCulture))), "]"); + + public Half[] ToArray() + => Memory.ToArray(); + + public bool Equals(HalfVector? other) + => other is not null && Memory.Span.SequenceEqual(other.Memory.Span); + + public override bool Equals(object? obj) + => obj is HalfVector vector && Equals(vector); + + public static bool operator ==(HalfVector? x, HalfVector? y) + => (x is null && y is null) || (x is not null && x.Equals(y)); + + public static bool operator !=(HalfVector? x, HalfVector? y) => !(x == y); + + public override int GetHashCode() + { + var hashCode = new HashCode(); + var span = Memory.Span; + + for (var i = 0; i < span.Length; i++) + hashCode.Add(span[i]); + + return hashCode.ToHashCode(); + } +} diff --git a/src/Pgvector/Npgsql/HalfvecConverter.cs b/src/Pgvector/Npgsql/HalfvecConverter.cs new file mode 100644 index 0000000..6c000c9 --- /dev/null +++ b/src/Pgvector/Npgsql/HalfvecConverter.cs @@ -0,0 +1,91 @@ +using System; +using System.Threading; +using System.Threading.Tasks; +using Npgsql.Internal; + +namespace Pgvector.Npgsql; + +public class HalfvecConverter : PgStreamingConverter +{ + public override HalfVector Read(PgReader reader) + { + if (reader.ShouldBuffer(2 * sizeof(ushort))) + reader.Buffer(2 * sizeof(ushort)); + + var dim = reader.ReadUInt16(); + var unused = reader.ReadUInt16(); + if (unused != 0) + throw new InvalidCastException("expected unused to be 0"); + + var vec = new Half[dim]; + for (var i = 0; i < dim; i++) + { + if (reader.ShouldBuffer(sizeof(ushort))) + reader.Buffer(sizeof(ushort)); + vec[i] = BitConverter.UInt16BitsToHalf(reader.ReadUInt16()); + } + + return new HalfVector(vec); + } + + public override async ValueTask ReadAsync(PgReader reader, CancellationToken cancellationToken = default) + { + if (reader.ShouldBuffer(2 * sizeof(ushort))) + await reader.BufferAsync(2 * sizeof(ushort), cancellationToken).ConfigureAwait(false); + + var dim = reader.ReadUInt16(); + var unused = reader.ReadUInt16(); + if (unused != 0) + throw new InvalidCastException("expected unused to be 0"); + + var vec = new Half[dim]; + for (var i = 0; i < dim; i++) + { + if (reader.ShouldBuffer(sizeof(ushort))) + await reader.BufferAsync(sizeof(ushort), cancellationToken).ConfigureAwait(false); + vec[i] = BitConverter.UInt16BitsToHalf(reader.ReadUInt16()); + } + + return new HalfVector(vec); + } + + public override Size GetSize(SizeContext context, HalfVector value, ref object? writeState) + => sizeof(ushort) * 2 + sizeof(ushort) * value.ToArray().Length; + + public override void Write(PgWriter writer, HalfVector value) + { + if (writer.ShouldFlush(sizeof(ushort) * 2)) + writer.Flush(); + + var span = value.Memory.Span; + var dim = span.Length; + writer.WriteUInt16(Convert.ToUInt16(dim)); + writer.WriteUInt16(0); + + for (int i = 0; i < dim; i++) + { + if (writer.ShouldFlush(sizeof(ushort))) + writer.Flush(); + writer.WriteUInt16(BitConverter.HalfToUInt16Bits(span[i])); + } + } + + public override async ValueTask WriteAsync( + PgWriter writer, HalfVector value, CancellationToken cancellationToken = default) + { + if (writer.ShouldFlush(sizeof(ushort) * 2)) + await writer.FlushAsync(cancellationToken); + + var memory = value.Memory; + var dim = memory.Length; + writer.WriteUInt16(Convert.ToUInt16(dim)); + writer.WriteUInt16(0); + + for (int i = 0; i < dim; i++) + { + if (writer.ShouldFlush(sizeof(ushort))) + await writer.FlushAsync(cancellationToken); + writer.WriteUInt16(BitConverter.HalfToUInt16Bits(memory.Span[i])); + } + } +} diff --git a/src/Pgvector/Npgsql/HalfvecTypeInfoResolverFactory.cs b/src/Pgvector/Npgsql/HalfvecTypeInfoResolverFactory.cs new file mode 100644 index 0000000..36054ad --- /dev/null +++ b/src/Pgvector/Npgsql/HalfvecTypeInfoResolverFactory.cs @@ -0,0 +1,42 @@ +using System; +using Npgsql.Internal; +using Npgsql.Internal.Postgres; + +namespace Pgvector.Npgsql; + +public class HalfvecTypeInfoResolverFactory : PgTypeInfoResolverFactory +{ + public override IPgTypeInfoResolver CreateResolver() => new Resolver(); + public override IPgTypeInfoResolver CreateArrayResolver() => new ArrayResolver(); + + class Resolver : IPgTypeInfoResolver + { + TypeInfoMappingCollection? _mappings; + protected TypeInfoMappingCollection Mappings => _mappings ??= AddMappings(new()); + + public PgTypeInfo? GetTypeInfo(Type? type, DataTypeName? dataTypeName, PgSerializerOptions options) + => Mappings.Find(type, dataTypeName, options); + + static TypeInfoMappingCollection AddMappings(TypeInfoMappingCollection mappings) + { + mappings.AddType("halfvec", + static (options, mapping, _) => mapping.CreateInfo(options, new HalfvecConverter()), isDefault: true); + return mappings; + } + } + + sealed class ArrayResolver : Resolver, IPgTypeInfoResolver + { + TypeInfoMappingCollection? _mappings; + new TypeInfoMappingCollection Mappings => _mappings ??= AddMappings(new(base.Mappings)); + + public new PgTypeInfo? GetTypeInfo(Type? type, DataTypeName? dataTypeName, PgSerializerOptions options) + => Mappings.Find(type, dataTypeName, options); + + static TypeInfoMappingCollection AddMappings(TypeInfoMappingCollection mappings) + { + mappings.AddArrayType("halfvec"); + return mappings; + } + } +} diff --git a/src/Pgvector/Npgsql/VectorExtensions.cs b/src/Pgvector/Npgsql/VectorExtensions.cs index 7835d69..7162bd9 100644 --- a/src/Pgvector/Npgsql/VectorExtensions.cs +++ b/src/Pgvector/Npgsql/VectorExtensions.cs @@ -8,6 +8,7 @@ public static class VectorExtensions public static INpgsqlTypeMapper UseVector(this INpgsqlTypeMapper mapper) { mapper.AddTypeInfoResolverFactory(new VectorTypeInfoResolverFactory()); + mapper.AddTypeInfoResolverFactory(new HalfvecTypeInfoResolverFactory()); return mapper; } } diff --git a/src/Pgvector/Pgvector.csproj b/src/Pgvector/Pgvector.csproj index 84e9616..711a7a0 100644 --- a/src/Pgvector/Pgvector.csproj +++ b/src/Pgvector/Pgvector.csproj @@ -9,7 +9,7 @@ https://github.com/pgvector/pgvector-dotnet README.md - netstandard2.0;net6.0 + net6.0 enable latest diff --git a/tests/Pgvector.CSharp.Tests/DapperTests.cs b/tests/Pgvector.CSharp.Tests/DapperTests.cs index c644b12..f11cd73 100644 --- a/tests/Pgvector.CSharp.Tests/DapperTests.cs +++ b/tests/Pgvector.CSharp.Tests/DapperTests.cs @@ -7,6 +7,7 @@ public class DapperItem { public int Id { get; set; } public Vector? Embedding { get; set; } + public HalfVector? HalfEmbedding { get; set; } } public class DapperTests @@ -15,6 +16,7 @@ public class DapperTests public async Task Main() { SqlMapper.AddTypeHandler(new VectorTypeHandler()); + SqlMapper.AddTypeHandler(new HalfvecTypeHandler()); var connString = "Host=localhost;Database=pgvector_dotnet_test"; @@ -28,17 +30,21 @@ public async Task Main() conn.ReloadTypes(); conn.Execute("DROP TABLE IF EXISTS dapper_items"); - conn.Execute("CREATE TABLE dapper_items (id serial PRIMARY KEY, embedding vector(3))"); + conn.Execute("CREATE TABLE dapper_items (id serial PRIMARY KEY, embedding vector(3), halfembedding halfvec(3))"); var embedding1 = new Vector(new float[] { 1, 1, 1 }); var embedding2 = new Vector(new float[] { 2, 2, 2 }); var embedding3 = new Vector(new float[] { 1, 1, 2 }); - conn.Execute(@"INSERT INTO dapper_items (embedding) VALUES (@embedding1), (@embedding2), (@embedding3)", new { embedding1, embedding2, embedding3 }); + var halfEmbedding1 = new HalfVector(new Half[] { (Half)1, (Half)1, (Half)1 }); + var halfEmbedding2 = new HalfVector(new Half[] { (Half)2, (Half)2, (Half)2 }); + var halfEmbedding3 = new HalfVector(new Half[] { (Half)1, (Half)1, (Half)2 }); + conn.Execute(@"INSERT INTO dapper_items (embedding, halfembedding) VALUES (@embedding1, @halfEmbedding1), (@embedding2, @halfEmbedding2), (@embedding3, @halfEmbedding3)", new { embedding1, halfEmbedding1, embedding2, halfEmbedding2, embedding3, halfEmbedding3 }); var embedding = new Vector(new float[] { 1, 1, 1 }); var items = conn.Query("SELECT * FROM dapper_items ORDER BY embedding <-> @embedding LIMIT 5", new { embedding }).AsList(); Assert.Equal(new int[] { 1, 3, 2 }, items.Select(v => v.Id).ToArray()); Assert.Equal(new float[] { 1, 1, 1 }, items[0].Embedding!.ToArray()); + Assert.Equal(new Half[] { (Half)1, (Half)1, (Half)1 }, items[0].HalfEmbedding!.ToArray()); conn.Execute("CREATE INDEX ON dapper_items USING ivfflat (embedding vector_l2_ops) WITH (lists = 100)"); } diff --git a/tests/Pgvector.CSharp.Tests/EntityFrameworkCoreTests.cs b/tests/Pgvector.CSharp.Tests/EntityFrameworkCoreTests.cs index 2de7ec8..a31188c 100644 --- a/tests/Pgvector.CSharp.Tests/EntityFrameworkCoreTests.cs +++ b/tests/Pgvector.CSharp.Tests/EntityFrameworkCoreTests.cs @@ -36,6 +36,9 @@ public class Item [Column("embedding", TypeName = "vector(3)")] public Vector? Embedding { get; set; } + + [Column("half_embedding", TypeName = "halfvec(3)")] + public HalfVector? HalfEmbedding { get; set; } } public class EntityFrameworkCoreTests @@ -49,15 +52,16 @@ public async Task Main() var databaseCreator = ctx.GetService(); databaseCreator.CreateTables(); - ctx.Items.Add(new Item { Embedding = new Vector(new float[] { 1, 1, 1 }) }); - ctx.Items.Add(new Item { Embedding = new Vector(new float[] { 2, 2, 2 }) }); - ctx.Items.Add(new Item { Embedding = new Vector(new float[] { 1, 1, 2 }) }); + ctx.Items.Add(new Item { Embedding = new Vector(new float[] { 1, 1, 1 }), HalfEmbedding = new HalfVector(new Half[] { (Half)1, (Half)1, (Half)1 }) }); + ctx.Items.Add(new Item { Embedding = new Vector(new float[] { 2, 2, 2 }), HalfEmbedding = new HalfVector(new Half[] { (Half)2, (Half)2, (Half)2 }) }); + ctx.Items.Add(new Item { Embedding = new Vector(new float[] { 1, 1, 2 }), HalfEmbedding = new HalfVector(new Half[] { (Half)1, (Half)1, (Half)2 }) }); ctx.SaveChanges(); var embedding = new Vector(new float[] { 1, 1, 1 }); var items = await ctx.Items.FromSql($"SELECT * FROM efcore_items ORDER BY embedding <-> {embedding} LIMIT 5").ToListAsync(); Assert.Equal(new int[] { 1, 3, 2 }, items.Select(v => v.Id).ToArray()); Assert.Equal(new float[] { 1, 1, 1 }, items[0].Embedding!.ToArray()); + Assert.Equal(new Half[] { (Half)1, (Half)1, (Half)1 }, items[0].HalfEmbedding!.ToArray()); items = await ctx.Items.OrderBy(x => x.Embedding!.L2Distance(embedding)).Take(5).ToListAsync(); Assert.Equal(new int[] { 1, 3, 2 }, items.Select(v => v.Id).ToArray()); @@ -72,6 +76,10 @@ public async Task Main() items = await ctx.Items.OrderBy(x => x.Embedding!.L1Distance(embedding)).Take(5).ToListAsync(); Assert.Equal(new int[] { 1, 3, 2 }, items.Select(v => v.Id).ToArray()); + var halfEmbedding = new HalfVector(new Half[] { (Half)1, (Half)1, (Half)1 }); + items = await ctx.Items.OrderBy(x => x.HalfEmbedding!.L2Distance(halfEmbedding)).Take(5).ToListAsync(); + Assert.Equal(new int[] { 1, 3, 2 }, items.Select(v => v.Id).ToArray()); + items = await ctx.Items .OrderBy(x => x.Id) .Where(x => x.Embedding!.L2Distance(embedding) < 1.5) diff --git a/tests/Pgvector.CSharp.Tests/HalfVectorTests.cs b/tests/Pgvector.CSharp.Tests/HalfVectorTests.cs new file mode 100644 index 0000000..79ee0a1 --- /dev/null +++ b/tests/Pgvector.CSharp.Tests/HalfVectorTests.cs @@ -0,0 +1,41 @@ +using Pgvector; + +namespace Pgvector.Tests; + +public class HalfVectorTests +{ + [Fact] + public void StringConstructor() + { + var v = new HalfVector("[1,2,3]"); + Assert.Equal("[1,2,3]", v.ToString()); + } + + [Fact] + public void ArrayConstructor() + { + var v = new HalfVector(new Half[] { (Half)1, (Half)2, (Half)3 }); + Assert.Equal(new Half[] { (Half)1, (Half)2, (Half)3 }, v.ToArray()); + } + + [Fact] + public void Equal() + { + var a = new HalfVector(new Half[] { (Half)1, (Half)1, (Half)1 }); + var b = new HalfVector(new Half[] { (Half)1, (Half)1, (Half)1 }); + var c = new HalfVector(new Half[] { (Half)1, (Half)2, (Half)3 }); + + Assert.Equal(a, b); + Assert.NotEqual(a, c); + + Assert.True(a == b); + Assert.False(a == c); + + Assert.False(a != b); + Assert.True(a != c); + + Assert.False(a == null); + Assert.False(null == a); + Assert.True((HalfVector?)null == null); + } +} diff --git a/tests/Pgvector.CSharp.Tests/NpgsqlTests.cs b/tests/Pgvector.CSharp.Tests/NpgsqlTests.cs index dfbb9f7..86328c9 100644 --- a/tests/Pgvector.CSharp.Tests/NpgsqlTests.cs +++ b/tests/Pgvector.CSharp.Tests/NpgsqlTests.cs @@ -25,19 +25,25 @@ public async Task Main() await cmd.ExecuteNonQueryAsync(); } - await using (var cmd = new NpgsqlCommand("CREATE TABLE items (id serial PRIMARY KEY, embedding vector(3))", conn)) + await using (var cmd = new NpgsqlCommand("CREATE TABLE items (id serial PRIMARY KEY, embedding vector(3), half_embedding halfvec(3))", conn)) { await cmd.ExecuteNonQueryAsync(); } - await using (var cmd = new NpgsqlCommand("INSERT INTO items (embedding) VALUES ($1), ($2), ($3)", conn)) + await using (var cmd = new NpgsqlCommand("INSERT INTO items (embedding, half_embedding) VALUES ($1, $2), ($3, $4), ($5, $6)", conn)) { var embedding1 = new Vector(new float[] { 1, 1, 1 }); var embedding2 = new Vector(new float[] { 2, 2, 2 }); var embedding3 = new Vector(new float[] { 1, 1, 2 }); + var halfEmbedding1 = new HalfVector(new Half[] { (Half)1, (Half)1, (Half)1 }); + var halfEmbedding2 = new HalfVector(new Half[] { (Half)2, (Half)2, (Half)2 }); + var halfEmbedding3 = new HalfVector(new Half[] { (Half)1, (Half)1, (Half)2 }); cmd.Parameters.AddWithValue(embedding1); + cmd.Parameters.AddWithValue(halfEmbedding1); cmd.Parameters.AddWithValue(embedding2); + cmd.Parameters.AddWithValue(halfEmbedding2); cmd.Parameters.AddWithValue(embedding3); + cmd.Parameters.AddWithValue(halfEmbedding3); await cmd.ExecuteNonQueryAsync(); } @@ -50,17 +56,22 @@ public async Task Main() { var ids = new List(); var embeddings = new List(); + var halfEmbeddings = new List(); while (await reader.ReadAsync()) { ids.Add((int)reader.GetValue(0)); embeddings.Add((Vector)reader.GetValue(1)); + halfEmbeddings.Add((HalfVector)reader.GetValue(2)); } Assert.Equal(new int[] { 1, 3, 2 }, ids.ToArray()); Assert.Equal(new float[] { 1, 1, 1 }, embeddings[0].ToArray()); Assert.Equal(new float[] { 1, 1, 2 }, embeddings[1].ToArray()); Assert.Equal(new float[] { 2, 2, 2 }, embeddings[2].ToArray()); + Assert.Equal(new Half[] { (Half)1, (Half)1, (Half)1 }, halfEmbeddings[0].ToArray()); + Assert.Equal(new Half[] { (Half)1, (Half)1, (Half)2 }, halfEmbeddings[1].ToArray()); + Assert.Equal(new Half[] { (Half)2, (Half)2, (Half)2 }, halfEmbeddings[2].ToArray()); } }