This is an automated email from the ASF dual-hosted git repository. richox pushed a commit to branch dev-v6.0.0-decimal-cast in repository https://gitbox.apache.org/repos/asf/auron.git
commit bcbe6736a7ade7185ef6896eba3aa511f5464a23 Author: zhangli20 <[email protected]> AuthorDate: Thu Jan 22 15:29:41 2026 +0800 optimize UDAF wrapper --- native-engine/blaze-jni-bridge/src/jni_bridge.rs | 24 ++-- .../datafusion-ext-plans/src/agg/count.rs | 132 ++++++++++++++++---- .../src/agg/spark_udaf_wrapper.rs | 65 +++------- .../spark/sql/blaze/SparkUDAFWrapperContext.scala | 137 +++++++++------------ 4 files changed, 199 insertions(+), 159 deletions(-) diff --git a/native-engine/blaze-jni-bridge/src/jni_bridge.rs b/native-engine/blaze-jni-bridge/src/jni_bridge.rs index 44b53508..f20cd17a 100644 --- a/native-engine/blaze-jni-bridge/src/jni_bridge.rs +++ b/native-engine/blaze-jni-bridge/src/jni_bridge.rs @@ -1228,10 +1228,10 @@ pub struct SparkUDAFWrapperContext<'a> { pub method_merge_ret: ReturnType, pub method_eval: JMethodID, pub method_eval_ret: ReturnType, - pub method_serializeRows: JMethodID, - pub method_serializeRows_ret: ReturnType, - pub method_deserializeRows: JMethodID, - pub method_deserializeRows_ret: ReturnType, + pub method_exportRows: JMethodID, + pub method_exportRows_ret: ReturnType, + pub method_importRows: JMethodID, + pub method_importRows_ret: ReturnType, pub method_spill: JMethodID, pub method_spill_ret: ReturnType, pub method_unspill: JMethodID, @@ -1281,18 +1281,18 @@ impl<'a> SparkUDAFWrapperContext<'a> { "(Lorg/apache/spark/sql/blaze/BufferRowsColumn;[IJ)V", )?, method_eval_ret: ReturnType::Primitive(Primitive::Void), - method_serializeRows: env.get_method_id( + method_exportRows: env.get_method_id( class, - "serializeRows", - "(Lorg/apache/spark/sql/blaze/BufferRowsColumn;[I)[B", + "exportRows", + "(Lorg/apache/spark/sql/blaze/BufferRowsColumn;[IJ)V", )?, - method_serializeRows_ret: ReturnType::Array, - method_deserializeRows: env.get_method_id( + method_exportRows_ret: ReturnType::Array, + method_importRows: env.get_method_id( class, - "deserializeRows", - "(Ljava/nio/ByteBuffer;)Lorg/apache/spark/sql/blaze/BufferRowsColumn;", + "importRows", + "(J)Lorg/apache/spark/sql/blaze/BufferRowsColumn;", )?, - method_deserializeRows_ret: ReturnType::Object, + method_importRows_ret: ReturnType::Object, method_spill: env.get_method_id( class, "spill", diff --git a/native-engine/datafusion-ext-plans/src/agg/count.rs b/native-engine/datafusion-ext-plans/src/agg/count.rs index 61508c77..ad4b71cf 100644 --- a/native-engine/datafusion-ext-plans/src/agg/count.rs +++ b/native-engine/datafusion-ext-plans/src/agg/count.rs @@ -20,14 +20,18 @@ use std::{ use arrow::{array::*, datatypes::*}; use datafusion::{common::Result, physical_expr::PhysicalExprRef}; -use datafusion_ext_commons::downcast_any; +use datafusion_ext_commons::{ + downcast_any, + io::{read_len, write_len}, +}; use crate::{ agg::{ - acc::{AccColumn, AccColumnRef, AccPrimColumn}, + acc::{AccColumn, AccColumnRef}, agg::{Agg, IdxSelection}, }, - idx_for_zipped, + idx_for, idx_for_zipped, + memmgr::spill::{SpillCompressedReader, SpillCompressedWriter}, }; pub struct AggCount { @@ -76,11 +80,9 @@ impl Agg for AggCount { } fn create_acc_column(&self, num_rows: usize) -> Box<dyn AccColumn> { - Box::new(AccPrimColumn::<i64>::new(num_rows, DataType::Int64)) - } - - fn acc_array_data_types(&self) -> &[DataType] { - &[DataType::Int64] + Box::new(AccCountColumn { + values: vec![0; num_rows], + }) } fn partial_update( @@ -90,15 +92,32 @@ impl Agg for AggCount { partial_args: &[ArrayRef], partial_arg_idx: IdxSelection<'_>, ) -> Result<()> { - let accs = downcast_any!(accs, mut AccPrimColumn<i64>)?; + let accs = downcast_any!(accs, mut AccCountColumn)?; accs.ensure_size(acc_idx); - idx_for_zipped! { - ((acc_idx, partial_arg_idx) in (acc_idx, partial_arg_idx)) => { - let add = partial_args - .iter() - .all(|arg| arg.is_valid(partial_arg_idx)) as i64; - accs.set_value(acc_idx, Some(accs.value(acc_idx).unwrap_or(0) + add)); + if partial_args.is_empty() { + idx_for_zipped! { + ((acc_idx, _partial_arg_idx) in (acc_idx, partial_arg_idx)) => { + if acc_idx >= accs.values.len() { + accs.values.push(1); + } else { + accs.values[acc_idx] += 1; + } + } + } + } else { + idx_for_zipped! { + ((acc_idx, partial_arg_idx) in (acc_idx, partial_arg_idx)) => { + let add = partial_args + .iter() + .all(|arg| arg.is_valid(partial_arg_idx)) as i64; + + if acc_idx >= accs.values.len() { + accs.values.push(add); + } else { + accs.values[acc_idx] += add; + } + } } } Ok(()) @@ -111,19 +130,17 @@ impl Agg for AggCount { merging_accs: &mut AccColumnRef, merging_acc_idx: IdxSelection<'_>, ) -> Result<()> { - let accs = downcast_any!(accs, mut AccPrimColumn<i64>)?; - let merging_accs = downcast_any!(merging_accs, mut AccPrimColumn<i64>)?; + let accs = downcast_any!(accs, mut AccCountColumn)?; + let merging_accs = downcast_any!(merging_accs, mut AccCountColumn)?; accs.ensure_size(acc_idx); idx_for_zipped! { ((acc_idx, merging_acc_idx) in (acc_idx, merging_acc_idx)) => { - let v = match (accs.value(acc_idx), merging_accs.value(merging_acc_idx)) { - (Some(a), Some(b)) => Some(a + b), - (Some(a), _) => Some(a), - (_, Some(b)) => Some(b), - _ => Some(0), - }; - accs.set_value(acc_idx, v); + if acc_idx < accs.values.len() { + accs.values[acc_idx] += merging_accs.values[merging_acc_idx]; + } else { + accs.values.push(merging_accs.values[merging_acc_idx]); + } } } Ok(()) @@ -132,4 +149,71 @@ impl Agg for AggCount { fn final_merge(&self, accs: &mut AccColumnRef, acc_idx: IdxSelection<'_>) -> Result<ArrayRef> { Ok(accs.freeze_to_arrays(acc_idx)?[0].clone()) } + + fn acc_array_data_types(&self) -> &[DataType] { + &[DataType::Int64] + } +} + +pub struct AccCountColumn { + pub values: Vec<i64>, +} + +impl AccColumn for AccCountColumn { + fn as_any(&self) -> &dyn Any { + self + } + + fn as_any_mut(&mut self) -> &mut dyn Any { + self + } + + fn resize(&mut self, num_accs: usize) { + self.values.resize(num_accs, 0); + } + + fn shrink_to_fit(&mut self) { + self.values.shrink_to_fit(); + } + + fn num_records(&self) -> usize { + self.values.len() + } + + fn mem_used(&self) -> usize { + self.values.capacity() * 2 * size_of::<i64>() + } + + fn freeze_to_arrays(&mut self, idx: IdxSelection<'_>) -> Result<Vec<ArrayRef>> { + let mut values = Vec::with_capacity(idx.len()); + idx_for! { + (idx in idx) => { + values.push(self.values[idx]); + } + } + Ok(vec![Arc::new(Int64Array::from(values))]) + } + + fn unfreeze_from_arrays(&mut self, arrays: &[ArrayRef]) -> Result<()> { + let array = downcast_any!(arrays[0], Int64Array)?; + self.values = array.iter().map(|v| v.unwrap_or(0)).collect(); + Ok(()) + } + + fn spill(&mut self, idx: IdxSelection<'_>, w: &mut SpillCompressedWriter) -> Result<()> { + idx_for! { + (idx in idx) => { + write_len(self.values[idx] as usize, w)?; + } + } + Ok(()) + } + + fn unspill(&mut self, num_rows: usize, r: &mut SpillCompressedReader) -> Result<()> { + assert_eq!(self.num_records(), 0, "expect empty AccColumn"); + for _ in 0..num_rows { + self.values.push(read_len(r)? as i64); + } + Ok(()) + } } diff --git a/native-engine/datafusion-ext-plans/src/agg/spark_udaf_wrapper.rs b/native-engine/datafusion-ext-plans/src/agg/spark_udaf_wrapper.rs index c891285f..1dada28f 100644 --- a/native-engine/datafusion-ext-plans/src/agg/spark_udaf_wrapper.rs +++ b/native-engine/datafusion-ext-plans/src/agg/spark_udaf_wrapper.rs @@ -15,21 +15,21 @@ use std::{ any::Any, fmt::{Debug, Display, Formatter}, - io::{Cursor, Read, Write}, sync::Arc, }; use arrow::{ array::{ - Array, ArrayAccessor, ArrayRef, BinaryArray, BinaryBuilder, StructArray, as_struct_array, + Array, ArrayRef, StructArray, as_struct_array, make_array, }, datatypes::{DataType, Field, Schema, SchemaRef}, ffi::{FFI_ArrowArray, FFI_ArrowSchema, from_ffi}, record_batch::{RecordBatch, RecordBatchOptions}, }; +use arrow::ffi::from_ffi_and_data_type; use blaze_jni_bridge::{ - jni_bridge::LocalRef, jni_call, jni_get_byte_array_len, jni_get_byte_array_region, + jni_bridge::LocalRef, jni_call, jni_new_direct_byte_buffer, jni_new_global_ref, jni_new_object, jni_new_prim_array, }; use datafusion::{ @@ -37,7 +37,7 @@ use datafusion::{ physical_expr::PhysicalExprRef, }; use datafusion_ext_commons::{ - UninitializedInit, downcast_any, + downcast_any, io::{read_len, write_len}, }; use jni::objects::{GlobalRef, JObject}; @@ -300,30 +300,22 @@ impl AccUDAFBufferRowsColumn { idx: IdxSelection<'_>, cache: &OnceCell<LocalRef>, ) -> Result<ArrayRef> { + + let mut ffi_exported_rows = FFI_ArrowArray::empty(); let idx_array = cache.get_or_try_init(move || jni_new_prim_array!(int, &idx.to_int32_vec()[..]))?; - let serialized = jni_call!( - SparkUDAFWrapperContext(self.jcontext.as_obj()).serializeRows( + jni_call!( + SparkUDAFWrapperContext(self.jcontext.as_obj()).exportRows( self.obj.as_obj(), idx_array.as_obj(), - ) -> JObject)?; - let serialized_len = jni_get_byte_array_len!(serialized.as_obj())?; - let mut serialized_bytes = Vec::uninitialized_init(serialized_len); - jni_get_byte_array_region!(serialized.as_obj(), 0, &mut serialized_bytes[..])?; - - // UnsafeRow is serialized with big-endian i32 length prefix - let mut serialized_pos = 0; - let mut binary_builder = BinaryBuilder::with_capacity(idx.len(), 0); - for i in 0..idx.len() { - let mut bytes_len_buf = [0u8; 4]; - bytes_len_buf.copy_from_slice(&serialized_bytes[serialized_pos..][..4]); - let bytes_len = i32::from_be_bytes(bytes_len_buf) as usize; - serialized_pos += 4; - - binary_builder.append_value(&serialized_bytes[serialized_pos..][..bytes_len]); - serialized_pos += bytes_len; - } - Ok(Arc::new(binary_builder.finish())) + &mut ffi_exported_rows as *mut FFI_ArrowArray as i64, + ) -> ())?; + let exported_rows_data = unsafe { + // safety: import output binary array from SparkUDAFWrapperContext.exportedRows() + from_ffi_and_data_type(ffi_exported_rows, DataType::Binary)? + }; + let exported_rows = make_array(exported_rows_data); + Ok(exported_rows) } pub fn spill_with_indices_cache( @@ -404,29 +396,12 @@ impl AccColumn for AccUDAFBufferRowsColumn { fn unfreeze_from_arrays(&mut self, arrays: &[ArrayRef]) -> Result<()> { assert_eq!(self.num_records(), 0, "expect empty AccColumn"); - let array = downcast_any!(arrays[0], BinaryArray)?; - - let mut cursors = vec![]; - for i in 0..array.len() { - cursors.push(Cursor::new(array.value(i))); - } - - let mut data = vec![]; - for (i, cursor) in cursors.iter_mut().enumerate() { - let bytes_len = array.value(i).len(); - data.write_all((bytes_len as i32).to_be_bytes().as_ref())?; - std::io::copy(&mut cursor.take(bytes_len as u64), &mut data)?; - } - - let data_buffer = jni_new_direct_byte_buffer!(data)?; + let num_rows = arrays[0].len(); + let ffi_imported_rows = FFI_ArrowArray::new(&arrays[0].to_data()); let rows = jni_call!(SparkUDAFWrapperContext(self.jcontext.as_obj()) - .deserializeRows(data_buffer.as_obj()) -> JObject)?; + .importRows(&ffi_imported_rows as *const FFI_ArrowArray as i64) -> JObject)?; self.obj = jni_new_global_ref!(rows.as_obj())?; - assert_eq!( - self.num_records(), - cursors.len(), - "unfreeze rows count mismatch" - ); + assert_eq!(self.num_records(), num_rows, "unfreeze rows count mismatch"); Ok(()) } diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDAFWrapperContext.scala b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDAFWrapperContext.scala index ee11feba..4a44ad48 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDAFWrapperContext.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDAFWrapperContext.scala @@ -22,15 +22,14 @@ import java.io.EOFException import java.io.InputStream import java.io.OutputStream import java.nio.ByteBuffer - import scala.collection.mutable import scala.collection.mutable.ArrayBuffer - import org.apache.arrow.c.ArrowArray import org.apache.arrow.c.Data -import org.apache.arrow.vector.VectorSchemaRoot +import org.apache.arrow.vector.{FieldVector, VarBinaryVector, VectorSchemaRoot} import org.apache.arrow.vector.dictionary.DictionaryProvider import org.apache.arrow.vector.dictionary.DictionaryProvider.MapDictionaryProvider +import org.apache.arrow.vector.types.pojo.Field import org.apache.spark.TaskContext import org.apache.spark.internal.Logging import org.apache.spark.sql.blaze.memory.OnHeapSpillManager @@ -167,12 +166,12 @@ case class SparkUDAFWrapperContext[B](serialized: ByteBuffer) extends Logging { } } - def serializeRows(rows: BufferRowsColumn[B], indices: Array[Int]): Array[Byte] = { - aggEvaluator.get.serializeRows(rows, indices.iterator) + def exportRows(rows: BufferRowsColumn[B], indices: Array[Int], outputArrowBinaryArrayPtr: Long): Unit = { + aggEvaluator.get.exportRows(rows, indices.iterator, outputArrowBinaryArrayPtr) } - def deserializeRows(dataBuffer: ByteBuffer): BufferRowsColumn[B] = { - aggEvaluator.get.deserializeRows(dataBuffer) + def importRows(inputArrowBinaryArrayPtr: Long): BufferRowsColumn[B] = { + aggEvaluator.get.importRows(inputArrowBinaryArrayPtr) } def spill( @@ -205,14 +204,8 @@ trait AggregateEvaluator[B, R <: BufferRowsColumn[B]] extends Logging { def createEmptyColumn(): R - def serializeRows( - rows: R, - indices: Iterator[Int], - streamWrapper: OutputStream => OutputStream = { s => s }): Array[Byte] - - def deserializeRows( - dataBuffer: ByteBuffer, - streamWrapper: InputStream => InputStream = { s => s }): R + def exportRows(rows: R, indices: Iterator[Int], outputArrowBinaryArrayPtr: Long): Unit + def importRows(inputArrowBinaryArrayPtr: Long): BufferRowsColumn[B] def spill( memTracker: SparkUDAFMemTracker, @@ -222,7 +215,7 @@ trait AggregateEvaluator[B, R <: BufferRowsColumn[B]] extends Logging { val hsm = OnHeapSpillManager.current val spillId = memTracker.getSpill(spillIdx) val byteBuffer = - ByteBuffer.wrap(serializeRows(rows, indices, spillCodec.compressedOutputStream)) + ByteBuffer.wrap(exportRows(rows, indices, spillCodec.compressedOutputStream)) val spillBlockSize = byteBuffer.limit() hsm.writeSpill(spillId, byteBuffer) spillBlockSize @@ -238,7 +231,7 @@ trait AggregateEvaluator[B, R <: BufferRowsColumn[B]] extends Logging { val readSize = hsm.readSpill(spillId, byteBuffer).toLong assert(readSize == spillBlockSize) byteBuffer.flip() - deserializeRows(byteBuffer, spillCodec.compressedInputStream) + importRows(byteBuffer, spillCodec.compressedInputStream) } } @@ -267,42 +260,45 @@ class DeclarativeEvaluator(val agg: DeclarativeAggregate, inputAttributes: Seq[A DeclarativeAggRowsColumn(this, ArrayBuffer()) } - override def serializeRows( + override def exportRows( rows: DeclarativeAggRowsColumn, indices: Iterator[Int], - streamWrapper: OutputStream => OutputStream): Array[Byte] = { + outputArrowBinaryArrayPtr: Long): Unit = { - val numFields = agg.aggBufferSchema.length - val outputDataStream = new ByteArrayOutputStream() - val wrappedStream = streamWrapper(outputDataStream) - val serializer = new UnsafeRowSerializer(numFields).newInstance() + Using.resource(new VarBinaryVector("output", ROOT_ALLOCATOR)) { binaryVector => + val rowDataStream = new ByteArrayOutputStream() + val rowDataBuffer = new Array[Byte](1024) - Using(serializer.serializeStream(wrappedStream)) { ser => - for (i <- indices) { - ser.writeValue(rows.rows(i)) - rows.rows(i) = releasedRow + for ((rowIdx, outputRowIdx) <- indices.zipWithIndex) { + rows.rows(rowIdx).writeToStream(rowDataStream, rowDataBuffer) + rows.rows(rowIdx) = releasedRow + binaryVector.setSafe(outputRowIdx, rowDataStream.toByteArray) + rowDataStream.reset() + } + + Using.resource(ArrowArray.wrap(outputArrowBinaryArrayPtr)) { outputArray => + Data.exportVector(ROOT_ALLOCATOR, binaryVector, new MapDictionaryProvider, outputArray) } } - wrappedStream.close() - outputDataStream.toByteArray } - override def deserializeRows( - dataBuffer: ByteBuffer, - streamWrapper: InputStream => InputStream): DeclarativeAggRowsColumn = { - val numFields = agg.aggBufferSchema.length - val deserializer = new UnsafeRowSerializer(numFields).newInstance() - val inputDataStream = new ByteBufferInputStream(dataBuffer) - val wrappedStream = streamWrapper(inputDataStream) - val rows = new ArrayBuffer[UnsafeRow]() - - Using.resource(deserializer.deserializeStream(wrappedStream)) { deser => - for (row <- deser.asKeyValueIterator.map(_._2.asInstanceOf[UnsafeRow].copy())) { + override def importRows(inputArrowBinaryArrayPtr: Long): DeclarativeAggRowsColumn = { + Using.resource(new VarBinaryVector("input", ROOT_ALLOCATOR)) { binaryVector => + Using.resource(ArrowArray.wrap(inputArrowBinaryArrayPtr)) { inputArray => + Data.importIntoVector(ROOT_ALLOCATOR, inputArray, binaryVector, new MapDictionaryProvider) + } + val numRows = binaryVector.getValueCount + val numFields = agg.aggBufferSchema.length + val rows = new ArrayBuffer[UnsafeRow]() + + for (rowIdx <- 0 until numRows) { + val row = new UnsafeRow(numFields) + val rowData = binaryVector.get(rowIdx) + row.pointTo(rowData, rowData.length) rows.append(row) } + DeclarativeAggRowsColumn(this, rows) } - wrappedStream.close() - DeclarativeAggRowsColumn(this, rows) } } @@ -378,51 +374,36 @@ class TypedImperativeEvaluator[B](val agg: TypedImperativeAggregate[B]) new TypedImperativeAggRowsColumn[B](this, ArrayBuffer()) } - override def serializeRows( + override def exportRows( rows: TypedImperativeAggRowsColumn[B], indices: Iterator[Int], - streamWrapper: OutputStream => OutputStream): Array[Byte] = { + outputArrowBinaryArrayPtr: Long): Unit = { - val outputStream = new ByteArrayOutputStream() - val wrappedStream = streamWrapper(outputStream) - val dataOut = new DataOutputStream(wrappedStream) + Using.resource(new VarBinaryVector("output", ROOT_ALLOCATOR)) { binaryVector => + for ((rowIdx, outputRowIdx) <- indices.zipWithIndex) { + binaryVector.setSafe(outputRowIdx, rows.serializedRow(rowIdx)) + rows.rows(rowIdx) = releasedRow + } - for (i <- indices) { - val bytes = rows.serializedRow(i) - dataOut.writeInt(bytes.length) - dataOut.write(bytes) - rows.rows(i) = releasedRow + Using.resource(ArrowArray.wrap(outputArrowBinaryArrayPtr)) { outputArray => + Data.exportVector(ROOT_ALLOCATOR, binaryVector, new MapDictionaryProvider, outputArray) + } } - dataOut.close() - outputStream.toByteArray - } - - override def deserializeRows( - dataBuffer: ByteBuffer, - streamWrapper: InputStream => InputStream): TypedImperativeAggRowsColumn[B] = { - val rows = ArrayBuffer[RowType]() - val inputStream = new ByteBufferInputStream(dataBuffer) - val wrappedStream = streamWrapper(inputStream) - val dataIn = new DataInputStream(wrappedStream) - var finished = false - - while (!finished) { - var length = -1 - try { - length = dataIn.readInt() - } catch { - case _: EOFException => - finished = true + + override def importRows(inputArrowBinaryArrayPtr: Long): TypedImperativeAggRowsColumn[B] = { + Using.resource(new VarBinaryVector("input", ROOT_ALLOCATOR)) { binaryVector => + Using.resource(ArrowArray.wrap(inputArrowBinaryArrayPtr)) { inputArray => + Data.importIntoVector(ROOT_ALLOCATOR, inputArray, binaryVector, new MapDictionaryProvider) } + val numRows = binaryVector.getValueCount + val rows = ArrayBuffer[RowType]() - if (!finished) { - val bytes = new Array[Byte](length) - dataIn.read(bytes) - rows.append(SerializedRowType(bytes)) + for (rowIdx <- 0 until numRows) { + val rowData = binaryVector.get(rowIdx) + rows.append(SerializedRowType(rowData)) } + TypedImperativeAggRowsColumn(this, rows) } - dataIn.close() - TypedImperativeAggRowsColumn(this, rows) } }
