This is an automated email from the ASF dual-hosted git repository.
yiguolei pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push:
new 73546b773a6 [Fix](AI) remove thread_pool in AI Functions (#56057)
73546b773a6 is described below
commit 73546b773a63f4acf92e7eb979148942a4491218
Author: linrrarity <[email protected]>
AuthorDate: Tue Sep 16 11:52:00 2025 +0800
[Fix](AI) remove thread_pool in AI Functions (#56057)
### What problem does this PR solve?
Issue Number: close #xxx
Related PR: https://github.com/apache/doris/pull/55886
Problem Summary:
AI Function creates a thread pool in execute_impl. So
`llm_max_concurrent_requests * CPU cores / 2 threads` will be used to
call AI api for a single AI function. It's not our expectation.
### Release note
None
### Check List (For Author)
- Test <!-- At least one of them must be included. -->
- [ ] Regression test
- [ ] Unit Test
- [ ] Manual test (add detailed scripts or steps below)
- [x] No need to test or manual test. Explain why:
- [ ] This is a refactor/code format and no logic has been changed.
- [ ] Previous test can cover this change.
- [ ] No code files have been changed.
- [ ] Other reason <!-- Add your reason? -->
- Behavior changed:
- [x] No.
- [ ] Yes. <!-- Explain the behavior change -->
- Does this need documentation?
- [x] No.
- [ ] Yes. <!-- Add document PR link here. eg:
https://github.com/apache/doris-website/pull/1214 -->
### Check List (For Reviewer who merge this PR)
- [ ] Confirm the release note
- [ ] Confirm test cases
- [ ] Confirm document
- [ ] Add branch pick label <!-- Add branch pick label that this PR
should merge into -->
---
be/src/common/config.cpp | 3 -
be/src/common/config.h | 3 -
be/src/vec/functions/ai/ai_functions.h | 149 ++++++---------------
...test.cpp => aggregate_function_ai_agg_test.cpp} | 3 +-
...{build_prompt_test.cpp => ai_function_test.cpp} | 27 ++++
5 files changed, 69 insertions(+), 116 deletions(-)
diff --git a/be/src/common/config.cpp b/be/src/common/config.cpp
index 41d6e356032..ba21728d916 100644
--- a/be/src/common/config.cpp
+++ b/be/src/common/config.cpp
@@ -1579,9 +1579,6 @@
DEFINE_mBool(enable_auto_clone_on_mow_publish_missing_version, "false");
// The maximum csv line reader output buffer size
DEFINE_mInt64(max_csv_line_reader_output_buffer_size, "4294967296");
-// The maximum number of threads supported when executing LLMFunction
-DEFINE_mInt32(llm_max_concurrent_requests, "1");
-
// Maximum number of openmp threads can be used by each doris threads.
// This configuration controls the parallelism level for OpenMP operations
within Doris,
// helping to prevent resource contention and ensure stable performance when
multiple
diff --git a/be/src/common/config.h b/be/src/common/config.h
index 2196550968f..959b46c9747 100644
--- a/be/src/common/config.h
+++ b/be/src/common/config.h
@@ -1636,9 +1636,6 @@ DECLARE_String(fuzzy_test_type);
// The maximum csv line reader output buffer size
DECLARE_mInt64(max_csv_line_reader_output_buffer_size);
-// The maximum number of threads supported when executing LLMFunction
-DECLARE_mInt32(llm_max_concurrent_requests);
-
// Maximum number of OpenMP threads that can be used by each Doris thread
DECLARE_Int32(omp_threads_limit);
// The capacity of segment partial column cache, used to cache column readers
for each segment.
diff --git a/be/src/vec/functions/ai/ai_functions.h
b/be/src/vec/functions/ai/ai_functions.h
index ce9e77cc572..1dd8dcf79a3 100644
--- a/be/src/vec/functions/ai/ai_functions.h
+++ b/be/src/vec/functions/ai/ai_functions.h
@@ -75,23 +75,6 @@ public:
assert_cast<const
Derived&>(*this).get_return_type_impl(DataTypes());
MutableColumnPtr col_result = return_type_impl->create_column();
- std::unique_ptr<ThreadPool> thread_pool;
- Status st = ThreadPoolBuilder("LLMRequestPool")
- .set_min_threads(1)
-
.set_max_threads(config::llm_max_concurrent_requests > 0
- ?
config::llm_max_concurrent_requests
- : 1)
- .build(&thread_pool);
- if (!st.ok()) {
- return Status::InternalError("Failed to create thread pool: " +
st.to_string());
- }
-
- struct RowResult {
- std::variant<std::string, std::vector<float>> data;
- Status status;
- bool is_null = false;
- };
-
TAIResource config;
std::shared_ptr<AIAdapter> adapter;
if (Status status =
@@ -101,114 +84,62 @@ public:
return status;
}
- std::vector<RowResult> results(input_rows_count);
for (size_t i = 0; i < input_rows_count; ++i) {
- Status submit_status = thread_pool->submit_func([this, i, &block,
&arguments, &results,
- &adapter,
&config, context,
-
&return_type_impl]() {
- RowResult& row_result = results[i];
-
- try {
- // Build AI prompt text
- std::string prompt;
- Status status = assert_cast<const
Derived&>(*this).build_prompt(
- block, arguments, i, prompt);
-
- if (!status.ok()) {
- row_result.status = status;
- row_result.is_null = true;
- return;
- }
-
- // Execute a single AI request and get the result
- if (return_type_impl->get_primitive_type() ==
PrimitiveType::TYPE_ARRAY) {
- std::vector<float> float_result;
- status = execute_single_request(prompt, float_result,
config, adapter,
- context);
- if (!status.ok()) {
- row_result.status = status;
- row_result.is_null = true;
- return;
- }
- row_result.data = std::move(float_result);
- } else {
- std::string string_result;
- status = execute_single_request(prompt, string_result,
config, adapter,
- context);
- if (!status.ok()) {
- row_result.status = status;
- row_result.is_null = true;
- return;
- }
- row_result.data = std::move(string_result);
- }
- row_result.status = Status::OK();
- } catch (const std::exception& e) {
- row_result.status = Status::InternalError("Exception in AI
request: " +
-
std::string(e.what()));
- row_result.is_null = true;
- }
- });
-
- if (!submit_status.ok()) {
- return Status::InternalError("Failed to submit task to thread
pool: " +
- submit_status.to_string());
- }
- }
-
- thread_pool->wait();
-
- for (size_t i = 0; i < input_rows_count; ++i) {
- const RowResult& row_result = results[i];
-
- if (!row_result.status.ok()) {
- return row_result.status;
- }
+ // Build AI prompt text
+ std::string prompt;
+ RETURN_IF_ERROR(
+ assert_cast<const Derived&>(*this).build_prompt(block,
arguments, i, prompt));
+
+ // Execute a single AI request and get the result
+ if (return_type_impl->get_primitive_type() ==
PrimitiveType::TYPE_ARRAY) {
+ // Array(Float) for AI_EMBED
+ std::vector<float> float_result;
+ RETURN_IF_ERROR(
+ execute_single_request(prompt, float_result, config,
adapter, context));
+
+ auto& col_array = assert_cast<ColumnArray&>(*col_result);
+ auto& offsets = col_array.get_offsets();
+ auto& nested_nullable_col =
assert_cast<ColumnNullable&>(col_array.get_data());
+ auto& nested_col =
+
assert_cast<ColumnFloat32&>(*(nested_nullable_col.get_nested_column_ptr()));
+ nested_col.reserve(nested_col.size() + float_result.size());
+
+ size_t current_offset = nested_col.size();
+ nested_col.insert_many_raw_data(reinterpret_cast<const
char*>(float_result.data()),
+ float_result.size());
+ offsets.push_back(current_offset + float_result.size());
+ auto& null_map = nested_nullable_col.get_null_map_column();
+ null_map.insert_many_vals(0, float_result.size());
+ } else {
+ std::string string_result;
+ RETURN_IF_ERROR(
+ execute_single_request(prompt, string_result, config,
adapter, context));
- if (!row_result.is_null) {
switch (return_type_impl->get_primitive_type()) {
case PrimitiveType::TYPE_STRING: { // string
- const auto& str_data =
std::get<std::string>(row_result.data);
assert_cast<ColumnString&>(*col_result)
- .insert_data(str_data.data(), str_data.size());
+ .insert_data(string_result.data(),
string_result.size());
break;
}
- case PrimitiveType::TYPE_BOOLEAN: { // boolean
- const auto& bool_data =
std::get<std::string>(row_result.data);
- if (bool_data != "true" && bool_data != "false") {
- return Status::RuntimeError("Failed to parse boolean
value: " + bool_data);
+ case PrimitiveType::TYPE_BOOLEAN: { // boolean for AI_FILTER
+#ifdef BE_TEST
+ string_result = "false";
+#endif
+ if (string_result != "true" && string_result != "false") {
+ return Status::RuntimeError("Failed to parse boolean
value: " +
+ string_result);
}
assert_cast<ColumnUInt8&>(*col_result)
- .insert_value(static_cast<UInt8>(bool_data ==
"true"));
- break;
- }
- case PrimitiveType::TYPE_FLOAT: { // float
- const auto& str_data =
std::get<std::string>(row_result.data);
-
assert_cast<ColumnFloat32&>(*col_result).insert_value(std::stof(str_data));
+ .insert_value(static_cast<UInt8>(string_result ==
"true"));
break;
}
- case PrimitiveType::TYPE_ARRAY: { // array of floats
- const auto& float_data =
std::get<std::vector<float>>(row_result.data);
- auto& col_array = assert_cast<ColumnArray&>(*col_result);
- auto& offsets = col_array.get_offsets();
- auto& nested_nullable_col =
assert_cast<ColumnNullable&>(col_array.get_data());
- auto& nested_col = assert_cast<ColumnFloat32&>(
- *(nested_nullable_col.get_nested_column_ptr()));
- nested_col.reserve(nested_col.size() + float_data.size());
-
- size_t current_offset = nested_col.size();
- nested_col.insert_many_raw_data(
- reinterpret_cast<const char*>(float_data.data()),
float_data.size());
- offsets.push_back(current_offset + float_data.size());
- auto& null_map = nested_nullable_col.get_null_map_column();
- null_map.insert_many_vals(0, float_data.size());
+ case PrimitiveType::TYPE_FLOAT: { // float for AI_SIMILARITY
+
assert_cast<ColumnFloat32&>(*col_result).insert_value(std::stof(string_result));
break;
}
default:
return Status::InternalError("Unsupported ReturnType for
AIFunction");
}
- } else {
- col_result->insert_default();
}
}
diff --git a/be/test/ai/aggregate_function_llm_agg_test.cpp
b/be/test/ai/aggregate_function_ai_agg_test.cpp
similarity index 99%
rename from be/test/ai/aggregate_function_llm_agg_test.cpp
rename to be/test/ai/aggregate_function_ai_agg_test.cpp
index 391f0686547..4374ff16c01 100644
--- a/be/test/ai/aggregate_function_llm_agg_test.cpp
+++ b/be/test/ai/aggregate_function_ai_agg_test.cpp
@@ -15,6 +15,8 @@
// specific language governing permissions and limitations
// under the License.
+#include "vec/aggregate_functions/aggregate_function_ai_agg.h"
+
#include <gmock/gmock-matchers.h>
#include <gtest/gtest.h>
@@ -26,7 +28,6 @@
#include "runtime/query_context.h"
#include "testutil/column_helper.h"
#include "testutil/mock/mock_runtime_state.h"
-#include "vec/aggregate_functions/aggregate_function_ai_agg.h"
#include "vec/aggregate_functions/aggregate_function_simple_factory.h"
#include "vec/columns/column_string.h"
#include "vec/common/arena.h"
diff --git a/be/test/ai/build_prompt_test.cpp b/be/test/ai/ai_function_test.cpp
similarity index 93%
rename from be/test/ai/build_prompt_test.cpp
rename to be/test/ai/ai_function_test.cpp
index 9a38c071784..f288a4b7b92 100644
--- a/be/test/ai/build_prompt_test.cpp
+++ b/be/test/ai/ai_function_test.cpp
@@ -316,6 +316,33 @@ TEST(AIFunctionTest, AIFilterTest) {
ASSERT_EQ(prompt, "This is a valid sentence.");
}
+TEST(AIFunctionTest, AIFilterExecuteTest) {
+ auto runtime_state = std::make_unique<MockRuntimeState>();
+ auto ctx = FunctionContext::create_context(runtime_state.get(), {}, {});
+
+ std::vector<std::string> resources = {"mock_resource"};
+ std::vector<std::string> texts = {"This is a valid sentence."};
+ auto col_resource = ColumnHelper::create_column<DataTypeString>(resources);
+ auto col_text = ColumnHelper::create_column<DataTypeString>(texts);
+
+ Block block;
+ block.insert({std::move(col_resource), std::make_shared<DataTypeString>(),
"resource"});
+ block.insert({std::move(col_text), std::make_shared<DataTypeString>(),
"text"});
+ block.insert({nullptr, std::make_shared<DataTypeBool>(), "result"});
+
+ ColumnNumbers arguments = {0, 1};
+ size_t result_idx = 2;
+
+ auto filter_func = FunctionAIFilter::create();
+ Status exec_status =
+ filter_func->execute_impl(ctx.get(), block, arguments, result_idx,
texts.size());
+
+ const auto& res_col =
+ assert_cast<const
ColumnUInt8&>(*block.get_by_position(result_idx).column);
+ UInt8 val = res_col.get_data()[0];
+ ASSERT_TRUE(val == 0);
+}
+
TEST(AIFunctionTest, ResourceNotFound) {
auto runtime_state = std::make_unique<MockRuntimeState>();
auto ctx = FunctionContext::create_context(runtime_state.get(), {}, {});
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]