This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch branch-4.0
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-4.0 by this push:
new 5ee004704ced [SPARK-51079][PYTHON] Support large variable types in
pandas UDF, createDataFrame and toPandas with Arrow
5ee004704ced is described below
commit 5ee004704ced1a98434f30c836e4331f8f2e4352
Author: Hyukjin Kwon <[email protected]>
AuthorDate: Thu Feb 6 08:57:52 2025 +0900
[SPARK-51079][PYTHON] Support large variable types in pandas UDF,
createDataFrame and toPandas with Arrow
### What changes were proposed in this pull request?
This PR is a retry of https://github.com/apache/spark/pull/41569 that
implements to use large variable types within PySpark everywhere.
https://github.com/apache/spark/pull/39572 implemented the core logic but
it only supports large variable types in the bold cases below:
- `mapInArrow`: **JVM -> Python -> JVM**
- Pandas UDF/Function API: **JVM -> Python** -> JVM
- createDataFrame with Arrow: Python -> JVM
- toPandas with Arrow: JVM -> Python
This PR completes them all.
### Why are the changes needed?
To consistently support the large variable types.
### Does this PR introduce _any_ user-facing change?
`spark.sql.execution.arrow.useLargeVarTypes` is not released out yet so it
doesn't affect any end users.
### How was this patch tested?
Existing tests with `spark.sql.execution.arrow.useLargeVarTypes` enabled.
Closes #49790 from HyukjinKwon/SPARK-39979-followup2.
Authored-by: Hyukjin Kwon <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
(cherry picked from commit e2ef5a482604059197c79f1f7c305df4ce57213a)
Signed-off-by: Hyukjin Kwon <[email protected]>
---
python/pyspark/pandas/typedef/typehints.py | 10 ++-
python/pyspark/sql/connect/conversion.py | 5 +-
python/pyspark/sql/connect/session.py | 25 +++++-
python/pyspark/sql/pandas/conversion.py | 28 +++++--
python/pyspark/sql/pandas/serializers.py | 9 +-
python/pyspark/sql/pandas/types.py | 44 ++++++++--
python/pyspark/sql/tests/arrow/test_arrow.py | 6 +-
python/pyspark/worker.py | 95 +++++++++++++++-------
.../org/apache/spark/sql/internal/SqlApiConf.scala | 2 +
.../spark/sql/internal/SqlApiConfHelper.scala | 1 +
.../org/apache/spark/sql/util/ArrowUtils.scala | 2 +-
.../spark/sql/execution/arrow/ArrowWriter.scala | 6 +-
.../org/apache/spark/sql/internal/SQLConf.scala | 3 +-
.../apache/spark/sql/util/ArrowUtilsSuite.scala | 9 +-
.../spark/sql/connect/SQLImplicitsTestSuite.scala | 3 +-
.../connect/client/arrow/ArrowEncoderSuite.scala | 9 +-
.../apache/spark/sql/connect/SparkSession.scala | 6 +-
.../sql/connect/client/arrow/ArrowSerializer.scala | 44 ++++++++--
.../connect/client/arrow/ArrowVectorReader.scala | 21 ++++-
.../execution/SparkConnectPlanExecution.scala | 13 ++-
.../sql/connect/planner/SparkConnectPlanner.scala | 10 ++-
.../spark/sql/connect/SparkConnectServerTest.scala | 7 +-
.../connect/planner/SparkConnectPlannerSuite.scala | 5 +-
.../connect/planner/SparkConnectProtoSuite.scala | 3 +-
.../org/apache/spark/sql/api/r/SQLUtils.scala | 8 +-
.../org/apache/spark/sql/classic/Dataset.scala | 17 +++-
.../sql/execution/arrow/ArrowConverters.scala | 52 +++++++++---
.../sql/execution/python/ArrowPythonRunner.scala | 5 +-
.../python/CoGroupedArrowPythonRunner.scala | 4 +-
.../python/FlatMapCoGroupsInBatchExec.scala | 2 +
.../spark/sql/execution/r/ArrowRRunner.scala | 3 +-
.../sql/execution/arrow/ArrowConvertersSuite.scala | 29 ++++---
.../sql/execution/arrow/ArrowWriterSuite.scala | 3 +-
33 files changed, 372 insertions(+), 117 deletions(-)
diff --git a/python/pyspark/pandas/typedef/typehints.py
b/python/pyspark/pandas/typedef/typehints.py
index cb82cf8d7149..4244f5831aa5 100644
--- a/python/pyspark/pandas/typedef/typehints.py
+++ b/python/pyspark/pandas/typedef/typehints.py
@@ -296,7 +296,15 @@ def spark_type_to_pandas_dtype(
elif isinstance(spark_type, (types.TimestampType, types.TimestampNTZType)):
return np.dtype("datetime64[ns]")
else:
- return np.dtype(to_arrow_type(spark_type).to_pandas_dtype())
+ from pyspark.pandas.utils import default_session
+
+ prefers_large_var_types = (
+ default_session()
+ .conf.get("spark.sql.execution.arrow.useLargeVarTypes", "false")
+ .lower()
+ == "true"
+ )
+ return np.dtype(to_arrow_type(spark_type,
prefers_large_var_types).to_pandas_dtype())
def pandas_on_spark_type(tpe: Union[str, type, Dtype]) -> Tuple[Dtype,
types.DataType]:
diff --git a/python/pyspark/sql/connect/conversion.py
b/python/pyspark/sql/connect/conversion.py
index d4363594a315..d36baacb10a3 100644
--- a/python/pyspark/sql/connect/conversion.py
+++ b/python/pyspark/sql/connect/conversion.py
@@ -323,7 +323,7 @@ class LocalDataToArrowConversion:
return lambda value: value
@staticmethod
- def convert(data: Sequence[Any], schema: StructType) -> "pa.Table":
+ def convert(data: Sequence[Any], schema: StructType, use_large_var_types:
bool) -> "pa.Table":
assert isinstance(data, list) and len(data) > 0
assert schema is not None and isinstance(schema, StructType)
@@ -372,7 +372,8 @@ class LocalDataToArrowConversion:
)
for field in schema.fields
]
- )
+ ),
+ prefers_large_types=use_large_var_types,
)
return pa.Table.from_arrays(pylist, schema=pa_schema)
diff --git a/python/pyspark/sql/connect/session.py
b/python/pyspark/sql/connect/session.py
index 59349a17886b..c01c1e42a318 100644
--- a/python/pyspark/sql/connect/session.py
+++ b/python/pyspark/sql/connect/session.py
@@ -506,9 +506,13 @@ class SparkSession:
"spark.sql.pyspark.inferNestedDictAsStruct.enabled",
"spark.sql.pyspark.legacy.inferArrayTypeFromFirstElement.enabled",
"spark.sql.pyspark.legacy.inferMapTypeFromFirstPair.enabled",
+ "spark.sql.execution.arrow.useLargeVarTypes",
)
timezone = configs["spark.sql.session.timeZone"]
prefer_timestamp = configs["spark.sql.timestampType"]
+ prefers_large_types: bool = (
+ cast(str,
configs["spark.sql.execution.arrow.useLargeVarTypes"]).lower() == "true"
+ )
_table: Optional[pa.Table] = None
@@ -552,7 +556,9 @@ class SparkSession:
if isinstance(schema, StructType):
deduped_schema = cast(StructType,
_deduplicate_field_names(schema))
spark_types = [field.dataType for field in
deduped_schema.fields]
- arrow_schema = to_arrow_schema(deduped_schema)
+ arrow_schema = to_arrow_schema(
+ deduped_schema, prefers_large_types=prefers_large_types
+ )
arrow_types = [field.type for field in arrow_schema]
_cols = [str(x) if not isinstance(x, str) else x for x in
schema.fieldNames()]
elif isinstance(schema, DataType):
@@ -570,7 +576,12 @@ class SparkSession:
else None
for t in data.dtypes
]
- arrow_types = [to_arrow_type(dt) if dt is not None else None
for dt in spark_types]
+ arrow_types = [
+ to_arrow_type(dt, prefers_large_types=prefers_large_types)
+ if dt is not None
+ else None
+ for dt in spark_types
+ ]
safecheck =
configs["spark.sql.execution.pandas.convertToArrowArraySafely"]
@@ -609,7 +620,13 @@ class SparkSession:
_table = (
_check_arrow_table_timestamps_localize(data, schema, True,
timezone)
- .cast(to_arrow_schema(schema,
error_on_duplicated_field_names_in_struct=True))
+ .cast(
+ to_arrow_schema(
+ schema,
+ error_on_duplicated_field_names_in_struct=True,
+ prefers_large_types=prefers_large_types,
+ )
+ )
.rename_columns(schema.names)
)
@@ -684,7 +701,7 @@ class SparkSession:
# Spark Connect will try its best to build the Arrow table with the
# inferred schema in the client side, and then rename the columns
and
# cast the datatypes in the server side.
- _table = LocalDataToArrowConversion.convert(_data, _schema)
+ _table = LocalDataToArrowConversion.convert(_data, _schema,
prefers_large_types)
# TODO: Beside the validation on number of columns, we should also
check
# whether the Arrow Schema is compatible with the user provided Schema.
diff --git a/python/pyspark/sql/pandas/conversion.py
b/python/pyspark/sql/pandas/conversion.py
index 172a4fc4b234..18360fd81392 100644
--- a/python/pyspark/sql/pandas/conversion.py
+++ b/python/pyspark/sql/pandas/conversion.py
@@ -81,7 +81,7 @@ class PandasConversionMixin:
from pyspark.sql.pandas.utils import
require_minimum_pyarrow_version
require_minimum_pyarrow_version()
- to_arrow_schema(self.schema)
+ to_arrow_schema(self.schema,
prefers_large_types=jconf.arrowUseLargeVarTypes())
except Exception as e:
if jconf.arrowPySparkFallbackEnabled():
msg = (
@@ -236,7 +236,12 @@ class PandasConversionMixin:
from pyspark.sql.pandas.utils import require_minimum_pyarrow_version
require_minimum_pyarrow_version()
- schema = to_arrow_schema(self.schema,
error_on_duplicated_field_names_in_struct=True)
+ prefers_large_var_types = jconf.arrowUseLargeVarTypes()
+ schema = to_arrow_schema(
+ self.schema,
+ error_on_duplicated_field_names_in_struct=True,
+ prefers_large_types=prefers_large_var_types,
+ )
import pyarrow as pa
@@ -322,7 +327,8 @@ class PandasConversionMixin:
from pyspark.sql.pandas.types import to_arrow_schema
import pyarrow as pa
- schema = to_arrow_schema(self.schema)
+ prefers_large_var_types =
self.sparkSession._jconf.arrowUseLargeVarTypes()
+ schema = to_arrow_schema(self.schema,
prefers_large_types=prefers_large_var_types)
empty_arrays = [pa.array([], type=field.type) for field in schema]
return [pa.RecordBatch.from_arrays(empty_arrays, schema=schema)]
@@ -715,9 +721,16 @@ class SparkConversionMixin:
pdf_slices = (pdf.iloc[start : start + step] for start in range(0,
len(pdf), step))
# Create list of Arrow (columns, arrow_type, spark_type) for
serializer dump_stream
+ prefers_large_var_types = self._jconf.arrowUseLargeVarTypes()
arrow_data = [
[
- (c, to_arrow_type(t) if t is not None else None, t)
+ (
+ c,
+ to_arrow_type(t,
prefers_large_types=prefers_large_var_types)
+ if t is not None
+ else None,
+ t,
+ )
for (_, c), t in zip(pdf_slice.items(), spark_types)
]
for pdf_slice in pdf_slices
@@ -785,8 +798,13 @@ class SparkConversionMixin:
if not isinstance(schema, StructType):
schema = from_arrow_schema(table.schema,
prefer_timestamp_ntz=prefer_timestamp_ntz)
+ prefers_large_var_types = self._jconf.arrowUseLargeVarTypes()
table = _check_arrow_table_timestamps_localize(table, schema, True,
timezone).cast(
- to_arrow_schema(schema,
error_on_duplicated_field_names_in_struct=True)
+ to_arrow_schema(
+ schema,
+ error_on_duplicated_field_names_in_struct=True,
+ prefers_large_types=prefers_large_var_types,
+ )
)
# Chunk the Arrow Table into RecordBatches
diff --git a/python/pyspark/sql/pandas/serializers.py
b/python/pyspark/sql/pandas/serializers.py
index 536bf7307065..cd2e1230418f 100644
--- a/python/pyspark/sql/pandas/serializers.py
+++ b/python/pyspark/sql/pandas/serializers.py
@@ -793,6 +793,7 @@ class
ApplyInPandasWithStateSerializer(ArrowStreamPandasUDFSerializer):
assign_cols_by_name,
state_object_schema,
arrow_max_records_per_batch,
+ prefers_large_var_types,
):
super(ApplyInPandasWithStateSerializer, self).__init__(
timezone, safecheck, assign_cols_by_name
@@ -808,7 +809,9 @@ class
ApplyInPandasWithStateSerializer(ArrowStreamPandasUDFSerializer):
]
)
- self.result_count_pdf_arrow_type =
to_arrow_type(self.result_count_df_type)
+ self.result_count_pdf_arrow_type = to_arrow_type(
+ self.result_count_df_type,
prefers_large_types=prefers_large_var_types
+ )
self.result_state_df_type = StructType(
[
@@ -819,7 +822,9 @@ class
ApplyInPandasWithStateSerializer(ArrowStreamPandasUDFSerializer):
]
)
- self.result_state_pdf_arrow_type =
to_arrow_type(self.result_state_df_type)
+ self.result_state_pdf_arrow_type = to_arrow_type(
+ self.result_state_df_type,
prefers_large_types=prefers_large_var_types
+ )
self.arrow_max_records_per_batch = arrow_max_records_per_batch
def load_stream(self, stream):
diff --git a/python/pyspark/sql/pandas/types.py
b/python/pyspark/sql/pandas/types.py
index d65126bb3db9..fcd70d4d1839 100644
--- a/python/pyspark/sql/pandas/types.py
+++ b/python/pyspark/sql/pandas/types.py
@@ -67,6 +67,7 @@ def to_arrow_type(
dt: DataType,
error_on_duplicated_field_names_in_struct: bool = False,
timestamp_utc: bool = True,
+ prefers_large_types: bool = False,
) -> "pa.DataType":
"""
Convert Spark data type to PyArrow type
@@ -107,8 +108,12 @@ def to_arrow_type(
arrow_type = pa.float64()
elif type(dt) == DecimalType:
arrow_type = pa.decimal128(dt.precision, dt.scale)
+ elif type(dt) == StringType and prefers_large_types:
+ arrow_type = pa.large_string()
elif type(dt) == StringType:
arrow_type = pa.string()
+ elif type(dt) == BinaryType and prefers_large_types:
+ arrow_type = pa.large_binary()
elif type(dt) == BinaryType:
arrow_type = pa.binary()
elif type(dt) == DateType:
@@ -125,19 +130,34 @@ def to_arrow_type(
elif type(dt) == ArrayType:
field = pa.field(
"element",
- to_arrow_type(dt.elementType,
error_on_duplicated_field_names_in_struct, timestamp_utc),
+ to_arrow_type(
+ dt.elementType,
+ error_on_duplicated_field_names_in_struct,
+ timestamp_utc,
+ prefers_large_types,
+ ),
nullable=dt.containsNull,
)
arrow_type = pa.list_(field)
elif type(dt) == MapType:
key_field = pa.field(
"key",
- to_arrow_type(dt.keyType,
error_on_duplicated_field_names_in_struct, timestamp_utc),
+ to_arrow_type(
+ dt.keyType,
+ error_on_duplicated_field_names_in_struct,
+ timestamp_utc,
+ prefers_large_types,
+ ),
nullable=False,
)
value_field = pa.field(
"value",
- to_arrow_type(dt.valueType,
error_on_duplicated_field_names_in_struct, timestamp_utc),
+ to_arrow_type(
+ dt.valueType,
+ error_on_duplicated_field_names_in_struct,
+ timestamp_utc,
+ prefers_large_types,
+ ),
nullable=dt.valueContainsNull,
)
arrow_type = pa.map_(key_field, value_field)
@@ -152,7 +172,10 @@ def to_arrow_type(
pa.field(
field.name,
to_arrow_type(
- field.dataType, error_on_duplicated_field_names_in_struct,
timestamp_utc
+ field.dataType,
+ error_on_duplicated_field_names_in_struct,
+ timestamp_utc,
+ prefers_large_types,
),
nullable=field.nullable,
)
@@ -163,7 +186,10 @@ def to_arrow_type(
arrow_type = pa.null()
elif isinstance(dt, UserDefinedType):
arrow_type = to_arrow_type(
- dt.sqlType(), error_on_duplicated_field_names_in_struct,
timestamp_utc
+ dt.sqlType(),
+ error_on_duplicated_field_names_in_struct,
+ timestamp_utc,
+ prefers_large_types,
)
elif type(dt) == VariantType:
fields = [
@@ -185,6 +211,7 @@ def to_arrow_schema(
schema: StructType,
error_on_duplicated_field_names_in_struct: bool = False,
timestamp_utc: bool = True,
+ prefers_large_types: bool = False,
) -> "pa.Schema":
"""
Convert a schema from Spark to Arrow
@@ -212,7 +239,12 @@ def to_arrow_schema(
fields = [
pa.field(
field.name,
- to_arrow_type(field.dataType,
error_on_duplicated_field_names_in_struct, timestamp_utc),
+ to_arrow_type(
+ field.dataType,
+ error_on_duplicated_field_names_in_struct,
+ timestamp_utc,
+ prefers_large_types,
+ ),
nullable=field.nullable,
)
for field in schema
diff --git a/python/pyspark/sql/tests/arrow/test_arrow.py
b/python/pyspark/sql/tests/arrow/test_arrow.py
index a2ee113b6386..065f97fcf7c7 100644
--- a/python/pyspark/sql/tests/arrow/test_arrow.py
+++ b/python/pyspark/sql/tests/arrow/test_arrow.py
@@ -730,7 +730,11 @@ class ArrowTestsMixin:
def test_schema_conversion_roundtrip(self):
from pyspark.sql.pandas.types import from_arrow_schema, to_arrow_schema
- arrow_schema = to_arrow_schema(self.schema)
+ arrow_schema = to_arrow_schema(self.schema, prefers_large_types=False)
+ schema_rt = from_arrow_schema(arrow_schema, prefer_timestamp_ntz=True)
+ self.assertEqual(self.schema, schema_rt)
+
+ arrow_schema = to_arrow_schema(self.schema, prefers_large_types=True)
schema_rt = from_arrow_schema(arrow_schema, prefer_timestamp_ntz=True)
self.assertEqual(self.schema, schema_rt)
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index e799498cdd80..7bac0157caee 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -117,10 +117,12 @@ def wrap_udf(f, args_offsets, kwargs_offsets,
return_type):
return args_kwargs_offsets, lambda *a: func(*a)
-def wrap_scalar_pandas_udf(f, args_offsets, kwargs_offsets, return_type):
+def wrap_scalar_pandas_udf(f, args_offsets, kwargs_offsets, return_type,
runner_conf):
func, args_kwargs_offsets = wrap_kwargs_support(f, args_offsets,
kwargs_offsets)
- arrow_return_type = to_arrow_type(return_type)
+ arrow_return_type = to_arrow_type(
+ return_type, prefers_large_types=use_large_var_types(runner_conf)
+ )
def verify_result_type(result):
if not hasattr(result, "__len__"):
@@ -159,7 +161,9 @@ def wrap_arrow_batch_udf(f, args_offsets, kwargs_offsets,
return_type, runner_co
func, args_kwargs_offsets = wrap_kwargs_support(f, args_offsets,
kwargs_offsets)
- arrow_return_type = to_arrow_type(return_type)
+ arrow_return_type = to_arrow_type(
+ return_type, prefers_large_types=use_large_var_types(runner_conf)
+ )
# "result_func" ensures the result of a Python UDF to be consistent
with/without Arrow
# optimization.
@@ -205,8 +209,10 @@ def wrap_arrow_batch_udf(f, args_offsets, kwargs_offsets,
return_type, runner_co
)
-def wrap_pandas_batch_iter_udf(f, return_type):
- arrow_return_type = to_arrow_type(return_type)
+def wrap_pandas_batch_iter_udf(f, return_type, runner_conf):
+ arrow_return_type = to_arrow_type(
+ return_type, prefers_large_types=use_large_var_types(runner_conf)
+ )
iter_type_label = "pandas.DataFrame" if type(return_type) == StructType
else "pandas.Series"
def verify_result(result):
@@ -303,8 +309,10 @@ def verify_pandas_result(result, return_type,
assign_cols_by_name, truncate_retu
)
-def wrap_arrow_batch_iter_udf(f, return_type):
- arrow_return_type = to_arrow_type(return_type)
+def wrap_arrow_batch_iter_udf(f, return_type, runner_conf):
+ arrow_return_type = to_arrow_type(
+ return_type, prefers_large_types=use_large_var_types(runner_conf)
+ )
def verify_result(result):
if not isinstance(result, Iterator) and not hasattr(result,
"__iter__"):
@@ -364,6 +372,7 @@ def wrap_cogrouped_map_arrow_udf(f, return_type, argspec,
runner_conf):
def wrap_cogrouped_map_pandas_udf(f, return_type, argspec, runner_conf):
+ _use_large_var_types = use_large_var_types(runner_conf)
_assign_cols_by_name = assign_cols_by_name(runner_conf)
def wrapped(left_key_series, left_value_series, right_key_series,
right_value_series):
@@ -384,7 +393,8 @@ def wrap_cogrouped_map_pandas_udf(f, return_type, argspec,
runner_conf):
return result
- return lambda kl, vl, kr, vr: [(wrapped(kl, vl, kr, vr),
to_arrow_type(return_type))]
+ arrow_return_type = to_arrow_type(return_type, _use_large_var_types)
+ return lambda kl, vl, kr, vr: [(wrapped(kl, vl, kr, vr),
arrow_return_type)]
def verify_arrow_result(table, assign_cols_by_name, expected_cols_and_types):
@@ -482,10 +492,12 @@ def wrap_grouped_map_arrow_udf(f, return_type, argspec,
runner_conf):
return result.to_batches()
- return lambda k, v: (wrapped(k, v), to_arrow_type(return_type))
+ arrow_return_type = to_arrow_type(return_type,
use_large_var_types(runner_conf))
+ return lambda k, v: (wrapped(k, v), arrow_return_type)
def wrap_grouped_map_pandas_udf(f, return_type, argspec, runner_conf):
+ _use_large_var_types = use_large_var_types(runner_conf)
_assign_cols_by_name = assign_cols_by_name(runner_conf)
def wrapped(key_series, value_series):
@@ -502,7 +514,8 @@ def wrap_grouped_map_pandas_udf(f, return_type, argspec,
runner_conf):
return result
- return lambda k, v: [(wrapped(k, v), to_arrow_type(return_type))]
+ arrow_return_type = to_arrow_type(return_type, _use_large_var_types)
+ return lambda k, v: [(wrapped(k, v), arrow_return_type)]
def wrap_grouped_transform_with_state_pandas_udf(f, return_type, runner_conf):
@@ -517,7 +530,8 @@ def wrap_grouped_transform_with_state_pandas_udf(f,
return_type, runner_conf):
return result_iter
- return lambda p, m, k, v: [(wrapped(p, m, k, v),
to_arrow_type(return_type))]
+ arrow_return_type = to_arrow_type(return_type,
use_large_var_types(runner_conf))
+ return lambda p, m, k, v: [(wrapped(p, m, k, v), arrow_return_type)]
def wrap_grouped_transform_with_state_pandas_init_state_udf(f, return_type,
runner_conf):
@@ -535,10 +549,11 @@ def
wrap_grouped_transform_with_state_pandas_init_state_udf(f, return_type, runn
return result_iter
- return lambda p, m, k, v: [(wrapped(p, m, k, v),
to_arrow_type(return_type))]
+ arrow_return_type = to_arrow_type(return_type,
use_large_var_types(runner_conf))
+ return lambda p, m, k, v: [(wrapped(p, m, k, v), arrow_return_type)]
-def wrap_grouped_map_pandas_udf_with_state(f, return_type):
+def wrap_grouped_map_pandas_udf_with_state(f, return_type, runner_conf):
"""
Provides a new lambda instance wrapping user function of
applyInPandasWithState.
@@ -553,6 +568,7 @@ def wrap_grouped_map_pandas_udf_with_state(f, return_type):
Along with the returned iterator, the lambda instance will also produce
the return_type as
converted to the arrow schema.
"""
+ _use_large_var_types = use_large_var_types(runner_conf)
def wrapped(key_series, value_series_gen, state):
"""
@@ -627,13 +643,16 @@ def wrap_grouped_map_pandas_udf_with_state(f,
return_type):
state,
)
- return lambda k, v, s: [(wrapped(k, v, s), to_arrow_type(return_type))]
+ arrow_return_type = to_arrow_type(return_type, _use_large_var_types)
+ return lambda k, v, s: [(wrapped(k, v, s), arrow_return_type)]
-def wrap_grouped_agg_pandas_udf(f, args_offsets, kwargs_offsets, return_type):
+def wrap_grouped_agg_pandas_udf(f, args_offsets, kwargs_offsets, return_type,
runner_conf):
func, args_kwargs_offsets = wrap_kwargs_support(f, args_offsets,
kwargs_offsets)
- arrow_return_type = to_arrow_type(return_type)
+ arrow_return_type = to_arrow_type(
+ return_type, prefers_large_types=use_large_var_types(runner_conf)
+ )
def wrapped(*series):
import pandas as pd
@@ -653,9 +672,13 @@ def wrap_window_agg_pandas_udf(
window_bound_types_str = runner_conf.get("pandas_window_bound_types")
window_bound_type = [t.strip().lower() for t in
window_bound_types_str.split(",")][udf_index]
if window_bound_type == "bounded":
- return wrap_bounded_window_agg_pandas_udf(f, args_offsets,
kwargs_offsets, return_type)
+ return wrap_bounded_window_agg_pandas_udf(
+ f, args_offsets, kwargs_offsets, return_type, runner_conf
+ )
elif window_bound_type == "unbounded":
- return wrap_unbounded_window_agg_pandas_udf(f, args_offsets,
kwargs_offsets, return_type)
+ return wrap_unbounded_window_agg_pandas_udf(
+ f, args_offsets, kwargs_offsets, return_type, runner_conf
+ )
else:
raise PySparkRuntimeError(
errorClass="INVALID_WINDOW_BOUND_TYPE",
@@ -665,14 +688,16 @@ def wrap_window_agg_pandas_udf(
)
-def wrap_unbounded_window_agg_pandas_udf(f, args_offsets, kwargs_offsets,
return_type):
+def wrap_unbounded_window_agg_pandas_udf(f, args_offsets, kwargs_offsets,
return_type, runner_conf):
func, args_kwargs_offsets = wrap_kwargs_support(f, args_offsets,
kwargs_offsets)
# This is similar to grouped_agg_pandas_udf, the only difference
# is that window_agg_pandas_udf needs to repeat the return value
# to match window length, where grouped_agg_pandas_udf just returns
# the scalar value.
- arrow_return_type = to_arrow_type(return_type)
+ arrow_return_type = to_arrow_type(
+ return_type, prefers_large_types=use_large_var_types(runner_conf)
+ )
def wrapped(*series):
import pandas as pd
@@ -686,12 +711,14 @@ def wrap_unbounded_window_agg_pandas_udf(f, args_offsets,
kwargs_offsets, return
)
-def wrap_bounded_window_agg_pandas_udf(f, args_offsets, kwargs_offsets,
return_type):
+def wrap_bounded_window_agg_pandas_udf(f, args_offsets, kwargs_offsets,
return_type, runner_conf):
# args_offsets should have at least 2 for begin_index, end_index.
assert len(args_offsets) >= 2, len(args_offsets)
func, args_kwargs_offsets = wrap_kwargs_support(f, args_offsets[2:],
kwargs_offsets)
- arrow_return_type = to_arrow_type(return_type)
+ arrow_return_type = to_arrow_type(
+ return_type, prefers_large_types=use_large_var_types(runner_conf)
+ )
def wrapped(begin_index, end_index, *series):
import pandas as pd
@@ -865,15 +892,15 @@ def read_single_udf(pickleSer, infile, eval_type,
runner_conf, udf_index, profil
# the last returnType will be the return type of UDF
if eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF:
- return wrap_scalar_pandas_udf(func, args_offsets, kwargs_offsets,
return_type)
+ return wrap_scalar_pandas_udf(func, args_offsets, kwargs_offsets,
return_type, runner_conf)
elif eval_type == PythonEvalType.SQL_ARROW_BATCHED_UDF:
return wrap_arrow_batch_udf(func, args_offsets, kwargs_offsets,
return_type, runner_conf)
elif eval_type == PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF:
- return args_offsets, wrap_pandas_batch_iter_udf(func, return_type)
+ return args_offsets, wrap_pandas_batch_iter_udf(func, return_type,
runner_conf)
elif eval_type == PythonEvalType.SQL_MAP_PANDAS_ITER_UDF:
- return args_offsets, wrap_pandas_batch_iter_udf(func, return_type)
+ return args_offsets, wrap_pandas_batch_iter_udf(func, return_type,
runner_conf)
elif eval_type == PythonEvalType.SQL_MAP_ARROW_ITER_UDF:
- return args_offsets, wrap_arrow_batch_iter_udf(func, return_type)
+ return args_offsets, wrap_arrow_batch_iter_udf(func, return_type,
runner_conf)
elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF:
argspec = inspect.getfullargspec(chained_func) # signature was lost
when wrapping it
return args_offsets, wrap_grouped_map_pandas_udf(func, return_type,
argspec, runner_conf)
@@ -881,7 +908,7 @@ def read_single_udf(pickleSer, infile, eval_type,
runner_conf, udf_index, profil
argspec = inspect.getfullargspec(chained_func) # signature was lost
when wrapping it
return args_offsets, wrap_grouped_map_arrow_udf(func, return_type,
argspec, runner_conf)
elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE:
- return args_offsets, wrap_grouped_map_pandas_udf_with_state(func,
return_type)
+ return args_offsets, wrap_grouped_map_pandas_udf_with_state(func,
return_type, runner_conf)
elif eval_type == PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_UDF:
return args_offsets, wrap_grouped_transform_with_state_pandas_udf(
func, return_type, runner_conf
@@ -897,7 +924,9 @@ def read_single_udf(pickleSer, infile, eval_type,
runner_conf, udf_index, profil
argspec = inspect.getfullargspec(chained_func) # signature was lost
when wrapping it
return args_offsets, wrap_cogrouped_map_arrow_udf(func, return_type,
argspec, runner_conf)
elif eval_type == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF:
- return wrap_grouped_agg_pandas_udf(func, args_offsets, kwargs_offsets,
return_type)
+ return wrap_grouped_agg_pandas_udf(
+ func, args_offsets, kwargs_offsets, return_type, runner_conf
+ )
elif eval_type == PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF:
return wrap_window_agg_pandas_udf(
func, args_offsets, kwargs_offsets, return_type, runner_conf,
udf_index
@@ -922,6 +951,10 @@ def assign_cols_by_name(runner_conf):
)
+def use_large_var_types(runner_conf):
+ return runner_conf.get("spark.sql.execution.arrow.useLargeVarTypes",
"false").lower() == "true"
+
+
# Read and process a serialized user-defined table function (UDTF) from a
socket.
# It expects the UDTF to be in a specific format and performs various checks to
# ensure the UDTF is valid. This function also prepares a mapper function for
applying
@@ -1254,7 +1287,9 @@ def read_udtf(pickleSer, infile, eval_type):
def wrap_arrow_udtf(f, return_type):
import pandas as pd
- arrow_return_type = to_arrow_type(return_type)
+ arrow_return_type = to_arrow_type(
+ return_type,
prefers_large_types=use_large_var_types(runner_conf)
+ )
return_type_size = len(return_type)
def verify_result(result):
@@ -1499,6 +1534,7 @@ def read_udfs(pickleSer, infile, eval_type):
# NOTE: if timezone is set here, that implies respectSessionTimeZone
is True
timezone = runner_conf.get("spark.sql.session.timeZone", None)
+ prefers_large_var_types = use_large_var_types(runner_conf)
safecheck = (
runner_conf.get("spark.sql.execution.pandas.convertToArrowArraySafely",
"false").lower()
== "true"
@@ -1521,6 +1557,7 @@ def read_udfs(pickleSer, infile, eval_type):
_assign_cols_by_name,
state_object_schema,
arrow_max_records_per_batch,
+ prefers_large_var_types,
)
elif eval_type == PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_UDF:
arrow_max_records_per_batch = runner_conf.get(
diff --git
a/sql/api/src/main/scala/org/apache/spark/sql/internal/SqlApiConf.scala
b/sql/api/src/main/scala/org/apache/spark/sql/internal/SqlApiConf.scala
index 76cd436b39b5..cb517c689ea1 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/internal/SqlApiConf.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/internal/SqlApiConf.scala
@@ -56,6 +56,8 @@ private[sql] object SqlApiConf {
val LEGACY_TIME_PARSER_POLICY_KEY: String =
SqlApiConfHelper.LEGACY_TIME_PARSER_POLICY_KEY
val CASE_SENSITIVE_KEY: String = SqlApiConfHelper.CASE_SENSITIVE_KEY
val SESSION_LOCAL_TIMEZONE_KEY: String =
SqlApiConfHelper.SESSION_LOCAL_TIMEZONE_KEY
+ val ARROW_EXECUTION_USE_LARGE_VAR_TYPES: String =
+ SqlApiConfHelper.ARROW_EXECUTION_USE_LARGE_VAR_TYPES
val LOCAL_RELATION_CACHE_THRESHOLD_KEY: String = {
SqlApiConfHelper.LOCAL_RELATION_CACHE_THRESHOLD_KEY
}
diff --git
a/sql/api/src/main/scala/org/apache/spark/sql/internal/SqlApiConfHelper.scala
b/sql/api/src/main/scala/org/apache/spark/sql/internal/SqlApiConfHelper.scala
index 13ef13e5894e..486a7dfb58dd 100644
---
a/sql/api/src/main/scala/org/apache/spark/sql/internal/SqlApiConfHelper.scala
+++
b/sql/api/src/main/scala/org/apache/spark/sql/internal/SqlApiConfHelper.scala
@@ -33,6 +33,7 @@ private[sql] object SqlApiConfHelper {
val SESSION_LOCAL_TIMEZONE_KEY: String = "spark.sql.session.timeZone"
val LOCAL_RELATION_CACHE_THRESHOLD_KEY: String =
"spark.sql.session.localRelationCacheThreshold"
val DEFAULT_COLLATION: String = "spark.sql.session.collation.default"
+ val ARROW_EXECUTION_USE_LARGE_VAR_TYPES =
"spark.sql.execution.arrow.useLargeVarTypes"
val confGetter: AtomicReference[() => SqlApiConf] = {
new AtomicReference[() => SqlApiConf](() => DefaultSqlApiConf)
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala
b/sql/api/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala
index 55d1aff8261d..587ca43e5730 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala
@@ -202,7 +202,7 @@ private[sql] object ArrowUtils {
schema: StructType,
timeZoneId: String,
errorOnDuplicatedFieldNames: Boolean,
- largeVarTypes: Boolean = false): Schema = {
+ largeVarTypes: Boolean): Schema = {
new Schema(schema.map { field =>
toArrowField(
field.name,
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala
index 065b4b8c821a..c496b0e82c26 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala
@@ -33,8 +33,10 @@ object ArrowWriter {
def create(
schema: StructType,
timeZoneId: String,
- errorOnDuplicatedFieldNames: Boolean = true): ArrowWriter = {
- val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId,
errorOnDuplicatedFieldNames)
+ errorOnDuplicatedFieldNames: Boolean = true,
+ largeVarTypes: Boolean = false): ArrowWriter = {
+ val arrowSchema = ArrowUtils.toArrowSchema(
+ schema, timeZoneId, errorOnDuplicatedFieldNames, largeVarTypes)
val root = VectorSchemaRoot.create(arrowSchema, ArrowUtils.rootAllocator)
create(root)
}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 13fd8742cfa5..ad37262b6ed2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -3421,9 +3421,8 @@ object SQLConf {
.doc("When using Apache Arrow, use large variable width vectors for
string and binary " +
"types. Regular string and binary types have a 2GiB limit for a column
in a single " +
"record batch. Large variable types remove this limitation at the cost
of higher memory " +
- "usage per value. Note that this only works for DataFrame.mapInArrow.")
+ "usage per value.")
.version("3.5.0")
- .internal()
.booleanConf
.createWithDefault(false)
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/util/ArrowUtilsSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/util/ArrowUtilsSuite.scala
index c705a6b791bd..7124c94b390d 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/util/ArrowUtilsSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/util/ArrowUtilsSuite.scala
@@ -30,7 +30,8 @@ class ArrowUtilsSuite extends SparkFunSuite {
def roundtrip(dt: DataType): Unit = {
dt match {
case schema: StructType =>
- assert(ArrowUtils.fromArrowSchema(ArrowUtils.toArrowSchema(schema,
null, true)) === schema)
+ assert(ArrowUtils.fromArrowSchema(
+ ArrowUtils.toArrowSchema(schema, null, true, false)) === schema)
case _ =>
roundtrip(new StructType().add("value", dt))
}
@@ -69,7 +70,7 @@ class ArrowUtilsSuite extends SparkFunSuite {
def roundtripWithTz(timeZoneId: String): Unit = {
val schema = new StructType().add("value", TimestampType)
- val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId, true)
+ val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId, true,
false)
val fieldType =
arrowSchema.findField("value").getType.asInstanceOf[ArrowType.Timestamp]
assert(fieldType.getTimezone() === timeZoneId)
assert(ArrowUtils.fromArrowSchema(arrowSchema) === schema)
@@ -105,9 +106,9 @@ class ArrowUtilsSuite extends SparkFunSuite {
def check(dt: DataType, expected: DataType): Unit = {
val schema = new StructType().add("value", dt)
intercept[SparkUnsupportedOperationException] {
- ArrowUtils.toArrowSchema(schema, null, true)
+ ArrowUtils.toArrowSchema(schema, null, true, false)
}
- assert(ArrowUtils.fromArrowSchema(ArrowUtils.toArrowSchema(schema, null,
false))
+ assert(ArrowUtils.fromArrowSchema(ArrowUtils.toArrowSchema(schema, null,
false, false))
=== new StructType().add("value", expected))
}
diff --git
a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/SQLImplicitsTestSuite.scala
b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/SQLImplicitsTestSuite.scala
index 2791c6b6add5..c7b4748f1222 100644
---
a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/SQLImplicitsTestSuite.scala
+++
b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/SQLImplicitsTestSuite.scala
@@ -64,7 +64,8 @@ class SQLImplicitsTestSuite extends ConnectFunSuite with
BeforeAndAfterAll {
input = Iterator.single(expected),
enc = encoder,
allocator = allocator,
- timeZoneId = "UTC")
+ timeZoneId = "UTC",
+ largeVarTypes = false)
val fromArrow = ArrowDeserializers.deserializeFromArrow(
input = Iterator.single(batch.toByteArray),
encoder = encoder,
diff --git
a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala
b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala
index f6662b3351ba..58e19389cae2 100644
---
a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala
+++
b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala
@@ -106,7 +106,8 @@ class ArrowEncoderSuite extends ConnectFunSuite with
BeforeAndAfterAll {
maxRecordsPerBatch = maxRecordsPerBatch,
maxBatchSize = maxBatchSize,
batchSizeCheckInterval = batchSizeCheckInterval,
- timeZoneId = "UTC")
+ timeZoneId = "UTC",
+ largeVarTypes = false)
val inspectedIterator = if (inspectBatch != null) {
arrowIterator.map { batch =>
@@ -183,7 +184,8 @@ class ArrowEncoderSuite extends ConnectFunSuite with
BeforeAndAfterAll {
allocator,
maxRecordsPerBatch = 1024,
maxBatchSize = 8 * 1024,
- timeZoneId = "UTC")
+ timeZoneId = "UTC",
+ largeVarTypes = false)
}
private def compareIterators[T](expected: Iterator[T], actual: Iterator[T]):
Unit = {
@@ -626,7 +628,8 @@ class ArrowEncoderSuite extends ConnectFunSuite with
BeforeAndAfterAll {
allocator,
maxRecordsPerBatch = 128,
maxBatchSize = 1024,
- timeZoneId = "UTC")
+ timeZoneId = "UTC",
+ largeVarTypes = false)
intercept[NullPointerException] {
iterator.next()
}
diff --git
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/SparkSession.scala
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/SparkSession.scala
index b0d7d6aa5d13..e4ac4a1ba619 100644
---
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/SparkSession.scala
+++
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/SparkSession.scala
@@ -18,6 +18,7 @@ package org.apache.spark.sql.connect
import java.net.URI
import java.nio.file.{Files, Paths}
+import java.util.Locale
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.atomic.AtomicLong
@@ -110,7 +111,8 @@ class SparkSession private[sql] (
private def createDataset[T](encoder: AgnosticEncoder[T], data:
Iterator[T]): Dataset[T] = {
newDataset(encoder) { builder =>
if (data.nonEmpty) {
- val arrowData = ArrowSerializer.serialize(data, encoder, allocator,
timeZoneId)
+ val arrowData =
+ ArrowSerializer.serialize(data, encoder, allocator, timeZoneId,
largeVarTypes)
if (arrowData.size() <=
conf.get(SqlApiConf.LOCAL_RELATION_CACHE_THRESHOLD_KEY).toInt) {
builder.getLocalRelationBuilder
.setSchema(encoder.schema.json)
@@ -467,6 +469,8 @@ class SparkSession private[sql] (
}
private[sql] def timeZoneId: String =
conf.get(SqlApiConf.SESSION_LOCAL_TIMEZONE_KEY)
+ private[sql] def largeVarTypes: Boolean =
+
conf.get(SqlApiConf.ARROW_EXECUTION_USE_LARGE_VAR_TYPES).toLowerCase(Locale.ROOT).toBoolean
private[sql] def execute[T](plan: proto.Plan, encoder: AgnosticEncoder[T]):
SparkResult[T] = {
val value = executeInternal(plan)
diff --git
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowSerializer.scala
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowSerializer.scala
index c01390bf0785..584a318f039d 100644
---
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowSerializer.scala
+++
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowSerializer.scala
@@ -27,7 +27,7 @@ import scala.jdk.CollectionConverters._
import com.google.protobuf.ByteString
import org.apache.arrow.memory.BufferAllocator
-import org.apache.arrow.vector.{BigIntVector, BitVector, DateDayVector,
DecimalVector, DurationVector, FieldVector, Float4Vector, Float8Vector,
IntervalYearVector, IntVector, NullVector, SmallIntVector,
TimeStampMicroTZVector, TimeStampMicroVector, TinyIntVector, VarBinaryVector,
VarCharVector, VectorSchemaRoot, VectorUnloader}
+import org.apache.arrow.vector._
import org.apache.arrow.vector.complex.{ListVector, MapVector, StructVector}
import org.apache.arrow.vector.ipc.{ArrowStreamWriter, WriteChannel}
import org.apache.arrow.vector.ipc.message.{IpcOption, MessageSerializer}
@@ -50,8 +50,10 @@ import org.apache.spark.unsafe.types.VariantVal
class ArrowSerializer[T](
private[this] val enc: AgnosticEncoder[T],
private[this] val allocator: BufferAllocator,
- private[this] val timeZoneId: String) {
- private val (root, serializer) = ArrowSerializer.serializerFor(enc,
allocator, timeZoneId)
+ private[this] val timeZoneId: String,
+ private[this] val largeVarTypes: Boolean) {
+ private val (root, serializer) =
+ ArrowSerializer.serializerFor(enc, allocator, timeZoneId, largeVarTypes)
private val vectors = root.getFieldVectors.asScala
private val unloader = new VectorUnloader(root)
private val schemaBytes = {
@@ -144,12 +146,13 @@ object ArrowSerializer {
maxRecordsPerBatch: Int,
maxBatchSize: Long,
timeZoneId: String,
+ largeVarTypes: Boolean,
batchSizeCheckInterval: Int = 128): CloseableIterator[Array[Byte]] = {
assert(maxRecordsPerBatch > 0)
assert(maxBatchSize > 0)
assert(batchSizeCheckInterval > 0)
new CloseableIterator[Array[Byte]] {
- private val serializer = new ArrowSerializer[T](enc, allocator,
timeZoneId)
+ private val serializer = new ArrowSerializer[T](enc, allocator,
timeZoneId, largeVarTypes)
private val bytes = new ByteArrayOutputStream
private var hasWrittenFirstBatch = false
@@ -191,8 +194,9 @@ object ArrowSerializer {
input: Iterator[T],
enc: AgnosticEncoder[T],
allocator: BufferAllocator,
- timeZoneId: String): ByteString = {
- val serializer = new ArrowSerializer[T](enc, allocator, timeZoneId)
+ timeZoneId: String,
+ largeVarTypes: Boolean): ByteString = {
+ val serializer = new ArrowSerializer[T](enc, allocator, timeZoneId,
largeVarTypes)
try {
input.foreach(serializer.append)
val output = ByteString.newOutput()
@@ -211,9 +215,14 @@ object ArrowSerializer {
def serializerFor[T](
encoder: AgnosticEncoder[T],
allocator: BufferAllocator,
- timeZoneId: String): (VectorSchemaRoot, Serializer) = {
+ timeZoneId: String,
+ largeVarTypes: Boolean): (VectorSchemaRoot, Serializer) = {
val arrowSchema =
- ArrowUtils.toArrowSchema(encoder.schema, timeZoneId,
errorOnDuplicatedFieldNames = true)
+ ArrowUtils.toArrowSchema(
+ encoder.schema,
+ timeZoneId,
+ errorOnDuplicatedFieldNames = true,
+ largeVarTypes = largeVarTypes)
val root = VectorSchemaRoot.create(arrowSchema, allocator)
val serializer = if (encoder.schema != encoder.dataType) {
assert(root.getSchema.getFields.size() == 1)
@@ -264,19 +273,36 @@ object ArrowSerializer {
new FieldSerializer[String, VarCharVector](v) {
override def set(index: Int, value: String): Unit = setString(v,
index, value)
}
+ case (StringEncoder, v: LargeVarCharVector) =>
+ new FieldSerializer[String, LargeVarCharVector](v) {
+ override def set(index: Int, value: String): Unit = setString(v,
index, value)
+ }
case (JavaEnumEncoder(_), v: VarCharVector) =>
new FieldSerializer[Enum[_], VarCharVector](v) {
override def set(index: Int, value: Enum[_]): Unit = setString(v,
index, value.name())
}
+ case (JavaEnumEncoder(_), v: LargeVarCharVector) =>
+ new FieldSerializer[Enum[_], LargeVarCharVector](v) {
+ override def set(index: Int, value: Enum[_]): Unit = setString(v,
index, value.name())
+ }
case (ScalaEnumEncoder(_, _), v: VarCharVector) =>
new FieldSerializer[Enumeration#Value, VarCharVector](v) {
override def set(index: Int, value: Enumeration#Value): Unit =
setString(v, index, value.toString)
}
+ case (ScalaEnumEncoder(_, _), v: LargeVarCharVector) =>
+ new FieldSerializer[Enumeration#Value, LargeVarCharVector](v) {
+ override def set(index: Int, value: Enumeration#Value): Unit =
+ setString(v, index, value.toString)
+ }
case (BinaryEncoder, v: VarBinaryVector) =>
new FieldSerializer[Array[Byte], VarBinaryVector](v) {
override def set(index: Int, value: Array[Byte]): Unit =
vector.setSafe(index, value)
}
+ case (BinaryEncoder, v: LargeVarBinaryVector) =>
+ new FieldSerializer[Array[Byte], LargeVarBinaryVector](v) {
+ override def set(index: Int, value: Array[Byte]): Unit =
vector.setSafe(index, value)
+ }
case (SparkDecimalEncoder(_), v: DecimalVector) =>
new FieldSerializer[Decimal, DecimalVector](v) {
override def set(index: Int, value: Decimal): Unit =
@@ -477,7 +503,7 @@ object ArrowSerializer {
private val methodLookup = MethodHandles.lookup()
- private def setString(vector: VarCharVector, index: Int, string: String):
Unit = {
+ private def setString(vector: VariableWidthFieldVector, index: Int, string:
String): Unit = {
val bytes = Text.encode(string)
vector.setSafe(index, bytes, 0, bytes.limit())
}
diff --git
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowVectorReader.scala
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowVectorReader.scala
index 53d8d46e6268..3dbfce18e7b4 100644
---
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowVectorReader.scala
+++
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowVectorReader.scala
@@ -20,7 +20,7 @@ import java.math.{BigDecimal => JBigDecimal}
import java.sql.{Date, Timestamp}
import java.time.{Duration, Instant, LocalDate, LocalDateTime, Period,
ZoneOffset}
-import org.apache.arrow.vector.{BigIntVector, BitVector, DateDayVector,
DecimalVector, DurationVector, FieldVector, Float4Vector, Float8Vector,
IntervalYearVector, IntVector, NullVector, SmallIntVector,
TimeStampMicroTZVector, TimeStampMicroVector, TinyIntVector, VarBinaryVector,
VarCharVector}
+import org.apache.arrow.vector._
import org.apache.arrow.vector.util.Text
import org.apache.spark.sql.catalyst.util.{DateFormatter, SparkIntervalUtils,
SparkStringUtils, TimestampFormatter}
@@ -82,7 +82,9 @@ object ArrowVectorReader {
case v: Float8Vector => new Float8VectorReader(v)
case v: DecimalVector => new DecimalVectorReader(v)
case v: VarCharVector => new VarCharVectorReader(v)
+ case v: LargeVarCharVector => new LargeVarCharVectorReader(v)
case v: VarBinaryVector => new VarBinaryVectorReader(v)
+ case v: LargeVarBinaryVector => new LargeVarBinaryVectorReader(v)
case v: DurationVector => new DurationVectorReader(v)
case v: IntervalYearVector => new IntervalYearVectorReader(v)
case v: DateDayVector => new DateDayVectorReader(v, timeZoneId)
@@ -189,12 +191,29 @@ private[arrow] class VarCharVectorReader(v: VarCharVector)
override def getString(i: Int): String = Text.decode(vector.get(i))
}
+private[arrow] class LargeVarCharVectorReader(v: LargeVarCharVector)
+ extends TypedArrowVectorReader[LargeVarCharVector](v) {
+ // This is currently a bit heavy on allocations:
+ // - byte array created in VarCharVector.get
+ // - CharBuffer created CharSetEncoder
+ // - char array in String
+ // By using direct buffers and reusing the char buffer
+ // we could get rid of the first two allocations.
+ override def getString(i: Int): String = Text.decode(vector.get(i))
+}
+
private[arrow] class VarBinaryVectorReader(v: VarBinaryVector)
extends TypedArrowVectorReader[VarBinaryVector](v) {
override def getBytes(i: Int): Array[Byte] = vector.get(i)
override def getString(i: Int): String =
SparkStringUtils.getHexString(getBytes(i))
}
+private[arrow] class LargeVarBinaryVectorReader(v: LargeVarBinaryVector)
+ extends TypedArrowVectorReader[LargeVarBinaryVector](v) {
+ override def getBytes(i: Int): Array[Byte] = vector.get(i)
+ override def getString(i: Int): String =
SparkStringUtils.getHexString(getBytes(i))
+}
+
private[arrow] class DurationVectorReader(v: DurationVector)
extends TypedArrowVectorReader[DurationVector](v) {
override def getDuration(i: Int): Duration = vector.getObject(i)
diff --git
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala
index 497576a6630d..fc3f18063416 100644
---
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala
+++
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala
@@ -90,14 +90,16 @@ private[execution] class
SparkConnectPlanExecution(executeHolder: ExecuteHolder)
maxRecordsPerBatch: Int,
maxBatchSize: Long,
timeZoneId: String,
- errorOnDuplicatedFieldNames: Boolean): Iterator[InternalRow] =>
Iterator[Batch] = { rows =>
+ errorOnDuplicatedFieldNames: Boolean,
+ largeVarTypes: Boolean): Iterator[InternalRow] => Iterator[Batch] = {
rows =>
val batches = ArrowConverters.toBatchWithSchemaIterator(
rows,
schema,
maxRecordsPerBatch,
maxBatchSize,
timeZoneId,
- errorOnDuplicatedFieldNames)
+ errorOnDuplicatedFieldNames,
+ largeVarTypes)
batches.map(b => b -> batches.rowCountInLastBatch)
}
@@ -110,6 +112,7 @@ private[execution] class
SparkConnectPlanExecution(executeHolder: ExecuteHolder)
val schema = dataframe.schema
val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch
val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone
+ val largeVarTypes = spark.sessionState.conf.arrowUseLargeVarTypes
// Conservatively sets it 70% because the size is not accurate but
estimated.
val maxBatchSize =
(SparkEnv.get.conf.get(CONNECT_GRPC_ARROW_MAX_BATCH_SIZE) * 0.7).toLong
@@ -118,7 +121,8 @@ private[execution] class
SparkConnectPlanExecution(executeHolder: ExecuteHolder)
maxRecordsPerBatch,
maxBatchSize,
timeZoneId,
- errorOnDuplicatedFieldNames = false)
+ errorOnDuplicatedFieldNames = false,
+ largeVarTypes = largeVarTypes)
var numSent = 0
def sendBatch(bytes: Array[Byte], count: Long, startOffset: Long): Unit = {
@@ -239,7 +243,8 @@ private[execution] class
SparkConnectPlanExecution(executeHolder: ExecuteHolder)
ArrowConverters.createEmptyArrowBatch(
schema,
timeZoneId,
- errorOnDuplicatedFieldNames = false),
+ errorOnDuplicatedFieldNames = false,
+ largeVarTypes = largeVarTypes),
0L,
0L)
}
diff --git
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 6019969be05b..2a296522b620 100644
---
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -2550,7 +2550,8 @@ class SparkConnectPlanner(
result.iterator,
StringEncoder,
ArrowUtils.rootAllocator,
- session.sessionState.conf.sessionLocalTimeZone)
+ session.sessionState.conf.sessionLocalTimeZone,
+ session.sessionState.conf.arrowUseLargeVarTypes)
val sqlCommandResult = SqlCommandResult.newBuilder()
sqlCommandResult.getRelationBuilder.getLocalRelationBuilder.setData(arrowData)
responseObserver.onNext(
@@ -2613,13 +2614,15 @@ class SparkConnectPlanner(
val schema = df.schema
val maxBatchSize =
(SparkEnv.get.conf.get(CONNECT_GRPC_ARROW_MAX_BATCH_SIZE) * 0.7).toLong
val timeZoneId = session.sessionState.conf.sessionLocalTimeZone
+ val largeVarTypes = session.sessionState.conf.arrowUseLargeVarTypes
// Convert the data.
val bytes = if (rows.isEmpty) {
ArrowConverters.createEmptyArrowBatch(
schema,
timeZoneId,
- errorOnDuplicatedFieldNames = false)
+ errorOnDuplicatedFieldNames = false,
+ largeVarTypes = largeVarTypes)
} else {
val batches = ArrowConverters.toBatchWithSchemaIterator(
rowIter = rows.iterator,
@@ -2627,7 +2630,8 @@ class SparkConnectPlanner(
maxRecordsPerBatch = -1,
maxEstimatedBatchSize = maxBatchSize,
timeZoneId = timeZoneId,
- errorOnDuplicatedFieldNames = false)
+ errorOnDuplicatedFieldNames = false,
+ largeVarTypes = largeVarTypes)
assert(batches.hasNext)
val bytes = batches.next()
assert(!batches.hasNext, s"remaining batches: ${batches.size}")
diff --git
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala
index db430549818d..76c88d515ec0 100644
---
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala
+++
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala
@@ -146,7 +146,12 @@ trait SparkConnectServerTest extends SharedSparkSession {
protected def buildLocalRelation[A <: Product: TypeTag](data: Seq[A]) = {
val encoder = ScalaReflection.encoderFor[A]
val arrowData =
- ArrowSerializer.serialize(data.iterator, encoder, allocator,
TimeZone.getDefault.getID)
+ ArrowSerializer.serialize(
+ data.iterator,
+ encoder,
+ allocator,
+ TimeZone.getDefault.getID,
+ largeVarTypes = false)
val localRelation = proto.LocalRelation
.newBuilder()
.setData(arrowData)
diff --git
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala
index 2a09d5f8e8bd..72f7065b4424 100644
---
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala
+++
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala
@@ -101,7 +101,8 @@ trait SparkConnectPlanTest extends SharedSparkSession {
Long.MaxValue,
Long.MaxValue,
timeZoneId,
- true)
+ true,
+ false)
.next()
localRelationBuilder.setData(ByteString.copyFrom(bytes))
@@ -478,7 +479,7 @@ class SparkConnectPlannerSuite extends SparkFunSuite with
SparkConnectPlanTest {
test("Empty ArrowBatch") {
val schema = StructType(Seq(StructField("int", IntegerType)))
- val data = ArrowConverters.createEmptyArrowBatch(schema, null, true)
+ val data = ArrowConverters.createEmptyArrowBatch(schema, null, true, false)
val localRelation = proto.Relation
.newBuilder()
.setLocalRelation(
diff --git
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
index 2bbd6863b110..494aceb2fb58 100644
---
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
+++
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
@@ -1115,7 +1115,8 @@ class SparkConnectProtoSuite extends PlanTest with
SparkConnectPlanTest {
Long.MaxValue,
Long.MaxValue,
null,
- true)
+ true,
+ false)
.next()
proto.Relation
.newBuilder()
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala
b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala
index ad58fc0c2fcf..1efd8f9e3220 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala
@@ -249,7 +249,13 @@ private[sql] object SQLUtils extends Logging {
val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone
val rdd = arrowBatchRDD.rdd.mapPartitions { iter =>
val context = TaskContext.get()
- ArrowConverters.fromBatchIterator(iter, schema, timeZoneId, true,
context)
+ ArrowConverters.fromBatchIterator(
+ iter,
+ schema,
+ timeZoneId,
+ true,
+ false,
+ context)
}
sparkSession.internalCreateDataFrame(rdd.setName("arrow"), schema)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.scala
b/sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.scala
index d78a3a391edb..8930b5895d32 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.scala
@@ -2089,7 +2089,8 @@ class Dataset[T] private[sql](
val buffer = new ByteArrayOutputStream()
val out = new DataOutputStream(outputStream)
val batchWriter =
- new ArrowBatchStreamWriter(schema, buffer, timeZoneId,
errorOnDuplicatedFieldNames = true)
+ new ArrowBatchStreamWriter(
+ schema, buffer, timeZoneId, errorOnDuplicatedFieldNames = true,
largeVarTypes = false)
val arrowBatchRdd = toArrowBatchRdd(plan)
val numPartitions = arrowBatchRdd.partitions.length
@@ -2140,12 +2141,14 @@ class Dataset[T] private[sql](
val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone
val errorOnDuplicatedFieldNames =
sparkSession.sessionState.conf.pandasStructHandlingMode == "legacy"
+ val largeVarTypes = sparkSession.sessionState.conf.arrowUseLargeVarTypes
PythonRDD.serveToStream("serve-Arrow") { outputStream =>
withAction("collectAsArrowToPython", queryExecution) { plan =>
val out = new DataOutputStream(outputStream)
val batchWriter =
- new ArrowBatchStreamWriter(schema, out, timeZoneId,
errorOnDuplicatedFieldNames)
+ new ArrowBatchStreamWriter(
+ schema, out, timeZoneId, errorOnDuplicatedFieldNames,
largeVarTypes)
// Batches ordered by (index of partition, batch index in that
partition) tuple
val batchOrder = ArrayBuffer.empty[(Int, Int)]
@@ -2294,10 +2297,18 @@ class Dataset[T] private[sql](
val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone
val errorOnDuplicatedFieldNames =
sparkSession.sessionState.conf.pandasStructHandlingMode == "legacy"
+ val largeVarTypes =
+ sparkSession.sessionState.conf.arrowUseLargeVarTypes
plan.execute().mapPartitionsInternal { iter =>
val context = TaskContext.get()
ArrowConverters.toBatchIterator(
- iter, schemaCaptured, maxRecordsPerBatch, timeZoneId,
errorOnDuplicatedFieldNames, context)
+ iter,
+ schemaCaptured,
+ maxRecordsPerBatch,
+ timeZoneId,
+ errorOnDuplicatedFieldNames,
+ largeVarTypes,
+ context)
}
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala
index 7ea1bd6ff7dc..ed490347ae82 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala
@@ -51,9 +51,11 @@ private[sql] class ArrowBatchStreamWriter(
schema: StructType,
out: OutputStream,
timeZoneId: String,
- errorOnDuplicatedFieldNames: Boolean) {
+ errorOnDuplicatedFieldNames: Boolean,
+ largeVarTypes: Boolean) {
- val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId,
errorOnDuplicatedFieldNames)
+ val arrowSchema = ArrowUtils.toArrowSchema(
+ schema, timeZoneId, errorOnDuplicatedFieldNames, largeVarTypes)
val writeChannel = new WriteChannel(Channels.newChannel(out))
// Write the Arrow schema first, before batches
@@ -81,10 +83,11 @@ private[sql] object ArrowConverters extends Logging {
maxRecordsPerBatch: Long,
timeZoneId: String,
errorOnDuplicatedFieldNames: Boolean,
+ largeVarTypes: Boolean,
context: TaskContext) extends Iterator[Array[Byte]] with AutoCloseable {
protected val arrowSchema =
- ArrowUtils.toArrowSchema(schema, timeZoneId, errorOnDuplicatedFieldNames)
+ ArrowUtils.toArrowSchema(schema, timeZoneId,
errorOnDuplicatedFieldNames, largeVarTypes)
private val allocator =
ArrowUtils.rootAllocator.newChildAllocator(
s"to${this.getClass.getSimpleName}", 0, Long.MaxValue)
@@ -137,9 +140,16 @@ private[sql] object ArrowConverters extends Logging {
maxEstimatedBatchSize: Long,
timeZoneId: String,
errorOnDuplicatedFieldNames: Boolean,
+ largeVarTypes: Boolean,
context: TaskContext)
extends ArrowBatchIterator(
- rowIter, schema, maxRecordsPerBatch, timeZoneId,
errorOnDuplicatedFieldNames, context) {
+ rowIter,
+ schema,
+ maxRecordsPerBatch,
+ timeZoneId,
+ errorOnDuplicatedFieldNames,
+ largeVarTypes,
+ context) {
private val arrowSchemaSize = SizeEstimator.estimate(arrowSchema)
var rowCountInLastBatch: Long = 0
@@ -205,9 +215,16 @@ private[sql] object ArrowConverters extends Logging {
maxRecordsPerBatch: Long,
timeZoneId: String,
errorOnDuplicatedFieldNames: Boolean,
+ largeVarTypes: Boolean,
context: TaskContext): ArrowBatchIterator = {
new ArrowBatchIterator(
- rowIter, schema, maxRecordsPerBatch, timeZoneId,
errorOnDuplicatedFieldNames, context)
+ rowIter,
+ schema,
+ maxRecordsPerBatch,
+ timeZoneId,
+ errorOnDuplicatedFieldNames,
+ largeVarTypes,
+ context)
}
/**
@@ -220,19 +237,21 @@ private[sql] object ArrowConverters extends Logging {
maxRecordsPerBatch: Long,
maxEstimatedBatchSize: Long,
timeZoneId: String,
- errorOnDuplicatedFieldNames: Boolean): ArrowBatchWithSchemaIterator = {
+ errorOnDuplicatedFieldNames: Boolean,
+ largeVarTypes: Boolean): ArrowBatchWithSchemaIterator = {
new ArrowBatchWithSchemaIterator(
rowIter, schema, maxRecordsPerBatch, maxEstimatedBatchSize,
- timeZoneId, errorOnDuplicatedFieldNames, TaskContext.get())
+ timeZoneId, errorOnDuplicatedFieldNames, largeVarTypes,
TaskContext.get())
}
private[sql] def createEmptyArrowBatch(
schema: StructType,
timeZoneId: String,
- errorOnDuplicatedFieldNames: Boolean): Array[Byte] = {
+ errorOnDuplicatedFieldNames: Boolean,
+ largeVarTypes: Boolean): Array[Byte] = {
val batches = new ArrowBatchWithSchemaIterator(
Iterator.empty, schema, 0L, 0L,
- timeZoneId, errorOnDuplicatedFieldNames, TaskContext.get()) {
+ timeZoneId, errorOnDuplicatedFieldNames, largeVarTypes,
TaskContext.get()) {
override def hasNext: Boolean = true
}
Utils.tryWithSafeFinally {
@@ -299,12 +318,13 @@ private[sql] object ArrowConverters extends Logging {
schema: StructType,
timeZoneId: String,
errorOnDuplicatedFieldNames: Boolean,
+ largeVarTypes: Boolean,
context: TaskContext)
extends InternalRowIterator(arrowBatchIter, context) {
override def nextBatch(): (Iterator[InternalRow], StructType) = {
val arrowSchema =
- ArrowUtils.toArrowSchema(schema, timeZoneId,
errorOnDuplicatedFieldNames)
+ ArrowUtils.toArrowSchema(schema, timeZoneId,
errorOnDuplicatedFieldNames, largeVarTypes)
val root = VectorSchemaRoot.create(arrowSchema, allocator)
resources.append(root)
val arrowRecordBatch = ArrowConverters.loadBatch(arrowBatchIter.next(),
allocator)
@@ -344,9 +364,12 @@ private[sql] object ArrowConverters extends Logging {
schema: StructType,
timeZoneId: String,
errorOnDuplicatedFieldNames: Boolean,
- context: TaskContext): Iterator[InternalRow] = new
InternalRowIteratorWithoutSchema(
- arrowBatchIter, schema, timeZoneId, errorOnDuplicatedFieldNames, context
- )
+ largeVarTypes: Boolean,
+ context: TaskContext): Iterator[InternalRow] = {
+ new InternalRowIteratorWithoutSchema(
+ arrowBatchIter, schema, timeZoneId, errorOnDuplicatedFieldNames,
largeVarTypes, context
+ )
+ }
/**
* Maps iterator from serialized ArrowRecordBatches to InternalRows.
Different from
@@ -393,6 +416,7 @@ private[sql] object ArrowConverters extends Logging {
val schema = DataType.fromJson(schemaString).asInstanceOf[StructType]
val attrs = toAttributes(schema)
val batchesInDriver = arrowBatches.toArray
+ val largeVarTypes = session.sessionState.conf.arrowUseLargeVarTypes
val shouldUseRDD = session.sessionState.conf
.arrowLocalRelationThreshold < batchesInDriver.map(_.length.toLong).sum
@@ -407,6 +431,7 @@ private[sql] object ArrowConverters extends Logging {
schema,
timezone,
errorOnDuplicatedFieldNames = false,
+ largeVarTypes = largeVarTypes,
TaskContext.get())
}
session.internalCreateDataFrame(rdd.setName("arrow"), schema)
@@ -417,6 +442,7 @@ private[sql] object ArrowConverters extends Logging {
schema,
session.sessionState.conf.sessionLocalTimeZone,
errorOnDuplicatedFieldNames = false,
+ largeVarTypes = largeVarTypes,
TaskContext.get())
// Project/copy it. Otherwise, the Arrow column vectors will be closed
and released out.
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
index 1bddd81fbfe2..bf2142422562 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
@@ -120,6 +120,9 @@ object ArrowPythonRunner {
val arrowAyncParallelism = conf.pythonUDFArrowConcurrencyLevel.map(v =>
Seq(SQLConf.PYTHON_UDF_ARROW_CONCURRENCY_LEVEL.key -> v.toString)
).getOrElse(Seq.empty)
- Map(timeZoneConf ++ pandasColsByName ++ arrowSafeTypeCheck ++
arrowAyncParallelism: _*)
+ val useLargeVarTypes = Seq(SQLConf.ARROW_EXECUTION_USE_LARGE_VAR_TYPES.key
->
+ conf.arrowUseLargeVarTypes.toString)
+ Map(timeZoneConf ++ pandasColsByName ++ arrowSafeTypeCheck ++
+ arrowAyncParallelism ++ useLargeVarTypes: _*)
}
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala
index 59e8970b9c9b..9caa344d00c5 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala
@@ -45,6 +45,7 @@ class CoGroupedArrowPythonRunner(
leftSchema: StructType,
rightSchema: StructType,
timeZoneId: String,
+ largeVarTypes: Boolean,
conf: Map[String, String],
override val pythonMetrics: Map[String, SQLMetric],
jobArtifactUUID: Option[String],
@@ -109,7 +110,8 @@ class CoGroupedArrowPythonRunner(
dataOut: DataOutputStream,
name: String): Unit = {
val arrowSchema =
- ArrowUtils.toArrowSchema(schema, timeZoneId,
errorOnDuplicatedFieldNames = true)
+ ArrowUtils.toArrowSchema(
+ schema, timeZoneId, errorOnDuplicatedFieldNames = true,
largeVarTypes = largeVarTypes)
val allocator = ArrowUtils.rootAllocator.newChildAllocator(
s"stdout writer for $pythonExec ($name)", 0, Long.MaxValue)
val root = VectorSchemaRoot.create(arrowSchema, allocator)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInBatchExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInBatchExec.scala
index 66ed2bca7677..af487218391e 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInBatchExec.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInBatchExec.scala
@@ -42,6 +42,7 @@ trait FlatMapCoGroupsInBatchExec extends SparkPlan with
BinaryExecNode with Pyth
protected val pythonEvalType: Int
private val sessionLocalTimeZone = conf.sessionLocalTimeZone
+ private val largeVarTypes = conf.arrowUseLargeVarTypes
private val pythonRunnerConf = ArrowPythonRunner.getPythonRunnerConfMap(conf)
private val pythonUDF = func.asInstanceOf[PythonUDF]
private val pandasFunction = pythonUDF.func
@@ -84,6 +85,7 @@ trait FlatMapCoGroupsInBatchExec extends SparkPlan with
BinaryExecNode with Pyth
DataTypeUtils.fromAttributes(leftDedup),
DataTypeUtils.fromAttributes(rightDedup),
sessionLocalTimeZone,
+ largeVarTypes,
pythonRunnerConf,
pythonMetrics,
jobArtifactUUID,
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/r/ArrowRRunner.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/r/ArrowRRunner.scala
index 45ecf8700950..aaf2f256273d 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/r/ArrowRRunner.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/r/ArrowRRunner.scala
@@ -85,7 +85,8 @@ class ArrowRRunner(
override protected def writeIteratorToStream(dataOut: DataOutputStream):
Unit = {
if (inputIterator.hasNext) {
val arrowSchema =
- ArrowUtils.toArrowSchema(schema, timeZoneId,
errorOnDuplicatedFieldNames = true)
+ ArrowUtils.toArrowSchema(
+ schema, timeZoneId, errorOnDuplicatedFieldNames = true,
largeVarTypes = false)
val allocator = ArrowUtils.rootAllocator.newChildAllocator(
"stdout writer for R", 0, Long.MaxValue)
val root = VectorSchemaRoot.create(arrowSchema, allocator)
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala
index 33e5d46ee233..39c3d8df7550 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala
@@ -1376,8 +1376,10 @@ class ArrowConvertersSuite extends SharedSparkSession {
val schema = StructType(Seq(StructField("int", IntegerType, nullable =
true)))
val ctx = TaskContext.empty()
- val batchIter = ArrowConverters.toBatchIterator(inputRows.iterator,
schema, 5, null, true, ctx)
- val outputRowIter = ArrowConverters.fromBatchIterator(batchIter, schema,
null, true, ctx)
+ val batchIter = ArrowConverters.toBatchIterator(
+ inputRows.iterator, schema, 5, null, true, false, ctx)
+ val outputRowIter = ArrowConverters.fromBatchIterator(
+ batchIter, schema, null, true, false, ctx)
var count = 0
outputRowIter.zipWithIndex.foreach { case (row, i) =>
@@ -1397,12 +1399,13 @@ class ArrowConvertersSuite extends SharedSparkSession {
val schema = StructType(Seq(StructField("int", IntegerType, nullable =
true)))
val ctx = TaskContext.empty()
- val batchIter = ArrowConverters.toBatchIterator(inputRows.iterator,
schema, 5, null, true, ctx)
+ val batchIter = ArrowConverters.toBatchIterator(
+ inputRows.iterator, schema, 5, null, true, false, ctx)
// Write batches to Arrow stream format as a byte array
val out = new ByteArrayOutputStream()
Utils.tryWithResource(new DataOutputStream(out)) { dataOut =>
- val writer = new ArrowBatchStreamWriter(schema, dataOut, null, true)
+ val writer = new ArrowBatchStreamWriter(schema, dataOut, null, true,
false)
writer.writeBatches(batchIter)
writer.end()
}
@@ -1410,7 +1413,8 @@ class ArrowConvertersSuite extends SharedSparkSession {
// Read Arrow stream into batches, then convert back to rows
val in = new ByteArrayReadableSeekableByteChannel(out.toByteArray)
val readBatches = ArrowConverters.getBatchesFromStream(in)
- val outputRowIter = ArrowConverters.fromBatchIterator(readBatches, schema,
null, true, ctx)
+ val outputRowIter = ArrowConverters.fromBatchIterator(
+ readBatches, schema, null, true, false, ctx)
var count = 0
outputRowIter.zipWithIndex.foreach { case (row, i) =>
@@ -1441,7 +1445,7 @@ class ArrowConvertersSuite extends SharedSparkSession {
}
val ctx = TaskContext.empty()
val batchIter = ArrowConverters.toBatchWithSchemaIterator(
- inputRows.iterator, schema, 5, 1024 * 1024, null, true)
+ inputRows.iterator, schema, 5, 1024 * 1024, null, true, false)
val (outputRowIter, outputType) =
ArrowConverters.fromBatchWithSchemaIterator(batchIter, ctx)
var count = 0
@@ -1460,7 +1464,8 @@ class ArrowConvertersSuite extends SharedSparkSession {
val schema = StructType(Seq(StructField("int", IntegerType, nullable =
true)))
val ctx = TaskContext.empty()
val batchIter =
- ArrowConverters.toBatchWithSchemaIterator(Iterator.empty, schema, 5,
1024 * 1024, null, true)
+ ArrowConverters.toBatchWithSchemaIterator(
+ Iterator.empty, schema, 5, 1024 * 1024, null, true, false)
val (outputRowIter, outputType) =
ArrowConverters.fromBatchWithSchemaIterator(batchIter, ctx)
assert(0 == outputRowIter.length)
@@ -1474,7 +1479,7 @@ class ArrowConvertersSuite extends SharedSparkSession {
proj(row).copy()
}
val batchIter1 = ArrowConverters.toBatchWithSchemaIterator(
- inputRows1.iterator, schema1, 5, 1024 * 1024, null, true)
+ inputRows1.iterator, schema1, 5, 1024 * 1024, null, true, false)
val schema2 = StructType(Seq(StructField("field2", IntegerType, nullable =
true)))
val inputRows2 = Array(InternalRow(1)).map { row =>
@@ -1482,7 +1487,7 @@ class ArrowConvertersSuite extends SharedSparkSession {
proj(row).copy()
}
val batchIter2 = ArrowConverters.toBatchWithSchemaIterator(
- inputRows2.iterator, schema2, 5, 1024 * 1024, null, true)
+ inputRows2.iterator, schema2, 5, 1024 * 1024, null, true, false)
val iter = batchIter1.toArray ++ batchIter2
@@ -1511,11 +1516,13 @@ class ArrowConvertersSuite extends SharedSparkSession {
batchBytes: Array[Byte],
jsonFile: File,
timeZoneId: String = null,
- errorOnDuplicatedFieldNames: Boolean = true): Unit = {
+ errorOnDuplicatedFieldNames: Boolean = true,
+ largeVarTypes: Boolean = false): Unit = {
val allocator = new RootAllocator(Long.MaxValue)
val jsonReader = new JsonFileReader(jsonFile, allocator)
- val arrowSchema = ArrowUtils.toArrowSchema(sparkSchema, timeZoneId,
errorOnDuplicatedFieldNames)
+ val arrowSchema = ArrowUtils.toArrowSchema(
+ sparkSchema, timeZoneId, errorOnDuplicatedFieldNames, largeVarTypes)
val jsonSchema = jsonReader.start()
Validator.compareSchemas(arrowSchema, jsonSchema)
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala
index 7830ea1da177..acf258a373c3 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala
@@ -158,7 +158,8 @@ class ArrowWriterSuite extends SparkFunSuite {
schema: StructType,
timeZoneId: String): (ArrowWriter, Int) = {
val arrowSchema =
- ArrowUtils.toArrowSchema(schema, timeZoneId,
errorOnDuplicatedFieldNames = true)
+ ArrowUtils.toArrowSchema(
+ schema, timeZoneId, errorOnDuplicatedFieldNames = true,
largeVarTypes = false)
val root = VectorSchemaRoot.create(arrowSchema, ArrowUtils.rootAllocator)
val vector = root.getFieldVectors.get(0)
vector.allocateNew()
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]