This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 9cea0fb663c [SPARK-39979][SQL] Add option to use large variable width 
vectors for arrow UDF operations
9cea0fb663c is described below

commit 9cea0fb663caa0ff13e07b2424cabeb56e6b9dbd
Author: Adam Binford <adam...@gmail.com>
AuthorDate: Mon May 29 09:05:24 2023 +0900

    [SPARK-39979][SQL] Add option to use large variable width vectors for arrow 
UDF operations
    
    ### What changes were proposed in this pull request?
    
    Adds a new config that uses the `LargeUtf8` and `LargeBinary` arrow types 
for arrow-based UDF operations. These arrow types make arrow use 
`LargeVarCharVector` and `LargeVarBinaryVector` instead of the regular 
`VarCharVector` and `VarBinaryVector` respectively. This config is disabled by 
default to maintain the current behavior.
    
    ### Why are the changes needed?
    
    `VarCharVector` and `VarBinaryVector` have a size limit of 2 GiB for a 
single vector. This is because they use 4 byte integers to track the offsets of 
each value in the vector. During certain operations, it is possible to hit this 
limit. The most affected way that we've run into this is during a 
`applyInPandas` operation, since the entire group is sent as a single 
RecordBatch, and there is no way to chunk up any smaller than the entire group. 
However, other map and UDF operations can  [...]
    
    The large vector types use an 8 byte long to track value offsets, removing 
the 2 GiB total size limit.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Adds an option that can help users get around what currently results in 
`IndexOutOfBoundsException`, though this exception being raised is a bug that 
was fixed in Arrow and it should actually be a `OversizedAllocationException` 
in the next release which suggests using the large variable width types instead.
    
    ### How was this patch tested?
    
    A few new tests are added. I also enabled the setting by default for a full 
CI run and all existing tests passed. I can add more tests if needed.
    
    Closes #39572 from Kimahriman/large-binary-vector.
    
    Authored-by: Adam Binford <adam...@gmail.com>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 python/pyspark/sql/pandas/types.py                 |  4 ++
 python/pyspark/sql/tests/pandas/test_pandas_map.py | 19 +++++++-
 python/pyspark/sql/tests/test_arrow_map.py         | 15 ++++++
 .../spark/sql/vectorized/ArrowColumnVector.java    | 44 +++++++++++++++++
 .../spark/sql/execution/arrow/ArrowWriter.scala    | 30 ++++++++++++
 .../org/apache/spark/sql/internal/SQLConf.scala    | 12 +++++
 .../org/apache/spark/sql/util/ArrowUtils.scala     | 37 +++++++++-----
 .../execution/python/AggregateInPandasExec.scala   |  2 +
 .../ApplyInPandasWithStatePythonRunner.scala       |  2 +
 .../sql/execution/python/ArrowEvalPythonExec.scala |  2 +
 .../sql/execution/python/ArrowPythonRunner.scala   |  1 +
 .../python/FlatMapGroupsInPandasExec.scala         |  2 +
 .../sql/execution/python/MapInBatchExec.scala      |  3 ++
 .../sql/execution/python/PythonArrowInput.scala    |  5 +-
 .../sql/execution/python/WindowInPandasExec.scala  |  2 +
 .../sql/execution/arrow/ArrowWriterSuite.scala     |  8 +++-
 .../sql/vectorized/ArrowColumnVectorSuite.scala    | 56 +++++++++++++++++++++-
 17 files changed, 229 insertions(+), 15 deletions(-)

diff --git a/python/pyspark/sql/pandas/types.py 
b/python/pyspark/sql/pandas/types.py
index ae7c25e0828..757deff6130 100644
--- a/python/pyspark/sql/pandas/types.py
+++ b/python/pyspark/sql/pandas/types.py
@@ -166,8 +166,12 @@ def from_arrow_type(at: "pa.DataType", 
prefer_timestamp_ntz: bool = False) -> Da
         spark_type = DecimalType(precision=at.precision, scale=at.scale)
     elif types.is_string(at):
         spark_type = StringType()
+    elif types.is_large_string(at):
+        spark_type = StringType()
     elif types.is_binary(at):
         spark_type = BinaryType()
+    elif types.is_large_binary(at):
+        spark_type = BinaryType()
     elif types.is_date32(at):
         spark_type = DateType()
     elif types.is_timestamp(at) and prefer_timestamp_ntz and at.tz is None:
diff --git a/python/pyspark/sql/tests/pandas/test_pandas_map.py 
b/python/pyspark/sql/tests/pandas/test_pandas_map.py
index 2f6f3f0df57..3d9a90bc81c 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_map.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_map.py
@@ -22,7 +22,7 @@ import unittest
 from typing import cast
 
 from pyspark.sql import Row
-from pyspark.sql.functions import lit
+from pyspark.sql.functions import col, encode, lit
 from pyspark.errors import PythonException
 from pyspark.testing.sqlutils import (
     ReusedSQLTestCase,
@@ -68,6 +68,23 @@ class MapInPandasTestsMixin:
         expected = df.collect()
         self.assertEqual(actual, expected)
 
+    def test_large_variable_types(self):
+        with self.sql_conf({"spark.sql.execution.arrow.useLargeVarTypes": 
True}):
+
+            def func(iterator):
+                for pdf in iterator:
+                    assert isinstance(pdf, pd.DataFrame)
+                    yield pdf
+
+            df = (
+                self.spark.range(10, numPartitions=3)
+                .select(col("id").cast("string").alias("str"))
+                .withColumn("bin", encode(col("str"), "utf8"))
+            )
+            actual = df.mapInPandas(func, "str string, bin binary").collect()
+            expected = df.collect()
+            self.assertEqual(actual, expected)
+
     def test_different_output_length(self):
         def func(iterator):
             for _ in iterator:
diff --git a/python/pyspark/sql/tests/test_arrow_map.py 
b/python/pyspark/sql/tests/test_arrow_map.py
index ff3d9b96b6b..050f2c32665 100644
--- a/python/pyspark/sql/tests/test_arrow_map.py
+++ b/python/pyspark/sql/tests/test_arrow_map.py
@@ -64,6 +64,21 @@ class MapInArrowTestsMixin(object):
         expected = df.collect()
         self.assertEqual(actual, expected)
 
+    def test_large_variable_width_types(self):
+        with self.sql_conf({"spark.sql.execution.arrow.useLargeVarTypes": 
True}):
+            data = [("foo", b"foo"), (None, None), ("bar", b"bar")]
+            df = self.spark.createDataFrame(data, "a string, b binary")
+
+            def func(iterator):
+                for batch in iterator:
+                    assert isinstance(batch, pa.RecordBatch)
+                    assert batch.schema.types == [pa.large_string(), 
pa.large_binary()]
+                    yield batch
+
+            actual = df.mapInArrow(func, df.schema).collect()
+            expected = df.collect()
+            self.assertEqual(actual, expected)
+
     def test_different_output_length(self):
         def func(iterator):
             for _ in iterator:
diff --git 
a/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java
 
b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java
index 742cf511395..635ad9994cb 100644
--- 
a/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java
+++ 
b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java
@@ -19,6 +19,7 @@ package org.apache.spark.sql.vectorized;
 
 import org.apache.arrow.vector.*;
 import org.apache.arrow.vector.complex.*;
+import org.apache.arrow.vector.holders.NullableLargeVarCharHolder;
 import org.apache.arrow.vector.holders.NullableVarCharHolder;
 
 import org.apache.spark.annotation.DeveloperApi;
@@ -160,8 +161,12 @@ public class ArrowColumnVector extends ColumnVector {
       accessor = new DecimalAccessor((DecimalVector) vector);
     } else if (vector instanceof VarCharVector) {
       accessor = new StringAccessor((VarCharVector) vector);
+    } else if (vector instanceof LargeVarCharVector) {
+      accessor = new LargeStringAccessor((LargeVarCharVector) vector);
     } else if (vector instanceof VarBinaryVector) {
       accessor = new BinaryAccessor((VarBinaryVector) vector);
+    } else if (vector instanceof LargeVarBinaryVector) {
+      accessor = new LargeBinaryAccessor((LargeVarBinaryVector) vector);
     } else if (vector instanceof DateDayVector) {
       accessor = new DateAccessor((DateDayVector) vector);
     } else if (vector instanceof TimeStampMicroTZVector) {
@@ -406,6 +411,30 @@ public class ArrowColumnVector extends ColumnVector {
     }
   }
 
+  static class LargeStringAccessor extends ArrowVectorAccessor {
+
+    private final LargeVarCharVector accessor;
+    private final NullableLargeVarCharHolder stringResult = new 
NullableLargeVarCharHolder();
+
+    LargeStringAccessor(LargeVarCharVector vector) {
+      super(vector);
+      this.accessor = vector;
+    }
+
+    @Override
+    final UTF8String getUTF8String(int rowId) {
+      accessor.get(rowId, stringResult);
+      if (stringResult.isSet == 0) {
+        return null;
+      } else {
+        return UTF8String.fromAddress(null,
+          stringResult.buffer.memoryAddress() + stringResult.start,
+          // A single string cannot be larger than the max integer size, so 
the conversion is safe
+          (int)(stringResult.end - stringResult.start));
+      }
+    }
+  }
+
   static class BinaryAccessor extends ArrowVectorAccessor {
 
     private final VarBinaryVector accessor;
@@ -421,6 +450,21 @@ public class ArrowColumnVector extends ColumnVector {
     }
   }
 
+  static class LargeBinaryAccessor extends ArrowVectorAccessor {
+
+    private final LargeVarBinaryVector accessor;
+
+    LargeBinaryAccessor(LargeVarBinaryVector vector) {
+      super(vector);
+      this.accessor = vector;
+    }
+
+    @Override
+    final byte[] getBinary(int rowId) {
+      return accessor.getObject(rowId);
+    }
+  }
+
   static class DateAccessor extends ArrowVectorAccessor {
 
     private final DateDayVector accessor;
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala
index efdbc583207..a55e4f0cfcd 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala
@@ -60,7 +60,9 @@ object ArrowWriter {
       case (DecimalType.Fixed(precision, scale), vector: DecimalVector) =>
         new DecimalWriter(vector, precision, scale)
       case (StringType, vector: VarCharVector) => new StringWriter(vector)
+      case (StringType, vector: LargeVarCharVector) => new 
LargeStringWriter(vector)
       case (BinaryType, vector: VarBinaryVector) => new BinaryWriter(vector)
+      case (BinaryType, vector: LargeVarBinaryVector) => new 
LargeBinaryWriter(vector)
       case (DateType, vector: DateDayVector) => new DateWriter(vector)
       case (TimestampType, vector: TimeStampMicroTZVector) => new 
TimestampWriter(vector)
       case (TimestampNTZType, vector: TimeStampMicroVector) => new 
TimestampNTZWriter(vector)
@@ -255,6 +257,21 @@ private[arrow] class StringWriter(val valueVector: 
VarCharVector) extends ArrowF
   }
 }
 
+private[arrow] class LargeStringWriter(
+    val valueVector: LargeVarCharVector) extends ArrowFieldWriter {
+
+  override def setNull(): Unit = {
+    valueVector.setNull(count)
+  }
+
+  override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
+    val utf8 = input.getUTF8String(ordinal)
+    val utf8ByteBuffer = utf8.getByteBuffer
+    // todo: for off-heap UTF8String, how to pass in to arrow without copy?
+    valueVector.setSafe(count, utf8ByteBuffer, utf8ByteBuffer.position(), 
utf8.numBytes())
+  }
+}
+
 private[arrow] class BinaryWriter(
     val valueVector: VarBinaryVector) extends ArrowFieldWriter {
 
@@ -268,6 +285,19 @@ private[arrow] class BinaryWriter(
   }
 }
 
+private[arrow] class LargeBinaryWriter(
+    val valueVector: LargeVarBinaryVector) extends ArrowFieldWriter {
+
+  override def setNull(): Unit = {
+    valueVector.setNull(count)
+  }
+
+  override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
+    val bytes = input.getBinary(ordinal)
+    valueVector.setSafe(count, bytes, 0, bytes.length)
+  }
+}
+
 private[arrow] class DateWriter(val valueVector: DateDayVector) extends 
ArrowFieldWriter {
 
   override def setNull(): Unit = {
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index b1e0285e6ae..e8185202a7e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -2825,6 +2825,16 @@ object SQLConf {
       .intConf
       .createWithDefault(10000)
 
+  val ARROW_EXECUTION_USE_LARGE_VAR_TYPES =
+    buildConf("spark.sql.execution.arrow.useLargeVarTypes")
+      .doc("When using Apache Arrow, use large variable width vectors for 
string and binary " +
+        "types. Regular string and binary types have a 2GiB limit for a column 
in a single " +
+        "record batch. Large variable types remove this limitation at the cost 
of higher memory " +
+        "usage per value.")
+      .version("3.5.0")
+      .booleanConf
+      .createWithDefault(false)
+
   val PANDAS_UDF_BUFFER_SIZE =
     buildConf("spark.sql.execution.pandas.udf.buffer.size")
       .doc(
@@ -4890,6 +4900,8 @@ class SQLConf extends Serializable with Logging {
 
   def arrowMaxRecordsPerBatch: Int = 
getConf(ARROW_EXECUTION_MAX_RECORDS_PER_BATCH)
 
+  def arrowUseLargeVarTypes: Boolean = 
getConf(ARROW_EXECUTION_USE_LARGE_VAR_TYPES)
+
   def pandasUDFBufferSize: Int = getConf(PANDAS_UDF_BUFFER_SIZE)
 
   def pandasStructHandlingMode: String = getConf(PANDAS_STRUCT_HANDLING_MODE)
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala
index 719691a338f..e880e973176 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala
@@ -37,7 +37,8 @@ private[sql] object ArrowUtils {
   // todo: support more types.
 
   /** Maps data type from Spark to Arrow. NOTE: timeZoneId required for 
TimestampTypes */
-  def toArrowType(dt: DataType, timeZoneId: String): ArrowType = dt match {
+  def toArrowType(
+      dt: DataType, timeZoneId: String, largeVarTypes: Boolean = false): 
ArrowType = dt match {
     case BooleanType => ArrowType.Bool.INSTANCE
     case ByteType => new ArrowType.Int(8, true)
     case ShortType => new ArrowType.Int(8 * 2, true)
@@ -45,8 +46,10 @@ private[sql] object ArrowUtils {
     case LongType => new ArrowType.Int(8 * 8, true)
     case FloatType => new 
ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE)
     case DoubleType => new 
ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE)
-    case StringType => ArrowType.Utf8.INSTANCE
-    case BinaryType => ArrowType.Binary.INSTANCE
+    case StringType if !largeVarTypes => ArrowType.Utf8.INSTANCE
+    case BinaryType if !largeVarTypes => ArrowType.Binary.INSTANCE
+    case StringType if largeVarTypes => ArrowType.LargeUtf8.INSTANCE
+    case BinaryType if largeVarTypes => ArrowType.LargeBinary.INSTANCE
     case DecimalType.Fixed(precision, scale) => new 
ArrowType.Decimal(precision, scale)
     case DateType => new ArrowType.Date(DateUnit.DAY)
     case TimestampType if timeZoneId == null =>
@@ -73,6 +76,8 @@ private[sql] object ArrowUtils {
       if float.getPrecision() == FloatingPointPrecision.DOUBLE => DoubleType
     case ArrowType.Utf8.INSTANCE => StringType
     case ArrowType.Binary.INSTANCE => BinaryType
+    case ArrowType.LargeUtf8.INSTANCE => StringType
+    case ArrowType.LargeBinary.INSTANCE => BinaryType
     case d: ArrowType.Decimal => DecimalType(d.getPrecision, d.getScale)
     case date: ArrowType.Date if date.getUnit == DateUnit.DAY => DateType
     case ts: ArrowType.Timestamp
@@ -86,17 +91,22 @@ private[sql] object ArrowUtils {
 
   /** Maps field from Spark to Arrow. NOTE: timeZoneId required for 
TimestampType */
   def toArrowField(
-      name: String, dt: DataType, nullable: Boolean, timeZoneId: String): 
Field = {
+      name: String,
+      dt: DataType,
+      nullable: Boolean,
+      timeZoneId: String,
+      largeVarTypes: Boolean = false): Field = {
     dt match {
       case ArrayType(elementType, containsNull) =>
         val fieldType = new FieldType(nullable, ArrowType.List.INSTANCE, null)
         new Field(name, fieldType,
-          Seq(toArrowField("element", elementType, containsNull, 
timeZoneId)).asJava)
+          Seq(toArrowField("element", elementType, containsNull, timeZoneId,
+            largeVarTypes)).asJava)
       case StructType(fields) =>
         val fieldType = new FieldType(nullable, ArrowType.Struct.INSTANCE, 
null)
         new Field(name, fieldType,
           fields.map { field =>
-            toArrowField(field.name, field.dataType, field.nullable, 
timeZoneId)
+            toArrowField(field.name, field.dataType, field.nullable, 
timeZoneId, largeVarTypes)
           }.toSeq.asJava)
       case MapType(keyType, valueType, valueContainsNull) =>
         val mapType = new FieldType(nullable, new ArrowType.Map(false), null)
@@ -107,10 +117,13 @@ private[sql] object ArrowUtils {
               .add(MapVector.KEY_NAME, keyType, nullable = false)
               .add(MapVector.VALUE_NAME, valueType, nullable = 
valueContainsNull),
             nullable = false,
-            timeZoneId)).asJava)
-      case udt: UserDefinedType[_] => toArrowField(name, udt.sqlType, 
nullable, timeZoneId)
+            timeZoneId,
+            largeVarTypes)).asJava)
+      case udt: UserDefinedType[_] =>
+        toArrowField(name, udt.sqlType, nullable, timeZoneId, largeVarTypes)
       case dataType =>
-        val fieldType = new FieldType(nullable, toArrowType(dataType, 
timeZoneId), null)
+        val fieldType = new FieldType(nullable, toArrowType(dataType, 
timeZoneId,
+          largeVarTypes), null)
         new Field(name, fieldType, Seq.empty[Field].asJava)
     }
   }
@@ -140,13 +153,15 @@ private[sql] object ArrowUtils {
   def toArrowSchema(
       schema: StructType,
       timeZoneId: String,
-      errorOnDuplicatedFieldNames: Boolean): Schema = {
+      errorOnDuplicatedFieldNames: Boolean,
+      largeVarTypes: Boolean = false): Schema = {
     new Schema(schema.map { field =>
       toArrowField(
         field.name,
         deduplicateFieldNames(field.dataType, errorOnDuplicatedFieldNames),
         field.nullable,
-        timeZoneId)
+        timeZoneId,
+        largeVarTypes)
     }.asJava)
   }
 
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala
index a9a9679bb36..c51a3a5cce3 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala
@@ -101,6 +101,7 @@ case class AggregateInPandasExec(
     val inputRDD = child.execute()
 
     val sessionLocalTimeZone = conf.sessionLocalTimeZone
+    val largeVarTypes = conf.arrowUseLargeVarTypes
     val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf)
 
     val (pyFuncs, inputs) = udfExpressions.map(collectFunctions).unzip
@@ -167,6 +168,7 @@ case class AggregateInPandasExec(
         argOffsets,
         aggInputSchema,
         sessionLocalTimeZone,
+        largeVarTypes,
         pythonRunnerConf,
         pythonMetrics).compute(projectedRowIter, context.partitionId(), 
context)
 
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala
index ac73e53266d..35676406f14 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala
@@ -78,6 +78,8 @@ class ApplyInPandasWithStatePythonRunner(
 
   override val simplifiedTraceback: Boolean = 
sqlConf.pysparkSimplifiedTraceback
 
+  override protected val largeVarTypes: Boolean = sqlConf.arrowUseLargeVarTypes
+
   override val bufferSize: Int = {
     val configuredSize = sqlConf.pandasUDFBufferSize
     if (configuredSize < 4) {
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala
index b11dd4947af..86a5d13aed0 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala
@@ -65,6 +65,7 @@ case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], 
resultAttrs: Seq[Attribute]
 
   private val batchSize = conf.arrowMaxRecordsPerBatch
   private val sessionLocalTimeZone = conf.sessionLocalTimeZone
+  private val largeVarTypes = conf.arrowUseLargeVarTypes
   private val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf)
 
   protected override def evaluate(
@@ -85,6 +86,7 @@ case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], 
resultAttrs: Seq[Attribute]
       argOffsets,
       schema,
       sessionLocalTimeZone,
+      largeVarTypes,
       pythonRunnerConf,
       pythonMetrics).compute(batchIter, context.partitionId(), context)
 
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
index d727c1b5ca0..175d67e9043 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
@@ -33,6 +33,7 @@ class ArrowPythonRunner(
     argOffsets: Array[Array[Int]],
     protected override val schema: StructType,
     protected override val timeZoneId: String,
+    protected override val largeVarTypes: Boolean,
     protected override val workerConf: Map[String, String],
     val pythonMetrics: Map[String, SQLMetric])
   extends BasePythonRunner[Iterator[InternalRow], ColumnarBatch](funcs, 
evalType, argOffsets)
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala
index 271ccdb6b27..8da53cc6c99 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala
@@ -53,6 +53,7 @@ case class FlatMapGroupsInPandasExec(
   extends SparkPlan with UnaryExecNode with PythonSQLMetrics {
 
   private val sessionLocalTimeZone = conf.sessionLocalTimeZone
+  private val largeVarTypes = conf.arrowUseLargeVarTypes
   private val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf)
   private val pandasFunction = func.asInstanceOf[PythonUDF].func
   private val chainedFunc = Seq(ChainedPythonFunctions(Seq(pandasFunction)))
@@ -89,6 +90,7 @@ case class FlatMapGroupsInPandasExec(
         Array(argOffsets),
         StructType.fromAttributes(dedupAttributes),
         sessionLocalTimeZone,
+        largeVarTypes,
         pythonRunnerConf,
         pythonMetrics)
 
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala
index 0fe3acb14e8..8281435ca92 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala
@@ -49,6 +49,8 @@ trait MapInBatchExec extends UnaryExecNode with 
PythonSQLMetrics {
 
   private val batchSize = conf.arrowMaxRecordsPerBatch
 
+  private val largeVarTypes = conf.arrowUseLargeVarTypes
+
   override def outputPartitioning: Partitioning = child.outputPartitioning
 
   override protected def doExecute(): RDD[InternalRow] = {
@@ -77,6 +79,7 @@ trait MapInBatchExec extends UnaryExecNode with 
PythonSQLMetrics {
         argOffsets,
         StructType(Array(StructField("struct", outputTypes))),
         sessionLocalTimeZone,
+        largeVarTypes,
         pythonRunnerConf,
         pythonMetrics).compute(batchIter, context.partitionId(), context)
 
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala
index 26ce10b6aae..c78ea564f18 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala
@@ -44,6 +44,8 @@ private[python] trait PythonArrowInput[IN] { self: 
BasePythonRunner[IN, _] =>
 
   protected val errorOnDuplicatedFieldNames: Boolean
 
+  protected val largeVarTypes: Boolean
+
   protected def pythonMetrics: Map[String, SQLMetric]
 
   protected def writeIteratorToArrowStream(
@@ -75,7 +77,8 @@ private[python] trait PythonArrowInput[IN] { self: 
BasePythonRunner[IN, _] =>
       }
 
       protected override def writeIteratorToStream(dataOut: DataOutputStream): 
Unit = {
-        val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId, 
errorOnDuplicatedFieldNames)
+        val arrowSchema = ArrowUtils.toArrowSchema(
+          schema, timeZoneId, errorOnDuplicatedFieldNames, largeVarTypes)
         val allocator = ArrowUtils.rootAllocator.newChildAllocator(
           s"stdout writer for $pythonExec", 0, Long.MaxValue)
         val root = VectorSchemaRoot.create(arrowSchema, allocator)
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala
index c5493079e40..e6a65dd61dc 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala
@@ -185,6 +185,7 @@ case class WindowInPandasExec(
     val inMemoryThreshold = conf.windowExecBufferInMemoryThreshold
     val spillThreshold = conf.windowExecBufferSpillThreshold
     val sessionLocalTimeZone = conf.sessionLocalTimeZone
+    val largeVarTypes = conf.arrowUseLargeVarTypes
 
     // Extract window expressions and window functions
     val windowExpressions = expressions.flatMap(_.collect { case e: 
WindowExpression => e })
@@ -385,6 +386,7 @@ case class WindowInPandasExec(
         argOffsets,
         pythonInputSchema,
         sessionLocalTimeZone,
+        largeVarTypes,
         pythonRunnerConf,
         pythonMetrics).compute(pythonInput, context.partitionId(), context)
 
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala
index a88f423ae01..86a961137f4 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala
@@ -27,7 +27,11 @@ import org.apache.spark.unsafe.types.UTF8String
 class ArrowWriterSuite extends SparkFunSuite {
 
   test("simple") {
-    def check(dt: DataType, data: Seq[Any], timeZoneId: String = null): Unit = 
{
+    def check(
+        dt: DataType,
+        data: Seq[Any],
+        timeZoneId: String = null,
+        largeVarTypes: Boolean = false): Unit = {
       val datatype = dt match {
         case _: DayTimeIntervalType => DayTimeIntervalType()
         case _: YearMonthIntervalType => YearMonthIntervalType()
@@ -77,7 +81,9 @@ class ArrowWriterSuite extends SparkFunSuite {
     check(DoubleType, Seq(1.0d, 2.0d, null, 4.0d))
     check(DecimalType.SYSTEM_DEFAULT, Seq(Decimal(1), Decimal(2), null, 
Decimal(4)))
     check(StringType, Seq("a", "b", null, "d").map(UTF8String.fromString))
+    check(StringType, Seq("a", "b", null, "d").map(UTF8String.fromString), 
null, true)
     check(BinaryType, Seq("a".getBytes(), "b".getBytes(), null, 
"d".getBytes()))
+    check(BinaryType, Seq("a".getBytes(), "b".getBytes(), null, 
"d".getBytes()), null, true)
     check(DateType, Seq(0, 1, 2, null, 4))
     check(TimestampType, Seq(0L, 3.6e9.toLong, null, 8.64e10.toLong), 
"America/Los_Angeles")
     check(TimestampNTZType, Seq(0L, 3.6e9.toLong, null, 8.64e10.toLong))
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/vectorized/ArrowColumnVectorSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/vectorized/ArrowColumnVectorSuite.scala
index 25beda99cd6..436cea50ad9 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/vectorized/ArrowColumnVectorSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/vectorized/ArrowColumnVectorSuite.scala
@@ -250,9 +250,36 @@ class ArrowColumnVectorSuite extends SparkFunSuite {
     allocator.close()
   }
 
+  test("large_string") {
+    val allocator = ArrowUtils.rootAllocator.newChildAllocator("string", 0, 
Long.MaxValue)
+    val vector = ArrowUtils.toArrowField("string", StringType, nullable = 
true, null, true)
+      .createVector(allocator).asInstanceOf[LargeVarCharVector]
+    vector.allocateNew()
+
+    (0 until 10).foreach { i =>
+      val utf8 = s"str$i".getBytes("utf8")
+      vector.setSafe(i, utf8, 0, utf8.length)
+    }
+    vector.setNull(10)
+    vector.setValueCount(11)
+
+    val columnVector = new ArrowColumnVector(vector)
+    assert(columnVector.dataType === StringType)
+    assert(columnVector.hasNull)
+    assert(columnVector.numNulls === 1)
+
+    (0 until 10).foreach { i =>
+      assert(columnVector.getUTF8String(i) === UTF8String.fromString(s"str$i"))
+    }
+    assert(columnVector.isNullAt(10))
+
+    columnVector.close()
+    allocator.close()
+  }
+
   test("binary") {
     val allocator = ArrowUtils.rootAllocator.newChildAllocator("binary", 0, 
Long.MaxValue)
-    val vector = ArrowUtils.toArrowField("binary", BinaryType, nullable = 
true, null)
+    val vector = ArrowUtils.toArrowField("binary", BinaryType, nullable = 
true, null, false)
       .createVector(allocator).asInstanceOf[VarBinaryVector]
     vector.allocateNew()
 
@@ -277,6 +304,33 @@ class ArrowColumnVectorSuite extends SparkFunSuite {
     allocator.close()
   }
 
+  test("large_binary") {
+    val allocator = ArrowUtils.rootAllocator.newChildAllocator("binary", 0, 
Long.MaxValue)
+    val vector = ArrowUtils.toArrowField("binary", BinaryType, nullable = 
true, null, true)
+      .createVector(allocator).asInstanceOf[LargeVarBinaryVector]
+    vector.allocateNew()
+
+    (0 until 10).foreach { i =>
+      val utf8 = s"str$i".getBytes("utf8")
+      vector.setSafe(i, utf8, 0, utf8.length)
+    }
+    vector.setNull(10)
+    vector.setValueCount(11)
+
+    val columnVector = new ArrowColumnVector(vector)
+    assert(columnVector.dataType === BinaryType)
+    assert(columnVector.hasNull)
+    assert(columnVector.numNulls === 1)
+
+    (0 until 10).foreach { i =>
+      assert(columnVector.getBinary(i) === s"str$i".getBytes("utf8"))
+    }
+    assert(columnVector.isNullAt(10))
+
+    columnVector.close()
+    allocator.close()
+  }
+
   test("array") {
     val allocator = ArrowUtils.rootAllocator.newChildAllocator("array", 0, 
Long.MaxValue)
     val vector = ArrowUtils.toArrowField("array", ArrayType(IntegerType), 
nullable = true, null)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to