This is an automated email from the ASF dual-hosted git repository.
twice pushed a commit to branch unstable
in repository https://gitbox.apache.org/repos/asf/kvrocks.git
The following commit(s) were added to refs/heads/unstable by this push:
new 1af42d7b Add semantics checker for KQIR (#2207)
1af42d7b is described below
commit 1af42d7b2ea8244182127c3f87522701bc5fc76b
Author: Twice <[email protected]>
AuthorDate: Fri Mar 29 23:45:15 2024 +0900
Add semantics checker for KQIR (#2207)
---
src/search/ir_dot_dumper.h | 2 +
src/search/ir_sema_checker.h | 145 ++++++++++++++++++++++++++++++++++
tests/cppunit/ir_sema_checker_test.cc | 82 +++++++++++++++++++
3 files changed, 229 insertions(+)
diff --git a/src/search/ir_dot_dumper.h b/src/search/ir_dot_dumper.h
index 5bb6dc7b..a0bcfb72 100644
--- a/src/search/ir_dot_dumper.h
+++ b/src/search/ir_dot_dumper.h
@@ -28,6 +28,8 @@ namespace kqir {
struct DotDumper {
std::ostream &os;
+ explicit DotDumper(std::ostream &os) : os(os) {}
+
void Dump(Node *node) {
os << "digraph {\n";
dump(node);
diff --git a/src/search/ir_sema_checker.h b/src/search/ir_sema_checker.h
new file mode 100644
index 00000000..77427c7c
--- /dev/null
+++ b/src/search/ir_sema_checker.h
@@ -0,0 +1,145 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ *
+ */
+
+#pragma once
+
+#include <map>
+#include <memory>
+
+#include "ir.h"
+#include "search_encoding.h"
+#include "storage/redis_metadata.h"
+
+namespace kqir {
+
+struct IndexInfo;
+
+struct FieldInfo {
+ std::string name;
+ IndexInfo *index = nullptr;
+ std::unique_ptr<redis::SearchFieldMetadata> metadata;
+
+ FieldInfo(std::string name, std::unique_ptr<redis::SearchFieldMetadata>
&&metadata)
+ : name(std::move(name)), metadata(std::move(metadata)) {}
+};
+
+struct IndexInfo {
+ using FieldMap = std::map<std::string, FieldInfo>;
+
+ std::string name;
+ SearchMetadata metadata;
+ FieldMap fields;
+
+ IndexInfo(std::string name, SearchMetadata metadata) :
name(std::move(name)), metadata(std::move(metadata)) {}
+
+ void Add(FieldInfo &&field) {
+ const auto &name = field.name;
+ field.index = this;
+ fields.emplace(name, std::move(field));
+ }
+};
+
+using IndexMap = std::map<std::string, IndexInfo>;
+
+struct SemaChecker {
+ const IndexMap &index_map;
+
+ const IndexInfo *current_index = nullptr;
+
+ using Result = std::map<const Node *, std::variant<const FieldInfo *, const
IndexInfo *>>;
+ Result result;
+
+ explicit SemaChecker(const IndexMap &index_map) : index_map(index_map) {}
+
+ Status Check(Node *node) {
+ if (auto v = dynamic_cast<SearchStmt *>(node)) {
+ auto index_name = v->index->name;
+ if (auto iter = index_map.find(index_name); iter != index_map.end()) {
+ current_index = &iter->second;
+ result.emplace(v->index.get(), current_index);
+
+ GET_OR_RET(Check(v->select_expr.get()));
+ if (v->query_expr) GET_OR_RET(Check(v->query_expr.get()));
+ if (v->limit) GET_OR_RET(Check(v->limit.get()));
+ if (v->sort_by) GET_OR_RET(Check(v->sort_by.get()));
+ } else {
+ return {Status::NotOK, fmt::format("index `{}` not found",
index_name)};
+ }
+ } else if (auto v [[maybe_unused]] = dynamic_cast<Limit *>(node)) {
+ return Status::OK();
+ } else if (auto v = dynamic_cast<SortBy *>(node)) {
+ if (auto iter = current_index->fields.find(v->field->name); iter ==
current_index->fields.end()) {
+ return {Status::NotOK, fmt::format("field `{}` not found in index
`{}`", v->field->name, current_index->name)};
+ } else {
+ result.emplace(v->field.get(), &iter->second);
+ }
+ } else if (auto v = dynamic_cast<AndExpr *>(node)) {
+ for (const auto &n : v->inners) {
+ GET_OR_RET(Check(n.get()));
+ }
+ } else if (auto v = dynamic_cast<OrExpr *>(node)) {
+ for (const auto &n : v->inners) {
+ GET_OR_RET(Check(n.get()));
+ }
+ } else if (auto v = dynamic_cast<NotExpr *>(node)) {
+ GET_OR_RET(Check(v->inner.get()));
+ } else if (auto v = dynamic_cast<TagContainExpr *>(node)) {
+ if (auto iter = current_index->fields.find(v->field->name); iter ==
current_index->fields.end()) {
+ return {Status::NotOK, fmt::format("field `{}` not found in index
`{}`", v->field->name)};
+ } else if (auto meta = dynamic_cast<redis::SearchTagFieldMetadata
*>(iter->second.metadata.get()); !meta) {
+ return {Status::NotOK, fmt::format("field `{}` is not a tag field",
v->field->name)};
+ } else {
+ result.emplace(v->field.get(), &iter->second);
+
+ if (v->tag->val.empty()) {
+ return {Status::NotOK, "tag cannot be an empty string"};
+ }
+
+ if (v->tag->val.find(meta->separator) != std::string::npos) {
+ return {Status::NotOK, fmt::format("tag cannot contain the separator
`{}`", meta->separator)};
+ }
+ }
+ } else if (auto v = dynamic_cast<NumericCompareExpr *>(node)) {
+ if (auto iter = current_index->fields.find(v->field->name); iter ==
current_index->fields.end()) {
+ return {Status::NotOK, fmt::format("field `{}` not found in index
`{}`", v->field->name, current_index->name)};
+ } else if (!dynamic_cast<redis::SearchNumericFieldMetadata
*>(iter->second.metadata.get())) {
+ return {Status::NotOK, fmt::format("field `{}` is not a numeric
field", v->field->name)};
+ } else {
+ result.emplace(v->field.get(), &iter->second);
+ }
+ } else if (auto v = dynamic_cast<SelectExpr *>(node)) {
+ for (const auto &n : v->fields) {
+ if (auto iter = current_index->fields.find(n->name); iter ==
current_index->fields.end()) {
+ return {Status::NotOK, fmt::format("field `{}` not found in index
`{}`", n->name, current_index->name)};
+ } else {
+ result.emplace(n.get(), &iter->second);
+ }
+ }
+ } else if (auto v [[maybe_unused]] = dynamic_cast<BoolLiteral *>(node)) {
+ return Status::OK();
+ } else {
+ return {Status::NotOK, fmt::format("unexpected IR node type: {}",
node->Name())};
+ }
+
+ return Status::OK();
+ }
+};
+
+} // namespace kqir
diff --git a/tests/cppunit/ir_sema_checker_test.cc
b/tests/cppunit/ir_sema_checker_test.cc
new file mode 100644
index 00000000..db222f6c
--- /dev/null
+++ b/tests/cppunit/ir_sema_checker_test.cc
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ *
+ */
+
+#include "search/ir_sema_checker.h"
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+
+#include <memory>
+
+#include "gtest/gtest.h"
+#include "search/search_encoding.h"
+#include "search/sql_transformer.h"
+#include "storage/redis_metadata.h"
+
+using namespace kqir;
+
+static auto Parse(const std::string& in) { return
sql::ParseToIR(peg::string_input(in, "test")); }
+
+IndexMap MakeIndexMap() {
+ auto f1 = FieldInfo("f1", std::make_unique<redis::SearchTagFieldMetadata>());
+ auto f2 = FieldInfo("f2",
std::make_unique<redis::SearchNumericFieldMetadata>());
+ auto f3 = FieldInfo("f3",
std::make_unique<redis::SearchNumericFieldMetadata>());
+ auto ia = IndexInfo("ia", SearchMetadata());
+ ia.Add(std::move(f1));
+ ia.Add(std::move(f2));
+ ia.Add(std::move(f3));
+
+ auto& name = ia.name;
+ IndexMap res;
+ res.emplace(name, std::move(ia));
+ return res;
+}
+
+using testing::MatchesRegex;
+
+TEST(SemaCheckerTest, Simple) {
+ auto index_map = MakeIndexMap();
+
+ {
+ SemaChecker checker(index_map);
+ ASSERT_EQ(checker.Check(Parse("select a from b")->get()).Msg(), "index `b`
not found");
+ ASSERT_EQ(checker.Check(Parse("select a from ia")->get()).Msg(), "field
`a` not found in index `ia`");
+ ASSERT_EQ(checker.Check(Parse("select f1 from ia")->get()).Msg(), "ok");
+ ASSERT_EQ(checker.Check(Parse("select f1 from ia where b =
1")->get()).Msg(), "field `b` not found in index `ia`");
+ ASSERT_EQ(checker.Check(Parse("select f1 from ia where f1 =
1")->get()).Msg(), "field `f1` is not a numeric field");
+ ASSERT_EQ(checker.Check(Parse("select f1 from ia where f2 hastag
\"a\"")->get()).Msg(),
+ "field `f2` is not a tag field");
+ ASSERT_EQ(checker.Check(Parse("select f1 from ia where f1 hastag \"a\" and
f2 = 1")->get()).Msg(), "ok");
+ ASSERT_EQ(checker.Check(Parse("select f1 from ia where f1 hastag
\"\"")->get()).Msg(),
+ "tag cannot be an empty string");
+ ASSERT_EQ(checker.Check(Parse("select f1 from ia where f1 hastag
\",\"")->get()).Msg(),
+ "tag cannot contain the separator `,`");
+ ASSERT_EQ(checker.Check(Parse("select f1 from ia order by
a")->get()).Msg(), "field `a` not found in index `ia`");
+ }
+
+ {
+ SemaChecker checker(index_map);
+ auto root = *Parse("select f1 from ia where f1 hastag \"a\" and f2 = 1
order by f3");
+
+ ASSERT_EQ(checker.Check(root.get()).Msg(), "ok");
+
+ ASSERT_EQ(checker.result.size(), 5);
+ }
+}