diff --git a/tests/Pgvector.CSharp.Tests/EntityFrameworkCoreTests.cs b/tests/Pgvector.CSharp.Tests/EntityFrameworkCoreTests.cs index beb1503..8b07bf4 100644 --- a/tests/Pgvector.CSharp.Tests/EntityFrameworkCoreTests.cs +++ b/tests/Pgvector.CSharp.Tests/EntityFrameworkCoreTests.cs @@ -64,13 +64,15 @@ public async Task Main() ctx.Items.Add(new Item { Embedding = new Vector(new float[] { 1, 1, 2 }), HalfEmbedding = new HalfVector(new Half[] { (Half)1, (Half)1, (Half)2 }), BinaryEmbedding = new BitArray(new bool[] { true, true, true }), SparseEmbedding = new SparseVector(new float[] { 1, 1, 2 }) }); ctx.SaveChanges(); - // vector - var embedding = new Vector(new float[] { 1, 1, 1 }); var items = await ctx.Items.FromSql($"SELECT * FROM efcore_items ORDER BY embedding <-> {embedding} LIMIT 5").ToListAsync(); Assert.Equal(new int[] { 1, 3, 2 }, items.Select(v => v.Id).ToArray()); Assert.Equal(new float[] { 1, 1, 1 }, items[0].Embedding!.ToArray()); + Assert.Equal(new BitArray(new bool[] { false, false, false }), items[0].BinaryEmbedding!); Assert.Equal(new Half[] { (Half)1, (Half)1, (Half)1 }, items[0].HalfEmbedding!.ToArray()); + Assert.Equal(new float[] { 1, 1, 1 }, items[0].SparseEmbedding!.ToArray()); + + // vector distance functions items = await ctx.Items.OrderBy(x => x.Embedding!.L2Distance(embedding)).Take(5).ToListAsync(); Assert.Equal(new int[] { 1, 3, 2 }, items.Select(v => v.Id).ToArray()); @@ -85,7 +87,7 @@ public async Task Main() items = await ctx.Items.OrderBy(x => x.Embedding!.L1Distance(embedding)).Take(5).ToListAsync(); Assert.Equal(new int[] { 1, 3, 2 }, items.Select(v => v.Id).ToArray()); - // halfvec + // halfvec distance functions var halfEmbedding = new HalfVector(new Half[] { (Half)1, (Half)1, (Half)1 }); items = await ctx.Items.OrderBy(x => x.HalfEmbedding!.L2Distance(halfEmbedding)).Take(5).ToListAsync(); @@ -100,7 +102,7 @@ public async Task Main() items = await ctx.Items.OrderBy(x => x.HalfEmbedding!.L1Distance(halfEmbedding)).Take(5).ToListAsync(); Assert.Equal(new int[] { 1, 3, 2 }, items.Select(v => v.Id).ToArray()); - // sparsevec + // sparsevec distance functions var sparseEmbedding = new SparseVector(new float[] { 1, 1, 1 }); items = await ctx.Items.OrderBy(x => x.SparseEmbedding!.L2Distance(sparseEmbedding)).Take(5).ToListAsync(); @@ -115,7 +117,7 @@ public async Task Main() items = await ctx.Items.OrderBy(x => x.SparseEmbedding!.L1Distance(sparseEmbedding)).Take(5).ToListAsync(); Assert.Equal(new int[] { 1, 3, 2 }, items.Select(v => v.Id).ToArray()); - // bit + // bit distance functions var binaryEmbedding = new BitArray(new bool[] { true, false, true }); items = await ctx.Items.OrderBy(x => x.BinaryEmbedding!.HammingDistance(binaryEmbedding)).Take(5).ToListAsync();