This is an automated email from the ASF dual-hosted git repository.

twice pushed a commit to branch unstable
in repository https://gitbox.apache.org/repos/asf/kvrocks.git


The following commit(s) were added to refs/heads/unstable by this push:
     new 7f812c7c feat(search): add support of HNSW vector fields for FT.CREATE 
(#2477)
7f812c7c is described below

commit 7f812c7c0b4b9e893ba4ece4088dc555d2b2e4e8
Author: Rebecca Zhou <[email protected]>
AuthorDate: Thu Aug 8 15:44:38 2024 -0700

    feat(search): add support of HNSW vector fields for FT.CREATE (#2477)
    
    Co-authored-by: Twice <[email protected]>
---
 src/commands/cmd_search.cc              | 79 ++++++++++++++++++++++++++++++++-
 src/commands/error_constants.h          |  1 +
 tests/gocase/unit/search/search_test.go | 78 +++++++++++++++++++++++++-------
 3 files changed, 142 insertions(+), 16 deletions(-)

diff --git a/src/commands/cmd_search.cc b/src/commands/cmd_search.cc
index f543c80c..210938f9 100644
--- a/src/commands/cmd_search.cc
+++ b/src/commands/cmd_search.cc
@@ -85,12 +85,27 @@ class CommandFTCreate : public Commander {
         }
 
         std::unique_ptr<redis::IndexFieldMetadata> field_meta;
+        std::unique_ptr<HnswIndexCreationState> hnsw_state;
         if (parser.EatEqICase("TAG")) {
           field_meta = std::make_unique<redis::TagFieldMetadata>();
         } else if (parser.EatEqICase("NUMERIC")) {
           field_meta = std::make_unique<redis::NumericFieldMetadata>();
+        } else if (parser.EatEqICase("VECTOR")) {
+          if (parser.EatEqICase("HNSW")) {
+            field_meta = std::make_unique<redis::HnswVectorFieldMetadata>();
+            auto num_attributes = GET_OR_RET(parser.TakeInt<uint8_t>());
+            if (num_attributes < 6) {
+              return {Status::NotOK, errInvalidNumOfAttributes};
+            }
+            if (num_attributes % 2 != 0) {
+              return {Status::NotOK, "number of attributes must be multiple of 
2"};
+            }
+            hnsw_state = 
std::make_unique<HnswIndexCreationState>(num_attributes);
+          } else {
+            return {Status::RedisParseErr, "only support HNSW algorithm for 
vector field"};
+          }
         } else {
-          return {Status::RedisParseErr, "expect field type TAG or NUMERIC"};
+          return {Status::RedisParseErr, "expect field type TAG, NUMERIC or 
VECTOR"};
         }
 
         while (parser.Good()) {
@@ -110,11 +125,51 @@ class CommandFTCreate : public Commander {
             } else {
               break;
             }
+          } else if (auto vector = dynamic_cast<redis::HnswVectorFieldMetadata 
*>(field_meta.get())) {
+            if (hnsw_state->num_attributes <= 0) break;
+
+            if (parser.EatEqICase("TYPE")) {
+              if (parser.EatEqICase("FLOAT64")) {
+                vector->vector_type = VectorType::FLOAT64;
+              } else {
+                return {Status::RedisParseErr, "unsupported vector type"};
+              }
+              hnsw_state->type_set = true;
+            } else if (parser.EatEqICase("DIM")) {
+              vector->dim = GET_OR_RET(parser.TakeInt<uint16_t>());
+              hnsw_state->dim_set = true;
+            } else if (parser.EatEqICase("DISTANCE_METRIC")) {
+              if (parser.EatEqICase("L2")) {
+                vector->distance_metric = DistanceMetric::L2;
+              } else if (parser.EatEqICase("IP")) {
+                vector->distance_metric = DistanceMetric::IP;
+              } else if (parser.EatEqICase("COSINE")) {
+                vector->distance_metric = DistanceMetric::COSINE;
+              } else {
+                return {Status::RedisParseErr, "unsupported distance metric"};
+              }
+              hnsw_state->distance_metric_set = true;
+            } else if (parser.EatEqICase("M")) {
+              vector->m = GET_OR_RET(parser.TakeInt<uint16_t>());
+            } else if (parser.EatEqICase("EF_CONSTRUCTION")) {
+              vector->ef_construction = GET_OR_RET(parser.TakeInt<uint32_t>());
+            } else if (parser.EatEqICase("EF_RUNTIME")) {
+              vector->ef_runtime = GET_OR_RET(parser.TakeInt<uint32_t>());
+            } else if (parser.EatEqICase("EPSILON")) {
+              vector->epsilon = GET_OR_RET(parser.TakeFloat<double>());
+            } else {
+              break;
+            }
+            hnsw_state->num_attributes -= 2;
           } else {
             break;
           }
         }
 
+        if (auto vector_meta [[maybe_unused]] = 
dynamic_cast<redis::HnswVectorFieldMetadata *>(field_meta.get())) {
+          GET_OR_RET(hnsw_state->Validate());
+        }
+
         kqir::FieldInfo field_info(field_name, std::move(field_meta));
 
         index_info_->Add(std::move(field_info));
@@ -140,6 +195,28 @@ class CommandFTCreate : public Commander {
   };
 
  private:
+  struct HnswIndexCreationState {
+    uint8_t num_attributes;
+    bool type_set;
+    bool dim_set;
+    bool distance_metric_set;
+
+    explicit HnswIndexCreationState(uint8_t num_attributes)
+        : num_attributes(num_attributes), type_set(false), dim_set(false), 
distance_metric_set(false) {}
+
+    Status Validate() const {
+      if (!type_set) {
+        return {Status::RedisParseErr, "VECTOR field requires TYPE to be set"};
+      }
+      if (!dim_set) {
+        return {Status::RedisParseErr, "VECTOR field requires DIM to be set"};
+      }
+      if (!distance_metric_set) {
+        return {Status::RedisParseErr, "VECTOR field requires DISTANCE_METRIC 
to be set"};
+      }
+      return Status::OK();
+    }
+  };
   std::unique_ptr<kqir::IndexInfo> index_info_;
 };
 
diff --git a/src/commands/error_constants.h b/src/commands/error_constants.h
index ea2c38b7..713f31d0 100644
--- a/src/commands/error_constants.h
+++ b/src/commands/error_constants.h
@@ -26,6 +26,7 @@ inline constexpr const char *errNotImplemented = "not 
implemented";
 inline constexpr const char *errInvalidSyntax = "syntax error";
 inline constexpr const char *errInvalidExpireTime = "invalid expire time";
 inline constexpr const char *errWrongNumOfArguments = "wrong number of 
arguments";
+inline constexpr const char *errInvalidNumOfAttributes = "number of attributes 
is not as required";
 inline constexpr const char *errValueNotInteger = "value is not an integer or 
out of range";
 inline constexpr const char *errAdminPermissionRequired = "admin permission 
required to perform the command";
 inline constexpr const char *errValueMustBePositive = "value is out of range, 
must be positive";
diff --git a/tests/gocase/unit/search/search_test.go 
b/tests/gocase/unit/search/search_test.go
index bb87c19f..ef9a554c 100644
--- a/tests/gocase/unit/search/search_test.go
+++ b/tests/gocase/unit/search/search_test.go
@@ -20,7 +20,9 @@
 package search
 
 import (
+       "bytes"
        "context"
+       "encoding/binary"
        "testing"
 
        "github.com/apache/kvrocks/tests/gocase/util"
@@ -28,6 +30,18 @@ import (
        "github.com/stretchr/testify/require"
 )
 
+func SetBinaryBuffer(buf *bytes.Buffer, vec []float64) error {
+       buf.Reset()
+
+       for _, v := range vec {
+               if err := binary.Write(buf, binary.LittleEndian, v); err != nil 
{
+                       return err
+               }
+       }
+
+       return nil
+}
+
 func TestSearch(t *testing.T) {
        srv := util.StartServer(t, map[string]string{})
        defer srv.Close()
@@ -37,7 +51,8 @@ func TestSearch(t *testing.T) {
        defer func() { require.NoError(t, rdb.Close()) }()
 
        t.Run("FT.CREATE", func(t *testing.T) {
-               require.NoError(t, rdb.Do(ctx, "FT.CREATE", "testidx1", "ON", 
"JSON", "PREFIX", "1", "test1:", "SCHEMA", "a", "TAG", "b", "NUMERIC").Err())
+               require.NoError(t, rdb.Do(ctx, "FT.CREATE", "testidx1", "ON", 
"JSON", "PREFIX", "1", "test1:", "SCHEMA", "a", "TAG", "b", "NUMERIC",
+                       "c", "VECTOR", "HNSW", "6", "TYPE", "FLOAT64", "DIM", 
"3", "DISTANCE_METRIC", "L2").Err())
 
                verify := func(t *testing.T) {
                        require.Equal(t, []interface{}{"testidx1"}, rdb.Do(ctx, 
"FT._LIST").Val())
@@ -53,6 +68,7 @@ func TestSearch(t *testing.T) {
                        require.Equal(t, "fields", idxInfo[6])
                        require.Equal(t, []interface{}{"a", "tag"}, 
idxInfo[7].([]interface{})[0])
                        require.Equal(t, []interface{}{"b", "numeric"}, 
idxInfo[7].([]interface{})[1])
+                       require.Equal(t, []interface{}{"c", "vector"}, 
idxInfo[7].([]interface{})[2])
                }
                verify(t)
 
@@ -61,10 +77,10 @@ func TestSearch(t *testing.T) {
        })
 
        t.Run("FT.SEARCH", func(t *testing.T) {
-               require.NoError(t, rdb.Do(ctx, "JSON.SET", "test1:k1", "$", 
`{"a": "x,y", "b": 11}`).Err())
-               require.NoError(t, rdb.Do(ctx, "JSON.SET", "test1:k2", "$", 
`{"a": "x,z", "b": 22}`).Err())
-               require.NoError(t, rdb.Do(ctx, "JSON.SET", "test1:k3", "$", 
`{"a": "y,z", "b": 33}`).Err())
-               require.NoError(t, rdb.Do(ctx, "JSON.SET", "test2:k4", "$", 
`{"a": "x,y,z", "b": 44}`).Err())
+               require.NoError(t, rdb.Do(ctx, "JSON.SET", "test1:k1", "$", 
`{"a": "x,y", "b": 11, "c": [2,3,4]}`).Err())
+               require.NoError(t, rdb.Do(ctx, "JSON.SET", "test1:k2", "$", 
`{"a": "x,z", "b": 22, "c": [12,13,14]}`).Err())
+               require.NoError(t, rdb.Do(ctx, "JSON.SET", "test1:k3", "$", 
`{"a": "y,z", "b": 33, "c": [23,24,25]}`).Err())
+               require.NoError(t, rdb.Do(ctx, "JSON.SET", "test2:k4", "$", 
`{"a": "x,y,z", "b": 44, "c": [33,34,35]}`).Err())
 
                verify := func(t *testing.T, res *redis.Cmd) {
                        require.NoError(t, res.Err())
@@ -85,24 +101,56 @@ func TestSearch(t *testing.T) {
                        require.Equal(t, 3, len(res.Val().([]interface{})))
                        require.Equal(t, int64(1), res.Val().([]interface{})[0])
                        require.Equal(t, "test1:k2", 
res.Val().([]interface{})[1])
+
                        fields := res.Val().([]interface{})[2].([]interface{})
-                       if fields[0] == "a" {
-                               require.Equal(t, "x,z", fields[1])
-                               require.Equal(t, "b", fields[2])
-                               require.Equal(t, "22", fields[3])
-                       } else if fields[0] == "b" {
-                               require.Equal(t, "22", fields[1])
-                               require.Equal(t, "a", fields[2])
-                               require.Equal(t, "x,z", fields[3])
-                       } else {
-                               require.Fail(t, "not started with a or b")
+                       fieldMap := make(map[string]string)
+                       for i := 0; i < len(fields); i += 2 {
+                               fieldMap[fields[i].(string)] = 
fields[i+1].(string)
                        }
+
+                       _, aExists := fieldMap["a"]
+                       _, bExists := fieldMap["b"]
+                       _, cExists := fieldMap["c"]
+
+                       require.True(t, aExists, "'a' should exist in the 
result")
+                       require.True(t, bExists, "'b' should exist in the 
result")
+                       require.True(t, cExists, "'c' should exist in the 
result")
+
+                       require.Equal(t, "x,z", fieldMap["a"])
+                       require.Equal(t, "22", fieldMap["b"])
+                       require.Equal(t, "12.000000, 13.000000, 14.000000", 
fieldMap["c"])
                }
 
                res = rdb.Do(ctx, "FT.SEARCHSQL", `select * from testidx1 where 
a hastag "z" and b < 30`)
                verify(t, res)
                res = rdb.Do(ctx, "FT.SEARCH", "testidx1", `@a:{z} @b:[-inf 
(30]`)
                verify(t, res)
+               res = rdb.Do(ctx, "FT.SEARCHSQL", `select * from testidx1 order 
by c <-> [13,14,15] limit 1`)
+               verify(t, res)
+               res = rdb.Do(ctx, "FT.SEARCHSQL", `select * from testidx1 where 
c <-> [16,17,18] < 7`)
+               verify(t, res)
+               res = rdb.Do(ctx, "FT.SEARCHSQL", `select * from testidx1 where 
a hastag "z" and c <-> [2,3,4] < 18`)
+               verify(t, res)
+
+               var buf bytes.Buffer
+
+               vec := []float64{13, 14, 15}
+               require.NoError(t, SetBinaryBuffer(&buf, vec), "Failed to set 
binary buffer")
+               vecBinary := buf.Bytes()
+               res = rdb.Do(ctx, "FT.SEARCH", "testidx1", `*=>[KNN 1 @c 
$BLOB]`, "PARAMS", "2", "BLOB", vecBinary)
+               verify(t, res)
+
+               vec = []float64{16, 17, 18}
+               require.NoError(t, SetBinaryBuffer(&buf, vec), "Failed to set 
binary buffer")
+               vecBinary = buf.Bytes()
+               res = rdb.Do(ctx, "FT.SEARCH", "testidx1", `@c:[VECTOR_RANGE 7 
$BLOB]`, "PARAMS", "2", "BLOB", vecBinary)
+               verify(t, res)
+
+               vec = []float64{2, 3, 4}
+               require.NoError(t, SetBinaryBuffer(&buf, vec), "Failed to set 
binary buffer")
+               vecBinary = buf.Bytes()
+               res = rdb.Do(ctx, "FT.SEARCH", "testidx1", `@a:{z} 
@c:[VECTOR_RANGE 18 $BLOB]`, "PARAMS", "2", "BLOB", vecBinary)
+               verify(t, res)
        })
 
        t.Run("FT.DROPINDEX", func(t *testing.T) {

Reply via email to