Skip to content

Commit

Permalink
TairVector: support cluster
Browse files Browse the repository at this point in the history
  • Loading branch information
DuanxinCao committed Dec 6, 2022
1 parent f621593 commit 7db35ac
Show file tree
Hide file tree
Showing 14 changed files with 1,193 additions and 477 deletions.
2 changes: 1 addition & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

<groupId>com.aliyun.tair</groupId>
<artifactId>alibabacloud-tairjedis-sdk</artifactId>
<version>2.4.0-SNAPSHOT</version>
<version>3.0.4</version>
<packaging>jar</packaging>

<name>alibabacloud-tairjedis-sdk</name>
Expand Down
317 changes: 107 additions & 210 deletions src/main/java/com/aliyun/tair/tairvector/TairVector.java

Large diffs are not rendered by default.

65 changes: 34 additions & 31 deletions src/main/java/com/aliyun/tair/tairvector/TairVectorCluster.java
Original file line number Diff line number Diff line change
@@ -1,14 +1,6 @@
package com.aliyun.tair.tairvector;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

import com.aliyun.tair.ModuleCommand;
import com.aliyun.tair.tairhash.factory.HashBuilderFactory;
import com.aliyun.tair.tairvector.factory.VectorBuilderFactory;
import com.aliyun.tair.tairvector.params.DistanceMethod;
import com.aliyun.tair.tairvector.params.HscanParams;
Expand All @@ -19,15 +11,19 @@
import redis.clients.jedis.ScanResult;
import redis.clients.jedis.util.SafeEncoder;

import java.util.*;
import java.util.stream.Collectors;

import static redis.clients.jedis.Protocol.toByteArray;

public class TairVectorCluster {
public class TairVectorCluster implements VectorShard {
private JedisCluster jc;

public TairVectorCluster(JedisCluster jc) {
this.jc = jc;
}

@Override
public void quit() {
if (jc != null) {
jc.close();
Expand All @@ -46,11 +42,13 @@ public void quit() {
* @param attrs other columns, optional
* @return Success: +OK; Fail: error
*/
@Override
public String tvscreateindex(final String index, int dims, IndexAlgorithm algorithm, DistanceMethod method, final String... attrs) {
Object obj = jc.sendCommand(SafeEncoder.encode(index), ModuleCommand.TVSCREATEINDEX, JoinParameters.joinParameters(SafeEncoder.encode(index), toByteArray(dims), SafeEncoder.encode(algorithm.name()), SafeEncoder.encode(method.name()), SafeEncoder.encodeMany(attrs)));
return BuilderFactory.STRING.build(obj);
}

@Override
public byte[] tvscreateindex(byte[] index, int dims, IndexAlgorithm algorithm, DistanceMethod method, final byte[]... params) {
Object obj = jc.sendCommand(index, ModuleCommand.TVSCREATEINDEX, JoinParameters.joinParameters(index, toByteArray(dims), SafeEncoder.encode(algorithm.name()), SafeEncoder.encode(method.name()), params));
return BuilderFactory.BYTE_ARRAY.build(obj);
Expand All @@ -64,11 +62,13 @@ public byte[] tvscreateindex(byte[] index, int dims, IndexAlgorithm algorithm, D
* @param index index name
* @return Success: string_map, Fail: empty
*/
@Override
public Map<String, String> tvsgetindex(final String index) {
Object obj = jc.sendCommand(SafeEncoder.encode(index), ModuleCommand.TVSGETINDEX, SafeEncoder.encode(index));
return BuilderFactory.STRING_MAP.build(obj);
}

@Override
public Map<byte[], byte[]> tvsgetindex(byte[] index) {
Object obj = jc.sendCommand(index, ModuleCommand.TVSGETINDEX, index);
return BuilderFactory.BYTE_ARRAY_MAP.build(obj);
Expand All @@ -82,37 +82,18 @@ public Map<byte[], byte[]> tvsgetindex(byte[] index) {
* @param index index name
* @return Success: 1; Fail: 0
*/
@Override
public Long tvsdelindex(final String index) {
Object obj = jc.sendCommand(SafeEncoder.encode(index), ModuleCommand.TVSDELINDEX, SafeEncoder.encode(index));
return BuilderFactory.LONG.build(obj);
}

@Override
public Long tvsdelindex(byte[] index) {
Object obj = jc.sendCommand(index, ModuleCommand.TVSDELINDEX, index);
return BuilderFactory.LONG.build(obj);
}


/**
* TVS.SCANINDEX TVS.SCANINDEX index_name
* <p>
* scan index
*
* @param cursor start offset
* @param params the params: [MATCH pattern] [COUNT count]
* `MATCH` - Set the pattern which is used to filter the results
* `COUNT` - Set the number of fields in a single scan (default is 10)
* `NOVAL` - The return result contains no data portion, only cursor information
* @return A ScanResult. {@link HashBuilderFactory#EXHSCAN_RESULT_STRING}
*/
public ScanResult<String> tvsscanindex(Long cursor, HscanParams params) {
final List<byte[]> args = new ArrayList<byte[]>();
args.add(toByteArray(cursor));
args.addAll(params.getParams());
Object obj = jc.sendCommand(toByteArray(cursor), ModuleCommand.TVSSCANINDEX, args.toArray(new byte[args.size()][]));
return VectorBuilderFactory.SCAN_CURSOR_STRING.build(obj);
}

/**
* TVS.HSET TVS.HSET index entityid vector [(attribute_key attribute_value) ...]
* <p>
Expand All @@ -126,11 +107,13 @@ public ScanResult<String> tvsscanindex(Long cursor, HscanParams params) {
* {@literal k} if success, k is the number of fields that were added..
* throw error like "(error) Illegal vector dimensions" if error
*/
@Override
public Long tvshset(final String index, final String entityid, final String vector, final String... params) {
Object obj = jc.sendCommand(SafeEncoder.encode(index), ModuleCommand.TVSHSET, JoinParameters.joinParameters(SafeEncoder.encode(index), SafeEncoder.encode(entityid), SafeEncoder.encode(VectorBuilderFactory.VECTOR_TAG), SafeEncoder.encode(vector), SafeEncoder.encodeMany(params)));
return BuilderFactory.LONG.build(obj);
}

@Override
public Long tvshset(byte[] index, byte[] entityid, byte[] vector, final byte[]... params) {
Object obj = jc.sendCommand(index, ModuleCommand.TVSHSET, JoinParameters.joinParameters(index, entityid, SafeEncoder.encode(VectorBuilderFactory.VECTOR_TAG), vector, params));
return BuilderFactory.LONG.build(obj);
Expand All @@ -145,11 +128,13 @@ public Long tvshset(byte[] index, byte[] entityid, byte[] vector, final byte[]..
* @param entityid entity id
* @return Map, an empty list when {@code entityid} does not exist.
*/
@Override
public Map<String, String> tvshgetall(final String index, final String entityid) {
Object obj = jc.sendCommand(SafeEncoder.encode(index), ModuleCommand.TVSHGETALL, SafeEncoder.encode(index), SafeEncoder.encode(entityid));
return BuilderFactory.STRING_MAP.build(obj);
}

@Override
public Map<byte[], byte[]> tvshgetall(byte[] index, byte[] entityid) {
Object obj = jc.sendCommand(index, ModuleCommand.TVSHGETALL, index, entityid);
return BuilderFactory.BYTE_ARRAY_MAP.build(obj);
Expand All @@ -165,11 +150,13 @@ public Map<byte[], byte[]> tvshgetall(byte[] index, byte[] entityid) {
* @param attrs attrs
* @return List, an empty list when {@code entityid} or {@code attrs} does not exist .
*/
@Override
public List<String> tvshmget(final String index, final String entityid, final String... attrs) {
Object obj = jc.sendCommand(SafeEncoder.encode(index), ModuleCommand.TVSHMGET, JoinParameters.joinParameters(SafeEncoder.encode(index), SafeEncoder.encode(entityid), SafeEncoder.encodeMany(attrs)));
return BuilderFactory.STRING_LIST.build(obj);
}

@Override
public List<byte[]> tvshmget(byte[] index, byte[] entityid, byte[]... attrs) {
Object obj = jc.sendCommand(index, ModuleCommand.TVSHMGET, JoinParameters.joinParameters(index, entityid, attrs));
return BuilderFactory.BYTE_ARRAY_LIST.build(obj);
Expand All @@ -186,11 +173,13 @@ public List<byte[]> tvshmget(byte[] index, byte[] entityid, byte[]... attrs) {
* @return Long integer-reply the number of fields that were removed from the tair-vector
* not including specified but non existing fields.
*/
@Override
public Long tvsdel(final String index, final String entityid) {
Object obj = jc.sendCommand(SafeEncoder.encode(index), ModuleCommand.TVSDEL, SafeEncoder.encode(index), SafeEncoder.encode(entityid));
return BuilderFactory.LONG.build(obj);
}

@Override
public Long tvsdel(byte[] index, byte[] entityid) {
Object obj = jc.sendCommand(index, ModuleCommand.TVSDEL, index, entityid);
return BuilderFactory.LONG.build(obj);
Expand All @@ -207,13 +196,15 @@ public Long tvsdel(byte[] index, byte[] entityid) {
* @return Long integer-reply the number of fields that were removed from the tair-vector
* not including specified but non existing fields.
*/
@Override
public Long tvshdel(final String index, final String entityid, final String... attrs) {
Object obj = jc.sendCommand(SafeEncoder.encode(index), ModuleCommand.TVSHDEL, JoinParameters.joinParameters(SafeEncoder.encode(index), SafeEncoder.encode(entityid), SafeEncoder.encodeMany(attrs)));
return BuilderFactory.LONG.build(obj);
}

@Override
public Long tvshdel(byte[] index, byte[] entityid, byte[]... attrs) {
Object obj = jc.sendCommand(index, ModuleCommand.TVSHDEL, JoinParameters.joinParameters(index, entityid, attrs));
Object obj = jc.sendCommand(index, ModuleCommand.TVSHDEL, JoinParameters.joinParameters(index, entityid,attrs));
return BuilderFactory.LONG.build(obj);
}

Expand All @@ -231,6 +222,7 @@ public Long tvshdel(byte[] index, byte[] entityid, byte[]... attrs) {
* `NOVAL` - The return result contains no data portion, only cursor information
* @return A ScanResult.
*/
@Override
public ScanResult<String> tvsscan(final String index, Long cursor, HscanParams params) {
final List<byte[]> args = new ArrayList<byte[]>();
args.add(SafeEncoder.encode(index));
Expand All @@ -240,6 +232,7 @@ public ScanResult<String> tvsscan(final String index, Long cursor, HscanParams p
return VectorBuilderFactory.SCAN_CURSOR_STRING.build(obj);
}

@Override
public ScanResult<byte[]> tvsscan(byte[] index, Long cursor, HscanParams params) {
final List<byte[]> args = new ArrayList<byte[]>();
args.add(index);
Expand All @@ -261,10 +254,12 @@ public ScanResult<byte[]> tvsscan(byte[] index, Long cursor, HscanParams params)
* ef_search range [0, 1000]
* @return VectorBuilderFactory.Knn<>
*/
@Override
public VectorBuilderFactory.Knn<String> tvsknnsearch(final String index, Long topn, final String vector, final String... params) {
return tvsknnsearchfilter(index, topn, vector, "", params);
}

@Override
public VectorBuilderFactory.Knn<byte[]> tvsknnsearch(byte[] index, Long topn, byte[] vector, final byte[]... params) {
return tvsknnsearchfilter(index, topn, vector, SafeEncoder.encode(""), params);
}
Expand All @@ -282,12 +277,14 @@ public VectorBuilderFactory.Knn<byte[]> tvsknnsearch(byte[] index, Long topn, by
* ef_search range [0, 1000]
* @return VectorBuilderFactory.Knn<>
*/
@Override
public VectorBuilderFactory.Knn<String> tvsknnsearchfilter(final String index, Long topn, final String vector, final String pattern, final String... params) {
Object obj = jc.sendCommand(SafeEncoder.encode(index), ModuleCommand.TVSKNNSEARCH, JoinParameters.joinParameters(SafeEncoder.encode(index), toByteArray(topn),
SafeEncoder.encode(vector), SafeEncoder.encode(pattern), SafeEncoder.encodeMany(params)));
return VectorBuilderFactory.STRING_KNN_RESULT.build(obj);
}

@Override
public VectorBuilderFactory.Knn<byte[]> tvsknnsearchfilter(byte[] index, Long topn, byte[] vector, byte[] pattern, final byte[]... params) {
Object obj = jc.sendCommand(index, ModuleCommand.TVSKNNSEARCH, JoinParameters.joinParameters(index, toByteArray(topn), vector, pattern, params));
return VectorBuilderFactory.BYTE_KNN_RESULT.build(obj);
Expand All @@ -303,10 +300,12 @@ public VectorBuilderFactory.Knn<byte[]> tvsknnsearchfilter(byte[] index, Long to
* ef_search range [0, 1000]
* @return Collection<>
*/
@Override
public Collection<VectorBuilderFactory.Knn<String>> tvsmknnsearch(final String index, Long topn, Collection<String> vectors, final String... params) {
return tvsmknnsearchfilter(index, topn, vectors, "", params);
}

@Override
public Collection<VectorBuilderFactory.Knn<byte[]>> tvsmknnsearch(byte[] index, Long topn, Collection<byte[]> vectors, final byte[]... params) {
return tvsmknnsearchfilter(index, topn, vectors, SafeEncoder.encode(""), params);
}
Expand All @@ -322,6 +321,7 @@ public Collection<VectorBuilderFactory.Knn<byte[]>> tvsmknnsearch(byte[] index,
* ef_search range [0, 1000]
* @return Collection<>
*/
@Override
public Collection<VectorBuilderFactory.Knn<String>> tvsmknnsearchfilter(final String index, Long topn, Collection<String> vectors, final String pattern, final String... params) {
final List<byte[]> args = new ArrayList<byte[]>();
args.add(SafeEncoder.encode(index));
Expand All @@ -334,6 +334,7 @@ public Collection<VectorBuilderFactory.Knn<String>> tvsmknnsearchfilter(final St
return VectorBuilderFactory.STRING_KNN_BATCH_RESULT.build(obj);
}

@Override
public Collection<VectorBuilderFactory.Knn<byte[]>> tvsmknnsearchfilter(byte[] index, Long topn, Collection<byte[]> vectors, byte[] pattern, final byte[]... params) {
final List<byte[]> args = new ArrayList<byte[]>();
args.add(index);
Expand All @@ -345,4 +346,6 @@ public Collection<VectorBuilderFactory.Knn<byte[]>> tvsmknnsearchfilter(byte[] i
Object obj = jc.sendCommand(index, ModuleCommand.TVSMKNNSEARCH, args.toArray(new byte[args.size()][]));
return VectorBuilderFactory.BYTE_KNN_BATCH_RESULT.build(obj);
}


}
10 changes: 3 additions & 7 deletions src/main/java/com/aliyun/tair/tairvector/TairVectorPipeline.java
Original file line number Diff line number Diff line change
@@ -1,12 +1,5 @@
package com.aliyun.tair.tairvector;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

import com.aliyun.tair.ModuleCommand;
import com.aliyun.tair.tairhash.factory.HashBuilderFactory;
import com.aliyun.tair.tairvector.factory.VectorBuilderFactory;
Expand All @@ -20,6 +13,9 @@
import redis.clients.jedis.ScanResult;
import redis.clients.jedis.util.SafeEncoder;

import java.util.*;
import java.util.stream.Collectors;

import static redis.clients.jedis.Protocol.toByteArray;

public class TairVectorPipeline extends Pipeline {
Expand Down
Loading

0 comments on commit 7db35ac

Please sign in to comment.