Tartarus0zm commented on code in PR #2069:
URL: https://github.com/apache/auron/pull/2069#discussion_r2882666091


##########
native-engine/datafusion-ext-plans/src/flink/serde/pb_deserializer.rs:
##########
@@ -0,0 +1,2191 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements.  See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License.  You may obtain a copy of the License at
+//
+//    http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+use std::{
+    any::Any,
+    cell::UnsafeCell,
+    collections::{HashMap, HashSet},
+    io::Cursor,
+    sync::Arc,
+};
+
+use arrow::array::{
+    Array, ArrayBuilder, ArrayRef, BinaryArray, BinaryBuilder, BooleanBuilder, 
Float32Builder,
+    Float64Builder, Int32Array, Int32Builder, Int64Array, Int64Builder, 
RecordBatch,
+    RecordBatchOptions, StringBuilder, StructArray, 
TimestampMillisecondBuilder, UInt32Builder,
+    UInt64Builder, new_null_array,
+};
+use arrow_schema::{DataType, Field, FieldRef, Fields, Schema, SchemaRef, 
TimeUnit};
+use bytes::Buf;
+use datafusion::{
+    common::ExprSchema, error::DataFusionError, 
logical_expr::UserDefinedLogicalNode,
+};
+use datafusion_ext_commons::{df_execution_err, downcast_any};
+use prost::{
+    DecodeError,
+    encoding::{DecodeContext, WireType},
+};
+use prost_reflect::{DescriptorPool, FieldDescriptor, Kind, MessageDescriptor, 
UnknownField};
+
+use crate::flink::serde::{
+    flink_deserializer::FlinkDeserializer, 
shared_array_builder::SharedArrayBuilder,
+    shared_list_array_builder::SharedListArrayBuilder,
+    shared_map_array_builder::SharedMapArrayBuilder,
+    shared_struct_array_builder::SharedStructArrayBuilder,
+};
+
+pub struct PbDeserializer {
+    output_schema: SchemaRef,
+    output_schema_without_meta: SchemaRef,
+    pb_schema: SchemaRef,
+    output_array_builders: Vec<SharedArrayBuilder>,
+    ensure_size: Box<dyn FnMut(usize) + Send>,
+    value_handlers: hashbrown::HashMap<
+        u32,
+        Box<dyn Fn(&mut Cursor<&[u8]>, u32, WireType) -> 
datafusion::error::Result<()> + Send>,
+        foldhash::fast::RandomState,
+    >,
+    msg_mapping: Vec<Vec<usize>>,
+}
+
+impl FlinkDeserializer for PbDeserializer {
+    fn parse_messages_with_kafka_meta(
+        &mut self,
+        messages: &BinaryArray,
+        kafka_partition: &Int32Array,
+        kafka_offset: &Int64Array,
+        kafka_timestamp: &Int64Array,
+    ) -> datafusion::common::Result<RecordBatch> {
+        let mut msg_cursors = messages
+            .iter()
+            .map(|v| {
+                let s = v.expect("message bytes must not be null");
+                Cursor::new(s)
+            })
+            .collect::<Vec<_>>();
+        for (row_idx, msg_cursor) in msg_cursors.iter_mut().enumerate() {
+            while msg_cursor.has_remaining() {
+                let (tag, wired_type) = 
prost::encoding::decode_key(msg_cursor).map_err(|e| {
+                    DataFusionError::Execution(format!("Failed to parse 
protobuf key: {e}"))
+                })?;
+                if let Some(value_handler) = self.value_handlers.get_mut(&tag) 
{
+                    value_handler(msg_cursor, tag, wired_type)?;
+                }
+            }
+            let ensure_size = &mut self.ensure_size;
+            ensure_size(row_idx + 1);
+        }
+
+        let root_struct = StructArray::from({
+            RecordBatch::try_new_with_options(
+                self.pb_schema.clone(),
+                self.output_array_builders
+                    .iter()
+                    .map(|builder| builder.get_dyn_mut().finish())
+                    .collect(),
+                
&RecordBatchOptions::new().with_row_count(Some(messages.len())),
+            )?
+        });
+        let mut output_arrays: Vec<ArrayRef> = Vec::new();
+        output_arrays.push(Arc::new(kafka_partition.clone()));
+        output_arrays.push(Arc::new(kafka_offset.clone()));
+        output_arrays.push(Arc::new(kafka_timestamp.clone()));
+        for (field_idx, field) in 
self.output_schema_without_meta.fields().iter().enumerate() {
+            let array_ref: ArrayRef = get_output_array(&root_struct, 
&self.msg_mapping[field_idx])?;
+            if array_ref.null_count() == array_ref.len() {
+                output_arrays.push(new_null_array(field.data_type(), 
array_ref.len()));
+            } else {
+                output_arrays.push(
+                    datafusion_ext_commons::arrow::cast::cast(&array_ref, 
field.data_type())
+                        .expect("Failed to cast array"),
+                );
+            }
+        }
+        let batch = RecordBatch::try_new_with_options(
+            self.output_schema.clone(),
+            output_arrays,
+            &RecordBatchOptions::new().with_row_count(Some(messages.len())),
+        )?;
+        Ok(batch)
+    }
+}
+
+impl PbDeserializer {
+    pub fn new(
+        proto_desc_data: impl AsRef<[u8]>,
+        message_name: &str,
+        output_schema: SchemaRef,
+        // Protobuf data may contain deeply nested hierarchies, supporting the 
extraction of
+        // certain fields to the topmost layer of the Flink output. 
{"flink_output_col1":
+        // "pb_field1.pb_sub_field2", "flink_output_col2":
+        // "pb_field1.pb_sub_field3.pb_sub_sub_field1"}
+        nested_msg_mapping: &HashMap<String, String>,
+        skip_fields: &[String],
+    ) -> datafusion::error::Result<Self> {
+        let pool: DescriptorPool =
+            DescriptorPool::decode(proto_desc_data.as_ref()).map_err(|e| {
+                DataFusionError::Execution(format!("Failed to parse descriptor 
file: {e}"))
+            })?;
+
+        for message in pool.all_messages() {
+            if message.name() == message_name {
+                return Self::try_new(message, output_schema, 
nested_msg_mapping, skip_fields);
+            }
+        }
+        Err(DataFusionError::Execution(format!(
+            "Message '{message_name}' not found"
+        )))
+    }
+
+    pub fn try_new(
+        message_descriptor: MessageDescriptor,
+        output_schema: SchemaRef,
+        nested_msg_mapping: &HashMap<String, String>,
+        skip_fields: &[String],
+    ) -> datafusion::error::Result<Self> {
+        // The output schema includes Kafka's meta fields, but these are 
absent in the
+        // PB data, so they must be filtered out.
+        let output_schema_without_meta = Arc::new(Schema::new(
+            output_schema
+                .fields()
+                .iter()
+                .filter(|f| {
+                    f.name() != "serialized_kafka_records_partition"
+                        && f.name() != "serialized_kafka_records_offset"
+                        && f.name() != "serialized_kafka_records_timestamp"
+                })
+                .cloned()
+                .collect::<Fields>(),
+        ));
+        // Schema inferred from the PB descriptor.
+        let pb_schema = transfer_output_schema_to_pb_schema(
+            message_descriptor.clone(),
+            &output_schema_without_meta,
+            nested_msg_mapping.clone(),
+            &skip_fields,
+        )
+        .expect("Failed to transfer output scheam to pb scheam");
+
+        let tag_to_output_mapping =
+            create_tag_to_output_mapping(message_descriptor.clone(), 
&pb_schema);
+
+        let output_array_builders =
+            create_output_array_builders(&pb_schema, 
message_descriptor.clone())?;
+        let ensure_size = 
ensure_output_array_builders_size(&output_array_builders)?;
+
+        let value_handlers = message_descriptor
+            .fields()
+            .map(|field| {
+                Ok((
+                    field.number(),
+                    create_value_handler(
+                        &message_descriptor,
+                        field.number(),
+                        &tag_to_output_mapping,
+                        &pb_schema,
+                        &output_array_builders,
+                    )?,
+                ))
+            })
+            .collect::<datafusion::error::Result<hashbrown::HashMap<_, _, 
foldhash::fast::RandomState>>>()?;
+
+        // precompute message mappings
+        let msg_mapping = output_schema_without_meta
+            .fields()
+            .iter()
+            .map(|field| {
+                let mut mapped_field_indices = vec![];
+                let mut cur_fields = pb_schema.fields();
+
+                if let Some(nested) = nested_msg_mapping.get(field.name()) {
+                    let nested_fields = nested.split(".").collect::<Vec<_>>();
+                    for nested_field in &nested_fields[..nested_fields.len() - 
1] {
+                        match cur_fields.find(nested_field) {
+                            Some((idx, f)) => {
+                                if let DataType::Struct(fields) = 
f.data_type() {
+                                    mapped_field_indices.push(idx);
+                                    cur_fields = fields;
+                                } else {
+                                    panic!("nested field must be struct");

Review Comment:
   good suggestion



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to