This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new fbc3471e8748 [SPARK-55285][SQL][PYTHON] Fix the initialization of
`PythonArrowInput`
fbc3471e8748 is described below
commit fbc3471e87480062d27dbe92f7673d6708b041a3
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Mon Feb 2 06:52:44 2026 +0900
[SPARK-55285][SQL][PYTHON] Fix the initialization of `PythonArrowInput`
### What changes were proposed in this pull request?
Delay the initialization of `PythonArrowInput`
### Why are the changes needed?
1, the initialization of `PythonArrowInput` is too early
```
// Use lazy val to initialize the fields before these are accessed in
[[PythonArrowInput]]'s
// constructor.
override protected lazy val schema: StructType = _schema
override protected lazy val timeZoneId: String = _timeZoneId
```
we have such code around the python plans to resolve `arrowSchema`
initialization with null `schema/timeZoneId`.
If I revert such code then some Python UDF tests fails due to the ``Cannot
invoke "org.apache.spark.sql.types.StructType.map(scala.Function1)" because
"schema" is null``, see
https://github.com/zhengruifeng/spark/actions/runs/21478957166/job/61870333678
```
JVM stacktrace:
org.apache.spark.SparkException: Job aborted due to stage failure: Task 0
in stage 1.0 failed 1 times, most recent failure: Lost task 0.0 in stage 1.0
(TID 2) (localhost executor driver): java.lang.NullPointerException: Cannot
invoke "org.apache.spark.sql.types.StructType.map(scala.Function1)" because
"schema" is null
at
org.apache.spark.sql.util.ArrowUtils$.toArrowSchema(ArrowUtils.scala:316)
at
org.apache.spark.sql.execution.python.PythonArrowInput.$init$(PythonArrowInput.scala:66)
at
org.apache.spark.sql.execution.python.streaming.ApplyInPandasWithStatePythonRunner.<init>(ApplyInPandasWithStatePythonRunner.scala:68)
at
```
Current fix is kind of tricky and doesn't cover all python nodes (e.g, we
don't do it for `ArrowPythonUDTFRunner` and the behvaior might be undefined),
and I suspect we still have similar issues on other `val`s in some cases, and
`override val`s from subclasses are not always respected (`Any` treated as
null, integers treated as zero, etc) .
To resolve it, I think we can change `schema`, `timeZoneId`,
`errorOnDuplicatedFieldNames`, `largeVarTypes` to `def` since they are only
used once (to get the arrow schema), and make `allocator/root` lazy
2, in case of mixin of `ArrowPythonRunner` and `GroupedPythonArrowInput`
like
https://github.com/apache/spark/blob/a03bedb6c1281c5263a42bfd20608d2ee005ab05/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowWindowPythonEvaluatorFactory.scala#L371-L381
the `runner` here actually inherits the `allocator/root` from
`ArrowPythonWithNamedArgumentRunner` -> `RowInputArrowPythonRunner` ->
`BasicPythonArrowInput` -> `PythonArrowInput`, but the `allocator` and `root`
are not used in this `runner`, that's because the the `GroupedPythonArrowInput`
has override the `def newWriter` method, and the `GroupedPythonArrowInput` will
create another `allocator/root` in its call of `createAndStartArrowWriter`.
In this case, we also need to make `allocator` and `root` lazy to avoid
unnecessary resource allocation.
### Does this PR introduce _any_ user-facing change?
undefined behavior / failures due to weird initialization order -> clearly
respect subclass's override
### How was this patch tested?
CI
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #54068 from zhengruifeng/restore_init.
Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
.../sql/execution/python/ArrowPythonRunner.scala | 26 +++++++++------------
.../sql/execution/python/PythonArrowInput.scala | 27 ++++++++++++----------
.../ApplyInPandasWithStatePythonRunner.scala | 7 ++----
.../TransformWithStateInPySparkPythonRunner.scala | 14 ++++-------
4 files changed, 33 insertions(+), 41 deletions(-)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
index 39d82b4b037b..7a12dbd556bf 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
@@ -32,8 +32,8 @@ abstract class BaseArrowPythonRunner[IN, OUT <: AnyRef](
funcs: Seq[(ChainedPythonFunctions, Long)],
evalType: Int,
argOffsets: Array[Array[Int]],
- _schema: StructType,
- _timeZoneId: String,
+ override protected val schema: StructType,
+ override protected val timeZoneId: String,
protected override val largeVarTypes: Boolean,
override val pythonMetrics: Map[String, SQLMetric],
jobArtifactUUID: Option[String],
@@ -67,10 +67,6 @@ abstract class BaseArrowPythonRunner[IN, OUT <: AnyRef](
override val hideTraceback: Boolean = SQLConf.get.pysparkHideTraceback
override val simplifiedTraceback: Boolean =
SQLConf.get.pysparkSimplifiedTraceback
- // Use lazy val to initialize the fields before these are accessed in
[[PythonArrowInput]]'s
- // constructor.
- override protected lazy val timeZoneId: String = _timeZoneId
- override protected lazy val schema: StructType = _schema
override val bufferSize: Int = SQLConf.get.pandasUDFBufferSize
require(
bufferSize >= 4,
@@ -82,14 +78,14 @@ abstract class RowInputArrowPythonRunner(
funcs: Seq[(ChainedPythonFunctions, Long)],
evalType: Int,
argOffsets: Array[Array[Int]],
- _schema: StructType,
- _timeZoneId: String,
+ schema: StructType,
+ timeZoneId: String,
largeVarTypes: Boolean,
pythonMetrics: Map[String, SQLMetric],
jobArtifactUUID: Option[String],
sessionUUID: Option[String])
extends BaseArrowPythonRunner[Iterator[InternalRow], ColumnarBatch](
- funcs, evalType, argOffsets, _schema, _timeZoneId, largeVarTypes,
+ funcs, evalType, argOffsets, schema, timeZoneId, largeVarTypes,
pythonMetrics, jobArtifactUUID, sessionUUID)
with BasicPythonArrowInput
with BasicPythonArrowOutput
@@ -101,15 +97,15 @@ class ArrowPythonRunner(
funcs: Seq[(ChainedPythonFunctions, Long)],
evalType: Int,
argOffsets: Array[Array[Int]],
- _schema: StructType,
- _timeZoneId: String,
+ schema: StructType,
+ timeZoneId: String,
largeVarTypes: Boolean,
pythonRunnerConf: Map[String, String],
pythonMetrics: Map[String, SQLMetric],
jobArtifactUUID: Option[String],
sessionUUID: Option[String])
extends RowInputArrowPythonRunner(
- funcs, evalType, argOffsets, _schema, _timeZoneId, largeVarTypes,
+ funcs, evalType, argOffsets, schema, timeZoneId, largeVarTypes,
pythonMetrics, jobArtifactUUID, sessionUUID) {
override protected def runnerConf: Map[String, String] = super.runnerConf ++
pythonRunnerConf
@@ -126,15 +122,15 @@ class ArrowPythonWithNamedArgumentRunner(
funcs: Seq[(ChainedPythonFunctions, Long)],
evalType: Int,
argMetas: Array[Array[ArgumentMetadata]],
- _schema: StructType,
- _timeZoneId: String,
+ schema: StructType,
+ timeZoneId: String,
largeVarTypes: Boolean,
pythonRunnerConf: Map[String, String],
pythonMetrics: Map[String, SQLMetric],
jobArtifactUUID: Option[String],
sessionUUID: Option[String])
extends RowInputArrowPythonRunner(
- funcs, evalType, argMetas.map(_.map(_.offset)), _schema, _timeZoneId,
largeVarTypes,
+ funcs, evalType, argMetas.map(_.map(_.offset)), schema, timeZoneId,
largeVarTypes,
pythonMetrics, jobArtifactUUID, sessionUUID) {
override protected def runnerConf: Map[String, String] = super.runnerConf ++
pythonRunnerConf
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala
index d2d16b0c9623..45011fa3cebc 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala
@@ -20,6 +20,7 @@ import java.io.DataOutputStream
import java.nio.channels.Channels
import org.apache.arrow.compression.{Lz4CompressionCodec, ZstdCompressionCodec}
+import org.apache.arrow.memory.BufferAllocator
import org.apache.arrow.vector.{VectorSchemaRoot, VectorUnloader}
import org.apache.arrow.vector.compression.{CompressionCodec,
NoCompressionCodec}
import org.apache.arrow.vector.ipc.ArrowStreamWriter
@@ -42,13 +43,13 @@ import org.apache.spark.util.Utils
* JVM (an iterator of internal rows + additional data if required) to Python
(Arrow).
*/
private[python] trait PythonArrowInput[IN] { self: BasePythonRunner[IN, _] =>
- protected val schema: StructType
+ protected def schema: StructType
- protected val timeZoneId: String
+ protected def timeZoneId: String
- protected val errorOnDuplicatedFieldNames: Boolean
+ protected def errorOnDuplicatedFieldNames: Boolean
- protected val largeVarTypes: Boolean
+ protected def largeVarTypes: Boolean
protected def pythonMetrics: Map[String, SQLMetric]
@@ -62,15 +63,17 @@ private[python] trait PythonArrowInput[IN] { self:
BasePythonRunner[IN, _] =>
protected def handleMetadataBeforeExec(stream: DataOutputStream): Unit = {}
- private val arrowSchema = ArrowUtils.toArrowSchema(
- schema, timeZoneId, errorOnDuplicatedFieldNames, largeVarTypes)
- protected val allocator =
+ protected lazy val allocator: BufferAllocator =
ArrowUtils.rootAllocator.newChildAllocator(s"stdout writer for
$pythonExec", 0, Long.MaxValue)
- protected val root = VectorSchemaRoot.create(arrowSchema, allocator)
+
+ protected lazy val root: VectorSchemaRoot = {
+ val arrowSchema = ArrowUtils.toArrowSchema(
+ schema, timeZoneId, errorOnDuplicatedFieldNames, largeVarTypes)
+ VectorSchemaRoot.create(arrowSchema, allocator)
+ }
// Create compression codec based on config
- private val compressionCodecName = SQLConf.get.arrowCompressionCodec
- private val codec = compressionCodecName match {
+ protected def codec: CompressionCodec = SQLConf.get.arrowCompressionCodec
match {
case "none" => NoCompressionCodec.INSTANCE
case "zstd" =>
val compressionLevel = SQLConf.get.arrowZstdCompressionLevel
@@ -85,7 +88,6 @@ private[python] trait PythonArrowInput[IN] { self:
BasePythonRunner[IN, _] =>
throw SparkException.internalError(
s"Unsupported Arrow compression codec: $other. Supported values: none,
zstd, lz4")
}
- protected val unloader = new VectorUnloader(root, true, codec, true)
protected var writer: ArrowStreamWriter = _
@@ -130,7 +132,8 @@ private[python] trait PythonArrowInput[IN] { self:
BasePythonRunner[IN, _] =>
private[python] trait BasicPythonArrowInput extends
PythonArrowInput[Iterator[InternalRow]] {
self: BasePythonRunner[Iterator[InternalRow], _] =>
- protected val arrowWriter: arrow.ArrowWriter = ArrowWriter.create(root)
+ protected lazy val arrowWriter: arrow.ArrowWriter = ArrowWriter.create(root)
+ protected lazy val unloader = new VectorUnloader(root, true, codec, true)
protected val maxRecordsPerBatch: Int = {
val v = SQLConf.get.arrowMaxRecordsPerBatch
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/ApplyInPandasWithStatePythonRunner.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/ApplyInPandasWithStatePythonRunner.scala
index 2c9dfc664833..165b90046630 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/ApplyInPandasWithStatePythonRunner.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/ApplyInPandasWithStatePythonRunner.scala
@@ -57,7 +57,7 @@ class ApplyInPandasWithStatePythonRunner(
evalType: Int,
argOffsets: Array[Array[Int]],
inputSchema: StructType,
- _timeZoneId: String,
+ override protected val timeZoneId: String,
initialRunnerConf: Map[String, String],
stateEncoder: ExpressionEncoder[Row],
keySchema: StructType,
@@ -84,10 +84,7 @@ class ApplyInPandasWithStatePythonRunner(
private val sqlConf = SQLConf.get
- // Use lazy val to initialize the fields before these are accessed in
[[PythonArrowInput]]'s
- // constructor.
- override protected lazy val schema: StructType = inputSchema.add("__state",
STATE_METADATA_SCHEMA)
- override protected lazy val timeZoneId: String = _timeZoneId
+ override val schema: StructType = inputSchema.add("__state",
STATE_METADATA_SCHEMA)
override val errorOnDuplicatedFieldNames: Boolean = true
override val hideTraceback: Boolean = sqlConf.pysparkHideTraceback
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkPythonRunner.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkPythonRunner.scala
index effc425f807e..e6c990ab1c89 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkPythonRunner.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkPythonRunner.scala
@@ -125,7 +125,7 @@ class TransformWithStateInPySparkPythonInitialStateRunner(
dataSchema: StructType,
initStateSchema: StructType,
processorHandle: StatefulProcessorHandleImpl,
- _timeZoneId: String,
+ timeZoneId: String,
initialRunnerConf: Map[String, String],
override val pythonMetrics: Map[String, SQLMetric],
jobArtifactUUID: Option[String],
@@ -133,12 +133,12 @@ class TransformWithStateInPySparkPythonInitialStateRunner(
batchTimestampMs: Option[Long],
eventTimeWatermarkForEviction: Option[Long])
extends TransformWithStateInPySparkPythonBaseRunner[GroupedInType](
- funcs, evalType, argOffsets, dataSchema, processorHandle, _timeZoneId,
+ funcs, evalType, argOffsets, dataSchema, processorHandle, timeZoneId,
initialRunnerConf, pythonMetrics, jobArtifactUUID, groupingKeySchema,
batchTimestampMs, eventTimeWatermarkForEviction)
with PythonArrowInput[GroupedInType] {
- override protected lazy val schema: StructType = new StructType()
+ override protected val schema: StructType = new StructType()
.add("inputData", dataSchema)
.add("initState", initStateSchema)
@@ -218,9 +218,9 @@ abstract class
TransformWithStateInPySparkPythonBaseRunner[I](
funcs: Seq[(ChainedPythonFunctions, Long)],
evalType: Int,
argOffsets: Array[Array[Int]],
- _schema: StructType,
+ override protected val schema: StructType,
processorHandle: StatefulProcessorHandleImpl,
- _timeZoneId: String,
+ override protected val timeZoneId: String,
initialRunnerConf: Map[String, String],
override val pythonMetrics: Map[String, SQLMetric],
jobArtifactUUID: Option[String],
@@ -244,10 +244,6 @@ abstract class
TransformWithStateInPySparkPythonBaseRunner[I](
SQLConf.ARROW_EXECUTION_MAX_BYTES_PER_BATCH.key ->
arrowMaxBytesPerBatch.toString
)
- // Use lazy val to initialize the fields before these are accessed in
[[PythonArrowInput]]'s
- // constructor.
- override protected lazy val schema: StructType = _schema
- override protected lazy val timeZoneId: String = _timeZoneId
override protected val errorOnDuplicatedFieldNames: Boolean = true
override protected val largeVarTypes: Boolean = sqlConf.arrowUseLargeVarTypes
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]