HappenLee commented on code in PR #57868: URL: https://github.com/apache/doris/pull/57868#discussion_r2554048809
########## be/src/udf/python/python_udaf_client.cpp: ########## @@ -0,0 +1,613 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "udf/python/python_udaf_client.h" + +#include "arrow/builder.h" +#include "arrow/flight/client.h" +#include "arrow/flight/server.h" +#include "arrow/io/memory.h" +#include "arrow/ipc/writer.h" +#include "arrow/record_batch.h" +#include "arrow/type.h" +#include "common/compiler_util.h" +#include "common/status.h" +#include "udf/python/python_udf_meta.h" +#include "udf/python/python_udf_runtime.h" +#include "util/arrow/utils.h" + +namespace doris { + +// Unified Schema for ALL UDAF operations +// This ensures gRPC Flight Stream uses the same schema for all RecordBatches +// Fields: [operation: int8, place_id: int64, metadata: binary, data: binary] +// - operation: UDAFOperation enum value +// - place_id: Aggregate state identifier +// - metadata: Serialized metadata (e.g., is_single_place, row_start, row_end, place_offset) +// - data: Serialized operation-specific data (e.g., input RecordBatch, serialized_state) +static const std::shared_ptr<arrow::Schema> kUnifiedUDAFSchema = arrow::schema({ + arrow::field("operation", arrow::int8()), + arrow::field("place_id", arrow::int64()), + arrow::field("metadata", arrow::binary()), + arrow::field("data", arrow::binary()), +}); + +// Metadata Schema for ACCUMULATE operation +// Fields: [is_single_place: bool, row_start: int64, row_end: int64, place_offset: int64] +static const std::shared_ptr<arrow::Schema> kAccumulateMetadataSchema = arrow::schema({ + arrow::field("is_single_place", arrow::boolean()), + arrow::field("row_start", arrow::int64()), + arrow::field("row_end", arrow::int64()), + arrow::field("place_offset", arrow::int64()), +}); + +// Helper function to serialize RecordBatch to binary +static Status serialize_record_batch(const arrow::RecordBatch& batch, + std::shared_ptr<arrow::Buffer>* out) { + auto output_stream_result = arrow::io::BufferOutputStream::Create(); + if (UNLIKELY(!output_stream_result.ok())) { + return Status::InternalError("Failed to create buffer output stream: {}", + output_stream_result.status().message()); + } + auto output_stream = std::move(output_stream_result).ValueOrDie(); + + auto writer_result = arrow::ipc::MakeStreamWriter(output_stream, batch.schema()); + if (UNLIKELY(!writer_result.ok())) { + return Status::InternalError("Failed to create IPC writer: {}", + writer_result.status().message()); + } + auto writer = std::move(writer_result).ValueOrDie(); + + RETURN_DORIS_STATUS_IF_ERROR(writer->WriteRecordBatch(batch)); + RETURN_DORIS_STATUS_IF_ERROR(writer->Close()); + + auto buffer_result = output_stream->Finish(); + if (UNLIKELY(!buffer_result.ok())) { + return Status::InternalError("Failed to finish buffer: {}", + buffer_result.status().message()); + } + *out = std::move(buffer_result).ValueOrDie(); + return Status::OK(); +} + +// Helper function to deserialize RecordBatch from binary +static Status deserialize_record_batch(const std::shared_ptr<arrow::Buffer>& buffer, + std::shared_ptr<arrow::RecordBatch>* out) { + // Create BufferReader from the input buffer + auto input_stream = std::make_shared<arrow::io::BufferReader>(buffer); + + // Open IPC stream reader + auto reader_result = arrow::ipc::RecordBatchStreamReader::Open(input_stream); + if (UNLIKELY(!reader_result.ok())) { + return Status::InternalError("Failed to open IPC reader: {}", + reader_result.status().message()); + } + auto reader = std::move(reader_result).ValueOrDie(); + + // Read the first (and only) RecordBatch + auto batch_result = reader->Next(); + if (UNLIKELY(!batch_result.ok())) { + return Status::InternalError("Failed to read RecordBatch: {}", + batch_result.status().message()); + } + + *out = std::move(batch_result).ValueOrDie(); + if (UNLIKELY(!*out)) { + return Status::InternalError("Deserialized RecordBatch is null"); + } + + return Status::OK(); +} + +// Helper function to validate and cast Arrow column to expected type +template <typename ArrowArrayType> +static Status validate_and_cast_column(const std::shared_ptr<arrow::RecordBatch>& batch, Review Comment: call call the func the column_index is 0, we better del the arg -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
