HappenLee commented on code in PR #57868: URL: https://github.com/apache/doris/pull/57868#discussion_r2553971260
########## be/src/udf/python/python_udaf_client.cpp: ########## @@ -0,0 +1,613 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "udf/python/python_udaf_client.h" + +#include "arrow/builder.h" +#include "arrow/flight/client.h" +#include "arrow/flight/server.h" +#include "arrow/io/memory.h" +#include "arrow/ipc/writer.h" +#include "arrow/record_batch.h" +#include "arrow/type.h" +#include "common/compiler_util.h" +#include "common/status.h" +#include "udf/python/python_udf_meta.h" +#include "udf/python/python_udf_runtime.h" +#include "util/arrow/utils.h" + +namespace doris { + +// Unified Schema for ALL UDAF operations +// This ensures gRPC Flight Stream uses the same schema for all RecordBatches +// Fields: [operation: int8, place_id: int64, metadata: binary, data: binary] +// - operation: UDAFOperation enum value +// - place_id: Aggregate state identifier +// - metadata: Serialized metadata (e.g., is_single_place, row_start, row_end, place_offset) +// - data: Serialized operation-specific data (e.g., input RecordBatch, serialized_state) +static const std::shared_ptr<arrow::Schema> kUnifiedUDAFSchema = arrow::schema({ + arrow::field("operation", arrow::int8()), + arrow::field("place_id", arrow::int64()), + arrow::field("metadata", arrow::binary()), + arrow::field("data", arrow::binary()), +}); + +// Metadata Schema for ACCUMULATE operation +// Fields: [is_single_place: bool, row_start: int64, row_end: int64, place_offset: int64] +static const std::shared_ptr<arrow::Schema> kAccumulateMetadataSchema = arrow::schema({ + arrow::field("is_single_place", arrow::boolean()), + arrow::field("row_start", arrow::int64()), + arrow::field("row_end", arrow::int64()), + arrow::field("place_offset", arrow::int64()), +}); + +// Helper function to serialize RecordBatch to binary +static Status serialize_record_batch(const arrow::RecordBatch& batch, + std::shared_ptr<arrow::Buffer>* out) { + auto output_stream_result = arrow::io::BufferOutputStream::Create(); + if (UNLIKELY(!output_stream_result.ok())) { + return Status::InternalError("Failed to create buffer output stream: {}", + output_stream_result.status().message()); + } + auto output_stream = std::move(output_stream_result).ValueOrDie(); + + auto writer_result = arrow::ipc::MakeStreamWriter(output_stream, batch.schema()); + if (UNLIKELY(!writer_result.ok())) { + return Status::InternalError("Failed to create IPC writer: {}", + writer_result.status().message()); + } + auto writer = std::move(writer_result).ValueOrDie(); + + RETURN_DORIS_STATUS_IF_ERROR(writer->WriteRecordBatch(batch)); + RETURN_DORIS_STATUS_IF_ERROR(writer->Close()); + + auto buffer_result = output_stream->Finish(); + if (UNLIKELY(!buffer_result.ok())) { + return Status::InternalError("Failed to finish buffer: {}", + buffer_result.status().message()); + } + *out = std::move(buffer_result).ValueOrDie(); + return Status::OK(); +} + +// Helper function to deserialize RecordBatch from binary +static Status deserialize_record_batch(const std::shared_ptr<arrow::Buffer>& buffer, + std::shared_ptr<arrow::RecordBatch>* out) { + // Create BufferReader from the input buffer + auto input_stream = std::make_shared<arrow::io::BufferReader>(buffer); + + // Open IPC stream reader + auto reader_result = arrow::ipc::RecordBatchStreamReader::Open(input_stream); + if (UNLIKELY(!reader_result.ok())) { + return Status::InternalError("Failed to open IPC reader: {}", + reader_result.status().message()); + } + auto reader = std::move(reader_result).ValueOrDie(); + + // Read the first (and only) RecordBatch + auto batch_result = reader->Next(); + if (UNLIKELY(!batch_result.ok())) { + return Status::InternalError("Failed to read RecordBatch: {}", + batch_result.status().message()); + } + + *out = std::move(batch_result).ValueOrDie(); + if (UNLIKELY(!*out)) { + return Status::InternalError("Deserialized RecordBatch is null"); + } + + return Status::OK(); +} + +// Helper function to validate and cast Arrow column to expected type +template <typename ArrowArrayType> +static Status validate_and_cast_column(const std::shared_ptr<arrow::RecordBatch>& batch, + int column_index, const std::string& operation_name, + std::shared_ptr<ArrowArrayType>* out) { + if (UNLIKELY(column_index >= batch->num_columns())) { + return Status::InternalError("{} response: column index {} out of range (total {} columns)", + operation_name, column_index, batch->num_columns()); + } + + auto column = batch->column(column_index); + + // Create expected type instance for comparison + std::shared_ptr<arrow::DataType> expected_type; + if constexpr (std::is_same_v<ArrowArrayType, arrow::BooleanArray>) { + expected_type = arrow::boolean(); + } else if constexpr (std::is_same_v<ArrowArrayType, arrow::Int64Array>) { + expected_type = arrow::int64(); + } else if constexpr (std::is_same_v<ArrowArrayType, arrow::BinaryArray>) { + expected_type = arrow::binary(); + } else { + return Status::InternalError("{} response: unsupported array type", operation_name); + } + + if (UNLIKELY(!column->type()->Equals(expected_type))) { + return Status::InternalError( + "{} response: expected column {} to be of type {}, but got type {}", operation_name, + column_index, expected_type->ToString(), column->type()->ToString()); + } + + *out = std::static_pointer_cast<ArrowArrayType>(column); + return Status::OK(); +} + +// Helper function to create a unified operation batch +static Status create_unified_batch(PythonUDAFClient::UDAFOperation operation, int64_t place_id, + const std::shared_ptr<arrow::Buffer>& metadata, + const std::shared_ptr<arrow::Buffer>& data, + std::shared_ptr<arrow::RecordBatch>* out) { + arrow::Int8Builder op_builder; + arrow::Int64Builder place_builder; + arrow::BinaryBuilder metadata_builder; + arrow::BinaryBuilder data_builder; + + RETURN_DORIS_STATUS_IF_ERROR(op_builder.Append(static_cast<int8_t>(operation))); + RETURN_DORIS_STATUS_IF_ERROR(place_builder.Append(place_id)); + + if (metadata) { + RETURN_DORIS_STATUS_IF_ERROR( + metadata_builder.Append(metadata->data(), static_cast<int32_t>(metadata->size()))); + } else { + RETURN_DORIS_STATUS_IF_ERROR(metadata_builder.AppendNull()); + } + + if (data) { + RETURN_DORIS_STATUS_IF_ERROR( + data_builder.Append(data->data(), static_cast<int32_t>(data->size()))); + } else { + RETURN_DORIS_STATUS_IF_ERROR(data_builder.AppendNull()); + } + + std::shared_ptr<arrow::Array> op_array; + std::shared_ptr<arrow::Array> place_array; + std::shared_ptr<arrow::Array> metadata_array; + std::shared_ptr<arrow::Array> data_array; + + RETURN_DORIS_STATUS_IF_ERROR(op_builder.Finish(&op_array)); + RETURN_DORIS_STATUS_IF_ERROR(place_builder.Finish(&place_array)); + RETURN_DORIS_STATUS_IF_ERROR(metadata_builder.Finish(&metadata_array)); + RETURN_DORIS_STATUS_IF_ERROR(data_builder.Finish(&data_array)); + + *out = arrow::RecordBatch::Make(kUnifiedUDAFSchema, 1, + {op_array, place_array, metadata_array, data_array}); + return Status::OK(); +} + +Status PythonUDAFClient::create(const PythonUDFMeta& func_meta, ProcessPtr process, + PythonUDAFClientPtr* client) { + PythonUDAFClientPtr python_udaf_client = std::make_shared<PythonUDAFClient>(); + RETURN_IF_ERROR(python_udaf_client->init(func_meta, std::move(process))); + *client = std::move(python_udaf_client); + return Status::OK(); +} + +Status PythonUDAFClient::init(const PythonUDFMeta& func_meta, ProcessPtr process) { + if (_inited) { + return Status::InternalError("PythonUDAFClient has already been initialized"); + } + + arrow::flight::Location location; + RETURN_DORIS_STATUS_IF_RESULT_ERROR(location, + arrow::flight::Location::Parse(process->get_uri())); + RETURN_DORIS_STATUS_IF_RESULT_ERROR(_arrow_client, FlightClient::Connect(location)); + + std::string command; + RETURN_IF_ERROR(func_meta.serialize_to_json(&command)); + + FlightDescriptor descriptor = FlightDescriptor::Command(command); + arrow::flight::FlightClient::DoExchangeResult exchange_res; + RETURN_DORIS_STATUS_IF_RESULT_ERROR(exchange_res, _arrow_client->DoExchange(descriptor)); + + _reader = std::move(exchange_res.reader); + _writer = std::move(exchange_res.writer); + _process = std::move(process); + _inited = true; + + return Status::OK(); +} + +Status PythonUDAFClient::create(int64_t place_id) { + if (!_process->is_alive()) { + return Status::RuntimeError("Python UDAF process is not alive"); + } + + // Create unified batch for CREATE operation (no metadata or data needed) + std::shared_ptr<arrow::RecordBatch> batch; + RETURN_IF_ERROR( + create_unified_batch(UDAFOperation::CREATE, place_id, nullptr, nullptr, &batch)); + + // Send to server and check response + std::shared_ptr<arrow::RecordBatch> output; + RETURN_IF_ERROR(_send_operation(batch.get(), &output)); + + // Validate response: [success: bool] + if (output->num_columns() != 1 || output->num_rows() != 1) { + return Status::InternalError("Invalid CREATE response from Python UDAF server"); + } + + std::shared_ptr<arrow::BooleanArray> bool_array; + RETURN_IF_ERROR(validate_and_cast_column(output, 0, "CREATE", &bool_array)); + + if (!bool_array->Value(0)) { + return Status::InternalError("CREATE operation failed for place_id={}", place_id); + } + + _created_states.insert(place_id); + return Status::OK(); +} + +Status PythonUDAFClient::accumulate(int64_t place_id, bool is_single_place, + const arrow::RecordBatch& input, int64_t row_start, + int64_t row_end, const int64_t* places, int64_t place_offset) { + if (!_process->is_alive()) { + return Status::RuntimeError("Python UDAF process is not alive"); + } + + // Build metadata RecordBatch + // Schema: [is_single_place: bool, row_start: int64, row_end: int64, place_offset: int64] + arrow::BooleanBuilder single_place_builder; + arrow::Int64Builder row_start_builder; + arrow::Int64Builder row_end_builder; + arrow::Int64Builder offset_builder; + + RETURN_DORIS_STATUS_IF_ERROR(single_place_builder.Append(is_single_place)); + RETURN_DORIS_STATUS_IF_ERROR(row_start_builder.Append(row_start)); + RETURN_DORIS_STATUS_IF_ERROR(row_end_builder.Append(row_end)); + RETURN_DORIS_STATUS_IF_ERROR(offset_builder.Append(place_offset)); + + std::shared_ptr<arrow::Array> single_place_array; + std::shared_ptr<arrow::Array> row_start_array; + std::shared_ptr<arrow::Array> row_end_array; + std::shared_ptr<arrow::Array> offset_array; + + RETURN_DORIS_STATUS_IF_ERROR(single_place_builder.Finish(&single_place_array)); + RETURN_DORIS_STATUS_IF_ERROR(row_start_builder.Finish(&row_start_array)); + RETURN_DORIS_STATUS_IF_ERROR(row_end_builder.Finish(&row_end_array)); + RETURN_DORIS_STATUS_IF_ERROR(offset_builder.Finish(&offset_array)); + + auto metadata_batch = arrow::RecordBatch::Make( + kAccumulateMetadataSchema, 1, + {single_place_array, row_start_array, row_end_array, offset_array}); + + // Serialize metadata + std::shared_ptr<arrow::Buffer> metadata_buffer; + RETURN_IF_ERROR(serialize_record_batch(*metadata_batch, &metadata_buffer)); + + // Build data RecordBatch (input columns + optional places array) + std::vector<std::shared_ptr<arrow::Field>> data_fields; + std::vector<std::shared_ptr<arrow::Array>> data_arrays; + + // Add input columns + for (int i = 0; i < input.num_columns(); ++i) { + data_fields.push_back(input.schema()->field(i)); + data_arrays.push_back(input.column(i)); + } + + // Add places array if needed + if (!is_single_place && places != nullptr) { + arrow::Int64Builder places_builder; + for (int64_t i = 0; i < input.num_rows(); ++i) { + RETURN_DORIS_STATUS_IF_ERROR(places_builder.Append(places[i])); + } + std::shared_ptr<arrow::Array> places_array; + RETURN_DORIS_STATUS_IF_ERROR(places_builder.Finish(&places_array)); + + data_fields.push_back(arrow::field("places", arrow::int64())); + data_arrays.push_back(places_array); + } + + auto data_schema = arrow::schema(data_fields); + auto data_batch = arrow::RecordBatch::Make(data_schema, input.num_rows(), data_arrays); + + // Serialize data + std::shared_ptr<arrow::Buffer> data_buffer; + RETURN_IF_ERROR(serialize_record_batch(*data_batch, &data_buffer)); + + // Create unified batch + std::shared_ptr<arrow::RecordBatch> batch; + RETURN_IF_ERROR(create_unified_batch(UDAFOperation::ACCUMULATE, place_id, metadata_buffer, + data_buffer, &batch)); + + // Send to server and check rows_processed + std::shared_ptr<arrow::RecordBatch> output; + RETURN_IF_ERROR(_send_operation(batch.get(), &output)); + + // Validate response: [rows_processed: int64] + if (output->num_columns() != 1 || output->num_rows() != 1) { + return Status::InternalError("Invalid ACCUMULATE response from Python UDAF server"); + } + + std::shared_ptr<arrow::Int64Array> int64_array; + RETURN_IF_ERROR(validate_and_cast_column(output, 0, "ACCUMULATE", &int64_array)); + + int64_t rows_processed = int64_array->Value(0); + int64_t expected_rows = row_end - row_start; + + if (rows_processed < expected_rows) { + return Status::InternalError( + "ACCUMULATE operation only processed {} out of {} rows for place_id={}", + rows_processed, expected_rows, place_id); + } + + return Status::OK(); +} + +Status PythonUDAFClient::serialize(int64_t place_id, + std::shared_ptr<arrow::Buffer>* serialized_state) { + if (!_process->is_alive()) { + return Status::RuntimeError("Python UDAF process is not alive"); + } + + // Create unified batch for SERIALIZE operation (no metadata or data needed) + std::shared_ptr<arrow::RecordBatch> batch; + RETURN_IF_ERROR( + create_unified_batch(UDAFOperation::SERIALIZE, place_id, nullptr, nullptr, &batch)); + + // Send and receive serialized state + std::shared_ptr<arrow::RecordBatch> output; + RETURN_IF_ERROR(_send_operation(batch.get(), &output)); + + // Extract serialized state from output (should be a binary column) + if (output->num_columns() != 1 || output->num_rows() != 1) { + return Status::InternalError("Invalid SERIALIZE response from Python UDAF server"); + } + + std::shared_ptr<arrow::BinaryArray> binary_array; + RETURN_IF_ERROR(validate_and_cast_column(output, 0, "SERIALIZE", &binary_array)); + + int32_t length; + const uint8_t* data = binary_array->GetValue(0, &length); + + // Check if serialization succeeded (non-empty binary) + if (length == 0) { + return Status::InternalError("SERIALIZE operation failed for place_id={}", place_id); + } + + *serialized_state = arrow::Buffer::Wrap(data, length); + + return Status::OK(); +} + +Status PythonUDAFClient::merge(int64_t place_id, + const std::shared_ptr<arrow::Buffer>& serialized_state) { + if (!_process->is_alive()) { + return Status::RuntimeError("Python UDAF process is not alive"); + } + + // Create unified batch for MERGE operation (data = serialized_state) + std::shared_ptr<arrow::RecordBatch> batch; + RETURN_IF_ERROR(create_unified_batch(UDAFOperation::MERGE, place_id, nullptr, serialized_state, + &batch)); + + // Send to server and check response + std::shared_ptr<arrow::RecordBatch> output; + RETURN_IF_ERROR(_send_operation(batch.get(), &output)); + + // Validate response: [success: bool] + if (output->num_columns() != 1 || output->num_rows() != 1) { + return Status::InternalError("Invalid MERGE response from Python UDAF server"); + } + + std::shared_ptr<arrow::BooleanArray> bool_array; + RETURN_IF_ERROR(validate_and_cast_column(output, 0, "MERGE", &bool_array)); + + if (!bool_array->Value(0)) { + return Status::InternalError("MERGE operation failed for place_id={}", place_id); + } + + return Status::OK(); +} + +Status PythonUDAFClient::finalize(int64_t place_id, std::shared_ptr<arrow::RecordBatch>* output) { + if (!_process->is_alive()) { + return Status::RuntimeError("Python UDAF process is not alive"); + } + + // Create unified batch for FINALIZE operation (no metadata or data needed) + std::shared_ptr<arrow::RecordBatch> batch; + RETURN_IF_ERROR( + create_unified_batch(UDAFOperation::FINALIZE, place_id, nullptr, nullptr, &batch)); + + // Send and receive final result + RETURN_IF_ERROR(_send_operation(batch.get(), output)); + + // Validate basic response structure (but allow NULL values) + if (!output || !(*output) || (*output)->num_columns() != 1 || (*output)->num_rows() != 1) { + return Status::InternalError( + "Invalid FINALIZE response: expected 1 column and 1 row, got {} columns and {} " + "rows", + output && (*output) ? (*output)->num_columns() : 0, + output && (*output) ? (*output)->num_rows() : 0); + } + + return Status::OK(); +} + +Status PythonUDAFClient::reset(int64_t place_id) { + if (!_process->is_alive()) { + return Status::RuntimeError("Python UDAF process is not alive"); + } + + // Create unified batch for RESET operation (no metadata or data needed) + std::shared_ptr<arrow::RecordBatch> batch; + RETURN_IF_ERROR(create_unified_batch(UDAFOperation::RESET, place_id, nullptr, nullptr, &batch)); + + // Send to server and check response + std::shared_ptr<arrow::RecordBatch> output; + RETURN_IF_ERROR(_send_operation(batch.get(), &output)); + + // Validate response: [success: bool] + if (output->num_columns() != 1 || output->num_rows() != 1) { + return Status::InternalError("Invalid RESET response from Python UDAF server"); + } + + std::shared_ptr<arrow::BooleanArray> bool_array; + RETURN_IF_ERROR(validate_and_cast_column(output, 0, "RESET", &bool_array)); + + if (!bool_array->Value(0)) { + return Status::InternalError("RESET operation failed for place_id={}", place_id); + } + + return Status::OK(); +} + +Status PythonUDAFClient::destroy(int64_t place_id) { + if (!_process->is_alive()) { + return Status::RuntimeError("Python UDAF process is not alive"); + } + + // Create unified batch for DESTROY operation (no metadata or data needed) + std::shared_ptr<arrow::RecordBatch> batch; + RETURN_IF_ERROR( + create_unified_batch(UDAFOperation::DESTROY, place_id, nullptr, nullptr, &batch)); + + // Send to server and check response + std::shared_ptr<arrow::RecordBatch> output; + RETURN_IF_ERROR(_send_operation(batch.get(), &output)); + + // Validate response: [success: bool] + if (output->num_columns() != 1 || output->num_rows() != 1) { + return Status::InternalError("Invalid DESTROY response from Python UDAF server"); + } + + std::shared_ptr<arrow::BooleanArray> bool_array; + RETURN_IF_ERROR(validate_and_cast_column(output, 0, "DESTROY", &bool_array)); + + if (!bool_array->Value(0)) { + // Log warning but don't fail - cleanup should be best-effort Review Comment: what case call the destory will be failed? -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
