This is an automated email from the ASF dual-hosted git repository. xinrong pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 305aa4a89ef [SPARK-41971][SQL][PYTHON] Add a config for pandas conversion how to handle struct types 305aa4a89ef is described below commit 305aa4a89efe02f517f82039225a99b31b20146f Author: Takuya UESHIN <ues...@databricks.com> AuthorDate: Thu May 4 11:01:28 2023 -0700 [SPARK-41971][SQL][PYTHON] Add a config for pandas conversion how to handle struct types ### What changes were proposed in this pull request? Adds a config for pandas conversion how to handle struct types. - `spark.sql.execution.pandas.structHandlingMode` (default: `"legacy"`) The conversion mode of struct type when creating pandas DataFrame. #### When `"legacy"`, the behavior is the same as before, except that with Arrow and Spark Connect will raise a more readable exception when there are duplicated nested field names. ```py >>> spark.sql("values (1, struct(1 as a, 2 as a)) as t(x, y)").toPandas() Traceback (most recent call last): ... pyspark.errors.exceptions.connect.UnsupportedOperationException: [DUPLICATED_FIELD_NAME_IN_ARROW_STRUCT] Duplicated field names in Arrow Struct are not allowed, got [a, a]. ``` #### When `"row"`, convert to Row object regardless of Arrow optimization. ```py >>> spark.conf.set('spark.sql.execution.pandas.structHandlingMode', 'row') >>> spark.conf.set('spark.sql.execution.arrow.pyspark.enabled', False) >>> spark.sql("values (1, struct(1 as a, 2 as b)) as t(x, y)").toPandas() x y 0 1 (1, 2) >>> spark.sql("values (1, struct(1 as a, 2 as a)) as t(x, y)").toPandas() x y 0 1 (1, 2) >>> spark.conf.set('spark.sql.execution.arrow.pyspark.enabled', True) >>> spark.sql("values (1, struct(1 as a, 2 as b)) as t(x, y)").toPandas() x y 0 1 (1, 2) >>> spark.sql("values (1, struct(1 as a, 2 as a)) as t(x, y)").toPandas() x y 0 1 (1, 2) ``` #### When `"dict"`, convert to dict and use suffixed key names, e.g., `a_0`, `a_1`, if there are duplicated nested field names, regardless of Arrow optimization. ```py >>> spark.conf.set('spark.sql.execution.pandas.structHandlingMode', 'dict') >>> spark.conf.set('spark.sql.execution.arrow.pyspark.enabled', False) >>> spark.sql("values (1, struct(1 as a, 2 as b)) as t(x, y)").toPandas() x y 0 1 {'a': 1, 'b': 2} >>> spark.sql("values (1, struct(1 as a, 2 as a)) as t(x, y)").toPandas() x y 0 1 {'a_0': 1, 'a_1': 2} >>> spark.conf.set('spark.sql.execution.arrow.pyspark.enabled', True) >>> spark.sql("values (1, struct(1 as a, 2 as b)) as t(x, y)").toPandas() x y 0 1 {'a': 1, 'b': 2} >>> spark.sql("values (1, struct(1 as a, 2 as a)) as t(x, y)").toPandas() x y 0 1 {'a_0': 1, 'a_1': 2} ``` ### Why are the changes needed? Currently there are three behaviors when `df.toPandas()` with nested struct types: - vanilla PySpark with Arrow optimization disabled ```py >>> spark.conf.set('spark.sql.execution.arrow.pyspark.enabled', False) >>> spark.sql("values (1, struct(1 as a, 2 as b)) as t(x, y)").toPandas() x y 0 1 (1, 2) ``` using `Row` object for struct types. It can use duplicated field names. ```py >>> spark.sql("values (1, struct(1 as a, 2 as a)) as t(x, y)").toPandas() x y 0 1 (1, 2) ``` - vanilla PySpark with Arrow optimization enabled ```py >>> spark.conf.set('spark.sql.execution.arrow.pyspark.enabled', True) >>> spark.sql("values (1, struct(1 as a, 2 as b)) as t(x, y)").toPandas() x y 0 1 {'a': 1, 'b': 2} ``` using `dict` for struct types. It raises an Exception when there are duplicated nested field names: ```py >>> spark.sql("values (1, struct(1 as a, 2 as a)) as t(x, y)").toPandas() Traceback (most recent call last): ... pyarrow.lib.ArrowInvalid: Ran out of field metadata, likely malformed ``` - Spark Connect ```py >>> spark.sql("values (1, struct(1 as a, 2 as b)) as t(x, y)").toPandas() x y 0 1 {'a': 1, 'b': 2} ``` using `dict` for struct types. If there are duplicated nested field names, the duplicated keys are suffixed: ```py >>> spark.sql("values (1, struct(1 as a, 2 as a)) as t(x, y)").toPandas() x y 0 1 {'a_0': 1, 'a_1': 2} ``` ### Does this PR introduce _any_ user-facing change? Users will be able to configure the behavior. ### How was this patch tested? Modified the related tests. Closes #40988 from ueshin/issues/SPARK-41971/struct_in_pandas. Authored-by: Takuya UESHIN <ues...@databricks.com> Signed-off-by: Xinrong Meng <xinr...@apache.org> --- .../scala/org/apache/spark/sql/SparkSession.scala | 3 +- .../sql/connect/client/util/ConvertToArrow.scala | 4 +- .../sql/connect/planner/SparkConnectPlanner.scala | 8 +- .../service/SparkConnectStreamHandler.scala | 57 ++--- .../connect/planner/SparkConnectPlannerSuite.scala | 5 +- .../connect/planner/SparkConnectProtoSuite.scala | 3 +- core/src/main/resources/error/error-classes.json | 5 + python/pyspark/errors/__init__.py | 2 + python/pyspark/errors/error_classes.py | 5 + python/pyspark/errors/exceptions/base.py | 6 + python/pyspark/errors/exceptions/captured.py | 9 + python/pyspark/errors/exceptions/connect.py | 9 + python/pyspark/sql/connect/client.py | 47 +++- python/pyspark/sql/connect/conversion.py | 21 +- python/pyspark/sql/connect/dataframe.py | 4 +- python/pyspark/sql/connect/session.py | 10 +- python/pyspark/sql/pandas/conversion.py | 204 +++++------------ python/pyspark/sql/pandas/types.py | 242 ++++++++++++++++++++- .../pyspark/sql/tests/connect/test_parity_arrow.py | 3 + python/pyspark/sql/tests/test_arrow.py | 56 ++++- python/pyspark/sql/types.py | 14 ++ .../spark/sql/errors/QueryExecutionErrors.scala | 7 + .../spark/sql/execution/arrow/ArrowWriter.scala | 7 +- .../org/apache/spark/sql/internal/SQLConf.scala | 19 ++ .../org/apache/spark/sql/util/ArrowUtils.scala | 47 +++- .../apache/spark/sql/util/ArrowUtilsSuite.scala | 25 ++- .../main/scala/org/apache/spark/sql/Dataset.scala | 12 +- .../sql/execution/arrow/ArrowConverters.scala | 33 +-- .../ApplyInPandasWithStatePythonRunner.scala | 2 + .../sql/execution/python/ArrowPythonRunner.scala | 2 + .../python/CoGroupedArrowPythonRunner.scala | 3 +- .../sql/execution/python/PythonArrowInput.scala | 4 +- .../spark/sql/execution/r/ArrowRRunner.scala | 3 +- .../sql/execution/arrow/ArrowConvertersSuite.scala | 29 ++- 34 files changed, 629 insertions(+), 281 deletions(-) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala index 48640878211..a8bfac5d71f 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -121,7 +121,8 @@ class SparkSession private[sql] ( newDataset(encoder) { builder => if (data.nonEmpty) { val timeZoneId = conf.get("spark.sql.session.timeZone") - val (arrowData, arrowDataSize) = ConvertToArrow(encoder, data, timeZoneId, allocator) + val (arrowData, arrowDataSize) = + ConvertToArrow(encoder, data, timeZoneId, errorOnDuplicatedFieldNames = true, allocator) if (arrowDataSize <= conf.get("spark.sql.session.localRelationCacheThreshold").toInt) { builder.getLocalRelationBuilder .setSchema(encoder.schema.json) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/util/ConvertToArrow.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/util/ConvertToArrow.scala index 46a9493d138..14235094f2d 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/util/ConvertToArrow.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/util/ConvertToArrow.scala @@ -40,8 +40,10 @@ private[sql] object ConvertToArrow { encoder: AgnosticEncoder[T], data: Iterator[T], timeZoneId: String, + errorOnDuplicatedFieldNames: Boolean, bufferAllocator: BufferAllocator): (ByteString, Int) = { - val arrowSchema = ArrowUtils.toArrowSchema(encoder.schema, timeZoneId) + val arrowSchema = + ArrowUtils.toArrowSchema(encoder.schema, timeZoneId, errorOnDuplicatedFieldNames) val root = VectorSchemaRoot.create(arrowSchema, bufferAllocator) val writer: ArrowWriter = ArrowWriter.create(root) val unloader = new VectorUnloader(root) diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 4af604ed4b9..36af01d251c 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -1993,14 +1993,18 @@ class SparkConnectPlanner(val session: SparkSession) { // Convert the data. val bytes = if (rows.isEmpty) { - ArrowConverters.createEmptyArrowBatch(schema, timeZoneId) + ArrowConverters.createEmptyArrowBatch( + schema, + timeZoneId, + errorOnDuplicatedFieldNames = false) } else { val batches = ArrowConverters.toBatchWithSchemaIterator( rows.iterator, schema, maxRecordsPerBatch, maxBatchSize, - timeZoneId) + timeZoneId, + errorOnDuplicatedFieldNames = false) assert(batches.hasNext) val bytes = batches.next() assert(!batches.hasNext, s"remaining batches: ${batches.size}") diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala index 062ef892979..4958fd69b9d 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.connect.service -import java.util.concurrent.atomic.AtomicInteger - import scala.collection.JavaConverters._ import scala.util.control.NonFatal @@ -41,7 +39,7 @@ import org.apache.spark.sql.connect.service.SparkConnectStreamHandler.processAsA import org.apache.spark.sql.execution.{LocalTableScanExec, SparkPlan, SQLExecution} import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AdaptiveSparkPlanHelper, QueryStageExec} import org.apache.spark.sql.execution.arrow.ArrowConverters -import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructField, StructType, UserDefinedType} +import org.apache.spark.sql.types.StructType import org.apache.spark.util.{ThreadUtils, Utils} class SparkConnectStreamHandler(responseObserver: StreamObserver[ExecutePlanResponse]) @@ -131,13 +129,15 @@ object SparkConnectStreamHandler { schema: StructType, maxRecordsPerBatch: Int, maxBatchSize: Long, - timeZoneId: String): Iterator[InternalRow] => Iterator[Batch] = { rows => + timeZoneId: String, + errorOnDuplicatedFieldNames: Boolean): Iterator[InternalRow] => Iterator[Batch] = { rows => val batches = ArrowConverters.toBatchWithSchemaIterator( rows, schema, maxRecordsPerBatch, maxBatchSize, - timeZoneId) + timeZoneId, + errorOnDuplicatedFieldNames) batches.map(b => b -> batches.rowCountInLastBatch) } @@ -145,45 +145,19 @@ object SparkConnectStreamHandler { sessionId: String, dataframe: DataFrame, responseObserver: StreamObserver[ExecutePlanResponse]): Unit = { - - def deduplicateFieldNames(dt: DataType): DataType = dt match { - case udt: UserDefinedType[_] => deduplicateFieldNames(udt.sqlType) - case st @ StructType(fields) => - val newNames = if (st.names.toSet.size == st.names.length) { - st.names - } else { - val genNawName = st.names.groupBy(identity).map { - case (name, names) if names.length > 1 => - val i = new AtomicInteger() - name -> { () => s"${name}_${i.getAndIncrement()}" } - case (name, _) => name -> { () => name } - } - st.names.map(genNawName(_)()) - } - val newFields = - fields.zip(newNames).map { case (StructField(_, dataType, nullable, metadata), name) => - StructField(name, deduplicateFieldNames(dataType), nullable, metadata) - } - StructType(newFields) - case ArrayType(elementType, containsNull) => - ArrayType(deduplicateFieldNames(elementType), containsNull) - case MapType(keyType, valueType, valueContainsNull) => - MapType( - deduplicateFieldNames(keyType), - deduplicateFieldNames(valueType), - valueContainsNull) - case _ => dt - } - val spark = dataframe.sparkSession - val schema = deduplicateFieldNames(dataframe.schema).asInstanceOf[StructType] + val schema = dataframe.schema val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone // Conservatively sets it 70% because the size is not accurate but estimated. val maxBatchSize = (SparkEnv.get.conf.get(CONNECT_GRPC_ARROW_MAX_BATCH_SIZE) * 0.7).toLong - val rowToArrowConverter = SparkConnectStreamHandler - .rowToArrowConverter(schema, maxRecordsPerBatch, maxBatchSize, timeZoneId) + val rowToArrowConverter = SparkConnectStreamHandler.rowToArrowConverter( + schema, + maxRecordsPerBatch, + maxBatchSize, + timeZoneId, + errorOnDuplicatedFieldNames = false) var numSent = 0 def sendBatch(bytes: Array[Byte], count: Long): Unit = { @@ -279,7 +253,12 @@ object SparkConnectStreamHandler { // Make sure at least 1 batch will be sent. if (numSent == 0) { - sendBatch(ArrowConverters.createEmptyArrowBatch(schema, timeZoneId), 0L) + sendBatch( + ArrowConverters.createEmptyArrowBatch( + schema, + timeZoneId, + errorOnDuplicatedFieldNames = false), + 0L) } } diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala index 88b4be16e5a..37d4bec9c87 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala @@ -89,7 +89,8 @@ trait SparkConnectPlanTest extends SharedSparkSession { StructType.fromAttributes(attrs.map(_.toAttribute)), Long.MaxValue, Long.MaxValue, - null) + null, + true) .next() localRelationBuilder.setData(ByteString.copyFrom(bytes)) @@ -464,7 +465,7 @@ class SparkConnectPlannerSuite extends SparkFunSuite with SparkConnectPlanTest { test("Empty ArrowBatch") { val schema = StructType(Seq(StructField("int", IntegerType))) - val data = ArrowConverters.createEmptyArrowBatch(schema, null) + val data = ArrowConverters.createEmptyArrowBatch(schema, null, true) val localRelation = proto.Relation .newBuilder() .setLocalRelation( diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala index 96dae647db6..8cb5c1a2919 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala @@ -1042,7 +1042,8 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest { StructType.fromAttributes(attributes), Long.MaxValue, Long.MaxValue, - null) + null, + true) .next() proto.Relation .newBuilder() diff --git a/core/src/main/resources/error/error-classes.json b/core/src/main/resources/error/error-classes.json index 753908932c8..4eea5f9684e 100644 --- a/core/src/main/resources/error/error-classes.json +++ b/core/src/main/resources/error/error-classes.json @@ -545,6 +545,11 @@ ], "sqlState" : "22012" }, + "DUPLICATED_FIELD_NAME_IN_ARROW_STRUCT" : { + "message" : [ + "Duplicated field names in Arrow Struct are not allowed, got <fieldNames>." + ] + }, "DUPLICATED_MAP_KEY" : { "message" : [ "Duplicate map key <key> was found, please check the input data. If you want to remove the duplicated keys, you can set <mapKeyDedupPolicy> to \"LAST_WIN\" so that the key inserted at last takes precedence." diff --git a/python/pyspark/errors/__init__.py b/python/pyspark/errors/__init__.py index 1525f351ea4..a9bcb973a6f 100644 --- a/python/pyspark/errors/__init__.py +++ b/python/pyspark/errors/__init__.py @@ -25,6 +25,7 @@ from pyspark.errors.exceptions.base import ( # noqa: F401 ParseException, IllegalArgumentException, ArithmeticException, + UnsupportedOperationException, ArrayIndexOutOfBoundsException, DateTimeException, NumberFormatException, @@ -50,6 +51,7 @@ __all__ = [ "ParseException", "IllegalArgumentException", "ArithmeticException", + "UnsupportedOperationException", "ArrayIndexOutOfBoundsException", "DateTimeException", "NumberFormatException", diff --git a/python/pyspark/errors/error_classes.py b/python/pyspark/errors/error_classes.py index 3f52a14a607..c2ed64d01fb 100644 --- a/python/pyspark/errors/error_classes.py +++ b/python/pyspark/errors/error_classes.py @@ -149,6 +149,11 @@ ERROR_CLASSES_JSON = """ "Argument `<arg_name>`(type: <arg_type>) should only contain a type in [<allowed_types>], got <return_type>" ] }, + "DUPLICATED_FIELD_NAME_IN_ARROW_STRUCT" : { + "message" : [ + "Duplicated field names in Arrow Struct are not allowed, got <field_names>" + ] + }, "EXCEED_RETRY" : { "message" : [ "Retries exceeded but no exception caught." diff --git a/python/pyspark/errors/exceptions/base.py b/python/pyspark/errors/exceptions/base.py index 543ee9473e3..1b9a6b0229e 100644 --- a/python/pyspark/errors/exceptions/base.py +++ b/python/pyspark/errors/exceptions/base.py @@ -126,6 +126,12 @@ class ArithmeticException(PySparkException): """ +class UnsupportedOperationException(PySparkException): + """ + Unsupported operation exception thrown from Spark with an error class. + """ + + class ArrayIndexOutOfBoundsException(PySparkException): """ Array index out of bounds exception thrown from Spark with an error class. diff --git a/python/pyspark/errors/exceptions/captured.py b/python/pyspark/errors/exceptions/captured.py index 5b008f4ab00..d62b7d24347 100644 --- a/python/pyspark/errors/exceptions/captured.py +++ b/python/pyspark/errors/exceptions/captured.py @@ -26,6 +26,7 @@ from pyspark.errors.exceptions.base import ( AnalysisException as BaseAnalysisException, IllegalArgumentException as BaseIllegalArgumentException, ArithmeticException as BaseArithmeticException, + UnsupportedOperationException as BaseUnsupportedOperationException, ArrayIndexOutOfBoundsException as BaseArrayIndexOutOfBoundsException, DateTimeException as BaseDateTimeException, NumberFormatException as BaseNumberFormatException, @@ -141,6 +142,8 @@ def convert_exception(e: Py4JJavaError) -> CapturedException: return IllegalArgumentException(origin=e) elif is_instance_of(gw, e, "java.lang.ArithmeticException"): return ArithmeticException(origin=e) + elif is_instance_of(gw, e, "java.lang.UnsupportedOperationException"): + return UnsupportedOperationException(origin=e) elif is_instance_of(gw, e, "java.lang.ArrayIndexOutOfBoundsException"): return ArrayIndexOutOfBoundsException(origin=e) elif is_instance_of(gw, e, "java.time.DateTimeException"): @@ -262,6 +265,12 @@ class ArithmeticException(CapturedException, BaseArithmeticException): """ +class UnsupportedOperationException(CapturedException, BaseUnsupportedOperationException): + """ + Unsupported operation exception. + """ + + class ArrayIndexOutOfBoundsException(CapturedException, BaseArrayIndexOutOfBoundsException): """ Array index out of bounds exception. diff --git a/python/pyspark/errors/exceptions/connect.py b/python/pyspark/errors/exceptions/connect.py index f8f234ed2ee..48b213080c1 100644 --- a/python/pyspark/errors/exceptions/connect.py +++ b/python/pyspark/errors/exceptions/connect.py @@ -22,6 +22,7 @@ from pyspark.errors.exceptions.base import ( AnalysisException as BaseAnalysisException, IllegalArgumentException as BaseIllegalArgumentException, ArithmeticException as BaseArithmeticException, + UnsupportedOperationException as BaseUnsupportedOperationException, ArrayIndexOutOfBoundsException as BaseArrayIndexOutOfBoundsException, DateTimeException as BaseDateTimeException, NumberFormatException as BaseNumberFormatException, @@ -69,6 +70,8 @@ def convert_exception(info: "ErrorInfo", message: str) -> SparkConnectException: return IllegalArgumentException(message) elif "java.lang.ArithmeticException" in classes: return ArithmeticException(message) + elif "java.lang.UnsupportedOperationException" in classes: + return UnsupportedOperationException(message) elif "java.lang.ArrayIndexOutOfBoundsException" in classes: return ArrayIndexOutOfBoundsException(message) elif "java.time.DateTimeException" in classes: @@ -151,6 +154,12 @@ class ArithmeticException(SparkConnectGrpcException, BaseArithmeticException): """ +class UnsupportedOperationException(SparkConnectGrpcException, BaseUnsupportedOperationException): + """ + Unsupported operation exception thrown from Spark Connect. + """ + + class ArrayIndexOutOfBoundsException(SparkConnectGrpcException, BaseArrayIndexOutOfBoundsException): """ Array index out of bounds exception thrown from Spark Connect. diff --git a/python/pyspark/sql/connect/client.py b/python/pyspark/sql/connect/client.py index ffcc4f768ae..070c4ab19d3 100644 --- a/python/pyspark/sql/connect/client.py +++ b/python/pyspark/sql/connect/client.py @@ -72,8 +72,8 @@ from pyspark.sql.connect.expressions import ( CommonInlineUserDefinedFunction, JavaUDF, ) -from pyspark.sql.pandas.types import _check_series_localize_timestamps, _convert_map_items_to_dict -from pyspark.sql.types import DataType, MapType, StructType, TimestampType +from pyspark.sql.pandas.types import _create_converter_to_pandas +from pyspark.sql.types import DataType, StructType, TimestampType, _has_type from pyspark.rdd import PythonEvalType from pyspark.storagelevel import StorageLevel from pyspark.errors import PySparkValueError, PySparkRuntimeError @@ -705,17 +705,35 @@ class SparkConnectClient(object): schema = schema or types.from_arrow_schema(table.schema) assert schema is not None and isinstance(schema, StructType) - pdf = table.to_pandas() - pdf.columns = schema.fieldNames() + # Rename columns to avoid duplicated column names. + pdf = table.rename_columns([f"col_{i}" for i in range(table.num_columns)]).to_pandas() + pdf.columns = schema.names - for field, pa_field in zip(schema, table.schema): - if isinstance(field.dataType, TimestampType): - assert pa_field.type.tz is not None - pdf[field.name] = _check_series_localize_timestamps( - pdf[field.name], pa_field.type.tz - ) - elif isinstance(field.dataType, MapType): - pdf[field.name] = _convert_map_items_to_dict(pdf[field.name]) + timezone: Optional[str] = None + struct_in_pandas: Optional[str] = None + error_on_duplicated_field_names: bool = False + if any(_has_type(f.dataType, (StructType, TimestampType)) for f in schema.fields): + timezone, struct_in_pandas = self.get_configs( + "spark.sql.session.timeZone", "spark.sql.execution.pandas.structHandlingMode" + ) + + if struct_in_pandas == "legacy": + error_on_duplicated_field_names = True + struct_in_pandas = "dict" + + pdf = pd.concat( + [ + _create_converter_to_pandas( + field.dataType, + field.nullable, + timezone=timezone, + struct_in_pandas=struct_in_pandas, + error_on_duplicated_field_names=error_on_duplicated_field_names, + )(pser) + for (_, pser), field, pa_field in zip(pdf.items(), schema.fields, table.schema) + ], + axis="columns", + ) if len(metrics) > 0: pdf.attrs["metrics"] = metrics @@ -1068,6 +1086,11 @@ class SparkConnectClient(object): req.user_context.user_id = self._user_id return req + def get_configs(self, *keys: str) -> Tuple[Optional[str], ...]: + op = pb2.ConfigRequest.Operation(get=pb2.ConfigRequest.Get(keys=keys)) + configs = dict(self.config(op).pairs) + return tuple(configs.get(key) for key in keys) + def config(self, operation: pb2.ConfigRequest.Operation) -> ConfigResult: """ Call the config RPC of Spark Connect. diff --git a/python/pyspark/sql/connect/conversion.py b/python/pyspark/sql/connect/conversion.py index 16679e80205..a7ea88fb007 100644 --- a/python/pyspark/sql/connect/conversion.py +++ b/python/pyspark/sql/connect/conversion.py @@ -19,7 +19,6 @@ from pyspark.sql.connect.utils import check_dependencies check_dependencies(__name__) import array -import itertools import datetime import decimal @@ -45,6 +44,7 @@ from pyspark.sql.types import ( from pyspark.storagelevel import StorageLevel from pyspark.sql.connect.types import to_arrow_schema import pyspark.sql.connect.proto as pb2 +from pyspark.sql.pandas.types import _dedup_names from typing import ( Any, @@ -507,22 +507,3 @@ def _deduplicate_field_names(dt: DataType) -> DataType: ) else: return dt - - -def _dedup_names(names: List[str]) -> List[str]: - if len(set(names)) == len(names): - return names - else: - - def _gen_dedup(_name: str) -> Callable[[], str]: - _i = itertools.count() - return lambda: f"{_name}_{next(_i)}" - - def _gen_identity(_name: str) -> Callable[[], str]: - return lambda: _name - - gen_new_name = { - name: _gen_dedup(name) if len(list(group)) > 1 else _gen_identity(name) - for name, group in itertools.groupby(sorted(names)) - } - return [gen_new_name[name]() for name in names] diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py index 005eed69722..50eadf46200 100644 --- a/python/pyspark/sql/connect/dataframe.py +++ b/python/pyspark/sql/connect/dataframe.py @@ -111,7 +111,7 @@ class DataFrame: repl_eager_eval_enabled, repl_eager_eval_max_num_rows, repl_eager_eval_truncate, - ) = self._session._get_configs( + ) = self._session._client.get_configs( "spark.sql.repl.eagerEval.enabled", "spark.sql.repl.eagerEval.maxNumRows", "spark.sql.repl.eagerEval.truncate", @@ -131,7 +131,7 @@ class DataFrame: repl_eager_eval_enabled, repl_eager_eval_max_num_rows, repl_eager_eval_truncate, - ) = self._session._get_configs( + ) = self._session._client.get_configs( "spark.sql.repl.eagerEval.enabled", "spark.sql.repl.eagerEval.maxNumRows", "spark.sql.repl.eagerEval.truncate", diff --git a/python/pyspark/sql/connect/session.py b/python/pyspark/sql/connect/session.py index fde861d12b9..3bd842f7847 100644 --- a/python/pyspark/sql/connect/session.py +++ b/python/pyspark/sql/connect/session.py @@ -47,7 +47,6 @@ from pandas.api.types import ( # type: ignore[attr-defined] ) from pyspark import SparkContext, SparkConf, __version__ -from pyspark.sql.connect import proto from pyspark.sql.connect.client import SparkConnectClient, ChannelBuilder from pyspark.sql.connect.conf import RuntimeConf from pyspark.sql.connect.dataframe import DataFrame @@ -226,11 +225,6 @@ class SparkSession: def readStream(self) -> "DataStreamReader": return DataStreamReader(self) - def _get_configs(self, *keys: str) -> Tuple[Optional[str], ...]: - op = proto.ConfigRequest.Operation(get=proto.ConfigRequest.Get(keys=keys)) - configs = dict(self._client.config(op).pairs) - return tuple(configs.get(key) for key in keys) - def _inferSchemaFromList( self, data: Iterable[Any], names: Optional[List[str]] = None ) -> StructType: @@ -244,7 +238,7 @@ class SparkSession: infer_dict_as_struct, infer_array_from_first_element, prefer_timestamp_ntz, - ) = self._get_configs( + ) = self._client.get_configs( "spark.sql.pyspark.inferNestedDictAsStruct.enabled", "spark.sql.pyspark.legacy.inferArrayTypeFromFirstElement.enabled", "spark.sql.timestampType", @@ -334,7 +328,7 @@ class SparkSession: for t in data.dtypes ] - timezone, safecheck = self._get_configs( + timezone, safecheck = self._client.get_configs( "spark.sql.session.timeZone", "spark.sql.execution.pandas.convertToArrowArraySafely" ) diff --git a/python/pyspark/sql/pandas/conversion.py b/python/pyspark/sql/pandas/conversion.py index ce0143d1851..0c29dcceed0 100644 --- a/python/pyspark/sql/pandas/conversion.py +++ b/python/pyspark/sql/pandas/conversion.py @@ -15,29 +15,20 @@ # limitations under the License. # import sys -from collections import Counter -from typing import List, Optional, Type, Union, no_type_check, overload, TYPE_CHECKING -from warnings import catch_warnings, simplefilter, warn +from typing import ( + List, + Optional, + Union, + no_type_check, + overload, + TYPE_CHECKING, +) +from warnings import warn from pyspark.errors.exceptions.captured import unwrap_spark_exception from pyspark.rdd import _load_from_socket from pyspark.sql.pandas.serializers import ArrowCollectSerializer -from pyspark.sql.types import ( - IntegralType, - ByteType, - ShortType, - IntegerType, - LongType, - FloatType, - DoubleType, - BooleanType, - MapType, - TimestampType, - TimestampNTZType, - DayTimeIntervalType, - StructType, - DataType, -) +from pyspark.sql.types import TimestampType, StructType, DataType from pyspark.sql.utils import is_timestamp_ntz_preferred from pyspark.traceback_utils import SCCallSiteSync @@ -85,16 +76,16 @@ class PandasConversionMixin: assert isinstance(self, DataFrame) + from pyspark.sql.pandas.types import _create_converter_to_pandas from pyspark.sql.pandas.utils import require_minimum_pandas_version require_minimum_pandas_version() - import numpy as np import pandas as pd - from pandas.core.dtypes.common import is_timedelta64_dtype jconf = self.sparkSession._jconf timezone = jconf.sessionLocalTimeZone() + struct_in_pandas = jconf.pandasStructHandlingMode() if jconf.arrowPySparkEnabled(): use_arrow = True @@ -132,18 +123,10 @@ class PandasConversionMixin: # of PyArrow is found, if 'spark.sql.execution.arrow.pyspark.enabled' is enabled. if use_arrow: try: - from pyspark.sql.pandas.types import ( - _check_series_localize_timestamps, - _convert_map_items_to_dict, - ) import pyarrow - # Rename columns to avoid duplicated column names. - tmp_column_names = ["col_{}".format(i) for i in range(len(self.columns))] self_destruct = jconf.arrowPySparkSelfDestructEnabled() - batches = self.toDF(*tmp_column_names)._collect_as_arrow( - split_batches=self_destruct - ) + batches = self._collect_as_arrow(split_batches=self_destruct) if len(batches) > 0: table = pyarrow.Table.from_batches(batches) # Ensure only the table has a reference to the batches, so that @@ -165,32 +148,34 @@ class PandasConversionMixin: "use_threads": False, } ) - pdf = table.to_pandas(**pandas_options) + # Rename columns to avoid duplicated column names. + pdf = table.rename_columns( + [f"col_{i}" for i in range(table.num_columns)] + ).to_pandas(**pandas_options) + # Rename back to the original column names. pdf.columns = self.columns - for field in self.schema: - if isinstance(field.dataType, TimestampType): - pdf[field.name] = _check_series_localize_timestamps( - pdf[field.name], timezone - ) - elif isinstance(field.dataType, MapType): - pdf[field.name] = _convert_map_items_to_dict(pdf[field.name]) - return pdf else: - corrected_panda_types = {} - for index, field in enumerate(self.schema): - pandas_type = PandasConversionMixin._to_corrected_pandas_type( - field.dataType - ) - corrected_panda_types[tmp_column_names[index]] = ( - object if pandas_type is None else pandas_type - ) - - pdf = pd.DataFrame(columns=tmp_column_names).astype( - dtype=corrected_panda_types - ) - pdf.columns = self.columns - return pdf + pdf = pd.DataFrame(columns=self.columns) + + error_on_duplicated_field_names = False + if struct_in_pandas == "legacy": + error_on_duplicated_field_names = True + struct_in_pandas = "dict" + + return pd.concat( + [ + _create_converter_to_pandas( + field.dataType, + field.nullable, + timezone=timezone, + struct_in_pandas=struct_in_pandas, + error_on_duplicated_field_names=error_on_duplicated_field_names, + )(pser) + for (_, pser), field in zip(pdf.items(), self.schema.fields) + ], + axis="columns", + ) except Exception as e: # We might have to allow fallback here as well but multiple Spark jobs can # be executed. So, simply fail in this case for now. @@ -207,107 +192,20 @@ class PandasConversionMixin: # Below is toPandas without Arrow optimization. pdf = pd.DataFrame.from_records(self.collect(), columns=self.columns) - column_counter = Counter(self.columns) - - corrected_dtypes: List[Optional[Type]] = [None] * len(self.schema) - for index, field in enumerate(self.schema): - # We use `iloc` to access columns with duplicate column names. - if column_counter[field.name] > 1: - pandas_col = pdf.iloc[:, index] - else: - pandas_col = pdf[field.name] - - pandas_type = PandasConversionMixin._to_corrected_pandas_type(field.dataType) - # SPARK-21766: if an integer field is nullable and has null values, it can be - # inferred by pandas as a float column. If we convert the column with NaN back - # to integer type e.g., np.int16, we will hit an exception. So we use the - # pandas-inferred float type, rather than the corrected type from the schema - # in this case. - if pandas_type is not None and not ( - isinstance(field.dataType, IntegralType) - and field.nullable - and pandas_col.isnull().any() - ): - corrected_dtypes[index] = pandas_type - # Ensure we fall back to nullable numpy types. - if isinstance(field.dataType, IntegralType) and pandas_col.isnull().any(): - corrected_dtypes[index] = np.float64 - if isinstance(field.dataType, BooleanType) and pandas_col.isnull().any(): - corrected_dtypes[index] = object - - df = pd.DataFrame() - for index, t in enumerate(corrected_dtypes): - column_name = self.schema[index].name - - # We use `iloc` to access columns with duplicate column names. - if column_counter[column_name] > 1: - series = pdf.iloc[:, index] - else: - series = pdf[column_name] - # No need to cast for non-empty series for timedelta. The type is already correct. - should_check_timedelta = is_timedelta64_dtype(t) and len(pdf) == 0 - - if (t is not None and not is_timedelta64_dtype(t)) or should_check_timedelta: - series = series.astype(t, copy=False) - - with catch_warnings(): - from pandas.errors import PerformanceWarning - - simplefilter(action="ignore", category=PerformanceWarning) - # `insert` API makes copy of data, - # we only do it for Series of duplicate column names. - # `pdf.iloc[:, index] = pdf.iloc[:, index]...` doesn't always work - # because `iloc` could return a view or a copy depending by context. - if column_counter[column_name] > 1: - df.insert(index, column_name, series, allow_duplicates=True) - else: - df[column_name] = series - - if timezone is None: - return df - else: - from pyspark.sql.pandas.types import _check_series_convert_timestamps_local_tz - - for field in self.schema: - # TODO: handle nested timestamps, such as ArrayType(TimestampType())? - if isinstance(field.dataType, TimestampType): - df[field.name] = _check_series_convert_timestamps_local_tz( - df[field.name], timezone - ) - return df - - @staticmethod - def _to_corrected_pandas_type(dt: DataType) -> Optional[Type]: - """ - When converting Spark SQL records to Pandas `pandas.DataFrame`, the inferred data type - may be wrong. This method gets the corrected data type for Pandas if that type may be - inferred incorrectly. - """ - import numpy as np - - if type(dt) == ByteType: - return np.int8 - elif type(dt) == ShortType: - return np.int16 - elif type(dt) == IntegerType: - return np.int32 - elif type(dt) == LongType: - return np.int64 - elif type(dt) == FloatType: - return np.float32 - elif type(dt) == DoubleType: - return np.float64 - elif type(dt) == BooleanType: - return bool - elif type(dt) == TimestampType: - return np.datetime64 - elif type(dt) == TimestampNTZType: - return np.datetime64 - elif type(dt) == DayTimeIntervalType: - return np.timedelta64 - else: - return None + return pd.concat( + [ + _create_converter_to_pandas( + field.dataType, + field.nullable, + timezone=timezone, + struct_in_pandas=("row" if struct_in_pandas == "legacy" else struct_in_pandas), + error_on_duplicated_field_names=False, + )(pser) + for (_, pser), field in zip(pdf.items(), self.schema.fields) + ], + axis="columns", + ) def _collect_as_arrow(self, split_batches: bool = False) -> List["pa.RecordBatch"]: """ diff --git a/python/pyspark/sql/pandas/types.py b/python/pyspark/sql/pandas/types.py index 70d50ca6e95..d23e83b1a5d 100644 --- a/python/pyspark/sql/pandas/types.py +++ b/python/pyspark/sql/pandas/types.py @@ -19,7 +19,8 @@ Type-specific codes between pandas and PyArrow. Also contains some utils to correct pandas instances during the type conversion. """ -from typing import Optional, TYPE_CHECKING +import itertools +from typing import Any, Callable, List, Optional, TYPE_CHECKING from pyspark.sql.types import ( cast, @@ -27,6 +28,7 @@ from pyspark.sql.types import ( ByteType, ShortType, IntegerType, + IntegralType, LongType, FloatType, DoubleType, @@ -43,10 +45,13 @@ from pyspark.sql.types import ( StructField, NullType, DataType, + Row, + _create_row, ) -from pyspark.errors import PySparkTypeError +from pyspark.errors import PySparkTypeError, UnsupportedOperationException if TYPE_CHECKING: + import pandas as pd import pyarrow as pa from pyspark.sql.pandas._typing import SeriesLike as PandasSeriesLike @@ -462,3 +467,236 @@ def _convert_dict_to_map_items(s: "PandasSeriesLike") -> "PandasSeriesLike": :return: pandas.Series of lists of (key, value) pairs """ return cast("PandasSeriesLike", s.apply(lambda d: list(d.items()) if d is not None else None)) + + +def _to_corrected_pandas_type(dt: DataType) -> Optional[Any]: + """ + When converting Spark SQL records to Pandas `pandas.DataFrame`, the inferred data type + may be wrong. This method gets the corrected data type for Pandas if that type may be + inferred incorrectly. + """ + import numpy as np + + if type(dt) == ByteType: + return np.int8 + elif type(dt) == ShortType: + return np.int16 + elif type(dt) == IntegerType: + return np.int32 + elif type(dt) == LongType: + return np.int64 + elif type(dt) == FloatType: + return np.float32 + elif type(dt) == DoubleType: + return np.float64 + elif type(dt) == BooleanType: + return bool + elif type(dt) == TimestampType: + return np.dtype("datetime64[ns]") + elif type(dt) == TimestampNTZType: + return np.dtype("datetime64[ns]") + elif type(dt) == DayTimeIntervalType: + return np.dtype("timedelta64[ns]") + else: + return None + + +def _create_converter_to_pandas( + data_type: DataType, + nullable: bool = True, + *, + timezone: Optional[str] = None, + struct_in_pandas: Optional[str] = None, + error_on_duplicated_field_names: bool = True, +) -> Callable[["pd.Series"], "pd.Series"]: + """ + Create a converter of pandas Series that is created from Spark's Python objects, + or `pyarrow.Table.to_pandas` method. + + Parameters + ---------- + data_type : :class:`DataType` + The data type corresponding to the pandas Series to be converted. + nullable : bool, optional + Whether the column is nullable or not. (default ``True``) + timezone : str, optional + The timezone to convert from. If there is a timestamp type, it's required. + struct_in_pandas : str, optional + How to handle struct type. If there is a struct type, it's required. + When ``row``, :class:`Row` object will be used. + When ``dict``, :class:`dict` will be used. If there are duplicated field names, + The fields will be suffixed, like `a_0`, `a_1`. + Must be one of: ``row``, ``dict``. + error_on_duplicated_field_names : bool, optional + Whether raise an exception when there are duplicated field names. + (default ``True``) + + Returns + ------- + The converter of `pandas.Series` + """ + import numpy as np + import pandas as pd + from pandas.core.dtypes.common import is_datetime64tz_dtype + + pandas_type = _to_corrected_pandas_type(data_type) + + if pandas_type is not None: + # SPARK-21766: if an integer field is nullable and has null values, it can be + # inferred by pandas as a float column. If we convert the column with NaN back + # to integer type e.g., np.int16, we will hit an exception. So we use the + # pandas-inferred float type, rather than the corrected type from the schema + # in this case. + if isinstance(data_type, IntegralType) and nullable: + + def correct_dtype(pser: pd.Series) -> pd.Series: + if pser.isnull().any(): + return pser.astype(np.float64, copy=False) + else: + return pser.astype(pandas_type, copy=False) + + elif isinstance(data_type, BooleanType) and nullable: + + def correct_dtype(pser: pd.Series) -> pd.Series: + if pser.isnull().any(): + return pser.astype(object, copy=False) + else: + return pser.astype(pandas_type, copy=False) + + elif isinstance(data_type, TimestampType): + assert timezone is not None + + def correct_dtype(pser: pd.Series) -> pd.Series: + if not is_datetime64tz_dtype(pser.dtype): + pser = pser.astype(pandas_type, copy=False) + return _check_series_convert_timestamps_local_tz(pser, timezone=cast(str, timezone)) + + else: + + def correct_dtype(pser: pd.Series) -> pd.Series: + return pser.astype(pandas_type, copy=False) + + return correct_dtype + + def _converter(dt: DataType) -> Optional[Callable[[Any], Any]]: + + if isinstance(dt, ArrayType): + _element_conv = _converter(dt.elementType) + if _element_conv is None: + return None + + def convert_array(value: Any) -> Any: + if value is None: + return None + elif isinstance(value, np.ndarray): + # `pyarrow.Table.to_pandas` uses `np.ndarray`. + return np.array([_element_conv(v) for v in value]) # type: ignore[misc] + else: + assert isinstance(value, list) + # otherwise, `list` should be used. + return [_element_conv(v) for v in value] # type: ignore[misc] + + return convert_array + + elif isinstance(dt, MapType): + _key_conv = _converter(dt.keyType) or (lambda x: x) + _value_conv = _converter(dt.valueType) or (lambda x: x) + + def convert_map(value: Any) -> Any: + if value is None: + return None + elif isinstance(value, list): + # `pyarrow.Table.to_pandas` uses `list` of key-value tuple. + return {_key_conv(k): _value_conv(v) for k, v in value} + else: + assert isinstance(value, dict) + # otherwise, `dict` should be used. + return {_key_conv(k): _value_conv(v) for k, v in value.items()} + + return convert_map + + elif isinstance(dt, StructType): + assert struct_in_pandas is not None + + field_names = dt.names + + if error_on_duplicated_field_names and len(set(field_names)) != len(field_names): + raise UnsupportedOperationException( + error_class="DUPLICATED_FIELD_NAME_IN_ARROW_STRUCT", + message_parameters={"field_names": str(field_names)}, + ) + + dedup_field_names = _dedup_names(field_names) + + field_convs = [_converter(f.dataType) or (lambda x: x) for f in dt.fields] + + if struct_in_pandas == "row": + + def convert_struct_as_row(value: Any) -> Any: + if value is None: + return None + elif isinstance(value, dict): + # `pyarrow.Table.to_pandas` uses `dict`. + _values = [ + field_convs[i](value.get(name, None)) + for i, name in enumerate(dedup_field_names) + ] + return _create_row(field_names, _values) + else: + assert isinstance(value, Row) + # otherwise, `Row` should be used. + _values = [field_convs[i](value[i]) for i, name in enumerate(value)] + return _create_row(field_names, _values) + + return convert_struct_as_row + + elif struct_in_pandas == "dict": + + def convert_struct_as_dict(value: Any) -> Any: + if value is None: + return None + elif isinstance(value, dict): + # `pyarrow.Table.to_pandas` uses `dict`. + return { + name: field_convs[i](value.get(name, None)) + for i, name in enumerate(dedup_field_names) + } + else: + assert isinstance(value, Row) + # otherwise, `Row` should be used. + return { + dedup_field_names[i]: field_convs[i](v) for i, v in enumerate(value) + } + + return convert_struct_as_dict + + else: + raise ValueError(f"Unknown value for `struct_in_pandas`: {struct_in_pandas}") + + else: + return None + + conv = _converter(data_type) + if conv is not None: + return lambda pser: pser.apply(conv) # type: ignore[return-value] + else: + return lambda pser: pser + + +def _dedup_names(names: List[str]) -> List[str]: + if len(set(names)) == len(names): + return names + else: + + def _gen_dedup(_name: str) -> Callable[[], str]: + _i = itertools.count() + return lambda: f"{_name}_{next(_i)}" + + def _gen_identity(_name: str) -> Callable[[], str]: + return lambda: _name + + gen_new_name = { + name: _gen_dedup(name) if len(list(group)) > 1 else _gen_identity(name) + for name, group in itertools.groupby(sorted(names)) + } + return [gen_new_name[name]() for name in names] diff --git a/python/pyspark/sql/tests/connect/test_parity_arrow.py b/python/pyspark/sql/tests/connect/test_parity_arrow.py index f2fa9ece4df..d27077f8907 100644 --- a/python/pyspark/sql/tests/connect/test_parity_arrow.py +++ b/python/pyspark/sql/tests/connect/test_parity_arrow.py @@ -106,6 +106,9 @@ class ArrowParityTests(ArrowTestsMixin, ReusedConnectTestCase): def test_toPandas_error(self): self.check_toPandas_error(True) + def test_toPandas_duplicate_field_names(self): + self.check_toPandas_duplicate_field_names(True) + if __name__ == "__main__": from pyspark.sql.tests.connect.test_parity_arrow import * # noqa: F401 diff --git a/python/pyspark/sql/tests/test_arrow.py b/python/pyspark/sql/tests/test_arrow.py index 84c782e8d95..52e13782199 100644 --- a/python/pyspark/sql/tests/test_arrow.py +++ b/python/pyspark/sql/tests/test_arrow.py @@ -55,7 +55,7 @@ from pyspark.testing.sqlutils import ( pyarrow_requirement_message, ) from pyspark.testing.utils import QuietTest -from pyspark.errors import ArithmeticException, PySparkTypeError +from pyspark.errors import ArithmeticException, PySparkTypeError, UnsupportedOperationException if have_pandas: import pandas as pd @@ -888,6 +888,60 @@ class ArrowTestsMixin: with self.assertRaises(ArithmeticException): self.spark.sql("select 1/0").toPandas() + def test_toPandas_duplicate_field_names(self): + for arrow_enabled in [True, False]: + with self.subTest(arrow_enabled=arrow_enabled): + self.check_toPandas_duplicate_field_names(arrow_enabled) + + def check_toPandas_duplicate_field_names(self, arrow_enabled): + data = [Row(Row("a", 1), Row(2, 3, "b", 4, "c")), Row(Row("x", 6), Row(7, 8, "y", 9, "z"))] + schema = ( + StructType() + .add("struct", StructType().add("x", StringType()).add("x", IntegerType())) + .add( + "struct", + StructType() + .add("a", IntegerType()) + .add("x", IntegerType()) + .add("x", StringType()) + .add("y", IntegerType()) + .add("y", StringType()), + ) + ) + for struct_in_pandas in ["legacy", "row", "dict"]: + df = self.spark.createDataFrame(data, schema=schema) + + with self.subTest(struct_in_pandas=struct_in_pandas): + with self.sql_conf( + { + "spark.sql.execution.arrow.pyspark.enabled": arrow_enabled, + "spark.sql.execution.pandas.structHandlingMode": struct_in_pandas, + } + ): + if arrow_enabled and struct_in_pandas == "legacy": + with self.assertRaisesRegexp( + UnsupportedOperationException, "DUPLICATED_FIELD_NAME_IN_ARROW_STRUCT" + ): + df.toPandas() + else: + if struct_in_pandas == "dict": + expected = pd.DataFrame( + [ + [ + {"x_0": "a", "x_1": 1}, + {"a": 2, "x_0": 3, "x_1": "b", "y_0": 4, "y_1": "c"}, + ], + [ + {"x_0": "x", "x_1": 6}, + {"a": 7, "x_0": 8, "x_1": "y", "y_0": 9, "y_1": "z"}, + ], + ], + columns=schema.names, + ) + else: + expected = pd.DataFrame.from_records(data, columns=schema.names) + assert_frame_equal(df.toPandas(), expected) + @unittest.skipIf( not have_pandas or not have_pyarrow, diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 6c7751c522a..70d90a03c10 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1687,6 +1687,20 @@ def _has_nulltype(dt: DataType) -> bool: return isinstance(dt, NullType) +def _has_type(dt: DataType, dts: Union[type, Tuple[type, ...]]) -> bool: + """Return whether there are specified types""" + if isinstance(dt, dts): + return True + elif isinstance(dt, StructType): + return any(_has_type(f.dataType, dts) for f in dt.fields) + elif isinstance(dt, ArrayType): + return _has_type(dt.elementType, dts) + elif isinstance(dt, MapType): + return _has_type(dt.keyType, dts) or _has_type(dt.valueType, dts) + else: + return False + + @overload def _merge_type(a: StructType, b: StructType, name: Optional[str] = None) -> StructType: ... diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala index e7d310c25c2..111f0391a72 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala @@ -1144,6 +1144,13 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase { messageParameters = Map("typeName" -> toSQLType(typeName))) } + def duplicatedFieldNameInArrowStructError( + fieldNames: Seq[String]): SparkUnsupportedOperationException = { + new SparkUnsupportedOperationException( + errorClass = "DUPLICATED_FIELD_NAME_IN_ARROW_STRUCT", + messageParameters = Map("fieldNames" -> fieldNames.mkString("[", ", ", "]"))) + } + def notSupportTypeError(dataType: DataType): Throwable = { new SparkException( errorClass = "_LEGACY_ERROR_TEMP_2100", diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala index af7126495c5..efdbc583207 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala @@ -30,8 +30,11 @@ import org.apache.spark.sql.util.ArrowUtils object ArrowWriter { - def create(schema: StructType, timeZoneId: String): ArrowWriter = { - val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId) + def create( + schema: StructType, + timeZoneId: String, + errorOnDuplicatedFieldNames: Boolean = true): ArrowWriter = { + val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId, errorOnDuplicatedFieldNames) val root = VectorSchemaRoot.create(arrowSchema, ArrowUtils.rootAllocator) create(root) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 874f95af1cb..c9974d2dfa8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -2835,6 +2835,23 @@ object SQLConf { .version("3.0.0") .fallbackConf(BUFFER_SIZE) + val PANDAS_STRUCT_HANDLING_MODE = + buildConf("spark.sql.execution.pandas.structHandlingMode") + .doc( + "The conversion mode of struct type when creating pandas DataFrame. " + + "When \"legacy\"," + + "1. when Arrow optimization is disabled, convert to Row object, " + + "2. when Arrow optimization is enabled, convert to dict or raise an Exception " + + "if there are duplicated nested field names. " + + "When \"row\", convert to Row object regardless of Arrow optimization. " + + "When \"dict\", convert to dict and use suffixed key names, e.g., a_0, a_1, " + + "if there are duplicated nested field names, regardless of Arrow optimization." + ) + .version("3.5.0") + .stringConf + .checkValues(Set("legacy", "row", "dict")) + .createWithDefaultString("legacy") + val PYSPARK_SIMPLIFIED_TRACEBACK = buildConf("spark.sql.execution.pyspark.udf.simplifiedTraceback.enabled") .doc( @@ -4855,6 +4872,8 @@ class SQLConf extends Serializable with Logging { def pandasUDFBufferSize: Int = getConf(PANDAS_UDF_BUFFER_SIZE) + def pandasStructHandlingMode: String = getConf(PANDAS_STRUCT_HANDLING_MODE) + def pysparkSimplifiedTraceback: Boolean = getConf(PYSPARK_SIMPLIFIED_TRACEBACK) def pandasGroupedMapAssignColumnsByName: Boolean = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala index d6a8fec81dd..719691a338f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.util +import java.util.concurrent.atomic.AtomicInteger + import scala.collection.JavaConverters._ import org.apache.arrow.memory.RootAllocator @@ -135,9 +137,16 @@ private[sql] object ArrowUtils { } /** Maps schema from Spark to Arrow. NOTE: timeZoneId required for TimestampType in StructType */ - def toArrowSchema(schema: StructType, timeZoneId: String): Schema = { + def toArrowSchema( + schema: StructType, + timeZoneId: String, + errorOnDuplicatedFieldNames: Boolean): Schema = { new Schema(schema.map { field => - toArrowField(field.name, field.dataType, field.nullable, timeZoneId) + toArrowField( + field.name, + deduplicateFieldNames(field.dataType, errorOnDuplicatedFieldNames), + field.nullable, + timeZoneId) }.asJava) } @@ -157,4 +166,38 @@ private[sql] object ArrowUtils { conf.arrowSafeTypeConversion.toString) Map(timeZoneConf ++ pandasColsByName ++ arrowSafeTypeCheck: _*) } + + private def deduplicateFieldNames( + dt: DataType, errorOnDuplicatedFieldNames: Boolean): DataType = dt match { + case udt: UserDefinedType[_] => deduplicateFieldNames(udt.sqlType, errorOnDuplicatedFieldNames) + case st @ StructType(fields) => + val newNames = if (st.names.toSet.size == st.names.length) { + st.names + } else { + if (errorOnDuplicatedFieldNames) { + throw QueryExecutionErrors.duplicatedFieldNameInArrowStructError(st.names) + } + val genNawName = st.names.groupBy(identity).map { + case (name, names) if names.length > 1 => + val i = new AtomicInteger() + name -> { () => s"${name}_${i.getAndIncrement()}" } + case (name, _) => name -> { () => name } + } + st.names.map(genNawName(_)()) + } + val newFields = + fields.zip(newNames).map { case (StructField(_, dataType, nullable, metadata), name) => + StructField( + name, deduplicateFieldNames(dataType, errorOnDuplicatedFieldNames), nullable, metadata) + } + StructType(newFields) + case ArrayType(elementType, containsNull) => + ArrayType(deduplicateFieldNames(elementType, errorOnDuplicatedFieldNames), containsNull) + case MapType(keyType, valueType, valueContainsNull) => + MapType( + deduplicateFieldNames(keyType, errorOnDuplicatedFieldNames), + deduplicateFieldNames(valueType, errorOnDuplicatedFieldNames), + valueContainsNull) + case _ => dt + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/util/ArrowUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/util/ArrowUtilsSuite.scala index 2f78d03db80..28ed061a71b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/util/ArrowUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/util/ArrowUtilsSuite.scala @@ -30,7 +30,7 @@ class ArrowUtilsSuite extends SparkFunSuite { def roundtrip(dt: DataType): Unit = { dt match { case schema: StructType => - assert(ArrowUtils.fromArrowSchema(ArrowUtils.toArrowSchema(schema, null)) === schema) + assert(ArrowUtils.fromArrowSchema(ArrowUtils.toArrowSchema(schema, null, true)) === schema) case _ => roundtrip(new StructType().add("value", dt)) } @@ -67,7 +67,7 @@ class ArrowUtilsSuite extends SparkFunSuite { def roundtripWithTz(timeZoneId: String): Unit = { val schema = new StructType().add("value", TimestampType) - val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId) + val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId, true) val fieldType = arrowSchema.findField("value").getType.asInstanceOf[ArrowType.Timestamp] assert(fieldType.getTimezone() === timeZoneId) assert(ArrowUtils.fromArrowSchema(arrowSchema) === schema) @@ -97,4 +97,25 @@ class ArrowUtilsSuite extends SparkFunSuite { "struct", new StructType().add("i", IntegerType).add("arr", ArrayType(IntegerType)))) } + + test("struct with duplicated field names") { + + def check(dt: DataType, expected: DataType): Unit = { + val schema = new StructType().add("value", dt) + intercept[SparkUnsupportedOperationException] { + ArrowUtils.toArrowSchema(schema, null, true) + } + assert(ArrowUtils.fromArrowSchema(ArrowUtils.toArrowSchema(schema, null, false)) + === new StructType().add("value", expected)) + } + + roundtrip(new StructType().add("i", IntegerType).add("i", StringType)) + + check(new StructType().add("i", IntegerType).add("i", StringType), + new StructType().add("i_0", IntegerType).add("i_1", StringType)) + check(ArrayType(new StructType().add("i", IntegerType).add("i", StringType)), + ArrayType(new StructType().add("i_0", IntegerType).add("i_1", StringType))) + check(MapType(StringType, new StructType().add("i", IntegerType).add("i", StringType)), + MapType(StringType, new StructType().add("i_0", IntegerType).add("i_1", StringType))) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 7973ba38b1a..7137e24ed96 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -4173,7 +4173,8 @@ class Dataset[T] private[sql]( withAction("collectAsArrowToR", queryExecution) { plan => val buffer = new ByteArrayOutputStream() val out = new DataOutputStream(outputStream) - val batchWriter = new ArrowBatchStreamWriter(schema, buffer, timeZoneId) + val batchWriter = + new ArrowBatchStreamWriter(schema, buffer, timeZoneId, errorOnDuplicatedFieldNames = true) val arrowBatchRdd = toArrowBatchRdd(plan) val numPartitions = arrowBatchRdd.partitions.length @@ -4222,11 +4223,14 @@ class Dataset[T] private[sql]( */ private[sql] def collectAsArrowToPython: Array[Any] = { val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone + val errorOnDuplicatedFieldNames = + sparkSession.sessionState.conf.pandasStructHandlingMode == "legacy" PythonRDD.serveToStream("serve-Arrow") { outputStream => withAction("collectAsArrowToPython", queryExecution) { plan => val out = new DataOutputStream(outputStream) - val batchWriter = new ArrowBatchStreamWriter(schema, out, timeZoneId) + val batchWriter = + new ArrowBatchStreamWriter(schema, out, timeZoneId, errorOnDuplicatedFieldNames) // Batches ordered by (index of partition, batch index in that partition) tuple val batchOrder = ArrayBuffer.empty[(Int, Int)] @@ -4355,10 +4359,12 @@ class Dataset[T] private[sql]( val schemaCaptured = this.schema val maxRecordsPerBatch = sparkSession.sessionState.conf.arrowMaxRecordsPerBatch val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone + val errorOnDuplicatedFieldNames = + sparkSession.sessionState.conf.pandasStructHandlingMode == "legacy" plan.execute().mapPartitionsInternal { iter => val context = TaskContext.get() ArrowConverters.toBatchIterator( - iter, schemaCaptured, maxRecordsPerBatch, timeZoneId, context) + iter, schemaCaptured, maxRecordsPerBatch, timeZoneId, errorOnDuplicatedFieldNames, context) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala index b22c80d17e8..8d6d2c78051 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala @@ -48,9 +48,10 @@ import org.apache.spark.util.{ByteBufferOutputStream, SizeEstimator, Utils} private[sql] class ArrowBatchStreamWriter( schema: StructType, out: OutputStream, - timeZoneId: String) { + timeZoneId: String, + errorOnDuplicatedFieldNames: Boolean) { - val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId) + val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId, errorOnDuplicatedFieldNames) val writeChannel = new WriteChannel(Channels.newChannel(out)) // Write the Arrow schema first, before batches @@ -77,9 +78,11 @@ private[sql] object ArrowConverters extends Logging { schema: StructType, maxRecordsPerBatch: Long, timeZoneId: String, + errorOnDuplicatedFieldNames: Boolean, context: TaskContext) extends Iterator[Array[Byte]] { - protected val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId) + protected val arrowSchema = + ArrowUtils.toArrowSchema(schema, timeZoneId, errorOnDuplicatedFieldNames) private val allocator = ArrowUtils.rootAllocator.newChildAllocator( s"to${this.getClass.getSimpleName}", 0, Long.MaxValue) @@ -128,9 +131,10 @@ private[sql] object ArrowConverters extends Logging { maxRecordsPerBatch: Long, maxEstimatedBatchSize: Long, timeZoneId: String, + errorOnDuplicatedFieldNames: Boolean, context: TaskContext) extends ArrowBatchIterator( - rowIter, schema, maxRecordsPerBatch, timeZoneId, context) { + rowIter, schema, maxRecordsPerBatch, timeZoneId, errorOnDuplicatedFieldNames, context) { private val arrowSchemaSize = SizeEstimator.estimate(arrowSchema) var rowCountInLastBatch: Long = 0 @@ -190,9 +194,10 @@ private[sql] object ArrowConverters extends Logging { schema: StructType, maxRecordsPerBatch: Long, timeZoneId: String, + errorOnDuplicatedFieldNames: Boolean, context: TaskContext): ArrowBatchIterator = { new ArrowBatchIterator( - rowIter, schema, maxRecordsPerBatch, timeZoneId, context) + rowIter, schema, maxRecordsPerBatch, timeZoneId, errorOnDuplicatedFieldNames, context) } /** @@ -204,16 +209,20 @@ private[sql] object ArrowConverters extends Logging { schema: StructType, maxRecordsPerBatch: Long, maxEstimatedBatchSize: Long, - timeZoneId: String): ArrowBatchWithSchemaIterator = { + timeZoneId: String, + errorOnDuplicatedFieldNames: Boolean): ArrowBatchWithSchemaIterator = { new ArrowBatchWithSchemaIterator( - rowIter, schema, maxRecordsPerBatch, maxEstimatedBatchSize, timeZoneId, TaskContext.get) + rowIter, schema, maxRecordsPerBatch, maxEstimatedBatchSize, + timeZoneId, errorOnDuplicatedFieldNames, TaskContext.get) } private[sql] def createEmptyArrowBatch( schema: StructType, - timeZoneId: String): Array[Byte] = { + timeZoneId: String, + errorOnDuplicatedFieldNames: Boolean): Array[Byte] = { new ArrowBatchWithSchemaIterator( - Iterator.empty, schema, 0L, 0L, timeZoneId, TaskContext.get) { + Iterator.empty, schema, 0L, 0L, + timeZoneId, errorOnDuplicatedFieldNames, TaskContext.get) { override def hasNext: Boolean = true }.next() } @@ -275,7 +284,8 @@ private[sql] object ArrowConverters extends Logging { extends InternalRowIterator(arrowBatchIter, context) { override def nextBatch(): (Iterator[InternalRow], StructType) = { - val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId) + val arrowSchema = + ArrowUtils.toArrowSchema(schema, timeZoneId, errorOnDuplicatedFieldNames = true) val root = VectorSchemaRoot.create(arrowSchema, allocator) resources.append(root) val arrowRecordBatch = ArrowConverters.loadBatch(arrowBatchIter.next(), allocator) @@ -353,9 +363,6 @@ private[sql] object ArrowConverters extends Logging { new ReadChannel(Channels.newChannel(in)), allocator) // throws IOException } - /** - * Create a DataFrame from an iterator of serialized ArrowRecordBatches. - */ /** * Create a DataFrame from an iterator of serialized ArrowRecordBatches. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala index 79773a9d534..d4b6f0db9c9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala @@ -70,6 +70,8 @@ class ApplyInPandasWithStatePythonRunner( override protected val schema: StructType = inputSchema.add("__state", STATE_METADATA_SCHEMA) + override val errorOnDuplicatedFieldNames: Boolean = true + override val simplifiedTraceback: Boolean = sqlConf.pysparkSimplifiedTraceback override val bufferSize: Int = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala index dbafc444281..427dfbbb32d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala @@ -39,6 +39,8 @@ class ArrowPythonRunner( with BasicPythonArrowInput with BasicPythonArrowOutput { + override val errorOnDuplicatedFieldNames: Boolean = true + override val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback override val bufferSize: Int = SQLConf.get.pandasUDFBufferSize diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala index 1df9f37188a..763a6d08d78 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala @@ -96,7 +96,8 @@ class CoGroupedArrowPythonRunner( schema: StructType, dataOut: DataOutputStream, name: String): Unit = { - val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId) + val arrowSchema = + ArrowUtils.toArrowSchema(schema, timeZoneId, errorOnDuplicatedFieldNames = true) val allocator = ArrowUtils.rootAllocator.newChildAllocator( s"stdout writer for $pythonExec ($name)", 0, Long.MaxValue) val root = VectorSchemaRoot.create(arrowSchema, allocator) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala index 5a0541d11cb..26ce10b6aae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala @@ -42,6 +42,8 @@ private[python] trait PythonArrowInput[IN] { self: BasePythonRunner[IN, _] => protected val timeZoneId: String + protected val errorOnDuplicatedFieldNames: Boolean + protected def pythonMetrics: Map[String, SQLMetric] protected def writeIteratorToArrowStream( @@ -73,7 +75,7 @@ private[python] trait PythonArrowInput[IN] { self: BasePythonRunner[IN, _] => } protected override def writeIteratorToStream(dataOut: DataOutputStream): Unit = { - val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId) + val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId, errorOnDuplicatedFieldNames) val allocator = ArrowUtils.rootAllocator.newChildAllocator( s"stdout writer for $pythonExec", 0, Long.MaxValue) val root = VectorSchemaRoot.create(arrowSchema, allocator) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/r/ArrowRRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/r/ArrowRRunner.scala index ae7b7ef2351..69faa4c8fec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/r/ArrowRRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/r/ArrowRRunner.scala @@ -83,7 +83,8 @@ class ArrowRRunner( */ override protected def writeIteratorToStream(dataOut: DataOutputStream): Unit = { if (inputIterator.hasNext) { - val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId) + val arrowSchema = + ArrowUtils.toArrowSchema(schema, timeZoneId, errorOnDuplicatedFieldNames = true) val allocator = ArrowUtils.rootAllocator.newChildAllocator( "stdout writer for R", 0, Long.MaxValue) val root = VectorSchemaRoot.create(arrowSchema, allocator) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala index 82e4c970837..e458d33a7cd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala @@ -1381,7 +1381,7 @@ class ArrowConvertersSuite extends SharedSparkSession { val schema = StructType(Seq(StructField("int", IntegerType, nullable = true))) val ctx = TaskContext.empty() - val batchIter = ArrowConverters.toBatchIterator(inputRows.iterator, schema, 5, null, ctx) + val batchIter = ArrowConverters.toBatchIterator(inputRows.iterator, schema, 5, null, true, ctx) val outputRowIter = ArrowConverters.fromBatchIterator(batchIter, schema, null, ctx) var count = 0 @@ -1402,12 +1402,12 @@ class ArrowConvertersSuite extends SharedSparkSession { val schema = StructType(Seq(StructField("int", IntegerType, nullable = true))) val ctx = TaskContext.empty() - val batchIter = ArrowConverters.toBatchIterator(inputRows.iterator, schema, 5, null, ctx) + val batchIter = ArrowConverters.toBatchIterator(inputRows.iterator, schema, 5, null, true, ctx) // Write batches to Arrow stream format as a byte array val out = new ByteArrayOutputStream() Utils.tryWithResource(new DataOutputStream(out)) { dataOut => - val writer = new ArrowBatchStreamWriter(schema, dataOut, null) + val writer = new ArrowBatchStreamWriter(schema, dataOut, null, true) writer.writeBatches(batchIter) writer.end() } @@ -1445,8 +1445,8 @@ class ArrowConvertersSuite extends SharedSparkSession { proj(row).copy() } val ctx = TaskContext.empty() - val batchIter = - ArrowConverters.toBatchWithSchemaIterator(inputRows.iterator, schema, 5, 1024 * 1024, null) + val batchIter = ArrowConverters.toBatchWithSchemaIterator( + inputRows.iterator, schema, 5, 1024 * 1024, null, true) val (outputRowIter, outputType) = ArrowConverters.fromBatchWithSchemaIterator(batchIter, ctx) var count = 0 @@ -1465,7 +1465,7 @@ class ArrowConvertersSuite extends SharedSparkSession { val schema = StructType(Seq(StructField("int", IntegerType, nullable = true))) val ctx = TaskContext.empty() val batchIter = - ArrowConverters.toBatchWithSchemaIterator(Iterator.empty, schema, 5, 1024 * 1024, null) + ArrowConverters.toBatchWithSchemaIterator(Iterator.empty, schema, 5, 1024 * 1024, null, true) val (outputRowIter, outputType) = ArrowConverters.fromBatchWithSchemaIterator(batchIter, ctx) assert(0 == outputRowIter.length) @@ -1479,7 +1479,7 @@ class ArrowConvertersSuite extends SharedSparkSession { proj(row).copy() } val batchIter1 = ArrowConverters.toBatchWithSchemaIterator( - inputRows1.iterator, schema1, 5, 1024 * 1024, null) + inputRows1.iterator, schema1, 5, 1024 * 1024, null, true) val schema2 = StructType(Seq(StructField("field2", IntegerType, nullable = true))) val inputRows2 = Array(InternalRow(1)).map { row => @@ -1487,7 +1487,7 @@ class ArrowConvertersSuite extends SharedSparkSession { proj(row).copy() } val batchIter2 = ArrowConverters.toBatchWithSchemaIterator( - inputRows2.iterator, schema2, 5, 1024 * 1024, null) + inputRows2.iterator, schema2, 5, 1024 * 1024, null, true) val iter = batchIter1.toArray ++ batchIter2 @@ -1499,23 +1499,28 @@ class ArrowConvertersSuite extends SharedSparkSession { /** Test that a converted DataFrame to Arrow record batch equals batch read from JSON file */ private def collectAndValidate( - df: DataFrame, json: String, file: String, timeZoneId: String = null): Unit = { + df: DataFrame, + json: String, + file: String, + timeZoneId: String = null, + errorOnDuplicatedFieldNames: Boolean = true): Unit = { // NOTE: coalesce to single partition because can only load 1 batch in validator val batchBytes = df.coalesce(1).toArrowBatchRdd.collect().head val tempFile = new File(tempDataPath, file) Files.write(json, tempFile, StandardCharsets.UTF_8) - validateConversion(df.schema, batchBytes, tempFile, timeZoneId) + validateConversion(df.schema, batchBytes, tempFile, timeZoneId, errorOnDuplicatedFieldNames) } private def validateConversion( sparkSchema: StructType, batchBytes: Array[Byte], jsonFile: File, - timeZoneId: String = null): Unit = { + timeZoneId: String = null, + errorOnDuplicatedFieldNames: Boolean = true): Unit = { val allocator = new RootAllocator(Long.MaxValue) val jsonReader = new JsonFileReader(jsonFile, allocator) - val arrowSchema = ArrowUtils.toArrowSchema(sparkSchema, timeZoneId) + val arrowSchema = ArrowUtils.toArrowSchema(sparkSchema, timeZoneId, errorOnDuplicatedFieldNames) val jsonSchema = jsonReader.start() Validator.compareSchemas(arrowSchema, jsonSchema) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org