cloud-fan commented on code in PR #56061: URL: https://github.com/apache/spark/pull/56061#discussion_r3308786018
########## sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/WidenStatefulOperatorAttributeNullability.scala: ########## @@ -0,0 +1,174 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, ExprId} +import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.StructType + +/** + * Shared helpers for the stateful-operator nullability fix. The fix has three + * independent components, all gated by + * [[SQLConf.STATEFUL_OPERATOR_ALWAYS_NULLABLE_OUTPUT]] (pinned per-query via the + * offset log so existing queries keep their pre-fix behavior on restart): + * + * - (a) `widenStateSchema`: explicit `asNullable` at every state-schema construction + * site in each stateful physical exec. + * - (b) `widenOutputForStatefulOp`: a per-op `output` override on every stateful logical + * and physical operator, used by the operator's `output` definition. + * - (c) [[WidenStatefulOperatorAttributeNullability]] (defined below in this file): a + * custom optimizer rule that widens `AttributeReference`s inside stateful ops' + * internal expressions and propagates upward to ancestor expressions. + */ +object WidenStatefulOpNullability { + + def isEnabled: Boolean = + SQLConf.get.getConf(SQLConf.STATEFUL_OPERATOR_ALWAYS_NULLABLE_OUTPUT) + + /** + * Recursively widens an attribute to be fully nullable: outer `nullable = true` plus + * every nested `StructField.nullable`, `ArrayType.containsNull`, and + * `MapType.valueContainsNull` flipped to `true` via + * [[org.apache.spark.sql.types.DataType#asNullable]]. + */ + def deepWidenAttribute(a: Attribute): Attribute = a match { + case ref: AttributeReference => + AttributeReference( + ref.name, ref.dataType.asNullable, nullable = true, ref.metadata)( + ref.exprId, ref.qualifier) + case other => other.withNullability(true) + } + + /** + * Component (a): widens a state schema to fully nullable. Stateful physical execs apply + * this at every `validateAndMaybeEvolveStateSchema(...)` call site and every + * `mapPartitionsWith*StateStore(...)` call site. When the conf is off, returns the + * schema unchanged. + */ + def widenStateSchema(schema: StructType): StructType = + if (isEnabled) schema.asNullable else schema + + /** + * Component (b): wraps a stateful operator's `output` to be fully nullable. The caller + * is responsible for only calling this from within an `output` definition on a stateful + * operator; gating is handled here via [[isEnabled]]. + */ + def widenOutputForStatefulOp(base: Seq[Attribute]): Seq[Attribute] = + if (isEnabled) base.map(deepWidenAttribute) else base + + /** + * Recursively walks a schema and replaces any nested `StructType` that + * structurally matches `original` (by field names and base types, ignoring + * nullability) with `widened`. Used by TransformWithState execs to widen + * the grouping-key portion of col-family key schemas without touching + * user-defined key/value portions. + */ + def widenGroupingKeyInSchema( + schema: StructType, + original: StructType, + widened: StructType): StructType = { + if (!isEnabled) return schema + if (structurallyMatches(schema, original)) { + widened + } else { + StructType(schema.fields.map { field => + field.dataType match { + case st: StructType if structurallyMatches(st, original) => + field.copy(dataType = widened) + case st: StructType => + field.copy(dataType = + widenGroupingKeyInSchema(st, original, widened)) + case _ => field + } + }) + } + } + + private def structurallyMatches( + a: StructType, b: StructType): Boolean = { + a.length == b.length && a.zip(b).forall { case (fa, fb) => + fa.name == fb.name && + fa.dataType.typeName == fb.dataType.typeName Review Comment: **Newly introduced** (`94dda32`). `structurallyMatches` only checks field count, names, and top-level `dataType.typeName` — insufficient to identify the grouping-key portion within a composite schema, and dangerous because the match-and-substitute logic in `widenGroupingKeyInSchema` replaces the matched field's `dataType` with `widened` (the grouping-key's widened schema), not with the field's own widened version. Concrete failure in `TransformWithStateExec` map state. User defines `case class GK(items: Seq[Int])` as grouping key and `case class UK(items: Seq[String])` as map state user-key. Composite key is `StructType("key" -> GK, "userKey" -> UK)`. In `widenGroupingKeyInSchema(composite, original=GK, widened=widenedGK)`: - `"key"` field (`GK`): structurally matches `GK` → replaced with `widenedGK`. Correct. - `"userKey"` field (`UK`): `structurallyMatches(UK, GK)` is `true` (`length=1`, name `"items"`, `typeName="array"` all match) → also replaced with `widenedGK`, not widened `UK`. The stored schema and `keyStateEncoderSpec` now claim `userKey.items: Array<Int>` while actual data is `Array<String>` — state decode on restart mis-decodes the user-key. Same shape arises with any pair of structs that share field count, names, and top-level type names but differ in nested types (`ArrayType` element, `MapType` key/value, nested `StructType`). `DataType.equalsIgnoreNullability` (already in `DataType.scala`) recursively compares everything except nullability and is a drop-in replacement that also restores the helper's Scaladoc promise "without touching user-defined key/value portions": ```suggestion private def structurallyMatches( a: StructType, b: StructType): Boolean = DataType.equalsIgnoreNullability(a, b) ``` (Plus `import org.apache.spark.sql.types.DataType` at the top of the file.) ########## sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkExec.scala: ########## @@ -127,14 +131,47 @@ case class TransformWithStateInPySparkExec( // Each state variable has its own schema, this is a dummy one. protected val schemaForValueRow: StructType = new StructType().add("value", BinaryType) + private lazy val widenedGroupingKeySchema: StructType = + WidenStatefulOpNullability.widenStateSchema(groupingKeySchema) + override def getColFamilySchemas( shouldBeNullable: Boolean): Map[String, StateStoreColFamilySchema] = { // For Python, the user can explicitly set nullability on schema, so // we need to throw an error if the schema is nullable - driverProcessorHandle.getColumnFamilySchemas( + val schemas = driverProcessorHandle.getColumnFamilySchemas( shouldCheckNullable = shouldBeNullable, shouldSetNullable = shouldBeNullable ) + widenColFamilyGroupingKeys(schemas) + } + + private def widenColFamilyGroupingKeys( + schemas: Map[String, StateStoreColFamilySchema]) + : Map[String, StateStoreColFamilySchema] = { + val original = groupingKeySchema + val widened = widenedGroupingKeySchema + if (original == widened) return schemas Review Comment: **Newly introduced** (`94dda32` + `0fa91ec`). This early return makes the value-schema widening at L172 (added in `0fa91ec` "for consistency with other stateful operators") dead code on the Python TWS path. `groupingKeySchema` (L122) is built from `groupingAttributes`. By the time the physical exec is constructed, component (c) has already widened `groupingAttributes` at the optimizer level, so all fields are `nullable=true`. Then `widenedGroupingKeySchema = widenStateSchema(groupingKeySchema) = groupingKeySchema.asNullable = groupingKeySchema`. `original == widened` → early return fires unconditionally, and `widenStateSchema(cf.valueSchema)` never runs. `TransformWithStateExec` (Scala) doesn't hit this because it uses `keyEncoder.schema` — the user's case-class encoder schema, independent of (c)'s widening — so `original != widened` for typical non-`Option` case classes and the value widening does run. Net effect: in non-Avro Python TWS (where `shouldBeNullable=false`, so the processor handle doesn't pre-widen user schemas), user-defined non-nullable state value schemas are written to disk un-widened. The state-schema compat check still passes (both stored and new are equally non-nullable), but the design's component-(a) intent is silently dropped on the Python branch, and the Scala/Python paths diverge invisibly. Suggested fix: gate on `WidenStatefulOpNullability.isEnabled` instead of `original == widened`, so the value-widening pass always runs when the conf is on, even when the grouping key happens to need no widening. The pair of `widenColFamilyGroupingKeys` methods in the two TWS execs are otherwise identical — consolidating into `TransformWithStateExecBase` (with a per-exec `original` accessor) would prevent this kind of asymmetry from re-emerging. ########## sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingStatefulOperatorNullabilityDriftSuite.scala: ########## @@ -0,0 +1,534 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import org.apache.hadoop.fs.Path + +import org.apache.spark.SparkUnsupportedOperationException +import org.apache.spark.sql.{DataFrame, Encoders} +import org.apache.spark.sql.catalyst.analysis.WidenStatefulOpNullability +import org.apache.spark.sql.execution.streaming.checkpointing.CheckpointFileManager +import org.apache.spark.sql.execution.streaming.runtime.MemoryStream +import org.apache.spark.sql.execution.streaming.state.{RocksDBStateStoreProvider, StateSchemaCompatibilityChecker, StateStore} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructType} + +/** + * Regression suite for stateful-operator nullability drift. + * + * Driver: `PropagateEmptyRelation` drops empty `Union` branches without a streaming + * guard, so the surviving branch's per-column nullability becomes the Union's + * nullability and propagates into a stateful operator above -- across microbatches or + * restarts. + * + * Coverage: + * - New-query (default conf): originally-failing scenarios now complete cleanly. + * - Existing-query (conf forced false): pre-fix behavior preserved verbatim. + * - Helper invariant: `WidenStatefulOpNullability.deepWidenAttribute` recurses into + * nested types. + */ +class StreamingStatefulOperatorNullabilityDriftSuite extends StreamTest { + + import testImplicits._ + + private def buildTwoSources(): (MemoryStream[Int], MemoryStream[Int], DataFrame, DataFrame) = { + val inputA = MemoryStream[Int] + val inputB = MemoryStream[Int] + + val dfA = inputA.toDF().select($"value".as("key")) + val dfB = inputB.toDF() + .select(when($"value" > Int.MinValue, $"value") + .otherwise(lit(null).cast("int")) + .as("key")) + + (inputA, inputB, dfA, dfB) + } + + private def buildTwoSourcesWithWatermark() + : (MemoryStream[Int], MemoryStream[Int], DataFrame, DataFrame) = { + val inputA = MemoryStream[Int] + val inputB = MemoryStream[Int] + + val dfA = inputA.toDF() + .select($"value".as("key"), + timestamp_seconds($"value").as("ts")) + .withWatermark("ts", "1 minute") + val dfB = inputB.toDF() + .select(when($"value" > Int.MinValue, $"value") + .otherwise(lit(null).cast("int")).as("key"), + timestamp_seconds($"value").as("ts")) + .withWatermark("ts", "1 minute") + + (inputA, inputB, dfA, dfB) + } + + private def runUnionBranchDropRestart( + buildSources: () => (MemoryStream[Int], MemoryStream[Int], DataFrame, DataFrame), + buildQuery: (DataFrame, DataFrame) => DataFrame, + outputMode: OutputMode, + nullableToNonNullable: Boolean): Unit = { + withTempDir { checkpointDir => + val checkpointPath = checkpointDir.getAbsolutePath + + val (inputA, inputB, dfA, dfB) = buildSources() + val q = buildQuery(dfA, dfB) + + if (nullableToNonNullable) { + testStream(q, outputMode)( + StartStream(checkpointLocation = checkpointPath), + MultiAddData(inputA, 1, 2, 3)(inputB, 4, 5), + ProcessAllAvailable(), + StopStream + ) + } else { + testStream(q, outputMode)( + StartStream(checkpointLocation = checkpointPath), + AddData(inputA, 1, 2, 3), + ProcessAllAvailable(), + StopStream + ) + } + + assertJournaledStateSchemaAllNullable(checkpointPath) + + if (nullableToNonNullable) { + testStream(q, outputMode)( + StartStream(checkpointLocation = checkpointPath), + AddData(inputA, 6), + ProcessAllAvailable() + ) + } else { + testStream(q, outputMode)( + StartStream(checkpointLocation = checkpointPath), + MultiAddData(inputA, 6)(inputB, 7), + ProcessAllAvailable() + ) + } + } + } + + private def assertJournaledStateSchemaAllNullable(checkpointPath: String): Unit = { + val partId = StateStore.PARTITION_ID_TO_CHECK_SCHEMA + val operatorRoot = new Path(checkpointPath, "state/0") + val partitionRoot = new Path(operatorRoot, s"$partId") + val hadoopConf = spark.sessionState.newHadoopConf() + val fm = CheckpointFileManager.create(operatorRoot, hadoopConf) + val fs = operatorRoot.getFileSystem(hadoopConf) + + def collectSchemaFiles(dir: Path): Seq[Path] = { + if (!fm.exists(dir)) return Seq.empty + if (fs.getFileStatus(dir).isDirectory) { + fs.listStatus(dir).filter(_.isFile).map(_.getPath).toSeq + } else { + Seq(dir) + } + } + + val schemaFiles = scala.collection.mutable.ArrayBuffer.empty[Path] + + val storeDirs = scala.collection.mutable.ArrayBuffer(partitionRoot) + if (fs.exists(partitionRoot)) { + fs.listStatus(partitionRoot) + .filter(_.isDirectory) + .filterNot(_.getPath.getName.startsWith("_")) + .foreach(d => storeDirs += d.getPath) + } + storeDirs.foreach { storeDir => + schemaFiles ++= collectSchemaFiles( + new Path(storeDir, "_metadata/schema")) + } + + val stateSchemaRoot = new Path(operatorRoot, "_stateSchema") + if (fs.exists(stateSchemaRoot)) { + fs.listStatus(stateSchemaRoot) + .filter(_.isDirectory) + .foreach { storeDir => + schemaFiles ++= collectSchemaFiles(storeDir.getPath) + } + } + + assert(schemaFiles.nonEmpty, + s"expected at least one schema file under $operatorRoot") + schemaFiles.foreach { schemaFile => + val inStream = fm.open(schemaFile) + try { + val schemas = StateSchemaCompatibilityChecker.readSchemaFile(inStream) + schemas.foreach { s => + assertSchemaAllNullable(s.keySchema, + s"$schemaFile: key schema for col family ${s.colFamilyName}") Review Comment: **Newly introduced** (`94dda32`). Removing the `valueSchema` assertion masks regression coverage for a real code path: the subsequent commit `0fa91ec` re-added `valueSchema = widenStateSchema(cf.valueSchema)` to both TWS execs "for consistency with other stateful operators". The existing Scala TWS test at L403 would actually exercise this — `keyEncoder.schema` differs from `widenStateSchema(...)` for the test's non-nullable grouping key, so the early-return doesn't fire and the value widening runs. Re-adding the assertion also surfaces the Python TWS early-return issue I describe inline on `TransformWithStateInPySparkExec.scala` once a Python TWS regression test is added. ```suggestion schemas.foreach { s => assertSchemaAllNullable(s.keySchema, s"$schemaFile: key schema for col family ${s.colFamilyName}") assertSchemaAllNullable(s.valueSchema, s"$schemaFile: value schema for col family ${s.colFamilyName}") } ``` -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
