github-actions[bot] commented on code in PR #61924:
URL: https://github.com/apache/doris/pull/61924#discussion_r3019386764
##########
be/src/exprs/function/ai/ai_functions.h:
##########
@@ -229,7 +170,7 @@ class AIFunction : public IFunction {
Status do_send_request(HttpClient* client, const std::string& request_body,
std::string& response, const TAIResource& config,
std::shared_ptr<AIAdapter>& adapter,
FunctionContext* context) const {
- RETURN_IF_ERROR(client->init(config.endpoint));
+ RETURN_IF_ERROR(client->init(config.endpoint, false));
QueryContext* query_ctx = context->state()->get_query_ctx();
Review Comment:
Logging the full `request_body`/`response_body` at INFO is risky here. In
the new multimodal flow the request body contains presigned media URLs, and
those URLs are bearer-style credentials for the underlying object. Because this
runs on every AI call, normal production logs will now persist raw external
payloads and temporary access tokens for all `EMBED` queries. Please either
drop these logs or redact the payloads aggressively before emitting them.
##########
be/src/exprs/function/ai/embed.h:
##########
@@ -33,7 +43,270 @@ class FunctionEmbed : public AIFunction<FunctionEmbed> {
return
std::make_shared<DataTypeArray>(make_nullable(std::make_shared<DataTypeFloat32>()));
}
+ Status execute_with_adapter(FunctionContext* context, Block& block,
+ const ColumnNumbers& arguments, uint32_t
result,
+ size_t input_rows_count, const TAIResource&
config,
+ std::shared_ptr<AIAdapter>& adapter) const {
+ if (arguments.size() != 2) {
+ return Status::InvalidArgument("Function EMBED expects 2
arguments, but got {}",
+ arguments.size());
+ }
+
+ PrimitiveType input_type =
+
remove_nullable(block.get_by_position(arguments[1]).type)->get_primitive_type();
+ if (input_type == PrimitiveType::TYPE_JSONB) {
+ return _execute_multimodal_embed(context, block, arguments,
result, input_rows_count,
+ config, adapter);
+ }
+ if (input_type == PrimitiveType::TYPE_STRING || input_type ==
PrimitiveType::TYPE_VARCHAR ||
+ input_type == PrimitiveType::TYPE_CHAR) {
+ return _execute_text_embed(context, block, arguments, result,
input_rows_count, config,
+ adapter);
+ }
+ return Status::InvalidArgument(
+ "Function EMBED expects the second argument to be STRING or
JSON, but got type {}",
+ block.get_by_position(arguments[1]).type->get_name());
+ }
+
static FunctionPtr create() { return std::make_shared<FunctionEmbed>(); }
+
+private:
+ Status _execute_text_embed(FunctionContext* context, Block& block,
+ const ColumnNumbers& arguments, uint32_t result,
+ size_t input_rows_count, const TAIResource&
config,
+ std::shared_ptr<AIAdapter>& adapter) const {
+ auto col_result = ColumnArray::create(
+ ColumnNullable::create(ColumnFloat32::create(),
ColumnUInt8::create()));
+
+ for (size_t i = 0; i < input_rows_count; ++i) {
+ std::string prompt;
+ RETURN_IF_ERROR(build_prompt(block, arguments, i, prompt));
+
+ std::vector<float> float_result;
+ RETURN_IF_ERROR(execute_single_request(prompt, float_result,
config, adapter, context));
+ _insert_embedding_result(*col_result, float_result);
+ }
+
+ block.replace_by_position(result, std::move(col_result));
+ return Status::OK();
+ }
+
+ Status _execute_multimodal_embed(FunctionContext* context, Block& block,
+ const ColumnNumbers& arguments, uint32_t
result,
+ size_t input_rows_count, const
TAIResource& config,
+ std::shared_ptr<AIAdapter>& adapter)
const {
+ auto col_result = ColumnArray::create(
+ ColumnNullable::create(ColumnFloat32::create(),
ColumnUInt8::create()));
+
+ int64_t ttl_seconds = 3600;
+ QueryContext* query_ctx = context->state()->get_query_ctx();
+ if (query_ctx &&
query_ctx->query_options().__isset.embed_presigned_url_ttl_seconds) {
+ ttl_seconds =
query_ctx->query_options().embed_presigned_url_ttl_seconds;
+ if (ttl_seconds <= 0) {
+ ttl_seconds = 3600;
+ }
+ }
+ LOG(INFO) << "[lzq]: EMBED multimodal execute begin,
input_rows_count=" << input_rows_count
+ << ", ttl_seconds=" << ttl_seconds;
+
+ const ColumnWithTypeAndName& file_column =
block.get_by_position(arguments[1]);
+ for (size_t i = 0; i < input_rows_count; ++i) {
+ rapidjson::Document file_input;
+ RETURN_IF_ERROR(_parse_file_input(file_column, i, file_input));
+
+ MultimodalType media_type;
+ RETURN_IF_ERROR(_infer_media_type(file_input, media_type));
+ LOG(INFO) << "[lzq]: EMBED inferred media type, row=" << i
+ << ", media_type=" <<
multimodal_type_to_string(media_type);
+
+ std::string media_url;
+ RETURN_IF_ERROR(_resolve_media_url(file_input, ttl_seconds,
media_url));
+ LOG(INFO) << "[lzq]: EMBED multimodal resolved media, row=" << i
+ << ", media_type=" <<
multimodal_type_to_string(media_type)
+ << ", media_url=" << media_url;
+
+ std::string request_body;
+
RETURN_IF_ERROR(adapter->build_multimodal_embedding_request(media_type,
media_url,
+
request_body));
+ LOG(INFO) << "[lzq]: EMBED multimodal request body built, row=" <<
i
+ << ", request_body=" << request_body;
+
+ std::vector<float> float_result;
+ RETURN_IF_ERROR(execute_embedding_request(request_body,
float_result, config, adapter,
+ context));
+ _insert_embedding_result(*col_result, float_result);
+ }
+
+ block.replace_by_position(result, std::move(col_result));
+ return Status::OK();
+ }
+
+ static void _insert_embedding_result(ColumnArray& col_array,
+ const std::vector<float>&
float_result) {
+ auto& offsets = col_array.get_offsets();
+ auto& nested_nullable_col =
assert_cast<ColumnNullable&>(col_array.get_data());
+ auto& nested_col =
+
assert_cast<ColumnFloat32&>(*(nested_nullable_col.get_nested_column_ptr()));
+ nested_col.reserve(nested_col.size() + float_result.size());
+
+ size_t current_offset = nested_col.size();
+ nested_col.insert_many_raw_data(reinterpret_cast<const
char*>(float_result.data()),
+ float_result.size());
+ offsets.push_back(current_offset + float_result.size());
+ auto& null_map = nested_nullable_col.get_null_map_column();
+ null_map.insert_many_vals(0, float_result.size());
+ }
+
+ static bool _starts_with_ignore_case(std::string_view s, std::string_view
prefix) {
+ if (s.size() < prefix.size()) {
+ return false;
+ }
+ return std::equal(prefix.begin(), prefix.end(), s.begin(), [](char a,
char b) {
+ return std::tolower(static_cast<unsigned char>(a)) ==
+ std::tolower(static_cast<unsigned char>(b));
+ });
+ }
+
+ static Status _infer_media_type(const rapidjson::Value& file_input,
+ MultimodalType& media_type) {
+ std::string content_type;
+ RETURN_IF_ERROR(_get_required_string_field(file_input, "content_type",
content_type));
+
+ if (_starts_with_ignore_case(content_type, "image/")) {
+ media_type = MultimodalType::IMAGE;
+ return Status::OK();
+ }
+ if (_starts_with_ignore_case(content_type, "video/")) {
+ media_type = MultimodalType::VIDEO;
+ return Status::OK();
+ }
+ if (_starts_with_ignore_case(content_type, "audio/")) {
+ media_type = MultimodalType::AUDIO;
+ return Status::OK();
+ }
+ return Status::InvalidArgument("Unsupported content_type for EMBED:
{}", content_type);
+ }
+
+ // Parse the FILE-like JSONB argument into a JSON object for downstream
field reads.
+ static Status _parse_file_input(const ColumnWithTypeAndName& file_column,
size_t row_num,
+ rapidjson::Document& file_input) {
+ std::string file_json =
+
JsonbToJson::jsonb_to_json_string(file_column.column->get_data_at(row_num).data,
+
file_column.column->get_data_at(row_num).size);
+ file_input.Parse(file_json.c_str());
+ if (file_input.HasParseError() || !file_input.IsObject()) {
+ return Status::InvalidArgument(
+ "EMBED file argument must be a valid json object, but got:
{}", file_json);
+ }
+ LOG(INFO) << "[lzq]: EMBED parsed file json, row=" << row_num << ",
file_json=" << file_json;
+ return Status::OK();
+ }
+
+ // [Pending] After support FILE type, We should use the interface provided
by FILE to get the fields
+ // replacing this function
+ static Status _get_required_string_field(const rapidjson::Value& obj,
const char* field_name,
+ std::string& value) {
+ auto iter = obj.FindMember(field_name);
+ if (iter == obj.MemberEnd() || !iter->value.IsString()) {
+ return Status::InvalidArgument(
+ "EMBED file json field '{}' is required and must be a
string", field_name);
+ }
+ value = iter->value.GetString();
+ if (value.empty()) {
+ return Status::InvalidArgument("EMBED file json field '{}' can not
be empty",
+ field_name);
+ }
+ return Status::OK();
+ }
+
+ Status _generate_s3_presigned_url(const rapidjson::Value& file_input,
const std::string& uri,
+ int64_t ttl_seconds, std::string&
presigned_url) const {
+ // Parse S3 URI to extract bucket and key
Review Comment:
This path still hard-depends on `S3URI`, but the surrounding code now
advertises provider flexibility via `provider` +
`obj_storage_type_from_string()`. `S3URI::parse()` only accepts `s3://`,
`http://`, and `https://`, so a FILE JSON like
`{"provider":"OSS","uri":"oss://bucket/key",...}` will fail before
`S3ClientFactory` ever sees the provider. That means the new non-AWS provider
path is functionally broken even though the FE validation now allows providers
such as JINA and the BE maps multiple object-store providers.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]