This is an automated email from the ASF dual-hosted git repository.
xinrong pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 305aa4a89ef [SPARK-41971][SQL][PYTHON] Add a config for pandas
conversion how to handle struct types
305aa4a89ef is described below
commit 305aa4a89efe02f517f82039225a99b31b20146f
Author: Takuya UESHIN <[email protected]>
AuthorDate: Thu May 4 11:01:28 2023 -0700
[SPARK-41971][SQL][PYTHON] Add a config for pandas conversion how to handle
struct types
### What changes were proposed in this pull request?
Adds a config for pandas conversion how to handle struct types.
- `spark.sql.execution.pandas.structHandlingMode` (default: `"legacy"`)
The conversion mode of struct type when creating pandas DataFrame.
#### When `"legacy"`, the behavior is the same as before, except that with
Arrow and Spark Connect will raise a more readable exception when there are
duplicated nested field names.
```py
>>> spark.sql("values (1, struct(1 as a, 2 as a)) as t(x, y)").toPandas()
Traceback (most recent call last):
...
pyspark.errors.exceptions.connect.UnsupportedOperationException:
[DUPLICATED_FIELD_NAME_IN_ARROW_STRUCT] Duplicated field names in Arrow Struct
are not allowed, got [a, a].
```
#### When `"row"`, convert to Row object regardless of Arrow optimization.
```py
>>> spark.conf.set('spark.sql.execution.pandas.structHandlingMode', 'row')
>>> spark.conf.set('spark.sql.execution.arrow.pyspark.enabled', False)
>>> spark.sql("values (1, struct(1 as a, 2 as b)) as t(x, y)").toPandas()
x y
0 1 (1, 2)
>>> spark.sql("values (1, struct(1 as a, 2 as a)) as t(x, y)").toPandas()
x y
0 1 (1, 2)
>>> spark.conf.set('spark.sql.execution.arrow.pyspark.enabled', True)
>>> spark.sql("values (1, struct(1 as a, 2 as b)) as t(x, y)").toPandas()
x y
0 1 (1, 2)
>>> spark.sql("values (1, struct(1 as a, 2 as a)) as t(x, y)").toPandas()
x y
0 1 (1, 2)
```
#### When `"dict"`, convert to dict and use suffixed key names, e.g.,
`a_0`, `a_1`, if there are duplicated nested field names, regardless of Arrow
optimization.
```py
>>> spark.conf.set('spark.sql.execution.pandas.structHandlingMode', 'dict')
>>> spark.conf.set('spark.sql.execution.arrow.pyspark.enabled', False)
>>> spark.sql("values (1, struct(1 as a, 2 as b)) as t(x, y)").toPandas()
x y
0 1 {'a': 1, 'b': 2}
>>> spark.sql("values (1, struct(1 as a, 2 as a)) as t(x, y)").toPandas()
x y
0 1 {'a_0': 1, 'a_1': 2}
>>> spark.conf.set('spark.sql.execution.arrow.pyspark.enabled', True)
>>> spark.sql("values (1, struct(1 as a, 2 as b)) as t(x, y)").toPandas()
x y
0 1 {'a': 1, 'b': 2}
>>> spark.sql("values (1, struct(1 as a, 2 as a)) as t(x, y)").toPandas()
x y
0 1 {'a_0': 1, 'a_1': 2}
```
### Why are the changes needed?
Currently there are three behaviors when `df.toPandas()` with nested struct
types:
- vanilla PySpark with Arrow optimization disabled
```py
>>> spark.conf.set('spark.sql.execution.arrow.pyspark.enabled', False)
>>> spark.sql("values (1, struct(1 as a, 2 as b)) as t(x, y)").toPandas()
x y
0 1 (1, 2)
```
using `Row` object for struct types.
It can use duplicated field names.
```py
>>> spark.sql("values (1, struct(1 as a, 2 as a)) as t(x, y)").toPandas()
x y
0 1 (1, 2)
```
- vanilla PySpark with Arrow optimization enabled
```py
>>> spark.conf.set('spark.sql.execution.arrow.pyspark.enabled', True)
>>> spark.sql("values (1, struct(1 as a, 2 as b)) as t(x, y)").toPandas()
x y
0 1 {'a': 1, 'b': 2}
```
using `dict` for struct types.
It raises an Exception when there are duplicated nested field names:
```py
>>> spark.sql("values (1, struct(1 as a, 2 as a)) as t(x, y)").toPandas()
Traceback (most recent call last):
...
pyarrow.lib.ArrowInvalid: Ran out of field metadata, likely malformed
```
- Spark Connect
```py
>>> spark.sql("values (1, struct(1 as a, 2 as b)) as t(x, y)").toPandas()
x y
0 1 {'a': 1, 'b': 2}
```
using `dict` for struct types.
If there are duplicated nested field names, the duplicated keys are
suffixed:
```py
>>> spark.sql("values (1, struct(1 as a, 2 as a)) as t(x, y)").toPandas()
x y
0 1 {'a_0': 1, 'a_1': 2}
```
### Does this PR introduce _any_ user-facing change?
Users will be able to configure the behavior.
### How was this patch tested?
Modified the related tests.
Closes #40988 from ueshin/issues/SPARK-41971/struct_in_pandas.
Authored-by: Takuya UESHIN <[email protected]>
Signed-off-by: Xinrong Meng <[email protected]>
---
.../scala/org/apache/spark/sql/SparkSession.scala | 3 +-
.../sql/connect/client/util/ConvertToArrow.scala | 4 +-
.../sql/connect/planner/SparkConnectPlanner.scala | 8 +-
.../service/SparkConnectStreamHandler.scala | 57 ++---
.../connect/planner/SparkConnectPlannerSuite.scala | 5 +-
.../connect/planner/SparkConnectProtoSuite.scala | 3 +-
core/src/main/resources/error/error-classes.json | 5 +
python/pyspark/errors/__init__.py | 2 +
python/pyspark/errors/error_classes.py | 5 +
python/pyspark/errors/exceptions/base.py | 6 +
python/pyspark/errors/exceptions/captured.py | 9 +
python/pyspark/errors/exceptions/connect.py | 9 +
python/pyspark/sql/connect/client.py | 47 +++-
python/pyspark/sql/connect/conversion.py | 21 +-
python/pyspark/sql/connect/dataframe.py | 4 +-
python/pyspark/sql/connect/session.py | 10 +-
python/pyspark/sql/pandas/conversion.py | 204 +++++------------
python/pyspark/sql/pandas/types.py | 242 ++++++++++++++++++++-
.../pyspark/sql/tests/connect/test_parity_arrow.py | 3 +
python/pyspark/sql/tests/test_arrow.py | 56 ++++-
python/pyspark/sql/types.py | 14 ++
.../spark/sql/errors/QueryExecutionErrors.scala | 7 +
.../spark/sql/execution/arrow/ArrowWriter.scala | 7 +-
.../org/apache/spark/sql/internal/SQLConf.scala | 19 ++
.../org/apache/spark/sql/util/ArrowUtils.scala | 47 +++-
.../apache/spark/sql/util/ArrowUtilsSuite.scala | 25 ++-
.../main/scala/org/apache/spark/sql/Dataset.scala | 12 +-
.../sql/execution/arrow/ArrowConverters.scala | 33 +--
.../ApplyInPandasWithStatePythonRunner.scala | 2 +
.../sql/execution/python/ArrowPythonRunner.scala | 2 +
.../python/CoGroupedArrowPythonRunner.scala | 3 +-
.../sql/execution/python/PythonArrowInput.scala | 4 +-
.../spark/sql/execution/r/ArrowRRunner.scala | 3 +-
.../sql/execution/arrow/ArrowConvertersSuite.scala | 29 ++-
34 files changed, 629 insertions(+), 281 deletions(-)
diff --git
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
index 48640878211..a8bfac5d71f 100644
---
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
+++
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
@@ -121,7 +121,8 @@ class SparkSession private[sql] (
newDataset(encoder) { builder =>
if (data.nonEmpty) {
val timeZoneId = conf.get("spark.sql.session.timeZone")
- val (arrowData, arrowDataSize) = ConvertToArrow(encoder, data,
timeZoneId, allocator)
+ val (arrowData, arrowDataSize) =
+ ConvertToArrow(encoder, data, timeZoneId,
errorOnDuplicatedFieldNames = true, allocator)
if (arrowDataSize <=
conf.get("spark.sql.session.localRelationCacheThreshold").toInt) {
builder.getLocalRelationBuilder
.setSchema(encoder.schema.json)
diff --git
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/util/ConvertToArrow.scala
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/util/ConvertToArrow.scala
index 46a9493d138..14235094f2d 100644
---
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/util/ConvertToArrow.scala
+++
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/util/ConvertToArrow.scala
@@ -40,8 +40,10 @@ private[sql] object ConvertToArrow {
encoder: AgnosticEncoder[T],
data: Iterator[T],
timeZoneId: String,
+ errorOnDuplicatedFieldNames: Boolean,
bufferAllocator: BufferAllocator): (ByteString, Int) = {
- val arrowSchema = ArrowUtils.toArrowSchema(encoder.schema, timeZoneId)
+ val arrowSchema =
+ ArrowUtils.toArrowSchema(encoder.schema, timeZoneId,
errorOnDuplicatedFieldNames)
val root = VectorSchemaRoot.create(arrowSchema, bufferAllocator)
val writer: ArrowWriter = ArrowWriter.create(root)
val unloader = new VectorUnloader(root)
diff --git
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 4af604ed4b9..36af01d251c 100644
---
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -1993,14 +1993,18 @@ class SparkConnectPlanner(val session: SparkSession) {
// Convert the data.
val bytes = if (rows.isEmpty) {
- ArrowConverters.createEmptyArrowBatch(schema, timeZoneId)
+ ArrowConverters.createEmptyArrowBatch(
+ schema,
+ timeZoneId,
+ errorOnDuplicatedFieldNames = false)
} else {
val batches = ArrowConverters.toBatchWithSchemaIterator(
rows.iterator,
schema,
maxRecordsPerBatch,
maxBatchSize,
- timeZoneId)
+ timeZoneId,
+ errorOnDuplicatedFieldNames = false)
assert(batches.hasNext)
val bytes = batches.next()
assert(!batches.hasNext, s"remaining batches: ${batches.size}")
diff --git
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala
index 062ef892979..4958fd69b9d 100644
---
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala
+++
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala
@@ -17,8 +17,6 @@
package org.apache.spark.sql.connect.service
-import java.util.concurrent.atomic.AtomicInteger
-
import scala.collection.JavaConverters._
import scala.util.control.NonFatal
@@ -41,7 +39,7 @@ import
org.apache.spark.sql.connect.service.SparkConnectStreamHandler.processAsA
import org.apache.spark.sql.execution.{LocalTableScanExec, SparkPlan,
SQLExecution}
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec,
AdaptiveSparkPlanHelper, QueryStageExec}
import org.apache.spark.sql.execution.arrow.ArrowConverters
-import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructField,
StructType, UserDefinedType}
+import org.apache.spark.sql.types.StructType
import org.apache.spark.util.{ThreadUtils, Utils}
class SparkConnectStreamHandler(responseObserver:
StreamObserver[ExecutePlanResponse])
@@ -131,13 +129,15 @@ object SparkConnectStreamHandler {
schema: StructType,
maxRecordsPerBatch: Int,
maxBatchSize: Long,
- timeZoneId: String): Iterator[InternalRow] => Iterator[Batch] = { rows =>
+ timeZoneId: String,
+ errorOnDuplicatedFieldNames: Boolean): Iterator[InternalRow] =>
Iterator[Batch] = { rows =>
val batches = ArrowConverters.toBatchWithSchemaIterator(
rows,
schema,
maxRecordsPerBatch,
maxBatchSize,
- timeZoneId)
+ timeZoneId,
+ errorOnDuplicatedFieldNames)
batches.map(b => b -> batches.rowCountInLastBatch)
}
@@ -145,45 +145,19 @@ object SparkConnectStreamHandler {
sessionId: String,
dataframe: DataFrame,
responseObserver: StreamObserver[ExecutePlanResponse]): Unit = {
-
- def deduplicateFieldNames(dt: DataType): DataType = dt match {
- case udt: UserDefinedType[_] => deduplicateFieldNames(udt.sqlType)
- case st @ StructType(fields) =>
- val newNames = if (st.names.toSet.size == st.names.length) {
- st.names
- } else {
- val genNawName = st.names.groupBy(identity).map {
- case (name, names) if names.length > 1 =>
- val i = new AtomicInteger()
- name -> { () => s"${name}_${i.getAndIncrement()}" }
- case (name, _) => name -> { () => name }
- }
- st.names.map(genNawName(_)())
- }
- val newFields =
- fields.zip(newNames).map { case (StructField(_, dataType, nullable,
metadata), name) =>
- StructField(name, deduplicateFieldNames(dataType), nullable,
metadata)
- }
- StructType(newFields)
- case ArrayType(elementType, containsNull) =>
- ArrayType(deduplicateFieldNames(elementType), containsNull)
- case MapType(keyType, valueType, valueContainsNull) =>
- MapType(
- deduplicateFieldNames(keyType),
- deduplicateFieldNames(valueType),
- valueContainsNull)
- case _ => dt
- }
-
val spark = dataframe.sparkSession
- val schema =
deduplicateFieldNames(dataframe.schema).asInstanceOf[StructType]
+ val schema = dataframe.schema
val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch
val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone
// Conservatively sets it 70% because the size is not accurate but
estimated.
val maxBatchSize =
(SparkEnv.get.conf.get(CONNECT_GRPC_ARROW_MAX_BATCH_SIZE) * 0.7).toLong
- val rowToArrowConverter = SparkConnectStreamHandler
- .rowToArrowConverter(schema, maxRecordsPerBatch, maxBatchSize,
timeZoneId)
+ val rowToArrowConverter = SparkConnectStreamHandler.rowToArrowConverter(
+ schema,
+ maxRecordsPerBatch,
+ maxBatchSize,
+ timeZoneId,
+ errorOnDuplicatedFieldNames = false)
var numSent = 0
def sendBatch(bytes: Array[Byte], count: Long): Unit = {
@@ -279,7 +253,12 @@ object SparkConnectStreamHandler {
// Make sure at least 1 batch will be sent.
if (numSent == 0) {
- sendBatch(ArrowConverters.createEmptyArrowBatch(schema, timeZoneId), 0L)
+ sendBatch(
+ ArrowConverters.createEmptyArrowBatch(
+ schema,
+ timeZoneId,
+ errorOnDuplicatedFieldNames = false),
+ 0L)
}
}
diff --git
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala
index 88b4be16e5a..37d4bec9c87 100644
---
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala
+++
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala
@@ -89,7 +89,8 @@ trait SparkConnectPlanTest extends SharedSparkSession {
StructType.fromAttributes(attrs.map(_.toAttribute)),
Long.MaxValue,
Long.MaxValue,
- null)
+ null,
+ true)
.next()
localRelationBuilder.setData(ByteString.copyFrom(bytes))
@@ -464,7 +465,7 @@ class SparkConnectPlannerSuite extends SparkFunSuite with
SparkConnectPlanTest {
test("Empty ArrowBatch") {
val schema = StructType(Seq(StructField("int", IntegerType)))
- val data = ArrowConverters.createEmptyArrowBatch(schema, null)
+ val data = ArrowConverters.createEmptyArrowBatch(schema, null, true)
val localRelation = proto.Relation
.newBuilder()
.setLocalRelation(
diff --git
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
index 96dae647db6..8cb5c1a2919 100644
---
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
+++
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
@@ -1042,7 +1042,8 @@ class SparkConnectProtoSuite extends PlanTest with
SparkConnectPlanTest {
StructType.fromAttributes(attributes),
Long.MaxValue,
Long.MaxValue,
- null)
+ null,
+ true)
.next()
proto.Relation
.newBuilder()
diff --git a/core/src/main/resources/error/error-classes.json
b/core/src/main/resources/error/error-classes.json
index 753908932c8..4eea5f9684e 100644
--- a/core/src/main/resources/error/error-classes.json
+++ b/core/src/main/resources/error/error-classes.json
@@ -545,6 +545,11 @@
],
"sqlState" : "22012"
},
+ "DUPLICATED_FIELD_NAME_IN_ARROW_STRUCT" : {
+ "message" : [
+ "Duplicated field names in Arrow Struct are not allowed, got
<fieldNames>."
+ ]
+ },
"DUPLICATED_MAP_KEY" : {
"message" : [
"Duplicate map key <key> was found, please check the input data. If you
want to remove the duplicated keys, you can set <mapKeyDedupPolicy> to
\"LAST_WIN\" so that the key inserted at last takes precedence."
diff --git a/python/pyspark/errors/__init__.py
b/python/pyspark/errors/__init__.py
index 1525f351ea4..a9bcb973a6f 100644
--- a/python/pyspark/errors/__init__.py
+++ b/python/pyspark/errors/__init__.py
@@ -25,6 +25,7 @@ from pyspark.errors.exceptions.base import ( # noqa: F401
ParseException,
IllegalArgumentException,
ArithmeticException,
+ UnsupportedOperationException,
ArrayIndexOutOfBoundsException,
DateTimeException,
NumberFormatException,
@@ -50,6 +51,7 @@ __all__ = [
"ParseException",
"IllegalArgumentException",
"ArithmeticException",
+ "UnsupportedOperationException",
"ArrayIndexOutOfBoundsException",
"DateTimeException",
"NumberFormatException",
diff --git a/python/pyspark/errors/error_classes.py
b/python/pyspark/errors/error_classes.py
index 3f52a14a607..c2ed64d01fb 100644
--- a/python/pyspark/errors/error_classes.py
+++ b/python/pyspark/errors/error_classes.py
@@ -149,6 +149,11 @@ ERROR_CLASSES_JSON = """
"Argument `<arg_name>`(type: <arg_type>) should only contain a type in
[<allowed_types>], got <return_type>"
]
},
+ "DUPLICATED_FIELD_NAME_IN_ARROW_STRUCT" : {
+ "message" : [
+ "Duplicated field names in Arrow Struct are not allowed, got
<field_names>"
+ ]
+ },
"EXCEED_RETRY" : {
"message" : [
"Retries exceeded but no exception caught."
diff --git a/python/pyspark/errors/exceptions/base.py
b/python/pyspark/errors/exceptions/base.py
index 543ee9473e3..1b9a6b0229e 100644
--- a/python/pyspark/errors/exceptions/base.py
+++ b/python/pyspark/errors/exceptions/base.py
@@ -126,6 +126,12 @@ class ArithmeticException(PySparkException):
"""
+class UnsupportedOperationException(PySparkException):
+ """
+ Unsupported operation exception thrown from Spark with an error class.
+ """
+
+
class ArrayIndexOutOfBoundsException(PySparkException):
"""
Array index out of bounds exception thrown from Spark with an error class.
diff --git a/python/pyspark/errors/exceptions/captured.py
b/python/pyspark/errors/exceptions/captured.py
index 5b008f4ab00..d62b7d24347 100644
--- a/python/pyspark/errors/exceptions/captured.py
+++ b/python/pyspark/errors/exceptions/captured.py
@@ -26,6 +26,7 @@ from pyspark.errors.exceptions.base import (
AnalysisException as BaseAnalysisException,
IllegalArgumentException as BaseIllegalArgumentException,
ArithmeticException as BaseArithmeticException,
+ UnsupportedOperationException as BaseUnsupportedOperationException,
ArrayIndexOutOfBoundsException as BaseArrayIndexOutOfBoundsException,
DateTimeException as BaseDateTimeException,
NumberFormatException as BaseNumberFormatException,
@@ -141,6 +142,8 @@ def convert_exception(e: Py4JJavaError) ->
CapturedException:
return IllegalArgumentException(origin=e)
elif is_instance_of(gw, e, "java.lang.ArithmeticException"):
return ArithmeticException(origin=e)
+ elif is_instance_of(gw, e, "java.lang.UnsupportedOperationException"):
+ return UnsupportedOperationException(origin=e)
elif is_instance_of(gw, e, "java.lang.ArrayIndexOutOfBoundsException"):
return ArrayIndexOutOfBoundsException(origin=e)
elif is_instance_of(gw, e, "java.time.DateTimeException"):
@@ -262,6 +265,12 @@ class ArithmeticException(CapturedException,
BaseArithmeticException):
"""
+class UnsupportedOperationException(CapturedException,
BaseUnsupportedOperationException):
+ """
+ Unsupported operation exception.
+ """
+
+
class ArrayIndexOutOfBoundsException(CapturedException,
BaseArrayIndexOutOfBoundsException):
"""
Array index out of bounds exception.
diff --git a/python/pyspark/errors/exceptions/connect.py
b/python/pyspark/errors/exceptions/connect.py
index f8f234ed2ee..48b213080c1 100644
--- a/python/pyspark/errors/exceptions/connect.py
+++ b/python/pyspark/errors/exceptions/connect.py
@@ -22,6 +22,7 @@ from pyspark.errors.exceptions.base import (
AnalysisException as BaseAnalysisException,
IllegalArgumentException as BaseIllegalArgumentException,
ArithmeticException as BaseArithmeticException,
+ UnsupportedOperationException as BaseUnsupportedOperationException,
ArrayIndexOutOfBoundsException as BaseArrayIndexOutOfBoundsException,
DateTimeException as BaseDateTimeException,
NumberFormatException as BaseNumberFormatException,
@@ -69,6 +70,8 @@ def convert_exception(info: "ErrorInfo", message: str) ->
SparkConnectException:
return IllegalArgumentException(message)
elif "java.lang.ArithmeticException" in classes:
return ArithmeticException(message)
+ elif "java.lang.UnsupportedOperationException" in classes:
+ return UnsupportedOperationException(message)
elif "java.lang.ArrayIndexOutOfBoundsException" in classes:
return ArrayIndexOutOfBoundsException(message)
elif "java.time.DateTimeException" in classes:
@@ -151,6 +154,12 @@ class ArithmeticException(SparkConnectGrpcException,
BaseArithmeticException):
"""
+class UnsupportedOperationException(SparkConnectGrpcException,
BaseUnsupportedOperationException):
+ """
+ Unsupported operation exception thrown from Spark Connect.
+ """
+
+
class ArrayIndexOutOfBoundsException(SparkConnectGrpcException,
BaseArrayIndexOutOfBoundsException):
"""
Array index out of bounds exception thrown from Spark Connect.
diff --git a/python/pyspark/sql/connect/client.py
b/python/pyspark/sql/connect/client.py
index ffcc4f768ae..070c4ab19d3 100644
--- a/python/pyspark/sql/connect/client.py
+++ b/python/pyspark/sql/connect/client.py
@@ -72,8 +72,8 @@ from pyspark.sql.connect.expressions import (
CommonInlineUserDefinedFunction,
JavaUDF,
)
-from pyspark.sql.pandas.types import _check_series_localize_timestamps,
_convert_map_items_to_dict
-from pyspark.sql.types import DataType, MapType, StructType, TimestampType
+from pyspark.sql.pandas.types import _create_converter_to_pandas
+from pyspark.sql.types import DataType, StructType, TimestampType, _has_type
from pyspark.rdd import PythonEvalType
from pyspark.storagelevel import StorageLevel
from pyspark.errors import PySparkValueError, PySparkRuntimeError
@@ -705,17 +705,35 @@ class SparkConnectClient(object):
schema = schema or types.from_arrow_schema(table.schema)
assert schema is not None and isinstance(schema, StructType)
- pdf = table.to_pandas()
- pdf.columns = schema.fieldNames()
+ # Rename columns to avoid duplicated column names.
+ pdf = table.rename_columns([f"col_{i}" for i in
range(table.num_columns)]).to_pandas()
+ pdf.columns = schema.names
- for field, pa_field in zip(schema, table.schema):
- if isinstance(field.dataType, TimestampType):
- assert pa_field.type.tz is not None
- pdf[field.name] = _check_series_localize_timestamps(
- pdf[field.name], pa_field.type.tz
- )
- elif isinstance(field.dataType, MapType):
- pdf[field.name] = _convert_map_items_to_dict(pdf[field.name])
+ timezone: Optional[str] = None
+ struct_in_pandas: Optional[str] = None
+ error_on_duplicated_field_names: bool = False
+ if any(_has_type(f.dataType, (StructType, TimestampType)) for f in
schema.fields):
+ timezone, struct_in_pandas = self.get_configs(
+ "spark.sql.session.timeZone",
"spark.sql.execution.pandas.structHandlingMode"
+ )
+
+ if struct_in_pandas == "legacy":
+ error_on_duplicated_field_names = True
+ struct_in_pandas = "dict"
+
+ pdf = pd.concat(
+ [
+ _create_converter_to_pandas(
+ field.dataType,
+ field.nullable,
+ timezone=timezone,
+ struct_in_pandas=struct_in_pandas,
+
error_on_duplicated_field_names=error_on_duplicated_field_names,
+ )(pser)
+ for (_, pser), field, pa_field in zip(pdf.items(),
schema.fields, table.schema)
+ ],
+ axis="columns",
+ )
if len(metrics) > 0:
pdf.attrs["metrics"] = metrics
@@ -1068,6 +1086,11 @@ class SparkConnectClient(object):
req.user_context.user_id = self._user_id
return req
+ def get_configs(self, *keys: str) -> Tuple[Optional[str], ...]:
+ op = pb2.ConfigRequest.Operation(get=pb2.ConfigRequest.Get(keys=keys))
+ configs = dict(self.config(op).pairs)
+ return tuple(configs.get(key) for key in keys)
+
def config(self, operation: pb2.ConfigRequest.Operation) -> ConfigResult:
"""
Call the config RPC of Spark Connect.
diff --git a/python/pyspark/sql/connect/conversion.py
b/python/pyspark/sql/connect/conversion.py
index 16679e80205..a7ea88fb007 100644
--- a/python/pyspark/sql/connect/conversion.py
+++ b/python/pyspark/sql/connect/conversion.py
@@ -19,7 +19,6 @@ from pyspark.sql.connect.utils import check_dependencies
check_dependencies(__name__)
import array
-import itertools
import datetime
import decimal
@@ -45,6 +44,7 @@ from pyspark.sql.types import (
from pyspark.storagelevel import StorageLevel
from pyspark.sql.connect.types import to_arrow_schema
import pyspark.sql.connect.proto as pb2
+from pyspark.sql.pandas.types import _dedup_names
from typing import (
Any,
@@ -507,22 +507,3 @@ def _deduplicate_field_names(dt: DataType) -> DataType:
)
else:
return dt
-
-
-def _dedup_names(names: List[str]) -> List[str]:
- if len(set(names)) == len(names):
- return names
- else:
-
- def _gen_dedup(_name: str) -> Callable[[], str]:
- _i = itertools.count()
- return lambda: f"{_name}_{next(_i)}"
-
- def _gen_identity(_name: str) -> Callable[[], str]:
- return lambda: _name
-
- gen_new_name = {
- name: _gen_dedup(name) if len(list(group)) > 1 else
_gen_identity(name)
- for name, group in itertools.groupby(sorted(names))
- }
- return [gen_new_name[name]() for name in names]
diff --git a/python/pyspark/sql/connect/dataframe.py
b/python/pyspark/sql/connect/dataframe.py
index 005eed69722..50eadf46200 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -111,7 +111,7 @@ class DataFrame:
repl_eager_eval_enabled,
repl_eager_eval_max_num_rows,
repl_eager_eval_truncate,
- ) = self._session._get_configs(
+ ) = self._session._client.get_configs(
"spark.sql.repl.eagerEval.enabled",
"spark.sql.repl.eagerEval.maxNumRows",
"spark.sql.repl.eagerEval.truncate",
@@ -131,7 +131,7 @@ class DataFrame:
repl_eager_eval_enabled,
repl_eager_eval_max_num_rows,
repl_eager_eval_truncate,
- ) = self._session._get_configs(
+ ) = self._session._client.get_configs(
"spark.sql.repl.eagerEval.enabled",
"spark.sql.repl.eagerEval.maxNumRows",
"spark.sql.repl.eagerEval.truncate",
diff --git a/python/pyspark/sql/connect/session.py
b/python/pyspark/sql/connect/session.py
index fde861d12b9..3bd842f7847 100644
--- a/python/pyspark/sql/connect/session.py
+++ b/python/pyspark/sql/connect/session.py
@@ -47,7 +47,6 @@ from pandas.api.types import ( # type: ignore[attr-defined]
)
from pyspark import SparkContext, SparkConf, __version__
-from pyspark.sql.connect import proto
from pyspark.sql.connect.client import SparkConnectClient, ChannelBuilder
from pyspark.sql.connect.conf import RuntimeConf
from pyspark.sql.connect.dataframe import DataFrame
@@ -226,11 +225,6 @@ class SparkSession:
def readStream(self) -> "DataStreamReader":
return DataStreamReader(self)
- def _get_configs(self, *keys: str) -> Tuple[Optional[str], ...]:
- op =
proto.ConfigRequest.Operation(get=proto.ConfigRequest.Get(keys=keys))
- configs = dict(self._client.config(op).pairs)
- return tuple(configs.get(key) for key in keys)
-
def _inferSchemaFromList(
self, data: Iterable[Any], names: Optional[List[str]] = None
) -> StructType:
@@ -244,7 +238,7 @@ class SparkSession:
infer_dict_as_struct,
infer_array_from_first_element,
prefer_timestamp_ntz,
- ) = self._get_configs(
+ ) = self._client.get_configs(
"spark.sql.pyspark.inferNestedDictAsStruct.enabled",
"spark.sql.pyspark.legacy.inferArrayTypeFromFirstElement.enabled",
"spark.sql.timestampType",
@@ -334,7 +328,7 @@ class SparkSession:
for t in data.dtypes
]
- timezone, safecheck = self._get_configs(
+ timezone, safecheck = self._client.get_configs(
"spark.sql.session.timeZone",
"spark.sql.execution.pandas.convertToArrowArraySafely"
)
diff --git a/python/pyspark/sql/pandas/conversion.py
b/python/pyspark/sql/pandas/conversion.py
index ce0143d1851..0c29dcceed0 100644
--- a/python/pyspark/sql/pandas/conversion.py
+++ b/python/pyspark/sql/pandas/conversion.py
@@ -15,29 +15,20 @@
# limitations under the License.
#
import sys
-from collections import Counter
-from typing import List, Optional, Type, Union, no_type_check, overload,
TYPE_CHECKING
-from warnings import catch_warnings, simplefilter, warn
+from typing import (
+ List,
+ Optional,
+ Union,
+ no_type_check,
+ overload,
+ TYPE_CHECKING,
+)
+from warnings import warn
from pyspark.errors.exceptions.captured import unwrap_spark_exception
from pyspark.rdd import _load_from_socket
from pyspark.sql.pandas.serializers import ArrowCollectSerializer
-from pyspark.sql.types import (
- IntegralType,
- ByteType,
- ShortType,
- IntegerType,
- LongType,
- FloatType,
- DoubleType,
- BooleanType,
- MapType,
- TimestampType,
- TimestampNTZType,
- DayTimeIntervalType,
- StructType,
- DataType,
-)
+from pyspark.sql.types import TimestampType, StructType, DataType
from pyspark.sql.utils import is_timestamp_ntz_preferred
from pyspark.traceback_utils import SCCallSiteSync
@@ -85,16 +76,16 @@ class PandasConversionMixin:
assert isinstance(self, DataFrame)
+ from pyspark.sql.pandas.types import _create_converter_to_pandas
from pyspark.sql.pandas.utils import require_minimum_pandas_version
require_minimum_pandas_version()
- import numpy as np
import pandas as pd
- from pandas.core.dtypes.common import is_timedelta64_dtype
jconf = self.sparkSession._jconf
timezone = jconf.sessionLocalTimeZone()
+ struct_in_pandas = jconf.pandasStructHandlingMode()
if jconf.arrowPySparkEnabled():
use_arrow = True
@@ -132,18 +123,10 @@ class PandasConversionMixin:
# of PyArrow is found, if
'spark.sql.execution.arrow.pyspark.enabled' is enabled.
if use_arrow:
try:
- from pyspark.sql.pandas.types import (
- _check_series_localize_timestamps,
- _convert_map_items_to_dict,
- )
import pyarrow
- # Rename columns to avoid duplicated column names.
- tmp_column_names = ["col_{}".format(i) for i in
range(len(self.columns))]
self_destruct = jconf.arrowPySparkSelfDestructEnabled()
- batches = self.toDF(*tmp_column_names)._collect_as_arrow(
- split_batches=self_destruct
- )
+ batches =
self._collect_as_arrow(split_batches=self_destruct)
if len(batches) > 0:
table = pyarrow.Table.from_batches(batches)
# Ensure only the table has a reference to the
batches, so that
@@ -165,32 +148,34 @@ class PandasConversionMixin:
"use_threads": False,
}
)
- pdf = table.to_pandas(**pandas_options)
+ # Rename columns to avoid duplicated column names.
+ pdf = table.rename_columns(
+ [f"col_{i}" for i in range(table.num_columns)]
+ ).to_pandas(**pandas_options)
+
# Rename back to the original column names.
pdf.columns = self.columns
- for field in self.schema:
- if isinstance(field.dataType, TimestampType):
- pdf[field.name] =
_check_series_localize_timestamps(
- pdf[field.name], timezone
- )
- elif isinstance(field.dataType, MapType):
- pdf[field.name] =
_convert_map_items_to_dict(pdf[field.name])
- return pdf
else:
- corrected_panda_types = {}
- for index, field in enumerate(self.schema):
- pandas_type =
PandasConversionMixin._to_corrected_pandas_type(
- field.dataType
- )
- corrected_panda_types[tmp_column_names[index]] = (
- object if pandas_type is None else pandas_type
- )
-
- pdf = pd.DataFrame(columns=tmp_column_names).astype(
- dtype=corrected_panda_types
- )
- pdf.columns = self.columns
- return pdf
+ pdf = pd.DataFrame(columns=self.columns)
+
+ error_on_duplicated_field_names = False
+ if struct_in_pandas == "legacy":
+ error_on_duplicated_field_names = True
+ struct_in_pandas = "dict"
+
+ return pd.concat(
+ [
+ _create_converter_to_pandas(
+ field.dataType,
+ field.nullable,
+ timezone=timezone,
+ struct_in_pandas=struct_in_pandas,
+
error_on_duplicated_field_names=error_on_duplicated_field_names,
+ )(pser)
+ for (_, pser), field in zip(pdf.items(),
self.schema.fields)
+ ],
+ axis="columns",
+ )
except Exception as e:
# We might have to allow fallback here as well but
multiple Spark jobs can
# be executed. So, simply fail in this case for now.
@@ -207,107 +192,20 @@ class PandasConversionMixin:
# Below is toPandas without Arrow optimization.
pdf = pd.DataFrame.from_records(self.collect(), columns=self.columns)
- column_counter = Counter(self.columns)
-
- corrected_dtypes: List[Optional[Type]] = [None] * len(self.schema)
- for index, field in enumerate(self.schema):
- # We use `iloc` to access columns with duplicate column names.
- if column_counter[field.name] > 1:
- pandas_col = pdf.iloc[:, index]
- else:
- pandas_col = pdf[field.name]
-
- pandas_type =
PandasConversionMixin._to_corrected_pandas_type(field.dataType)
- # SPARK-21766: if an integer field is nullable and has null
values, it can be
- # inferred by pandas as a float column. If we convert the column
with NaN back
- # to integer type e.g., np.int16, we will hit an exception. So we
use the
- # pandas-inferred float type, rather than the corrected type from
the schema
- # in this case.
- if pandas_type is not None and not (
- isinstance(field.dataType, IntegralType)
- and field.nullable
- and pandas_col.isnull().any()
- ):
- corrected_dtypes[index] = pandas_type
- # Ensure we fall back to nullable numpy types.
- if isinstance(field.dataType, IntegralType) and
pandas_col.isnull().any():
- corrected_dtypes[index] = np.float64
- if isinstance(field.dataType, BooleanType) and
pandas_col.isnull().any():
- corrected_dtypes[index] = object
-
- df = pd.DataFrame()
- for index, t in enumerate(corrected_dtypes):
- column_name = self.schema[index].name
-
- # We use `iloc` to access columns with duplicate column names.
- if column_counter[column_name] > 1:
- series = pdf.iloc[:, index]
- else:
- series = pdf[column_name]
- # No need to cast for non-empty series for timedelta. The type is
already correct.
- should_check_timedelta = is_timedelta64_dtype(t) and len(pdf) == 0
-
- if (t is not None and not is_timedelta64_dtype(t)) or
should_check_timedelta:
- series = series.astype(t, copy=False)
-
- with catch_warnings():
- from pandas.errors import PerformanceWarning
-
- simplefilter(action="ignore", category=PerformanceWarning)
- # `insert` API makes copy of data,
- # we only do it for Series of duplicate column names.
- # `pdf.iloc[:, index] = pdf.iloc[:, index]...` doesn't always
work
- # because `iloc` could return a view or a copy depending by
context.
- if column_counter[column_name] > 1:
- df.insert(index, column_name, series,
allow_duplicates=True)
- else:
- df[column_name] = series
-
- if timezone is None:
- return df
- else:
- from pyspark.sql.pandas.types import
_check_series_convert_timestamps_local_tz
-
- for field in self.schema:
- # TODO: handle nested timestamps, such as
ArrayType(TimestampType())?
- if isinstance(field.dataType, TimestampType):
- df[field.name] = _check_series_convert_timestamps_local_tz(
- df[field.name], timezone
- )
- return df
-
- @staticmethod
- def _to_corrected_pandas_type(dt: DataType) -> Optional[Type]:
- """
- When converting Spark SQL records to Pandas `pandas.DataFrame`, the
inferred data type
- may be wrong. This method gets the corrected data type for Pandas if
that type may be
- inferred incorrectly.
- """
- import numpy as np
-
- if type(dt) == ByteType:
- return np.int8
- elif type(dt) == ShortType:
- return np.int16
- elif type(dt) == IntegerType:
- return np.int32
- elif type(dt) == LongType:
- return np.int64
- elif type(dt) == FloatType:
- return np.float32
- elif type(dt) == DoubleType:
- return np.float64
- elif type(dt) == BooleanType:
- return bool
- elif type(dt) == TimestampType:
- return np.datetime64
- elif type(dt) == TimestampNTZType:
- return np.datetime64
- elif type(dt) == DayTimeIntervalType:
- return np.timedelta64
- else:
- return None
+ return pd.concat(
+ [
+ _create_converter_to_pandas(
+ field.dataType,
+ field.nullable,
+ timezone=timezone,
+ struct_in_pandas=("row" if struct_in_pandas == "legacy"
else struct_in_pandas),
+ error_on_duplicated_field_names=False,
+ )(pser)
+ for (_, pser), field in zip(pdf.items(), self.schema.fields)
+ ],
+ axis="columns",
+ )
def _collect_as_arrow(self, split_batches: bool = False) ->
List["pa.RecordBatch"]:
"""
diff --git a/python/pyspark/sql/pandas/types.py
b/python/pyspark/sql/pandas/types.py
index 70d50ca6e95..d23e83b1a5d 100644
--- a/python/pyspark/sql/pandas/types.py
+++ b/python/pyspark/sql/pandas/types.py
@@ -19,7 +19,8 @@
Type-specific codes between pandas and PyArrow. Also contains some utils to
correct
pandas instances during the type conversion.
"""
-from typing import Optional, TYPE_CHECKING
+import itertools
+from typing import Any, Callable, List, Optional, TYPE_CHECKING
from pyspark.sql.types import (
cast,
@@ -27,6 +28,7 @@ from pyspark.sql.types import (
ByteType,
ShortType,
IntegerType,
+ IntegralType,
LongType,
FloatType,
DoubleType,
@@ -43,10 +45,13 @@ from pyspark.sql.types import (
StructField,
NullType,
DataType,
+ Row,
+ _create_row,
)
-from pyspark.errors import PySparkTypeError
+from pyspark.errors import PySparkTypeError, UnsupportedOperationException
if TYPE_CHECKING:
+ import pandas as pd
import pyarrow as pa
from pyspark.sql.pandas._typing import SeriesLike as PandasSeriesLike
@@ -462,3 +467,236 @@ def _convert_dict_to_map_items(s: "PandasSeriesLike") ->
"PandasSeriesLike":
:return: pandas.Series of lists of (key, value) pairs
"""
return cast("PandasSeriesLike", s.apply(lambda d: list(d.items()) if d is
not None else None))
+
+
+def _to_corrected_pandas_type(dt: DataType) -> Optional[Any]:
+ """
+ When converting Spark SQL records to Pandas `pandas.DataFrame`, the
inferred data type
+ may be wrong. This method gets the corrected data type for Pandas if that
type may be
+ inferred incorrectly.
+ """
+ import numpy as np
+
+ if type(dt) == ByteType:
+ return np.int8
+ elif type(dt) == ShortType:
+ return np.int16
+ elif type(dt) == IntegerType:
+ return np.int32
+ elif type(dt) == LongType:
+ return np.int64
+ elif type(dt) == FloatType:
+ return np.float32
+ elif type(dt) == DoubleType:
+ return np.float64
+ elif type(dt) == BooleanType:
+ return bool
+ elif type(dt) == TimestampType:
+ return np.dtype("datetime64[ns]")
+ elif type(dt) == TimestampNTZType:
+ return np.dtype("datetime64[ns]")
+ elif type(dt) == DayTimeIntervalType:
+ return np.dtype("timedelta64[ns]")
+ else:
+ return None
+
+
+def _create_converter_to_pandas(
+ data_type: DataType,
+ nullable: bool = True,
+ *,
+ timezone: Optional[str] = None,
+ struct_in_pandas: Optional[str] = None,
+ error_on_duplicated_field_names: bool = True,
+) -> Callable[["pd.Series"], "pd.Series"]:
+ """
+ Create a converter of pandas Series that is created from Spark's Python
objects,
+ or `pyarrow.Table.to_pandas` method.
+
+ Parameters
+ ----------
+ data_type : :class:`DataType`
+ The data type corresponding to the pandas Series to be converted.
+ nullable : bool, optional
+ Whether the column is nullable or not. (default ``True``)
+ timezone : str, optional
+ The timezone to convert from. If there is a timestamp type, it's
required.
+ struct_in_pandas : str, optional
+ How to handle struct type. If there is a struct type, it's required.
+ When ``row``, :class:`Row` object will be used.
+ When ``dict``, :class:`dict` will be used. If there are duplicated
field names,
+ The fields will be suffixed, like `a_0`, `a_1`.
+ Must be one of: ``row``, ``dict``.
+ error_on_duplicated_field_names : bool, optional
+ Whether raise an exception when there are duplicated field names.
+ (default ``True``)
+
+ Returns
+ -------
+ The converter of `pandas.Series`
+ """
+ import numpy as np
+ import pandas as pd
+ from pandas.core.dtypes.common import is_datetime64tz_dtype
+
+ pandas_type = _to_corrected_pandas_type(data_type)
+
+ if pandas_type is not None:
+ # SPARK-21766: if an integer field is nullable and has null values, it
can be
+ # inferred by pandas as a float column. If we convert the column with
NaN back
+ # to integer type e.g., np.int16, we will hit an exception. So we use
the
+ # pandas-inferred float type, rather than the corrected type from the
schema
+ # in this case.
+ if isinstance(data_type, IntegralType) and nullable:
+
+ def correct_dtype(pser: pd.Series) -> pd.Series:
+ if pser.isnull().any():
+ return pser.astype(np.float64, copy=False)
+ else:
+ return pser.astype(pandas_type, copy=False)
+
+ elif isinstance(data_type, BooleanType) and nullable:
+
+ def correct_dtype(pser: pd.Series) -> pd.Series:
+ if pser.isnull().any():
+ return pser.astype(object, copy=False)
+ else:
+ return pser.astype(pandas_type, copy=False)
+
+ elif isinstance(data_type, TimestampType):
+ assert timezone is not None
+
+ def correct_dtype(pser: pd.Series) -> pd.Series:
+ if not is_datetime64tz_dtype(pser.dtype):
+ pser = pser.astype(pandas_type, copy=False)
+ return _check_series_convert_timestamps_local_tz(pser,
timezone=cast(str, timezone))
+
+ else:
+
+ def correct_dtype(pser: pd.Series) -> pd.Series:
+ return pser.astype(pandas_type, copy=False)
+
+ return correct_dtype
+
+ def _converter(dt: DataType) -> Optional[Callable[[Any], Any]]:
+
+ if isinstance(dt, ArrayType):
+ _element_conv = _converter(dt.elementType)
+ if _element_conv is None:
+ return None
+
+ def convert_array(value: Any) -> Any:
+ if value is None:
+ return None
+ elif isinstance(value, np.ndarray):
+ # `pyarrow.Table.to_pandas` uses `np.ndarray`.
+ return np.array([_element_conv(v) for v in value]) #
type: ignore[misc]
+ else:
+ assert isinstance(value, list)
+ # otherwise, `list` should be used.
+ return [_element_conv(v) for v in value] # type:
ignore[misc]
+
+ return convert_array
+
+ elif isinstance(dt, MapType):
+ _key_conv = _converter(dt.keyType) or (lambda x: x)
+ _value_conv = _converter(dt.valueType) or (lambda x: x)
+
+ def convert_map(value: Any) -> Any:
+ if value is None:
+ return None
+ elif isinstance(value, list):
+ # `pyarrow.Table.to_pandas` uses `list` of key-value tuple.
+ return {_key_conv(k): _value_conv(v) for k, v in value}
+ else:
+ assert isinstance(value, dict)
+ # otherwise, `dict` should be used.
+ return {_key_conv(k): _value_conv(v) for k, v in
value.items()}
+
+ return convert_map
+
+ elif isinstance(dt, StructType):
+ assert struct_in_pandas is not None
+
+ field_names = dt.names
+
+ if error_on_duplicated_field_names and len(set(field_names)) !=
len(field_names):
+ raise UnsupportedOperationException(
+ error_class="DUPLICATED_FIELD_NAME_IN_ARROW_STRUCT",
+ message_parameters={"field_names": str(field_names)},
+ )
+
+ dedup_field_names = _dedup_names(field_names)
+
+ field_convs = [_converter(f.dataType) or (lambda x: x) for f in
dt.fields]
+
+ if struct_in_pandas == "row":
+
+ def convert_struct_as_row(value: Any) -> Any:
+ if value is None:
+ return None
+ elif isinstance(value, dict):
+ # `pyarrow.Table.to_pandas` uses `dict`.
+ _values = [
+ field_convs[i](value.get(name, None))
+ for i, name in enumerate(dedup_field_names)
+ ]
+ return _create_row(field_names, _values)
+ else:
+ assert isinstance(value, Row)
+ # otherwise, `Row` should be used.
+ _values = [field_convs[i](value[i]) for i, name in
enumerate(value)]
+ return _create_row(field_names, _values)
+
+ return convert_struct_as_row
+
+ elif struct_in_pandas == "dict":
+
+ def convert_struct_as_dict(value: Any) -> Any:
+ if value is None:
+ return None
+ elif isinstance(value, dict):
+ # `pyarrow.Table.to_pandas` uses `dict`.
+ return {
+ name: field_convs[i](value.get(name, None))
+ for i, name in enumerate(dedup_field_names)
+ }
+ else:
+ assert isinstance(value, Row)
+ # otherwise, `Row` should be used.
+ return {
+ dedup_field_names[i]: field_convs[i](v) for i, v
in enumerate(value)
+ }
+
+ return convert_struct_as_dict
+
+ else:
+ raise ValueError(f"Unknown value for `struct_in_pandas`:
{struct_in_pandas}")
+
+ else:
+ return None
+
+ conv = _converter(data_type)
+ if conv is not None:
+ return lambda pser: pser.apply(conv) # type: ignore[return-value]
+ else:
+ return lambda pser: pser
+
+
+def _dedup_names(names: List[str]) -> List[str]:
+ if len(set(names)) == len(names):
+ return names
+ else:
+
+ def _gen_dedup(_name: str) -> Callable[[], str]:
+ _i = itertools.count()
+ return lambda: f"{_name}_{next(_i)}"
+
+ def _gen_identity(_name: str) -> Callable[[], str]:
+ return lambda: _name
+
+ gen_new_name = {
+ name: _gen_dedup(name) if len(list(group)) > 1 else
_gen_identity(name)
+ for name, group in itertools.groupby(sorted(names))
+ }
+ return [gen_new_name[name]() for name in names]
diff --git a/python/pyspark/sql/tests/connect/test_parity_arrow.py
b/python/pyspark/sql/tests/connect/test_parity_arrow.py
index f2fa9ece4df..d27077f8907 100644
--- a/python/pyspark/sql/tests/connect/test_parity_arrow.py
+++ b/python/pyspark/sql/tests/connect/test_parity_arrow.py
@@ -106,6 +106,9 @@ class ArrowParityTests(ArrowTestsMixin,
ReusedConnectTestCase):
def test_toPandas_error(self):
self.check_toPandas_error(True)
+ def test_toPandas_duplicate_field_names(self):
+ self.check_toPandas_duplicate_field_names(True)
+
if __name__ == "__main__":
from pyspark.sql.tests.connect.test_parity_arrow import * # noqa: F401
diff --git a/python/pyspark/sql/tests/test_arrow.py
b/python/pyspark/sql/tests/test_arrow.py
index 84c782e8d95..52e13782199 100644
--- a/python/pyspark/sql/tests/test_arrow.py
+++ b/python/pyspark/sql/tests/test_arrow.py
@@ -55,7 +55,7 @@ from pyspark.testing.sqlutils import (
pyarrow_requirement_message,
)
from pyspark.testing.utils import QuietTest
-from pyspark.errors import ArithmeticException, PySparkTypeError
+from pyspark.errors import ArithmeticException, PySparkTypeError,
UnsupportedOperationException
if have_pandas:
import pandas as pd
@@ -888,6 +888,60 @@ class ArrowTestsMixin:
with self.assertRaises(ArithmeticException):
self.spark.sql("select 1/0").toPandas()
+ def test_toPandas_duplicate_field_names(self):
+ for arrow_enabled in [True, False]:
+ with self.subTest(arrow_enabled=arrow_enabled):
+ self.check_toPandas_duplicate_field_names(arrow_enabled)
+
+ def check_toPandas_duplicate_field_names(self, arrow_enabled):
+ data = [Row(Row("a", 1), Row(2, 3, "b", 4, "c")), Row(Row("x", 6),
Row(7, 8, "y", 9, "z"))]
+ schema = (
+ StructType()
+ .add("struct", StructType().add("x", StringType()).add("x",
IntegerType()))
+ .add(
+ "struct",
+ StructType()
+ .add("a", IntegerType())
+ .add("x", IntegerType())
+ .add("x", StringType())
+ .add("y", IntegerType())
+ .add("y", StringType()),
+ )
+ )
+ for struct_in_pandas in ["legacy", "row", "dict"]:
+ df = self.spark.createDataFrame(data, schema=schema)
+
+ with self.subTest(struct_in_pandas=struct_in_pandas):
+ with self.sql_conf(
+ {
+ "spark.sql.execution.arrow.pyspark.enabled":
arrow_enabled,
+ "spark.sql.execution.pandas.structHandlingMode":
struct_in_pandas,
+ }
+ ):
+ if arrow_enabled and struct_in_pandas == "legacy":
+ with self.assertRaisesRegexp(
+ UnsupportedOperationException,
"DUPLICATED_FIELD_NAME_IN_ARROW_STRUCT"
+ ):
+ df.toPandas()
+ else:
+ if struct_in_pandas == "dict":
+ expected = pd.DataFrame(
+ [
+ [
+ {"x_0": "a", "x_1": 1},
+ {"a": 2, "x_0": 3, "x_1": "b", "y_0":
4, "y_1": "c"},
+ ],
+ [
+ {"x_0": "x", "x_1": 6},
+ {"a": 7, "x_0": 8, "x_1": "y", "y_0":
9, "y_1": "z"},
+ ],
+ ],
+ columns=schema.names,
+ )
+ else:
+ expected = pd.DataFrame.from_records(data,
columns=schema.names)
+ assert_frame_equal(df.toPandas(), expected)
+
@unittest.skipIf(
not have_pandas or not have_pyarrow,
diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py
index 6c7751c522a..70d90a03c10 100644
--- a/python/pyspark/sql/types.py
+++ b/python/pyspark/sql/types.py
@@ -1687,6 +1687,20 @@ def _has_nulltype(dt: DataType) -> bool:
return isinstance(dt, NullType)
+def _has_type(dt: DataType, dts: Union[type, Tuple[type, ...]]) -> bool:
+ """Return whether there are specified types"""
+ if isinstance(dt, dts):
+ return True
+ elif isinstance(dt, StructType):
+ return any(_has_type(f.dataType, dts) for f in dt.fields)
+ elif isinstance(dt, ArrayType):
+ return _has_type(dt.elementType, dts)
+ elif isinstance(dt, MapType):
+ return _has_type(dt.keyType, dts) or _has_type(dt.valueType, dts)
+ else:
+ return False
+
+
@overload
def _merge_type(a: StructType, b: StructType, name: Optional[str] = None) ->
StructType:
...
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
index e7d310c25c2..111f0391a72 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
@@ -1144,6 +1144,13 @@ private[sql] object QueryExecutionErrors extends
QueryErrorsBase {
messageParameters = Map("typeName" -> toSQLType(typeName)))
}
+ def duplicatedFieldNameInArrowStructError(
+ fieldNames: Seq[String]): SparkUnsupportedOperationException = {
+ new SparkUnsupportedOperationException(
+ errorClass = "DUPLICATED_FIELD_NAME_IN_ARROW_STRUCT",
+ messageParameters = Map("fieldNames" -> fieldNames.mkString("[", ", ",
"]")))
+ }
+
def notSupportTypeError(dataType: DataType): Throwable = {
new SparkException(
errorClass = "_LEGACY_ERROR_TEMP_2100",
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala
index af7126495c5..efdbc583207 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala
@@ -30,8 +30,11 @@ import org.apache.spark.sql.util.ArrowUtils
object ArrowWriter {
- def create(schema: StructType, timeZoneId: String): ArrowWriter = {
- val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId)
+ def create(
+ schema: StructType,
+ timeZoneId: String,
+ errorOnDuplicatedFieldNames: Boolean = true): ArrowWriter = {
+ val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId,
errorOnDuplicatedFieldNames)
val root = VectorSchemaRoot.create(arrowSchema, ArrowUtils.rootAllocator)
create(root)
}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 874f95af1cb..c9974d2dfa8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -2835,6 +2835,23 @@ object SQLConf {
.version("3.0.0")
.fallbackConf(BUFFER_SIZE)
+ val PANDAS_STRUCT_HANDLING_MODE =
+ buildConf("spark.sql.execution.pandas.structHandlingMode")
+ .doc(
+ "The conversion mode of struct type when creating pandas DataFrame. " +
+ "When \"legacy\"," +
+ "1. when Arrow optimization is disabled, convert to Row object, " +
+ "2. when Arrow optimization is enabled, convert to dict or raise an
Exception " +
+ "if there are duplicated nested field names. " +
+ "When \"row\", convert to Row object regardless of Arrow optimization.
" +
+ "When \"dict\", convert to dict and use suffixed key names, e.g., a_0,
a_1, " +
+ "if there are duplicated nested field names, regardless of Arrow
optimization."
+ )
+ .version("3.5.0")
+ .stringConf
+ .checkValues(Set("legacy", "row", "dict"))
+ .createWithDefaultString("legacy")
+
val PYSPARK_SIMPLIFIED_TRACEBACK =
buildConf("spark.sql.execution.pyspark.udf.simplifiedTraceback.enabled")
.doc(
@@ -4855,6 +4872,8 @@ class SQLConf extends Serializable with Logging {
def pandasUDFBufferSize: Int = getConf(PANDAS_UDF_BUFFER_SIZE)
+ def pandasStructHandlingMode: String = getConf(PANDAS_STRUCT_HANDLING_MODE)
+
def pysparkSimplifiedTraceback: Boolean =
getConf(PYSPARK_SIMPLIFIED_TRACEBACK)
def pandasGroupedMapAssignColumnsByName: Boolean =
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala
index d6a8fec81dd..719691a338f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.util
+import java.util.concurrent.atomic.AtomicInteger
+
import scala.collection.JavaConverters._
import org.apache.arrow.memory.RootAllocator
@@ -135,9 +137,16 @@ private[sql] object ArrowUtils {
}
/** Maps schema from Spark to Arrow. NOTE: timeZoneId required for
TimestampType in StructType */
- def toArrowSchema(schema: StructType, timeZoneId: String): Schema = {
+ def toArrowSchema(
+ schema: StructType,
+ timeZoneId: String,
+ errorOnDuplicatedFieldNames: Boolean): Schema = {
new Schema(schema.map { field =>
- toArrowField(field.name, field.dataType, field.nullable, timeZoneId)
+ toArrowField(
+ field.name,
+ deduplicateFieldNames(field.dataType, errorOnDuplicatedFieldNames),
+ field.nullable,
+ timeZoneId)
}.asJava)
}
@@ -157,4 +166,38 @@ private[sql] object ArrowUtils {
conf.arrowSafeTypeConversion.toString)
Map(timeZoneConf ++ pandasColsByName ++ arrowSafeTypeCheck: _*)
}
+
+ private def deduplicateFieldNames(
+ dt: DataType, errorOnDuplicatedFieldNames: Boolean): DataType = dt match
{
+ case udt: UserDefinedType[_] => deduplicateFieldNames(udt.sqlType,
errorOnDuplicatedFieldNames)
+ case st @ StructType(fields) =>
+ val newNames = if (st.names.toSet.size == st.names.length) {
+ st.names
+ } else {
+ if (errorOnDuplicatedFieldNames) {
+ throw
QueryExecutionErrors.duplicatedFieldNameInArrowStructError(st.names)
+ }
+ val genNawName = st.names.groupBy(identity).map {
+ case (name, names) if names.length > 1 =>
+ val i = new AtomicInteger()
+ name -> { () => s"${name}_${i.getAndIncrement()}" }
+ case (name, _) => name -> { () => name }
+ }
+ st.names.map(genNawName(_)())
+ }
+ val newFields =
+ fields.zip(newNames).map { case (StructField(_, dataType, nullable,
metadata), name) =>
+ StructField(
+ name, deduplicateFieldNames(dataType,
errorOnDuplicatedFieldNames), nullable, metadata)
+ }
+ StructType(newFields)
+ case ArrayType(elementType, containsNull) =>
+ ArrayType(deduplicateFieldNames(elementType,
errorOnDuplicatedFieldNames), containsNull)
+ case MapType(keyType, valueType, valueContainsNull) =>
+ MapType(
+ deduplicateFieldNames(keyType, errorOnDuplicatedFieldNames),
+ deduplicateFieldNames(valueType, errorOnDuplicatedFieldNames),
+ valueContainsNull)
+ case _ => dt
+ }
}
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/util/ArrowUtilsSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/util/ArrowUtilsSuite.scala
index 2f78d03db80..28ed061a71b 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/util/ArrowUtilsSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/util/ArrowUtilsSuite.scala
@@ -30,7 +30,7 @@ class ArrowUtilsSuite extends SparkFunSuite {
def roundtrip(dt: DataType): Unit = {
dt match {
case schema: StructType =>
- assert(ArrowUtils.fromArrowSchema(ArrowUtils.toArrowSchema(schema,
null)) === schema)
+ assert(ArrowUtils.fromArrowSchema(ArrowUtils.toArrowSchema(schema,
null, true)) === schema)
case _ =>
roundtrip(new StructType().add("value", dt))
}
@@ -67,7 +67,7 @@ class ArrowUtilsSuite extends SparkFunSuite {
def roundtripWithTz(timeZoneId: String): Unit = {
val schema = new StructType().add("value", TimestampType)
- val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId)
+ val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId, true)
val fieldType =
arrowSchema.findField("value").getType.asInstanceOf[ArrowType.Timestamp]
assert(fieldType.getTimezone() === timeZoneId)
assert(ArrowUtils.fromArrowSchema(arrowSchema) === schema)
@@ -97,4 +97,25 @@ class ArrowUtilsSuite extends SparkFunSuite {
"struct",
new StructType().add("i", IntegerType).add("arr",
ArrayType(IntegerType))))
}
+
+ test("struct with duplicated field names") {
+
+ def check(dt: DataType, expected: DataType): Unit = {
+ val schema = new StructType().add("value", dt)
+ intercept[SparkUnsupportedOperationException] {
+ ArrowUtils.toArrowSchema(schema, null, true)
+ }
+ assert(ArrowUtils.fromArrowSchema(ArrowUtils.toArrowSchema(schema, null,
false))
+ === new StructType().add("value", expected))
+ }
+
+ roundtrip(new StructType().add("i", IntegerType).add("i", StringType))
+
+ check(new StructType().add("i", IntegerType).add("i", StringType),
+ new StructType().add("i_0", IntegerType).add("i_1", StringType))
+ check(ArrayType(new StructType().add("i", IntegerType).add("i",
StringType)),
+ ArrayType(new StructType().add("i_0", IntegerType).add("i_1",
StringType)))
+ check(MapType(StringType, new StructType().add("i", IntegerType).add("i",
StringType)),
+ MapType(StringType, new StructType().add("i_0", IntegerType).add("i_1",
StringType)))
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 7973ba38b1a..7137e24ed96 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -4173,7 +4173,8 @@ class Dataset[T] private[sql](
withAction("collectAsArrowToR", queryExecution) { plan =>
val buffer = new ByteArrayOutputStream()
val out = new DataOutputStream(outputStream)
- val batchWriter = new ArrowBatchStreamWriter(schema, buffer,
timeZoneId)
+ val batchWriter =
+ new ArrowBatchStreamWriter(schema, buffer, timeZoneId,
errorOnDuplicatedFieldNames = true)
val arrowBatchRdd = toArrowBatchRdd(plan)
val numPartitions = arrowBatchRdd.partitions.length
@@ -4222,11 +4223,14 @@ class Dataset[T] private[sql](
*/
private[sql] def collectAsArrowToPython: Array[Any] = {
val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone
+ val errorOnDuplicatedFieldNames =
+ sparkSession.sessionState.conf.pandasStructHandlingMode == "legacy"
PythonRDD.serveToStream("serve-Arrow") { outputStream =>
withAction("collectAsArrowToPython", queryExecution) { plan =>
val out = new DataOutputStream(outputStream)
- val batchWriter = new ArrowBatchStreamWriter(schema, out, timeZoneId)
+ val batchWriter =
+ new ArrowBatchStreamWriter(schema, out, timeZoneId,
errorOnDuplicatedFieldNames)
// Batches ordered by (index of partition, batch index in that
partition) tuple
val batchOrder = ArrayBuffer.empty[(Int, Int)]
@@ -4355,10 +4359,12 @@ class Dataset[T] private[sql](
val schemaCaptured = this.schema
val maxRecordsPerBatch =
sparkSession.sessionState.conf.arrowMaxRecordsPerBatch
val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone
+ val errorOnDuplicatedFieldNames =
+ sparkSession.sessionState.conf.pandasStructHandlingMode == "legacy"
plan.execute().mapPartitionsInternal { iter =>
val context = TaskContext.get()
ArrowConverters.toBatchIterator(
- iter, schemaCaptured, maxRecordsPerBatch, timeZoneId, context)
+ iter, schemaCaptured, maxRecordsPerBatch, timeZoneId,
errorOnDuplicatedFieldNames, context)
}
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala
index b22c80d17e8..8d6d2c78051 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala
@@ -48,9 +48,10 @@ import org.apache.spark.util.{ByteBufferOutputStream,
SizeEstimator, Utils}
private[sql] class ArrowBatchStreamWriter(
schema: StructType,
out: OutputStream,
- timeZoneId: String) {
+ timeZoneId: String,
+ errorOnDuplicatedFieldNames: Boolean) {
- val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId)
+ val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId,
errorOnDuplicatedFieldNames)
val writeChannel = new WriteChannel(Channels.newChannel(out))
// Write the Arrow schema first, before batches
@@ -77,9 +78,11 @@ private[sql] object ArrowConverters extends Logging {
schema: StructType,
maxRecordsPerBatch: Long,
timeZoneId: String,
+ errorOnDuplicatedFieldNames: Boolean,
context: TaskContext) extends Iterator[Array[Byte]] {
- protected val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId)
+ protected val arrowSchema =
+ ArrowUtils.toArrowSchema(schema, timeZoneId, errorOnDuplicatedFieldNames)
private val allocator =
ArrowUtils.rootAllocator.newChildAllocator(
s"to${this.getClass.getSimpleName}", 0, Long.MaxValue)
@@ -128,9 +131,10 @@ private[sql] object ArrowConverters extends Logging {
maxRecordsPerBatch: Long,
maxEstimatedBatchSize: Long,
timeZoneId: String,
+ errorOnDuplicatedFieldNames: Boolean,
context: TaskContext)
extends ArrowBatchIterator(
- rowIter, schema, maxRecordsPerBatch, timeZoneId, context) {
+ rowIter, schema, maxRecordsPerBatch, timeZoneId,
errorOnDuplicatedFieldNames, context) {
private val arrowSchemaSize = SizeEstimator.estimate(arrowSchema)
var rowCountInLastBatch: Long = 0
@@ -190,9 +194,10 @@ private[sql] object ArrowConverters extends Logging {
schema: StructType,
maxRecordsPerBatch: Long,
timeZoneId: String,
+ errorOnDuplicatedFieldNames: Boolean,
context: TaskContext): ArrowBatchIterator = {
new ArrowBatchIterator(
- rowIter, schema, maxRecordsPerBatch, timeZoneId, context)
+ rowIter, schema, maxRecordsPerBatch, timeZoneId,
errorOnDuplicatedFieldNames, context)
}
/**
@@ -204,16 +209,20 @@ private[sql] object ArrowConverters extends Logging {
schema: StructType,
maxRecordsPerBatch: Long,
maxEstimatedBatchSize: Long,
- timeZoneId: String): ArrowBatchWithSchemaIterator = {
+ timeZoneId: String,
+ errorOnDuplicatedFieldNames: Boolean): ArrowBatchWithSchemaIterator = {
new ArrowBatchWithSchemaIterator(
- rowIter, schema, maxRecordsPerBatch, maxEstimatedBatchSize, timeZoneId,
TaskContext.get)
+ rowIter, schema, maxRecordsPerBatch, maxEstimatedBatchSize,
+ timeZoneId, errorOnDuplicatedFieldNames, TaskContext.get)
}
private[sql] def createEmptyArrowBatch(
schema: StructType,
- timeZoneId: String): Array[Byte] = {
+ timeZoneId: String,
+ errorOnDuplicatedFieldNames: Boolean): Array[Byte] = {
new ArrowBatchWithSchemaIterator(
- Iterator.empty, schema, 0L, 0L, timeZoneId, TaskContext.get) {
+ Iterator.empty, schema, 0L, 0L,
+ timeZoneId, errorOnDuplicatedFieldNames, TaskContext.get) {
override def hasNext: Boolean = true
}.next()
}
@@ -275,7 +284,8 @@ private[sql] object ArrowConverters extends Logging {
extends InternalRowIterator(arrowBatchIter, context) {
override def nextBatch(): (Iterator[InternalRow], StructType) = {
- val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId)
+ val arrowSchema =
+ ArrowUtils.toArrowSchema(schema, timeZoneId,
errorOnDuplicatedFieldNames = true)
val root = VectorSchemaRoot.create(arrowSchema, allocator)
resources.append(root)
val arrowRecordBatch = ArrowConverters.loadBatch(arrowBatchIter.next(),
allocator)
@@ -353,9 +363,6 @@ private[sql] object ArrowConverters extends Logging {
new ReadChannel(Channels.newChannel(in)), allocator) // throws
IOException
}
- /**
- * Create a DataFrame from an iterator of serialized ArrowRecordBatches.
- */
/**
* Create a DataFrame from an iterator of serialized ArrowRecordBatches.
*/
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala
index 79773a9d534..d4b6f0db9c9 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala
@@ -70,6 +70,8 @@ class ApplyInPandasWithStatePythonRunner(
override protected val schema: StructType = inputSchema.add("__state",
STATE_METADATA_SCHEMA)
+ override val errorOnDuplicatedFieldNames: Boolean = true
+
override val simplifiedTraceback: Boolean =
sqlConf.pysparkSimplifiedTraceback
override val bufferSize: Int = {
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
index dbafc444281..427dfbbb32d 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
@@ -39,6 +39,8 @@ class ArrowPythonRunner(
with BasicPythonArrowInput
with BasicPythonArrowOutput {
+ override val errorOnDuplicatedFieldNames: Boolean = true
+
override val simplifiedTraceback: Boolean =
SQLConf.get.pysparkSimplifiedTraceback
override val bufferSize: Int = SQLConf.get.pandasUDFBufferSize
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala
index 1df9f37188a..763a6d08d78 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala
@@ -96,7 +96,8 @@ class CoGroupedArrowPythonRunner(
schema: StructType,
dataOut: DataOutputStream,
name: String): Unit = {
- val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId)
+ val arrowSchema =
+ ArrowUtils.toArrowSchema(schema, timeZoneId,
errorOnDuplicatedFieldNames = true)
val allocator = ArrowUtils.rootAllocator.newChildAllocator(
s"stdout writer for $pythonExec ($name)", 0, Long.MaxValue)
val root = VectorSchemaRoot.create(arrowSchema, allocator)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala
index 5a0541d11cb..26ce10b6aae 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala
@@ -42,6 +42,8 @@ private[python] trait PythonArrowInput[IN] { self:
BasePythonRunner[IN, _] =>
protected val timeZoneId: String
+ protected val errorOnDuplicatedFieldNames: Boolean
+
protected def pythonMetrics: Map[String, SQLMetric]
protected def writeIteratorToArrowStream(
@@ -73,7 +75,7 @@ private[python] trait PythonArrowInput[IN] { self:
BasePythonRunner[IN, _] =>
}
protected override def writeIteratorToStream(dataOut: DataOutputStream):
Unit = {
- val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId)
+ val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId,
errorOnDuplicatedFieldNames)
val allocator = ArrowUtils.rootAllocator.newChildAllocator(
s"stdout writer for $pythonExec", 0, Long.MaxValue)
val root = VectorSchemaRoot.create(arrowSchema, allocator)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/r/ArrowRRunner.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/r/ArrowRRunner.scala
index ae7b7ef2351..69faa4c8fec 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/r/ArrowRRunner.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/r/ArrowRRunner.scala
@@ -83,7 +83,8 @@ class ArrowRRunner(
*/
override protected def writeIteratorToStream(dataOut: DataOutputStream):
Unit = {
if (inputIterator.hasNext) {
- val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId)
+ val arrowSchema =
+ ArrowUtils.toArrowSchema(schema, timeZoneId,
errorOnDuplicatedFieldNames = true)
val allocator = ArrowUtils.rootAllocator.newChildAllocator(
"stdout writer for R", 0, Long.MaxValue)
val root = VectorSchemaRoot.create(arrowSchema, allocator)
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala
index 82e4c970837..e458d33a7cd 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala
@@ -1381,7 +1381,7 @@ class ArrowConvertersSuite extends SharedSparkSession {
val schema = StructType(Seq(StructField("int", IntegerType, nullable =
true)))
val ctx = TaskContext.empty()
- val batchIter = ArrowConverters.toBatchIterator(inputRows.iterator,
schema, 5, null, ctx)
+ val batchIter = ArrowConverters.toBatchIterator(inputRows.iterator,
schema, 5, null, true, ctx)
val outputRowIter = ArrowConverters.fromBatchIterator(batchIter, schema,
null, ctx)
var count = 0
@@ -1402,12 +1402,12 @@ class ArrowConvertersSuite extends SharedSparkSession {
val schema = StructType(Seq(StructField("int", IntegerType, nullable =
true)))
val ctx = TaskContext.empty()
- val batchIter = ArrowConverters.toBatchIterator(inputRows.iterator,
schema, 5, null, ctx)
+ val batchIter = ArrowConverters.toBatchIterator(inputRows.iterator,
schema, 5, null, true, ctx)
// Write batches to Arrow stream format as a byte array
val out = new ByteArrayOutputStream()
Utils.tryWithResource(new DataOutputStream(out)) { dataOut =>
- val writer = new ArrowBatchStreamWriter(schema, dataOut, null)
+ val writer = new ArrowBatchStreamWriter(schema, dataOut, null, true)
writer.writeBatches(batchIter)
writer.end()
}
@@ -1445,8 +1445,8 @@ class ArrowConvertersSuite extends SharedSparkSession {
proj(row).copy()
}
val ctx = TaskContext.empty()
- val batchIter =
- ArrowConverters.toBatchWithSchemaIterator(inputRows.iterator, schema, 5,
1024 * 1024, null)
+ val batchIter = ArrowConverters.toBatchWithSchemaIterator(
+ inputRows.iterator, schema, 5, 1024 * 1024, null, true)
val (outputRowIter, outputType) =
ArrowConverters.fromBatchWithSchemaIterator(batchIter, ctx)
var count = 0
@@ -1465,7 +1465,7 @@ class ArrowConvertersSuite extends SharedSparkSession {
val schema = StructType(Seq(StructField("int", IntegerType, nullable =
true)))
val ctx = TaskContext.empty()
val batchIter =
- ArrowConverters.toBatchWithSchemaIterator(Iterator.empty, schema, 5,
1024 * 1024, null)
+ ArrowConverters.toBatchWithSchemaIterator(Iterator.empty, schema, 5,
1024 * 1024, null, true)
val (outputRowIter, outputType) =
ArrowConverters.fromBatchWithSchemaIterator(batchIter, ctx)
assert(0 == outputRowIter.length)
@@ -1479,7 +1479,7 @@ class ArrowConvertersSuite extends SharedSparkSession {
proj(row).copy()
}
val batchIter1 = ArrowConverters.toBatchWithSchemaIterator(
- inputRows1.iterator, schema1, 5, 1024 * 1024, null)
+ inputRows1.iterator, schema1, 5, 1024 * 1024, null, true)
val schema2 = StructType(Seq(StructField("field2", IntegerType, nullable =
true)))
val inputRows2 = Array(InternalRow(1)).map { row =>
@@ -1487,7 +1487,7 @@ class ArrowConvertersSuite extends SharedSparkSession {
proj(row).copy()
}
val batchIter2 = ArrowConverters.toBatchWithSchemaIterator(
- inputRows2.iterator, schema2, 5, 1024 * 1024, null)
+ inputRows2.iterator, schema2, 5, 1024 * 1024, null, true)
val iter = batchIter1.toArray ++ batchIter2
@@ -1499,23 +1499,28 @@ class ArrowConvertersSuite extends SharedSparkSession {
/** Test that a converted DataFrame to Arrow record batch equals batch read
from JSON file */
private def collectAndValidate(
- df: DataFrame, json: String, file: String, timeZoneId: String = null):
Unit = {
+ df: DataFrame,
+ json: String,
+ file: String,
+ timeZoneId: String = null,
+ errorOnDuplicatedFieldNames: Boolean = true): Unit = {
// NOTE: coalesce to single partition because can only load 1 batch in
validator
val batchBytes = df.coalesce(1).toArrowBatchRdd.collect().head
val tempFile = new File(tempDataPath, file)
Files.write(json, tempFile, StandardCharsets.UTF_8)
- validateConversion(df.schema, batchBytes, tempFile, timeZoneId)
+ validateConversion(df.schema, batchBytes, tempFile, timeZoneId,
errorOnDuplicatedFieldNames)
}
private def validateConversion(
sparkSchema: StructType,
batchBytes: Array[Byte],
jsonFile: File,
- timeZoneId: String = null): Unit = {
+ timeZoneId: String = null,
+ errorOnDuplicatedFieldNames: Boolean = true): Unit = {
val allocator = new RootAllocator(Long.MaxValue)
val jsonReader = new JsonFileReader(jsonFile, allocator)
- val arrowSchema = ArrowUtils.toArrowSchema(sparkSchema, timeZoneId)
+ val arrowSchema = ArrowUtils.toArrowSchema(sparkSchema, timeZoneId,
errorOnDuplicatedFieldNames)
val jsonSchema = jsonReader.start()
Validator.compareSchemas(arrowSchema, jsonSchema)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]