This is an automated email from the ASF dual-hosted git repository.

yiguolei pushed a commit to branch branch-4.0
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/branch-4.0 by this push:
     new 37798bf88ca branch-4.0: [Fix](ai) Fix _exec_plan_fragment_impl meet 
unknown error when call AI_Functions (#58521)
37798bf88ca is described below

commit 37798bf88caffb6eb635f82f9e86ee676b49e341
Author: linrrarity <[email protected]>
AuthorDate: Sat Nov 29 22:33:39 2025 +0800

    branch-4.0: [Fix](ai) Fix _exec_plan_fragment_impl meet unknown error when 
call AI_Functions (#58521)
    
    pick: #58363
---
 .../segment_v2/ann_index/ann_index_reader.cpp      |   8 +-
 be/src/runtime/query_context.h                     |  11 +-
 be/src/util/quantile_state.h                       |   2 +-
 .../aggregate_function_ai_agg.h                    |  11 +-
 .../exprs/lambda_function/varray_sort_function.cpp | 273 +++++++++++++++++++++
 be/src/vec/functions/ai/ai_functions.h             |   9 +-
 be/test/ai/aggregate_function_ai_agg_test.cpp      |  28 +++
 be/test/ai/ai_function_test.cpp                    |  29 +++
 .../main/java/org/apache/doris/qe/Coordinator.java |  14 ++
 9 files changed, 367 insertions(+), 18 deletions(-)

diff --git a/be/src/olap/rowset/segment_v2/ann_index/ann_index_reader.cpp 
b/be/src/olap/rowset/segment_v2/ann_index/ann_index_reader.cpp
index 35880336580..f993804c072 100644
--- a/be/src/olap/rowset/segment_v2/ann_index/ann_index_reader.cpp
+++ b/be/src/olap/rowset/segment_v2/ann_index/ann_index_reader.cpp
@@ -125,8 +125,8 @@ Status AnnIndexReader::query(io::IOContext* io_ctx, 
AnnTopNParam* param, AnnInde
             
stats->engine_convert_ns.update(index_search_result.engine_convert_ns);
             
stats->engine_prepare_ns.update(index_search_result.engine_prepare_ns);
         } else {
-            throw Status::NotSupported("Unsupported index type: {}",
-                                       ann_index_type_to_string(_index_type));
+            throw Exception(Status::NotSupported("Unsupported index type: {}",
+                                                 
ann_index_type_to_string(_index_type)));
         }
 
         DCHECK(index_search_result.roaring != nullptr);
@@ -174,8 +174,8 @@ Status AnnIndexReader::range_search(const 
AnnRangeSearchParams& params,
             hnsw_param->bounded_queue = custom_params.hnsw_bounded_queue;
             search_param = std::move(hnsw_param);
         } else {
-            throw Status::NotSupported("Unsupported index type: {}",
-                                       ann_index_type_to_string(_index_type));
+            throw Exception(Status::NotSupported("Unsupported index type: {}",
+                                                 
ann_index_type_to_string(_index_type)));
         }
 
         search_param->is_le_or_lt = params.is_le_or_lt;
diff --git a/be/src/runtime/query_context.h b/be/src/runtime/query_context.h
index 7e879e6221e..30cc8fd6016 100644
--- a/be/src/runtime/query_context.h
+++ b/be/src/runtime/query_context.h
@@ -260,14 +260,11 @@ public:
 
     void set_ai_resources(std::map<std::string, TAIResource> ai_resources) {
         _ai_resources =
-                std::make_unique<std::map<std::string, 
TAIResource>>(std::move(ai_resources));
+                std::make_shared<std::map<std::string, 
TAIResource>>(std::move(ai_resources));
     }
 
-    const std::map<std::string, TAIResource>& get_ai_resources() const {
-        if (_ai_resources == nullptr) {
-            throw Status::InternalError("AI resources not found");
-        }
-        return *_ai_resources;
+    const std::shared_ptr<std::map<std::string, TAIResource>>& 
get_ai_resources() const {
+        return _ai_resources;
     }
 
     std::unordered_map<TNetworkAddress, std::shared_ptr<PBackendService_Stub>>
@@ -360,7 +357,7 @@ private:
     std::unordered_map<int, std::vector<std::shared_ptr<TRuntimeProfileTree>>> 
_profile_map;
     std::unordered_map<int, std::shared_ptr<TRuntimeProfileTree>> 
_load_channel_profile_map;
 
-    std::unique_ptr<std::map<std::string, TAIResource>> _ai_resources;
+    std::shared_ptr<std::map<std::string, TAIResource>> _ai_resources;
 
     void _report_query_profile();
 
diff --git a/be/src/util/quantile_state.h b/be/src/util/quantile_state.h
index 2d46989fcc6..23e43860c34 100644
--- a/be/src/util/quantile_state.h
+++ b/be/src/util/quantile_state.h
@@ -64,7 +64,7 @@ public:
     double get_explicit_value_by_percentile(float percentile) const;
 #ifdef BE_TEST
     std::string to_string() const {
-        throw Status::NotSupported("QuantileState::to_string() not 
implemented");
+        throw Exception(Status::NotSupported("QuantileState::to_string() not 
implemented"));
     }
 #endif
     ~QuantileState() = default;
diff --git a/be/src/vec/aggregate_functions/aggregate_function_ai_agg.h 
b/be/src/vec/aggregate_functions/aggregate_function_ai_agg.h
index 89ac8828ba1..d27ceb41312 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_ai_agg.h
+++ b/be/src/vec/aggregate_functions/aggregate_function_ai_agg.h
@@ -141,9 +141,14 @@ public:
             _task = task_ref.to_string();
 
             std::string resource_name = resource_name_ref.to_string();
-            const std::map<std::string, TAIResource>& ai_resources = 
_ctx->get_ai_resources();
-            auto it = ai_resources.find(resource_name);
-            if (it == ai_resources.end()) {
+            const std::shared_ptr<std::map<std::string, TAIResource>>& 
ai_resources =
+                    _ctx->get_ai_resources();
+            if (!ai_resources) {
+                throw Exception(ErrorCode::INTERNAL_ERROR,
+                                "AI resources metadata missing in 
QueryContext");
+            }
+            auto it = ai_resources->find(resource_name);
+            if (it == ai_resources->end()) {
                 throw Exception(ErrorCode::NOT_FOUND, "AI resource not found: 
" + resource_name);
             }
             _ai_config = it->second;
diff --git a/be/src/vec/exprs/lambda_function/varray_sort_function.cpp 
b/be/src/vec/exprs/lambda_function/varray_sort_function.cpp
new file mode 100644
index 00000000000..1cffdcd5ea7
--- /dev/null
+++ b/be/src/vec/exprs/lambda_function/varray_sort_function.cpp
@@ -0,0 +1,273 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <glog/logging.h>
+
+#include "common/status.h"
+#include "vec/columns/column.h"
+#include "vec/columns/column_array.h"
+#include "vec/columns/column_nullable.h"
+#include "vec/columns/column_varbinary.h"
+#include "vec/columns/column_vector.h"
+#include "vec/common/assert_cast.h"
+#include "vec/core/block.h"
+#include "vec/core/column_with_type_and_name.h"
+#include "vec/data_types/data_type.h"
+#include "vec/exprs/lambda_function/lambda_function.h"
+#include "vec/exprs/lambda_function/lambda_function_factory.h"
+#include "vec/exprs/vexpr.h"
+#include "vec/utils/util.hpp"
+
+namespace doris::vectorized {
+#include "common/compile_check_begin.h"
+
+class VExprContext;
+
+using ConstColumnVariant =
+        std::variant<const ColumnUInt8*, const ColumnInt8*, const 
ColumnInt16*, const ColumnInt32*,
+                     const ColumnInt64*, const ColumnInt128*, const 
ColumnFloat32*,
+                     const ColumnFloat64*, const ColumnString*, const 
ColumnVarbinary*,
+                     const ColumnArray*, const ColumnIPv4*, const ColumnIPv6*,
+                     const ColumnDecimal32*, const ColumnDecimal64*, const 
ColumnDecimal128V2*,
+                     const ColumnDecimal128V3*, const ColumnDecimal256*, const 
ColumnDate*,
+                     const ColumnDateTime*, const ColumnDateV2*, const 
ColumnDateTimeV2*,
+                     const ColumnTime*, const ColumnTimeV2*>;
+
+template <typename T>
+struct is_column_vector : std::false_type {};
+
+template <PrimitiveType T>
+struct is_column_vector<ColumnVector<T>> : std::true_type {};
+
+template <typename T>
+inline constexpr bool is_column_vector_v = is_column_vector<T>::value;
+
+class ArraySortFunction : public LambdaFunction {
+    ENABLE_FACTORY_CREATOR(ArraySortFunction);
+
+public:
+    ~ArraySortFunction() override = default;
+    static constexpr auto name = "array_sort";
+
+    static LambdaFunctionPtr create() { return 
std::make_shared<ArraySortFunction>(); }
+
+    std::string get_name() const override { return name; }
+
+    Status execute(VExprContext* context, const vectorized::Block* block, 
ColumnPtr& result_column,
+                   const DataTypePtr& result_type, const VExprSPtrs& children) 
const override {
+        ///* array_sort(lambda, arg) *///
+
+        DCHECK_EQ(children.size(), 2);
+
+        // 1. get data, we need to obtain this actual data and type.
+        ColumnPtr column_ptr;
+        RETURN_IF_ERROR(children[1]->execute_column(context, block, 
column_ptr));
+        DataTypePtr type_ptr = children[1]->execute_type(block);
+
+        auto column = column_ptr->convert_to_full_column_if_const();
+
+        auto input_rows = column->size();
+
+        auto arg_type = type_ptr;
+        auto arg_column = column;
+
+        ColumnPtr outside_null_map = nullptr;
+
+        if (arg_column->is_nullable()) {
+            arg_column = assert_cast<const 
ColumnNullable*>(column.get())->get_nested_column_ptr();
+            outside_null_map =
+                    assert_cast<const 
ColumnNullable*>(column.get())->get_null_map_column_ptr();
+            arg_type = assert_cast<const 
DataTypeNullable*>(type_ptr.get())->get_nested_type();
+        }
+
+        const auto& col_array = assert_cast<const ColumnArray&>(*arg_column);
+        const auto& off_data =
+                assert_cast<const 
ColumnArray::ColumnOffsets&>(col_array.get_offsets_column())
+                        .get_data();
+
+        const auto& nested_nullable_column =
+                assert_cast<const ColumnNullable&>(*col_array.get_data_ptr());
+
+        auto pType = assert_cast<const DataTypeArray*>(arg_type.get())
+                             ->get_nested_type()
+                             ->get_primitive_type();
+
+        // Get the actual type data based on PrimitiveType.
+        ConstColumnVariant src_data;
+        RETURN_IF_ERROR(
+                get_data_from_type(pType, 
nested_nullable_column.get_nested_column(), src_data));
+
+        const auto& src_nullmap = nested_nullable_column.get_null_map_column();
+
+        const auto& col_type = assert_cast<const DataTypeArray&>(*arg_type);
+
+        // 2. prepare a lambda_block for lambda execution
+        auto element_size = nested_nullable_column.size();
+        IColumn::Permutation permutation(element_size);
+        for (size_t i = 0; i < element_size; ++i) {
+            permutation[i] = i;
+        }
+
+        /**
+         *  suppose the data_type is nullable(int). The first two rows are the 
parameter columns, and the
+         *  last row is the result column(type: tinyint). every column's size 
is 1. the lambda_block is 
+         *  row  data  nullmap    type
+         *   0    10      0    nullable(int)
+         *   1    20      1    nullable(int)
+         *   2   1/-1/0  ...     tinyint
+         *  The size of a column is always 1; we only need to use it to store 
the specific values ​​in the array for comparison.
+         */
+        Block lambda_block;
+        for (int i = 0; i <= 2; i++) {
+            lambda_block.insert(vectorized::ColumnWithTypeAndName(
+                    nested_nullable_column.clone_empty(), 
col_type.get_nested_type(), "temp"));
+        }
+
+        MutableColumnPtr temp_data[2];
+        NullMap* temp_nullmap_data[2] = {nullptr, nullptr};
+        for (int i = 0; i < 2; i++) {
+            auto* temp_column = assert_cast<ColumnNullable*>(
+                    
lambda_block.get_by_position(i).column->assume_mutable().get());
+            temp_data[i] = temp_column->get_nested_column_ptr();
+            auto& null_map_col = 
assert_cast<ColumnUInt8&>(temp_column->get_null_map_column());
+            temp_nullmap_data[i] = &null_map_col.get_data();
+            temp_data[i]->resize(1);
+            temp_nullmap_data[i]->resize(1);
+        };
+
+        int lambda_res_id = 2;
+
+        // 3. sort array by executing lambda function
+        // During the sorting process, the parameter columns of lambda_block 
are first populated using prepare_lambda_input,
+        // and then the lambda function is executed to obtain the result.
+        std::visit(
+                [&](auto* data) {
+                    using ColumnType = std::decay_t<decltype(*data)>;
+                    ColumnType* data_vec[2] = 
{assert_cast<ColumnType*>(temp_data[0].get()),
+                                               
assert_cast<ColumnType*>(temp_data[1].get())};
+
+                    // If columnType is ColumnVector<T>, use `get_data()[0]`;
+                    // otherwise, need to clear it first, and then use 
`insert_from`.
+                    auto prepare_lambda_input = [&](size_t i, size_t cid) {
+                        if (src_nullmap.get_data()[i]) {
+                            (*temp_nullmap_data[cid])[0] = 1;
+                        } else {
+                            (*temp_nullmap_data[cid])[0] = 0;
+                            if constexpr (is_column_vector_v<ColumnType>) {
+                                data_vec[cid]->get_data()[0] = 
data->get_data()[i];
+                            } else {
+                                data_vec[cid]->clear();
+                                data_vec[cid]->insert_from(*data, i);
+                            }
+                        }
+                    };
+
+                    for (int row = 0; row < input_rows; ++row) {
+                        auto start = off_data[row - 1];
+                        auto end = off_data[row];
+                        std::sort(&permutation[start], &permutation[end], 
[&](size_t i, size_t j) {
+                            prepare_lambda_input(i, 0);
+                            prepare_lambda_input(j, 1);
+                            auto status =
+                                    children[0]->execute(context, 
&lambda_block, &lambda_res_id);
+                            if (!status.ok()) [[unlikely]] {
+                                throw Exception(Status::InternalError(
+                                        "when execute array_sort lambda 
function: {}",
+                                        status.to_string()));
+                            }
+
+                            // raw_res_col maybe columnVector or ColumnConst
+                            ColumnPtr raw_res_col =
+                                    
lambda_block.get_by_position(lambda_res_id).column;
+                            ColumnPtr full_res_col = 
raw_res_col->convert_to_full_column_if_const();
+
+                            // only -1, 0, 1
+                            long cmp = assert_cast<const 
ColumnInt8*>(full_res_col.get())
+                                               ->get_data()[0];
+
+                            return cmp < 0;
+                        });
+                    }
+                },
+                src_data);
+
+        // 4. set the result to result_column
+        ColumnWithTypeAndName result_arr;
+        if (result_type->is_nullable()) {
+            result_column = ColumnNullable::create(
+                    
ColumnArray::create(nested_nullable_column.permute(permutation, 0),
+                                        col_array.get_offsets_ptr()),
+                    outside_null_map);
+
+        } else {
+            DCHECK(!column->is_nullable());
+            result_column = 
ColumnArray::create(nested_nullable_column.permute(permutation, 0),
+                                                col_array.get_offsets_ptr());
+        }
+
+        return Status::OK();
+    }
+
+#define DISPATCH_PRIMITIVE_TYPE(TYPE, COLUMN_CLASS)                 \
+    case TYPE:                                                      \
+        column_variant = &assert_cast<const COLUMN_CLASS&>(column); \
+        break;
+
+    Status get_data_from_type(PrimitiveType pType, const IColumn& column,
+                              ConstColumnVariant& column_variant) const {
+        switch (pType) {
+            DISPATCH_PRIMITIVE_TYPE(TYPE_BOOLEAN, ColumnUInt8)
+            DISPATCH_PRIMITIVE_TYPE(TYPE_TINYINT, ColumnInt8)
+            DISPATCH_PRIMITIVE_TYPE(TYPE_SMALLINT, ColumnInt16)
+            DISPATCH_PRIMITIVE_TYPE(TYPE_INT, ColumnInt32)
+            DISPATCH_PRIMITIVE_TYPE(TYPE_BIGINT, ColumnInt64)
+            DISPATCH_PRIMITIVE_TYPE(TYPE_LARGEINT, ColumnInt128)
+            DISPATCH_PRIMITIVE_TYPE(TYPE_FLOAT, ColumnFloat32)
+            DISPATCH_PRIMITIVE_TYPE(TYPE_DOUBLE, ColumnFloat64)
+            DISPATCH_PRIMITIVE_TYPE(TYPE_CHAR, ColumnString)
+            DISPATCH_PRIMITIVE_TYPE(TYPE_STRING, ColumnString)
+            DISPATCH_PRIMITIVE_TYPE(TYPE_VARCHAR, ColumnString)
+            DISPATCH_PRIMITIVE_TYPE(TYPE_VARBINARY, ColumnVarbinary)
+            DISPATCH_PRIMITIVE_TYPE(TYPE_ARRAY, ColumnArray)
+            DISPATCH_PRIMITIVE_TYPE(TYPE_IPV4, ColumnIPv4)
+            DISPATCH_PRIMITIVE_TYPE(TYPE_IPV6, ColumnIPv6)
+            DISPATCH_PRIMITIVE_TYPE(TYPE_DECIMAL32, ColumnDecimal32)
+            DISPATCH_PRIMITIVE_TYPE(TYPE_DECIMAL64, ColumnDecimal64)
+            DISPATCH_PRIMITIVE_TYPE(TYPE_DECIMAL128I, ColumnDecimal128V3)
+            DISPATCH_PRIMITIVE_TYPE(TYPE_DECIMALV2, ColumnDecimal128V2)
+            DISPATCH_PRIMITIVE_TYPE(TYPE_DECIMAL256, ColumnDecimal256)
+            DISPATCH_PRIMITIVE_TYPE(TYPE_DATE, ColumnDate)
+            DISPATCH_PRIMITIVE_TYPE(TYPE_DATETIME, ColumnDateTime)
+            DISPATCH_PRIMITIVE_TYPE(TYPE_DATEV2, ColumnDateV2)
+            DISPATCH_PRIMITIVE_TYPE(TYPE_DATETIMEV2, ColumnDateTimeV2)
+            DISPATCH_PRIMITIVE_TYPE(TYPE_TIME, ColumnTime)
+            DISPATCH_PRIMITIVE_TYPE(TYPE_TIMEV2, ColumnTimeV2)
+        default:
+            return Status::InternalError("Unsupported type in array_sort");
+        }
+        return Status::OK();
+    }
+
+#undef DISPATCH_PRIMITIVE_TYPE
+};
+
+void register_function_array_sort(doris::vectorized::LambdaFunctionFactory& 
factory) {
+    factory.register_function<ArraySortFunction>();
+}
+
+#include "common/compile_check_end.h"
+} // namespace doris::vectorized
diff --git a/be/src/vec/functions/ai/ai_functions.h 
b/be/src/vec/functions/ai/ai_functions.h
index c31a7659614..01efe65f938 100644
--- a/be/src/vec/functions/ai/ai_functions.h
+++ b/be/src/vec/functions/ai/ai_functions.h
@@ -190,10 +190,13 @@ private:
         StringRef resource_name_ref = resource_column.column->get_data_at(0);
         std::string resource_name = std::string(resource_name_ref.data, 
resource_name_ref.size);
 
-        const std::map<std::string, TAIResource>& ai_resources =
+        const std::shared_ptr<std::map<std::string, TAIResource>>& 
ai_resources =
                 context->state()->get_query_ctx()->get_ai_resources();
-        auto it = ai_resources.find(resource_name);
-        if (it == ai_resources.end()) {
+        if (!ai_resources) {
+            return Status::InternalError("AI resources metadata missing in 
QueryContext");
+        }
+        auto it = ai_resources->find(resource_name);
+        if (it == ai_resources->end()) {
             return Status::InvalidArgument("AI resource not found: " + 
resource_name);
         }
         config = it->second;
diff --git a/be/test/ai/aggregate_function_ai_agg_test.cpp 
b/be/test/ai/aggregate_function_ai_agg_test.cpp
index ff7580552ea..911da8583e8 100644
--- a/be/test/ai/aggregate_function_ai_agg_test.cpp
+++ b/be/test/ai/aggregate_function_ai_agg_test.cpp
@@ -413,4 +413,32 @@ TEST_F(AggregateFunctionAIAggTest, 
mock_resource_send_request_test) {
     _agg_function->destroy(place);
 }
 
+TEST_F(AggregateFunctionAIAggTest, missing_ai_resources_metadata_test) {
+    auto empty_query_ctx = MockQueryContext::create();
+    _agg_function->set_query_context(empty_query_ctx.get());
+
+    std::vector<std::string> resources = {"resource_name"};
+    std::vector<std::string> texts = {"test input"};
+    std::vector<std::string> task = {"summarize"};
+    auto col_resource = ColumnHelper::create_column<DataTypeString>(resources);
+    auto col_text = ColumnHelper::create_column<DataTypeString>(texts);
+    auto col_task = ColumnHelper::create_column<DataTypeString>(task);
+
+    std::unique_ptr<char[]> memory(new char[_agg_function->size_of_data()]);
+    AggregateDataPtr place = memory.get();
+    _agg_function->create(place);
+
+    const IColumn* columns[3] = {col_resource.get(), col_text.get(), 
col_task.get()};
+
+    try {
+        _agg_function->add(place, columns, 0, _arena);
+        FAIL() << "Expected exception for missing AI resources";
+    } catch (const Exception& e) {
+        EXPECT_EQ(e.code(), ErrorCode::INTERNAL_ERROR);
+        EXPECT_NE(e.to_string().find("AI resources metadata missing"), 
std::string::npos);
+    }
+
+    _agg_function->destroy(place);
+}
+
 } // namespace doris::vectorized
diff --git a/be/test/ai/ai_function_test.cpp b/be/test/ai/ai_function_test.cpp
index 74a46240be8..14409781417 100644
--- a/be/test/ai/ai_function_test.cpp
+++ b/be/test/ai/ai_function_test.cpp
@@ -551,6 +551,35 @@ TEST(AIFunctionTest, MockResourceSendRequest) {
     ASSERT_EQ(val, "this is a mock response. test input");
 }
 
+TEST(AIFunctionTest, MissingAIResourcesMetadataTest) {
+    auto query_ctx = MockQueryContext::create();
+    TQueryOptions query_options;
+    TQueryGlobals query_globals;
+    RuntimeState runtime_state(TUniqueId(), 0, query_options, query_globals, 
nullptr,
+                               query_ctx.get());
+    auto ctx = FunctionContext::create_context(&runtime_state, {}, {});
+
+    std::vector<std::string> resources = {"resource_name"};
+    std::vector<std::string> texts = {"test"};
+    auto col_resource = ColumnHelper::create_column<DataTypeString>(resources);
+    auto col_text = ColumnHelper::create_column<DataTypeString>(texts);
+
+    Block block;
+    block.insert({std::move(col_resource), std::make_shared<DataTypeString>(), 
"resource"});
+    block.insert({std::move(col_text), std::make_shared<DataTypeString>(), 
"text"});
+    block.insert({nullptr, std::make_shared<DataTypeString>(), "result"});
+
+    ColumnNumbers arguments = {0, 1};
+    size_t result_idx = 2;
+
+    auto sentiment_func = FunctionAISentiment::create();
+    Status exec_status =
+            sentiment_func->execute_impl(ctx.get(), block, arguments, 
result_idx, texts.size());
+
+    ASSERT_FALSE(exec_status.ok());
+    ASSERT_NE(exec_status.to_string().find("AI resources metadata missing"), 
std::string::npos);
+}
+
 TEST(AIFunctionTest, ReturnTypeTest) {
     FunctionAIClassify func_classify;
     DataTypes args;
diff --git a/fe/fe-core/src/main/java/org/apache/doris/qe/Coordinator.java 
b/fe/fe-core/src/main/java/org/apache/doris/qe/Coordinator.java
index 9d44f15384e..c0b9f2be457 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/qe/Coordinator.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/qe/Coordinator.java
@@ -19,8 +19,10 @@ package org.apache.doris.qe;
 
 import org.apache.doris.analysis.DescriptorTable;
 import org.apache.doris.analysis.StorageBackend;
+import org.apache.doris.catalog.AIResource;
 import org.apache.doris.catalog.Env;
 import org.apache.doris.catalog.FsBroker;
+import org.apache.doris.catalog.Resource;
 import org.apache.doris.common.Config;
 import org.apache.doris.common.MarkedCountDownLatch;
 import org.apache.doris.common.Pair;
@@ -89,6 +91,7 @@ import org.apache.doris.system.Backend;
 import org.apache.doris.system.SystemInfoService;
 import org.apache.doris.task.LoadEtlTask;
 import org.apache.doris.thrift.PaloInternalServiceVersion;
+import org.apache.doris.thrift.TAIResource;
 import org.apache.doris.thrift.TBrokerScanRange;
 import org.apache.doris.thrift.TDataSinkType;
 import org.apache.doris.thrift.TDescriptorTable;
@@ -3235,6 +3238,17 @@ public class Coordinator implements CoordInterface {
                     if (ignoreDataDistribution) {
                         params.setParallelInstances(parallelTasksNum);
                     }
+
+                    // Used for AI Functions
+                    Map<String, TAIResource> aiResourceMap = 
Maps.newLinkedHashMap();
+                    for (Resource resource : 
Env.getCurrentEnv().getResourceMgr()
+                                                    
.getResource(Resource.ResourceType.AI)) {
+                        if (resource instanceof AIResource) {
+                            aiResourceMap.put(resource.getName(), 
((AIResource) resource).toThrift());
+                        }
+                    }
+
+                    params.setAiResources(aiResourceMap);
                     res.put(instanceExecParam.host, params);
                     
res.get(instanceExecParam.host).setBucketSeqToInstanceIdx(new HashMap<Integer, 
Integer>());
                     
res.get(instanceExecParam.host).setShuffleIdxToInstanceIdx(new HashMap<Integer, 
Integer>());


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to