This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch branch-4.0
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-4.0 by this push:
     new 5ee004704ced [SPARK-51079][PYTHON] Support large variable types in 
pandas UDF, createDataFrame and toPandas with Arrow
5ee004704ced is described below

commit 5ee004704ced1a98434f30c836e4331f8f2e4352
Author: Hyukjin Kwon <[email protected]>
AuthorDate: Thu Feb 6 08:57:52 2025 +0900

    [SPARK-51079][PYTHON] Support large variable types in pandas UDF, 
createDataFrame and toPandas with Arrow
    
    ### What changes were proposed in this pull request?
    
    This PR is a retry of https://github.com/apache/spark/pull/41569 that 
implements to use large variable types within PySpark everywhere.
    
    https://github.com/apache/spark/pull/39572 implemented the core logic but 
it only supports large variable types in the bold cases below:
    
    - `mapInArrow`: **JVM -> Python -> JVM**
    - Pandas UDF/Function API: **JVM -> Python** -> JVM
    - createDataFrame with Arrow: Python -> JVM
    - toPandas with Arrow: JVM -> Python
    
    This PR completes them all.
    
    ### Why are the changes needed?
    
    To consistently support the large variable types.
    
    ### Does this PR introduce _any_ user-facing change?
    
    `spark.sql.execution.arrow.useLargeVarTypes` is not released out yet so it 
doesn't affect any end users.
    
    ### How was this patch tested?
    
    Existing tests with `spark.sql.execution.arrow.useLargeVarTypes` enabled.
    
    Closes #49790 from HyukjinKwon/SPARK-39979-followup2.
    
    Authored-by: Hyukjin Kwon <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
    (cherry picked from commit e2ef5a482604059197c79f1f7c305df4ce57213a)
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 python/pyspark/pandas/typedef/typehints.py         | 10 ++-
 python/pyspark/sql/connect/conversion.py           |  5 +-
 python/pyspark/sql/connect/session.py              | 25 +++++-
 python/pyspark/sql/pandas/conversion.py            | 28 +++++--
 python/pyspark/sql/pandas/serializers.py           |  9 +-
 python/pyspark/sql/pandas/types.py                 | 44 ++++++++--
 python/pyspark/sql/tests/arrow/test_arrow.py       |  6 +-
 python/pyspark/worker.py                           | 95 +++++++++++++++-------
 .../org/apache/spark/sql/internal/SqlApiConf.scala |  2 +
 .../spark/sql/internal/SqlApiConfHelper.scala      |  1 +
 .../org/apache/spark/sql/util/ArrowUtils.scala     |  2 +-
 .../spark/sql/execution/arrow/ArrowWriter.scala    |  6 +-
 .../org/apache/spark/sql/internal/SQLConf.scala    |  3 +-
 .../apache/spark/sql/util/ArrowUtilsSuite.scala    |  9 +-
 .../spark/sql/connect/SQLImplicitsTestSuite.scala  |  3 +-
 .../connect/client/arrow/ArrowEncoderSuite.scala   |  9 +-
 .../apache/spark/sql/connect/SparkSession.scala    |  6 +-
 .../sql/connect/client/arrow/ArrowSerializer.scala | 44 ++++++++--
 .../connect/client/arrow/ArrowVectorReader.scala   | 21 ++++-
 .../execution/SparkConnectPlanExecution.scala      | 13 ++-
 .../sql/connect/planner/SparkConnectPlanner.scala  | 10 ++-
 .../spark/sql/connect/SparkConnectServerTest.scala |  7 +-
 .../connect/planner/SparkConnectPlannerSuite.scala |  5 +-
 .../connect/planner/SparkConnectProtoSuite.scala   |  3 +-
 .../org/apache/spark/sql/api/r/SQLUtils.scala      |  8 +-
 .../org/apache/spark/sql/classic/Dataset.scala     | 17 +++-
 .../sql/execution/arrow/ArrowConverters.scala      | 52 +++++++++---
 .../sql/execution/python/ArrowPythonRunner.scala   |  5 +-
 .../python/CoGroupedArrowPythonRunner.scala        |  4 +-
 .../python/FlatMapCoGroupsInBatchExec.scala        |  2 +
 .../spark/sql/execution/r/ArrowRRunner.scala       |  3 +-
 .../sql/execution/arrow/ArrowConvertersSuite.scala | 29 ++++---
 .../sql/execution/arrow/ArrowWriterSuite.scala     |  3 +-
 33 files changed, 372 insertions(+), 117 deletions(-)

diff --git a/python/pyspark/pandas/typedef/typehints.py 
b/python/pyspark/pandas/typedef/typehints.py
index cb82cf8d7149..4244f5831aa5 100644
--- a/python/pyspark/pandas/typedef/typehints.py
+++ b/python/pyspark/pandas/typedef/typehints.py
@@ -296,7 +296,15 @@ def spark_type_to_pandas_dtype(
     elif isinstance(spark_type, (types.TimestampType, types.TimestampNTZType)):
         return np.dtype("datetime64[ns]")
     else:
-        return np.dtype(to_arrow_type(spark_type).to_pandas_dtype())
+        from pyspark.pandas.utils import default_session
+
+        prefers_large_var_types = (
+            default_session()
+            .conf.get("spark.sql.execution.arrow.useLargeVarTypes", "false")
+            .lower()
+            == "true"
+        )
+        return np.dtype(to_arrow_type(spark_type, 
prefers_large_var_types).to_pandas_dtype())
 
 
 def pandas_on_spark_type(tpe: Union[str, type, Dtype]) -> Tuple[Dtype, 
types.DataType]:
diff --git a/python/pyspark/sql/connect/conversion.py 
b/python/pyspark/sql/connect/conversion.py
index d4363594a315..d36baacb10a3 100644
--- a/python/pyspark/sql/connect/conversion.py
+++ b/python/pyspark/sql/connect/conversion.py
@@ -323,7 +323,7 @@ class LocalDataToArrowConversion:
             return lambda value: value
 
     @staticmethod
-    def convert(data: Sequence[Any], schema: StructType) -> "pa.Table":
+    def convert(data: Sequence[Any], schema: StructType, use_large_var_types: 
bool) -> "pa.Table":
         assert isinstance(data, list) and len(data) > 0
 
         assert schema is not None and isinstance(schema, StructType)
@@ -372,7 +372,8 @@ class LocalDataToArrowConversion:
                     )
                     for field in schema.fields
                 ]
-            )
+            ),
+            prefers_large_types=use_large_var_types,
         )
 
         return pa.Table.from_arrays(pylist, schema=pa_schema)
diff --git a/python/pyspark/sql/connect/session.py 
b/python/pyspark/sql/connect/session.py
index 59349a17886b..c01c1e42a318 100644
--- a/python/pyspark/sql/connect/session.py
+++ b/python/pyspark/sql/connect/session.py
@@ -506,9 +506,13 @@ class SparkSession:
             "spark.sql.pyspark.inferNestedDictAsStruct.enabled",
             "spark.sql.pyspark.legacy.inferArrayTypeFromFirstElement.enabled",
             "spark.sql.pyspark.legacy.inferMapTypeFromFirstPair.enabled",
+            "spark.sql.execution.arrow.useLargeVarTypes",
         )
         timezone = configs["spark.sql.session.timeZone"]
         prefer_timestamp = configs["spark.sql.timestampType"]
+        prefers_large_types: bool = (
+            cast(str, 
configs["spark.sql.execution.arrow.useLargeVarTypes"]).lower() == "true"
+        )
 
         _table: Optional[pa.Table] = None
 
@@ -552,7 +556,9 @@ class SparkSession:
             if isinstance(schema, StructType):
                 deduped_schema = cast(StructType, 
_deduplicate_field_names(schema))
                 spark_types = [field.dataType for field in 
deduped_schema.fields]
-                arrow_schema = to_arrow_schema(deduped_schema)
+                arrow_schema = to_arrow_schema(
+                    deduped_schema, prefers_large_types=prefers_large_types
+                )
                 arrow_types = [field.type for field in arrow_schema]
                 _cols = [str(x) if not isinstance(x, str) else x for x in 
schema.fieldNames()]
             elif isinstance(schema, DataType):
@@ -570,7 +576,12 @@ class SparkSession:
                     else None
                     for t in data.dtypes
                 ]
-                arrow_types = [to_arrow_type(dt) if dt is not None else None 
for dt in spark_types]
+                arrow_types = [
+                    to_arrow_type(dt, prefers_large_types=prefers_large_types)
+                    if dt is not None
+                    else None
+                    for dt in spark_types
+                ]
 
             safecheck = 
configs["spark.sql.execution.pandas.convertToArrowArraySafely"]
 
@@ -609,7 +620,13 @@ class SparkSession:
 
             _table = (
                 _check_arrow_table_timestamps_localize(data, schema, True, 
timezone)
-                .cast(to_arrow_schema(schema, 
error_on_duplicated_field_names_in_struct=True))
+                .cast(
+                    to_arrow_schema(
+                        schema,
+                        error_on_duplicated_field_names_in_struct=True,
+                        prefers_large_types=prefers_large_types,
+                    )
+                )
                 .rename_columns(schema.names)
             )
 
@@ -684,7 +701,7 @@ class SparkSession:
             # Spark Connect will try its best to build the Arrow table with the
             # inferred schema in the client side, and then rename the columns 
and
             # cast the datatypes in the server side.
-            _table = LocalDataToArrowConversion.convert(_data, _schema)
+            _table = LocalDataToArrowConversion.convert(_data, _schema, 
prefers_large_types)
 
         # TODO: Beside the validation on number of columns, we should also 
check
         # whether the Arrow Schema is compatible with the user provided Schema.
diff --git a/python/pyspark/sql/pandas/conversion.py 
b/python/pyspark/sql/pandas/conversion.py
index 172a4fc4b234..18360fd81392 100644
--- a/python/pyspark/sql/pandas/conversion.py
+++ b/python/pyspark/sql/pandas/conversion.py
@@ -81,7 +81,7 @@ class PandasConversionMixin:
                 from pyspark.sql.pandas.utils import 
require_minimum_pyarrow_version
 
                 require_minimum_pyarrow_version()
-                to_arrow_schema(self.schema)
+                to_arrow_schema(self.schema, 
prefers_large_types=jconf.arrowUseLargeVarTypes())
             except Exception as e:
                 if jconf.arrowPySparkFallbackEnabled():
                     msg = (
@@ -236,7 +236,12 @@ class PandasConversionMixin:
         from pyspark.sql.pandas.utils import require_minimum_pyarrow_version
 
         require_minimum_pyarrow_version()
-        schema = to_arrow_schema(self.schema, 
error_on_duplicated_field_names_in_struct=True)
+        prefers_large_var_types = jconf.arrowUseLargeVarTypes()
+        schema = to_arrow_schema(
+            self.schema,
+            error_on_duplicated_field_names_in_struct=True,
+            prefers_large_types=prefers_large_var_types,
+        )
 
         import pyarrow as pa
 
@@ -322,7 +327,8 @@ class PandasConversionMixin:
             from pyspark.sql.pandas.types import to_arrow_schema
             import pyarrow as pa
 
-            schema = to_arrow_schema(self.schema)
+            prefers_large_var_types = 
self.sparkSession._jconf.arrowUseLargeVarTypes()
+            schema = to_arrow_schema(self.schema, 
prefers_large_types=prefers_large_var_types)
             empty_arrays = [pa.array([], type=field.type) for field in schema]
             return [pa.RecordBatch.from_arrays(empty_arrays, schema=schema)]
 
@@ -715,9 +721,16 @@ class SparkConversionMixin:
         pdf_slices = (pdf.iloc[start : start + step] for start in range(0, 
len(pdf), step))
 
         # Create list of Arrow (columns, arrow_type, spark_type) for 
serializer dump_stream
+        prefers_large_var_types = self._jconf.arrowUseLargeVarTypes()
         arrow_data = [
             [
-                (c, to_arrow_type(t) if t is not None else None, t)
+                (
+                    c,
+                    to_arrow_type(t, 
prefers_large_types=prefers_large_var_types)
+                    if t is not None
+                    else None,
+                    t,
+                )
                 for (_, c), t in zip(pdf_slice.items(), spark_types)
             ]
             for pdf_slice in pdf_slices
@@ -785,8 +798,13 @@ class SparkConversionMixin:
         if not isinstance(schema, StructType):
             schema = from_arrow_schema(table.schema, 
prefer_timestamp_ntz=prefer_timestamp_ntz)
 
+        prefers_large_var_types = self._jconf.arrowUseLargeVarTypes()
         table = _check_arrow_table_timestamps_localize(table, schema, True, 
timezone).cast(
-            to_arrow_schema(schema, 
error_on_duplicated_field_names_in_struct=True)
+            to_arrow_schema(
+                schema,
+                error_on_duplicated_field_names_in_struct=True,
+                prefers_large_types=prefers_large_var_types,
+            )
         )
 
         # Chunk the Arrow Table into RecordBatches
diff --git a/python/pyspark/sql/pandas/serializers.py 
b/python/pyspark/sql/pandas/serializers.py
index 536bf7307065..cd2e1230418f 100644
--- a/python/pyspark/sql/pandas/serializers.py
+++ b/python/pyspark/sql/pandas/serializers.py
@@ -793,6 +793,7 @@ class 
ApplyInPandasWithStateSerializer(ArrowStreamPandasUDFSerializer):
         assign_cols_by_name,
         state_object_schema,
         arrow_max_records_per_batch,
+        prefers_large_var_types,
     ):
         super(ApplyInPandasWithStateSerializer, self).__init__(
             timezone, safecheck, assign_cols_by_name
@@ -808,7 +809,9 @@ class 
ApplyInPandasWithStateSerializer(ArrowStreamPandasUDFSerializer):
             ]
         )
 
-        self.result_count_pdf_arrow_type = 
to_arrow_type(self.result_count_df_type)
+        self.result_count_pdf_arrow_type = to_arrow_type(
+            self.result_count_df_type, 
prefers_large_types=prefers_large_var_types
+        )
 
         self.result_state_df_type = StructType(
             [
@@ -819,7 +822,9 @@ class 
ApplyInPandasWithStateSerializer(ArrowStreamPandasUDFSerializer):
             ]
         )
 
-        self.result_state_pdf_arrow_type = 
to_arrow_type(self.result_state_df_type)
+        self.result_state_pdf_arrow_type = to_arrow_type(
+            self.result_state_df_type, 
prefers_large_types=prefers_large_var_types
+        )
         self.arrow_max_records_per_batch = arrow_max_records_per_batch
 
     def load_stream(self, stream):
diff --git a/python/pyspark/sql/pandas/types.py 
b/python/pyspark/sql/pandas/types.py
index d65126bb3db9..fcd70d4d1839 100644
--- a/python/pyspark/sql/pandas/types.py
+++ b/python/pyspark/sql/pandas/types.py
@@ -67,6 +67,7 @@ def to_arrow_type(
     dt: DataType,
     error_on_duplicated_field_names_in_struct: bool = False,
     timestamp_utc: bool = True,
+    prefers_large_types: bool = False,
 ) -> "pa.DataType":
     """
     Convert Spark data type to PyArrow type
@@ -107,8 +108,12 @@ def to_arrow_type(
         arrow_type = pa.float64()
     elif type(dt) == DecimalType:
         arrow_type = pa.decimal128(dt.precision, dt.scale)
+    elif type(dt) == StringType and prefers_large_types:
+        arrow_type = pa.large_string()
     elif type(dt) == StringType:
         arrow_type = pa.string()
+    elif type(dt) == BinaryType and prefers_large_types:
+        arrow_type = pa.large_binary()
     elif type(dt) == BinaryType:
         arrow_type = pa.binary()
     elif type(dt) == DateType:
@@ -125,19 +130,34 @@ def to_arrow_type(
     elif type(dt) == ArrayType:
         field = pa.field(
             "element",
-            to_arrow_type(dt.elementType, 
error_on_duplicated_field_names_in_struct, timestamp_utc),
+            to_arrow_type(
+                dt.elementType,
+                error_on_duplicated_field_names_in_struct,
+                timestamp_utc,
+                prefers_large_types,
+            ),
             nullable=dt.containsNull,
         )
         arrow_type = pa.list_(field)
     elif type(dt) == MapType:
         key_field = pa.field(
             "key",
-            to_arrow_type(dt.keyType, 
error_on_duplicated_field_names_in_struct, timestamp_utc),
+            to_arrow_type(
+                dt.keyType,
+                error_on_duplicated_field_names_in_struct,
+                timestamp_utc,
+                prefers_large_types,
+            ),
             nullable=False,
         )
         value_field = pa.field(
             "value",
-            to_arrow_type(dt.valueType, 
error_on_duplicated_field_names_in_struct, timestamp_utc),
+            to_arrow_type(
+                dt.valueType,
+                error_on_duplicated_field_names_in_struct,
+                timestamp_utc,
+                prefers_large_types,
+            ),
             nullable=dt.valueContainsNull,
         )
         arrow_type = pa.map_(key_field, value_field)
@@ -152,7 +172,10 @@ def to_arrow_type(
             pa.field(
                 field.name,
                 to_arrow_type(
-                    field.dataType, error_on_duplicated_field_names_in_struct, 
timestamp_utc
+                    field.dataType,
+                    error_on_duplicated_field_names_in_struct,
+                    timestamp_utc,
+                    prefers_large_types,
                 ),
                 nullable=field.nullable,
             )
@@ -163,7 +186,10 @@ def to_arrow_type(
         arrow_type = pa.null()
     elif isinstance(dt, UserDefinedType):
         arrow_type = to_arrow_type(
-            dt.sqlType(), error_on_duplicated_field_names_in_struct, 
timestamp_utc
+            dt.sqlType(),
+            error_on_duplicated_field_names_in_struct,
+            timestamp_utc,
+            prefers_large_types,
         )
     elif type(dt) == VariantType:
         fields = [
@@ -185,6 +211,7 @@ def to_arrow_schema(
     schema: StructType,
     error_on_duplicated_field_names_in_struct: bool = False,
     timestamp_utc: bool = True,
+    prefers_large_types: bool = False,
 ) -> "pa.Schema":
     """
     Convert a schema from Spark to Arrow
@@ -212,7 +239,12 @@ def to_arrow_schema(
     fields = [
         pa.field(
             field.name,
-            to_arrow_type(field.dataType, 
error_on_duplicated_field_names_in_struct, timestamp_utc),
+            to_arrow_type(
+                field.dataType,
+                error_on_duplicated_field_names_in_struct,
+                timestamp_utc,
+                prefers_large_types,
+            ),
             nullable=field.nullable,
         )
         for field in schema
diff --git a/python/pyspark/sql/tests/arrow/test_arrow.py 
b/python/pyspark/sql/tests/arrow/test_arrow.py
index a2ee113b6386..065f97fcf7c7 100644
--- a/python/pyspark/sql/tests/arrow/test_arrow.py
+++ b/python/pyspark/sql/tests/arrow/test_arrow.py
@@ -730,7 +730,11 @@ class ArrowTestsMixin:
     def test_schema_conversion_roundtrip(self):
         from pyspark.sql.pandas.types import from_arrow_schema, to_arrow_schema
 
-        arrow_schema = to_arrow_schema(self.schema)
+        arrow_schema = to_arrow_schema(self.schema, prefers_large_types=False)
+        schema_rt = from_arrow_schema(arrow_schema, prefer_timestamp_ntz=True)
+        self.assertEqual(self.schema, schema_rt)
+
+        arrow_schema = to_arrow_schema(self.schema, prefers_large_types=True)
         schema_rt = from_arrow_schema(arrow_schema, prefer_timestamp_ntz=True)
         self.assertEqual(self.schema, schema_rt)
 
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index e799498cdd80..7bac0157caee 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -117,10 +117,12 @@ def wrap_udf(f, args_offsets, kwargs_offsets, 
return_type):
         return args_kwargs_offsets, lambda *a: func(*a)
 
 
-def wrap_scalar_pandas_udf(f, args_offsets, kwargs_offsets, return_type):
+def wrap_scalar_pandas_udf(f, args_offsets, kwargs_offsets, return_type, 
runner_conf):
     func, args_kwargs_offsets = wrap_kwargs_support(f, args_offsets, 
kwargs_offsets)
 
-    arrow_return_type = to_arrow_type(return_type)
+    arrow_return_type = to_arrow_type(
+        return_type, prefers_large_types=use_large_var_types(runner_conf)
+    )
 
     def verify_result_type(result):
         if not hasattr(result, "__len__"):
@@ -159,7 +161,9 @@ def wrap_arrow_batch_udf(f, args_offsets, kwargs_offsets, 
return_type, runner_co
 
     func, args_kwargs_offsets = wrap_kwargs_support(f, args_offsets, 
kwargs_offsets)
 
-    arrow_return_type = to_arrow_type(return_type)
+    arrow_return_type = to_arrow_type(
+        return_type, prefers_large_types=use_large_var_types(runner_conf)
+    )
 
     # "result_func" ensures the result of a Python UDF to be consistent 
with/without Arrow
     # optimization.
@@ -205,8 +209,10 @@ def wrap_arrow_batch_udf(f, args_offsets, kwargs_offsets, 
return_type, runner_co
     )
 
 
-def wrap_pandas_batch_iter_udf(f, return_type):
-    arrow_return_type = to_arrow_type(return_type)
+def wrap_pandas_batch_iter_udf(f, return_type, runner_conf):
+    arrow_return_type = to_arrow_type(
+        return_type, prefers_large_types=use_large_var_types(runner_conf)
+    )
     iter_type_label = "pandas.DataFrame" if type(return_type) == StructType 
else "pandas.Series"
 
     def verify_result(result):
@@ -303,8 +309,10 @@ def verify_pandas_result(result, return_type, 
assign_cols_by_name, truncate_retu
             )
 
 
-def wrap_arrow_batch_iter_udf(f, return_type):
-    arrow_return_type = to_arrow_type(return_type)
+def wrap_arrow_batch_iter_udf(f, return_type, runner_conf):
+    arrow_return_type = to_arrow_type(
+        return_type, prefers_large_types=use_large_var_types(runner_conf)
+    )
 
     def verify_result(result):
         if not isinstance(result, Iterator) and not hasattr(result, 
"__iter__"):
@@ -364,6 +372,7 @@ def wrap_cogrouped_map_arrow_udf(f, return_type, argspec, 
runner_conf):
 
 
 def wrap_cogrouped_map_pandas_udf(f, return_type, argspec, runner_conf):
+    _use_large_var_types = use_large_var_types(runner_conf)
     _assign_cols_by_name = assign_cols_by_name(runner_conf)
 
     def wrapped(left_key_series, left_value_series, right_key_series, 
right_value_series):
@@ -384,7 +393,8 @@ def wrap_cogrouped_map_pandas_udf(f, return_type, argspec, 
runner_conf):
 
         return result
 
-    return lambda kl, vl, kr, vr: [(wrapped(kl, vl, kr, vr), 
to_arrow_type(return_type))]
+    arrow_return_type = to_arrow_type(return_type, _use_large_var_types)
+    return lambda kl, vl, kr, vr: [(wrapped(kl, vl, kr, vr), 
arrow_return_type)]
 
 
 def verify_arrow_result(table, assign_cols_by_name, expected_cols_and_types):
@@ -482,10 +492,12 @@ def wrap_grouped_map_arrow_udf(f, return_type, argspec, 
runner_conf):
 
         return result.to_batches()
 
-    return lambda k, v: (wrapped(k, v), to_arrow_type(return_type))
+    arrow_return_type = to_arrow_type(return_type, 
use_large_var_types(runner_conf))
+    return lambda k, v: (wrapped(k, v), arrow_return_type)
 
 
 def wrap_grouped_map_pandas_udf(f, return_type, argspec, runner_conf):
+    _use_large_var_types = use_large_var_types(runner_conf)
     _assign_cols_by_name = assign_cols_by_name(runner_conf)
 
     def wrapped(key_series, value_series):
@@ -502,7 +514,8 @@ def wrap_grouped_map_pandas_udf(f, return_type, argspec, 
runner_conf):
 
         return result
 
-    return lambda k, v: [(wrapped(k, v), to_arrow_type(return_type))]
+    arrow_return_type = to_arrow_type(return_type, _use_large_var_types)
+    return lambda k, v: [(wrapped(k, v), arrow_return_type)]
 
 
 def wrap_grouped_transform_with_state_pandas_udf(f, return_type, runner_conf):
@@ -517,7 +530,8 @@ def wrap_grouped_transform_with_state_pandas_udf(f, 
return_type, runner_conf):
 
         return result_iter
 
-    return lambda p, m, k, v: [(wrapped(p, m, k, v), 
to_arrow_type(return_type))]
+    arrow_return_type = to_arrow_type(return_type, 
use_large_var_types(runner_conf))
+    return lambda p, m, k, v: [(wrapped(p, m, k, v), arrow_return_type)]
 
 
 def wrap_grouped_transform_with_state_pandas_init_state_udf(f, return_type, 
runner_conf):
@@ -535,10 +549,11 @@ def 
wrap_grouped_transform_with_state_pandas_init_state_udf(f, return_type, runn
 
         return result_iter
 
-    return lambda p, m, k, v: [(wrapped(p, m, k, v), 
to_arrow_type(return_type))]
+    arrow_return_type = to_arrow_type(return_type, 
use_large_var_types(runner_conf))
+    return lambda p, m, k, v: [(wrapped(p, m, k, v), arrow_return_type)]
 
 
-def wrap_grouped_map_pandas_udf_with_state(f, return_type):
+def wrap_grouped_map_pandas_udf_with_state(f, return_type, runner_conf):
     """
     Provides a new lambda instance wrapping user function of 
applyInPandasWithState.
 
@@ -553,6 +568,7 @@ def wrap_grouped_map_pandas_udf_with_state(f, return_type):
     Along with the returned iterator, the lambda instance will also produce 
the return_type as
     converted to the arrow schema.
     """
+    _use_large_var_types = use_large_var_types(runner_conf)
 
     def wrapped(key_series, value_series_gen, state):
         """
@@ -627,13 +643,16 @@ def wrap_grouped_map_pandas_udf_with_state(f, 
return_type):
             state,
         )
 
-    return lambda k, v, s: [(wrapped(k, v, s), to_arrow_type(return_type))]
+    arrow_return_type = to_arrow_type(return_type, _use_large_var_types)
+    return lambda k, v, s: [(wrapped(k, v, s), arrow_return_type)]
 
 
-def wrap_grouped_agg_pandas_udf(f, args_offsets, kwargs_offsets, return_type):
+def wrap_grouped_agg_pandas_udf(f, args_offsets, kwargs_offsets, return_type, 
runner_conf):
     func, args_kwargs_offsets = wrap_kwargs_support(f, args_offsets, 
kwargs_offsets)
 
-    arrow_return_type = to_arrow_type(return_type)
+    arrow_return_type = to_arrow_type(
+        return_type, prefers_large_types=use_large_var_types(runner_conf)
+    )
 
     def wrapped(*series):
         import pandas as pd
@@ -653,9 +672,13 @@ def wrap_window_agg_pandas_udf(
     window_bound_types_str = runner_conf.get("pandas_window_bound_types")
     window_bound_type = [t.strip().lower() for t in 
window_bound_types_str.split(",")][udf_index]
     if window_bound_type == "bounded":
-        return wrap_bounded_window_agg_pandas_udf(f, args_offsets, 
kwargs_offsets, return_type)
+        return wrap_bounded_window_agg_pandas_udf(
+            f, args_offsets, kwargs_offsets, return_type, runner_conf
+        )
     elif window_bound_type == "unbounded":
-        return wrap_unbounded_window_agg_pandas_udf(f, args_offsets, 
kwargs_offsets, return_type)
+        return wrap_unbounded_window_agg_pandas_udf(
+            f, args_offsets, kwargs_offsets, return_type, runner_conf
+        )
     else:
         raise PySparkRuntimeError(
             errorClass="INVALID_WINDOW_BOUND_TYPE",
@@ -665,14 +688,16 @@ def wrap_window_agg_pandas_udf(
         )
 
 
-def wrap_unbounded_window_agg_pandas_udf(f, args_offsets, kwargs_offsets, 
return_type):
+def wrap_unbounded_window_agg_pandas_udf(f, args_offsets, kwargs_offsets, 
return_type, runner_conf):
     func, args_kwargs_offsets = wrap_kwargs_support(f, args_offsets, 
kwargs_offsets)
 
     # This is similar to grouped_agg_pandas_udf, the only difference
     # is that window_agg_pandas_udf needs to repeat the return value
     # to match window length, where grouped_agg_pandas_udf just returns
     # the scalar value.
-    arrow_return_type = to_arrow_type(return_type)
+    arrow_return_type = to_arrow_type(
+        return_type, prefers_large_types=use_large_var_types(runner_conf)
+    )
 
     def wrapped(*series):
         import pandas as pd
@@ -686,12 +711,14 @@ def wrap_unbounded_window_agg_pandas_udf(f, args_offsets, 
kwargs_offsets, return
     )
 
 
-def wrap_bounded_window_agg_pandas_udf(f, args_offsets, kwargs_offsets, 
return_type):
+def wrap_bounded_window_agg_pandas_udf(f, args_offsets, kwargs_offsets, 
return_type, runner_conf):
     # args_offsets should have at least 2 for begin_index, end_index.
     assert len(args_offsets) >= 2, len(args_offsets)
     func, args_kwargs_offsets = wrap_kwargs_support(f, args_offsets[2:], 
kwargs_offsets)
 
-    arrow_return_type = to_arrow_type(return_type)
+    arrow_return_type = to_arrow_type(
+        return_type, prefers_large_types=use_large_var_types(runner_conf)
+    )
 
     def wrapped(begin_index, end_index, *series):
         import pandas as pd
@@ -865,15 +892,15 @@ def read_single_udf(pickleSer, infile, eval_type, 
runner_conf, udf_index, profil
 
     # the last returnType will be the return type of UDF
     if eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF:
-        return wrap_scalar_pandas_udf(func, args_offsets, kwargs_offsets, 
return_type)
+        return wrap_scalar_pandas_udf(func, args_offsets, kwargs_offsets, 
return_type, runner_conf)
     elif eval_type == PythonEvalType.SQL_ARROW_BATCHED_UDF:
         return wrap_arrow_batch_udf(func, args_offsets, kwargs_offsets, 
return_type, runner_conf)
     elif eval_type == PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF:
-        return args_offsets, wrap_pandas_batch_iter_udf(func, return_type)
+        return args_offsets, wrap_pandas_batch_iter_udf(func, return_type, 
runner_conf)
     elif eval_type == PythonEvalType.SQL_MAP_PANDAS_ITER_UDF:
-        return args_offsets, wrap_pandas_batch_iter_udf(func, return_type)
+        return args_offsets, wrap_pandas_batch_iter_udf(func, return_type, 
runner_conf)
     elif eval_type == PythonEvalType.SQL_MAP_ARROW_ITER_UDF:
-        return args_offsets, wrap_arrow_batch_iter_udf(func, return_type)
+        return args_offsets, wrap_arrow_batch_iter_udf(func, return_type, 
runner_conf)
     elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF:
         argspec = inspect.getfullargspec(chained_func)  # signature was lost 
when wrapping it
         return args_offsets, wrap_grouped_map_pandas_udf(func, return_type, 
argspec, runner_conf)
@@ -881,7 +908,7 @@ def read_single_udf(pickleSer, infile, eval_type, 
runner_conf, udf_index, profil
         argspec = inspect.getfullargspec(chained_func)  # signature was lost 
when wrapping it
         return args_offsets, wrap_grouped_map_arrow_udf(func, return_type, 
argspec, runner_conf)
     elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE:
-        return args_offsets, wrap_grouped_map_pandas_udf_with_state(func, 
return_type)
+        return args_offsets, wrap_grouped_map_pandas_udf_with_state(func, 
return_type, runner_conf)
     elif eval_type == PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_UDF:
         return args_offsets, wrap_grouped_transform_with_state_pandas_udf(
             func, return_type, runner_conf
@@ -897,7 +924,9 @@ def read_single_udf(pickleSer, infile, eval_type, 
runner_conf, udf_index, profil
         argspec = inspect.getfullargspec(chained_func)  # signature was lost 
when wrapping it
         return args_offsets, wrap_cogrouped_map_arrow_udf(func, return_type, 
argspec, runner_conf)
     elif eval_type == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF:
-        return wrap_grouped_agg_pandas_udf(func, args_offsets, kwargs_offsets, 
return_type)
+        return wrap_grouped_agg_pandas_udf(
+            func, args_offsets, kwargs_offsets, return_type, runner_conf
+        )
     elif eval_type == PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF:
         return wrap_window_agg_pandas_udf(
             func, args_offsets, kwargs_offsets, return_type, runner_conf, 
udf_index
@@ -922,6 +951,10 @@ def assign_cols_by_name(runner_conf):
     )
 
 
+def use_large_var_types(runner_conf):
+    return runner_conf.get("spark.sql.execution.arrow.useLargeVarTypes", 
"false").lower() == "true"
+
+
 # Read and process a serialized user-defined table function (UDTF) from a 
socket.
 # It expects the UDTF to be in a specific format and performs various checks to
 # ensure the UDTF is valid. This function also prepares a mapper function for 
applying
@@ -1254,7 +1287,9 @@ def read_udtf(pickleSer, infile, eval_type):
         def wrap_arrow_udtf(f, return_type):
             import pandas as pd
 
-            arrow_return_type = to_arrow_type(return_type)
+            arrow_return_type = to_arrow_type(
+                return_type, 
prefers_large_types=use_large_var_types(runner_conf)
+            )
             return_type_size = len(return_type)
 
             def verify_result(result):
@@ -1499,6 +1534,7 @@ def read_udfs(pickleSer, infile, eval_type):
 
         # NOTE: if timezone is set here, that implies respectSessionTimeZone 
is True
         timezone = runner_conf.get("spark.sql.session.timeZone", None)
+        prefers_large_var_types = use_large_var_types(runner_conf)
         safecheck = (
             
runner_conf.get("spark.sql.execution.pandas.convertToArrowArraySafely", 
"false").lower()
             == "true"
@@ -1521,6 +1557,7 @@ def read_udfs(pickleSer, infile, eval_type):
                 _assign_cols_by_name,
                 state_object_schema,
                 arrow_max_records_per_batch,
+                prefers_large_var_types,
             )
         elif eval_type == PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_UDF:
             arrow_max_records_per_batch = runner_conf.get(
diff --git 
a/sql/api/src/main/scala/org/apache/spark/sql/internal/SqlApiConf.scala 
b/sql/api/src/main/scala/org/apache/spark/sql/internal/SqlApiConf.scala
index 76cd436b39b5..cb517c689ea1 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/internal/SqlApiConf.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/internal/SqlApiConf.scala
@@ -56,6 +56,8 @@ private[sql] object SqlApiConf {
   val LEGACY_TIME_PARSER_POLICY_KEY: String = 
SqlApiConfHelper.LEGACY_TIME_PARSER_POLICY_KEY
   val CASE_SENSITIVE_KEY: String = SqlApiConfHelper.CASE_SENSITIVE_KEY
   val SESSION_LOCAL_TIMEZONE_KEY: String = 
SqlApiConfHelper.SESSION_LOCAL_TIMEZONE_KEY
+  val ARROW_EXECUTION_USE_LARGE_VAR_TYPES: String =
+    SqlApiConfHelper.ARROW_EXECUTION_USE_LARGE_VAR_TYPES
   val LOCAL_RELATION_CACHE_THRESHOLD_KEY: String = {
     SqlApiConfHelper.LOCAL_RELATION_CACHE_THRESHOLD_KEY
   }
diff --git 
a/sql/api/src/main/scala/org/apache/spark/sql/internal/SqlApiConfHelper.scala 
b/sql/api/src/main/scala/org/apache/spark/sql/internal/SqlApiConfHelper.scala
index 13ef13e5894e..486a7dfb58dd 100644
--- 
a/sql/api/src/main/scala/org/apache/spark/sql/internal/SqlApiConfHelper.scala
+++ 
b/sql/api/src/main/scala/org/apache/spark/sql/internal/SqlApiConfHelper.scala
@@ -33,6 +33,7 @@ private[sql] object SqlApiConfHelper {
   val SESSION_LOCAL_TIMEZONE_KEY: String = "spark.sql.session.timeZone"
   val LOCAL_RELATION_CACHE_THRESHOLD_KEY: String = 
"spark.sql.session.localRelationCacheThreshold"
   val DEFAULT_COLLATION: String = "spark.sql.session.collation.default"
+  val ARROW_EXECUTION_USE_LARGE_VAR_TYPES = 
"spark.sql.execution.arrow.useLargeVarTypes"
 
   val confGetter: AtomicReference[() => SqlApiConf] = {
     new AtomicReference[() => SqlApiConf](() => DefaultSqlApiConf)
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala 
b/sql/api/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala
index 55d1aff8261d..587ca43e5730 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala
@@ -202,7 +202,7 @@ private[sql] object ArrowUtils {
       schema: StructType,
       timeZoneId: String,
       errorOnDuplicatedFieldNames: Boolean,
-      largeVarTypes: Boolean = false): Schema = {
+      largeVarTypes: Boolean): Schema = {
     new Schema(schema.map { field =>
       toArrowField(
         field.name,
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala
index 065b4b8c821a..c496b0e82c26 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala
@@ -33,8 +33,10 @@ object ArrowWriter {
   def create(
       schema: StructType,
       timeZoneId: String,
-      errorOnDuplicatedFieldNames: Boolean = true): ArrowWriter = {
-    val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId, 
errorOnDuplicatedFieldNames)
+      errorOnDuplicatedFieldNames: Boolean = true,
+      largeVarTypes: Boolean = false): ArrowWriter = {
+    val arrowSchema = ArrowUtils.toArrowSchema(
+      schema, timeZoneId, errorOnDuplicatedFieldNames, largeVarTypes)
     val root = VectorSchemaRoot.create(arrowSchema, ArrowUtils.rootAllocator)
     create(root)
   }
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 13fd8742cfa5..ad37262b6ed2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -3421,9 +3421,8 @@ object SQLConf {
       .doc("When using Apache Arrow, use large variable width vectors for 
string and binary " +
         "types. Regular string and binary types have a 2GiB limit for a column 
in a single " +
         "record batch. Large variable types remove this limitation at the cost 
of higher memory " +
-        "usage per value. Note that this only works for DataFrame.mapInArrow.")
+        "usage per value.")
       .version("3.5.0")
-      .internal()
       .booleanConf
       .createWithDefault(false)
 
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/util/ArrowUtilsSuite.scala 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/util/ArrowUtilsSuite.scala
index c705a6b791bd..7124c94b390d 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/util/ArrowUtilsSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/util/ArrowUtilsSuite.scala
@@ -30,7 +30,8 @@ class ArrowUtilsSuite extends SparkFunSuite {
   def roundtrip(dt: DataType): Unit = {
     dt match {
       case schema: StructType =>
-        assert(ArrowUtils.fromArrowSchema(ArrowUtils.toArrowSchema(schema, 
null, true)) === schema)
+        assert(ArrowUtils.fromArrowSchema(
+          ArrowUtils.toArrowSchema(schema, null, true, false)) === schema)
       case _ =>
         roundtrip(new StructType().add("value", dt))
     }
@@ -69,7 +70,7 @@ class ArrowUtilsSuite extends SparkFunSuite {
 
     def roundtripWithTz(timeZoneId: String): Unit = {
       val schema = new StructType().add("value", TimestampType)
-      val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId, true)
+      val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId, true, 
false)
       val fieldType = 
arrowSchema.findField("value").getType.asInstanceOf[ArrowType.Timestamp]
       assert(fieldType.getTimezone() === timeZoneId)
       assert(ArrowUtils.fromArrowSchema(arrowSchema) === schema)
@@ -105,9 +106,9 @@ class ArrowUtilsSuite extends SparkFunSuite {
     def check(dt: DataType, expected: DataType): Unit = {
       val schema = new StructType().add("value", dt)
       intercept[SparkUnsupportedOperationException] {
-        ArrowUtils.toArrowSchema(schema, null, true)
+        ArrowUtils.toArrowSchema(schema, null, true, false)
       }
-      assert(ArrowUtils.fromArrowSchema(ArrowUtils.toArrowSchema(schema, null, 
false))
+      assert(ArrowUtils.fromArrowSchema(ArrowUtils.toArrowSchema(schema, null, 
false, false))
         === new StructType().add("value", expected))
     }
 
diff --git 
a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/SQLImplicitsTestSuite.scala
 
b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/SQLImplicitsTestSuite.scala
index 2791c6b6add5..c7b4748f1222 100644
--- 
a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/SQLImplicitsTestSuite.scala
+++ 
b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/SQLImplicitsTestSuite.scala
@@ -64,7 +64,8 @@ class SQLImplicitsTestSuite extends ConnectFunSuite with 
BeforeAndAfterAll {
           input = Iterator.single(expected),
           enc = encoder,
           allocator = allocator,
-          timeZoneId = "UTC")
+          timeZoneId = "UTC",
+          largeVarTypes = false)
         val fromArrow = ArrowDeserializers.deserializeFromArrow(
           input = Iterator.single(batch.toByteArray),
           encoder = encoder,
diff --git 
a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala
 
b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala
index f6662b3351ba..58e19389cae2 100644
--- 
a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala
+++ 
b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala
@@ -106,7 +106,8 @@ class ArrowEncoderSuite extends ConnectFunSuite with 
BeforeAndAfterAll {
       maxRecordsPerBatch = maxRecordsPerBatch,
       maxBatchSize = maxBatchSize,
       batchSizeCheckInterval = batchSizeCheckInterval,
-      timeZoneId = "UTC")
+      timeZoneId = "UTC",
+      largeVarTypes = false)
 
     val inspectedIterator = if (inspectBatch != null) {
       arrowIterator.map { batch =>
@@ -183,7 +184,8 @@ class ArrowEncoderSuite extends ConnectFunSuite with 
BeforeAndAfterAll {
       allocator,
       maxRecordsPerBatch = 1024,
       maxBatchSize = 8 * 1024,
-      timeZoneId = "UTC")
+      timeZoneId = "UTC",
+      largeVarTypes = false)
   }
 
   private def compareIterators[T](expected: Iterator[T], actual: Iterator[T]): 
Unit = {
@@ -626,7 +628,8 @@ class ArrowEncoderSuite extends ConnectFunSuite with 
BeforeAndAfterAll {
         allocator,
         maxRecordsPerBatch = 128,
         maxBatchSize = 1024,
-        timeZoneId = "UTC")
+        timeZoneId = "UTC",
+        largeVarTypes = false)
       intercept[NullPointerException] {
         iterator.next()
       }
diff --git 
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/SparkSession.scala
 
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/SparkSession.scala
index b0d7d6aa5d13..e4ac4a1ba619 100644
--- 
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/SparkSession.scala
+++ 
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/SparkSession.scala
@@ -18,6 +18,7 @@ package org.apache.spark.sql.connect
 
 import java.net.URI
 import java.nio.file.{Files, Paths}
+import java.util.Locale
 import java.util.concurrent.ConcurrentHashMap
 import java.util.concurrent.atomic.AtomicLong
 
@@ -110,7 +111,8 @@ class SparkSession private[sql] (
   private def createDataset[T](encoder: AgnosticEncoder[T], data: 
Iterator[T]): Dataset[T] = {
     newDataset(encoder) { builder =>
       if (data.nonEmpty) {
-        val arrowData = ArrowSerializer.serialize(data, encoder, allocator, 
timeZoneId)
+        val arrowData =
+          ArrowSerializer.serialize(data, encoder, allocator, timeZoneId, 
largeVarTypes)
         if (arrowData.size() <= 
conf.get(SqlApiConf.LOCAL_RELATION_CACHE_THRESHOLD_KEY).toInt) {
           builder.getLocalRelationBuilder
             .setSchema(encoder.schema.json)
@@ -467,6 +469,8 @@ class SparkSession private[sql] (
   }
 
   private[sql] def timeZoneId: String = 
conf.get(SqlApiConf.SESSION_LOCAL_TIMEZONE_KEY)
+  private[sql] def largeVarTypes: Boolean =
+    
conf.get(SqlApiConf.ARROW_EXECUTION_USE_LARGE_VAR_TYPES).toLowerCase(Locale.ROOT).toBoolean
 
   private[sql] def execute[T](plan: proto.Plan, encoder: AgnosticEncoder[T]): 
SparkResult[T] = {
     val value = executeInternal(plan)
diff --git 
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowSerializer.scala
 
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowSerializer.scala
index c01390bf0785..584a318f039d 100644
--- 
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowSerializer.scala
+++ 
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowSerializer.scala
@@ -27,7 +27,7 @@ import scala.jdk.CollectionConverters._
 
 import com.google.protobuf.ByteString
 import org.apache.arrow.memory.BufferAllocator
-import org.apache.arrow.vector.{BigIntVector, BitVector, DateDayVector, 
DecimalVector, DurationVector, FieldVector, Float4Vector, Float8Vector, 
IntervalYearVector, IntVector, NullVector, SmallIntVector, 
TimeStampMicroTZVector, TimeStampMicroVector, TinyIntVector, VarBinaryVector, 
VarCharVector, VectorSchemaRoot, VectorUnloader}
+import org.apache.arrow.vector._
 import org.apache.arrow.vector.complex.{ListVector, MapVector, StructVector}
 import org.apache.arrow.vector.ipc.{ArrowStreamWriter, WriteChannel}
 import org.apache.arrow.vector.ipc.message.{IpcOption, MessageSerializer}
@@ -50,8 +50,10 @@ import org.apache.spark.unsafe.types.VariantVal
 class ArrowSerializer[T](
     private[this] val enc: AgnosticEncoder[T],
     private[this] val allocator: BufferAllocator,
-    private[this] val timeZoneId: String) {
-  private val (root, serializer) = ArrowSerializer.serializerFor(enc, 
allocator, timeZoneId)
+    private[this] val timeZoneId: String,
+    private[this] val largeVarTypes: Boolean) {
+  private val (root, serializer) =
+    ArrowSerializer.serializerFor(enc, allocator, timeZoneId, largeVarTypes)
   private val vectors = root.getFieldVectors.asScala
   private val unloader = new VectorUnloader(root)
   private val schemaBytes = {
@@ -144,12 +146,13 @@ object ArrowSerializer {
       maxRecordsPerBatch: Int,
       maxBatchSize: Long,
       timeZoneId: String,
+      largeVarTypes: Boolean,
       batchSizeCheckInterval: Int = 128): CloseableIterator[Array[Byte]] = {
     assert(maxRecordsPerBatch > 0)
     assert(maxBatchSize > 0)
     assert(batchSizeCheckInterval > 0)
     new CloseableIterator[Array[Byte]] {
-      private val serializer = new ArrowSerializer[T](enc, allocator, 
timeZoneId)
+      private val serializer = new ArrowSerializer[T](enc, allocator, 
timeZoneId, largeVarTypes)
       private val bytes = new ByteArrayOutputStream
       private var hasWrittenFirstBatch = false
 
@@ -191,8 +194,9 @@ object ArrowSerializer {
       input: Iterator[T],
       enc: AgnosticEncoder[T],
       allocator: BufferAllocator,
-      timeZoneId: String): ByteString = {
-    val serializer = new ArrowSerializer[T](enc, allocator, timeZoneId)
+      timeZoneId: String,
+      largeVarTypes: Boolean): ByteString = {
+    val serializer = new ArrowSerializer[T](enc, allocator, timeZoneId, 
largeVarTypes)
     try {
       input.foreach(serializer.append)
       val output = ByteString.newOutput()
@@ -211,9 +215,14 @@ object ArrowSerializer {
   def serializerFor[T](
       encoder: AgnosticEncoder[T],
       allocator: BufferAllocator,
-      timeZoneId: String): (VectorSchemaRoot, Serializer) = {
+      timeZoneId: String,
+      largeVarTypes: Boolean): (VectorSchemaRoot, Serializer) = {
     val arrowSchema =
-      ArrowUtils.toArrowSchema(encoder.schema, timeZoneId, 
errorOnDuplicatedFieldNames = true)
+      ArrowUtils.toArrowSchema(
+        encoder.schema,
+        timeZoneId,
+        errorOnDuplicatedFieldNames = true,
+        largeVarTypes = largeVarTypes)
     val root = VectorSchemaRoot.create(arrowSchema, allocator)
     val serializer = if (encoder.schema != encoder.dataType) {
       assert(root.getSchema.getFields.size() == 1)
@@ -264,19 +273,36 @@ object ArrowSerializer {
         new FieldSerializer[String, VarCharVector](v) {
           override def set(index: Int, value: String): Unit = setString(v, 
index, value)
         }
+      case (StringEncoder, v: LargeVarCharVector) =>
+        new FieldSerializer[String, LargeVarCharVector](v) {
+          override def set(index: Int, value: String): Unit = setString(v, 
index, value)
+        }
       case (JavaEnumEncoder(_), v: VarCharVector) =>
         new FieldSerializer[Enum[_], VarCharVector](v) {
           override def set(index: Int, value: Enum[_]): Unit = setString(v, 
index, value.name())
         }
+      case (JavaEnumEncoder(_), v: LargeVarCharVector) =>
+        new FieldSerializer[Enum[_], LargeVarCharVector](v) {
+          override def set(index: Int, value: Enum[_]): Unit = setString(v, 
index, value.name())
+        }
       case (ScalaEnumEncoder(_, _), v: VarCharVector) =>
         new FieldSerializer[Enumeration#Value, VarCharVector](v) {
           override def set(index: Int, value: Enumeration#Value): Unit =
             setString(v, index, value.toString)
         }
+      case (ScalaEnumEncoder(_, _), v: LargeVarCharVector) =>
+        new FieldSerializer[Enumeration#Value, LargeVarCharVector](v) {
+          override def set(index: Int, value: Enumeration#Value): Unit =
+            setString(v, index, value.toString)
+        }
       case (BinaryEncoder, v: VarBinaryVector) =>
         new FieldSerializer[Array[Byte], VarBinaryVector](v) {
           override def set(index: Int, value: Array[Byte]): Unit = 
vector.setSafe(index, value)
         }
+      case (BinaryEncoder, v: LargeVarBinaryVector) =>
+        new FieldSerializer[Array[Byte], LargeVarBinaryVector](v) {
+          override def set(index: Int, value: Array[Byte]): Unit = 
vector.setSafe(index, value)
+        }
       case (SparkDecimalEncoder(_), v: DecimalVector) =>
         new FieldSerializer[Decimal, DecimalVector](v) {
           override def set(index: Int, value: Decimal): Unit =
@@ -477,7 +503,7 @@ object ArrowSerializer {
 
   private val methodLookup = MethodHandles.lookup()
 
-  private def setString(vector: VarCharVector, index: Int, string: String): 
Unit = {
+  private def setString(vector: VariableWidthFieldVector, index: Int, string: 
String): Unit = {
     val bytes = Text.encode(string)
     vector.setSafe(index, bytes, 0, bytes.limit())
   }
diff --git 
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowVectorReader.scala
 
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowVectorReader.scala
index 53d8d46e6268..3dbfce18e7b4 100644
--- 
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowVectorReader.scala
+++ 
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowVectorReader.scala
@@ -20,7 +20,7 @@ import java.math.{BigDecimal => JBigDecimal}
 import java.sql.{Date, Timestamp}
 import java.time.{Duration, Instant, LocalDate, LocalDateTime, Period, 
ZoneOffset}
 
-import org.apache.arrow.vector.{BigIntVector, BitVector, DateDayVector, 
DecimalVector, DurationVector, FieldVector, Float4Vector, Float8Vector, 
IntervalYearVector, IntVector, NullVector, SmallIntVector, 
TimeStampMicroTZVector, TimeStampMicroVector, TinyIntVector, VarBinaryVector, 
VarCharVector}
+import org.apache.arrow.vector._
 import org.apache.arrow.vector.util.Text
 
 import org.apache.spark.sql.catalyst.util.{DateFormatter, SparkIntervalUtils, 
SparkStringUtils, TimestampFormatter}
@@ -82,7 +82,9 @@ object ArrowVectorReader {
       case v: Float8Vector => new Float8VectorReader(v)
       case v: DecimalVector => new DecimalVectorReader(v)
       case v: VarCharVector => new VarCharVectorReader(v)
+      case v: LargeVarCharVector => new LargeVarCharVectorReader(v)
       case v: VarBinaryVector => new VarBinaryVectorReader(v)
+      case v: LargeVarBinaryVector => new LargeVarBinaryVectorReader(v)
       case v: DurationVector => new DurationVectorReader(v)
       case v: IntervalYearVector => new IntervalYearVectorReader(v)
       case v: DateDayVector => new DateDayVectorReader(v, timeZoneId)
@@ -189,12 +191,29 @@ private[arrow] class VarCharVectorReader(v: VarCharVector)
   override def getString(i: Int): String = Text.decode(vector.get(i))
 }
 
+private[arrow] class LargeVarCharVectorReader(v: LargeVarCharVector)
+    extends TypedArrowVectorReader[LargeVarCharVector](v) {
+  // This is currently a bit heavy on allocations:
+  // - byte array created in VarCharVector.get
+  // - CharBuffer created CharSetEncoder
+  // - char array in String
+  // By using direct buffers and reusing the char buffer
+  // we could get rid of the first two allocations.
+  override def getString(i: Int): String = Text.decode(vector.get(i))
+}
+
 private[arrow] class VarBinaryVectorReader(v: VarBinaryVector)
     extends TypedArrowVectorReader[VarBinaryVector](v) {
   override def getBytes(i: Int): Array[Byte] = vector.get(i)
   override def getString(i: Int): String = 
SparkStringUtils.getHexString(getBytes(i))
 }
 
+private[arrow] class LargeVarBinaryVectorReader(v: LargeVarBinaryVector)
+    extends TypedArrowVectorReader[LargeVarBinaryVector](v) {
+  override def getBytes(i: Int): Array[Byte] = vector.get(i)
+  override def getString(i: Int): String = 
SparkStringUtils.getHexString(getBytes(i))
+}
+
 private[arrow] class DurationVectorReader(v: DurationVector)
     extends TypedArrowVectorReader[DurationVector](v) {
   override def getDuration(i: Int): Duration = vector.getObject(i)
diff --git 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala
 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala
index 497576a6630d..fc3f18063416 100644
--- 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala
+++ 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala
@@ -90,14 +90,16 @@ private[execution] class 
SparkConnectPlanExecution(executeHolder: ExecuteHolder)
       maxRecordsPerBatch: Int,
       maxBatchSize: Long,
       timeZoneId: String,
-      errorOnDuplicatedFieldNames: Boolean): Iterator[InternalRow] => 
Iterator[Batch] = { rows =>
+      errorOnDuplicatedFieldNames: Boolean,
+      largeVarTypes: Boolean): Iterator[InternalRow] => Iterator[Batch] = { 
rows =>
     val batches = ArrowConverters.toBatchWithSchemaIterator(
       rows,
       schema,
       maxRecordsPerBatch,
       maxBatchSize,
       timeZoneId,
-      errorOnDuplicatedFieldNames)
+      errorOnDuplicatedFieldNames,
+      largeVarTypes)
     batches.map(b => b -> batches.rowCountInLastBatch)
   }
 
@@ -110,6 +112,7 @@ private[execution] class 
SparkConnectPlanExecution(executeHolder: ExecuteHolder)
     val schema = dataframe.schema
     val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch
     val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone
+    val largeVarTypes = spark.sessionState.conf.arrowUseLargeVarTypes
     // Conservatively sets it 70% because the size is not accurate but 
estimated.
     val maxBatchSize = 
(SparkEnv.get.conf.get(CONNECT_GRPC_ARROW_MAX_BATCH_SIZE) * 0.7).toLong
 
@@ -118,7 +121,8 @@ private[execution] class 
SparkConnectPlanExecution(executeHolder: ExecuteHolder)
       maxRecordsPerBatch,
       maxBatchSize,
       timeZoneId,
-      errorOnDuplicatedFieldNames = false)
+      errorOnDuplicatedFieldNames = false,
+      largeVarTypes = largeVarTypes)
 
     var numSent = 0
     def sendBatch(bytes: Array[Byte], count: Long, startOffset: Long): Unit = {
@@ -239,7 +243,8 @@ private[execution] class 
SparkConnectPlanExecution(executeHolder: ExecuteHolder)
         ArrowConverters.createEmptyArrowBatch(
           schema,
           timeZoneId,
-          errorOnDuplicatedFieldNames = false),
+          errorOnDuplicatedFieldNames = false,
+          largeVarTypes = largeVarTypes),
         0L,
         0L)
     }
diff --git 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 6019969be05b..2a296522b620 100644
--- 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -2550,7 +2550,8 @@ class SparkConnectPlanner(
       result.iterator,
       StringEncoder,
       ArrowUtils.rootAllocator,
-      session.sessionState.conf.sessionLocalTimeZone)
+      session.sessionState.conf.sessionLocalTimeZone,
+      session.sessionState.conf.arrowUseLargeVarTypes)
     val sqlCommandResult = SqlCommandResult.newBuilder()
     
sqlCommandResult.getRelationBuilder.getLocalRelationBuilder.setData(arrowData)
     responseObserver.onNext(
@@ -2613,13 +2614,15 @@ class SparkConnectPlanner(
       val schema = df.schema
       val maxBatchSize = 
(SparkEnv.get.conf.get(CONNECT_GRPC_ARROW_MAX_BATCH_SIZE) * 0.7).toLong
       val timeZoneId = session.sessionState.conf.sessionLocalTimeZone
+      val largeVarTypes = session.sessionState.conf.arrowUseLargeVarTypes
 
       // Convert the data.
       val bytes = if (rows.isEmpty) {
         ArrowConverters.createEmptyArrowBatch(
           schema,
           timeZoneId,
-          errorOnDuplicatedFieldNames = false)
+          errorOnDuplicatedFieldNames = false,
+          largeVarTypes = largeVarTypes)
       } else {
         val batches = ArrowConverters.toBatchWithSchemaIterator(
           rowIter = rows.iterator,
@@ -2627,7 +2630,8 @@ class SparkConnectPlanner(
           maxRecordsPerBatch = -1,
           maxEstimatedBatchSize = maxBatchSize,
           timeZoneId = timeZoneId,
-          errorOnDuplicatedFieldNames = false)
+          errorOnDuplicatedFieldNames = false,
+          largeVarTypes = largeVarTypes)
         assert(batches.hasNext)
         val bytes = batches.next()
         assert(!batches.hasNext, s"remaining batches: ${batches.size}")
diff --git 
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala
 
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala
index db430549818d..76c88d515ec0 100644
--- 
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala
+++ 
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala
@@ -146,7 +146,12 @@ trait SparkConnectServerTest extends SharedSparkSession {
   protected def buildLocalRelation[A <: Product: TypeTag](data: Seq[A]) = {
     val encoder = ScalaReflection.encoderFor[A]
     val arrowData =
-      ArrowSerializer.serialize(data.iterator, encoder, allocator, 
TimeZone.getDefault.getID)
+      ArrowSerializer.serialize(
+        data.iterator,
+        encoder,
+        allocator,
+        TimeZone.getDefault.getID,
+        largeVarTypes = false)
     val localRelation = proto.LocalRelation
       .newBuilder()
       .setData(arrowData)
diff --git 
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala
 
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala
index 2a09d5f8e8bd..72f7065b4424 100644
--- 
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala
+++ 
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala
@@ -101,7 +101,8 @@ trait SparkConnectPlanTest extends SharedSparkSession {
         Long.MaxValue,
         Long.MaxValue,
         timeZoneId,
-        true)
+        true,
+        false)
       .next()
 
     localRelationBuilder.setData(ByteString.copyFrom(bytes))
@@ -478,7 +479,7 @@ class SparkConnectPlannerSuite extends SparkFunSuite with 
SparkConnectPlanTest {
 
   test("Empty ArrowBatch") {
     val schema = StructType(Seq(StructField("int", IntegerType)))
-    val data = ArrowConverters.createEmptyArrowBatch(schema, null, true)
+    val data = ArrowConverters.createEmptyArrowBatch(schema, null, true, false)
     val localRelation = proto.Relation
       .newBuilder()
       .setLocalRelation(
diff --git 
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
 
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
index 2bbd6863b110..494aceb2fb58 100644
--- 
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
+++ 
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
@@ -1115,7 +1115,8 @@ class SparkConnectProtoSuite extends PlanTest with 
SparkConnectPlanTest {
         Long.MaxValue,
         Long.MaxValue,
         null,
-        true)
+        true,
+        false)
       .next()
     proto.Relation
       .newBuilder()
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala
index ad58fc0c2fcf..1efd8f9e3220 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala
@@ -249,7 +249,13 @@ private[sql] object SQLUtils extends Logging {
     val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone
     val rdd = arrowBatchRDD.rdd.mapPartitions { iter =>
       val context = TaskContext.get()
-      ArrowConverters.fromBatchIterator(iter, schema, timeZoneId, true, 
context)
+      ArrowConverters.fromBatchIterator(
+        iter,
+        schema,
+        timeZoneId,
+        true,
+        false,
+        context)
     }
     sparkSession.internalCreateDataFrame(rdd.setName("arrow"), schema)
   }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.scala
index d78a3a391edb..8930b5895d32 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.scala
@@ -2089,7 +2089,8 @@ class Dataset[T] private[sql](
         val buffer = new ByteArrayOutputStream()
         val out = new DataOutputStream(outputStream)
         val batchWriter =
-          new ArrowBatchStreamWriter(schema, buffer, timeZoneId, 
errorOnDuplicatedFieldNames = true)
+          new ArrowBatchStreamWriter(
+            schema, buffer, timeZoneId, errorOnDuplicatedFieldNames = true, 
largeVarTypes = false)
         val arrowBatchRdd = toArrowBatchRdd(plan)
         val numPartitions = arrowBatchRdd.partitions.length
 
@@ -2140,12 +2141,14 @@ class Dataset[T] private[sql](
     val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone
     val errorOnDuplicatedFieldNames =
       sparkSession.sessionState.conf.pandasStructHandlingMode == "legacy"
+    val largeVarTypes = sparkSession.sessionState.conf.arrowUseLargeVarTypes
 
     PythonRDD.serveToStream("serve-Arrow") { outputStream =>
       withAction("collectAsArrowToPython", queryExecution) { plan =>
         val out = new DataOutputStream(outputStream)
         val batchWriter =
-          new ArrowBatchStreamWriter(schema, out, timeZoneId, 
errorOnDuplicatedFieldNames)
+          new ArrowBatchStreamWriter(
+            schema, out, timeZoneId, errorOnDuplicatedFieldNames, 
largeVarTypes)
 
         // Batches ordered by (index of partition, batch index in that 
partition) tuple
         val batchOrder = ArrayBuffer.empty[(Int, Int)]
@@ -2294,10 +2297,18 @@ class Dataset[T] private[sql](
     val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone
     val errorOnDuplicatedFieldNames =
       sparkSession.sessionState.conf.pandasStructHandlingMode == "legacy"
+    val largeVarTypes =
+      sparkSession.sessionState.conf.arrowUseLargeVarTypes
     plan.execute().mapPartitionsInternal { iter =>
       val context = TaskContext.get()
       ArrowConverters.toBatchIterator(
-        iter, schemaCaptured, maxRecordsPerBatch, timeZoneId, 
errorOnDuplicatedFieldNames, context)
+        iter,
+        schemaCaptured,
+        maxRecordsPerBatch,
+        timeZoneId,
+        errorOnDuplicatedFieldNames,
+        largeVarTypes,
+        context)
     }
   }
 
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala
index 7ea1bd6ff7dc..ed490347ae82 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala
@@ -51,9 +51,11 @@ private[sql] class ArrowBatchStreamWriter(
     schema: StructType,
     out: OutputStream,
     timeZoneId: String,
-    errorOnDuplicatedFieldNames: Boolean) {
+    errorOnDuplicatedFieldNames: Boolean,
+    largeVarTypes: Boolean) {
 
-  val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId, 
errorOnDuplicatedFieldNames)
+  val arrowSchema = ArrowUtils.toArrowSchema(
+    schema, timeZoneId, errorOnDuplicatedFieldNames, largeVarTypes)
   val writeChannel = new WriteChannel(Channels.newChannel(out))
 
   // Write the Arrow schema first, before batches
@@ -81,10 +83,11 @@ private[sql] object ArrowConverters extends Logging {
       maxRecordsPerBatch: Long,
       timeZoneId: String,
       errorOnDuplicatedFieldNames: Boolean,
+      largeVarTypes: Boolean,
       context: TaskContext) extends Iterator[Array[Byte]] with AutoCloseable {
 
     protected val arrowSchema =
-      ArrowUtils.toArrowSchema(schema, timeZoneId, errorOnDuplicatedFieldNames)
+      ArrowUtils.toArrowSchema(schema, timeZoneId, 
errorOnDuplicatedFieldNames, largeVarTypes)
     private val allocator =
       ArrowUtils.rootAllocator.newChildAllocator(
         s"to${this.getClass.getSimpleName}", 0, Long.MaxValue)
@@ -137,9 +140,16 @@ private[sql] object ArrowConverters extends Logging {
       maxEstimatedBatchSize: Long,
       timeZoneId: String,
       errorOnDuplicatedFieldNames: Boolean,
+      largeVarTypes: Boolean,
       context: TaskContext)
     extends ArrowBatchIterator(
-      rowIter, schema, maxRecordsPerBatch, timeZoneId, 
errorOnDuplicatedFieldNames, context) {
+      rowIter,
+      schema,
+      maxRecordsPerBatch,
+      timeZoneId,
+      errorOnDuplicatedFieldNames,
+      largeVarTypes,
+      context) {
 
     private val arrowSchemaSize = SizeEstimator.estimate(arrowSchema)
     var rowCountInLastBatch: Long = 0
@@ -205,9 +215,16 @@ private[sql] object ArrowConverters extends Logging {
       maxRecordsPerBatch: Long,
       timeZoneId: String,
       errorOnDuplicatedFieldNames: Boolean,
+      largeVarTypes: Boolean,
       context: TaskContext): ArrowBatchIterator = {
     new ArrowBatchIterator(
-      rowIter, schema, maxRecordsPerBatch, timeZoneId, 
errorOnDuplicatedFieldNames, context)
+      rowIter,
+      schema,
+      maxRecordsPerBatch,
+      timeZoneId,
+      errorOnDuplicatedFieldNames,
+      largeVarTypes,
+      context)
   }
 
   /**
@@ -220,19 +237,21 @@ private[sql] object ArrowConverters extends Logging {
       maxRecordsPerBatch: Long,
       maxEstimatedBatchSize: Long,
       timeZoneId: String,
-      errorOnDuplicatedFieldNames: Boolean): ArrowBatchWithSchemaIterator = {
+      errorOnDuplicatedFieldNames: Boolean,
+      largeVarTypes: Boolean): ArrowBatchWithSchemaIterator = {
     new ArrowBatchWithSchemaIterator(
       rowIter, schema, maxRecordsPerBatch, maxEstimatedBatchSize,
-      timeZoneId, errorOnDuplicatedFieldNames, TaskContext.get())
+      timeZoneId, errorOnDuplicatedFieldNames, largeVarTypes, 
TaskContext.get())
   }
 
   private[sql] def createEmptyArrowBatch(
       schema: StructType,
       timeZoneId: String,
-      errorOnDuplicatedFieldNames: Boolean): Array[Byte] = {
+      errorOnDuplicatedFieldNames: Boolean,
+      largeVarTypes: Boolean): Array[Byte] = {
     val batches = new ArrowBatchWithSchemaIterator(
         Iterator.empty, schema, 0L, 0L,
-        timeZoneId, errorOnDuplicatedFieldNames, TaskContext.get()) {
+        timeZoneId, errorOnDuplicatedFieldNames, largeVarTypes, 
TaskContext.get()) {
       override def hasNext: Boolean = true
     }
     Utils.tryWithSafeFinally {
@@ -299,12 +318,13 @@ private[sql] object ArrowConverters extends Logging {
       schema: StructType,
       timeZoneId: String,
       errorOnDuplicatedFieldNames: Boolean,
+      largeVarTypes: Boolean,
       context: TaskContext)
       extends InternalRowIterator(arrowBatchIter, context) {
 
     override def nextBatch(): (Iterator[InternalRow], StructType) = {
       val arrowSchema =
-        ArrowUtils.toArrowSchema(schema, timeZoneId, 
errorOnDuplicatedFieldNames)
+        ArrowUtils.toArrowSchema(schema, timeZoneId, 
errorOnDuplicatedFieldNames, largeVarTypes)
       val root = VectorSchemaRoot.create(arrowSchema, allocator)
       resources.append(root)
       val arrowRecordBatch = ArrowConverters.loadBatch(arrowBatchIter.next(), 
allocator)
@@ -344,9 +364,12 @@ private[sql] object ArrowConverters extends Logging {
       schema: StructType,
       timeZoneId: String,
       errorOnDuplicatedFieldNames: Boolean,
-      context: TaskContext): Iterator[InternalRow] = new 
InternalRowIteratorWithoutSchema(
-    arrowBatchIter, schema, timeZoneId, errorOnDuplicatedFieldNames, context
-  )
+      largeVarTypes: Boolean,
+      context: TaskContext): Iterator[InternalRow] = {
+    new InternalRowIteratorWithoutSchema(
+      arrowBatchIter, schema, timeZoneId, errorOnDuplicatedFieldNames, 
largeVarTypes, context
+    )
+  }
 
   /**
    * Maps iterator from serialized ArrowRecordBatches to InternalRows. 
Different from
@@ -393,6 +416,7 @@ private[sql] object ArrowConverters extends Logging {
     val schema = DataType.fromJson(schemaString).asInstanceOf[StructType]
     val attrs = toAttributes(schema)
     val batchesInDriver = arrowBatches.toArray
+    val largeVarTypes = session.sessionState.conf.arrowUseLargeVarTypes
     val shouldUseRDD = session.sessionState.conf
       .arrowLocalRelationThreshold < batchesInDriver.map(_.length.toLong).sum
 
@@ -407,6 +431,7 @@ private[sql] object ArrowConverters extends Logging {
             schema,
             timezone,
             errorOnDuplicatedFieldNames = false,
+            largeVarTypes = largeVarTypes,
             TaskContext.get())
         }
       session.internalCreateDataFrame(rdd.setName("arrow"), schema)
@@ -417,6 +442,7 @@ private[sql] object ArrowConverters extends Logging {
         schema,
         session.sessionState.conf.sessionLocalTimeZone,
         errorOnDuplicatedFieldNames = false,
+        largeVarTypes = largeVarTypes,
         TaskContext.get())
 
       // Project/copy it. Otherwise, the Arrow column vectors will be closed 
and released out.
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
index 1bddd81fbfe2..bf2142422562 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
@@ -120,6 +120,9 @@ object ArrowPythonRunner {
     val arrowAyncParallelism = conf.pythonUDFArrowConcurrencyLevel.map(v =>
       Seq(SQLConf.PYTHON_UDF_ARROW_CONCURRENCY_LEVEL.key -> v.toString)
     ).getOrElse(Seq.empty)
-    Map(timeZoneConf ++ pandasColsByName ++ arrowSafeTypeCheck ++ 
arrowAyncParallelism: _*)
+    val useLargeVarTypes = Seq(SQLConf.ARROW_EXECUTION_USE_LARGE_VAR_TYPES.key 
->
+      conf.arrowUseLargeVarTypes.toString)
+    Map(timeZoneConf ++ pandasColsByName ++ arrowSafeTypeCheck ++
+      arrowAyncParallelism ++ useLargeVarTypes: _*)
   }
 }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala
index 59e8970b9c9b..9caa344d00c5 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala
@@ -45,6 +45,7 @@ class CoGroupedArrowPythonRunner(
     leftSchema: StructType,
     rightSchema: StructType,
     timeZoneId: String,
+    largeVarTypes: Boolean,
     conf: Map[String, String],
     override val pythonMetrics: Map[String, SQLMetric],
     jobArtifactUUID: Option[String],
@@ -109,7 +110,8 @@ class CoGroupedArrowPythonRunner(
           dataOut: DataOutputStream,
           name: String): Unit = {
         val arrowSchema =
-          ArrowUtils.toArrowSchema(schema, timeZoneId, 
errorOnDuplicatedFieldNames = true)
+          ArrowUtils.toArrowSchema(
+            schema, timeZoneId, errorOnDuplicatedFieldNames = true, 
largeVarTypes = largeVarTypes)
         val allocator = ArrowUtils.rootAllocator.newChildAllocator(
           s"stdout writer for $pythonExec ($name)", 0, Long.MaxValue)
         val root = VectorSchemaRoot.create(arrowSchema, allocator)
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInBatchExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInBatchExec.scala
index 66ed2bca7677..af487218391e 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInBatchExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInBatchExec.scala
@@ -42,6 +42,7 @@ trait FlatMapCoGroupsInBatchExec extends SparkPlan with 
BinaryExecNode with Pyth
   protected val pythonEvalType: Int
 
   private val sessionLocalTimeZone = conf.sessionLocalTimeZone
+  private val largeVarTypes = conf.arrowUseLargeVarTypes
   private val pythonRunnerConf = ArrowPythonRunner.getPythonRunnerConfMap(conf)
   private val pythonUDF = func.asInstanceOf[PythonUDF]
   private val pandasFunction = pythonUDF.func
@@ -84,6 +85,7 @@ trait FlatMapCoGroupsInBatchExec extends SparkPlan with 
BinaryExecNode with Pyth
           DataTypeUtils.fromAttributes(leftDedup),
           DataTypeUtils.fromAttributes(rightDedup),
           sessionLocalTimeZone,
+          largeVarTypes,
           pythonRunnerConf,
           pythonMetrics,
           jobArtifactUUID,
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/r/ArrowRRunner.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/r/ArrowRRunner.scala
index 45ecf8700950..aaf2f256273d 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/r/ArrowRRunner.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/r/ArrowRRunner.scala
@@ -85,7 +85,8 @@ class ArrowRRunner(
       override protected def writeIteratorToStream(dataOut: DataOutputStream): 
Unit = {
         if (inputIterator.hasNext) {
           val arrowSchema =
-            ArrowUtils.toArrowSchema(schema, timeZoneId, 
errorOnDuplicatedFieldNames = true)
+            ArrowUtils.toArrowSchema(
+              schema, timeZoneId, errorOnDuplicatedFieldNames = true, 
largeVarTypes = false)
           val allocator = ArrowUtils.rootAllocator.newChildAllocator(
             "stdout writer for R", 0, Long.MaxValue)
           val root = VectorSchemaRoot.create(arrowSchema, allocator)
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala
index 33e5d46ee233..39c3d8df7550 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala
@@ -1376,8 +1376,10 @@ class ArrowConvertersSuite extends SharedSparkSession {
     val schema = StructType(Seq(StructField("int", IntegerType, nullable = 
true)))
 
     val ctx = TaskContext.empty()
-    val batchIter = ArrowConverters.toBatchIterator(inputRows.iterator, 
schema, 5, null, true, ctx)
-    val outputRowIter = ArrowConverters.fromBatchIterator(batchIter, schema, 
null, true, ctx)
+    val batchIter = ArrowConverters.toBatchIterator(
+      inputRows.iterator, schema, 5, null, true, false, ctx)
+    val outputRowIter = ArrowConverters.fromBatchIterator(
+      batchIter, schema, null, true, false, ctx)
 
     var count = 0
     outputRowIter.zipWithIndex.foreach { case (row, i) =>
@@ -1397,12 +1399,13 @@ class ArrowConvertersSuite extends SharedSparkSession {
 
     val schema = StructType(Seq(StructField("int", IntegerType, nullable = 
true)))
     val ctx = TaskContext.empty()
-    val batchIter = ArrowConverters.toBatchIterator(inputRows.iterator, 
schema, 5, null, true, ctx)
+    val batchIter = ArrowConverters.toBatchIterator(
+      inputRows.iterator, schema, 5, null, true, false, ctx)
 
     // Write batches to Arrow stream format as a byte array
     val out = new ByteArrayOutputStream()
     Utils.tryWithResource(new DataOutputStream(out)) { dataOut =>
-      val writer = new ArrowBatchStreamWriter(schema, dataOut, null, true)
+      val writer = new ArrowBatchStreamWriter(schema, dataOut, null, true, 
false)
       writer.writeBatches(batchIter)
       writer.end()
     }
@@ -1410,7 +1413,8 @@ class ArrowConvertersSuite extends SharedSparkSession {
     // Read Arrow stream into batches, then convert back to rows
     val in = new ByteArrayReadableSeekableByteChannel(out.toByteArray)
     val readBatches = ArrowConverters.getBatchesFromStream(in)
-    val outputRowIter = ArrowConverters.fromBatchIterator(readBatches, schema, 
null, true, ctx)
+    val outputRowIter = ArrowConverters.fromBatchIterator(
+      readBatches, schema, null, true, false, ctx)
 
     var count = 0
     outputRowIter.zipWithIndex.foreach { case (row, i) =>
@@ -1441,7 +1445,7 @@ class ArrowConvertersSuite extends SharedSparkSession {
     }
     val ctx = TaskContext.empty()
     val batchIter = ArrowConverters.toBatchWithSchemaIterator(
-      inputRows.iterator, schema, 5, 1024 * 1024, null, true)
+      inputRows.iterator, schema, 5, 1024 * 1024, null, true, false)
     val (outputRowIter, outputType) = 
ArrowConverters.fromBatchWithSchemaIterator(batchIter, ctx)
 
     var count = 0
@@ -1460,7 +1464,8 @@ class ArrowConvertersSuite extends SharedSparkSession {
     val schema = StructType(Seq(StructField("int", IntegerType, nullable = 
true)))
     val ctx = TaskContext.empty()
     val batchIter =
-      ArrowConverters.toBatchWithSchemaIterator(Iterator.empty, schema, 5, 
1024 * 1024, null, true)
+      ArrowConverters.toBatchWithSchemaIterator(
+        Iterator.empty, schema, 5, 1024 * 1024, null, true, false)
     val (outputRowIter, outputType) = 
ArrowConverters.fromBatchWithSchemaIterator(batchIter, ctx)
 
     assert(0 == outputRowIter.length)
@@ -1474,7 +1479,7 @@ class ArrowConvertersSuite extends SharedSparkSession {
       proj(row).copy()
     }
     val batchIter1 = ArrowConverters.toBatchWithSchemaIterator(
-      inputRows1.iterator, schema1, 5, 1024 * 1024, null, true)
+      inputRows1.iterator, schema1, 5, 1024 * 1024, null, true, false)
 
     val schema2 = StructType(Seq(StructField("field2", IntegerType, nullable = 
true)))
     val inputRows2 = Array(InternalRow(1)).map { row =>
@@ -1482,7 +1487,7 @@ class ArrowConvertersSuite extends SharedSparkSession {
       proj(row).copy()
     }
     val batchIter2 = ArrowConverters.toBatchWithSchemaIterator(
-      inputRows2.iterator, schema2, 5, 1024 * 1024, null, true)
+      inputRows2.iterator, schema2, 5, 1024 * 1024, null, true, false)
 
     val iter = batchIter1.toArray ++ batchIter2
 
@@ -1511,11 +1516,13 @@ class ArrowConvertersSuite extends SharedSparkSession {
       batchBytes: Array[Byte],
       jsonFile: File,
       timeZoneId: String = null,
-      errorOnDuplicatedFieldNames: Boolean = true): Unit = {
+      errorOnDuplicatedFieldNames: Boolean = true,
+      largeVarTypes: Boolean = false): Unit = {
     val allocator = new RootAllocator(Long.MaxValue)
     val jsonReader = new JsonFileReader(jsonFile, allocator)
 
-    val arrowSchema = ArrowUtils.toArrowSchema(sparkSchema, timeZoneId, 
errorOnDuplicatedFieldNames)
+    val arrowSchema = ArrowUtils.toArrowSchema(
+      sparkSchema, timeZoneId, errorOnDuplicatedFieldNames, largeVarTypes)
     val jsonSchema = jsonReader.start()
     Validator.compareSchemas(arrowSchema, jsonSchema)
 
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala
index 7830ea1da177..acf258a373c3 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala
@@ -158,7 +158,8 @@ class ArrowWriterSuite extends SparkFunSuite {
         schema: StructType,
         timeZoneId: String): (ArrowWriter, Int) = {
       val arrowSchema =
-        ArrowUtils.toArrowSchema(schema, timeZoneId, 
errorOnDuplicatedFieldNames = true)
+        ArrowUtils.toArrowSchema(
+          schema, timeZoneId, errorOnDuplicatedFieldNames = true, 
largeVarTypes = false)
       val root = VectorSchemaRoot.create(arrowSchema, ArrowUtils.rootAllocator)
       val vector = root.getFieldVectors.get(0)
       vector.allocateNew()


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to