Skip to content

Commit

Permalink
fixed: knn_vector field invalid encoder format
Browse files Browse the repository at this point in the history
  • Loading branch information
t83714 committed Aug 5, 2024
1 parent a0ffe00 commit 32bf083
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ class KnnVectorFieldTest extends AnyFlatSpec with Matchers with ElasticApi {
|"dimension":512,
|"method":{"name":"hnsw","engine":"faiss","space_type":"l2",
|"parameters":{"ef_construction":100,"m":50,"ef_search":50,
|"encoder":{"name":"pq","encoder":{"m":50,"code_size":100}}}}}""".stripMargin.replace(
|"encoder":{"name":"pq","parameters":{"m":50,"code_size":100}}}}}""".stripMargin.replace(
"\n",
""
)
Expand Down Expand Up @@ -161,7 +161,7 @@ class KnnVectorFieldTest extends AnyFlatSpec with Matchers with ElasticApi {
|"dimension":512,
|"method":{"name":"hnsw","engine":"faiss","space_type":"innerproduct",
|"parameters":{"ef_construction":100,"m":50,"ef_search":50,
|"encoder":{"name":"sq","encoder":{"clip":true,"type":"fp16"}}}}}""".stripMargin.replace(
|"encoder":{"name":"sq","parameters":{"clip":true,"type":"fp16"}}}}}""".stripMargin.replace(
"\n",
""
)
Expand All @@ -182,12 +182,11 @@ class KnnVectorFieldTest extends AnyFlatSpec with Matchers with ElasticApi {
KnnVectorField(
name = "myfield123",
dimension = 512,
HnswParameters(
IvfParameters(
engine = Some(KnnEngine.faiss),
spaceType = Some(SpaceType.innerProduct),
efConstruction = Some(100),
m = Some(50),
efSearch = Some(50),
nlist = Some(4),
nprobes = Some(2),
encoder = Some(
FaissEncoder(
Some(FaissEncoderName.sq),
Expand All @@ -201,9 +200,9 @@ class KnnVectorFieldTest extends AnyFlatSpec with Matchers with ElasticApi {
.string shouldBe
"""{"type":"knn_vector",
|"dimension":512,
|"method":{"name":"hnsw","engine":"faiss","space_type":"innerproduct",
|"parameters":{"ef_construction":100,"m":50,"ef_search":50,
|"encoder":{"name":"sq","encoder":{"clip":true,"type":"fp16"}}}}}""".stripMargin.replace(
|"method":{"name":"ivf","engine":"faiss","space_type":"innerproduct",
|"parameters":{"nlist":4,"nprobes":2,
|"encoder":{"name":"sq","parameters":{"clip":true,"type":"fp16"}}}}}""".stripMargin.replace(
"\n",
""
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ object IvfParameters {
engine.get match {
case KnnEngine.faiss =>
// do nothing as all parameters are accepted
if (!encoder.isInstanceOf[FaissEncoder]) {
if (encoder.nonEmpty && !encoder.get.isInstanceOf[FaissEncoder]) {
throw new IllegalArgumentException(
"encoder must be instance of FaissEncoder for faiss engine"
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,35 @@ object KnnVectorFieldBuilderFn {
.get("name")
.map(_.asInstanceOf[String])
.flatMap(v => FaissEncoderName.withName(v))
val m = e.get("m").map(_.asInstanceOf[Int])
val codeSize = e.get("code_size").map(_.asInstanceOf[Int])
val sqClip = e.get("clip").map(_.asInstanceOf[Boolean])
val m = e
.get("parameters")
.flatMap(p =>
p.asInstanceOf[Map[String, Any]]
.get("m")
.map(_.asInstanceOf[Int])
)
val codeSize = e
.get("parameters")
.flatMap(p =>
p.asInstanceOf[Map[String, Any]]
.get("code_size")
.map(_.asInstanceOf[Int])
)
val sqClip = e
.get("parameters")
.flatMap(p =>
p.asInstanceOf[Map[String, Any]]
.get("clip")
.map(_.asInstanceOf[Boolean])
)
val sqType = e
.get("type")
.map(_.asInstanceOf[String])
.flatMap(v => FaissScalarQuantizationType.withName(v))
.get("parameters")
.flatMap(p =>
p.asInstanceOf[Map[String, Any]]
.get("type")
.map(_.asInstanceOf[String])
flatMap (v => FaissScalarQuantizationType.withName(v))
)
FaissEncoder(
name = name,
m = m,
Expand Down Expand Up @@ -152,7 +174,7 @@ object KnnVectorFieldBuilderFn {
encoder.name.foreach(v => builder.field("name", v.name))

// start of encoder `parameters` field
builder.startObject("encoder")
builder.startObject("parameters")
encoder match {
case e: FaissEncoder =>
e.m.foreach(v => builder.field("m", v))
Expand All @@ -174,7 +196,7 @@ object KnnVectorFieldBuilderFn {
encoder.name.foreach(v => builder.field("name", v.name))

// start of encoder `parameters` field
builder.startObject("encoder")
builder.startObject("parameters")
encoder match {
case e: FaissEncoder =>
e.m.foreach(v => builder.field("m", v))
Expand Down

0 comments on commit 32bf083

Please sign in to comment.