This is an automated email from the ASF dual-hosted git repository. imbruced pushed a commit to branch arrow-worker in repository https://gitbox.apache.org/repos/asf/sedona.git
commit 3874a5ab5f743a8010d06c773fff13a44b95b6bc Author: pawelkocinski <[email protected]> AuthorDate: Thu Nov 13 22:19:41 2025 +0100 SEDONA-748 add working example --- .../sql/execution/python/SedonaArrowStrategy.scala | 171 +-------------------- .../sql/execution/python/SedonaArrowUtils.scala | 64 +------- .../execution/python/SedonaPythonArrowInput.scala | 1 + 3 files changed, 6 insertions(+), 230 deletions(-) diff --git a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaArrowStrategy.scala b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaArrowStrategy.scala index 3869ab24b8..ff3c027c5d 100644 --- a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaArrowStrategy.scala +++ b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaArrowStrategy.scala @@ -20,32 +20,19 @@ package org.apache.spark.sql.execution.python import org.apache.sedona.sql.UDF.PythonEvalType import org.apache.spark.api.python.ChainedPythonFunctions -import org.apache.spark.rdd.RDD import org.apache.spark.sql.Strategy import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.InternalRow.copyValue import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences -import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection} -import org.apache.spark.sql.catalyst.expressions.UnsafeProjection.createObject +import org.apache.spark.sql.catalyst.expressions.UnsafeProjection import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection -import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT -import org.apache.spark.sql.vectorized.{ColumnarBatchRow, ColumnarRow} -//import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences -//import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection -import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference, CodeGeneratorWithInterpretedFallback, Expression, InterpretedUnsafeProjection, JoinedRow, MutableProjection, Projection, PythonUDF, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, PythonUDF} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.catalyst.util.GenericArrayData import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{DataType, StructField, StructType} +import org.apache.spark.sql.types.StructType import org.apache.spark.sql.udf.SedonaArrowEvalPython -import org.apache.spark.util.Utils -import org.apache.spark.{ContextAwareIterator, JobArtifactSet, SparkEnv, TaskContext} -import org.locationtech.jts.io.WKTReader -import org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder -import java.io.File +import org.apache.spark.{JobArtifactSet, TaskContext} import scala.collection.JavaConverters.asScalaIteratorConverter -import scala.collection.mutable.ArrayBuffer // We use custom Strategy to avoid Apache Spark assert on types, we // can consider extending this to support other engines working with @@ -58,16 +45,6 @@ class SedonaArrowStrategy extends Strategy { } } -/** - * The factory object for `UnsafeProjection`. - */ -object SedonaUnsafeProjection { - - def create(exprs: Seq[Expression], inputSchema: Seq[Attribute]): UnsafeProjection = { - GenerateUnsafeProjection.generate(bindReferences(exprs, inputSchema), SQLConf.get.subexpressionEliminationEnabled) -// createObject(bindReferences(exprs, inputSchema)) - } -} // It's modification og Apache Spark's ArrowEvalPythonExec, we remove the check on the types to allow geometry types // here, it's initial version to allow the vectorized udf for Sedona geometry types. We can consider extending this // to support other engines working with arrow data @@ -112,144 +89,4 @@ case class SedonaArrowEvalPythonExec( override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = copy(child = newChild) - - private def collectFunctions(udf: PythonUDF): (ChainedPythonFunctions, Seq[Expression]) = { - udf.children match { - case Seq(u: PythonUDF) => - val (chained, children) = collectFunctions(u) - (ChainedPythonFunctions(chained.funcs ++ Seq(udf.func)), children) - case children => - // There should not be any other UDFs, or the children can't be evaluated directly. - assert(children.forall(!_.exists(_.isInstanceOf[PythonUDF]))) - (ChainedPythonFunctions(Seq(udf.func)), udf.children) - } - } - - override def doExecute(): RDD[InternalRow] = { - - val customProjection = new Projection with Serializable{ - def apply(row: InternalRow): InternalRow = { - row match { - case joinedRow: JoinedRow => - val arrowField = joinedRow.getRight.asInstanceOf[ColumnarBatchRow] - val left = joinedRow.getLeft - - -// resultAttrs.zipWithIndex.map { -// case (x, y) => -// if (x.dataType.isInstanceOf[GeometryUDT]) { -// val wkbReader = new org.locationtech.jts.io.WKBReader() -// wkbReader.read(left.getBinary(y)) -// -// println("ssss") -// } -// GeometryUDT -// left.getByte(y) -// -// left.setByte(y, 1.toByte) -// -// println(left.getByte(y)) -// } -// -// println("ssss") -// arrowField. - row - // We need to convert JoinedRow to UnsafeRow -// val leftUnsafe = left.asInstanceOf[UnsafeRow] -// val rightUnsafe = right.asInstanceOf[UnsafeRow] -// val joinedUnsafe = new UnsafeRow(leftUnsafe.numFields + rightUnsafe.numFields) -// joinedUnsafe.pointTo( -// leftUnsafe.getBaseObject, leftUnsafe.getBaseOffset, -// leftUnsafe.getSizeInBytes + rightUnsafe.getSizeInBytes) -// joinedUnsafe.setLeft(rightUnsafe) -// joinedUnsafe.setRight(leftUnsafe) -// joinedUnsafe -// val wktReader = new WKTReader() - val resultProj = SedonaUnsafeProjection.create(output, output) -// val WKBWriter = new org.locationtech.jts.io.WKBWriter() - resultProj(new JoinedRow(left, arrowField)) - case _ => - println(row.getClass) - throw new UnsupportedOperationException("Unsupported row type") - } - } - } - val inputRDD = child.execute().map(_.copy()) - - inputRDD.mapPartitions { iter => - val context = TaskContext.get() - val contextAwareIterator = new ContextAwareIterator(context, iter) - - // The queue used to buffer input rows so we can drain it to - // combine input with output from Python. - val queue = HybridRowQueue(context.taskMemoryManager(), - new File(Utils.getLocalDir(SparkEnv.get.conf)), child.output.length) - context.addTaskCompletionListener[Unit] { ctx => - queue.close() - } - - val (pyFuncs, inputs) = udfs.map(collectFunctions).unzip - - // flatten all the arguments - val allInputs = new ArrayBuffer[Expression] - val dataTypes = new ArrayBuffer[DataType] - val argOffsets = inputs.map { input => - input.map { e => - if (allInputs.exists(_.semanticEquals(e))) { - allInputs.indexWhere(_.semanticEquals(e)) - } else { - allInputs += e - dataTypes += e.dataType - allInputs.length - 1 - } - }.toArray - }.toArray - val projection = MutableProjection.create(allInputs.toSeq, child.output) - projection.initialize(context.partitionId()) - val schema = StructType(dataTypes.zipWithIndex.map { case (dt, i) => - StructField(s"_$i", dt) - }.toArray) - - // Add rows to queue to join later with the result. - val projectedRowIter = contextAwareIterator.map { inputRow => - queue.add(inputRow.asInstanceOf[UnsafeRow]) - projection(inputRow) - } - - val outputRowIterator = evaluate( - pyFuncs, argOffsets, projectedRowIter, schema, context) - - val joined = new JoinedRow - - outputRowIterator.map { outputRow => - val joinedRow = joined(queue.remove(), outputRow) - - val projected = customProjection(joinedRow) - - val numFields = projected.numFields - val startField = numFields - resultAttrs.length - println(resultAttrs.length) - - val row = new GenericInternalRow(numFields) - - resultAttrs.zipWithIndex.map { - case (attr, index) => - if (attr.dataType.isInstanceOf[GeometryUDT]) { - // Convert the geometry type to WKB - val wkbReader = new org.locationtech.jts.io.WKBReader() - val wkbWriter = new org.locationtech.jts.io.WKBWriter() - val geom = wkbReader.read(projected.getBinary(startField + index)) - - row.update(startField + index, wkbWriter.write(geom)) - - println("ssss") - } - } - - println("ssss") -// 3.2838116E-8 - row - } - } - } } diff --git a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaArrowUtils.scala b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaArrowUtils.scala index 58166d173d..ec4f7c00d0 100644 --- a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaArrowUtils.scala +++ b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaArrowUtils.scala @@ -21,71 +21,16 @@ import java.util.concurrent.atomic.AtomicInteger import scala.collection.JavaConverters._ import org.apache.arrow.memory.RootAllocator import org.apache.arrow.vector.complex.MapVector -import org.apache.arrow.vector.types.{DateUnit, FloatingPointPrecision, IntervalUnit, TimeUnit} import org.apache.arrow.vector.types.pojo.{ArrowType, Field, FieldType, Schema} import org.apache.spark.sql.errors.ExecutionErrors import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT import org.apache.spark.sql.types._ +import org.apache.spark.sql.util.ArrowUtils.{fromArrowType, toArrowType} private[sql] object SedonaArrowUtils { val rootAllocator = new RootAllocator(Long.MaxValue) - // todo: support more types. - - /** Maps data type from Spark to Arrow. NOTE: timeZoneId required for TimestampTypes */ - def toArrowType( - dt: DataType, timeZoneId: String, largeVarTypes: Boolean = false): ArrowType = dt match { - case BooleanType => ArrowType.Bool.INSTANCE - case ByteType => new ArrowType.Int(8, true) - case ShortType => new ArrowType.Int(8 * 2, true) - case IntegerType => new ArrowType.Int(8 * 4, true) - case LongType => new ArrowType.Int(8 * 8, true) - case FloatType => new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE) - case DoubleType => new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE) - case StringType if !largeVarTypes => ArrowType.Utf8.INSTANCE - case BinaryType if !largeVarTypes => ArrowType.Binary.INSTANCE - case StringType if largeVarTypes => ArrowType.LargeUtf8.INSTANCE - case BinaryType if largeVarTypes => ArrowType.LargeBinary.INSTANCE - case DecimalType.Fixed(precision, scale) => new ArrowType.Decimal(precision, scale) - case DateType => new ArrowType.Date(DateUnit.DAY) - case TimestampType if timeZoneId == null => - throw new IllegalStateException("Missing timezoneId where it is mandatory.") - case TimestampType => new ArrowType.Timestamp(TimeUnit.MICROSECOND, timeZoneId) - case TimestampNTZType => - new ArrowType.Timestamp(TimeUnit.MICROSECOND, null) - case NullType => ArrowType.Null.INSTANCE - case _: YearMonthIntervalType => new ArrowType.Interval(IntervalUnit.YEAR_MONTH) - case _: DayTimeIntervalType => new ArrowType.Duration(TimeUnit.MICROSECOND) - case _ => - throw ExecutionErrors.unsupportedDataTypeError(dt) - } - - def fromArrowType(dt: ArrowType): DataType = dt match { - case ArrowType.Bool.INSTANCE => BooleanType - case int: ArrowType.Int if int.getIsSigned && int.getBitWidth == 8 => ByteType - case int: ArrowType.Int if int.getIsSigned && int.getBitWidth == 8 * 2 => ShortType - case int: ArrowType.Int if int.getIsSigned && int.getBitWidth == 8 * 4 => IntegerType - case int: ArrowType.Int if int.getIsSigned && int.getBitWidth == 8 * 8 => LongType - case float: ArrowType.FloatingPoint - if float.getPrecision() == FloatingPointPrecision.SINGLE => FloatType - case float: ArrowType.FloatingPoint - if float.getPrecision() == FloatingPointPrecision.DOUBLE => DoubleType - case ArrowType.Utf8.INSTANCE => StringType - case ArrowType.Binary.INSTANCE => BinaryType - case ArrowType.LargeUtf8.INSTANCE => StringType - case ArrowType.LargeBinary.INSTANCE => BinaryType - case d: ArrowType.Decimal => DecimalType(d.getPrecision, d.getScale) - case date: ArrowType.Date if date.getUnit == DateUnit.DAY => DateType - case ts: ArrowType.Timestamp - if ts.getUnit == TimeUnit.MICROSECOND && ts.getTimezone == null => TimestampNTZType - case ts: ArrowType.Timestamp if ts.getUnit == TimeUnit.MICROSECOND => TimestampType - case ArrowType.Null.INSTANCE => NullType - case yi: ArrowType.Interval if yi.getUnit == IntervalUnit.YEAR_MONTH => YearMonthIntervalType() - case di: ArrowType.Duration if di.getUnit == TimeUnit.MICROSECOND => DayTimeIntervalType() - case _ => throw ExecutionErrors.unsupportedArrowTypeError(dt) - } - /** Maps field from Spark to Arrow. NOTE: timeZoneId required for TimestampType */ def toArrowField( name: String, @@ -172,13 +117,6 @@ private[sql] object SedonaArrowUtils { }.asJava) } - def fromArrowSchema(schema: Schema): StructType = { - StructType(schema.getFields.asScala.map { field => - val dt = fromArrowField(field) - StructField(field.getName, dt, field.isNullable) - }.toArray) - } - private def deduplicateFieldNames( dt: DataType, errorOnDuplicatedFieldNames: Boolean): DataType = dt match { case geometryType: GeometryUDT => geometryType diff --git a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonArrowInput.scala b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonArrowInput.scala index 6791015ae9..8a5e241c51 100644 --- a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonArrowInput.scala +++ b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonArrowInput.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.execution.arrow.ArrowWriter import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.ArrowUtils +import org.apache.spark.sql.util.ArrowUtils.toArrowSchema import org.apache.spark.util.Utils import org.apache.spark.{SparkEnv, TaskContext}
