This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 9cea0fb663c [SPARK-39979][SQL] Add option to use large variable width vectors for arrow UDF operations 9cea0fb663c is described below commit 9cea0fb663caa0ff13e07b2424cabeb56e6b9dbd Author: Adam Binford <adam...@gmail.com> AuthorDate: Mon May 29 09:05:24 2023 +0900 [SPARK-39979][SQL] Add option to use large variable width vectors for arrow UDF operations ### What changes were proposed in this pull request? Adds a new config that uses the `LargeUtf8` and `LargeBinary` arrow types for arrow-based UDF operations. These arrow types make arrow use `LargeVarCharVector` and `LargeVarBinaryVector` instead of the regular `VarCharVector` and `VarBinaryVector` respectively. This config is disabled by default to maintain the current behavior. ### Why are the changes needed? `VarCharVector` and `VarBinaryVector` have a size limit of 2 GiB for a single vector. This is because they use 4 byte integers to track the offsets of each value in the vector. During certain operations, it is possible to hit this limit. The most affected way that we've run into this is during a `applyInPandas` operation, since the entire group is sent as a single RecordBatch, and there is no way to chunk up any smaller than the entire group. However, other map and UDF operations can [...] The large vector types use an 8 byte long to track value offsets, removing the 2 GiB total size limit. ### Does this PR introduce _any_ user-facing change? Adds an option that can help users get around what currently results in `IndexOutOfBoundsException`, though this exception being raised is a bug that was fixed in Arrow and it should actually be a `OversizedAllocationException` in the next release which suggests using the large variable width types instead. ### How was this patch tested? A few new tests are added. I also enabled the setting by default for a full CI run and all existing tests passed. I can add more tests if needed. Closes #39572 from Kimahriman/large-binary-vector. Authored-by: Adam Binford <adam...@gmail.com> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- python/pyspark/sql/pandas/types.py | 4 ++ python/pyspark/sql/tests/pandas/test_pandas_map.py | 19 +++++++- python/pyspark/sql/tests/test_arrow_map.py | 15 ++++++ .../spark/sql/vectorized/ArrowColumnVector.java | 44 +++++++++++++++++ .../spark/sql/execution/arrow/ArrowWriter.scala | 30 ++++++++++++ .../org/apache/spark/sql/internal/SQLConf.scala | 12 +++++ .../org/apache/spark/sql/util/ArrowUtils.scala | 37 +++++++++----- .../execution/python/AggregateInPandasExec.scala | 2 + .../ApplyInPandasWithStatePythonRunner.scala | 2 + .../sql/execution/python/ArrowEvalPythonExec.scala | 2 + .../sql/execution/python/ArrowPythonRunner.scala | 1 + .../python/FlatMapGroupsInPandasExec.scala | 2 + .../sql/execution/python/MapInBatchExec.scala | 3 ++ .../sql/execution/python/PythonArrowInput.scala | 5 +- .../sql/execution/python/WindowInPandasExec.scala | 2 + .../sql/execution/arrow/ArrowWriterSuite.scala | 8 +++- .../sql/vectorized/ArrowColumnVectorSuite.scala | 56 +++++++++++++++++++++- 17 files changed, 229 insertions(+), 15 deletions(-) diff --git a/python/pyspark/sql/pandas/types.py b/python/pyspark/sql/pandas/types.py index ae7c25e0828..757deff6130 100644 --- a/python/pyspark/sql/pandas/types.py +++ b/python/pyspark/sql/pandas/types.py @@ -166,8 +166,12 @@ def from_arrow_type(at: "pa.DataType", prefer_timestamp_ntz: bool = False) -> Da spark_type = DecimalType(precision=at.precision, scale=at.scale) elif types.is_string(at): spark_type = StringType() + elif types.is_large_string(at): + spark_type = StringType() elif types.is_binary(at): spark_type = BinaryType() + elif types.is_large_binary(at): + spark_type = BinaryType() elif types.is_date32(at): spark_type = DateType() elif types.is_timestamp(at) and prefer_timestamp_ntz and at.tz is None: diff --git a/python/pyspark/sql/tests/pandas/test_pandas_map.py b/python/pyspark/sql/tests/pandas/test_pandas_map.py index 2f6f3f0df57..3d9a90bc81c 100644 --- a/python/pyspark/sql/tests/pandas/test_pandas_map.py +++ b/python/pyspark/sql/tests/pandas/test_pandas_map.py @@ -22,7 +22,7 @@ import unittest from typing import cast from pyspark.sql import Row -from pyspark.sql.functions import lit +from pyspark.sql.functions import col, encode, lit from pyspark.errors import PythonException from pyspark.testing.sqlutils import ( ReusedSQLTestCase, @@ -68,6 +68,23 @@ class MapInPandasTestsMixin: expected = df.collect() self.assertEqual(actual, expected) + def test_large_variable_types(self): + with self.sql_conf({"spark.sql.execution.arrow.useLargeVarTypes": True}): + + def func(iterator): + for pdf in iterator: + assert isinstance(pdf, pd.DataFrame) + yield pdf + + df = ( + self.spark.range(10, numPartitions=3) + .select(col("id").cast("string").alias("str")) + .withColumn("bin", encode(col("str"), "utf8")) + ) + actual = df.mapInPandas(func, "str string, bin binary").collect() + expected = df.collect() + self.assertEqual(actual, expected) + def test_different_output_length(self): def func(iterator): for _ in iterator: diff --git a/python/pyspark/sql/tests/test_arrow_map.py b/python/pyspark/sql/tests/test_arrow_map.py index ff3d9b96b6b..050f2c32665 100644 --- a/python/pyspark/sql/tests/test_arrow_map.py +++ b/python/pyspark/sql/tests/test_arrow_map.py @@ -64,6 +64,21 @@ class MapInArrowTestsMixin(object): expected = df.collect() self.assertEqual(actual, expected) + def test_large_variable_width_types(self): + with self.sql_conf({"spark.sql.execution.arrow.useLargeVarTypes": True}): + data = [("foo", b"foo"), (None, None), ("bar", b"bar")] + df = self.spark.createDataFrame(data, "a string, b binary") + + def func(iterator): + for batch in iterator: + assert isinstance(batch, pa.RecordBatch) + assert batch.schema.types == [pa.large_string(), pa.large_binary()] + yield batch + + actual = df.mapInArrow(func, df.schema).collect() + expected = df.collect() + self.assertEqual(actual, expected) + def test_different_output_length(self): def func(iterator): for _ in iterator: diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java index 742cf511395..635ad9994cb 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java @@ -19,6 +19,7 @@ package org.apache.spark.sql.vectorized; import org.apache.arrow.vector.*; import org.apache.arrow.vector.complex.*; +import org.apache.arrow.vector.holders.NullableLargeVarCharHolder; import org.apache.arrow.vector.holders.NullableVarCharHolder; import org.apache.spark.annotation.DeveloperApi; @@ -160,8 +161,12 @@ public class ArrowColumnVector extends ColumnVector { accessor = new DecimalAccessor((DecimalVector) vector); } else if (vector instanceof VarCharVector) { accessor = new StringAccessor((VarCharVector) vector); + } else if (vector instanceof LargeVarCharVector) { + accessor = new LargeStringAccessor((LargeVarCharVector) vector); } else if (vector instanceof VarBinaryVector) { accessor = new BinaryAccessor((VarBinaryVector) vector); + } else if (vector instanceof LargeVarBinaryVector) { + accessor = new LargeBinaryAccessor((LargeVarBinaryVector) vector); } else if (vector instanceof DateDayVector) { accessor = new DateAccessor((DateDayVector) vector); } else if (vector instanceof TimeStampMicroTZVector) { @@ -406,6 +411,30 @@ public class ArrowColumnVector extends ColumnVector { } } + static class LargeStringAccessor extends ArrowVectorAccessor { + + private final LargeVarCharVector accessor; + private final NullableLargeVarCharHolder stringResult = new NullableLargeVarCharHolder(); + + LargeStringAccessor(LargeVarCharVector vector) { + super(vector); + this.accessor = vector; + } + + @Override + final UTF8String getUTF8String(int rowId) { + accessor.get(rowId, stringResult); + if (stringResult.isSet == 0) { + return null; + } else { + return UTF8String.fromAddress(null, + stringResult.buffer.memoryAddress() + stringResult.start, + // A single string cannot be larger than the max integer size, so the conversion is safe + (int)(stringResult.end - stringResult.start)); + } + } + } + static class BinaryAccessor extends ArrowVectorAccessor { private final VarBinaryVector accessor; @@ -421,6 +450,21 @@ public class ArrowColumnVector extends ColumnVector { } } + static class LargeBinaryAccessor extends ArrowVectorAccessor { + + private final LargeVarBinaryVector accessor; + + LargeBinaryAccessor(LargeVarBinaryVector vector) { + super(vector); + this.accessor = vector; + } + + @Override + final byte[] getBinary(int rowId) { + return accessor.getObject(rowId); + } + } + static class DateAccessor extends ArrowVectorAccessor { private final DateDayVector accessor; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala index efdbc583207..a55e4f0cfcd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala @@ -60,7 +60,9 @@ object ArrowWriter { case (DecimalType.Fixed(precision, scale), vector: DecimalVector) => new DecimalWriter(vector, precision, scale) case (StringType, vector: VarCharVector) => new StringWriter(vector) + case (StringType, vector: LargeVarCharVector) => new LargeStringWriter(vector) case (BinaryType, vector: VarBinaryVector) => new BinaryWriter(vector) + case (BinaryType, vector: LargeVarBinaryVector) => new LargeBinaryWriter(vector) case (DateType, vector: DateDayVector) => new DateWriter(vector) case (TimestampType, vector: TimeStampMicroTZVector) => new TimestampWriter(vector) case (TimestampNTZType, vector: TimeStampMicroVector) => new TimestampNTZWriter(vector) @@ -255,6 +257,21 @@ private[arrow] class StringWriter(val valueVector: VarCharVector) extends ArrowF } } +private[arrow] class LargeStringWriter( + val valueVector: LargeVarCharVector) extends ArrowFieldWriter { + + override def setNull(): Unit = { + valueVector.setNull(count) + } + + override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { + val utf8 = input.getUTF8String(ordinal) + val utf8ByteBuffer = utf8.getByteBuffer + // todo: for off-heap UTF8String, how to pass in to arrow without copy? + valueVector.setSafe(count, utf8ByteBuffer, utf8ByteBuffer.position(), utf8.numBytes()) + } +} + private[arrow] class BinaryWriter( val valueVector: VarBinaryVector) extends ArrowFieldWriter { @@ -268,6 +285,19 @@ private[arrow] class BinaryWriter( } } +private[arrow] class LargeBinaryWriter( + val valueVector: LargeVarBinaryVector) extends ArrowFieldWriter { + + override def setNull(): Unit = { + valueVector.setNull(count) + } + + override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { + val bytes = input.getBinary(ordinal) + valueVector.setSafe(count, bytes, 0, bytes.length) + } +} + private[arrow] class DateWriter(val valueVector: DateDayVector) extends ArrowFieldWriter { override def setNull(): Unit = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index b1e0285e6ae..e8185202a7e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -2825,6 +2825,16 @@ object SQLConf { .intConf .createWithDefault(10000) + val ARROW_EXECUTION_USE_LARGE_VAR_TYPES = + buildConf("spark.sql.execution.arrow.useLargeVarTypes") + .doc("When using Apache Arrow, use large variable width vectors for string and binary " + + "types. Regular string and binary types have a 2GiB limit for a column in a single " + + "record batch. Large variable types remove this limitation at the cost of higher memory " + + "usage per value.") + .version("3.5.0") + .booleanConf + .createWithDefault(false) + val PANDAS_UDF_BUFFER_SIZE = buildConf("spark.sql.execution.pandas.udf.buffer.size") .doc( @@ -4890,6 +4900,8 @@ class SQLConf extends Serializable with Logging { def arrowMaxRecordsPerBatch: Int = getConf(ARROW_EXECUTION_MAX_RECORDS_PER_BATCH) + def arrowUseLargeVarTypes: Boolean = getConf(ARROW_EXECUTION_USE_LARGE_VAR_TYPES) + def pandasUDFBufferSize: Int = getConf(PANDAS_UDF_BUFFER_SIZE) def pandasStructHandlingMode: String = getConf(PANDAS_STRUCT_HANDLING_MODE) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala index 719691a338f..e880e973176 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala @@ -37,7 +37,8 @@ private[sql] object ArrowUtils { // todo: support more types. /** Maps data type from Spark to Arrow. NOTE: timeZoneId required for TimestampTypes */ - def toArrowType(dt: DataType, timeZoneId: String): ArrowType = dt match { + def toArrowType( + dt: DataType, timeZoneId: String, largeVarTypes: Boolean = false): ArrowType = dt match { case BooleanType => ArrowType.Bool.INSTANCE case ByteType => new ArrowType.Int(8, true) case ShortType => new ArrowType.Int(8 * 2, true) @@ -45,8 +46,10 @@ private[sql] object ArrowUtils { case LongType => new ArrowType.Int(8 * 8, true) case FloatType => new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE) case DoubleType => new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE) - case StringType => ArrowType.Utf8.INSTANCE - case BinaryType => ArrowType.Binary.INSTANCE + case StringType if !largeVarTypes => ArrowType.Utf8.INSTANCE + case BinaryType if !largeVarTypes => ArrowType.Binary.INSTANCE + case StringType if largeVarTypes => ArrowType.LargeUtf8.INSTANCE + case BinaryType if largeVarTypes => ArrowType.LargeBinary.INSTANCE case DecimalType.Fixed(precision, scale) => new ArrowType.Decimal(precision, scale) case DateType => new ArrowType.Date(DateUnit.DAY) case TimestampType if timeZoneId == null => @@ -73,6 +76,8 @@ private[sql] object ArrowUtils { if float.getPrecision() == FloatingPointPrecision.DOUBLE => DoubleType case ArrowType.Utf8.INSTANCE => StringType case ArrowType.Binary.INSTANCE => BinaryType + case ArrowType.LargeUtf8.INSTANCE => StringType + case ArrowType.LargeBinary.INSTANCE => BinaryType case d: ArrowType.Decimal => DecimalType(d.getPrecision, d.getScale) case date: ArrowType.Date if date.getUnit == DateUnit.DAY => DateType case ts: ArrowType.Timestamp @@ -86,17 +91,22 @@ private[sql] object ArrowUtils { /** Maps field from Spark to Arrow. NOTE: timeZoneId required for TimestampType */ def toArrowField( - name: String, dt: DataType, nullable: Boolean, timeZoneId: String): Field = { + name: String, + dt: DataType, + nullable: Boolean, + timeZoneId: String, + largeVarTypes: Boolean = false): Field = { dt match { case ArrayType(elementType, containsNull) => val fieldType = new FieldType(nullable, ArrowType.List.INSTANCE, null) new Field(name, fieldType, - Seq(toArrowField("element", elementType, containsNull, timeZoneId)).asJava) + Seq(toArrowField("element", elementType, containsNull, timeZoneId, + largeVarTypes)).asJava) case StructType(fields) => val fieldType = new FieldType(nullable, ArrowType.Struct.INSTANCE, null) new Field(name, fieldType, fields.map { field => - toArrowField(field.name, field.dataType, field.nullable, timeZoneId) + toArrowField(field.name, field.dataType, field.nullable, timeZoneId, largeVarTypes) }.toSeq.asJava) case MapType(keyType, valueType, valueContainsNull) => val mapType = new FieldType(nullable, new ArrowType.Map(false), null) @@ -107,10 +117,13 @@ private[sql] object ArrowUtils { .add(MapVector.KEY_NAME, keyType, nullable = false) .add(MapVector.VALUE_NAME, valueType, nullable = valueContainsNull), nullable = false, - timeZoneId)).asJava) - case udt: UserDefinedType[_] => toArrowField(name, udt.sqlType, nullable, timeZoneId) + timeZoneId, + largeVarTypes)).asJava) + case udt: UserDefinedType[_] => + toArrowField(name, udt.sqlType, nullable, timeZoneId, largeVarTypes) case dataType => - val fieldType = new FieldType(nullable, toArrowType(dataType, timeZoneId), null) + val fieldType = new FieldType(nullable, toArrowType(dataType, timeZoneId, + largeVarTypes), null) new Field(name, fieldType, Seq.empty[Field].asJava) } } @@ -140,13 +153,15 @@ private[sql] object ArrowUtils { def toArrowSchema( schema: StructType, timeZoneId: String, - errorOnDuplicatedFieldNames: Boolean): Schema = { + errorOnDuplicatedFieldNames: Boolean, + largeVarTypes: Boolean = false): Schema = { new Schema(schema.map { field => toArrowField( field.name, deduplicateFieldNames(field.dataType, errorOnDuplicatedFieldNames), field.nullable, - timeZoneId) + timeZoneId, + largeVarTypes) }.asJava) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala index a9a9679bb36..c51a3a5cce3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala @@ -101,6 +101,7 @@ case class AggregateInPandasExec( val inputRDD = child.execute() val sessionLocalTimeZone = conf.sessionLocalTimeZone + val largeVarTypes = conf.arrowUseLargeVarTypes val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf) val (pyFuncs, inputs) = udfExpressions.map(collectFunctions).unzip @@ -167,6 +168,7 @@ case class AggregateInPandasExec( argOffsets, aggInputSchema, sessionLocalTimeZone, + largeVarTypes, pythonRunnerConf, pythonMetrics).compute(projectedRowIter, context.partitionId(), context) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala index ac73e53266d..35676406f14 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala @@ -78,6 +78,8 @@ class ApplyInPandasWithStatePythonRunner( override val simplifiedTraceback: Boolean = sqlConf.pysparkSimplifiedTraceback + override protected val largeVarTypes: Boolean = sqlConf.arrowUseLargeVarTypes + override val bufferSize: Int = { val configuredSize = sqlConf.pandasUDFBufferSize if (configuredSize < 4) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala index b11dd4947af..86a5d13aed0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala @@ -65,6 +65,7 @@ case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], resultAttrs: Seq[Attribute] private val batchSize = conf.arrowMaxRecordsPerBatch private val sessionLocalTimeZone = conf.sessionLocalTimeZone + private val largeVarTypes = conf.arrowUseLargeVarTypes private val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf) protected override def evaluate( @@ -85,6 +86,7 @@ case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], resultAttrs: Seq[Attribute] argOffsets, schema, sessionLocalTimeZone, + largeVarTypes, pythonRunnerConf, pythonMetrics).compute(batchIter, context.partitionId(), context) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala index d727c1b5ca0..175d67e9043 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala @@ -33,6 +33,7 @@ class ArrowPythonRunner( argOffsets: Array[Array[Int]], protected override val schema: StructType, protected override val timeZoneId: String, + protected override val largeVarTypes: Boolean, protected override val workerConf: Map[String, String], val pythonMetrics: Map[String, SQLMetric]) extends BasePythonRunner[Iterator[InternalRow], ColumnarBatch](funcs, evalType, argOffsets) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala index 271ccdb6b27..8da53cc6c99 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala @@ -53,6 +53,7 @@ case class FlatMapGroupsInPandasExec( extends SparkPlan with UnaryExecNode with PythonSQLMetrics { private val sessionLocalTimeZone = conf.sessionLocalTimeZone + private val largeVarTypes = conf.arrowUseLargeVarTypes private val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf) private val pandasFunction = func.asInstanceOf[PythonUDF].func private val chainedFunc = Seq(ChainedPythonFunctions(Seq(pandasFunction))) @@ -89,6 +90,7 @@ case class FlatMapGroupsInPandasExec( Array(argOffsets), StructType.fromAttributes(dedupAttributes), sessionLocalTimeZone, + largeVarTypes, pythonRunnerConf, pythonMetrics) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala index 0fe3acb14e8..8281435ca92 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala @@ -49,6 +49,8 @@ trait MapInBatchExec extends UnaryExecNode with PythonSQLMetrics { private val batchSize = conf.arrowMaxRecordsPerBatch + private val largeVarTypes = conf.arrowUseLargeVarTypes + override def outputPartitioning: Partitioning = child.outputPartitioning override protected def doExecute(): RDD[InternalRow] = { @@ -77,6 +79,7 @@ trait MapInBatchExec extends UnaryExecNode with PythonSQLMetrics { argOffsets, StructType(Array(StructField("struct", outputTypes))), sessionLocalTimeZone, + largeVarTypes, pythonRunnerConf, pythonMetrics).compute(batchIter, context.partitionId(), context) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala index 26ce10b6aae..c78ea564f18 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala @@ -44,6 +44,8 @@ private[python] trait PythonArrowInput[IN] { self: BasePythonRunner[IN, _] => protected val errorOnDuplicatedFieldNames: Boolean + protected val largeVarTypes: Boolean + protected def pythonMetrics: Map[String, SQLMetric] protected def writeIteratorToArrowStream( @@ -75,7 +77,8 @@ private[python] trait PythonArrowInput[IN] { self: BasePythonRunner[IN, _] => } protected override def writeIteratorToStream(dataOut: DataOutputStream): Unit = { - val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId, errorOnDuplicatedFieldNames) + val arrowSchema = ArrowUtils.toArrowSchema( + schema, timeZoneId, errorOnDuplicatedFieldNames, largeVarTypes) val allocator = ArrowUtils.rootAllocator.newChildAllocator( s"stdout writer for $pythonExec", 0, Long.MaxValue) val root = VectorSchemaRoot.create(arrowSchema, allocator) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala index c5493079e40..e6a65dd61dc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala @@ -185,6 +185,7 @@ case class WindowInPandasExec( val inMemoryThreshold = conf.windowExecBufferInMemoryThreshold val spillThreshold = conf.windowExecBufferSpillThreshold val sessionLocalTimeZone = conf.sessionLocalTimeZone + val largeVarTypes = conf.arrowUseLargeVarTypes // Extract window expressions and window functions val windowExpressions = expressions.flatMap(_.collect { case e: WindowExpression => e }) @@ -385,6 +386,7 @@ case class WindowInPandasExec( argOffsets, pythonInputSchema, sessionLocalTimeZone, + largeVarTypes, pythonRunnerConf, pythonMetrics).compute(pythonInput, context.partitionId(), context) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala index a88f423ae01..86a961137f4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala @@ -27,7 +27,11 @@ import org.apache.spark.unsafe.types.UTF8String class ArrowWriterSuite extends SparkFunSuite { test("simple") { - def check(dt: DataType, data: Seq[Any], timeZoneId: String = null): Unit = { + def check( + dt: DataType, + data: Seq[Any], + timeZoneId: String = null, + largeVarTypes: Boolean = false): Unit = { val datatype = dt match { case _: DayTimeIntervalType => DayTimeIntervalType() case _: YearMonthIntervalType => YearMonthIntervalType() @@ -77,7 +81,9 @@ class ArrowWriterSuite extends SparkFunSuite { check(DoubleType, Seq(1.0d, 2.0d, null, 4.0d)) check(DecimalType.SYSTEM_DEFAULT, Seq(Decimal(1), Decimal(2), null, Decimal(4))) check(StringType, Seq("a", "b", null, "d").map(UTF8String.fromString)) + check(StringType, Seq("a", "b", null, "d").map(UTF8String.fromString), null, true) check(BinaryType, Seq("a".getBytes(), "b".getBytes(), null, "d".getBytes())) + check(BinaryType, Seq("a".getBytes(), "b".getBytes(), null, "d".getBytes()), null, true) check(DateType, Seq(0, 1, 2, null, 4)) check(TimestampType, Seq(0L, 3.6e9.toLong, null, 8.64e10.toLong), "America/Los_Angeles") check(TimestampNTZType, Seq(0L, 3.6e9.toLong, null, 8.64e10.toLong)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/vectorized/ArrowColumnVectorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/vectorized/ArrowColumnVectorSuite.scala index 25beda99cd6..436cea50ad9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/vectorized/ArrowColumnVectorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/vectorized/ArrowColumnVectorSuite.scala @@ -250,9 +250,36 @@ class ArrowColumnVectorSuite extends SparkFunSuite { allocator.close() } + test("large_string") { + val allocator = ArrowUtils.rootAllocator.newChildAllocator("string", 0, Long.MaxValue) + val vector = ArrowUtils.toArrowField("string", StringType, nullable = true, null, true) + .createVector(allocator).asInstanceOf[LargeVarCharVector] + vector.allocateNew() + + (0 until 10).foreach { i => + val utf8 = s"str$i".getBytes("utf8") + vector.setSafe(i, utf8, 0, utf8.length) + } + vector.setNull(10) + vector.setValueCount(11) + + val columnVector = new ArrowColumnVector(vector) + assert(columnVector.dataType === StringType) + assert(columnVector.hasNull) + assert(columnVector.numNulls === 1) + + (0 until 10).foreach { i => + assert(columnVector.getUTF8String(i) === UTF8String.fromString(s"str$i")) + } + assert(columnVector.isNullAt(10)) + + columnVector.close() + allocator.close() + } + test("binary") { val allocator = ArrowUtils.rootAllocator.newChildAllocator("binary", 0, Long.MaxValue) - val vector = ArrowUtils.toArrowField("binary", BinaryType, nullable = true, null) + val vector = ArrowUtils.toArrowField("binary", BinaryType, nullable = true, null, false) .createVector(allocator).asInstanceOf[VarBinaryVector] vector.allocateNew() @@ -277,6 +304,33 @@ class ArrowColumnVectorSuite extends SparkFunSuite { allocator.close() } + test("large_binary") { + val allocator = ArrowUtils.rootAllocator.newChildAllocator("binary", 0, Long.MaxValue) + val vector = ArrowUtils.toArrowField("binary", BinaryType, nullable = true, null, true) + .createVector(allocator).asInstanceOf[LargeVarBinaryVector] + vector.allocateNew() + + (0 until 10).foreach { i => + val utf8 = s"str$i".getBytes("utf8") + vector.setSafe(i, utf8, 0, utf8.length) + } + vector.setNull(10) + vector.setValueCount(11) + + val columnVector = new ArrowColumnVector(vector) + assert(columnVector.dataType === BinaryType) + assert(columnVector.hasNull) + assert(columnVector.numNulls === 1) + + (0 until 10).foreach { i => + assert(columnVector.getBinary(i) === s"str$i".getBytes("utf8")) + } + assert(columnVector.isNullAt(10)) + + columnVector.close() + allocator.close() + } + test("array") { val allocator = ArrowUtils.rootAllocator.newChildAllocator("array", 0, Long.MaxValue) val vector = ArrowUtils.toArrowField("array", ArrayType(IntegerType), nullable = true, null) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org