This is an automated email from the ASF dual-hosted git repository. hulk pushed a commit to branch 2.12 in repository https://gitbox.apache.org/repos/asf/kvrocks.git
commit e8fecff5e2c667da650b7447c0cd60339d0833eb Author: hulk <[email protected]> AuthorDate: Tue Apr 15 09:11:27 2025 +0800 fix(protocol): inline mode should allow the quoted string (#2873) --- src/common/string_util.cc | 101 ++++++++++++++++++++++++++++ src/common/string_util.h | 1 + src/server/redis_request.cc | 6 +- tests/cppunit/string_util_test.cc | 60 +++++++++++++++++ tests/gocase/unit/protocol/protocol_test.go | 20 ++++++ 5 files changed, 187 insertions(+), 1 deletion(-) diff --git a/src/common/string_util.cc b/src/common/string_util.cc index c1e6f7e3f..f05e286fc 100644 --- a/src/common/string_util.cc +++ b/src/common/string_util.cc @@ -276,6 +276,107 @@ std::pair<std::string, std::string> SplitGlob(std::string_view glob) { return {prefix, ""}; } +static int HexDigitToInt(const char c) { + if (c >= '0' && c <= '9') { + return c - '0'; + } else if (c >= 'a' && c <= 'f') { + return c - 'a' + 10; + } else if (c >= 'A' && c <= 'F') { + return c - 'A' + 10; + } + return 0; +} + +StatusOr<std::vector<std::string>> SplitArguments(std::string_view in) { + std::vector<std::string> arguments; + std::string current_string; + + enum State { NORMAL, DOUBLE_QUOTED, SINGLE_QUOTED, ESCAPE } state = NORMAL; + + bool done = false; + for (size_t i = 0; i < in.size() && !done; i++) { + const auto c = in[i]; + switch (state) { + case NORMAL: + if (std::isspace(c)) { + if (!current_string.empty()) { + arguments.emplace_back(std::move(current_string)); + current_string.clear(); + } + } else if (c == '\r' || c == '\n' || c == '\t') { + done = true; + } else if (c == '"') { + state = DOUBLE_QUOTED; + } else if (c == '\'') { + state = SINGLE_QUOTED; + } else { + current_string.push_back(c); + } + break; + case SINGLE_QUOTED: + if (c == '\\' && (i + 1) < in.size() && in[i + 1] == '\'') { + current_string.push_back('\''); + i++; + } else if (c == '\'') { + // + if (i + 1 < in.size() && !std::isspace(in[i + 1])) { + return {Status::NotOK, "the closed single quote must be followed by a space"}; + } + state = NORMAL; + } else { + current_string.push_back(c); + } + break; + case DOUBLE_QUOTED: + if (c == '\\') { + state = ESCAPE; + } else if (c == '"') { + if (i + 1 < in.size() && !std::isspace(in[i + 1])) { + return {Status::NotOK, "the closed double quote must be followed by a space"}; + } + state = NORMAL; + } else { + current_string.push_back(c); + } + break; + case ESCAPE: + // It's the hex digit after the \x + if (c == 'x' && (i + 2) < in.size() && std::isxdigit(in[i + 1]) && std::isxdigit(in[i + 2])) { + // Convert the hex digit to a char + auto hex_byte = static_cast<char>(HexDigitToInt(in[i + 1]) * 16 | HexDigitToInt(in[i + 2])); + current_string.push_back(hex_byte); + i += 2; + } else if (c == '"' || c == '\'' || c == '\\') { + current_string.push_back(c); + } else if (c == 'n') { + current_string.push_back('\n'); + } else if (c == 'r') { + current_string.push_back('\r'); + } else if (c == 't') { + current_string.push_back('\t'); + } else if (c == 'b') { + current_string.push_back('\b'); + } else if (c == 'a') { + current_string.push_back('\a'); + } else { + current_string.push_back(c); + } + state = DOUBLE_QUOTED; + break; + } + } + if (state == DOUBLE_QUOTED || state == SINGLE_QUOTED) { + return {Status::NotOK, "unclosed quote string"}; + } + if (state == ESCAPE) { + return {Status::NotOK, "unexpected trailing escape character"}; + } + if (!current_string.empty()) { + arguments.emplace_back(std::move(current_string)); + } + return arguments; +} + std::vector<std::string> RegexMatch(const std::string &str, const std::string ®ex) { std::regex base_regex(regex); std::smatch pieces_match; diff --git a/src/common/string_util.h b/src/common/string_util.h index fe66be5a2..32c0f4b8e 100644 --- a/src/common/string_util.h +++ b/src/common/string_util.h @@ -50,6 +50,7 @@ Iter FindICase(Iter begin, Iter end, std::string_view expected) { Status ValidateGlob(std::string_view glob); bool StringMatch(std::string_view glob, std::string_view str, bool ignore_case = false); std::pair<std::string, std::string> SplitGlob(std::string_view glob); +StatusOr<std::vector<std::string>> SplitArguments(std::string_view in); std::vector<std::string> RegexMatch(const std::string &str, const std::string ®ex); std::string StringToHex(std::string_view input); diff --git a/src/server/redis_request.cc b/src/server/redis_request.cc index f1e2eb993..f729e85c2 100644 --- a/src/server/redis_request.cc +++ b/src/server/redis_request.cc @@ -86,7 +86,11 @@ Status Request::Tokenize(evbuffer *input) { return {Status::NotOK, "Protocol error: invalid bulk length"}; } - tokens_ = util::Split(std::string(line.get(), line.length), " \t"); + auto arguments = util::SplitArguments(line.get()); + if (!arguments.IsOK()) { + return {Status::NotOK, "Protocol error: " + arguments.Msg()}; + } + tokens_ = std::move(arguments.GetValue()); if (tokens_.empty()) continue; commands_.emplace_back(std::move(tokens_)); state_ = ArrayLen; diff --git a/tests/cppunit/string_util_test.cc b/tests/cppunit/string_util_test.cc index 1d24cf594..d31a99615 100644 --- a/tests/cppunit/string_util_test.cc +++ b/tests/cppunit/string_util_test.cc @@ -268,3 +268,63 @@ TEST(StringUtil, RegexMatchExtractSSTFile) { ASSERT_TRUE(match_results[1] == "/000038.sst"); } } + +TEST(StringUtil, SplitArguments) { + std::map<std::string, std::vector<std::string>> valid_cases = { + // With ' ' only + {"a b c", {"a", "b", "c"}}, + // Other whitespace characters should work + {"a\tb\nc\fd", {"a", "b", "c", "d"}}, + + // With double quote escape characters + {R"(hello "a b" c)", {"hello", "a b", "c"}}, + // With single quote escape characters + {R"('a b' c)", {"a b", "c"}}, + // With both single and double quote escape characters + {R"(a 'b c' " d e ")", {"a", "b c", " d e "}}, + // With both single and double quote escape characters + {R"(a " b c " 'd e')", {"a", " b c ", "d e"}}, + + // With the single quote escape characters + {R"('a\' b' c)", {"a' b", "c"}}, + {R"('a\n\t\r\'b' c)", {R"(a\n\t\r'b)", "c"}}, + + // With the double quote escape characters + {R"("a\"b" c)", {"a\"b", "c"}}, + {R"("a\n\t\qb\g" c)", {"a\n\tqbg", "c"}}, + + // Escape with the hex digits + {R"(\x61 \x62 \x63)", {R"(\x61)", R"(\x62)", R"(\x63)"}}, + {R"("a \x61\x62" "\x63")", {"a ab", "c"}}, + // '\' will be removed from '\xT0' because it's not v alid hex digit and a valid escape sequence + {R"("a \xT0\x62" "\x63")", {R"(a xT0b)", "c"}}, + {R"("a b\x6Fc" "d\x63e")", {"a boc", "dce"}}, + + }; + for (const auto &item : valid_cases) { + const std::string &input = item.first; + const std::vector<std::string> &expected = item.second; + auto result = util::SplitArguments(input); + ASSERT_TRUE(result.IsOK()); + ASSERT_EQ(result.GetValue(), expected); + } + + // invalid cases + std::map<std::string, std::string> invalid_cases = { + {R"(a "b c)", "unclosed quote string"}, + {R"(a 'b c)", "unclosed quote string"}, + {R"(a "b' c)", "unclosed quote string"}, + {R"(a 'b" c)", "unclosed quote string"}, + {R"(a b 'c\)", "unclosed quote string"}, + {R"(a b "c\)", "unexpected trailing escape character"}, + {R"(a b "c"d)", "the closed double quote must be followed by a space"}, + {R"(a 'b'c)", "the closed single quote must be followed by a space"}, + }; + for (const auto &item : invalid_cases) { + const std::string &input = item.first; + const std::string &expected_error = item.second; + auto result = util::SplitArguments(input); + ASSERT_FALSE(result.IsOK()); + ASSERT_EQ(result.Msg(), expected_error); + } +} diff --git a/tests/gocase/unit/protocol/protocol_test.go b/tests/gocase/unit/protocol/protocol_test.go index 6be669bb8..a533bfde9 100644 --- a/tests/gocase/unit/protocol/protocol_test.go +++ b/tests/gocase/unit/protocol/protocol_test.go @@ -114,6 +114,26 @@ func TestProtocolNetwork(t *testing.T) { c.MustRead(t, "+OK") }) + t.Run("inline protocol with quoted string", func(t *testing.T) { + c := srv.NewTCPClient() + LF := "\n" + defer func() { require.NoError(t, c.Close()) }() + require.NoError(t, c.Write("RPUSH my_list a 'b c' d"+LF)) + c.MustRead(t, ":3") + require.NoError(t, c.Write(`RPUSH my_list "foo \x61\x62"`+LF)) + c.MustRead(t, ":4") + require.NoError(t, c.Write(`RPUSH my_list "bar \"\g\t\n\q"`+LF)) + c.MustRead(t, ":5") + require.NoError(t, c.Write(`RPUSH my_list ' a b' "c d e " `+LF)) + c.MustRead(t, ":7") + + rdb := srv.NewClient() + defer func() { require.NoError(t, rdb.Close()) }() + values, err := rdb.LRange(context.Background(), "my_list", 0, -1).Result() + require.NoError(t, err) + require.Equal(t, []string{"a", "b c", "d", "foo ab", "bar \"g\t\nq", " a b", "c d e "}, values) + }) + t.Run("mix LF/CRLF protocol separator", func(t *testing.T) { c := srv.NewTCPClient() defer func() { require.NoError(t, c.Close()) }()
