This is an automated email from the ASF dual-hosted git repository.

twice pushed a commit to branch unstable
in repository https://gitbox.apache.org/repos/asf/kvrocks.git


The following commit(s) were added to refs/heads/unstable by this push:
     new 1af42d7b Add semantics checker for KQIR (#2207)
1af42d7b is described below

commit 1af42d7b2ea8244182127c3f87522701bc5fc76b
Author: Twice <[email protected]>
AuthorDate: Fri Mar 29 23:45:15 2024 +0900

    Add semantics checker for KQIR (#2207)
---
 src/search/ir_dot_dumper.h            |   2 +
 src/search/ir_sema_checker.h          | 145 ++++++++++++++++++++++++++++++++++
 tests/cppunit/ir_sema_checker_test.cc |  82 +++++++++++++++++++
 3 files changed, 229 insertions(+)

diff --git a/src/search/ir_dot_dumper.h b/src/search/ir_dot_dumper.h
index 5bb6dc7b..a0bcfb72 100644
--- a/src/search/ir_dot_dumper.h
+++ b/src/search/ir_dot_dumper.h
@@ -28,6 +28,8 @@ namespace kqir {
 struct DotDumper {
   std::ostream &os;
 
+  explicit DotDumper(std::ostream &os) : os(os) {}
+
   void Dump(Node *node) {
     os << "digraph {\n";
     dump(node);
diff --git a/src/search/ir_sema_checker.h b/src/search/ir_sema_checker.h
new file mode 100644
index 00000000..77427c7c
--- /dev/null
+++ b/src/search/ir_sema_checker.h
@@ -0,0 +1,145 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ *
+ */
+
+#pragma once
+
+#include <map>
+#include <memory>
+
+#include "ir.h"
+#include "search_encoding.h"
+#include "storage/redis_metadata.h"
+
+namespace kqir {
+
+struct IndexInfo;
+
+struct FieldInfo {
+  std::string name;
+  IndexInfo *index = nullptr;
+  std::unique_ptr<redis::SearchFieldMetadata> metadata;
+
+  FieldInfo(std::string name, std::unique_ptr<redis::SearchFieldMetadata> 
&&metadata)
+      : name(std::move(name)), metadata(std::move(metadata)) {}
+};
+
+struct IndexInfo {
+  using FieldMap = std::map<std::string, FieldInfo>;
+
+  std::string name;
+  SearchMetadata metadata;
+  FieldMap fields;
+
+  IndexInfo(std::string name, SearchMetadata metadata) : 
name(std::move(name)), metadata(std::move(metadata)) {}
+
+  void Add(FieldInfo &&field) {
+    const auto &name = field.name;
+    field.index = this;
+    fields.emplace(name, std::move(field));
+  }
+};
+
+using IndexMap = std::map<std::string, IndexInfo>;
+
+struct SemaChecker {
+  const IndexMap &index_map;
+
+  const IndexInfo *current_index = nullptr;
+
+  using Result = std::map<const Node *, std::variant<const FieldInfo *, const 
IndexInfo *>>;
+  Result result;
+
+  explicit SemaChecker(const IndexMap &index_map) : index_map(index_map) {}
+
+  Status Check(Node *node) {
+    if (auto v = dynamic_cast<SearchStmt *>(node)) {
+      auto index_name = v->index->name;
+      if (auto iter = index_map.find(index_name); iter != index_map.end()) {
+        current_index = &iter->second;
+        result.emplace(v->index.get(), current_index);
+
+        GET_OR_RET(Check(v->select_expr.get()));
+        if (v->query_expr) GET_OR_RET(Check(v->query_expr.get()));
+        if (v->limit) GET_OR_RET(Check(v->limit.get()));
+        if (v->sort_by) GET_OR_RET(Check(v->sort_by.get()));
+      } else {
+        return {Status::NotOK, fmt::format("index `{}` not found", 
index_name)};
+      }
+    } else if (auto v [[maybe_unused]] = dynamic_cast<Limit *>(node)) {
+      return Status::OK();
+    } else if (auto v = dynamic_cast<SortBy *>(node)) {
+      if (auto iter = current_index->fields.find(v->field->name); iter == 
current_index->fields.end()) {
+        return {Status::NotOK, fmt::format("field `{}` not found in index 
`{}`", v->field->name, current_index->name)};
+      } else {
+        result.emplace(v->field.get(), &iter->second);
+      }
+    } else if (auto v = dynamic_cast<AndExpr *>(node)) {
+      for (const auto &n : v->inners) {
+        GET_OR_RET(Check(n.get()));
+      }
+    } else if (auto v = dynamic_cast<OrExpr *>(node)) {
+      for (const auto &n : v->inners) {
+        GET_OR_RET(Check(n.get()));
+      }
+    } else if (auto v = dynamic_cast<NotExpr *>(node)) {
+      GET_OR_RET(Check(v->inner.get()));
+    } else if (auto v = dynamic_cast<TagContainExpr *>(node)) {
+      if (auto iter = current_index->fields.find(v->field->name); iter == 
current_index->fields.end()) {
+        return {Status::NotOK, fmt::format("field `{}` not found in index 
`{}`", v->field->name)};
+      } else if (auto meta = dynamic_cast<redis::SearchTagFieldMetadata 
*>(iter->second.metadata.get()); !meta) {
+        return {Status::NotOK, fmt::format("field `{}` is not a tag field", 
v->field->name)};
+      } else {
+        result.emplace(v->field.get(), &iter->second);
+
+        if (v->tag->val.empty()) {
+          return {Status::NotOK, "tag cannot be an empty string"};
+        }
+
+        if (v->tag->val.find(meta->separator) != std::string::npos) {
+          return {Status::NotOK, fmt::format("tag cannot contain the separator 
`{}`", meta->separator)};
+        }
+      }
+    } else if (auto v = dynamic_cast<NumericCompareExpr *>(node)) {
+      if (auto iter = current_index->fields.find(v->field->name); iter == 
current_index->fields.end()) {
+        return {Status::NotOK, fmt::format("field `{}` not found in index 
`{}`", v->field->name, current_index->name)};
+      } else if (!dynamic_cast<redis::SearchNumericFieldMetadata 
*>(iter->second.metadata.get())) {
+        return {Status::NotOK, fmt::format("field `{}` is not a numeric 
field", v->field->name)};
+      } else {
+        result.emplace(v->field.get(), &iter->second);
+      }
+    } else if (auto v = dynamic_cast<SelectExpr *>(node)) {
+      for (const auto &n : v->fields) {
+        if (auto iter = current_index->fields.find(n->name); iter == 
current_index->fields.end()) {
+          return {Status::NotOK, fmt::format("field `{}` not found in index 
`{}`", n->name, current_index->name)};
+        } else {
+          result.emplace(n.get(), &iter->second);
+        }
+      }
+    } else if (auto v [[maybe_unused]] = dynamic_cast<BoolLiteral *>(node)) {
+      return Status::OK();
+    } else {
+      return {Status::NotOK, fmt::format("unexpected IR node type: {}", 
node->Name())};
+    }
+
+    return Status::OK();
+  }
+};
+
+}  // namespace kqir
diff --git a/tests/cppunit/ir_sema_checker_test.cc 
b/tests/cppunit/ir_sema_checker_test.cc
new file mode 100644
index 00000000..db222f6c
--- /dev/null
+++ b/tests/cppunit/ir_sema_checker_test.cc
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ *
+ */
+
+#include "search/ir_sema_checker.h"
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+
+#include <memory>
+
+#include "gtest/gtest.h"
+#include "search/search_encoding.h"
+#include "search/sql_transformer.h"
+#include "storage/redis_metadata.h"
+
+using namespace kqir;
+
+static auto Parse(const std::string& in) { return 
sql::ParseToIR(peg::string_input(in, "test")); }
+
+IndexMap MakeIndexMap() {
+  auto f1 = FieldInfo("f1", std::make_unique<redis::SearchTagFieldMetadata>());
+  auto f2 = FieldInfo("f2", 
std::make_unique<redis::SearchNumericFieldMetadata>());
+  auto f3 = FieldInfo("f3", 
std::make_unique<redis::SearchNumericFieldMetadata>());
+  auto ia = IndexInfo("ia", SearchMetadata());
+  ia.Add(std::move(f1));
+  ia.Add(std::move(f2));
+  ia.Add(std::move(f3));
+
+  auto& name = ia.name;
+  IndexMap res;
+  res.emplace(name, std::move(ia));
+  return res;
+}
+
+using testing::MatchesRegex;
+
+TEST(SemaCheckerTest, Simple) {
+  auto index_map = MakeIndexMap();
+
+  {
+    SemaChecker checker(index_map);
+    ASSERT_EQ(checker.Check(Parse("select a from b")->get()).Msg(), "index `b` 
not found");
+    ASSERT_EQ(checker.Check(Parse("select a from ia")->get()).Msg(), "field 
`a` not found in index `ia`");
+    ASSERT_EQ(checker.Check(Parse("select f1 from ia")->get()).Msg(), "ok");
+    ASSERT_EQ(checker.Check(Parse("select f1 from ia where b = 
1")->get()).Msg(), "field `b` not found in index `ia`");
+    ASSERT_EQ(checker.Check(Parse("select f1 from ia where f1 = 
1")->get()).Msg(), "field `f1` is not a numeric field");
+    ASSERT_EQ(checker.Check(Parse("select f1 from ia where f2 hastag 
\"a\"")->get()).Msg(),
+              "field `f2` is not a tag field");
+    ASSERT_EQ(checker.Check(Parse("select f1 from ia where f1 hastag \"a\" and 
f2 = 1")->get()).Msg(), "ok");
+    ASSERT_EQ(checker.Check(Parse("select f1 from ia where f1 hastag 
\"\"")->get()).Msg(),
+              "tag cannot be an empty string");
+    ASSERT_EQ(checker.Check(Parse("select f1 from ia where f1 hastag 
\",\"")->get()).Msg(),
+              "tag cannot contain the separator `,`");
+    ASSERT_EQ(checker.Check(Parse("select f1 from ia order by 
a")->get()).Msg(), "field `a` not found in index `ia`");
+  }
+
+  {
+    SemaChecker checker(index_map);
+    auto root = *Parse("select f1 from ia where f1 hastag \"a\" and f2 = 1 
order by f3");
+
+    ASSERT_EQ(checker.Check(root.get()).Msg(), "ok");
+
+    ASSERT_EQ(checker.result.size(), 5);
+  }
+}

Reply via email to