This is an automated email from the ASF dual-hosted git repository.

yiguolei pushed a commit to branch branch-4.0
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/branch-4.0 by this push:
     new f841fb6a9e5 branch-4.0: [Fix](AI_Func) Fix the error prompt-building 
and response-parsing process of Local_Adapter #58492 (#58796)
f841fb6a9e5 is described below

commit f841fb6a9e5d3b9aafc303033b06aa4d24c18cd5
Author: github-actions[bot] 
<41898282+github-actions[bot]@users.noreply.github.com>
AuthorDate: Tue Dec 9 17:44:19 2025 +0800

    branch-4.0: [Fix](AI_Func) Fix the error prompt-building and 
response-parsing process of Local_Adapter #58492 (#58796)
    
    Cherry-picked from #58492
    
    Co-authored-by: linrrarity <[email protected]>
---
 be/src/vec/functions/ai/ai_adapter.h   | 194 +++++++++++++++++++++++++--------
 be/src/vec/functions/ai/ai_functions.h |  14 +++
 be/test/ai/ai_adapter_test.cpp         | 164 +++++++++++++++++++++++-----
 be/test/ai/ai_function_test.cpp        |  24 ++++
 be/test/ai/embed_test.cpp              |  11 ++
 5 files changed, 332 insertions(+), 75 deletions(-)

diff --git a/be/src/vec/functions/ai/ai_adapter.h 
b/be/src/vec/functions/ai/ai_adapter.h
index 8faffc200e6..fa06c9551d3 100644
--- a/be/src/vec/functions/ai/ai_adapter.h
+++ b/be/src/vec/functions/ai/ai_adapter.h
@@ -268,33 +268,14 @@ public:
         doc.SetObject();
         auto& allocator = doc.GetAllocator();
 
-        if (!_config.model_name.empty()) {
-            doc.AddMember("model", 
rapidjson::Value(_config.model_name.c_str(), allocator),
-                          allocator);
-        }
-
-        // If 'temperature' and 'max_tokens' are set, add them to the request 
body.
-        if (_config.temperature != -1) {
-            doc.AddMember("temperature", _config.temperature, allocator);
-        }
-        if (_config.max_tokens != -1) {
-            doc.AddMember("max_tokens", _config.max_tokens, allocator);
-        }
-
-        rapidjson::Value messages(rapidjson::kArrayType);
-        if (system_prompt && *system_prompt) {
-            rapidjson::Value sys_msg(rapidjson::kObjectType);
-            sys_msg.AddMember("role", "system", allocator);
-            sys_msg.AddMember("content", rapidjson::Value(system_prompt, 
allocator), allocator);
-            messages.PushBack(sys_msg, allocator);
-        }
-        for (const auto& input : inputs) {
-            rapidjson::Value message(rapidjson::kObjectType);
-            message.AddMember("role", "user", allocator);
-            message.AddMember("content", rapidjson::Value(input.c_str(), 
allocator), allocator);
-            messages.PushBack(message, allocator);
+        std::string end_point = _config.endpoint;
+        if (end_point.ends_with("chat") || end_point.ends_with("generate")) {
+            RETURN_IF_ERROR(
+                    build_ollama_request(doc, allocator, inputs, 
system_prompt, request_body));
+        } else {
+            RETURN_IF_ERROR(
+                    build_default_request(doc, allocator, inputs, 
system_prompt, request_body));
         }
-        doc.AddMember("messages", messages, allocator);
 
         rapidjson::StringBuffer buffer;
         rapidjson::Writer<rapidjson::StringBuffer> writer(buffer);
@@ -329,30 +310,23 @@ public:
                     results.emplace_back(choices[i]["text"].GetString());
                 }
             }
-
-            if (!results.empty()) {
-                return Status::OK();
-            }
-        }
-
-        // Format 2: Simple response with just "text" or "content" field
-        if (doc.HasMember("text") && doc["text"].IsString()) {
+        } else if (doc.HasMember("text") && doc["text"].IsString()) {
+            // Format 2: Simple response with just "text" or "content" field
             results.emplace_back(doc["text"].GetString());
-            return Status::OK();
-        }
-
-        if (doc.HasMember("content") && doc["content"].IsString()) {
+        } else if (doc.HasMember("content") && doc["content"].IsString()) {
             results.emplace_back(doc["content"].GetString());
-            return Status::OK();
-        }
-
-        // Format 3: Response field (Ollama format)
-        if (doc.HasMember("response") && doc["response"].IsString()) {
+        } else if (doc.HasMember("response") && doc["response"].IsString()) {
+            // Format 3: Response field (Ollama `generate` format)
             results.emplace_back(doc["response"].GetString());
-            return Status::OK();
+        } else if (doc.HasMember("message") && doc["message"].IsObject() &&
+                   doc["message"].HasMember("content") && 
doc["message"]["content"].IsString()) {
+            // Format 4: message/content field (Ollama `chat` format)
+            results.emplace_back(doc["message"]["content"].GetString());
+        } else {
+            return Status::NotSupported("Unsupported response format from 
local AI.");
         }
 
-        return Status::NotSupported("Unsupported response format from local 
AI.");
+        return Status::OK();
     }
 
     Status build_embedding_request(const std::vector<std::string>& inputs,
@@ -408,6 +382,15 @@ public:
                                std::back_inserter(results.emplace_back()),
                                [](const auto& val) { return val.GetFloat(); });
             }
+        } else if (doc.HasMember("embeddings") && doc["embeddings"].IsArray()) 
{
+            // "embeddings":[[0.1, 0.2, ...]]
+            results.reserve(1);
+            for (int i = 0; i < doc["embeddings"].Size(); i++) {
+                embedding = doc["embeddings"][i];
+                std::transform(embedding.Begin(), embedding.End(),
+                               std::back_inserter(results.emplace_back()),
+                               [](const auto& val) { return val.GetFloat(); });
+            }
         } else if (doc.HasMember("embedding") && doc["embedding"].IsArray()) {
             // "embedding":[0.1, 0.2, ...]
             results.reserve(1);
@@ -422,6 +405,127 @@ public:
 
         return Status::OK();
     }
+
+private:
+    Status build_ollama_request(rapidjson::Document& doc,
+                                rapidjson::Document::AllocatorType& allocator,
+                                const std::vector<std::string>& inputs,
+                                const char* const system_prompt, std::string& 
request_body) const {
+        /*
+        for endpoints end_with `/chat` like 'http://localhost:11434/api/chat':
+        {
+            "model": <model_name>,
+            "stream": false,
+            "think": false,
+            "options": {
+                "temperature": <temperature>,
+                "max_token": <max_token>
+            },
+            "messages": [
+                {"role": "system", "content": <system_prompt>},
+                {"role": "user", "content": <user_prompt>}
+            ]
+        }
+        
+        for endpoints end_with `/generate` like 
'http://localhost:11434/api/generate':
+        {
+            "model": <model_name>,
+            "stream": false,
+            "think": false
+            "options": {
+                "temperature": <temperature>,
+                "max_token": <max_token>
+            },
+            "system": <system_prompt>,
+            "prompt": <user_prompt>
+        }
+        */
+
+        // For Ollama, only the prompt section ("system" + "prompt" or "role" 
+ "content") is affected by the endpoint;
+        // The rest remains identical.
+        doc.AddMember("model", rapidjson::Value(_config.model_name.c_str(), 
allocator), allocator);
+        doc.AddMember("stream", false, allocator);
+        doc.AddMember("think", false, allocator);
+
+        // option section
+        rapidjson::Value options(rapidjson::kObjectType);
+        if (_config.temperature != -1) {
+            options.AddMember("temperature", _config.temperature, allocator);
+        }
+        if (_config.max_tokens != -1) {
+            options.AddMember("max_token", _config.max_tokens, allocator);
+        }
+        doc.AddMember("options", options, allocator);
+
+        // prompt section
+        if (_config.endpoint.ends_with("chat")) {
+            rapidjson::Value messages(rapidjson::kArrayType);
+            if (system_prompt && *system_prompt) {
+                rapidjson::Value sys_msg(rapidjson::kObjectType);
+                sys_msg.AddMember("role", "system", allocator);
+                sys_msg.AddMember("content", rapidjson::Value(system_prompt, 
allocator), allocator);
+                messages.PushBack(sys_msg, allocator);
+            }
+            for (const auto& input : inputs) {
+                rapidjson::Value message(rapidjson::kObjectType);
+                message.AddMember("role", "user", allocator);
+                message.AddMember("content", rapidjson::Value(input.c_str(), 
allocator), allocator);
+                messages.PushBack(message, allocator);
+            }
+            doc.AddMember("messages", messages, allocator);
+        } else {
+            if (system_prompt && *system_prompt) {
+                doc.AddMember("system", rapidjson::Value(system_prompt, 
allocator), allocator);
+            }
+            doc.AddMember("prompt", rapidjson::Value(inputs[0].c_str(), 
allocator), allocator);
+        }
+
+        return Status::OK();
+    }
+
+    Status build_default_request(rapidjson::Document& doc,
+                                 rapidjson::Document::AllocatorType& allocator,
+                                 const std::vector<std::string>& inputs,
+                                 const char* const system_prompt, std::string& 
request_body) const {
+        /*
+        Default format(OpenAI-compatible):
+        {
+            "model": <model_name>,
+            "temperature": <temperature>,
+            "max_tokens": <max_tokens>,
+            "messages": [
+                {"role": "system", "content": <system_prompt>},
+                {"role": "user", "content": <user_prompt>}
+            ]
+        }
+        */
+
+        doc.AddMember("model", rapidjson::Value(_config.model_name.c_str(), 
allocator), allocator);
+
+        // If 'temperature' and 'max_tokens' are set, add them to the request 
body.
+        if (_config.temperature != -1) {
+            doc.AddMember("temperature", _config.temperature, allocator);
+        }
+        if (_config.max_tokens != -1) {
+            doc.AddMember("max_tokens", _config.max_tokens, allocator);
+        }
+
+        rapidjson::Value messages(rapidjson::kArrayType);
+        if (system_prompt && *system_prompt) {
+            rapidjson::Value sys_msg(rapidjson::kObjectType);
+            sys_msg.AddMember("role", "system", allocator);
+            sys_msg.AddMember("content", rapidjson::Value(system_prompt, 
allocator), allocator);
+            messages.PushBack(sys_msg, allocator);
+        }
+        for (const auto& input : inputs) {
+            rapidjson::Value message(rapidjson::kObjectType);
+            message.AddMember("role", "user", allocator);
+            message.AddMember("content", rapidjson::Value(input.c_str(), 
allocator), allocator);
+            messages.PushBack(message, allocator);
+        }
+        doc.AddMember("messages", messages, allocator);
+        return Status::OK();
+    }
 };
 
 // The OpenAI API format can be reused with some compatible AIs.
diff --git a/be/src/vec/functions/ai/ai_functions.h 
b/be/src/vec/functions/ai/ai_functions.h
index 01efe65f938..4729faafc33 100644
--- a/be/src/vec/functions/ai/ai_functions.h
+++ b/be/src/vec/functions/ai/ai_functions.h
@@ -170,6 +170,18 @@ public:
         return Status::OK();
     }
 
+protected:
+    // The endpoint `v1/completions` does not support `system_prompt`.
+    // To ensure a clear structure and stable AI results.
+    // Convert from `v1/completions` to `v1/chat/completions`
+    static void normalize_endpoint(TAIResource& config) {
+        if (config.endpoint.ends_with("v1/completions")) {
+            static constexpr std::string_view legacy_suffix = "v1/completions";
+            config.endpoint.replace(config.endpoint.size() - 
legacy_suffix.size(),
+                                    legacy_suffix.size(), 
"v1/chat/completions");
+        }
+    }
+
 private:
     // Trim whitespace and newlines from string
     static void trim_string(std::string& str) {
@@ -201,6 +213,8 @@ private:
         }
         config = it->second;
 
+        normalize_endpoint(config);
+
         // 2. Create an adapter based on provider_type
         adapter = AIAdapterFactory::create_adapter(config.provider_type);
         if (!adapter) {
diff --git a/be/test/ai/ai_adapter_test.cpp b/be/test/ai/ai_adapter_test.cpp
index e5adb6346d5..fb0e3f8e849 100644
--- a/be/test/ai/ai_adapter_test.cpp
+++ b/be/test/ai/ai_adapter_test.cpp
@@ -41,12 +41,13 @@ private:
     std::string _content_type;
 };
 
-TEST(AI_ADAPTER_TEST, local_adapter_request) {
+TEST(AI_ADAPTER_TEST, local_adapter_request_chat_endpoint) {
     LocalAdapter adapter;
     TAIResource config;
     config.model_name = "ollama";
     config.temperature = 0.7;
     config.max_tokens = 128;
+    config.endpoint = "http://localhost:11434/api/chat";;
     adapter.init(config);
 
     // header test
@@ -67,40 +68,89 @@ TEST(AI_ADAPTER_TEST, local_adapter_request) {
     ASSERT_FALSE(doc.HasParseError()) << "JSON parse error";
     ASSERT_TRUE(doc.IsObject()) << "JSON is not an object";
 
-    // model name
+    // general flags
     ASSERT_TRUE(doc.HasMember("model")) << "Missing model field";
     ASSERT_TRUE(doc["model"].IsString()) << "Model field is not a string";
     ASSERT_STREQ(doc["model"].GetString(), "ollama");
+    ASSERT_TRUE(doc.HasMember("stream")) << "Missing stream field";
+    ASSERT_TRUE(doc["stream"].IsBool()) << "Stream field is not a bool";
+    ASSERT_FALSE(doc["stream"].GetBool());
+    ASSERT_TRUE(doc.HasMember("think")) << "Missing think field";
+    ASSERT_TRUE(doc["think"].IsBool()) << "Think field is not a bool";
+    ASSERT_FALSE(doc["think"].GetBool());
+
+    // options (temperature + max_token)
+    ASSERT_FALSE(doc.HasMember("temperature")) << "Temperature should be 
nested in options";
+    ASSERT_FALSE(doc.HasMember("max_tokens")) << "Max tokens should be nested 
in options";
+    ASSERT_TRUE(doc.HasMember("options")) << "Missing options field";
+    ASSERT_TRUE(doc["options"].IsObject()) << "Options is not an object";
+    const auto& options = doc["options"];
+    ASSERT_TRUE(options.HasMember("temperature")) << "Missing 
options.temperature field";
+    ASSERT_TRUE(options["temperature"].IsNumber()) << "options.temperature is 
not a number";
+    ASSERT_DOUBLE_EQ(options["temperature"].GetDouble(), 0.7);
+    ASSERT_TRUE(options.HasMember("max_token")) << "Missing options.max_token 
field";
+    ASSERT_TRUE(options["max_token"].IsInt()) << "options.max_token is not an 
integer";
+    ASSERT_EQ(options["max_token"].GetInt(), 128);
 
-    // temperature
-    ASSERT_TRUE(doc.HasMember("temperature")) << "Missing temperature field";
-    ASSERT_TRUE(doc["temperature"].IsNumber()) << "Temperature field is not a 
number";
-    ASSERT_DOUBLE_EQ(doc["temperature"].GetDouble(), 0.7);
+    // content
+    ASSERT_TRUE(doc.HasMember("messages")) << "Missing messages field";
+    ASSERT_TRUE(doc["messages"].IsArray()) << "Messages is not an array";
+    ASSERT_GE(doc["messages"].Size(), 2) << "Messages should contain system 
and user prompts";
+    const auto& first_message = doc["messages"][0];
+    ASSERT_TRUE(first_message.HasMember("role")) << "Message missing role 
field";
+    ASSERT_TRUE(first_message["role"].IsString()) << "Role field is not a 
string";
+    ASSERT_STREQ(first_message["role"].GetString(), "system");
+    ASSERT_TRUE(first_message.HasMember("content")) << "Message missing 
content field";
+    ASSERT_TRUE(first_message["content"].IsString()) << "Content field is not 
a string";
+    ASSERT_STREQ(first_message["content"].GetString(), 
FunctionAISummarize::system_prompt);
+
+    const auto& user_message = doc["messages"][doc["messages"].Size() - 1];
+    ASSERT_TRUE(user_message.HasMember("role")) << "User message missing role 
field";
+    ASSERT_TRUE(user_message["role"].IsString()) << "User role field is not a 
string";
+    ASSERT_STREQ(user_message["role"].GetString(), "user");
+    ASSERT_TRUE(user_message.HasMember("content")) << "User message missing 
content field";
+    ASSERT_TRUE(user_message["content"].IsString()) << "User content field is 
not a string";
+    ASSERT_STREQ(user_message["content"].GetString(), inputs[0].c_str());
+}
 
-    // max token
-    ASSERT_TRUE(doc.HasMember("max_tokens")) << "Missing max_tokens field";
-    ASSERT_TRUE(doc["max_tokens"].IsInt()) << "Max_tokens field is not an 
integer";
-    ASSERT_EQ(doc["max_tokens"].GetInt(), 128);
+TEST(AI_ADAPTER_TEST, local_adapter_request_generate_endpoint) {
+    LocalAdapter adapter;
+    TAIResource config;
+    config.model_name = "ollama";
+    config.temperature = 0.8;
+    config.max_tokens = 64;
+    config.endpoint = "http://localhost:11434/api/generate";;
+    adapter.init(config);
 
-    // content
-    if (doc.HasMember("messages")) {
-        ASSERT_TRUE(doc["messages"].IsArray()) << "Messages is not an array";
-        ASSERT_GT(doc["messages"].Size(), 0) << "Messages array is empty";
-        // system_prompt
-        const auto& first_message = doc["messages"][0];
-        ASSERT_TRUE(first_message.HasMember("role")) << "Message missing role 
field";
-        ASSERT_TRUE(first_message["role"].IsString()) << "Role field is not a 
string";
-        ASSERT_STREQ(first_message["role"].GetString(), "system");
-        ASSERT_STREQ(first_message["content"].GetString(), 
FunctionAISummarize::system_prompt);
-
-        const auto& last_message = doc["messages"][doc["messages"].Size() - 1];
-        ASSERT_TRUE(last_message.HasMember("content")) << "Message missing 
content field";
-        ASSERT_TRUE(last_message["content"].IsString()) << "Content field is 
not a string";
-        ASSERT_STREQ(last_message["content"].GetString(), inputs[0].c_str());
-    } else if (doc.HasMember("prompt")) {
-        ASSERT_TRUE(doc["prompt"].IsString()) << "Prompt field is not a 
string";
-        ASSERT_STREQ(doc["prompt"].GetString(), inputs[0].c_str());
-    }
+    std::vector<std::string> inputs = {"hello world"};
+    std::string request_body;
+    Status st =
+            adapter.build_request_payload(inputs, 
FunctionAISummarize::system_prompt, request_body);
+    ASSERT_TRUE(st.ok());
+
+    rapidjson::Document doc;
+    doc.Parse(request_body.c_str());
+    ASSERT_FALSE(doc.HasParseError()) << "JSON parse error";
+    ASSERT_TRUE(doc.IsObject()) << "JSON is not an object";
+
+    ASSERT_TRUE(doc.HasMember("model")) << "Missing model field";
+    ASSERT_STREQ(doc["model"].GetString(), "ollama");
+    ASSERT_TRUE(doc.HasMember("system")) << "Missing system field";
+    ASSERT_TRUE(doc["system"].IsString()) << "System field is not a string";
+    ASSERT_STREQ(doc["system"].GetString(), 
FunctionAISummarize::system_prompt);
+    ASSERT_TRUE(doc.HasMember("prompt")) << "Missing prompt field";
+    ASSERT_TRUE(doc["prompt"].IsString()) << "Prompt field is not a string";
+    ASSERT_STREQ(doc["prompt"].GetString(), inputs[0].c_str());
+
+    ASSERT_FALSE(doc.HasMember("messages")) << "Generate endpoint should not 
include messages";
+
+    ASSERT_TRUE(doc.HasMember("options")) << "Missing options field";
+    ASSERT_TRUE(doc["options"].IsObject()) << "Options is not an object";
+    const auto& options = doc["options"];
+    ASSERT_TRUE(options.HasMember("temperature")) << "Missing 
options.temperature field";
+    ASSERT_DOUBLE_EQ(options["temperature"].GetDouble(), 0.8);
+    ASSERT_TRUE(options.HasMember("max_token")) << "Missing options.max_token 
field";
+    ASSERT_EQ(options["max_token"].GetInt(), 64);
 }
 
 TEST(AI_ADAPTER_TEST, local_adapter_parse_response) {
@@ -128,13 +178,67 @@ TEST(AI_ADAPTER_TEST, local_adapter_parse_response) {
     ASSERT_EQ(results.size(), 1);
     ASSERT_EQ(results[0], "simple result");
 
-    // Ollama type
+    // Ollama response type
     resp = R"({"response":"ollama result"})";
     results.clear();
     st = adapter.parse_response(resp, results);
     ASSERT_TRUE(st.ok());
     ASSERT_EQ(results.size(), 1);
     ASSERT_EQ(results[0], "ollama result");
+
+    // Ollama chat message type
+    resp = R"({"message":{"content":"ollama chat"}})";
+    results.clear();
+    st = adapter.parse_response(resp, results);
+    ASSERT_TRUE(st.ok());
+    ASSERT_EQ(results.size(), 1);
+    ASSERT_EQ(results[0], "ollama chat");
+}
+
+TEST(AI_ADAPTER_TEST, local_adapter_request_default_endpoint) {
+    LocalAdapter adapter;
+    TAIResource config;
+    config.model_name = "local-default";
+    config.temperature = 0.3;
+    config.max_tokens = 42;
+    config.endpoint = "http://localhost:8000/v1/completions";;
+    adapter.init(config);
+
+    std::vector<std::string> inputs = {"default prompt"};
+    std::string request_body;
+    Status st =
+            adapter.build_request_payload(inputs, 
FunctionAISummarize::system_prompt, request_body);
+    ASSERT_TRUE(st.ok()) << st.to_string();
+
+    rapidjson::Document doc;
+    doc.Parse(request_body.c_str());
+    ASSERT_FALSE(doc.HasParseError()) << "JSON parse error";
+    ASSERT_TRUE(doc.IsObject()) << "JSON is not an object";
+
+    ASSERT_TRUE(doc.HasMember("model"));
+    ASSERT_STREQ(doc["model"].GetString(), "local-default");
+
+    ASSERT_TRUE(doc.HasMember("temperature"));
+    ASSERT_DOUBLE_EQ(doc["temperature"].GetDouble(), 0.3);
+
+    ASSERT_TRUE(doc.HasMember("max_tokens"));
+    ASSERT_EQ(doc["max_tokens"].GetInt(), 42);
+
+    ASSERT_TRUE(doc.HasMember("messages"));
+    ASSERT_TRUE(doc["messages"].IsArray());
+    ASSERT_GE(doc["messages"].Size(), 2);
+
+    const auto& system_msg = doc["messages"][0];
+    ASSERT_TRUE(system_msg.HasMember("role"));
+    ASSERT_STREQ(system_msg["role"].GetString(), "system");
+    ASSERT_TRUE(system_msg.HasMember("content"));
+    ASSERT_STREQ(system_msg["content"].GetString(), 
FunctionAISummarize::system_prompt);
+
+    const auto& user_msg = doc["messages"][doc["messages"].Size() - 1];
+    ASSERT_TRUE(user_msg.HasMember("role"));
+    ASSERT_STREQ(user_msg["role"].GetString(), "user");
+    ASSERT_TRUE(user_msg.HasMember("content"));
+    ASSERT_STREQ(user_msg["content"].GetString(), inputs[0].c_str());
 }
 
 TEST(AI_ADAPTER_TEST, openai_adapter_completions_request) {
diff --git a/be/test/ai/ai_function_test.cpp b/be/test/ai/ai_function_test.cpp
index 14409781417..a7e86c4724d 100644
--- a/be/test/ai/ai_function_test.cpp
+++ b/be/test/ai/ai_function_test.cpp
@@ -15,6 +15,7 @@
 // specific language governing permissions and limitations
 // under the License.
 
+#include <gen_cpp/PaloInternalService_types.h>
 #include <gtest/gtest.h>
 
 #include <string>
@@ -638,4 +639,27 @@ TEST(AIFunctionTest, ReturnTypeTest) {
     ASSERT_EQ(ret_type->get_family_name(), "Array");
 }
 
+class FunctionAISentimentTestHelper : public FunctionAISentiment {
+public:
+    using FunctionAISentiment::normalize_endpoint;
+};
+
+TEST(AIFunctionTest, NormalizeLegacyCompletionsEndpoint) {
+    TAIResource resource;
+    resource.endpoint = "https://api.openai.com/v1/completions";;
+    FunctionAISentimentTestHelper::normalize_endpoint(resource);
+    ASSERT_EQ(resource.endpoint, "https://api.openai.com/v1/chat/completions";);
+}
+
+TEST(AIFunctionTest, NormalizeEndpointNoopForOtherPaths) {
+    TAIResource resource;
+    resource.endpoint = "https://api.openai.com/v1/chat/completions";;
+    FunctionAISentimentTestHelper::normalize_endpoint(resource);
+    ASSERT_EQ(resource.endpoint, "https://api.openai.com/v1/chat/completions";);
+
+    resource.endpoint = "https://localhost/v1/responses";;
+    FunctionAISentimentTestHelper::normalize_endpoint(resource);
+    ASSERT_EQ(resource.endpoint, "https://localhost/v1/responses";);
+}
+
 } // namespace doris::vectorized
\ No newline at end of file
diff --git a/be/test/ai/embed_test.cpp b/be/test/ai/embed_test.cpp
index 6e4a818484c..d5b70292a9c 100644
--- a/be/test/ai/embed_test.cpp
+++ b/be/test/ai/embed_test.cpp
@@ -186,6 +186,17 @@ TEST(EMBED_TEST, local_adapter_parse_embedding_response) {
     ASSERT_EQ(results[0].size(), 3);
     ASSERT_FLOAT_EQ(results[0][0], 0.1F);
     ASSERT_FLOAT_EQ(results[0][2], 0.3F);
+
+    std::string resp3 = R"({
+        "embeddings": [[0.6, 0.7]]
+    })";
+    results.clear();
+    st = adapter.parse_embedding_response(resp3, results);
+    ASSERT_TRUE(st.ok()) << "Format 3 failed: " << st.to_string();
+    ASSERT_EQ(results.size(), 1);
+    ASSERT_EQ(results[0].size(), 2);
+    ASSERT_FLOAT_EQ(results[0][0], 0.6F);
+    ASSERT_FLOAT_EQ(results[0][1], 0.7F);
 }
 
 TEST(EMBED_TEST, openai_adapter_embedding_request) {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to