This is an automated email from the ASF dual-hosted git repository.
yiguolei pushed a commit to branch branch-4.0
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/branch-4.0 by this push:
new f841fb6a9e5 branch-4.0: [Fix](AI_Func) Fix the error prompt-building
and response-parsing process of Local_Adapter #58492 (#58796)
f841fb6a9e5 is described below
commit f841fb6a9e5d3b9aafc303033b06aa4d24c18cd5
Author: github-actions[bot]
<41898282+github-actions[bot]@users.noreply.github.com>
AuthorDate: Tue Dec 9 17:44:19 2025 +0800
branch-4.0: [Fix](AI_Func) Fix the error prompt-building and
response-parsing process of Local_Adapter #58492 (#58796)
Cherry-picked from #58492
Co-authored-by: linrrarity <[email protected]>
---
be/src/vec/functions/ai/ai_adapter.h | 194 +++++++++++++++++++++++++--------
be/src/vec/functions/ai/ai_functions.h | 14 +++
be/test/ai/ai_adapter_test.cpp | 164 +++++++++++++++++++++++-----
be/test/ai/ai_function_test.cpp | 24 ++++
be/test/ai/embed_test.cpp | 11 ++
5 files changed, 332 insertions(+), 75 deletions(-)
diff --git a/be/src/vec/functions/ai/ai_adapter.h
b/be/src/vec/functions/ai/ai_adapter.h
index 8faffc200e6..fa06c9551d3 100644
--- a/be/src/vec/functions/ai/ai_adapter.h
+++ b/be/src/vec/functions/ai/ai_adapter.h
@@ -268,33 +268,14 @@ public:
doc.SetObject();
auto& allocator = doc.GetAllocator();
- if (!_config.model_name.empty()) {
- doc.AddMember("model",
rapidjson::Value(_config.model_name.c_str(), allocator),
- allocator);
- }
-
- // If 'temperature' and 'max_tokens' are set, add them to the request
body.
- if (_config.temperature != -1) {
- doc.AddMember("temperature", _config.temperature, allocator);
- }
- if (_config.max_tokens != -1) {
- doc.AddMember("max_tokens", _config.max_tokens, allocator);
- }
-
- rapidjson::Value messages(rapidjson::kArrayType);
- if (system_prompt && *system_prompt) {
- rapidjson::Value sys_msg(rapidjson::kObjectType);
- sys_msg.AddMember("role", "system", allocator);
- sys_msg.AddMember("content", rapidjson::Value(system_prompt,
allocator), allocator);
- messages.PushBack(sys_msg, allocator);
- }
- for (const auto& input : inputs) {
- rapidjson::Value message(rapidjson::kObjectType);
- message.AddMember("role", "user", allocator);
- message.AddMember("content", rapidjson::Value(input.c_str(),
allocator), allocator);
- messages.PushBack(message, allocator);
+ std::string end_point = _config.endpoint;
+ if (end_point.ends_with("chat") || end_point.ends_with("generate")) {
+ RETURN_IF_ERROR(
+ build_ollama_request(doc, allocator, inputs,
system_prompt, request_body));
+ } else {
+ RETURN_IF_ERROR(
+ build_default_request(doc, allocator, inputs,
system_prompt, request_body));
}
- doc.AddMember("messages", messages, allocator);
rapidjson::StringBuffer buffer;
rapidjson::Writer<rapidjson::StringBuffer> writer(buffer);
@@ -329,30 +310,23 @@ public:
results.emplace_back(choices[i]["text"].GetString());
}
}
-
- if (!results.empty()) {
- return Status::OK();
- }
- }
-
- // Format 2: Simple response with just "text" or "content" field
- if (doc.HasMember("text") && doc["text"].IsString()) {
+ } else if (doc.HasMember("text") && doc["text"].IsString()) {
+ // Format 2: Simple response with just "text" or "content" field
results.emplace_back(doc["text"].GetString());
- return Status::OK();
- }
-
- if (doc.HasMember("content") && doc["content"].IsString()) {
+ } else if (doc.HasMember("content") && doc["content"].IsString()) {
results.emplace_back(doc["content"].GetString());
- return Status::OK();
- }
-
- // Format 3: Response field (Ollama format)
- if (doc.HasMember("response") && doc["response"].IsString()) {
+ } else if (doc.HasMember("response") && doc["response"].IsString()) {
+ // Format 3: Response field (Ollama `generate` format)
results.emplace_back(doc["response"].GetString());
- return Status::OK();
+ } else if (doc.HasMember("message") && doc["message"].IsObject() &&
+ doc["message"].HasMember("content") &&
doc["message"]["content"].IsString()) {
+ // Format 4: message/content field (Ollama `chat` format)
+ results.emplace_back(doc["message"]["content"].GetString());
+ } else {
+ return Status::NotSupported("Unsupported response format from
local AI.");
}
- return Status::NotSupported("Unsupported response format from local
AI.");
+ return Status::OK();
}
Status build_embedding_request(const std::vector<std::string>& inputs,
@@ -408,6 +382,15 @@ public:
std::back_inserter(results.emplace_back()),
[](const auto& val) { return val.GetFloat(); });
}
+ } else if (doc.HasMember("embeddings") && doc["embeddings"].IsArray())
{
+ // "embeddings":[[0.1, 0.2, ...]]
+ results.reserve(1);
+ for (int i = 0; i < doc["embeddings"].Size(); i++) {
+ embedding = doc["embeddings"][i];
+ std::transform(embedding.Begin(), embedding.End(),
+ std::back_inserter(results.emplace_back()),
+ [](const auto& val) { return val.GetFloat(); });
+ }
} else if (doc.HasMember("embedding") && doc["embedding"].IsArray()) {
// "embedding":[0.1, 0.2, ...]
results.reserve(1);
@@ -422,6 +405,127 @@ public:
return Status::OK();
}
+
+private:
+ Status build_ollama_request(rapidjson::Document& doc,
+ rapidjson::Document::AllocatorType& allocator,
+ const std::vector<std::string>& inputs,
+ const char* const system_prompt, std::string&
request_body) const {
+ /*
+ for endpoints end_with `/chat` like 'http://localhost:11434/api/chat':
+ {
+ "model": <model_name>,
+ "stream": false,
+ "think": false,
+ "options": {
+ "temperature": <temperature>,
+ "max_token": <max_token>
+ },
+ "messages": [
+ {"role": "system", "content": <system_prompt>},
+ {"role": "user", "content": <user_prompt>}
+ ]
+ }
+
+ for endpoints end_with `/generate` like
'http://localhost:11434/api/generate':
+ {
+ "model": <model_name>,
+ "stream": false,
+ "think": false
+ "options": {
+ "temperature": <temperature>,
+ "max_token": <max_token>
+ },
+ "system": <system_prompt>,
+ "prompt": <user_prompt>
+ }
+ */
+
+ // For Ollama, only the prompt section ("system" + "prompt" or "role"
+ "content") is affected by the endpoint;
+ // The rest remains identical.
+ doc.AddMember("model", rapidjson::Value(_config.model_name.c_str(),
allocator), allocator);
+ doc.AddMember("stream", false, allocator);
+ doc.AddMember("think", false, allocator);
+
+ // option section
+ rapidjson::Value options(rapidjson::kObjectType);
+ if (_config.temperature != -1) {
+ options.AddMember("temperature", _config.temperature, allocator);
+ }
+ if (_config.max_tokens != -1) {
+ options.AddMember("max_token", _config.max_tokens, allocator);
+ }
+ doc.AddMember("options", options, allocator);
+
+ // prompt section
+ if (_config.endpoint.ends_with("chat")) {
+ rapidjson::Value messages(rapidjson::kArrayType);
+ if (system_prompt && *system_prompt) {
+ rapidjson::Value sys_msg(rapidjson::kObjectType);
+ sys_msg.AddMember("role", "system", allocator);
+ sys_msg.AddMember("content", rapidjson::Value(system_prompt,
allocator), allocator);
+ messages.PushBack(sys_msg, allocator);
+ }
+ for (const auto& input : inputs) {
+ rapidjson::Value message(rapidjson::kObjectType);
+ message.AddMember("role", "user", allocator);
+ message.AddMember("content", rapidjson::Value(input.c_str(),
allocator), allocator);
+ messages.PushBack(message, allocator);
+ }
+ doc.AddMember("messages", messages, allocator);
+ } else {
+ if (system_prompt && *system_prompt) {
+ doc.AddMember("system", rapidjson::Value(system_prompt,
allocator), allocator);
+ }
+ doc.AddMember("prompt", rapidjson::Value(inputs[0].c_str(),
allocator), allocator);
+ }
+
+ return Status::OK();
+ }
+
+ Status build_default_request(rapidjson::Document& doc,
+ rapidjson::Document::AllocatorType& allocator,
+ const std::vector<std::string>& inputs,
+ const char* const system_prompt, std::string&
request_body) const {
+ /*
+ Default format(OpenAI-compatible):
+ {
+ "model": <model_name>,
+ "temperature": <temperature>,
+ "max_tokens": <max_tokens>,
+ "messages": [
+ {"role": "system", "content": <system_prompt>},
+ {"role": "user", "content": <user_prompt>}
+ ]
+ }
+ */
+
+ doc.AddMember("model", rapidjson::Value(_config.model_name.c_str(),
allocator), allocator);
+
+ // If 'temperature' and 'max_tokens' are set, add them to the request
body.
+ if (_config.temperature != -1) {
+ doc.AddMember("temperature", _config.temperature, allocator);
+ }
+ if (_config.max_tokens != -1) {
+ doc.AddMember("max_tokens", _config.max_tokens, allocator);
+ }
+
+ rapidjson::Value messages(rapidjson::kArrayType);
+ if (system_prompt && *system_prompt) {
+ rapidjson::Value sys_msg(rapidjson::kObjectType);
+ sys_msg.AddMember("role", "system", allocator);
+ sys_msg.AddMember("content", rapidjson::Value(system_prompt,
allocator), allocator);
+ messages.PushBack(sys_msg, allocator);
+ }
+ for (const auto& input : inputs) {
+ rapidjson::Value message(rapidjson::kObjectType);
+ message.AddMember("role", "user", allocator);
+ message.AddMember("content", rapidjson::Value(input.c_str(),
allocator), allocator);
+ messages.PushBack(message, allocator);
+ }
+ doc.AddMember("messages", messages, allocator);
+ return Status::OK();
+ }
};
// The OpenAI API format can be reused with some compatible AIs.
diff --git a/be/src/vec/functions/ai/ai_functions.h
b/be/src/vec/functions/ai/ai_functions.h
index 01efe65f938..4729faafc33 100644
--- a/be/src/vec/functions/ai/ai_functions.h
+++ b/be/src/vec/functions/ai/ai_functions.h
@@ -170,6 +170,18 @@ public:
return Status::OK();
}
+protected:
+ // The endpoint `v1/completions` does not support `system_prompt`.
+ // To ensure a clear structure and stable AI results.
+ // Convert from `v1/completions` to `v1/chat/completions`
+ static void normalize_endpoint(TAIResource& config) {
+ if (config.endpoint.ends_with("v1/completions")) {
+ static constexpr std::string_view legacy_suffix = "v1/completions";
+ config.endpoint.replace(config.endpoint.size() -
legacy_suffix.size(),
+ legacy_suffix.size(),
"v1/chat/completions");
+ }
+ }
+
private:
// Trim whitespace and newlines from string
static void trim_string(std::string& str) {
@@ -201,6 +213,8 @@ private:
}
config = it->second;
+ normalize_endpoint(config);
+
// 2. Create an adapter based on provider_type
adapter = AIAdapterFactory::create_adapter(config.provider_type);
if (!adapter) {
diff --git a/be/test/ai/ai_adapter_test.cpp b/be/test/ai/ai_adapter_test.cpp
index e5adb6346d5..fb0e3f8e849 100644
--- a/be/test/ai/ai_adapter_test.cpp
+++ b/be/test/ai/ai_adapter_test.cpp
@@ -41,12 +41,13 @@ private:
std::string _content_type;
};
-TEST(AI_ADAPTER_TEST, local_adapter_request) {
+TEST(AI_ADAPTER_TEST, local_adapter_request_chat_endpoint) {
LocalAdapter adapter;
TAIResource config;
config.model_name = "ollama";
config.temperature = 0.7;
config.max_tokens = 128;
+ config.endpoint = "http://localhost:11434/api/chat";
adapter.init(config);
// header test
@@ -67,40 +68,89 @@ TEST(AI_ADAPTER_TEST, local_adapter_request) {
ASSERT_FALSE(doc.HasParseError()) << "JSON parse error";
ASSERT_TRUE(doc.IsObject()) << "JSON is not an object";
- // model name
+ // general flags
ASSERT_TRUE(doc.HasMember("model")) << "Missing model field";
ASSERT_TRUE(doc["model"].IsString()) << "Model field is not a string";
ASSERT_STREQ(doc["model"].GetString(), "ollama");
+ ASSERT_TRUE(doc.HasMember("stream")) << "Missing stream field";
+ ASSERT_TRUE(doc["stream"].IsBool()) << "Stream field is not a bool";
+ ASSERT_FALSE(doc["stream"].GetBool());
+ ASSERT_TRUE(doc.HasMember("think")) << "Missing think field";
+ ASSERT_TRUE(doc["think"].IsBool()) << "Think field is not a bool";
+ ASSERT_FALSE(doc["think"].GetBool());
+
+ // options (temperature + max_token)
+ ASSERT_FALSE(doc.HasMember("temperature")) << "Temperature should be
nested in options";
+ ASSERT_FALSE(doc.HasMember("max_tokens")) << "Max tokens should be nested
in options";
+ ASSERT_TRUE(doc.HasMember("options")) << "Missing options field";
+ ASSERT_TRUE(doc["options"].IsObject()) << "Options is not an object";
+ const auto& options = doc["options"];
+ ASSERT_TRUE(options.HasMember("temperature")) << "Missing
options.temperature field";
+ ASSERT_TRUE(options["temperature"].IsNumber()) << "options.temperature is
not a number";
+ ASSERT_DOUBLE_EQ(options["temperature"].GetDouble(), 0.7);
+ ASSERT_TRUE(options.HasMember("max_token")) << "Missing options.max_token
field";
+ ASSERT_TRUE(options["max_token"].IsInt()) << "options.max_token is not an
integer";
+ ASSERT_EQ(options["max_token"].GetInt(), 128);
- // temperature
- ASSERT_TRUE(doc.HasMember("temperature")) << "Missing temperature field";
- ASSERT_TRUE(doc["temperature"].IsNumber()) << "Temperature field is not a
number";
- ASSERT_DOUBLE_EQ(doc["temperature"].GetDouble(), 0.7);
+ // content
+ ASSERT_TRUE(doc.HasMember("messages")) << "Missing messages field";
+ ASSERT_TRUE(doc["messages"].IsArray()) << "Messages is not an array";
+ ASSERT_GE(doc["messages"].Size(), 2) << "Messages should contain system
and user prompts";
+ const auto& first_message = doc["messages"][0];
+ ASSERT_TRUE(first_message.HasMember("role")) << "Message missing role
field";
+ ASSERT_TRUE(first_message["role"].IsString()) << "Role field is not a
string";
+ ASSERT_STREQ(first_message["role"].GetString(), "system");
+ ASSERT_TRUE(first_message.HasMember("content")) << "Message missing
content field";
+ ASSERT_TRUE(first_message["content"].IsString()) << "Content field is not
a string";
+ ASSERT_STREQ(first_message["content"].GetString(),
FunctionAISummarize::system_prompt);
+
+ const auto& user_message = doc["messages"][doc["messages"].Size() - 1];
+ ASSERT_TRUE(user_message.HasMember("role")) << "User message missing role
field";
+ ASSERT_TRUE(user_message["role"].IsString()) << "User role field is not a
string";
+ ASSERT_STREQ(user_message["role"].GetString(), "user");
+ ASSERT_TRUE(user_message.HasMember("content")) << "User message missing
content field";
+ ASSERT_TRUE(user_message["content"].IsString()) << "User content field is
not a string";
+ ASSERT_STREQ(user_message["content"].GetString(), inputs[0].c_str());
+}
- // max token
- ASSERT_TRUE(doc.HasMember("max_tokens")) << "Missing max_tokens field";
- ASSERT_TRUE(doc["max_tokens"].IsInt()) << "Max_tokens field is not an
integer";
- ASSERT_EQ(doc["max_tokens"].GetInt(), 128);
+TEST(AI_ADAPTER_TEST, local_adapter_request_generate_endpoint) {
+ LocalAdapter adapter;
+ TAIResource config;
+ config.model_name = "ollama";
+ config.temperature = 0.8;
+ config.max_tokens = 64;
+ config.endpoint = "http://localhost:11434/api/generate";
+ adapter.init(config);
- // content
- if (doc.HasMember("messages")) {
- ASSERT_TRUE(doc["messages"].IsArray()) << "Messages is not an array";
- ASSERT_GT(doc["messages"].Size(), 0) << "Messages array is empty";
- // system_prompt
- const auto& first_message = doc["messages"][0];
- ASSERT_TRUE(first_message.HasMember("role")) << "Message missing role
field";
- ASSERT_TRUE(first_message["role"].IsString()) << "Role field is not a
string";
- ASSERT_STREQ(first_message["role"].GetString(), "system");
- ASSERT_STREQ(first_message["content"].GetString(),
FunctionAISummarize::system_prompt);
-
- const auto& last_message = doc["messages"][doc["messages"].Size() - 1];
- ASSERT_TRUE(last_message.HasMember("content")) << "Message missing
content field";
- ASSERT_TRUE(last_message["content"].IsString()) << "Content field is
not a string";
- ASSERT_STREQ(last_message["content"].GetString(), inputs[0].c_str());
- } else if (doc.HasMember("prompt")) {
- ASSERT_TRUE(doc["prompt"].IsString()) << "Prompt field is not a
string";
- ASSERT_STREQ(doc["prompt"].GetString(), inputs[0].c_str());
- }
+ std::vector<std::string> inputs = {"hello world"};
+ std::string request_body;
+ Status st =
+ adapter.build_request_payload(inputs,
FunctionAISummarize::system_prompt, request_body);
+ ASSERT_TRUE(st.ok());
+
+ rapidjson::Document doc;
+ doc.Parse(request_body.c_str());
+ ASSERT_FALSE(doc.HasParseError()) << "JSON parse error";
+ ASSERT_TRUE(doc.IsObject()) << "JSON is not an object";
+
+ ASSERT_TRUE(doc.HasMember("model")) << "Missing model field";
+ ASSERT_STREQ(doc["model"].GetString(), "ollama");
+ ASSERT_TRUE(doc.HasMember("system")) << "Missing system field";
+ ASSERT_TRUE(doc["system"].IsString()) << "System field is not a string";
+ ASSERT_STREQ(doc["system"].GetString(),
FunctionAISummarize::system_prompt);
+ ASSERT_TRUE(doc.HasMember("prompt")) << "Missing prompt field";
+ ASSERT_TRUE(doc["prompt"].IsString()) << "Prompt field is not a string";
+ ASSERT_STREQ(doc["prompt"].GetString(), inputs[0].c_str());
+
+ ASSERT_FALSE(doc.HasMember("messages")) << "Generate endpoint should not
include messages";
+
+ ASSERT_TRUE(doc.HasMember("options")) << "Missing options field";
+ ASSERT_TRUE(doc["options"].IsObject()) << "Options is not an object";
+ const auto& options = doc["options"];
+ ASSERT_TRUE(options.HasMember("temperature")) << "Missing
options.temperature field";
+ ASSERT_DOUBLE_EQ(options["temperature"].GetDouble(), 0.8);
+ ASSERT_TRUE(options.HasMember("max_token")) << "Missing options.max_token
field";
+ ASSERT_EQ(options["max_token"].GetInt(), 64);
}
TEST(AI_ADAPTER_TEST, local_adapter_parse_response) {
@@ -128,13 +178,67 @@ TEST(AI_ADAPTER_TEST, local_adapter_parse_response) {
ASSERT_EQ(results.size(), 1);
ASSERT_EQ(results[0], "simple result");
- // Ollama type
+ // Ollama response type
resp = R"({"response":"ollama result"})";
results.clear();
st = adapter.parse_response(resp, results);
ASSERT_TRUE(st.ok());
ASSERT_EQ(results.size(), 1);
ASSERT_EQ(results[0], "ollama result");
+
+ // Ollama chat message type
+ resp = R"({"message":{"content":"ollama chat"}})";
+ results.clear();
+ st = adapter.parse_response(resp, results);
+ ASSERT_TRUE(st.ok());
+ ASSERT_EQ(results.size(), 1);
+ ASSERT_EQ(results[0], "ollama chat");
+}
+
+TEST(AI_ADAPTER_TEST, local_adapter_request_default_endpoint) {
+ LocalAdapter adapter;
+ TAIResource config;
+ config.model_name = "local-default";
+ config.temperature = 0.3;
+ config.max_tokens = 42;
+ config.endpoint = "http://localhost:8000/v1/completions";
+ adapter.init(config);
+
+ std::vector<std::string> inputs = {"default prompt"};
+ std::string request_body;
+ Status st =
+ adapter.build_request_payload(inputs,
FunctionAISummarize::system_prompt, request_body);
+ ASSERT_TRUE(st.ok()) << st.to_string();
+
+ rapidjson::Document doc;
+ doc.Parse(request_body.c_str());
+ ASSERT_FALSE(doc.HasParseError()) << "JSON parse error";
+ ASSERT_TRUE(doc.IsObject()) << "JSON is not an object";
+
+ ASSERT_TRUE(doc.HasMember("model"));
+ ASSERT_STREQ(doc["model"].GetString(), "local-default");
+
+ ASSERT_TRUE(doc.HasMember("temperature"));
+ ASSERT_DOUBLE_EQ(doc["temperature"].GetDouble(), 0.3);
+
+ ASSERT_TRUE(doc.HasMember("max_tokens"));
+ ASSERT_EQ(doc["max_tokens"].GetInt(), 42);
+
+ ASSERT_TRUE(doc.HasMember("messages"));
+ ASSERT_TRUE(doc["messages"].IsArray());
+ ASSERT_GE(doc["messages"].Size(), 2);
+
+ const auto& system_msg = doc["messages"][0];
+ ASSERT_TRUE(system_msg.HasMember("role"));
+ ASSERT_STREQ(system_msg["role"].GetString(), "system");
+ ASSERT_TRUE(system_msg.HasMember("content"));
+ ASSERT_STREQ(system_msg["content"].GetString(),
FunctionAISummarize::system_prompt);
+
+ const auto& user_msg = doc["messages"][doc["messages"].Size() - 1];
+ ASSERT_TRUE(user_msg.HasMember("role"));
+ ASSERT_STREQ(user_msg["role"].GetString(), "user");
+ ASSERT_TRUE(user_msg.HasMember("content"));
+ ASSERT_STREQ(user_msg["content"].GetString(), inputs[0].c_str());
}
TEST(AI_ADAPTER_TEST, openai_adapter_completions_request) {
diff --git a/be/test/ai/ai_function_test.cpp b/be/test/ai/ai_function_test.cpp
index 14409781417..a7e86c4724d 100644
--- a/be/test/ai/ai_function_test.cpp
+++ b/be/test/ai/ai_function_test.cpp
@@ -15,6 +15,7 @@
// specific language governing permissions and limitations
// under the License.
+#include <gen_cpp/PaloInternalService_types.h>
#include <gtest/gtest.h>
#include <string>
@@ -638,4 +639,27 @@ TEST(AIFunctionTest, ReturnTypeTest) {
ASSERT_EQ(ret_type->get_family_name(), "Array");
}
+class FunctionAISentimentTestHelper : public FunctionAISentiment {
+public:
+ using FunctionAISentiment::normalize_endpoint;
+};
+
+TEST(AIFunctionTest, NormalizeLegacyCompletionsEndpoint) {
+ TAIResource resource;
+ resource.endpoint = "https://api.openai.com/v1/completions";
+ FunctionAISentimentTestHelper::normalize_endpoint(resource);
+ ASSERT_EQ(resource.endpoint, "https://api.openai.com/v1/chat/completions");
+}
+
+TEST(AIFunctionTest, NormalizeEndpointNoopForOtherPaths) {
+ TAIResource resource;
+ resource.endpoint = "https://api.openai.com/v1/chat/completions";
+ FunctionAISentimentTestHelper::normalize_endpoint(resource);
+ ASSERT_EQ(resource.endpoint, "https://api.openai.com/v1/chat/completions");
+
+ resource.endpoint = "https://localhost/v1/responses";
+ FunctionAISentimentTestHelper::normalize_endpoint(resource);
+ ASSERT_EQ(resource.endpoint, "https://localhost/v1/responses");
+}
+
} // namespace doris::vectorized
\ No newline at end of file
diff --git a/be/test/ai/embed_test.cpp b/be/test/ai/embed_test.cpp
index 6e4a818484c..d5b70292a9c 100644
--- a/be/test/ai/embed_test.cpp
+++ b/be/test/ai/embed_test.cpp
@@ -186,6 +186,17 @@ TEST(EMBED_TEST, local_adapter_parse_embedding_response) {
ASSERT_EQ(results[0].size(), 3);
ASSERT_FLOAT_EQ(results[0][0], 0.1F);
ASSERT_FLOAT_EQ(results[0][2], 0.3F);
+
+ std::string resp3 = R"({
+ "embeddings": [[0.6, 0.7]]
+ })";
+ results.clear();
+ st = adapter.parse_embedding_response(resp3, results);
+ ASSERT_TRUE(st.ok()) << "Format 3 failed: " << st.to_string();
+ ASSERT_EQ(results.size(), 1);
+ ASSERT_EQ(results[0].size(), 2);
+ ASSERT_FLOAT_EQ(results[0][0], 0.6F);
+ ASSERT_FLOAT_EQ(results[0][1], 0.7F);
}
TEST(EMBED_TEST, openai_adapter_embedding_request) {
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]