This is an automated email from the ASF dual-hosted git repository.
yiguolei pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push:
new abc105b94d6 [Enhancement](ai) relax the matching restriction of some
AI-Functions (#58077)
abc105b94d6 is described below
commit abc105b94d6d94c79de4d84bb43ad34d5daa2a52
Author: linrrarity <[email protected]>
AuthorDate: Tue Nov 18 16:13:47 2025 +0800
[Enhancement](ai) relax the matching restriction of some AI-Functions
(#58077)
### What problem does this PR solve?
Issue Number: close #xxx
Related PR: #xxx
Problem Summary:
Function `AI_FILTER` can only match `0` or `1`, which may cause parsing
errors due to the uncertainty of AI returned content, like `0\n` or `\t1
` etc.
Process `AI_SIMILARITY` consistently
### Release note
None
### Check List (For Author)
- Test <!-- At least one of them must be included. -->
- [ ] Regression test
- [x] Unit Test
- [ ] Manual test (add detailed scripts or steps below)
- [ ] No need to test or manual test. Explain why:
- [ ] This is a refactor/code format and no logic has been changed.
- [ ] Previous test can cover this change.
- [ ] No code files have been changed.
- [ ] Other reason <!-- Add your reason? -->
- Behavior changed:
- [ ] No.
- [ ] Yes. <!-- Explain the behavior change -->
- Does this need documentation?
- [x] No.
- [ ] Yes. <!-- Add document PR link here. eg:
https://github.com/apache/doris-website/pull/1214 -->
### Check List (For Reviewer who merge this PR)
- [ ] Confirm the release note
- [ ] Confirm test cases
- [ ] Confirm document
- [ ] Add branch pick label <!-- Add branch pick label that this PR
should merge into -->
---
be/src/vec/functions/ai/ai_functions.h | 38 +++++++-
be/test/ai/ai_function_test.cpp | 154 +++++++++++++++++++++++++++++++++
2 files changed, 190 insertions(+), 2 deletions(-)
diff --git a/be/src/vec/functions/ai/ai_functions.h
b/be/src/vec/functions/ai/ai_functions.h
index 7bb4909b457..c31a7659614 100644
--- a/be/src/vec/functions/ai/ai_functions.h
+++ b/be/src/vec/functions/ai/ai_functions.h
@@ -20,6 +20,9 @@
#include <gen_cpp/FrontendService.h>
#include <gen_cpp/PaloInternalService_types.h>
+#include <algorithm>
+#include <cctype>
+#include <cstdlib>
#include <memory>
#include <string>
#include <type_traits>
@@ -122,8 +125,14 @@ public:
}
case PrimitiveType::TYPE_BOOLEAN: { // boolean for AI_FILTER
#ifdef BE_TEST
- string_result = "0";
+ const char* test_result = std::getenv("AI_TEST_RESULT");
+ if (test_result != nullptr) {
+ string_result = test_result;
+ } else {
+ string_result = "0";
+ }
#endif
+ trim_string(string_result);
if (string_result != "1" && string_result != "0") {
return Status::RuntimeError("Failed to parse boolean
value: " +
string_result);
@@ -133,7 +142,22 @@ public:
break;
}
case PrimitiveType::TYPE_FLOAT: { // float for AI_SIMILARITY
-
assert_cast<ColumnFloat32&>(*col_result).insert_value(std::stof(string_result));
+#ifdef BE_TEST
+ const char* test_result = std::getenv("AI_TEST_RESULT");
+ if (test_result != nullptr) {
+ string_result = test_result;
+ } else {
+ string_result = "0.0";
+ }
+#endif
+ trim_string(string_result);
+ try {
+ float float_value = std::stof(string_result);
+
assert_cast<ColumnFloat32&>(*col_result).insert_value(float_value);
+ } catch (...) {
+ return Status::RuntimeError("Failed to parse float
value: " +
+ string_result);
+ }
break;
}
default:
@@ -147,6 +171,16 @@ public:
}
private:
+ // Trim whitespace and newlines from string
+ static void trim_string(std::string& str) {
+ str.erase(str.begin(), std::find_if(str.begin(), str.end(),
+ [](unsigned char ch) { return
!std::isspace(ch); }));
+ str.erase(std::find_if(str.rbegin(), str.rend(),
+ [](unsigned char ch) { return
!std::isspace(ch); })
+ .base(),
+ str.end());
+ }
+
// The ai resource must be literal
Status _init_from_resource(FunctionContext* context, const Block& block,
const ColumnNumbers& arguments, TAIResource&
config,
diff --git a/be/test/ai/ai_function_test.cpp b/be/test/ai/ai_function_test.cpp
index f288a4b7b92..74a46240be8 100644
--- a/be/test/ai/ai_function_test.cpp
+++ b/be/test/ai/ai_function_test.cpp
@@ -295,6 +295,79 @@ TEST(AIFunctionTest, AISimilarityTest) {
ASSERT_EQ(prompt, "Text 1: I like this dish\nText 2: This dish is very
good");
}
+TEST(AIFunctionTest, AISimilarityExecuteTest) {
+ auto runtime_state = std::make_unique<MockRuntimeState>();
+ auto ctx = FunctionContext::create_context(runtime_state.get(), {}, {});
+
+ std::vector<std::string> resources = {"mock_resource"};
+ std::vector<std::string> text1 = {"I like this dish"};
+ std::vector<std::string> text2 = {"This dish is very good"};
+ auto col_resource = ColumnHelper::create_column<DataTypeString>(resources);
+ auto col_text1 = ColumnHelper::create_column<DataTypeString>(text1);
+ auto col_text2 = ColumnHelper::create_column<DataTypeString>(text2);
+
+ Block block;
+ block.insert({std::move(col_resource), std::make_shared<DataTypeString>(),
"resource"});
+ block.insert({std::move(col_text1), std::make_shared<DataTypeString>(),
"text1"});
+ block.insert({std::move(col_text2), std::make_shared<DataTypeString>(),
"text2"});
+ block.insert({nullptr, std::make_shared<DataTypeFloat32>(), "result"});
+
+ ColumnNumbers arguments = {0, 1, 2};
+ size_t result_idx = 3;
+
+ auto similarity_func = FunctionAISimilarity::create();
+ Status exec_status =
+ similarity_func->execute_impl(ctx.get(), block, arguments,
result_idx, text1.size());
+
+ ASSERT_TRUE(exec_status.ok());
+}
+
+TEST(AIFunctionTest, AISimilarityTrimWhitespace) {
+ auto runtime_state = std::make_unique<MockRuntimeState>();
+ auto ctx = FunctionContext::create_context(runtime_state.get(), {}, {});
+
+ std::vector<std::pair<std::string, float>> test_cases = {
+ {"0.5", 0.5f}, {"1.0", 1.0f}, {"0.0", 0.0f},
{" 0.5", 0.5f},
+ {"0.5 ", 0.5f}, {" 0.5 ", 0.5f}, {"\n0.8", 0.8f},
{"0.3\n", 0.3f},
+ {"\n0.7\n", 0.7f}, {"\t0.2\t", 0.2f}, {" \n\t0.9 \n\t", 0.9f},
{" 0.1 ", 0.1f},
+ {"\r\n0.6\r\n", 0.6f}};
+
+ for (const auto& test_case : test_cases) {
+ setenv("AI_TEST_RESULT", test_case.first.c_str(), 1);
+
+ std::vector<std::string> resources = {"mock_resource"};
+ std::vector<std::string> text1 = {"Test text 1"};
+ std::vector<std::string> text2 = {"Test text 2"};
+ auto col_resource =
ColumnHelper::create_column<DataTypeString>(resources);
+ auto col_text1 = ColumnHelper::create_column<DataTypeString>(text1);
+ auto col_text2 = ColumnHelper::create_column<DataTypeString>(text2);
+
+ Block block;
+ block.insert({std::move(col_resource),
std::make_shared<DataTypeString>(), "resource"});
+ block.insert({std::move(col_text1),
std::make_shared<DataTypeString>(), "text1"});
+ block.insert({std::move(col_text2),
std::make_shared<DataTypeString>(), "text2"});
+ block.insert({nullptr, std::make_shared<DataTypeFloat32>(), "result"});
+
+ ColumnNumbers arguments = {0, 1, 2};
+ size_t result_idx = 3;
+
+ auto similarity_func = FunctionAISimilarity::create();
+ Status exec_status = similarity_func->execute_impl(ctx.get(), block,
arguments, result_idx,
+ text1.size());
+
+ ASSERT_TRUE(exec_status.ok()) << "Failed for test case: '" <<
test_case.first << "'";
+
+ const auto& res_col =
+ assert_cast<const
ColumnFloat32&>(*block.get_by_position(result_idx).column);
+ float val = res_col.get_data()[0];
+ ASSERT_FLOAT_EQ(val, test_case.second)
+ << "Failed for test case: '" << test_case.first
+ << "', expected: " << test_case.second << ", got: " << val;
+ }
+
+ unsetenv("AI_TEST_RESULT");
+}
+
TEST(AIFunctionTest, AIFilterTest) {
FunctionAIFilter function;
@@ -343,6 +416,87 @@ TEST(AIFunctionTest, AIFilterExecuteTest) {
ASSERT_TRUE(val == 0);
}
+TEST(AIFunctionTest, AIFilterTrimWhitespace) {
+ auto runtime_state = std::make_unique<MockRuntimeState>();
+ auto ctx = FunctionContext::create_context(runtime_state.get(), {}, {});
+
+ std::vector<std::pair<std::string, UInt8>> test_cases = {
+ {"0", 0}, {"1", 1}, {" 0", 0}, {"0 ", 0},
+ {" 0 ", 0}, {"\n0", 0}, {"0\n", 0}, {"\n0\n", 0},
+ {"\t1\t", 1}, {" \n\t1 \n\t", 1}, {" 1 ", 1}, {"\r\n0\r\n", 0}};
+
+ for (const auto& test_case : test_cases) {
+ setenv("AI_TEST_RESULT", test_case.first.c_str(), 1);
+
+ std::vector<std::string> resources = {"mock_resource"};
+ std::vector<std::string> texts = {"Test input"};
+ auto col_resource =
ColumnHelper::create_column<DataTypeString>(resources);
+ auto col_text = ColumnHelper::create_column<DataTypeString>(texts);
+
+ Block block;
+ block.insert({std::move(col_resource),
std::make_shared<DataTypeString>(), "resource"});
+ block.insert({std::move(col_text), std::make_shared<DataTypeString>(),
"text"});
+ block.insert({nullptr, std::make_shared<DataTypeBool>(), "result"});
+
+ ColumnNumbers arguments = {0, 1};
+ size_t result_idx = 2;
+
+ auto filter_func = FunctionAIFilter::create();
+ Status exec_status =
+ filter_func->execute_impl(ctx.get(), block, arguments,
result_idx, texts.size());
+
+ ASSERT_TRUE(exec_status.ok()) << "Failed for test case: '" <<
test_case.first << "'";
+
+ const auto& res_col =
+ assert_cast<const
ColumnUInt8&>(*block.get_by_position(result_idx).column);
+ UInt8 val = res_col.get_data()[0];
+ ASSERT_EQ(val, test_case.second)
+ << "Failed for test case: '" << test_case.first
+ << "', expected: " << (int)test_case.second << ", got: " <<
(int)val;
+ }
+
+ unsetenv("AI_TEST_RESULT");
+}
+
+TEST(AIFunctionTest, AIFilterInvalidValue) {
+ auto runtime_state = std::make_unique<MockRuntimeState>();
+ auto ctx = FunctionContext::create_context(runtime_state.get(), {}, {});
+
+ std::vector<std::string> invalid_cases = {
+ "2", "maybe", "ok", "", " ", "01", "0.5", "sure",
"truee", "falsee",
+ "yess", "noo", "true", "false", "yes", "no", "TRUE", "FALSE",
"YES", "NO"};
+
+ for (const auto& invalid_value : invalid_cases) {
+ setenv("AI_TEST_RESULT", invalid_value.c_str(), 1);
+
+ std::vector<std::string> resources = {"mock_resource"};
+ std::vector<std::string> texts = {"Test input"};
+ auto col_resource =
ColumnHelper::create_column<DataTypeString>(resources);
+ auto col_text = ColumnHelper::create_column<DataTypeString>(texts);
+
+ Block block;
+ block.insert({std::move(col_resource),
std::make_shared<DataTypeString>(), "resource"});
+ block.insert({std::move(col_text), std::make_shared<DataTypeString>(),
"text"});
+ block.insert({nullptr, std::make_shared<DataTypeBool>(), "result"});
+
+ ColumnNumbers arguments = {0, 1};
+ size_t result_idx = 2;
+
+ auto filter_func = FunctionAIFilter::create();
+ Status exec_status =
+ filter_func->execute_impl(ctx.get(), block, arguments,
result_idx, texts.size());
+
+ ASSERT_FALSE(exec_status.ok())
+ << "Should have failed for invalid value: '" << invalid_value
<< "'";
+ ASSERT_TRUE(exec_status.to_string().find("Failed to parse boolean
value") !=
+ std::string::npos)
+ << "Error message should mention boolean parsing for value: '"
<< invalid_value
+ << "'";
+ }
+
+ unsetenv("AI_TEST_RESULT");
+}
+
TEST(AIFunctionTest, ResourceNotFound) {
auto runtime_state = std::make_unique<MockRuntimeState>();
auto ctx = FunctionContext::create_context(runtime_state.get(), {}, {});
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]