Skip to content

Commit

Permalink
Added support for halfvec type
Browse files Browse the repository at this point in the history
  • Loading branch information
ankane committed May 17, 2024
1 parent 97d9601 commit b5c3c62
Show file tree
Hide file tree
Showing 19 changed files with 335 additions and 21 deletions.
5 changes: 5 additions & 0 deletions src/Pgvector.Dapper/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
## 0.3.0 (unreleased)

- Added support for `halfvec` type
- Dropped support for .NET Standard

## 0.2.0 (2024-04-17)

- Added support for Npgsql 8
Expand Down
28 changes: 28 additions & 0 deletions src/Pgvector.Dapper/HalfvecTypeHandler.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
using Dapper;
using Pgvector;
using System;
using System.Data;
using System.Data.SqlClient;

namespace Pgvector.Dapper;

public class HalfvecTypeHandler : SqlMapper.TypeHandler<HalfVector?>
{
public override HalfVector? Parse(object value)
=> value switch
{
null or DBNull => null,
HalfVector vec => vec,
_ => value.ToString() is string s ? new HalfVector(s) : null
};

public override void SetValue(IDbDataParameter parameter, HalfVector? value)
{
parameter.Value = value is null ? DBNull.Value : value;

if (parameter is SqlParameter sqlParameter)
{
sqlParameter.UdtTypeName = "halfvec";
}
}
}
2 changes: 1 addition & 1 deletion src/Pgvector.Dapper/Pgvector.Dapper.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
<PackageProjectUrl>https://github.com/pgvector/pgvector-dotnet</PackageProjectUrl>
<PackageReadmeFile>README.md</PackageReadmeFile>

<TargetFrameworks>netstandard2.0;net6.0</TargetFrameworks>
<TargetFrameworks>net6.0</TargetFrameworks>
<Nullable>enable</Nullable>
<LangVersion>latest</LangVersion>
</PropertyGroup>
Expand Down
1 change: 1 addition & 0 deletions src/Pgvector.EntityFrameworkCore/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
## 0.2.1 (unreleased)

- Added support for `halfvec` type
- Added support for compiled models
- Added `L1Distance` function

Expand Down
19 changes: 19 additions & 0 deletions src/Pgvector.EntityFrameworkCore/HalfvecTypeMapping.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
using Microsoft.EntityFrameworkCore.Storage;
using Npgsql.EntityFrameworkCore.PostgreSQL.Storage.Internal.Mapping;
using NpgsqlTypes;

namespace Pgvector.EntityFrameworkCore;

public class HalfvecTypeMapping : RelationalTypeMapping
{
public static HalfvecTypeMapping Default { get; } = new();

public HalfvecTypeMapping() : base("halfvec", typeof(HalfVector)) { }

public HalfvecTypeMapping(string storeType) : base(storeType, typeof(HalfVector)) { }

protected HalfvecTypeMapping(RelationalTypeMappingParameters parameters) : base(parameters) { }

protected override RelationalTypeMapping Clone(RelationalTypeMappingParameters parameters)
=> new HalfvecTypeMapping(parameters);
}
11 changes: 11 additions & 0 deletions src/Pgvector.EntityFrameworkCore/HalfvecTypeMappingSourcePlugin.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
using Microsoft.EntityFrameworkCore.Storage;

namespace Pgvector.EntityFrameworkCore;

public class HalfvecTypeMappingSourcePlugin : IRelationalTypeMappingSourcePlugin
{
public RelationalTypeMapping? FindMapping(in RelationalTypeMappingInfo mappingInfo)
=> mappingInfo.ClrType == typeof(HalfVector)
? new HalfvecTypeMapping(mappingInfo.StoreTypeName ?? "halfvec")
: null;
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ public void ApplyServices(IServiceCollection services)
.TryAdd<IMethodCallTranslatorPlugin, VectorDbFunctionsTranslatorPlugin>();

services.AddSingleton<IRelationalTypeMappingSourcePlugin, VectorTypeMappingSourcePlugin>();
services.AddSingleton<IRelationalTypeMappingSourcePlugin, HalfvecTypeMappingSourcePlugin>();
}

public void Validate(IDbContextOptions options) { }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@ namespace Pgvector.EntityFrameworkCore;

public static class VectorDbFunctionsExtensions
{
public static double L2Distance(this Vector a, Vector b)
public static double L2Distance(this object a, object b)
=> throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(L2Distance)));

public static double MaxInnerProduct(this Vector a, Vector b)
public static double MaxInnerProduct(this object a, object b)
=> throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(MaxInnerProduct)));

public static double CosineDistance(this Vector a, Vector b)
public static double CosineDistance(this object a, object b)
=> throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(CosineDistance)));

public static double L1Distance(this Vector a, Vector b)
public static double L1Distance(this object a, object b)
=> throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(L1Distance)));
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,29 +31,29 @@ private class VectorDbFunctionsTranslator : IMethodCallTranslator
private static readonly MethodInfo _methodL2Distance = typeof(VectorDbFunctionsExtensions)
.GetRuntimeMethod(nameof(VectorDbFunctionsExtensions.L2Distance), new[]
{
typeof(Vector),
typeof(Vector),
typeof(object),
typeof(object),
})!;

private static readonly MethodInfo _methodMaxInnerProduct = typeof(VectorDbFunctionsExtensions)
.GetRuntimeMethod(nameof(VectorDbFunctionsExtensions.MaxInnerProduct), new[]
{
typeof(Vector),
typeof(Vector),
typeof(object),
typeof(object),
})!;

private static readonly MethodInfo _methodCosineDistance = typeof(VectorDbFunctionsExtensions)
.GetRuntimeMethod(nameof(VectorDbFunctionsExtensions.CosineDistance), new[]
{
typeof(Vector),
typeof(Vector),
typeof(object),
typeof(object),
})!;

private static readonly MethodInfo _methodL1Distance = typeof(VectorDbFunctionsExtensions)
.GetRuntimeMethod(nameof(VectorDbFunctionsExtensions.L1Distance), new[]
{
typeof(Vector),
typeof(Vector),
typeof(object),
typeof(object),
})!;

public VectorDbFunctionsTranslator(
Expand Down
5 changes: 5 additions & 0 deletions src/Pgvector/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
## 0.3.0 (unreleased)

- Added support for `halfvec` type
- Dropped support for .NET Standard

## 0.2.0 (2023-11-24)

- Added support for Npgsql 8
Expand Down
44 changes: 44 additions & 0 deletions src/Pgvector/HalfVector.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
using System;
using System.Globalization;
using System.Linq;

namespace Pgvector;

public class HalfVector : IEquatable<HalfVector>
{
public ReadOnlyMemory<Half> Memory { get; }

public HalfVector(ReadOnlyMemory<Half> v)
=> Memory = v;

public HalfVector(string s)
=> Memory = Array.ConvertAll(s.Substring(1, s.Length - 2).Split(','), v => Half.Parse(v, CultureInfo.InvariantCulture));

public override string ToString()
=> string.Concat("[", string.Join(",", Memory.ToArray().Select(v => v.ToString(CultureInfo.InvariantCulture))), "]");

public Half[] ToArray()
=> Memory.ToArray();

public bool Equals(HalfVector? other)
=> other is not null && Memory.Span.SequenceEqual(other.Memory.Span);

public override bool Equals(object? obj)
=> obj is HalfVector vector && Equals(vector);

public static bool operator ==(HalfVector? x, HalfVector? y)
=> (x is null && y is null) || (x is not null && x.Equals(y));

public static bool operator !=(HalfVector? x, HalfVector? y) => !(x == y);

public override int GetHashCode()
{
var hashCode = new HashCode();
var span = Memory.Span;

for (var i = 0; i < span.Length; i++)
hashCode.Add(span[i]);

return hashCode.ToHashCode();
}
}
91 changes: 91 additions & 0 deletions src/Pgvector/Npgsql/HalfvecConverter.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
using System;
using System.Threading;
using System.Threading.Tasks;
using Npgsql.Internal;

namespace Pgvector.Npgsql;

public class HalfvecConverter : PgStreamingConverter<HalfVector>
{
public override HalfVector Read(PgReader reader)
{
if (reader.ShouldBuffer(2 * sizeof(ushort)))
reader.Buffer(2 * sizeof(ushort));

var dim = reader.ReadUInt16();
var unused = reader.ReadUInt16();
if (unused != 0)
throw new InvalidCastException("expected unused to be 0");

var vec = new Half[dim];
for (var i = 0; i < dim; i++)
{
if (reader.ShouldBuffer(sizeof(ushort)))
reader.Buffer(sizeof(ushort));
vec[i] = BitConverter.UInt16BitsToHalf(reader.ReadUInt16());
}

return new HalfVector(vec);
}

public override async ValueTask<HalfVector> ReadAsync(PgReader reader, CancellationToken cancellationToken = default)
{
if (reader.ShouldBuffer(2 * sizeof(ushort)))
await reader.BufferAsync(2 * sizeof(ushort), cancellationToken).ConfigureAwait(false);

var dim = reader.ReadUInt16();
var unused = reader.ReadUInt16();
if (unused != 0)
throw new InvalidCastException("expected unused to be 0");

var vec = new Half[dim];
for (var i = 0; i < dim; i++)
{
if (reader.ShouldBuffer(sizeof(ushort)))
await reader.BufferAsync(sizeof(ushort), cancellationToken).ConfigureAwait(false);
vec[i] = BitConverter.UInt16BitsToHalf(reader.ReadUInt16());
}

return new HalfVector(vec);
}

public override Size GetSize(SizeContext context, HalfVector value, ref object? writeState)
=> sizeof(ushort) * 2 + sizeof(ushort) * value.ToArray().Length;

public override void Write(PgWriter writer, HalfVector value)
{
if (writer.ShouldFlush(sizeof(ushort) * 2))
writer.Flush();

var span = value.Memory.Span;
var dim = span.Length;
writer.WriteUInt16(Convert.ToUInt16(dim));
writer.WriteUInt16(0);

for (int i = 0; i < dim; i++)
{
if (writer.ShouldFlush(sizeof(ushort)))
writer.Flush();
writer.WriteUInt16(BitConverter.HalfToUInt16Bits(span[i]));
}
}

public override async ValueTask WriteAsync(
PgWriter writer, HalfVector value, CancellationToken cancellationToken = default)
{
if (writer.ShouldFlush(sizeof(ushort) * 2))
await writer.FlushAsync(cancellationToken);

var memory = value.Memory;
var dim = memory.Length;
writer.WriteUInt16(Convert.ToUInt16(dim));
writer.WriteUInt16(0);

for (int i = 0; i < dim; i++)
{
if (writer.ShouldFlush(sizeof(ushort)))
await writer.FlushAsync(cancellationToken);
writer.WriteUInt16(BitConverter.HalfToUInt16Bits(memory.Span[i]));
}
}
}
42 changes: 42 additions & 0 deletions src/Pgvector/Npgsql/HalfvecTypeInfoResolverFactory.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
using System;
using Npgsql.Internal;
using Npgsql.Internal.Postgres;

namespace Pgvector.Npgsql;

public class HalfvecTypeInfoResolverFactory : PgTypeInfoResolverFactory
{
public override IPgTypeInfoResolver CreateResolver() => new Resolver();
public override IPgTypeInfoResolver CreateArrayResolver() => new ArrayResolver();

class Resolver : IPgTypeInfoResolver
{
TypeInfoMappingCollection? _mappings;
protected TypeInfoMappingCollection Mappings => _mappings ??= AddMappings(new());

public PgTypeInfo? GetTypeInfo(Type? type, DataTypeName? dataTypeName, PgSerializerOptions options)
=> Mappings.Find(type, dataTypeName, options);

static TypeInfoMappingCollection AddMappings(TypeInfoMappingCollection mappings)
{
mappings.AddType<HalfVector>("halfvec",
static (options, mapping, _) => mapping.CreateInfo(options, new HalfvecConverter()), isDefault: true);
return mappings;
}
}

sealed class ArrayResolver : Resolver, IPgTypeInfoResolver
{
TypeInfoMappingCollection? _mappings;
new TypeInfoMappingCollection Mappings => _mappings ??= AddMappings(new(base.Mappings));

public new PgTypeInfo? GetTypeInfo(Type? type, DataTypeName? dataTypeName, PgSerializerOptions options)
=> Mappings.Find(type, dataTypeName, options);

static TypeInfoMappingCollection AddMappings(TypeInfoMappingCollection mappings)
{
mappings.AddArrayType<HalfVector>("halfvec");
return mappings;
}
}
}
1 change: 1 addition & 0 deletions src/Pgvector/Npgsql/VectorExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ public static class VectorExtensions
public static INpgsqlTypeMapper UseVector(this INpgsqlTypeMapper mapper)
{
mapper.AddTypeInfoResolverFactory(new VectorTypeInfoResolverFactory());
mapper.AddTypeInfoResolverFactory(new HalfvecTypeInfoResolverFactory());
return mapper;
}
}
2 changes: 1 addition & 1 deletion src/Pgvector/Pgvector.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
<PackageProjectUrl>https://github.com/pgvector/pgvector-dotnet</PackageProjectUrl>
<PackageReadmeFile>README.md</PackageReadmeFile>

<TargetFrameworks>netstandard2.0;net6.0</TargetFrameworks>
<TargetFrameworks>net6.0</TargetFrameworks>
<Nullable>enable</Nullable>
<LangVersion>latest</LangVersion>
</PropertyGroup>
Expand Down
10 changes: 8 additions & 2 deletions tests/Pgvector.CSharp.Tests/DapperTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ public class DapperItem
{
public int Id { get; set; }
public Vector? Embedding { get; set; }
public HalfVector? HalfEmbedding { get; set; }
}

public class DapperTests
Expand All @@ -15,6 +16,7 @@ public class DapperTests
public async Task Main()
{
SqlMapper.AddTypeHandler(new VectorTypeHandler());
SqlMapper.AddTypeHandler(new HalfvecTypeHandler());

var connString = "Host=localhost;Database=pgvector_dotnet_test";

Expand All @@ -28,17 +30,21 @@ public async Task Main()
conn.ReloadTypes();

conn.Execute("DROP TABLE IF EXISTS dapper_items");
conn.Execute("CREATE TABLE dapper_items (id serial PRIMARY KEY, embedding vector(3))");
conn.Execute("CREATE TABLE dapper_items (id serial PRIMARY KEY, embedding vector(3), halfembedding halfvec(3))");

var embedding1 = new Vector(new float[] { 1, 1, 1 });
var embedding2 = new Vector(new float[] { 2, 2, 2 });
var embedding3 = new Vector(new float[] { 1, 1, 2 });
conn.Execute(@"INSERT INTO dapper_items (embedding) VALUES (@embedding1), (@embedding2), (@embedding3)", new { embedding1, embedding2, embedding3 });
var halfEmbedding1 = new HalfVector(new Half[] { (Half)1, (Half)1, (Half)1 });
var halfEmbedding2 = new HalfVector(new Half[] { (Half)2, (Half)2, (Half)2 });
var halfEmbedding3 = new HalfVector(new Half[] { (Half)1, (Half)1, (Half)2 });
conn.Execute(@"INSERT INTO dapper_items (embedding, halfembedding) VALUES (@embedding1, @halfEmbedding1), (@embedding2, @halfEmbedding2), (@embedding3, @halfEmbedding3)", new { embedding1, halfEmbedding1, embedding2, halfEmbedding2, embedding3, halfEmbedding3 });

var embedding = new Vector(new float[] { 1, 1, 1 });
var items = conn.Query<DapperItem>("SELECT * FROM dapper_items ORDER BY embedding <-> @embedding LIMIT 5", new { embedding }).AsList();
Assert.Equal(new int[] { 1, 3, 2 }, items.Select(v => v.Id).ToArray());
Assert.Equal(new float[] { 1, 1, 1 }, items[0].Embedding!.ToArray());
Assert.Equal(new Half[] { (Half)1, (Half)1, (Half)1 }, items[0].HalfEmbedding!.ToArray());

conn.Execute("CREATE INDEX ON dapper_items USING ivfflat (embedding vector_l2_ops) WITH (lists = 100)");
}
Expand Down
Loading

0 comments on commit b5c3c62

Please sign in to comment.