Skip to content

Commit

Permalink
Added support for halfvec type
Browse files Browse the repository at this point in the history
  • Loading branch information
ankane committed May 17, 2024
1 parent e3bb664 commit 75d32cf
Show file tree
Hide file tree
Showing 9 changed files with 231 additions and 2 deletions.
5 changes: 5 additions & 0 deletions src/Pgvector.Dapper/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
## 0.3.0 (unreleased)

- Added support for `halfvec` type
- Dropped support for .NET Standard

## 0.2.0 (2024-04-17)

- Added support for Npgsql 8
Expand Down
2 changes: 1 addition & 1 deletion src/Pgvector.Dapper/Pgvector.Dapper.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
<PackageProjectUrl>https://github.com/pgvector/pgvector-dotnet</PackageProjectUrl>
<PackageReadmeFile>README.md</PackageReadmeFile>

<TargetFrameworks>netstandard2.0;net6.0</TargetFrameworks>
<TargetFrameworks>net6.0</TargetFrameworks>
<Nullable>enable</Nullable>
<LangVersion>latest</LangVersion>
</PropertyGroup>
Expand Down
5 changes: 5 additions & 0 deletions src/Pgvector/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
## 0.3.0 (unreleased)

- Added support for `halfvec` type
- Dropped support for .NET Standard

## 0.2.0 (2023-11-24)

- Added support for Npgsql 8
Expand Down
44 changes: 44 additions & 0 deletions src/Pgvector/HalfVector.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
using System;
using System.Globalization;
using System.Linq;

namespace Pgvector;

public class HalfVector : IEquatable<HalfVector>
{
public ReadOnlyMemory<Half> Memory { get; }

public HalfVector(ReadOnlyMemory<Half> v)
=> Memory = v;

public HalfVector(string s)
=> Memory = Array.ConvertAll(s.Substring(1, s.Length - 2).Split(','), v => Half.Parse(v, CultureInfo.InvariantCulture));

public override string ToString()
=> string.Concat("[", string.Join(",", Memory.ToArray().Select(v => v.ToString(CultureInfo.InvariantCulture))), "]");

public Half[] ToArray()
=> Memory.ToArray();

public bool Equals(HalfVector? other)
=> other is not null && Memory.Span.SequenceEqual(other.Memory.Span);

public override bool Equals(object? obj)
=> obj is HalfVector vector && Equals(vector);

public static bool operator ==(HalfVector? x, HalfVector? y)
=> (x is null && y is null) || (x is not null && x.Equals(y));

public static bool operator !=(HalfVector? x, HalfVector? y) => !(x == y);

public override int GetHashCode()
{
var hashCode = new HashCode();
var span = Memory.Span;

for (var i = 0; i < span.Length; i++)
hashCode.Add(span[i]);

return hashCode.ToHashCode();
}
}
91 changes: 91 additions & 0 deletions src/Pgvector/Npgsql/HalfvecConverter.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
using System;
using System.Threading;
using System.Threading.Tasks;
using Npgsql.Internal;

namespace Pgvector.Npgsql;

public class HalfvecConverter : PgStreamingConverter<HalfVector>
{
public override HalfVector Read(PgReader reader)
{
if (reader.ShouldBuffer(2 * sizeof(ushort)))
reader.Buffer(2 * sizeof(ushort));

var dim = reader.ReadUInt16();
var unused = reader.ReadUInt16();
if (unused != 0)
throw new InvalidCastException("expected unused to be 0");

var vec = new Half[dim];
for (var i = 0; i < dim; i++)
{
if (reader.ShouldBuffer(sizeof(ushort)))
reader.Buffer(sizeof(ushort));
vec[i] = BitConverter.ToHalf(BitConverter.GetBytes(reader.ReadUInt16()));
}

return new HalfVector(vec);
}

public override async ValueTask<HalfVector> ReadAsync(PgReader reader, CancellationToken cancellationToken = default)
{
if (reader.ShouldBuffer(2 * sizeof(ushort)))
await reader.BufferAsync(2 * sizeof(ushort), cancellationToken).ConfigureAwait(false);

var dim = reader.ReadUInt16();
var unused = reader.ReadUInt16();
if (unused != 0)
throw new InvalidCastException("expected unused to be 0");

var vec = new Half[dim];
for (var i = 0; i < dim; i++)
{
if (reader.ShouldBuffer(sizeof(ushort)))
await reader.BufferAsync(sizeof(ushort), cancellationToken).ConfigureAwait(false);
vec[i] = BitConverter.ToHalf(BitConverter.GetBytes(reader.ReadUInt16()));
}

return new HalfVector(vec);
}

public override Size GetSize(SizeContext context, HalfVector value, ref object? writeState)
=> sizeof(ushort) * 2 + sizeof(float) * value.ToArray().Length;

public override void Write(PgWriter writer, HalfVector value)
{
if (writer.ShouldFlush(sizeof(ushort) * 2))
writer.Flush();

var span = value.Memory.Span;
var dim = span.Length;
writer.WriteUInt16(Convert.ToUInt16(dim));
writer.WriteUInt16(0);

for (int i = 0; i < dim; i++)
{
if (writer.ShouldFlush(sizeof(ushort)))
writer.Flush();
writer.WriteBytes(BitConverter.GetBytes(span[i]));
}
}

public override async ValueTask WriteAsync(
PgWriter writer, HalfVector value, CancellationToken cancellationToken = default)
{
if (writer.ShouldFlush(sizeof(ushort) * 2))
await writer.FlushAsync(cancellationToken);

var memory = value.Memory;
var dim = memory.Length;
writer.WriteUInt16(Convert.ToUInt16(dim));
writer.WriteUInt16(0);

for (int i = 0; i < dim; i++)
{
if (writer.ShouldFlush(sizeof(ushort)))
await writer.FlushAsync(cancellationToken);
writer.WriteBytes(BitConverter.GetBytes(memory.Span[i]));
}
}
}
42 changes: 42 additions & 0 deletions src/Pgvector/Npgsql/HalfvecTypeInfoResolverFactory.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
using System;
using Npgsql.Internal;
using Npgsql.Internal.Postgres;

namespace Pgvector.Npgsql;

public class HalfvecTypeInfoResolverFactory : PgTypeInfoResolverFactory
{
public override IPgTypeInfoResolver CreateResolver() => new Resolver();
public override IPgTypeInfoResolver CreateArrayResolver() => new ArrayResolver();

class Resolver : IPgTypeInfoResolver
{
TypeInfoMappingCollection? _mappings;
protected TypeInfoMappingCollection Mappings => _mappings ??= AddMappings(new());

public PgTypeInfo? GetTypeInfo(Type? type, DataTypeName? dataTypeName, PgSerializerOptions options)
=> Mappings.Find(type, dataTypeName, options);

static TypeInfoMappingCollection AddMappings(TypeInfoMappingCollection mappings)
{
mappings.AddType<HalfVector>("halfvec",
static (options, mapping, _) => mapping.CreateInfo(options, new HalfvecConverter()), isDefault: true);
return mappings;
}
}

sealed class ArrayResolver : Resolver, IPgTypeInfoResolver
{
TypeInfoMappingCollection? _mappings;
new TypeInfoMappingCollection Mappings => _mappings ??= AddMappings(new(base.Mappings));

public new PgTypeInfo? GetTypeInfo(Type? type, DataTypeName? dataTypeName, PgSerializerOptions options)
=> Mappings.Find(type, dataTypeName, options);

static TypeInfoMappingCollection AddMappings(TypeInfoMappingCollection mappings)
{
mappings.AddArrayType<HalfVector>("halfvec");
return mappings;
}
}
}
1 change: 1 addition & 0 deletions src/Pgvector/Npgsql/VectorExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ public static class VectorExtensions
public static INpgsqlTypeMapper UseVector(this INpgsqlTypeMapper mapper)
{
mapper.AddTypeInfoResolverFactory(new VectorTypeInfoResolverFactory());
mapper.AddTypeInfoResolverFactory(new HalfvecTypeInfoResolverFactory());
return mapper;
}
}
2 changes: 1 addition & 1 deletion src/Pgvector/Pgvector.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
<PackageProjectUrl>https://github.com/pgvector/pgvector-dotnet</PackageProjectUrl>
<PackageReadmeFile>README.md</PackageReadmeFile>

<TargetFrameworks>netstandard2.0;net6.0</TargetFrameworks>
<TargetFrameworks>net6.0</TargetFrameworks>
<Nullable>enable</Nullable>
<LangVersion>latest</LangVersion>
</PropertyGroup>
Expand Down
41 changes: 41 additions & 0 deletions tests/Pgvector.CSharp.Tests/HalfVectorTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
using Pgvector;

namespace Pgvector.Tests;

public class HalfVectorTests
{
[Fact]
public void StringConstructor()
{
var v = new HalfVector("[1,2,3]");
Assert.Equal("[1,2,3]", v.ToString());
}

[Fact]
public void ArrayConstructor()
{
var v = new HalfVector(new Half[] { (Half)1, (Half)2, (Half)3 });
Assert.Equal(new Half[] { (Half)1, (Half)2, (Half)3 }, v.ToArray());
}

[Fact]
public void Equal()
{
var a = new HalfVector(new Half[] { (Half)1, (Half)1, (Half)1 });
var b = new HalfVector(new Half[] { (Half)1, (Half)1, (Half)1 });
var c = new HalfVector(new Half[] { (Half)1, (Half)2, (Half)3 });

Assert.Equal(a, b);
Assert.NotEqual(a, c);

Assert.True(a == b);
Assert.False(a == c);

Assert.False(a != b);
Assert.True(a != c);

Assert.False(a == null);
Assert.False(null == a);
Assert.True((HalfVector?)null == null);
}
}

0 comments on commit 75d32cf

Please sign in to comment.