This is an automated email from the ASF dual-hosted git repository.
arawat pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/impala.git
The following commit(s) were added to refs/heads/master by this push:
new 3668a9517 IMPALA-13131: Azure OpenAI API expects 'api-key' instead of
'Authorization' in the request header
3668a9517 is described below
commit 3668a9517c4d8097591ed3b6fa672bf87faa77f6
Author: Abhishek Rawat <[email protected]>
AuthorDate: Fri Jun 7 07:13:58 2024 -0700
IMPALA-13131: Azure OpenAI API expects 'api-key' instead of 'Authorization'
in the request header
Updated the POST request when communicating with Azure Open AI
endpoint. The header now includes 'api-key: <api-key>' instead of
'Authorization: Bearer <api-key>'.
Also, removed 'model' as a required param for the Azure Open AI api
call. This is mainly because the endpoint contains deployment which
is basically already mapped to a model.
Testing:
- Updated existing unit test as per the Azure API reference
- Manually tested builtin 'ai_generate_text' using an Azure Open AI
deployment.
Change-Id: If9cc07940ce355d511bcf0ee615ff31042d13eb5
Reviewed-on: http://gerrit.cloudera.org:8080/21493
Reviewed-by: Impala Public Jenkins <[email protected]>
Tested-by: Impala Public Jenkins <[email protected]>
---
be/src/exprs/ai-functions-ir.cc | 59 +++++++++++++++++++++++--
be/src/exprs/ai-functions.h | 26 ++++++++++-
be/src/exprs/ai-functions.inline.h | 88 +++++++++++++++++++++-----------------
be/src/exprs/expr-test.cc | 78 +++++++++++++++++++++++----------
4 files changed, 183 insertions(+), 68 deletions(-)
diff --git a/be/src/exprs/ai-functions-ir.cc b/be/src/exprs/ai-functions-ir.cc
index 6def1a010..2c9f17398 100644
--- a/be/src/exprs/ai-functions-ir.cc
+++ b/be/src/exprs/ai-functions-ir.cc
@@ -60,6 +60,10 @@ const string
AiFunctions::AI_GENERATE_TXT_N_OVERRIDE_FORBIDDEN_ERROR =
string AiFunctions::ai_api_key_;
const char* AiFunctions::OPEN_AI_REQUEST_FIELD_CONTENT_TYPE_HEADER =
"Content-Type: application/json";
+const char* AiFunctions::OPEN_AI_REQUEST_AUTH_HEADER =
+ "Authorization: Bearer ";
+const char* AiFunctions::AZURE_OPEN_AI_REQUEST_AUTH_HEADER =
+ "api-key: ";
// other constants
static const StringVal NULL_STRINGVAL = StringVal::null();
@@ -85,6 +89,22 @@ bool AiFunctions::is_api_endpoint_supported(const
std::string_view& endpoint) {
gstrncasestr(endpoint.data(), OPEN_AI_PUBLIC_ENDPOINT, endpoint.size())
!= nullptr);
}
+AiFunctions::AI_PLATFORM AiFunctions::GetAiPlatformFromEndpoint(
+ const std::string_view& endpoint) {
+ // Only OpenAI endpoints are supported.
+ if (gstrncasestr(endpoint.data(), OPEN_AI_PUBLIC_ENDPOINT, endpoint.size())
!= nullptr)
+ return AiFunctions::AI_PLATFORM::OPEN_AI;
+ if (gstrncasestr(endpoint.data(), OPEN_AI_AZURE_ENDPOINT, endpoint.size())
!= nullptr)
+ return AiFunctions::AI_PLATFORM::AZURE_OPEN_AI;
+ return AiFunctions::AI_PLATFORM::UNSUPPORTED;
+}
+
+StringVal AiFunctions::copyErrorMessage(FunctionContext* ctx, const string&
errorMsg) {
+ return StringVal::CopyFrom(ctx,
+ reinterpret_cast<const uint8_t*>(errorMsg.c_str()),
+ errorMsg.length());
+}
+
string AiFunctions::AiGenerateTextParseOpenAiResponse(
const std::string_view& response) {
rapidjson::Document document;
@@ -123,17 +143,48 @@ string AiFunctions::AiGenerateTextParseOpenAiResponse(
return message[OPEN_AI_RESPONSE_FIELD_CONTENT].GetString();
}
+template <bool fastpath>
+StringVal AiFunctions::AiGenerateTextHelper(FunctionContext* ctx,
+ const StringVal& endpoint, const StringVal& prompt, const StringVal& model,
+ const StringVal& api_key_jceks_secret, const StringVal& params) {
+ std::string_view endpoint_sv(FLAGS_ai_endpoint);
+ // endpoint validation
+ if (!fastpath && endpoint.ptr != nullptr && endpoint.len != 0) {
+ endpoint_sv = std::string_view(reinterpret_cast<char*>(endpoint.ptr),
endpoint.len);
+ // Simple validation for endpoint. It should start with https://
+ if (!is_api_endpoint_valid(endpoint_sv)) {
+ LOG(ERROR) << "AI Generate Text: \ninvalid protocol: " << endpoint_sv;
+ return StringVal(AI_GENERATE_TXT_INVALID_PROTOCOL_ERROR.c_str());
+ }
+ }
+ AI_PLATFORM platform = GetAiPlatformFromEndpoint(endpoint_sv);
+ switch(platform) {
+ case AI_PLATFORM::OPEN_AI:
+ return AiGenerateTextInternal<fastpath, AI_PLATFORM::OPEN_AI>(
+ ctx, endpoint_sv, prompt, model, api_key_jceks_secret, params,
false);
+ case AI_PLATFORM::AZURE_OPEN_AI:
+ return AiGenerateTextInternal<fastpath, AI_PLATFORM::AZURE_OPEN_AI>(
+ ctx, endpoint_sv, prompt, model, api_key_jceks_secret, params,
false);
+ default:
+ if (fastpath) {
+ DCHECK(false) << "Default endpoint " << FLAGS_ai_endpoint << "must be
supported";
+ }
+ LOG(ERROR) << "AI Generate Text: \nunsupported endpoint: " <<
endpoint_sv;
+ return StringVal(AI_GENERATE_TXT_UNSUPPORTED_ENDPOINT_ERROR.c_str());
+ }
+}
+
StringVal AiFunctions::AiGenerateText(FunctionContext* ctx, const StringVal&
endpoint,
const StringVal& prompt, const StringVal& model,
const StringVal& api_key_jceks_secret, const StringVal& params) {
- return AiGenerateTextInternal<false>(
- ctx, endpoint, prompt, model, api_key_jceks_secret, params, false);
+ return AiGenerateTextHelper<false>(
+ ctx, endpoint, prompt, model, api_key_jceks_secret, params);
}
StringVal AiFunctions::AiGenerateTextDefault(
FunctionContext* ctx, const StringVal& prompt) {
- return AiGenerateTextInternal<true>(
- ctx, NULL_STRINGVAL, prompt, NULL_STRINGVAL, NULL_STRINGVAL,
NULL_STRINGVAL, false);
+ return AiGenerateTextHelper<true>(
+ ctx, NULL_STRINGVAL, prompt, NULL_STRINGVAL, NULL_STRINGVAL,
NULL_STRINGVAL);
}
} // namespace impala
diff --git a/be/src/exprs/ai-functions.h b/be/src/exprs/ai-functions.h
index 0e6396b40..1e3fcf8fd 100644
--- a/be/src/exprs/ai-functions.h
+++ b/be/src/exprs/ai-functions.h
@@ -37,6 +37,16 @@ class AiFunctions {
static const string AI_GENERATE_TXT_MSG_OVERRIDE_FORBIDDEN_ERROR;
static const string AI_GENERATE_TXT_N_OVERRIDE_FORBIDDEN_ERROR;
static const char* OPEN_AI_REQUEST_FIELD_CONTENT_TYPE_HEADER;
+ static const char* OPEN_AI_REQUEST_AUTH_HEADER;
+ static const char* AZURE_OPEN_AI_REQUEST_AUTH_HEADER;
+ enum AI_PLATFORM {
+ /// Unsupported platform
+ UNSUPPORTED,
+ /// OpenAI public platform
+ OPEN_AI,
+ /// Azure OpenAI platform
+ AZURE_OPEN_AI
+ };
/// Sends a prompt to the input AI endpoint using the input model, api_key
and
/// optional params.
static StringVal AiGenerateText(FunctionContext* ctx, const StringVal&
endpoint,
@@ -58,14 +68,26 @@ class AiFunctions {
/// Internal function which implements the logic of parsing user input and
sending
/// request to the external API endpoint. If 'dry_run' is set, the POST
request is
/// returned. 'dry_run' mode is used only for unit tests.
- template <bool fastpath>
- static StringVal AiGenerateTextInternal(FunctionContext* ctx, const
StringVal& endpoint,
+ template <bool fastpath, AI_PLATFORM platform>
+ static StringVal AiGenerateTextInternal(
+ FunctionContext* ctx, const std::string_view& endpoint,
const StringVal& prompt, const StringVal& model,
const StringVal& api_key_jceks_secret, const StringVal& params, const
bool dry_run);
+ /// Helper function for calling AiGenerateTextInternal with common code for
both
+ /// fastpath and regular path.
+ template <bool fastpath>
+ static StringVal AiGenerateTextHelper(
+ FunctionContext* ctx, const StringVal& endpoint, const StringVal& prompt,
+ const StringVal& model, const StringVal& api_key_jceks_secret,
+ const StringVal& params);
/// Internal helper function for parsing OPEN AI's API response. Input
parameter is the
/// json representation of the OPEN AI's API response.
static std::string AiGenerateTextParseOpenAiResponse(
const std::string_view& reponse);
+ /// Helper function for getting AI Platform from the endpoint
+ static AI_PLATFORM GetAiPlatformFromEndpoint(const std::string_view&
endpoint);
+ /// Helper functions for deep copying error message
+ static StringVal copyErrorMessage(FunctionContext* ctx, const string&
errorMsg);
friend class ExprTest_AiFunctionsTest_Test;
};
diff --git a/be/src/exprs/ai-functions.inline.h
b/be/src/exprs/ai-functions.inline.h
index 9f143e2df..bd39a5002 100644
--- a/be/src/exprs/ai-functions.inline.h
+++ b/be/src/exprs/ai-functions.inline.h
@@ -42,55 +42,69 @@ DECLARE_int32(ai_connection_timeout_s);
namespace impala {
-template <bool fastpath>
+#define RETURN_STRINGVAL_IF_ERROR(ctx, stmt) \
+ do { \
+ const ::impala::Status& _status = (stmt); \
+ if (UNLIKELY(!_status.ok())) { \
+ return copyErrorMessage(ctx, _status.msg().msg()); \
+ } \
+ } while (false)
+
+template<AiFunctions::AI_PLATFORM platform>
+Status getAuthorizationHeader(string& authHeader, const string& api_key) {
+ switch(platform) {
+ case AiFunctions::AI_PLATFORM::OPEN_AI:
+ authHeader = AiFunctions::OPEN_AI_REQUEST_AUTH_HEADER + api_key;
+ return Status::OK();
+ case AiFunctions::AI_PLATFORM::AZURE_OPEN_AI:
+ authHeader = AiFunctions::AZURE_OPEN_AI_REQUEST_AUTH_HEADER + api_key;
+ return Status::OK();
+ default:
+ DCHECK(false) <<
+ "AiGenerateTextInternal should only be called for Supported
Platforms";
+ return Status(AiFunctions::AI_GENERATE_TXT_UNSUPPORTED_ENDPOINT_ERROR);
+ }
+}
+
+template <bool fastpath, AiFunctions::AI_PLATFORM platform>
StringVal AiFunctions::AiGenerateTextInternal(FunctionContext* ctx,
- const StringVal& endpoint, const StringVal& prompt, const StringVal& model,
+ const std::string_view& endpoint_sv, const StringVal& prompt, const
StringVal& model,
const StringVal& api_key_jceks_secret, const StringVal& params, const bool
dry_run) {
- std::string_view endpoint_sv(FLAGS_ai_endpoint);
- // endpoint validation
- if (!fastpath && endpoint.ptr != nullptr && endpoint.len != 0) {
- endpoint_sv = std::string_view(reinterpret_cast<char*>(endpoint.ptr),
endpoint.len);
- // Simple validation for endpoint. It should start with https://
- if (!is_api_endpoint_valid(endpoint_sv)) {
- LOG(ERROR) << "AI Generate Text: \ninvalid protocol: " << endpoint_sv;
- return StringVal(AI_GENERATE_TXT_INVALID_PROTOCOL_ERROR.c_str());
- }
- // Only OpenAI endpoints are supported.
- if (!is_api_endpoint_supported(endpoint_sv)) {
- LOG(ERROR) << "AI Generate Text: \nunsupported endpoint: " <<
endpoint_sv;
- return StringVal(AI_GENERATE_TXT_UNSUPPORTED_ENDPOINT_ERROR.c_str());
- }
- }
// Generate the header for the POST request
vector<string> headers;
headers.emplace_back(OPEN_AI_REQUEST_FIELD_CONTENT_TYPE_HEADER);
+ string authHeader;
if (!fastpath && api_key_jceks_secret.ptr != nullptr &&
api_key_jceks_secret.len != 0) {
string api_key;
string api_key_secret(
reinterpret_cast<char*>(api_key_jceks_secret.ptr),
api_key_jceks_secret.len);
- Status status = ExecEnv::GetInstance()->frontend()->GetSecretFromKeyStore(
- api_key_secret, &api_key);
- if (!status.ok()) {
- return StringVal::CopyFrom(ctx,
- reinterpret_cast<const uint8_t*>(status.msg().msg().c_str()),
- status.msg().msg().length());
- }
- headers.emplace_back("Authorization: Bearer " + api_key);
+ RETURN_STRINGVAL_IF_ERROR(ctx,
+ ExecEnv::GetInstance()->frontend()->GetSecretFromKeyStore(
+ api_key_secret, &api_key));
+ RETURN_STRINGVAL_IF_ERROR(ctx,
+ getAuthorizationHeader<platform>(authHeader, api_key));
} else {
- headers.emplace_back("Authorization: Bearer " + ai_api_key_);
+ RETURN_STRINGVAL_IF_ERROR(ctx,
+ getAuthorizationHeader<platform>(authHeader, ai_api_key_));
}
+ headers.emplace_back(authHeader);
// Generate the payload for the POST request
Document payload;
payload.SetObject();
Document::AllocatorType& payload_allocator = payload.GetAllocator();
- if (!fastpath && model.ptr != nullptr && model.len != 0) {
- payload.AddMember("model",
- rapidjson::StringRef(reinterpret_cast<char*>(model.ptr), model.len),
- payload_allocator);
- } else {
- payload.AddMember("model",
- rapidjson::StringRef(FLAGS_ai_model.c_str(), FLAGS_ai_model.length()),
- payload_allocator);
+ // Azure Open AI endpoint doesn't expect model as a separate param since it's
+ // embedded in the endpoint. The 'deployment_name' below maps to a model.
+ //
https://<resource_name>.openai.azure.com/openai/deployments/<deployment_name>/..
+ if (platform != AI_PLATFORM::AZURE_OPEN_AI) {
+ if (!fastpath && model.ptr != nullptr && model.len != 0) {
+ payload.AddMember("model",
+ rapidjson::StringRef(reinterpret_cast<char*>(model.ptr), model.len),
+ payload_allocator);
+ } else {
+ payload.AddMember("model",
+ rapidjson::StringRef(FLAGS_ai_model.c_str(),
FLAGS_ai_model.length()),
+ payload_allocator);
+ }
}
Value message_array(rapidjson::kArrayType);
Value message(rapidjson::kObjectType);
@@ -169,11 +183,7 @@ StringVal
AiFunctions::AiGenerateTextInternal(FunctionContext* ctx,
status = curl.PostToURL(endpoint_str, payload_str, &resp, headers);
}
VLOG(2) << "AI Generate Text: \noriginal response: " << resp.ToString();
- if (!status.ok()) {
- string msg = status.ToString();
- return StringVal::CopyFrom(
- ctx, reinterpret_cast<const uint8_t*>(msg.c_str()), msg.size());
- }
+ if (UNLIKELY(!status.ok())) return copyErrorMessage(ctx, status.ToString());
// Parse the JSON response string
std::string response = AiGenerateTextParseOpenAiResponse(
std::string_view(reinterpret_cast<char*>(resp.data()), resp.size()));
diff --git a/be/src/exprs/expr-test.cc b/be/src/exprs/expr-test.cc
index 61e53933f..e57dbeeb7 100644
--- a/be/src/exprs/expr-test.cc
+++ b/be/src/exprs/expr-test.cc
@@ -11227,7 +11227,10 @@ TEST_P(ExprTest, AiFunctionsTest) {
string secret_key("do_not_share");
AiFunctions::set_api_key(secret_key);
// valid endpoint
- StringVal openai_endpoint("https://openai.azure.com");
+ std::string_view
openai_endpoint("https://api.openai.com/v1/chat/completions");
+ std::string_view azure_openai_endpoint(
+ "https://resource.openai.azure.com/openai/deployments/"
+ "deployment/completions?api-version=2024-02-01");
// empty jceks secret key
StringVal jceks_secret("");
// dummy model.
@@ -11239,9 +11242,19 @@ TEST_P(ExprTest, AiFunctionsTest) {
// dry_run to receive HTTP request header and body
bool dry_run = true;
+ // Test GetAiPlatformFromEndpoint
+ EXPECT_EQ(AiFunctions::AI_PLATFORM::OPEN_AI,
+ AiFunctions::GetAiPlatformFromEndpoint(openai_endpoint));
+ EXPECT_EQ(AiFunctions::AI_PLATFORM::AZURE_OPEN_AI,
+ AiFunctions::GetAiPlatformFromEndpoint(azure_openai_endpoint));
+ EXPECT_EQ(AiFunctions::AI_PLATFORM::UNSUPPORTED,
+ AiFunctions::GetAiPlatformFromEndpoint("https://qwerty.com"));
+
// Test fastpath
- StringVal result = AiFunctions::AiGenerateTextInternal<true>(ctx,
StringVal::null(),
- prompt, StringVal::null(), StringVal::null(), StringVal::null(),
dry_run);
+ StringVal result =
+ AiFunctions::AiGenerateTextInternal<true,
AiFunctions::AI_PLATFORM::OPEN_AI>(
+ ctx, FLAGS_ai_endpoint, prompt, StringVal::null(), StringVal::null(),
+ StringVal::null(), dry_run);
EXPECT_EQ(string(reinterpret_cast<char*>(result.ptr), result.len),
string("https://api.openai.com/v1/chat/completions"
"\nContent-Type: application/json"
@@ -11249,22 +11262,34 @@ TEST_P(ExprTest, AiFunctionsTest) {
"\n{\"model\":\"gpt-4\",\"messages\":[{\"role\":\"user\",\"content\":"
"\"hello!\"}]}"));
+ result =
+ AiFunctions::AiGenerateTextInternal<true,
AiFunctions::AI_PLATFORM::AZURE_OPEN_AI>(
+ ctx, azure_openai_endpoint, prompt, StringVal::null(),
StringVal::null(),
+ StringVal::null(), dry_run);
+ EXPECT_EQ(string(reinterpret_cast<char*>(result.ptr), result.len),
+ string("https://resource.openai.azure.com/openai/deployments/"
+ "deployment/completions?api-version=2024-02-01"
+ "\nContent-Type: application/json"
+ "\napi-key: do_not_share"
+ "\n{\"messages\":[{\"role\":\"user\",\"content\":"
+ "\"hello!\"}]}"));
+
// Test endpoints.
// endpoints must begin with https.
- result = AiFunctions::AiGenerateTextInternal<false>(
- ctx, StringVal("http://ai.com"), prompt, model, jceks_secret,
json_params, dry_run);
+ result = AiFunctions::AiGenerateText(
+ ctx, StringVal("http://ai.com"), prompt, model, jceks_secret,
json_params);
EXPECT_EQ(string(reinterpret_cast<char*>(result.ptr), result.len),
AiFunctions::AI_GENERATE_TXT_INVALID_PROTOCOL_ERROR);
// only OpenAI endpoints are supported.
- result = AiFunctions::AiGenerateTextInternal<false>(ctx,
StringVal("https://ai.com"),
- prompt, model, jceks_secret, json_params, dry_run);
+ result = AiFunctions::AiGenerateText(
+ ctx, "https://ai.com", prompt, model, jceks_secret, json_params);
EXPECT_EQ(string(reinterpret_cast<char*>(result.ptr), result.len),
AiFunctions::AI_GENERATE_TXT_UNSUPPORTED_ENDPOINT_ERROR);
// valid request using OpenAI endpoint.
- result = AiFunctions::AiGenerateTextInternal<false>(
+ result = AiFunctions::AiGenerateTextInternal<false,
AiFunctions::AI_PLATFORM::OPEN_AI>(
ctx, openai_endpoint, prompt, model, jceks_secret, json_params, dry_run);
EXPECT_EQ(string(reinterpret_cast<char*>(result.ptr), result.len),
- string("https://openai.azure.com"
+ string("https://api.openai.com/v1/chat/completions"
"\nContent-Type: application/json"
"\nAuthorization: Bearer do_not_share"
"\n{\"model\":\"bot\",\"messages\":[{\"role\":\"user\",\"content\":\"hello!"
@@ -11273,21 +11298,27 @@ TEST_P(ExprTest, AiFunctionsTest) {
// Test prompt.
// prompt cannot be empty.
StringVal invalid_prompt("");
- result = AiFunctions::AiGenerateTextInternal<false>(
+ result = AiFunctions::AiGenerateTextInternal<false,
AiFunctions::AI_PLATFORM::OPEN_AI>(
ctx, openai_endpoint, invalid_prompt, model, jceks_secret, json_params,
dry_run);
+ EXPECT_EQ(string(reinterpret_cast<char*>(result.ptr), result.len),
+ AiFunctions::AI_GENERATE_TXT_INVALID_PROMPT_ERROR);
+ result = AiFunctions::AiGenerateTextDefault(ctx, invalid_prompt);
EXPECT_EQ(string(reinterpret_cast<char*>(result.ptr), result.len),
AiFunctions::AI_GENERATE_TXT_INVALID_PROMPT_ERROR);
// prompt cannot be null.
invalid_prompt = StringVal::null();
- result = AiFunctions::AiGenerateTextInternal<false>(
+ result = AiFunctions::AiGenerateTextInternal<false,
AiFunctions::AI_PLATFORM::OPEN_AI>(
ctx, openai_endpoint, invalid_prompt, model, jceks_secret, json_params,
dry_run);
EXPECT_EQ(string(reinterpret_cast<char*>(result.ptr), result.len),
AiFunctions::AI_GENERATE_TXT_INVALID_PROMPT_ERROR);
+ result = AiFunctions::AiGenerateTextDefault(ctx, invalid_prompt);
+ EXPECT_EQ(string(reinterpret_cast<char*>(result.ptr), result.len),
+ AiFunctions::AI_GENERATE_TXT_INVALID_PROMPT_ERROR);
// Test override/additional params
// invalid json results in error.
StringVal invalid_json_params("{\"temperature\": 0.49, \"stop\":
[\"*\",::,]}");
- result = AiFunctions::AiGenerateTextInternal<false>(
+ result = AiFunctions::AiGenerateTextInternal<false,
AiFunctions::AI_PLATFORM::OPEN_AI>(
ctx, openai_endpoint, prompt, model, jceks_secret, invalid_json_params,
dry_run);
EXPECT_EQ(string(reinterpret_cast<char*>(result.ptr), result.len),
AiFunctions::AI_GENERATE_TXT_JSON_PARSE_ERROR);
@@ -11295,10 +11326,10 @@ TEST_P(ExprTest, AiFunctionsTest) {
// like 'temperature' and 'stop'.
StringVal valid_json_params(
"{\"model\": \"gpt\", \"temperature\": 0.49, \"stop\": [\"*\", \"%\"]}");
- result = AiFunctions::AiGenerateTextInternal<false>(
+ result = AiFunctions::AiGenerateTextInternal<false,
AiFunctions::AI_PLATFORM::OPEN_AI>(
ctx, openai_endpoint, prompt, model, jceks_secret, valid_json_params,
dry_run);
EXPECT_EQ(string(reinterpret_cast<char*>(result.ptr), result.len),
- string("https://openai.azure.com"
+ string("https://api.openai.com/v1/chat/completions"
"\nContent-Type: application/json"
"\nAuthorization: Bearer do_not_share"
"\n{\"model\":\"gpt\",\"messages\":[{\"role\":\"user\",\"content\":\"hello!"
@@ -11306,44 +11337,45 @@ TEST_P(ExprTest, AiFunctionsTest) {
// messages cannot be overriden, as they we constructed from the prompt.
StringVal forbidden_msg_override(
"{\"messages\": [{\"role\":\"system\",\"content\":\"howdy!\"}]}");
- result = AiFunctions::AiGenerateTextInternal<false>(
+ result = AiFunctions::AiGenerateTextInternal<false,
AiFunctions::AI_PLATFORM::OPEN_AI>(
ctx, openai_endpoint, prompt, model, jceks_secret,
forbidden_msg_override, dry_run);
EXPECT_EQ(string(reinterpret_cast<char*>(result.ptr), result.len),
AiFunctions::AI_GENERATE_TXT_MSG_OVERRIDE_FORBIDDEN_ERROR);
// 'n != 1' cannot be overriden as additional params
StringVal forbidden_n_value("{\"n\": 2}");
- result = AiFunctions::AiGenerateTextInternal<false>(
+ result = AiFunctions::AiGenerateTextInternal<false,
AiFunctions::AI_PLATFORM::OPEN_AI>(
ctx, openai_endpoint, prompt, model, jceks_secret, forbidden_n_value,
dry_run);
EXPECT_EQ(string(reinterpret_cast<char*>(result.ptr), result.len),
AiFunctions::AI_GENERATE_TXT_N_OVERRIDE_FORBIDDEN_ERROR);
// non integer value of 'n' cannot be overriden as additional params
StringVal forbidden_n_type("{\"n\": \"1\"}");
- result = AiFunctions::AiGenerateTextInternal<false>(
+ result = AiFunctions::AiGenerateTextInternal<false,
AiFunctions::AI_PLATFORM::OPEN_AI>(
ctx, openai_endpoint, prompt, model, jceks_secret, forbidden_n_type,
dry_run);
EXPECT_EQ(string(reinterpret_cast<char*>(result.ptr), result.len),
AiFunctions::AI_GENERATE_TXT_N_OVERRIDE_FORBIDDEN_ERROR);
// accept 'n=1' override as additional params
StringVal allowed_n_override("{\"n\": 1}");
- result = AiFunctions::AiGenerateTextInternal<false>(
+ result = AiFunctions::AiGenerateTextInternal<false,
AiFunctions::AI_PLATFORM::OPEN_AI>(
ctx, openai_endpoint, prompt, model, jceks_secret, allowed_n_override,
dry_run);
EXPECT_EQ(string(reinterpret_cast<char*>(result.ptr), result.len),
- string("https://openai.azure.com"
+ string("https://api.openai.com/v1/chat/completions"
"\nContent-Type: application/json"
"\nAuthorization: Bearer do_not_share"
"\n{\"model\":\"bot\",\"messages\":[{\"role\":\"user\",\"content\":\"hello!"
"\"}],\"n\":1}"));
// Test flag file options are used when input is empty/null
- result = AiFunctions::AiGenerateTextInternal<false>(ctx, StringVal::null(),
prompt,
- StringVal::null(), jceks_secret, json_params, dry_run);
+ result = AiFunctions::AiGenerateTextInternal<false,
AiFunctions::AI_PLATFORM::OPEN_AI>(
+ ctx, FLAGS_ai_endpoint, prompt, StringVal::null(), jceks_secret,
json_params,
+ dry_run);
EXPECT_EQ(string(reinterpret_cast<char*>(result.ptr), result.len),
string("https://api.openai.com/v1/chat/completions"
"\nContent-Type: application/json"
"\nAuthorization: Bearer do_not_share"
"\n{\"model\":\"gpt-4\",\"messages\":[{\"role\":\"user\",\"content\":"
"\"hello!\"}]}"));
- result = AiFunctions::AiGenerateTextInternal<false>(
- ctx, StringVal(""), prompt, StringVal(""), jceks_secret, json_params,
dry_run);
+ result = AiFunctions::AiGenerateTextInternal<false,
AiFunctions::AI_PLATFORM::OPEN_AI>(
+ ctx, FLAGS_ai_endpoint, prompt, StringVal(""), jceks_secret,
json_params, dry_run);
EXPECT_EQ(string(reinterpret_cast<char*>(result.ptr), result.len),
string("https://api.openai.com/v1/chat/completions"
"\nContent-Type: application/json"