This is an automated email from the ASF dual-hosted git repository.
twice pushed a commit to branch unstable
in repository https://gitbox.apache.org/repos/asf/kvrocks.git
The following commit(s) were added to refs/heads/unstable by this push:
new 7f812c7c feat(search): add support of HNSW vector fields for FT.CREATE
(#2477)
7f812c7c is described below
commit 7f812c7c0b4b9e893ba4ece4088dc555d2b2e4e8
Author: Rebecca Zhou <[email protected]>
AuthorDate: Thu Aug 8 15:44:38 2024 -0700
feat(search): add support of HNSW vector fields for FT.CREATE (#2477)
Co-authored-by: Twice <[email protected]>
---
src/commands/cmd_search.cc | 79 ++++++++++++++++++++++++++++++++-
src/commands/error_constants.h | 1 +
tests/gocase/unit/search/search_test.go | 78 +++++++++++++++++++++++++-------
3 files changed, 142 insertions(+), 16 deletions(-)
diff --git a/src/commands/cmd_search.cc b/src/commands/cmd_search.cc
index f543c80c..210938f9 100644
--- a/src/commands/cmd_search.cc
+++ b/src/commands/cmd_search.cc
@@ -85,12 +85,27 @@ class CommandFTCreate : public Commander {
}
std::unique_ptr<redis::IndexFieldMetadata> field_meta;
+ std::unique_ptr<HnswIndexCreationState> hnsw_state;
if (parser.EatEqICase("TAG")) {
field_meta = std::make_unique<redis::TagFieldMetadata>();
} else if (parser.EatEqICase("NUMERIC")) {
field_meta = std::make_unique<redis::NumericFieldMetadata>();
+ } else if (parser.EatEqICase("VECTOR")) {
+ if (parser.EatEqICase("HNSW")) {
+ field_meta = std::make_unique<redis::HnswVectorFieldMetadata>();
+ auto num_attributes = GET_OR_RET(parser.TakeInt<uint8_t>());
+ if (num_attributes < 6) {
+ return {Status::NotOK, errInvalidNumOfAttributes};
+ }
+ if (num_attributes % 2 != 0) {
+ return {Status::NotOK, "number of attributes must be multiple of
2"};
+ }
+ hnsw_state =
std::make_unique<HnswIndexCreationState>(num_attributes);
+ } else {
+ return {Status::RedisParseErr, "only support HNSW algorithm for
vector field"};
+ }
} else {
- return {Status::RedisParseErr, "expect field type TAG or NUMERIC"};
+ return {Status::RedisParseErr, "expect field type TAG, NUMERIC or
VECTOR"};
}
while (parser.Good()) {
@@ -110,11 +125,51 @@ class CommandFTCreate : public Commander {
} else {
break;
}
+ } else if (auto vector = dynamic_cast<redis::HnswVectorFieldMetadata
*>(field_meta.get())) {
+ if (hnsw_state->num_attributes <= 0) break;
+
+ if (parser.EatEqICase("TYPE")) {
+ if (parser.EatEqICase("FLOAT64")) {
+ vector->vector_type = VectorType::FLOAT64;
+ } else {
+ return {Status::RedisParseErr, "unsupported vector type"};
+ }
+ hnsw_state->type_set = true;
+ } else if (parser.EatEqICase("DIM")) {
+ vector->dim = GET_OR_RET(parser.TakeInt<uint16_t>());
+ hnsw_state->dim_set = true;
+ } else if (parser.EatEqICase("DISTANCE_METRIC")) {
+ if (parser.EatEqICase("L2")) {
+ vector->distance_metric = DistanceMetric::L2;
+ } else if (parser.EatEqICase("IP")) {
+ vector->distance_metric = DistanceMetric::IP;
+ } else if (parser.EatEqICase("COSINE")) {
+ vector->distance_metric = DistanceMetric::COSINE;
+ } else {
+ return {Status::RedisParseErr, "unsupported distance metric"};
+ }
+ hnsw_state->distance_metric_set = true;
+ } else if (parser.EatEqICase("M")) {
+ vector->m = GET_OR_RET(parser.TakeInt<uint16_t>());
+ } else if (parser.EatEqICase("EF_CONSTRUCTION")) {
+ vector->ef_construction = GET_OR_RET(parser.TakeInt<uint32_t>());
+ } else if (parser.EatEqICase("EF_RUNTIME")) {
+ vector->ef_runtime = GET_OR_RET(parser.TakeInt<uint32_t>());
+ } else if (parser.EatEqICase("EPSILON")) {
+ vector->epsilon = GET_OR_RET(parser.TakeFloat<double>());
+ } else {
+ break;
+ }
+ hnsw_state->num_attributes -= 2;
} else {
break;
}
}
+ if (auto vector_meta [[maybe_unused]] =
dynamic_cast<redis::HnswVectorFieldMetadata *>(field_meta.get())) {
+ GET_OR_RET(hnsw_state->Validate());
+ }
+
kqir::FieldInfo field_info(field_name, std::move(field_meta));
index_info_->Add(std::move(field_info));
@@ -140,6 +195,28 @@ class CommandFTCreate : public Commander {
};
private:
+ struct HnswIndexCreationState {
+ uint8_t num_attributes;
+ bool type_set;
+ bool dim_set;
+ bool distance_metric_set;
+
+ explicit HnswIndexCreationState(uint8_t num_attributes)
+ : num_attributes(num_attributes), type_set(false), dim_set(false),
distance_metric_set(false) {}
+
+ Status Validate() const {
+ if (!type_set) {
+ return {Status::RedisParseErr, "VECTOR field requires TYPE to be set"};
+ }
+ if (!dim_set) {
+ return {Status::RedisParseErr, "VECTOR field requires DIM to be set"};
+ }
+ if (!distance_metric_set) {
+ return {Status::RedisParseErr, "VECTOR field requires DISTANCE_METRIC
to be set"};
+ }
+ return Status::OK();
+ }
+ };
std::unique_ptr<kqir::IndexInfo> index_info_;
};
diff --git a/src/commands/error_constants.h b/src/commands/error_constants.h
index ea2c38b7..713f31d0 100644
--- a/src/commands/error_constants.h
+++ b/src/commands/error_constants.h
@@ -26,6 +26,7 @@ inline constexpr const char *errNotImplemented = "not
implemented";
inline constexpr const char *errInvalidSyntax = "syntax error";
inline constexpr const char *errInvalidExpireTime = "invalid expire time";
inline constexpr const char *errWrongNumOfArguments = "wrong number of
arguments";
+inline constexpr const char *errInvalidNumOfAttributes = "number of attributes
is not as required";
inline constexpr const char *errValueNotInteger = "value is not an integer or
out of range";
inline constexpr const char *errAdminPermissionRequired = "admin permission
required to perform the command";
inline constexpr const char *errValueMustBePositive = "value is out of range,
must be positive";
diff --git a/tests/gocase/unit/search/search_test.go
b/tests/gocase/unit/search/search_test.go
index bb87c19f..ef9a554c 100644
--- a/tests/gocase/unit/search/search_test.go
+++ b/tests/gocase/unit/search/search_test.go
@@ -20,7 +20,9 @@
package search
import (
+ "bytes"
"context"
+ "encoding/binary"
"testing"
"github.com/apache/kvrocks/tests/gocase/util"
@@ -28,6 +30,18 @@ import (
"github.com/stretchr/testify/require"
)
+func SetBinaryBuffer(buf *bytes.Buffer, vec []float64) error {
+ buf.Reset()
+
+ for _, v := range vec {
+ if err := binary.Write(buf, binary.LittleEndian, v); err != nil
{
+ return err
+ }
+ }
+
+ return nil
+}
+
func TestSearch(t *testing.T) {
srv := util.StartServer(t, map[string]string{})
defer srv.Close()
@@ -37,7 +51,8 @@ func TestSearch(t *testing.T) {
defer func() { require.NoError(t, rdb.Close()) }()
t.Run("FT.CREATE", func(t *testing.T) {
- require.NoError(t, rdb.Do(ctx, "FT.CREATE", "testidx1", "ON",
"JSON", "PREFIX", "1", "test1:", "SCHEMA", "a", "TAG", "b", "NUMERIC").Err())
+ require.NoError(t, rdb.Do(ctx, "FT.CREATE", "testidx1", "ON",
"JSON", "PREFIX", "1", "test1:", "SCHEMA", "a", "TAG", "b", "NUMERIC",
+ "c", "VECTOR", "HNSW", "6", "TYPE", "FLOAT64", "DIM",
"3", "DISTANCE_METRIC", "L2").Err())
verify := func(t *testing.T) {
require.Equal(t, []interface{}{"testidx1"}, rdb.Do(ctx,
"FT._LIST").Val())
@@ -53,6 +68,7 @@ func TestSearch(t *testing.T) {
require.Equal(t, "fields", idxInfo[6])
require.Equal(t, []interface{}{"a", "tag"},
idxInfo[7].([]interface{})[0])
require.Equal(t, []interface{}{"b", "numeric"},
idxInfo[7].([]interface{})[1])
+ require.Equal(t, []interface{}{"c", "vector"},
idxInfo[7].([]interface{})[2])
}
verify(t)
@@ -61,10 +77,10 @@ func TestSearch(t *testing.T) {
})
t.Run("FT.SEARCH", func(t *testing.T) {
- require.NoError(t, rdb.Do(ctx, "JSON.SET", "test1:k1", "$",
`{"a": "x,y", "b": 11}`).Err())
- require.NoError(t, rdb.Do(ctx, "JSON.SET", "test1:k2", "$",
`{"a": "x,z", "b": 22}`).Err())
- require.NoError(t, rdb.Do(ctx, "JSON.SET", "test1:k3", "$",
`{"a": "y,z", "b": 33}`).Err())
- require.NoError(t, rdb.Do(ctx, "JSON.SET", "test2:k4", "$",
`{"a": "x,y,z", "b": 44}`).Err())
+ require.NoError(t, rdb.Do(ctx, "JSON.SET", "test1:k1", "$",
`{"a": "x,y", "b": 11, "c": [2,3,4]}`).Err())
+ require.NoError(t, rdb.Do(ctx, "JSON.SET", "test1:k2", "$",
`{"a": "x,z", "b": 22, "c": [12,13,14]}`).Err())
+ require.NoError(t, rdb.Do(ctx, "JSON.SET", "test1:k3", "$",
`{"a": "y,z", "b": 33, "c": [23,24,25]}`).Err())
+ require.NoError(t, rdb.Do(ctx, "JSON.SET", "test2:k4", "$",
`{"a": "x,y,z", "b": 44, "c": [33,34,35]}`).Err())
verify := func(t *testing.T, res *redis.Cmd) {
require.NoError(t, res.Err())
@@ -85,24 +101,56 @@ func TestSearch(t *testing.T) {
require.Equal(t, 3, len(res.Val().([]interface{})))
require.Equal(t, int64(1), res.Val().([]interface{})[0])
require.Equal(t, "test1:k2",
res.Val().([]interface{})[1])
+
fields := res.Val().([]interface{})[2].([]interface{})
- if fields[0] == "a" {
- require.Equal(t, "x,z", fields[1])
- require.Equal(t, "b", fields[2])
- require.Equal(t, "22", fields[3])
- } else if fields[0] == "b" {
- require.Equal(t, "22", fields[1])
- require.Equal(t, "a", fields[2])
- require.Equal(t, "x,z", fields[3])
- } else {
- require.Fail(t, "not started with a or b")
+ fieldMap := make(map[string]string)
+ for i := 0; i < len(fields); i += 2 {
+ fieldMap[fields[i].(string)] =
fields[i+1].(string)
}
+
+ _, aExists := fieldMap["a"]
+ _, bExists := fieldMap["b"]
+ _, cExists := fieldMap["c"]
+
+ require.True(t, aExists, "'a' should exist in the
result")
+ require.True(t, bExists, "'b' should exist in the
result")
+ require.True(t, cExists, "'c' should exist in the
result")
+
+ require.Equal(t, "x,z", fieldMap["a"])
+ require.Equal(t, "22", fieldMap["b"])
+ require.Equal(t, "12.000000, 13.000000, 14.000000",
fieldMap["c"])
}
res = rdb.Do(ctx, "FT.SEARCHSQL", `select * from testidx1 where
a hastag "z" and b < 30`)
verify(t, res)
res = rdb.Do(ctx, "FT.SEARCH", "testidx1", `@a:{z} @b:[-inf
(30]`)
verify(t, res)
+ res = rdb.Do(ctx, "FT.SEARCHSQL", `select * from testidx1 order
by c <-> [13,14,15] limit 1`)
+ verify(t, res)
+ res = rdb.Do(ctx, "FT.SEARCHSQL", `select * from testidx1 where
c <-> [16,17,18] < 7`)
+ verify(t, res)
+ res = rdb.Do(ctx, "FT.SEARCHSQL", `select * from testidx1 where
a hastag "z" and c <-> [2,3,4] < 18`)
+ verify(t, res)
+
+ var buf bytes.Buffer
+
+ vec := []float64{13, 14, 15}
+ require.NoError(t, SetBinaryBuffer(&buf, vec), "Failed to set
binary buffer")
+ vecBinary := buf.Bytes()
+ res = rdb.Do(ctx, "FT.SEARCH", "testidx1", `*=>[KNN 1 @c
$BLOB]`, "PARAMS", "2", "BLOB", vecBinary)
+ verify(t, res)
+
+ vec = []float64{16, 17, 18}
+ require.NoError(t, SetBinaryBuffer(&buf, vec), "Failed to set
binary buffer")
+ vecBinary = buf.Bytes()
+ res = rdb.Do(ctx, "FT.SEARCH", "testidx1", `@c:[VECTOR_RANGE 7
$BLOB]`, "PARAMS", "2", "BLOB", vecBinary)
+ verify(t, res)
+
+ vec = []float64{2, 3, 4}
+ require.NoError(t, SetBinaryBuffer(&buf, vec), "Failed to set
binary buffer")
+ vecBinary = buf.Bytes()
+ res = rdb.Do(ctx, "FT.SEARCH", "testidx1", `@a:{z}
@c:[VECTOR_RANGE 18 $BLOB]`, "PARAMS", "2", "BLOB", vecBinary)
+ verify(t, res)
})
t.Run("FT.DROPINDEX", func(t *testing.T) {