This is an automated email from the ASF dual-hosted git repository.

twice pushed a commit to branch unstable
in repository https://gitbox.apache.org/repos/asf/kvrocks.git


The following commit(s) were added to refs/heads/unstable by this push:
     new 45ba475a feat(search): support query parameters in SQL and RediSearch 
query (#2443)
45ba475a is described below

commit 45ba475a20637e2177901f9cd3e5118cc2e95896
Author: Twice <[email protected]>
AuthorDate: Wed Jul 24 11:24:19 2024 +0900

    feat(search): support query parameters in SQL and RediSearch query (#2443)
---
 src/commands/cmd_search.cc               | 85 +++++++++++++++++++++++++-------
 src/commands/command_parser.h            |  3 ++
 src/search/common_transformer.h          | 19 +++++++
 src/search/redis_query_parser.h          | 10 ++--
 src/search/redis_query_transformer.h     | 59 +++++++++++++++-------
 src/search/sql_parser.h                  |  8 ++-
 src/search/sql_transformer.h             | 50 ++++++++++++++-----
 src/search/value.h                       | 13 ++++-
 tests/cppunit/redis_query_parser_test.cc | 13 ++++-
 tests/cppunit/sql_parser_test.cc         | 15 +++++-
 10 files changed, 221 insertions(+), 54 deletions(-)

diff --git a/src/commands/cmd_search.cc b/src/commands/cmd_search.cc
index 1928672c..f543c80c 100644
--- a/src/commands/cmd_search.cc
+++ b/src/commands/cmd_search.cc
@@ -24,6 +24,7 @@
 
 #include "commander.h"
 #include "commands/command_parser.h"
+#include "search/common_transformer.h"
 #include "search/index_info.h"
 #include "search/ir.h"
 #include "search/ir_dot_dumper.h"
@@ -155,31 +156,56 @@ static void DumpQueryResult(const 
std::vector<kqir::ExecutorContext::RowType> &r
   }
 }
 
+using CommandParserWithNode = 
std::pair<CommandParserFromConst<std::vector<std::string>>, 
std::unique_ptr<kqir::Node>>;
+
+static StatusOr<CommandParserWithNode> ParseSQLQuery(const 
std::vector<std::string> &args) {
+  CommandParser parser(args, 1);
+
+  auto sql = GET_OR_RET(parser.TakeStr());
+
+  kqir::ParamMap param_map;
+  if (parser.EatEqICase("PARAMS")) {
+    auto nargs = GET_OR_RET(parser.TakeInt<size_t>());
+    if (nargs % 2 != 0) {
+      return {Status::NotOK, "nargs of PARAMS must be multiple of 2"};
+    }
+
+    for (size_t i = 0; i < nargs / 2; ++i) {
+      auto key = GET_OR_RET(parser.TakeStr());
+      auto val = GET_OR_RET(parser.TakeStr());
+
+      param_map.emplace(key, val);
+    }
+  }
+
+  auto ir = GET_OR_RET(kqir::sql::ParseToIR(kqir::peg::string_input(sql, 
"ft.searchsql"), param_map));
+  return std::make_pair(parser, std::move(ir));
+}
+
 class CommandFTExplainSQL : public Commander {
   Status Parse(const std::vector<std::string> &args) override {
-    if (args.size() == 3) {
-      if (util::EqualICase(args[2], "simple")) {
+    auto [parser, ir] = GET_OR_RET(ParseSQLQuery(args_));
+    ir_ = std::move(ir);
+
+    if (parser.Good()) {
+      if (parser.EatEqICase("simple")) {
         format_ = SIMPLE;
-      } else if (util::EqualICase(args[2], "dot")) {
+      } else if (parser.EatEqICase("dot")) {
         format_ = DOT_GRAPH;
       } else {
         return {Status::NotOK, "output format should be SIMPLE or DOT"};
       }
     }
 
-    if (args.size() > 3) {
-      return {Status::NotOK, "more arguments than expected"};
+    if (parser.Good()) {
+      return {Status::NotOK, "unexpected arguments in the end"};
     }
 
     return Status::OK();
   }
 
   Status Execute(Server *srv, Connection *conn, std::string *output) override {
-    const auto &sql = args_[1];
-
-    auto ir = GET_OR_RET(kqir::sql::ParseToIR(kqir::peg::string_input(sql, 
"ft.explainsql")));
-
-    auto plan = GET_OR_RET(srv->index_mgr.GeneratePlan(std::move(ir), 
conn->GetNamespace()));
+    auto plan = GET_OR_RET(srv->index_mgr.GeneratePlan(std::move(ir_), 
conn->GetNamespace()));
 
     if (format_ == SIMPLE) {
       output->append(BulkString(plan->Dump()));
@@ -195,20 +221,30 @@ class CommandFTExplainSQL : public Commander {
   };
 
   enum OutputFormat { SIMPLE, DOT_GRAPH } format_ = SIMPLE;
+  std::unique_ptr<kqir::Node> ir_;
 };
 
 class CommandFTSearchSQL : public Commander {
-  Status Execute(Server *srv, Connection *conn, std::string *output) override {
-    const auto &sql = args_[1];
+  Status Parse(const std::vector<std::string> &args) override {
+    auto [parser, ir] = GET_OR_RET(ParseSQLQuery(args));
+    ir_ = std::move(ir);
 
-    auto ir = GET_OR_RET(kqir::sql::ParseToIR(kqir::peg::string_input(sql, 
"ft.searchsql")));
+    if (parser.Good()) {
+      return {Status::NotOK, "unexpected arguments in the end"};
+    }
 
-    auto results = GET_OR_RET(srv->index_mgr.Search(std::move(ir), 
conn->GetNamespace()));
+    return Status::OK();
+  }
+  Status Execute(Server *srv, Connection *conn, std::string *output) override {
+    auto results = GET_OR_RET(srv->index_mgr.Search(std::move(ir_), 
conn->GetNamespace()));
 
     DumpQueryResult(results, output);
 
     return Status::OK();
   };
+
+ private:
+  std::unique_ptr<kqir::Node> ir_;
 };
 
 static StatusOr<std::unique_ptr<kqir::Node>> ParseRediSearchQuery(const 
std::vector<std::string> &args) {
@@ -218,12 +254,12 @@ static StatusOr<std::unique_ptr<kqir::Node>> 
ParseRediSearchQuery(const std::vec
   auto query_str = GET_OR_RET(parser.TakeStr());
 
   auto index_ref = std::make_unique<kqir::IndexRef>(index_name);
-  auto query = kqir::Node::MustAs<kqir::QueryExpr>(
-      
GET_OR_RET(kqir::redis_query::ParseToIR(kqir::peg::string_input(query_str, 
"ft.search"))));
 
   auto select = 
std::make_unique<kqir::SelectClause>(std::vector<std::unique_ptr<kqir::FieldRef>>{});
   std::unique_ptr<kqir::SortByClause> sort_by;
   std::unique_ptr<kqir::LimitClause> limit;
+
+  kqir::ParamMap param_map;
   while (parser.Good()) {
     if (parser.EatEqICase("RETURNS")) {
       auto count = GET_OR_RET(parser.TakeInt<size_t>());
@@ -247,11 +283,26 @@ static StatusOr<std::unique_ptr<kqir::Node>> 
ParseRediSearchQuery(const std::vec
       auto count = GET_OR_RET(parser.TakeInt<size_t>());
 
       limit = std::make_unique<kqir::LimitClause>(offset, count);
+    } else if (parser.EatEqICase("PARAMS")) {
+      auto nargs = GET_OR_RET(parser.TakeInt<size_t>());
+      if (nargs % 2 != 0) {
+        return {Status::NotOK, "nargs of PARAMS must be multiple of 2"};
+      }
+
+      for (size_t i = 0; i < nargs / 2; ++i) {
+        auto key = GET_OR_RET(parser.TakeStr());
+        auto val = GET_OR_RET(parser.TakeStr());
+
+        param_map.emplace(key, val);
+      }
     } else {
       return parser.InvalidSyntax();
     }
   }
 
+  auto query = kqir::Node::MustAs<kqir::QueryExpr>(
+      
GET_OR_RET(kqir::redis_query::ParseToIR(kqir::peg::string_input(query_str, 
"ft.search"), param_map)));
+
   return std::make_unique<kqir::SearchExpr>(std::move(index_ref), 
std::move(query), std::move(limit),
                                             std::move(sort_by), 
std::move(select));
 }
@@ -359,7 +410,7 @@ class CommandFTDrop : public Commander {
 };
 
 REDIS_REGISTER_COMMANDS(MakeCmdAttr<CommandFTCreate>("ft.create", -2, "write 
exclusive no-multi no-script", 0, 0, 0),
-                        MakeCmdAttr<CommandFTSearchSQL>("ft.searchsql", 2, 
"read-only", 0, 0, 0),
+                        MakeCmdAttr<CommandFTSearchSQL>("ft.searchsql", -2, 
"read-only", 0, 0, 0),
                         MakeCmdAttr<CommandFTSearch>("ft.search", -3, 
"read-only", 0, 0, 0),
                         MakeCmdAttr<CommandFTExplainSQL>("ft.explainsql", -2, 
"read-only", 0, 0, 0),
                         MakeCmdAttr<CommandFTExplain>("ft.explain", -3, 
"read-only", 0, 0, 0),
diff --git a/src/commands/command_parser.h b/src/commands/command_parser.h
index c13d682b..a4e06e1c 100644
--- a/src/commands/command_parser.h
+++ b/src/commands/command_parser.h
@@ -177,3 +177,6 @@ CommandParser(const Container&, size_t = 0) -> 
CommandParser<typename Container:
 
 template <typename Container>
 CommandParser(Container&&, size_t = 0) -> CommandParser<MoveIterator<typename 
Container::iterator>>;
+
+template <typename Container>
+using CommandParserFromConst = CommandParser<typename 
Container::const_iterator>;
diff --git a/src/search/common_transformer.h b/src/search/common_transformer.h
index 8febbb4c..5ebbcff6 100644
--- a/src/search/common_transformer.h
+++ b/src/search/common_transformer.h
@@ -20,6 +20,7 @@
 
 #pragma once
 
+#include <map>
 #include <tao/pegtl/contrib/parse_tree.hpp>
 #include <tao/pegtl/contrib/unescape.hpp>
 #include <tao/pegtl/demangle.hpp>
@@ -29,9 +30,27 @@
 
 namespace kqir {
 
+using ParamMap = std::map<std::string, std::string, std::less<>>;
+
 struct TreeTransformer {
   using TreeNode = std::unique_ptr<peg::parse_tree::node>;
 
+  const ParamMap& param_map;
+
+  explicit TreeTransformer(const ParamMap& param_map) : param_map(param_map) {}
+
+  StatusOr<std::string> GetParam(const TreeNode& node) {
+    // node->type must be Param here
+    auto name = node->string_view().substr(1);
+
+    auto iter = param_map.find(name);
+    if (iter == param_map.end()) {
+      return {Status::NotOK, fmt::format("parameter with name `{}` not found", 
name)};
+    }
+
+    return iter->second;
+  }
+
   template <typename T>
   static bool Is(const TreeNode& node) {
     return node->type == peg::demangle<T>();
diff --git a/src/search/redis_query_parser.h b/src/search/redis_query_parser.h
index d64d0bf0..5fe03046 100644
--- a/src/search/redis_query_parser.h
+++ b/src/search/redis_query_parser.h
@@ -32,12 +32,16 @@ using namespace peg;
 
 struct Field : seq<one<'@'>, Identifier> {};
 
-struct Tag : sor<Identifier, StringL> {};
+struct Param : seq<one<'$'>, Identifier> {};
+
+struct Tag : sor<Identifier, StringL, Param> {};
 struct TagList : seq<one<'{'>, WSPad<Tag>, star<seq<one<'|'>, WSPad<Tag>>>, 
one<'}'>> {};
 
+struct NumberOrParam : sor<Number, Param> {};
+
 struct Inf : seq<opt<one<'+', '-'>>, string<'i', 'n', 'f'>> {};
-struct ExclusiveNumber : seq<one<'('>, Number> {};
-struct NumericRangePart : sor<Inf, ExclusiveNumber, Number> {};
+struct ExclusiveNumber : seq<one<'('>, NumberOrParam> {};
+struct NumericRangePart : sor<Inf, ExclusiveNumber, NumberOrParam> {};
 struct NumericRange : seq<one<'['>, WSPad<NumericRangePart>, 
WSPad<NumericRangePart>, one<']'>> {};
 
 struct FieldQuery : seq<WSPad<Field>, one<':'>, WSPad<sor<TagList, 
NumericRange>>> {};
diff --git a/src/search/redis_query_transformer.h 
b/src/search/redis_query_transformer.h
index 0928ed59..6ff1581b 100644
--- a/src/search/redis_query_transformer.h
+++ b/src/search/redis_query_transformer.h
@@ -36,7 +36,7 @@ namespace ir = kqir;
 
 template <typename Rule>
 using TreeSelector =
-    parse_tree::selector<Rule, parse_tree::store_content::on<Number, StringL, 
Identifier, Inf>,
+    parse_tree::selector<Rule, parse_tree::store_content::on<Number, StringL, 
Param, Identifier, Inf>,
                          parse_tree::remove_content::on<TagList, NumericRange, 
ExclusiveNumber, FieldQuery, NotExpr,
                                                         AndExpr, OrExpr, 
Wildcard>>;
 
@@ -51,7 +51,9 @@ StatusOr<std::unique_ptr<parse_tree::node>> 
ParseToTree(Input&& in) {
 }
 
 struct Transformer : ir::TreeTransformer {
-  static auto Transform(const TreeNode& node) -> 
StatusOr<std::unique_ptr<Node>> {
+  explicit Transformer(const ParamMap& param_map) : TreeTransformer(param_map) 
{}
+
+  auto Transform(const TreeNode& node) -> StatusOr<std::unique_ptr<Node>> {
     if (Is<Number>(node)) {
       return Node::Create<ir::NumericLiteral>(*ParseFloat(node->string()));
     } else if (Is<Wildcard>(node)) {
@@ -66,7 +68,17 @@ struct Transformer : ir::TreeTransformer {
         std::vector<std::unique_ptr<ir::QueryExpr>> exprs;
 
         for (const auto& tag : query->children) {
-          auto tag_str = Is<Identifier>(tag) ? tag->string() : 
GET_OR_RET(UnescapeString(tag->string()));
+          std::string tag_str;
+          if (Is<Identifier>(tag)) {
+            tag_str = tag->string();
+          } else if (Is<StringL>(tag)) {
+            tag_str = GET_OR_RET(UnescapeString(tag->string()));
+          } else if (Is<Param>(tag)) {
+            tag_str = GET_OR_RET(GetParam(tag));
+          } else {
+            return {Status::NotOK, "encountered invalid tag"};
+          }
+
           
exprs.push_back(std::make_unique<ir::TagContainExpr>(std::make_unique<FieldRef>(field),
                                                                
std::make_unique<StringLiteral>(tag_str)));
         }
@@ -82,14 +94,27 @@ struct Transformer : ir::TreeTransformer {
         const auto& lhs = query->children[0];
         const auto& rhs = query->children[1];
 
+        auto number_or_param = [this](const TreeNode& node) -> 
StatusOr<std::unique_ptr<NumericLiteral>> {
+          if (Is<Number>(node)) {
+            return 
Node::MustAs<ir::NumericLiteral>(GET_OR_RET(Transform(node)));
+          } else if (Is<Param>(node)) {
+            auto val = GET_OR_RET(ParseFloat(GET_OR_RET(GetParam(node)))
+                                      .Prefixed(fmt::format("parameter {} is 
not a number", node->string_view())));
+
+            return std::make_unique<ir::NumericLiteral>(val);
+          } else {
+            return {Status::NotOK,
+                    fmt::format("expected a number or a parameter in numeric 
comparison but got {}", node->type)};
+          }
+        };
+
         if (Is<ExclusiveNumber>(lhs)) {
+          
exprs.push_back(std::make_unique<NumericCompareExpr>(NumericCompareExpr::GT,
+                                                               
std::make_unique<FieldRef>(field),
+                                                               
GET_OR_RET(number_or_param(lhs->children[0]))));
+        } else if (Is<Number>(lhs) || Is<Param>(lhs)) {
           exprs.push_back(std::make_unique<NumericCompareExpr>(
-              NumericCompareExpr::GT, std::make_unique<FieldRef>(field),
-              
Node::MustAs<NumericLiteral>(GET_OR_RET(Transform(lhs->children[0])))));
-        } else if (Is<Number>(lhs)) {
-          exprs.push_back(
-              std::make_unique<NumericCompareExpr>(NumericCompareExpr::GET, 
std::make_unique<FieldRef>(field),
-                                                   
Node::MustAs<NumericLiteral>(GET_OR_RET(Transform(lhs)))));
+              NumericCompareExpr::GET, std::make_unique<FieldRef>(field), 
GET_OR_RET(number_or_param(lhs))));
         } else {  // Inf
           if (lhs->string_view() == "+inf") {
             return {Status::NotOK, "it's not allowed to set the lower bound as 
positive infinity"};
@@ -97,13 +122,12 @@ struct Transformer : ir::TreeTransformer {
         }
 
         if (Is<ExclusiveNumber>(rhs)) {
+          
exprs.push_back(std::make_unique<NumericCompareExpr>(NumericCompareExpr::LT,
+                                                               
std::make_unique<FieldRef>(field),
+                                                               
GET_OR_RET(number_or_param(rhs->children[0]))));
+        } else if (Is<Number>(rhs) || Is<Param>(rhs)) {
           exprs.push_back(std::make_unique<NumericCompareExpr>(
-              NumericCompareExpr::LT, std::make_unique<FieldRef>(field),
-              
Node::MustAs<NumericLiteral>(GET_OR_RET(Transform(rhs->children[0])))));
-        } else if (Is<Number>(rhs)) {
-          exprs.push_back(
-              std::make_unique<NumericCompareExpr>(NumericCompareExpr::LET, 
std::make_unique<FieldRef>(field),
-                                                   
Node::MustAs<NumericLiteral>(GET_OR_RET(Transform(rhs)))));
+              NumericCompareExpr::LET, std::make_unique<FieldRef>(field), 
GET_OR_RET(number_or_param(rhs))));
         } else {  // Inf
           if (rhs->string_view() == "-inf") {
             return {Status::NotOK, "it's not allowed to set the upper bound as 
negative infinity"};
@@ -150,8 +174,9 @@ struct Transformer : ir::TreeTransformer {
 };
 
 template <typename Input>
-StatusOr<std::unique_ptr<ir::Node>> ParseToIR(Input&& in) {
-  return 
Transformer::Transform(GET_OR_RET(ParseToTree(std::forward<Input>(in))));
+StatusOr<std::unique_ptr<ir::Node>> ParseToIR(Input&& in, const ParamMap& 
param_map = {}) {
+  Transformer transformer(param_map);
+  return 
transformer.Transform(GET_OR_RET(ParseToTree(std::forward<Input>(in))));
 }
 
 }  // namespace redis_query
diff --git a/src/search/sql_parser.h b/src/search/sql_parser.h
index 9a611d9b..22b985fd 100644
--- a/src/search/sql_parser.h
+++ b/src/search/sql_parser.h
@@ -30,10 +30,14 @@ namespace sql {
 
 using namespace peg;
 
+struct Param : seq<one<'@'>, Identifier> {};
+struct StringOrParam : sor<StringL, Param> {};
+struct NumberOrParam : sor<Number, Param> {};
+
 struct HasTag : string<'h', 'a', 's', 't', 'a', 'g'> {};
-struct HasTagExpr : WSPad<seq<Identifier, WSPad<HasTag>, StringL>> {};
+struct HasTagExpr : WSPad<seq<Identifier, WSPad<HasTag>, StringOrParam>> {};
 
-struct NumericAtomExpr : WSPad<sor<Number, Identifier>> {};
+struct NumericAtomExpr : WSPad<sor<NumberOrParam, Identifier>> {};
 struct NumericCompareOp : sor<string<'!', '='>, string<'<', '='>, string<'>', 
'='>, one<'=', '<', '>'>> {};
 struct NumericCompareExpr : seq<NumericAtomExpr, NumericCompareOp, 
NumericAtomExpr> {};
 
diff --git a/src/search/sql_transformer.h b/src/search/sql_transformer.h
index 72aa0de7..d2ed8c21 100644
--- a/src/search/sql_transformer.h
+++ b/src/search/sql_transformer.h
@@ -25,6 +25,7 @@
 #include <variant>
 
 #include "common_transformer.h"
+#include "fmt/format.h"
 #include "ir.h"
 #include "parse_util.h"
 #include "sql_parser.h"
@@ -38,7 +39,8 @@ namespace ir = kqir;
 template <typename Rule>
 using TreeSelector = parse_tree::selector<
     Rule,
-    parse_tree::store_content::on<Boolean, Number, StringL, Identifier, 
NumericCompareOp, AscOrDesc, UnsignedInteger>,
+    parse_tree::store_content::on<Boolean, Number, StringL, Param, Identifier, 
NumericCompareOp, AscOrDesc,
+                                  UnsignedInteger>,
     parse_tree::remove_content::on<HasTagExpr, NumericCompareExpr, NotExpr, 
AndExpr, OrExpr, Wildcard, SelectExpr,
                                    FromExpr, WhereClause, OrderByClause, 
LimitClause, SearchStmt>>;
 
@@ -53,7 +55,9 @@ StatusOr<std::unique_ptr<parse_tree::node>> 
ParseToTree(Input&& in) {
 }
 
 struct Transformer : ir::TreeTransformer {
-  static auto Transform(const TreeNode& node) -> 
StatusOr<std::unique_ptr<Node>> {
+  explicit Transformer(const ParamMap& param_map) : TreeTransformer(param_map) 
{}
+
+  auto Transform(const TreeNode& node) -> StatusOr<std::unique_ptr<Node>> {
     if (Is<Boolean>(node)) {
       return Node::Create<ir::BoolLiteral>(node->string_view() == "true");
     } else if (Is<Number>(node)) {
@@ -63,23 +67,46 @@ struct Transformer : ir::TreeTransformer {
     } else if (Is<HasTagExpr>(node)) {
       CHECK(node->children.size() == 2);
 
-      return Node::Create<ir::TagContainExpr>(
-          std::make_unique<ir::FieldRef>(node->children[0]->string()),
-          
Node::MustAs<ir::StringLiteral>(GET_OR_RET(Transform(node->children[1]))));
+      const auto& tag = node->children[1];
+      std::unique_ptr<ir::StringLiteral> res;
+      if (Is<StringL>(tag)) {
+        res = Node::MustAs<ir::StringLiteral>(GET_OR_RET(Transform(tag)));
+      } else if (Is<Param>(tag)) {
+        res = std::make_unique<ir::StringLiteral>(GET_OR_RET(GetParam(tag)));
+      } else {
+        return {Status::NotOK, "encountered invalid tag"};
+      }
+
+      return 
Node::Create<ir::TagContainExpr>(std::make_unique<ir::FieldRef>(node->children[0]->string()),
+                                              std::move(res));
     } else if (Is<NumericCompareExpr>(node)) {
       CHECK(node->children.size() == 3);
 
       const auto& lhs = node->children[0];
       const auto& rhs = node->children[2];
 
+      auto number_or_param = [this](const TreeNode& node) -> 
StatusOr<std::unique_ptr<NumericLiteral>> {
+        if (Is<Number>(node)) {
+          return Node::MustAs<ir::NumericLiteral>(GET_OR_RET(Transform(node)));
+        } else if (Is<Param>(node)) {
+          auto val = GET_OR_RET(ParseFloat(GET_OR_RET(GetParam(node)))
+                                    .Prefixed(fmt::format("parameter {} is not 
a number", node->string_view())));
+
+          return std::make_unique<ir::NumericLiteral>(val);
+        } else {
+          return {Status::NotOK,
+                  fmt::format("expected a number or a parameter in numeric 
comparison but got {}", node->type)};
+        }
+      };
+
       auto op = 
ir::NumericCompareExpr::FromOperator(node->children[1]->string_view()).value();
-      if (Is<Identifier>(lhs) && Is<Number>(rhs)) {
+      if (Is<Identifier>(lhs) && (Is<Number>(rhs) || Is<Param>(rhs))) {
         return Node::Create<ir::NumericCompareExpr>(op, 
std::make_unique<ir::FieldRef>(lhs->string()),
-                                                    
Node::MustAs<ir::NumericLiteral>(GET_OR_RET(Transform(rhs))));
-      } else if (Is<Number>(lhs) && Is<Identifier>(rhs)) {
+                                                    
GET_OR_RET(number_or_param(rhs)));
+      } else if ((Is<Number>(lhs) || Is<Param>(lhs)) && Is<Identifier>(rhs)) {
         return 
Node::Create<ir::NumericCompareExpr>(ir::NumericCompareExpr::Flip(op),
                                                     
std::make_unique<ir::FieldRef>(rhs->string()),
-                                                    
Node::MustAs<ir::NumericLiteral>(GET_OR_RET(Transform(lhs))));
+                                                    
GET_OR_RET(number_or_param(lhs)));
       } else {
         return {Status::NotOK, "the left and right side of numeric comparison 
should be an identifier and a number"};
       }
@@ -181,8 +208,9 @@ struct Transformer : ir::TreeTransformer {
 };
 
 template <typename Input>
-StatusOr<std::unique_ptr<ir::Node>> ParseToIR(Input&& in) {
-  return 
Transformer::Transform(GET_OR_RET(ParseToTree(std::forward<Input>(in))));
+StatusOr<std::unique_ptr<ir::Node>> ParseToIR(Input&& in, const ParamMap& 
param_map = {}) {
+  Transformer transformer{param_map};
+  return 
transformer.Transform(GET_OR_RET(ParseToTree(std::forward<Input>(in))));
 }
 
 }  // namespace sql
diff --git a/src/search/value.h b/src/search/value.h
index f3395717..8978b22f 100644
--- a/src/search/value.h
+++ b/src/search/value.h
@@ -40,8 +40,8 @@ using String = std::string;  // e.g. a single tag
 using NumericArray = std::vector<Numeric>;  // used for vector fields
 using StringArray = std::vector<String>;    // used for tag fields, e.g. a 
list for tags
 
-struct Value : std::variant<Null, Numeric, StringArray, NumericArray> {
-  using Base = std::variant<Null, Numeric, StringArray, NumericArray>;
+struct Value : std::variant<Null, Numeric, String, StringArray, NumericArray> {
+  using Base = std::variant<Null, Numeric, String, StringArray, NumericArray>;
 
   using Base::Base;
 
@@ -52,6 +52,11 @@ struct Value : std::variant<Null, Numeric, StringArray, 
NumericArray> {
     return std::holds_alternative<T>(*this);
   }
 
+  template <typename T>
+  bool IsOrNull() const {
+    return Is<T>() || IsNull();
+  }
+
   template <typename T>
   const auto &Get() const {
     CHECK(Is<T>());
@@ -69,6 +74,8 @@ struct Value : std::variant<Null, Numeric, StringArray, 
NumericArray> {
       return "";
     } else if (Is<Numeric>()) {
       return fmt::format("{}", Get<Numeric>());
+    } else if (Is<String>()) {
+      return Get<String>();
     } else if (Is<StringArray>()) {
       return util::StringJoin(
           Get<StringArray>(), [](const auto &v) -> decltype(auto) { return v; 
}, sep);
@@ -85,6 +92,8 @@ struct Value : std::variant<Null, Numeric, StringArray, 
NumericArray> {
       return "";
     } else if (Is<Numeric>()) {
       return fmt::format("{}", Get<Numeric>());
+    } else if (Is<String>()) {
+      return Get<String>();
     } else if (Is<StringArray>()) {
       auto tag = dynamic_cast<redis::TagFieldMetadata *>(meta);
       char sep = tag ? tag->separator : ',';
diff --git a/tests/cppunit/redis_query_parser_test.cc 
b/tests/cppunit/redis_query_parser_test.cc
index a31051a5..bd66d41a 100644
--- a/tests/cppunit/redis_query_parser_test.cc
+++ b/tests/cppunit/redis_query_parser_test.cc
@@ -21,11 +21,14 @@
 #include <gtest/gtest.h>
 #include <search/redis_query_transformer.h>
 
+#include "search/common_transformer.h"
 #include "tao/pegtl/string_input.hpp"
 
 using namespace kqir::redis_query;
 
-static auto Parse(const std::string& in) { return ParseToIR(string_input(in, 
"test")); }
+static auto Parse(const std::string& in, const kqir::ParamMap& pm = {}) {
+  return ParseToIR(string_input(in, "test"), pm);
+}
 
 #define AssertSyntaxError(node) ASSERT_EQ(node.Msg(), "invalid syntax");  // 
NOLINT
 
@@ -90,3 +93,11 @@ TEST(RedisQueryParserTest, Simple) {
   AssertIR(Parse("* *"), "(and true, true)");
   AssertIR(Parse("*|*"), "(or true, true)");
 }
+
+TEST(RedisQueryParserTest, Params) {
+  AssertIR(Parse("@c:[$left ($right]", {{"left", "1"}, {"right", "2"}}), "(and 
c >= 1, c < 2)");
+  AssertIR(Parse("@c:[($x $x]", {{"x", "2"}}), "(and c > 2, c <= 2)");
+  AssertIR(Parse("@c:{$y}", {{"y", "hello"}}), "c hastag \"hello\"");
+  AssertIR(Parse("@c:{$y} @d:[$zzz inf]", {{"y", "hello"}, {"zzz", "3"}}), 
"(and c hastag \"hello\", d >= 3)");
+  ASSERT_EQ(Parse("@c:{$y}", {{"z", "hello"}}).Msg(), "parameter with name `y` 
not found");
+}
diff --git a/tests/cppunit/sql_parser_test.cc b/tests/cppunit/sql_parser_test.cc
index c6fc96ba..e85368ce 100644
--- a/tests/cppunit/sql_parser_test.cc
+++ b/tests/cppunit/sql_parser_test.cc
@@ -21,11 +21,15 @@
 #include <gtest/gtest.h>
 #include <search/sql_transformer.h>
 
+#include "search/common_transformer.h"
+#include "search/value.h"
 #include "tao/pegtl/string_input.hpp"
 
 using namespace kqir::sql;
 
-static auto Parse(const std::string& in) { return ParseToIR(string_input(in, 
"test")); }
+static auto Parse(const std::string& in, const kqir::ParamMap& pm = {}) {
+  return ParseToIR(string_input(in, "test"), pm);
+}
 
 #define AssertSyntaxError(node) ASSERT_EQ(node.Msg(), "invalid syntax");  // 
NOLINT
 
@@ -133,3 +137,12 @@ TEST(SQLParserTest, Simple) {
   AssertIR(Parse("select a from b where c = 1 or d hastag \"x\" and 2 <= e 
order by e asc limit 0, 10"),
            "select a from b where (or c = 1, (and d hastag \"x\", e >= 2)) 
sortby e, asc limit 0, 10");
 }
+
+TEST(SQLParserTest, Params) {
+  AssertIR(Parse("select a from b where c = @what", {{"what", "1"}}), "select 
a from b where c = 1");
+  AssertIR(Parse("select a from b where @x = c", {{"x", "2"}}), "select a from 
b where c = 2");
+  AssertIR(Parse("select a from b where c hastag @y", {{"y", "hello"}}), 
"select a from b where c hastag \"hello\"");
+  AssertIR(Parse("select a from b where c hastag @y and @zzz = d", {{"y", 
"hello"}, {"zzz", "3"}}),
+           "select a from b where (and c hastag \"hello\", d = 3)");
+  ASSERT_EQ(Parse("select a from b where c hastag @y", {{"z", 
"hello"}}).Msg(), "parameter with name `y` not found");
+}

Reply via email to