This is an automated email from the ASF dual-hosted git repository.

HeartSaVioR pushed a commit to branch branch-4.x
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-4.x by this push:
     new f4bec8339d78 [SPARK-57003][SQL][SS] Widen stateful operator output and 
state schema nullability
f4bec8339d78 is described below

commit f4bec8339d7879eb7db4035dc6adce84319c2ecb
Author: Jungtaek Lim <[email protected]>
AuthorDate: Wed May 27 22:30:16 2026 +0900

    [SPARK-57003][SQL][SS] Widen stateful operator output and state schema 
nullability
    
    ### What changes were proposed in this pull request?
    
    Introduce a three-component fix for stateful-operator nullability drift, 
gated by `spark.sql.streaming.statefulOperator.alwaysNullableOutput.enabled` 
(pinned per-query via the offset log):
    
    - (a) `WidenStatefulOpNullability.widenStateSchema`: every stateful 
physical exec widens its state key/value schema to fully nullable at 
construction. This covers `StateStoreSaveExec`, `BaseStreamingDeduplicateExec`, 
`StreamingSymmetricHashJoinExec`, `FlatMapGroupsWithStateExec`, 
`TransformWithStateExec` (including user-defined state variable col family 
schemas), `TransformWithStateInPySparkExec`, and `StreamingGlobalLimitExec`.
    - (b) `WidenStatefulOpNullability.widenOutputForStatefulOp`: every stateful 
logical and physical operator widens its declared `output` to fully nullable.
    - (c) `WidenStatefulOperatorAttributeNullability`: an optimizer rule that 
widens `AttributeReference`s inside stateful ops' internal expressions and 
propagates upward through ancestor expressions. The rule uses 
`resolveOperatorsUp` (bottom-up) and scopes the widening precisely: at a 
stateful operator, all children's output is included (for internal expression 
references like grouping keys); at non-stateful ancestors, only children whose 
subtrees contain a stateful operator are include [...]
    
    With the above fix, we aim to ensure the state schema to be "fully" 
nullable (top level column, nested column, and collection types) regardless of 
the input schema, and the output schema of the stateful operator to be also 
"fully" nullable as well. The change of output schema for stateful operator is 
necessary, because even if the input schema is non-nullable, state can produce 
the null value, hence the output can be nullable.
    
    ### Why are the changes needed?
    
    This has been a long standing issue of streaming engine vs Query Optimizer.
    
    By the nature of streaming query, the query is meant to be long-running, in 
many cases spans to multiple Spark versions. Also, the logical plan is not 
always the same across batches (e.g. there are multiple stream sources and one 
of the source does not have a new data at batch N). This puts the streaming 
query to be affected by analyzer and optimizer.
    
    The state schema of stateful operator is mostly determined by the input 
schema of the stateful operator, and nullability isn't an exception. If the 
input schema has a nullable column, state schema would have a nullable column. 
Vice versa with non-nullable column.
    
    For Query Optimizer, one of the optimizations is to flip the nullability, 
say, nullable to non-nullable if appropriate. This can be done directly or 
indirectly, and the most problematic case is when the optimization is applied 
"selectively".
    
    The one of easy example is the elimination of Union: for the streaming 
query with multiple streams using Union, batch N could have one stream be 
non-empty while another stream to be empty. For that 
case,`PropagateEmptyRelation` can drop empty `Union` branches, causing a 
per-column nullability flip that propagates into a stateful operator's state 
schema across microbatches or restarts. This causes either 
`STATE_STORE_KEY_SCHEMA_NOT_COMPATIBLE` on restart or a codegen NPE when 
state-res [...]
    
    ### Does this PR introduce _any_ user-facing change?
    
    No user-visible behavior change for new queries (all stateful operator 
outputs become nullable, which is semantically correct). Existing queries keep 
their original behavior via the offset log gate.
    
    ### How was this patch tested?
    
    New `StreamingStatefulOperatorNullabilityDriftSuite` covering:
    - New-query path: Union-branch-drop restart scenarios for aggregate, 
dropDuplicates, dropDuplicatesWithinWatermark, stream-stream join, 
flatMapGroupsWithState, and transformWithState.
    - Codegen NPE regression with struct grouping keys.
    - Existing-query path: widening forced off still triggers schema mismatch.
    - State schema assertion validates all state stores and column families 
(both v2 file format and v3 directory format including `_stateSchema`).
    - Rule-level: scope check (non-stateful subtrees skipped).
    - Helper-level: `deepWidenAttribute` recursion into nested types.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    Yes. Generated-by: Claude 4.7 Opus
    
    Closes #56061 from HeartSaVioR/widen-stateful-op-nullability.
    
    Authored-by: Jungtaek Lim <[email protected]>
    Signed-off-by: Jungtaek Lim <[email protected]>
    (cherry picked from commit 0fb04a4ac9aa41cc5b9d0fa4c42877d2f1f450eb)
    Signed-off-by: Jungtaek Lim <[email protected]>
---
 ...WidenStatefulOperatorAttributeNullability.scala | 167 +++++++
 .../plans/logical/basicLogicalOperators.scala      |  32 +-
 .../spark/sql/catalyst/plans/logical/object.scala  |  12 +-
 .../plans/logical/pythonLogicalOperators.scala     |  10 +-
 .../org/apache/spark/sql/internal/SQLConf.scala    |  18 +
 .../streaming/ClientStreamingQuerySuite.scala      |   2 +-
 .../sql/execution/adaptive/AQEOptimizer.scala      |   5 +-
 .../FlatMapGroupsInPandasWithStateExec.scala       |   4 +-
 .../TransformWithStateInPySparkExec.scala          |  45 +-
 .../streaming/checkpointing/OffsetSeq.scala        |   6 +-
 .../FlatMapGroupsWithStateExec.scala               |  24 +-
 .../join/StreamingSymmetricHashJoinExec.scala      |  34 +-
 .../operators/stateful/statefulOperators.scala     |  87 ++--
 .../operators/stateful/streamingLimits.scala       |   4 +-
 .../TransformWithStateExec.scala                   |  55 ++-
 .../streaming/runtime/IncrementalExecution.scala   |   4 +-
 .../spark/sql/streaming/StreamingJoinSuite.scala   |  10 +-
 .../spark/sql/streaming/StreamingJoinV4Suite.scala |  12 +-
 ...mingStatefulOperatorNullabilityDriftSuite.scala | 534 +++++++++++++++++++++
 .../sql/streaming/TransformWithStateSuite.scala    |   4 +-
 20 files changed, 976 insertions(+), 93 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/WidenStatefulOperatorAttributeNullability.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/WidenStatefulOperatorAttributeNullability.scala
new file mode 100644
index 000000000000..b2ce8780a2ed
--- /dev/null
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/WidenStatefulOperatorAttributeNullability.scala
@@ -0,0 +1,167 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.analysis
+
+import org.apache.spark.sql.catalyst.expressions.{Attribute, 
AttributeReference, ExprId}
+import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan}
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types.{DataType, StructType}
+
+/**
+ * Shared helpers for the stateful-operator nullability fix. The fix has three
+ * independent components, all gated by
+ * [[SQLConf.STATEFUL_OPERATOR_ALWAYS_NULLABLE_OUTPUT]] (pinned per-query via 
the
+ * offset log so existing queries keep their pre-fix behavior on restart):
+ *
+ *   - (a) `widenStateSchema`: explicit `asNullable` at every state-schema 
construction
+ *         site in each stateful physical exec.
+ *   - (b) `widenOutputForStatefulOp`: a per-op `output` override on every 
stateful logical
+ *         and physical operator, used by the operator's `output` definition.
+ *   - (c) [[WidenStatefulOperatorAttributeNullability]] (defined below in 
this file): a
+ *         custom optimizer rule that widens `AttributeReference`s inside 
stateful ops'
+ *         internal expressions and propagates upward to ancestor expressions.
+ */
+object WidenStatefulOpNullability {
+
+  def isEnabled: Boolean =
+    SQLConf.get.getConf(SQLConf.STATEFUL_OPERATOR_ALWAYS_NULLABLE_OUTPUT)
+
+  /**
+   * Recursively widens an attribute to be fully nullable: outer `nullable = 
true` plus
+   * every nested `StructField.nullable`, `ArrayType.containsNull`, and
+   * `MapType.valueContainsNull` flipped to `true` via
+   * [[org.apache.spark.sql.types.DataType#asNullable]].
+   */
+  def deepWidenAttribute(a: Attribute): Attribute = a match {
+    case ref: AttributeReference =>
+      AttributeReference(
+        ref.name, ref.dataType.asNullable, nullable = true, ref.metadata)(
+        ref.exprId, ref.qualifier)
+    case other => other.withNullability(true)
+  }
+
+  /**
+   * Component (a): widens a state schema to fully nullable. Stateful physical 
execs apply
+   * this at every `validateAndMaybeEvolveStateSchema(...)` call site and every
+   * `mapPartitionsWith*StateStore(...)` call site. When the conf is off, 
returns the
+   * schema unchanged.
+   */
+  def widenStateSchema(schema: StructType): StructType =
+    if (isEnabled) schema.asNullable else schema
+
+  /**
+   * Component (b): wraps a stateful operator's `output` to be fully nullable. 
The caller
+   * is responsible for only calling this from within an `output` definition 
on a stateful
+   * operator; gating is handled here via [[isEnabled]].
+   */
+  def widenOutputForStatefulOp(base: Seq[Attribute]): Seq[Attribute] =
+    if (isEnabled) base.map(deepWidenAttribute) else base
+
+  /**
+   * Recursively walks a schema and replaces any nested `StructType` that
+   * structurally matches `original` (by field names and base types, ignoring
+   * nullability) with `widened`. Used by TransformWithState execs to widen
+   * the grouping-key portion of col-family key schemas without touching
+   * user-defined key/value portions.
+   */
+  def widenGroupingKeyInSchema(
+      schema: StructType,
+      original: StructType,
+      widened: StructType): StructType = {
+    if (!isEnabled) return schema
+    if (DataType.equalsIgnoreNullability(schema, original)) {
+      widened
+    } else {
+      StructType(schema.fields.map { field =>
+        field.dataType match {
+          case st: StructType
+              if DataType.equalsIgnoreNullability(st, original) =>
+            field.copy(dataType = widened)
+          case st: StructType =>
+            field.copy(dataType =
+              widenGroupingKeyInSchema(st, original, widened))
+          case _ => field
+        }
+      })
+    }
+  }
+}
+
+/**
+ * Component (c) of the stateful-operator nullability fix: a custom optimizer 
rule that
+ * widens `AttributeReference`s inside streaming-stateful operators' internal 
expressions
+ * and propagates the widening upward to ancestor operators' expressions.
+ *
+ * The rule does NOT introduce any new logical or physical node. It is purely 
an
+ * attribute-rewrite pass using `resolveOperatorsUp` (bottom-up): for every 
node whose
+ * subtree contains a stateful operator, collect `exprId`s from children's 
output, then
+ * deep-widen every `AttributeReference` in the node's expressions whose 
`exprId` is in
+ * that set via [[WidenStatefulOpNullability#deepWidenAttribute]].
+ *
+ * At a stateful operator itself, all children's output attributes are 
included because
+ * the operator's internal expressions (e.g. grouping keys) reference them 
directly.
+ * At non-stateful ancestor operators, only children whose subtrees contain a 
stateful
+ * operator are included, to avoid unnecessary widening of non-stateful 
siblings.
+ * The node's own `p.output` is not needed for non-stateful ancestors because 
the
+ * bottom-up traversal guarantees children are already transformed, so their 
output
+ * attributes are already nullable and the ancestor's expressions reference 
those
+ * children's `exprId`s.
+ *
+ * '''Scope.''' The walk only fires on nodes whose subtree contains a stateful 
operator.
+ *
+ * '''Ordering constraint.''' This rule must run AFTER every 
`UpdateAttributeNullability`
+ * invocation in both the main optimizer and AQE.
+ *
+ * '''Idempotence.''' [[WidenStatefulOpNullability#deepWidenAttribute]] is 
idempotent.
+ */
+object WidenStatefulOperatorAttributeNullability extends Rule[LogicalPlan] {
+
+  override def apply(plan: LogicalPlan): LogicalPlan = {
+    if (!conf.getConf(SQLConf.STATEFUL_OPERATOR_ALWAYS_NULLABLE_OUTPUT) ||
+        !plan.containsStatefulOperator) {
+      return plan
+    }
+    plan.resolveOperatorsUp {
+      case p if !p.resolved => p
+      case p: LeafNode => p
+      case p if !p.containsStatefulOperator => p
+      case p =>
+        val widenableAttrs = if (p.isStateful) {
+          p.output ++ p.children.flatMap(_.output)
+        } else {
+          p.children.filter(_.containsStatefulOperator).flatMap(_.output)
+        }
+        val widenableExprIds: Set[ExprId] = widenableAttrs
+          .iterator.collect { case ar: AttributeReference => ar.exprId }.toSet
+        if (widenableExprIds.isEmpty) {
+          p
+        } else {
+          p.transformExpressions {
+            case ar: AttributeReference if 
widenableExprIds.contains(ar.exprId) =>
+              val widened = WidenStatefulOpNullability.deepWidenAttribute(ar)
+              if (ar.dataType == widened.dataType && ar.nullable == 
widened.nullable) {
+                ar
+              } else {
+                widened
+              }
+          }
+        }
+    }
+  }
+}
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
index ac0784474e2e..840b26232f0b 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
@@ -18,7 +18,7 @@
 package org.apache.spark.sql.catalyst.plans.logical
 
 import org.apache.spark.sql.catalyst.{AliasIdentifier, InternalRow, 
SQLConfHelper}
-import org.apache.spark.sql.catalyst.analysis.{Analyzer, AnsiTypeCoercion, 
MultiInstanceRelation, Resolver, TypeCoercion, TypeCoercionBase, 
UnresolvedUnaryNode}
+import org.apache.spark.sql.catalyst.analysis.{Analyzer, AnsiTypeCoercion, 
MultiInstanceRelation, Resolver, TypeCoercion, TypeCoercionBase, 
UnresolvedUnaryNode, WidenStatefulOpNullability}
 import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, 
CatalogTable}
 import 
org.apache.spark.sql.catalyst.catalog.CatalogTable.VIEW_STORING_ANALYZED_PLAN
 import org.apache.spark.sql.catalyst.expressions._
@@ -746,7 +746,10 @@ case class Join(
     }
   }
 
-  override def output: Seq[Attribute] = Join.computeOutput(joinType, 
left.output, right.output)
+  override def output: Seq[Attribute] = {
+    val base = Join.computeOutput(joinType, left.output, right.output)
+    if (isStateful) WidenStatefulOpNullability.widenOutputForStatefulOp(base) 
else base
+  }
 
   override def metadataOutput: Seq[Attribute] = {
     joinType match {
@@ -1226,7 +1229,10 @@ case class Aggregate(
     expressions.forall(_.resolved) && childrenResolved && !hasWindowExpressions
   }
 
-  override def output: Seq[Attribute] = aggregateExpressions.map(_.toAttribute)
+  override def output: Seq[Attribute] = {
+    val base = aggregateExpressions.map(_.toAttribute)
+    if (isStateful) WidenStatefulOpNullability.widenOutputForStatefulOp(base) 
else base
+  }
   override def metadataOutput: Seq[Attribute] = Nil
   override def maxRows: Option[Long] = {
     if (groupingExpressions.isEmpty) {
@@ -1750,7 +1756,10 @@ object Limit {
  * order.
  */
 case class GlobalLimit(limitExpr: Expression, child: LogicalPlan) extends 
UnaryNode {
-  override def output: Seq[Attribute] = child.output
+  override def output: Seq[Attribute] = {
+    val base = child.output
+    if (isStateful) WidenStatefulOpNullability.widenOutputForStatefulOp(base) 
else base
+  }
   override def maxRows: Option[Long] = {
     limitExpr match {
       case IntegerLiteral(limit) => Some(limit)
@@ -2005,7 +2014,10 @@ case class Sample(
  */
 case class Distinct(child: LogicalPlan) extends UnaryNode {
   override def maxRows: Option[Long] = child.maxRows
-  override def output: Seq[Attribute] = child.output
+  override def output: Seq[Attribute] = {
+    val base = child.output
+    if (isStateful) WidenStatefulOpNullability.widenOutputForStatefulOp(base) 
else base
+  }
   final override val nodePatterns: Seq[TreePattern] = Seq(DISTINCT_LIKE)
   override protected def withNewChildInternal(newChild: LogicalPlan): Distinct 
=
     copy(child = newChild)
@@ -2175,7 +2187,10 @@ case class Deduplicate(
     keys: Seq[Attribute],
     child: LogicalPlan) extends UnaryNode {
   override def maxRows: Option[Long] = child.maxRows
-  override def output: Seq[Attribute] = child.output
+  override def output: Seq[Attribute] = {
+    val base = child.output
+    if (isStateful) WidenStatefulOpNullability.widenOutputForStatefulOp(base) 
else base
+  }
   final override val nodePatterns: Seq[TreePattern] = Seq(DISTINCT_LIKE)
   override protected def withNewChildInternal(newChild: LogicalPlan): 
Deduplicate =
     copy(child = newChild)
@@ -2187,7 +2202,10 @@ case class DeduplicateWithinWatermark(keys: 
Seq[Attribute], child: LogicalPlan)
   override def references: AttributeSet = AttributeSet(keys) ++
     
AttributeSet(child.output.filter(_.metadata.contains(EventTimeWatermark.delayKey)))
   override def maxRows: Option[Long] = child.maxRows
-  override def output: Seq[Attribute] = child.output
+  override def output: Seq[Attribute] = {
+    val base = child.output
+    if (isStateful) WidenStatefulOpNullability.widenOutputForStatefulOp(base) 
else base
+  }
   final override val nodePatterns: Seq[TreePattern] = Seq(DISTINCT_LIKE)
   override protected def withNewChildInternal(newChild: LogicalPlan): 
DeduplicateWithinWatermark =
     copy(child = newChild)
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
index 0c6f59073559..720b0dd640d0 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.plans.logical
 import org.apache.spark.api.java.function.FilterFunction
 import org.apache.spark.broadcast.Broadcast
 import org.apache.spark.sql.{catalyst, Encoder, Row}
-import org.apache.spark.sql.catalyst.analysis.{Resolver, 
UnresolvedDeserializer}
+import org.apache.spark.sql.catalyst.analysis.{Resolver, 
UnresolvedDeserializer, WidenStatefulOpNullability}
 import org.apache.spark.sql.catalyst.encoders._
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.objects.Invoke
@@ -568,6 +568,11 @@ case class FlatMapGroupsWithState(
       newLeft: LogicalPlan, newRight: LogicalPlan): FlatMapGroupsWithState =
     copy(child = newLeft, initialState = newRight)
   override def isStateful: Boolean = child.isStreaming
+
+  override def output: Seq[Attribute] = {
+    val base = super.output
+    if (isStateful) WidenStatefulOpNullability.widenOutputForStatefulOp(base) 
else base
+  }
 }
 
 object TransformWithState {
@@ -657,6 +662,11 @@ case class TransformWithState(
       newLeft: LogicalPlan, newRight: LogicalPlan): TransformWithState =
     copy(child = newLeft, initialState = newRight)
   override def isStateful: Boolean = child.isStreaming
+
+  override def output: Seq[Attribute] = {
+    val base = super.output
+    if (isStateful) WidenStatefulOpNullability.widenOutputForStatefulOp(base) 
else base
+  }
 }
 
 /** Factory for constructing new `FlatMapGroupsInR` nodes. */
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala
index 56dc2f6de043..31e7d9402968 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.plans.logical
 
 import org.apache.spark.resource.ResourceProfile
 import org.apache.spark.sql.catalyst.SQLConfHelper
-import org.apache.spark.sql.catalyst.analysis.{FunctionRegistryBase, 
MultiInstanceRelation, UnresolvedAttribute, UnresolvedStar}
+import org.apache.spark.sql.catalyst.analysis.{FunctionRegistryBase, 
MultiInstanceRelation, UnresolvedAttribute, UnresolvedStar, 
WidenStatefulOpNullability}
 import 
org.apache.spark.sql.catalyst.analysis.TableFunctionRegistry.TableFunctionBuilder
 import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, 
AttributeSet, Expression, ExpressionDescription, ExpressionInfo, JsonToStructs, 
PythonUDF, PythonUDTF}
 import org.apache.spark.sql.catalyst.trees.TreePattern._
@@ -159,7 +159,9 @@ case class FlatMapGroupsInPandasWithState(
     timeout: GroupStateTimeout,
     child: LogicalPlan) extends UnaryNode {
 
-  override def output: Seq[Attribute] = outputAttrs
+  override def output: Seq[Attribute] =
+    if (isStateful) 
WidenStatefulOpNullability.widenOutputForStatefulOp(outputAttrs)
+    else outputAttrs
 
   override def producedAttributes: AttributeSet = AttributeSet(outputAttrs)
 
@@ -206,7 +208,9 @@ case class TransformWithStateInPySpark(
 
   override def right: LogicalPlan = initialState
 
-  override def output: Seq[Attribute] = outputAttrs
+  override def output: Seq[Attribute] =
+    if (isStateful) 
WidenStatefulOpNullability.widenOutputForStatefulOp(outputAttrs)
+    else outputAttrs
 
   override def producedAttributes: AttributeSet = AttributeSet(outputAttrs)
 
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 328f434195f4..0aed28e92558 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -3444,6 +3444,24 @@ object SQLConf {
       .booleanConf
       .createWithDefault(true)
 
+  val STATEFUL_OPERATOR_ALWAYS_NULLABLE_OUTPUT =
+    
buildConf("spark.sql.streaming.statefulOperator.alwaysNullableOutput.enabled")
+      .internal()
+      .withBindingPolicy(ConfigBindingPolicy.SESSION)
+      .doc("When true, every streaming stateful operator reports its output 
schema with " +
+        "nullable=true on all columns (including nested struct fields, array 
elements, and " +
+        "map values), and the state schema is widened at every construction 
site, so the " +
+        "existing state schema " +
+        "compatibility check trivially passes regardless of input nullability. 
" +
+        "This prevents query-optimizer decisions (e.g., PropagateEmptyRelation 
dropping a " +
+        "Union branch) from flipping the state schema nullability across 
microbatches or " +
+        "restarts. The effective value is pinned per query via the offset log 
at batch 0, " +
+        "so pre-existing queries keep their original behavior; only newly 
started queries " +
+        "pick this up.")
+      .version("4.3.0")
+      .booleanConf
+      .createWithDefault(true)
+
   val FILESTREAM_SINK_METADATA_IGNORED =
     buildConf("spark.sql.streaming.fileStreamSink.ignoreMetadata")
       .internal()
diff --git 
a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/streaming/ClientStreamingQuerySuite.scala
 
b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/streaming/ClientStreamingQuerySuite.scala
index c8a25652dacb..057e2fdc4775 100644
--- 
a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/streaming/ClientStreamingQuerySuite.scala
+++ 
b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/streaming/ClientStreamingQuerySuite.scala
@@ -86,7 +86,7 @@ class ClientStreamingQuerySuite extends QueryTest with 
RemoteSparkSession with L
         .count()
         .selectExpr("window.start as timestamp", "count as num_events")
 
-      assert(countsDF.schema.toDDL == "timestamp TIMESTAMP,num_events BIGINT 
NOT NULL")
+      assert(countsDF.schema.toDDL == "timestamp TIMESTAMP,num_events BIGINT")
 
       // Start the query
       val queryName = "sparkConnectStreamingQuery"
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEOptimizer.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEOptimizer.scala
index f16c6d9cfe6d..3c23930090ab 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEOptimizer.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEOptimizer.scala
@@ -18,7 +18,7 @@
 package org.apache.spark.sql.execution.adaptive
 
 import org.apache.spark.internal.LogKeys.{BATCH_NAME, RULE_NAME}
-import org.apache.spark.sql.catalyst.analysis.UpdateAttributeNullability
+import org.apache.spark.sql.catalyst.analysis.{UpdateAttributeNullability, 
WidenStatefulOperatorAttributeNullability}
 import org.apache.spark.sql.catalyst.optimizer.{ConvertToLocalRelation, 
EliminateLimits, OptimizeOneRowPlan}
 import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, 
LogicalPlanIntegrity}
 import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor}
@@ -44,7 +44,8 @@ class AQEOptimizer(conf: SQLConf, 
extendedRuntimeOptimizerRules: Seq[Rule[Logica
     Batch("Dynamic Join Selection", Once, DynamicJoinSelection),
     Batch("Eliminate Limits", fixedPoint, EliminateLimits),
     Batch("Optimize One Row Plan", fixedPoint, OptimizeOneRowPlan)) :+
-    Batch("User Provided Runtime Optimizers", fixedPoint, 
extendedRuntimeOptimizerRules: _*)
+    Batch("User Provided Runtime Optimizers", fixedPoint, 
extendedRuntimeOptimizerRules: _*) :+
+    Batch("Widen Stateful Op Nullability", Once, 
WidenStatefulOperatorAttributeNullability)
 
   final override protected def batches: Seq[Batch] = {
     val excludedRules = conf.getConf(SQLConf.ADAPTIVE_OPTIMIZER_EXCLUDED_RULES)
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/FlatMapGroupsInPandasWithStateExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/FlatMapGroupsInPandasWithStateExec.scala
index e9430ed9f9b7..a61f90515836 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/FlatMapGroupsInPandasWithStateExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/FlatMapGroupsInPandasWithStateExec.scala
@@ -20,6 +20,7 @@ import org.apache.spark.{JobArtifactSet, SparkException, 
SparkUnsupportedOperati
 import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType}
 import org.apache.spark.sql.Row
 import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.analysis.WidenStatefulOpNullability
 import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.logical.{EventTimeTimeout, 
ProcessingTimeTimeout}
@@ -81,7 +82,8 @@ case class FlatMapGroupsInPandasWithStateExec(
   override protected val stateEncoder: ExpressionEncoder[Any] =
     
ExpressionEncoder(stateType).resolveAndBind().asInstanceOf[ExpressionEncoder[Any]]
 
-  override def output: Seq[Attribute] = outAttributes
+  override def output: Seq[Attribute] =
+    WidenStatefulOpNullability.widenOutputForStatefulOp(outAttributes)
 
   private val sessionLocalTimeZone = conf.sessionLocalTimeZone
   private val pythonRunnerConf = ArrowPythonRunner.getPythonRunnerConfMap(conf)
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkExec.scala
index 45f2af5c1dfe..d3fd757784e0 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkExec.scala
@@ -27,6 +27,7 @@ import org.apache.spark.api.python.{ChainedPythonFunctions, 
PythonEvalType}
 import org.apache.spark.broadcast.Broadcast
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.analysis.WidenStatefulOpNullability
 import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
 import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, 
PythonUDF}
 import org.apache.spark.sql.catalyst.plans.logical.TransformWithStateInPySpark
@@ -39,7 +40,7 @@ import 
org.apache.spark.sql.execution.streaming.operators.stateful.{StatefulOper
 import 
org.apache.spark.sql.execution.streaming.operators.stateful.join.StreamingSymmetricHashJoinHelper.StateStoreAwareZipPartitionsHelper
 import 
org.apache.spark.sql.execution.streaming.operators.stateful.transformwithstate.{TransformWithStateExecBase,
 TransformWithStateVariableInfo}
 import 
org.apache.spark.sql.execution.streaming.operators.stateful.transformwithstate.statefulprocessor.{DriverStatefulProcessorHandleImpl,
 StatefulProcessorHandleImpl}
-import 
org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec, 
RocksDBStateStoreProvider, StateSchemaValidationResult, StateStore, 
StateStoreColFamilySchema, StateStoreConf, StateStoreId, StateStoreOps, 
StateStoreProvider, StateStoreProviderId}
+import 
org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec, 
PrefixKeyScanStateEncoderSpec, RangeKeyScanStateEncoderSpec, 
RocksDBStateStoreProvider, StateSchemaValidationResult, StateStore, 
StateStoreColFamilySchema, StateStoreConf, StateStoreId, StateStoreOps, 
StateStoreProvider, StateStoreProviderId, 
TimestampAsPostfixKeyStateEncoderSpec, TimestampAsPrefixKeyStateEncoderSpec}
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.streaming.{OutputMode, TimeMode}
 import org.apache.spark.sql.types.{BinaryType, StructField, StructType}
@@ -51,7 +52,7 @@ import org.apache.spark.util.{CompletionIterator, 
SerializableConfiguration, Uti
  *
  * @param functionExpr function called on each group
  * @param groupingAttributes used to group the data
- * @param output used to define the output rows
+ * @param outputAttrs used to define the output rows
  * @param outputMode defines the output mode for the statefulProcessor
  * @param timeMode The time mode semantics of the stateful processor for 
timers and TTL.
  * @param stateInfo Used to identify the state store for a given operator.
@@ -69,7 +70,7 @@ import org.apache.spark.util.{CompletionIterator, 
SerializableConfiguration, Uti
 case class TransformWithStateInPySparkExec(
     functionExpr: Expression,
     groupingAttributes: Seq[Attribute],
-    output: Seq[Attribute],
+    outputAttrs: Seq[Attribute],
     outputMode: OutputMode,
     timeMode: TimeMode,
     stateInfo: Option[StatefulOperatorStateInfo],
@@ -94,6 +95,9 @@ case class TransformWithStateInPySparkExec(
     initialStateGroupingAttrs,
     initialState) {
 
+  override def output: Seq[Attribute] =
+    WidenStatefulOpNullability.widenOutputForStatefulOp(outputAttrs)
+
   // NOTE: This is needed to comply with existing release of 
transformWithStateInPandas.
   override def shortName: String = if (
     userFacingDataType == TransformWithStateInPySpark.UserFacingDataType.PANDAS
@@ -127,16 +131,49 @@ case class TransformWithStateInPySparkExec(
   // Each state variable has its own schema, this is a dummy one.
   protected val schemaForValueRow: StructType = new StructType().add("value", 
BinaryType)
 
+  private lazy val widenedGroupingKeySchema: StructType =
+    WidenStatefulOpNullability.widenStateSchema(groupingKeySchema)
+
   override def getColFamilySchemas(
       shouldBeNullable: Boolean): Map[String, StateStoreColFamilySchema] = {
     // For Python, the user can explicitly set nullability on schema, so
     // we need to throw an error if the schema is nullable
-    driverProcessorHandle.getColumnFamilySchemas(
+    val schemas = driverProcessorHandle.getColumnFamilySchemas(
       shouldCheckNullable = shouldBeNullable,
       shouldSetNullable = shouldBeNullable
     )
+    widenColFamilyGroupingKeys(schemas)
   }
 
+  private def widenColFamilyGroupingKeys(
+      schemas: Map[String, StateStoreColFamilySchema])
+      : Map[String, StateStoreColFamilySchema] = {
+    if (!WidenStatefulOpNullability.isEnabled) return schemas
+    val original = groupingKeySchema
+    val widened = widenedGroupingKeySchema
+    def widenKey(ks: StructType): StructType =
+      WidenStatefulOpNullability.widenGroupingKeyInSchema(
+        ks, original, widened)
+    schemas.map { case (name, cf) =>
+      val widenedSpec = cf.keyStateEncoderSpec.map {
+        case NoPrefixKeyStateEncoderSpec(ks) =>
+          NoPrefixKeyStateEncoderSpec(widenKey(ks))
+        case PrefixKeyScanStateEncoderSpec(ks, n) =>
+          PrefixKeyScanStateEncoderSpec(widenKey(ks), n)
+        case RangeKeyScanStateEncoderSpec(ks, o) =>
+          RangeKeyScanStateEncoderSpec(widenKey(ks), o)
+        case TimestampAsPrefixKeyStateEncoderSpec(ks) =>
+          TimestampAsPrefixKeyStateEncoderSpec(widenKey(ks))
+        case TimestampAsPostfixKeyStateEncoderSpec(ks) =>
+          TimestampAsPostfixKeyStateEncoderSpec(widenKey(ks))
+      }
+      name -> cf.copy(
+        keySchema = widenKey(cf.keySchema),
+        keyStateEncoderSpec = widenedSpec)
+    }
+  }
+
+
   override def getStateVariableInfos(): Map[String, 
TransformWithStateVariableInfo] = {
     driverProcessorHandle.getStateVariableInfos
   }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/checkpointing/OffsetSeq.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/checkpointing/OffsetSeq.scala
index bf2278b81492..9ba99ac2c036 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/checkpointing/OffsetSeq.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/checkpointing/OffsetSeq.scala
@@ -204,7 +204,8 @@ object OffsetSeqMetadata extends Logging {
     STATEFUL_OPERATOR_USE_STRICT_DISTRIBUTION,
     PRUNE_FILTERS_CAN_PRUNE_STREAMING_SUBPLAN, 
STREAMING_STATE_STORE_ENCODING_FORMAT,
     STATE_STORE_ROW_CHECKSUM_ENABLED, PROTOBUF_EXTENSIONS_SUPPORT_ENABLED,
-    ENABLE_STREAMING_SOURCE_EVOLUTION
+    ENABLE_STREAMING_SOURCE_EVOLUTION,
+    STATEFUL_OPERATOR_ALWAYS_NULLABLE_OUTPUT
   )
 
   /**
@@ -254,7 +255,8 @@ object OffsetSeqMetadata extends Logging {
     STATE_STORE_ROW_CHECKSUM_ENABLED.key -> "false",
     STATE_STORE_ROCKSDB_MERGE_OPERATOR_VERSION.key -> "1",
     PROTOBUF_EXTENSIONS_SUPPORT_ENABLED.key -> "false",
-    ENABLE_STREAMING_SOURCE_EVOLUTION.key -> "false"
+    ENABLE_STREAMING_SOURCE_EVOLUTION.key -> "false",
+    STATEFUL_OPERATOR_ALWAYS_NULLABLE_OUTPUT.key -> "false"
   )
 
   def readValue[T](metadataLog: OffsetSeqMetadataBase, confKey: 
ConfigEntry[T]): String = {
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/flatmapgroupswithstate/FlatMapGroupsWithStateExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/flatmapgroupswithstate/FlatMapGroupsWithStateExec.scala
index 6b9f90a9ab5c..48d1dad70f5e 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/flatmapgroupswithstate/FlatMapGroupsWithStateExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/flatmapgroupswithstate/FlatMapGroupsWithStateExec.scala
@@ -25,6 +25,7 @@ import org.apache.hadoop.conf.Configuration
 import org.apache.spark.{SparkException, SparkThrowable}
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.analysis.WidenStatefulOpNullability
 import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
 import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, 
Expression, SortOrder, UnsafeRow}
 import org.apache.spark.sql.catalyst.plans.logical._
@@ -36,6 +37,7 @@ import 
org.apache.spark.sql.execution.streaming.operators.stateful.join.Streamin
 import org.apache.spark.sql.execution.streaming.state._
 import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode}
 import org.apache.spark.sql.streaming.GroupStateTimeout.NoTimeout
+import org.apache.spark.sql.types.StructType
 import org.apache.spark.util.{CompletionIterator, SerializableConfiguration}
 
 /**
@@ -72,6 +74,11 @@ trait FlatMapGroupsWithStateExecBase
   lazy val stateManager: StateManager =
     createStateManager(stateEncoder, isTimeoutEnabled, stateFormatVersion)
 
+  private lazy val stateKeySchema: StructType =
+    
WidenStatefulOpNullability.widenStateSchema(groupingAttributes.toStructType)
+  private lazy val stateValueSchema: StructType =
+    WidenStatefulOpNullability.widenStateSchema(stateManager.stateSchema)
+
   /**
    * Distribute by grouping attributes - We need the underlying data and the 
initial state data
    * to have the same grouping so that the data are co-lacated on the same 
task.
@@ -200,7 +207,7 @@ trait FlatMapGroupsWithStateExecBase
       batchId: Long,
       stateSchemaVersion: Int): List[StateSchemaValidationResult] = {
     val newStateSchema = 
List(StateStoreColFamilySchema(StateStore.DEFAULT_COL_FAMILY_NAME, 0,
-      groupingAttributes.toStructType, 0, stateManager.stateSchema))
+      stateKeySchema, 0, stateValueSchema))
     
List(StateSchemaCompatibilityChecker.validateAndMaybeEvolveStateSchema(getStateInfo,
 hadoopConf,
       newStateSchema, session.sessionState, stateSchemaVersion))
   }
@@ -243,9 +250,9 @@ trait FlatMapGroupsWithStateExecBase
           val storeProviderId = StateStoreProviderId(stateStoreId, 
stateInfo.get.queryRunId)
           val store = StateStore.get(
             storeProviderId,
-            groupingAttributes.toStructType,
-            stateManager.stateSchema,
-            NoPrefixKeyStateEncoderSpec(groupingAttributes.toStructType),
+            stateKeySchema,
+            stateValueSchema,
+            NoPrefixKeyStateEncoderSpec(stateKeySchema),
             stateInfo.get.storeVersion,
             stateInfo.get.getStateStoreCkptId(partitionId).map(_.head),
             None,
@@ -257,9 +264,9 @@ trait FlatMapGroupsWithStateExecBase
     } else {
       child.execute().mapPartitionsWithStateStore[InternalRow](
         getStateInfo,
-        groupingAttributes.toStructType,
-        stateManager.stateSchema,
-        NoPrefixKeyStateEncoderSpec(groupingAttributes.toStructType),
+        stateKeySchema,
+        stateValueSchema,
+        NoPrefixKeyStateEncoderSpec(stateKeySchema),
         session.sessionState,
         Some(session.streams.stateStoreCoordinator)
       ) { case (store: StateStore, singleIterator: Iterator[InternalRow]) =>
@@ -425,6 +432,9 @@ case class FlatMapGroupsWithStateExec(
     skipEmittingInitialStateKeys: Boolean,
     child: SparkPlan)
   extends FlatMapGroupsWithStateExecBase with BinaryExecNode with  
ObjectProducerExec {
+  override def output: Seq[Attribute] =
+    WidenStatefulOpNullability.widenOutputForStatefulOp(super.output)
+
   import GroupStateImpl._
   import FlatMapGroupsWithStateExecHelper._
 
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/StreamingSymmetricHashJoinExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/StreamingSymmetricHashJoinExec.scala
index 9eca04c98591..8f90a603c7ef 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/StreamingSymmetricHashJoinExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/StreamingSymmetricHashJoinExec.scala
@@ -24,6 +24,7 @@ import org.apache.hadoop.conf.Configuration
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.analysis.StreamingJoinHelper
+import org.apache.spark.sql.catalyst.analysis.WidenStatefulOpNullability
 import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, 
Expression, GenericInternalRow, JoinedRow, Literal, Predicate, 
UnsafeProjection, UnsafeRow}
 import org.apache.spark.sql.catalyst.plans._
 import org.apache.spark.sql.catalyst.plans.physical._
@@ -231,13 +232,16 @@ case class StreamingSymmetricHashJoinExec(
     StatefulOpClusteredDistribution(leftKeys, getStateInfo.numPartitions) ::
       StatefulOpClusteredDistribution(rightKeys, getStateInfo.numPartitions) 
:: Nil
 
-  override def output: Seq[Attribute] = joinType match {
-    case _: InnerLike => left.output ++ right.output
-    case LeftOuter => left.output ++ right.output.map(_.withNullability(true))
-    case RightOuter => left.output.map(_.withNullability(true)) ++ right.output
-    case FullOuter => (left.output ++ 
right.output).map(_.withNullability(true))
-    case LeftSemi => left.output
-    case _ => throwBadJoinTypeException()
+  override def output: Seq[Attribute] = {
+    val base = joinType match {
+      case _: InnerLike => left.output ++ right.output
+      case LeftOuter => left.output ++ 
right.output.map(_.withNullability(true))
+      case RightOuter => left.output.map(_.withNullability(true)) ++ 
right.output
+      case FullOuter => (left.output ++ 
right.output).map(_.withNullability(true))
+      case LeftSemi => left.output
+      case _ => throwBadJoinTypeException()
+    }
+    WidenStatefulOpNullability.widenOutputForStatefulOp(base)
   }
 
   override def outputPartitioning: Partitioning = joinType match {
@@ -279,11 +283,16 @@ case class StreamingSymmetricHashJoinExec(
   override def getColFamilySchemas(
       shouldBeNullable: Boolean): Map[String, StateStoreColFamilySchema] = {
     assert(useVirtualColumnFamilies)
-    // We only have one state store for the join, but there are four distinct 
schemas
-    SymmetricHashJoinStateManager
+    val raw = SymmetricHashJoinStateManager
       .getSchemasForStateStoreWithColFamily(LeftSide, left.output, leftKeys, 
stateFormatVersion) ++
-    SymmetricHashJoinStateManager
-      .getSchemasForStateStoreWithColFamily(RightSide, right.output, 
rightKeys, stateFormatVersion)
+      SymmetricHashJoinStateManager
+        .getSchemasForStateStoreWithColFamily(
+          RightSide, right.output, rightKeys, stateFormatVersion)
+    raw.map { case (name, cf) =>
+      name -> cf.copy(
+        keySchema = WidenStatefulOpNullability.widenStateSchema(cf.keySchema),
+        valueSchema = 
WidenStatefulOpNullability.widenStateSchema(cf.valueSchema))
+    }
   }
 
   override def shouldRunAnotherBatch(newInputWatermark: Long): Boolean = {
@@ -328,7 +337,8 @@ case class StreamingSymmetricHashJoinExec(
         // we have to add the default column family schema because the 
RocksDBStateEncoder
         // expects this entry to be present in the stateSchemaProvider.
         val newStateSchema = 
List(StateStoreColFamilySchema(StateStore.DEFAULT_COL_FAMILY_NAME, 0,
-          keySchema, 0, valueSchema))
+          WidenStatefulOpNullability.widenStateSchema(keySchema), 0,
+          WidenStatefulOpNullability.widenStateSchema(valueSchema)))
         
StateSchemaCompatibilityChecker.validateAndMaybeEvolveStateSchema(getStateInfo, 
hadoopConf,
           newStateSchema, session.sessionState, stateSchemaVersion, storeName 
= stateStoreName)
       }.toList
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/statefulOperators.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/statefulOperators.scala
index 59a2b9ee74f8..022fa3469eea 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/statefulOperators.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/statefulOperators.scala
@@ -31,6 +31,7 @@ import org.apache.spark.internal.Logging
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.{AnalysisException, SparkSession}
 import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.analysis.WidenStatefulOpNullability
 import org.apache.spark.sql.catalyst.expressions._
 import 
org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
 import org.apache.spark.sql.catalyst.optimizer.NormalizeFloatingNumbers
@@ -767,11 +768,16 @@ case class StateStoreRestoreExec(
   private[sql] val stateManager = 
StreamingAggregationStateManager.createStateManager(
     keyExpressions, child.output, stateFormatVersion)
 
+  private val stateKeySchema: StructType =
+    WidenStatefulOpNullability.widenStateSchema(keyExpressions.toStructType)
+  private val stateValueSchema: StructType =
+    
WidenStatefulOpNullability.widenStateSchema(stateManager.getStateValueSchema)
+
   override def validateAndMaybeEvolveStateSchema(
       hadoopConf: Configuration, batchId: Long, stateSchemaVersion: Int):
     List[StateSchemaValidationResult] = {
     val newStateSchema = 
List(StateStoreColFamilySchema(StateStore.DEFAULT_COL_FAMILY_NAME,
-      0, keyExpressions.toStructType, 0, stateManager.getStateValueSchema))
+      0, stateKeySchema, 0, stateValueSchema))
     
List(StateSchemaCompatibilityChecker.validateAndMaybeEvolveStateSchema(getStateInfo,
       hadoopConf, newStateSchema, session.sessionState, stateSchemaVersion))
   }
@@ -781,9 +787,9 @@ case class StateStoreRestoreExec(
 
     child.execute().mapPartitionsWithReadStateStore(
       getStateInfo,
-      keyExpressions.toStructType,
-      stateManager.getStateValueSchema,
-      NoPrefixKeyStateEncoderSpec(keyExpressions.toStructType),
+      stateKeySchema,
+      stateValueSchema,
+      NoPrefixKeyStateEncoderSpec(stateKeySchema),
       session.sessionState,
       Some(session.streams.stateStoreCoordinator)) { case (store, iter) =>
       val hasInput = iter.hasNext
@@ -805,7 +811,8 @@ case class StateStoreRestoreExec(
     }
   }
 
-  override def output: Seq[Attribute] = child.output
+  override def output: Seq[Attribute] =
+    WidenStatefulOpNullability.widenOutputForStatefulOp(child.output)
 
   override def outputPartitioning: Partitioning = child.outputPartitioning
 
@@ -838,13 +845,18 @@ case class StateStoreSaveExec(
   private[sql] val stateManager = 
StreamingAggregationStateManager.createStateManager(
     keyExpressions, child.output, stateFormatVersion)
 
+  private val stateKeySchema: StructType =
+    WidenStatefulOpNullability.widenStateSchema(keyExpressions.toStructType)
+  private val stateValueSchema: StructType =
+    
WidenStatefulOpNullability.widenStateSchema(stateManager.getStateValueSchema)
+
   override def validateAndMaybeEvolveStateSchema(
       hadoopConf: Configuration,
       batchId: Long,
       stateSchemaVersion: Int): List[StateSchemaValidationResult] = {
     val newStateSchema = 
List(StateStoreColFamilySchema(StateStore.DEFAULT_COL_FAMILY_NAME,
-      keySchemaId = 0, keyExpressions.toStructType, valueSchemaId = 0,
-      stateManager.getStateValueSchema))
+      keySchemaId = 0, stateKeySchema, valueSchemaId = 0,
+      stateValueSchema))
     
List(StateSchemaCompatibilityChecker.validateAndMaybeEvolveStateSchema(getStateInfo,
       hadoopConf, newStateSchema, session.sessionState, stateSchemaVersion))
   }
@@ -856,9 +868,9 @@ case class StateStoreSaveExec(
 
     child.execute().mapPartitionsWithStateStore(
       getStateInfo,
-      keyExpressions.toStructType,
-      stateManager.getStateValueSchema,
-      NoPrefixKeyStateEncoderSpec(keyExpressions.toStructType),
+      stateKeySchema,
+      stateValueSchema,
+      NoPrefixKeyStateEncoderSpec(stateKeySchema),
       session.sessionState,
       Some(session.streams.stateStoreCoordinator)) { (store, iter) =>
         val numOutputRows = longMetric("numOutputRows")
@@ -1000,7 +1012,8 @@ case class StateStoreSaveExec(
     }
   }
 
-  override def output: Seq[Attribute] = child.output
+  override def output: Seq[Attribute] =
+    WidenStatefulOpNullability.widenOutputForStatefulOp(child.output)
 
   override def outputPartitioning: Partitioning = child.outputPartitioning
 
@@ -1054,12 +1067,17 @@ case class SessionWindowStateStoreRestoreExec(
   private val stateManager = 
StreamingSessionWindowStateManager.createStateManager(
     keyWithoutSessionExpressions, sessionExpression, child.output, 
stateFormatVersion)
 
+  private val stateKeySchema: StructType =
+    WidenStatefulOpNullability.widenStateSchema(stateManager.getStateKeySchema)
+  private val stateValueSchema: StructType =
+    
WidenStatefulOpNullability.widenStateSchema(stateManager.getStateValueSchema)
+
   override def validateAndMaybeEvolveStateSchema(
       hadoopConf: Configuration, batchId: Long, stateSchemaVersion: Int):
     List[StateSchemaValidationResult] = {
     val newStateSchema = 
List(StateStoreColFamilySchema(StateStore.DEFAULT_COL_FAMILY_NAME,
-      keySchemaId = 0, stateManager.getStateKeySchema, valueSchemaId = 0,
-      stateManager.getStateValueSchema))
+      keySchemaId = 0, stateKeySchema, valueSchemaId = 0,
+      stateValueSchema))
     
List(StateSchemaCompatibilityChecker.validateAndMaybeEvolveStateSchema(getStateInfo,
 hadoopConf,
       newStateSchema, session.sessionState, stateSchemaVersion))
   }
@@ -1069,9 +1087,9 @@ case class SessionWindowStateStoreRestoreExec(
 
     child.execute().mapPartitionsWithReadStateStore(
       getStateInfo,
-      stateManager.getStateKeySchema,
-      stateManager.getStateValueSchema,
-      PrefixKeyScanStateEncoderSpec(stateManager.getStateKeySchema,
+      stateKeySchema,
+      stateValueSchema,
+      PrefixKeyScanStateEncoderSpec(stateKeySchema,
         stateManager.getNumColsForPrefixKey),
       session.sessionState,
       Some(session.streams.stateStoreCoordinator)) { case (store, iter) =>
@@ -1099,7 +1117,8 @@ case class SessionWindowStateStoreRestoreExec(
     }
   }
 
-  override def output: Seq[Attribute] = child.output
+  override def output: Seq[Attribute] =
+    WidenStatefulOpNullability.widenOutputForStatefulOp(child.output)
 
   override def outputPartitioning: Partitioning = child.outputPartitioning
 
@@ -1147,11 +1166,16 @@ case class SessionWindowStateStoreSaveExec(
   private val stateManager = 
StreamingSessionWindowStateManager.createStateManager(
     keyWithoutSessionExpressions, sessionExpression, child.output, 
stateFormatVersion)
 
+  private val stateKeySchema: StructType =
+    WidenStatefulOpNullability.widenStateSchema(stateManager.getStateKeySchema)
+  private val stateValueSchema: StructType =
+    
WidenStatefulOpNullability.widenStateSchema(stateManager.getStateValueSchema)
+
   override def validateAndMaybeEvolveStateSchema(
       hadoopConf: Configuration, batchId: Long, stateSchemaVersion: Int):
     List[StateSchemaValidationResult] = {
     val newStateSchema = 
List(StateStoreColFamilySchema(StateStore.DEFAULT_COL_FAMILY_NAME, 0,
-      stateManager.getStateKeySchema, 0, stateManager.getStateValueSchema))
+      stateKeySchema, 0, stateValueSchema))
     
List(StateSchemaCompatibilityChecker.validateAndMaybeEvolveStateSchema(getStateInfo,
 hadoopConf,
       newStateSchema, session.sessionState, stateSchemaVersion))
   }
@@ -1165,9 +1189,9 @@ case class SessionWindowStateStoreSaveExec(
 
     child.execute().mapPartitionsWithStateStore(
       getStateInfo,
-      stateManager.getStateKeySchema,
-      stateManager.getStateValueSchema,
-      PrefixKeyScanStateEncoderSpec(stateManager.getStateKeySchema,
+      stateKeySchema,
+      stateValueSchema,
+      PrefixKeyScanStateEncoderSpec(stateKeySchema,
         stateManager.getNumColsForPrefixKey),
       session.sessionState,
       Some(session.streams.stateStoreCoordinator)) { case (store, iter) =>
@@ -1251,7 +1275,8 @@ case class SessionWindowStateStoreSaveExec(
     }
   }
 
-  override def output: Seq[Attribute] = child.output
+  override def output: Seq[Attribute] =
+    WidenStatefulOpNullability.widenOutputForStatefulOp(child.output)
 
   override def outputPartitioning: Partitioning = child.outputPartitioning
 
@@ -1355,14 +1380,19 @@ abstract class BaseStreamingDeduplicateExec
   protected val schemaForValueRow: StructType
   protected val extraOptionOnStateStore: Map[String, String]
 
+  protected lazy val stateKeySchema: StructType =
+    WidenStatefulOpNullability.widenStateSchema(keyExpressions.toStructType)
+  protected lazy val stateValueSchema: StructType =
+    WidenStatefulOpNullability.widenStateSchema(schemaForValueRow)
+
   override protected def doExecute(): RDD[InternalRow] = {
     metrics // force lazy init at driver
 
     child.execute().mapPartitionsWithStateStore(
       getStateInfo,
-      keyExpressions.toStructType,
-      schemaForValueRow,
-      NoPrefixKeyStateEncoderSpec(keyExpressions.toStructType),
+      stateKeySchema,
+      stateValueSchema,
+      NoPrefixKeyStateEncoderSpec(stateKeySchema),
       session.sessionState,
       Some(session.streams.stateStoreCoordinator),
       extraOptions = extraOptionOnStateStore) { (store, iter) =>
@@ -1422,7 +1452,8 @@ abstract class BaseStreamingDeduplicateExec
 
   protected def evictDupInfoFromState(store: StateStore): Unit
 
-  override def output: Seq[Attribute] = child.output
+  override def output: Seq[Attribute] =
+    WidenStatefulOpNullability.widenOutputForStatefulOp(child.output)
 
   override def outputPartitioning: Partitioning = child.outputPartitioning
 
@@ -1476,7 +1507,7 @@ case class StreamingDeduplicateExec(
       hadoopConf: Configuration, batchId: Long, stateSchemaVersion: Int):
     List[StateSchemaValidationResult] = {
     val newStateSchema = 
List(StateStoreColFamilySchema(StateStore.DEFAULT_COL_FAMILY_NAME, 0,
-      keyExpressions.toStructType, 0, schemaForValueRow))
+      stateKeySchema, 0, stateValueSchema))
     
List(StateSchemaCompatibilityChecker.validateAndMaybeEvolveStateSchema(getStateInfo,
 hadoopConf,
       newStateSchema, session.sessionState, stateSchemaVersion,
       extraOptions = extraOptionOnStateStore))
@@ -1562,7 +1593,7 @@ case class StreamingDeduplicateWithinWatermarkExec(
       hadoopConf: Configuration, batchId: Long, stateSchemaVersion: Int):
     List[StateSchemaValidationResult] = {
     val newStateSchema = 
List(StateStoreColFamilySchema(StateStore.DEFAULT_COL_FAMILY_NAME, 0,
-      keyExpressions.toStructType, 0, schemaForValueRow))
+      stateKeySchema, 0, stateValueSchema))
     
List(StateSchemaCompatibilityChecker.validateAndMaybeEvolveStateSchema(getStateInfo,
 hadoopConf,
       newStateSchema, session.sessionState, stateSchemaVersion,
       extraOptions = extraOptionOnStateStore))
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/streamingLimits.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/streamingLimits.scala
index 6816be103f6e..da54c0ce0fe6 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/streamingLimits.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/streamingLimits.scala
@@ -22,6 +22,7 @@ import org.apache.hadoop.conf.Configuration
 
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.analysis.WidenStatefulOpNullability
 import org.apache.spark.sql.catalyst.expressions.{Attribute, 
GenericInternalRow, SortOrder, UnsafeProjection, UnsafeRow}
 import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, Distribution, 
Partitioning}
 import org.apache.spark.sql.execution.{LimitExec, SparkPlan, UnaryExecNode}
@@ -98,7 +99,8 @@ case class StreamingGlobalLimitExec(
     }
   }
 
-  override def output: Seq[Attribute] = child.output
+  override def output: Seq[Attribute] =
+    WidenStatefulOpNullability.widenOutputForStatefulOp(child.output)
 
   override def outputPartitioning: Partitioning = child.outputPartitioning
 
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/TransformWithStateExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/TransformWithStateExec.scala
index b200bde96cbc..f0e3003b2b71 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/TransformWithStateExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/TransformWithStateExec.scala
@@ -24,6 +24,7 @@ import org.apache.hadoop.conf.Configuration
 import org.apache.spark.broadcast.Broadcast
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.analysis.WidenStatefulOpNullability
 import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
 import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, 
UnsafeRow}
 import org.apache.spark.sql.catalyst.plans.logical._
@@ -35,6 +36,7 @@ import 
org.apache.spark.sql.execution.streaming.operators.stateful.transformwith
 import org.apache.spark.sql.execution.streaming.state._
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.streaming._
+import org.apache.spark.sql.types.StructType
 import org.apache.spark.util.{CompletionIterator, SerializableConfiguration, 
Utils}
 
 /**
@@ -88,6 +90,12 @@ case class TransformWithStateExec(
     initialState)
   with ObjectProducerExec {
 
+  override def output: Seq[Attribute] =
+    WidenStatefulOpNullability.widenOutputForStatefulOp(super.output)
+
+  private lazy val stateKeySchema: StructType =
+    WidenStatefulOpNullability.widenStateSchema(keyExpressions.toStructType)
+
   override def shortName: String = 
StatefulOperatorsUtils.TRANSFORM_WITH_STATE_EXEC_OP_NAME
 
   // We need to just initialize key and value deserializer once per partition.
@@ -133,12 +141,11 @@ case class TransformWithStateExec(
   override def getColFamilySchemas(
       shouldBeNullable: Boolean
   ): Map[String, StateStoreColFamilySchema] = {
-    val keySchema = keyExpressions.toStructType
     // we have to add the default column family schema because the 
RocksDBStateEncoder
     // expects this entry to be present in the stateSchemaProvider.
     val defaultSchema = 
StateStoreColFamilySchema(StateStore.DEFAULT_COL_FAMILY_NAME,
-      0, keyExpressions.toStructType, 0, DUMMY_VALUE_ROW_SCHEMA,
-      Some(NoPrefixKeyStateEncoderSpec(keySchema)))
+      0, stateKeySchema, 0, DUMMY_VALUE_ROW_SCHEMA,
+      Some(NoPrefixKeyStateEncoderSpec(stateKeySchema)))
 
     // For Scala, the user can't explicitly set nullability on schema, so 
there is
     // no reason to throw an error, and we can simply set the schema to 
nullable.
@@ -147,9 +154,37 @@ case class TransformWithStateExec(
         shouldCheckNullable = false, shouldSetNullable = shouldBeNullable) ++
         Map(StateStore.DEFAULT_COL_FAMILY_NAME -> defaultSchema)
     closeProcessorHandle()
-    columnFamilySchemas
+    widenColFamilyGroupingKeys(columnFamilySchemas)
   }
 
+  private def widenColFamilyGroupingKeys(
+      schemas: Map[String, StateStoreColFamilySchema])
+      : Map[String, StateStoreColFamilySchema] = {
+    if (!WidenStatefulOpNullability.isEnabled) return schemas
+    val original = keyEncoder.schema
+    val widened = stateKeySchema
+    def widenKey(ks: StructType): StructType =
+      WidenStatefulOpNullability.widenGroupingKeyInSchema(ks, original, 
widened)
+    schemas.map { case (name, cf) =>
+      val widenedSpec = cf.keyStateEncoderSpec.map {
+        case NoPrefixKeyStateEncoderSpec(ks) =>
+          NoPrefixKeyStateEncoderSpec(widenKey(ks))
+        case PrefixKeyScanStateEncoderSpec(ks, n) =>
+          PrefixKeyScanStateEncoderSpec(widenKey(ks), n)
+        case RangeKeyScanStateEncoderSpec(ks, o) =>
+          RangeKeyScanStateEncoderSpec(widenKey(ks), o)
+        case TimestampAsPrefixKeyStateEncoderSpec(ks) =>
+          TimestampAsPrefixKeyStateEncoderSpec(widenKey(ks))
+        case TimestampAsPostfixKeyStateEncoderSpec(ks) =>
+          TimestampAsPostfixKeyStateEncoderSpec(widenKey(ks))
+      }
+      name -> cf.copy(
+        keySchema = widenKey(cf.keySchema),
+        keyStateEncoderSpec = widenedSpec)
+    }
+  }
+
+
   override def getStateVariableInfos(): Map[String, 
TransformWithStateVariableInfo] = {
     val stateVariableInfos = getDriverProcessorHandle().getStateVariableInfos
     closeProcessorHandle()
@@ -401,9 +436,9 @@ case class TransformWithStateExec(
             val storeProviderId = StateStoreProviderId(stateStoreId, 
stateInfo.get.queryRunId)
             val store = StateStore.get(
               storeProviderId = storeProviderId,
-              keyEncoder.schema,
+              stateKeySchema,
               DUMMY_VALUE_ROW_SCHEMA,
-              NoPrefixKeyStateEncoderSpec(keyEncoder.schema),
+              NoPrefixKeyStateEncoderSpec(stateKeySchema),
               version = stateInfo.get.storeVersion,
               stateStoreCkptId = 
stateInfo.get.getStateStoreCkptId(partitionId).map(_.head),
               stateSchemaBroadcast = stateInfo.get.stateSchemaMetadata,
@@ -423,9 +458,9 @@ case class TransformWithStateExec(
       if (isStreaming) {
         child.execute().mapPartitionsWithStateStore[InternalRow](
           getStateInfo,
-          keyEncoder.schema,
+          stateKeySchema,
           DUMMY_VALUE_ROW_SCHEMA,
-          NoPrefixKeyStateEncoderSpec(keyEncoder.schema),
+          NoPrefixKeyStateEncoderSpec(stateKeySchema),
           session.sessionState,
           Some(session.streams.stateStoreCoordinator),
           useColumnFamilies = true
@@ -473,9 +508,9 @@ case class TransformWithStateExec(
     // Create StateStoreProvider for this partition
     val stateStoreProvider = StateStoreProvider.createAndInit(
       providerId,
-      keyEncoder.schema,
+      stateKeySchema,
       DUMMY_VALUE_ROW_SCHEMA,
-      NoPrefixKeyStateEncoderSpec(keyEncoder.schema),
+      NoPrefixKeyStateEncoderSpec(stateKeySchema),
       useColumnFamilies = true,
       storeConf = storeConf,
       hadoopConf = hadoopConfBroadcast.value.value,
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/IncrementalExecution.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/IncrementalExecution.scala
index 9fc72241e83b..0d2e4a6941a0 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/IncrementalExecution.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/IncrementalExecution.scala
@@ -27,6 +27,7 @@ import org.apache.hadoop.fs.Path
 import org.apache.spark.internal.Logging
 import org.apache.spark.internal.LogKeys.{BATCH_TIMESTAMP, ERROR}
 import org.apache.spark.sql.catalyst.QueryPlanningTracker
+import 
org.apache.spark.sql.catalyst.analysis.WidenStatefulOperatorAttributeNullability
 import org.apache.spark.sql.catalyst.expressions.{CurrentBatchTimestamp, 
ExpressionWithRandomSeed}
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.catalyst.rules.Rule
@@ -133,7 +134,7 @@ class IncrementalExecution(
       // of sink information.
       case w: WriteToMicroBatchDataSourceV1 => w.child
     }
-    sparkSession.sessionState.optimizer.executeAndTrack(preOptimized,
+    val optimized = 
sparkSession.sessionState.optimizer.executeAndTrack(preOptimized,
       tracker).transformAllExpressionsWithPruning(
       _.containsAnyPattern(CURRENT_LIKE, EXPRESSION_WITH_RANDOM_SEED)) {
       case ts @ CurrentBatchTimestamp(timestamp, _, _) =>
@@ -141,6 +142,7 @@ class IncrementalExecution(
         ts.toLiteral
       case e: ExpressionWithRandomSeed => 
e.withNewSeed(Utils.random.nextLong())
     }
+    WidenStatefulOperatorAttributeNullability(optimized)
   }
 
   // Use `this` for explain so the already-open transaction and executedPlan 
are reused.
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala
index 1e1aa451a0ae..c46f0076721b 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala
@@ -1181,16 +1181,16 @@ abstract class StreamingInnerJoinSuite extends 
StreamingInnerJoinBase {
       val hadoopConf = spark.sessionState.newHadoopConf()
       val fm = CheckpointFileManager.create(stateSchemaPath, hadoopConf)
 
-      val keySchemaForNums = new StructType().add("field0", IntegerType, 
nullable = false)
+      val keySchemaForNums = new StructType().add("field0", IntegerType)
       val keySchemaForIndex = keySchemaForNums.add("index", LongType)
       val numSchema: StructType = new StructType().add("value", LongType)
       val leftIndexSchema: StructType = new StructType()
-        .add("key", IntegerType, nullable = false)
-        .add("leftValue", IntegerType, nullable = false)
+        .add("key", IntegerType)
+        .add("leftValue", IntegerType)
         .add("matched", BooleanType)
       val rightIndexSchema: StructType = new StructType()
-        .add("key", IntegerType, nullable = false)
-        .add("rightValue", IntegerType, nullable = false)
+        .add("key", IntegerType)
+        .add("rightValue", IntegerType)
         .add("matched", BooleanType)
 
       val schemaLeftIndex = StateStoreColFamilySchema(
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinV4Suite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinV4Suite.scala
index e58af3b2bf65..6d4a97861efe 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinV4Suite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinV4Suite.scala
@@ -112,16 +112,16 @@ class StreamingInnerJoinV4Suite
         CheckpointFileManager.create(stateSchemaPath, hadoopConf)
 
       val keySchemaWithTimestamp = new StructType()
-        .add("field0", IntegerType, nullable = false)
-        .add("__event_time", LongType, nullable = false)
+        .add("field0", IntegerType)
+        .add("__event_time", LongType)
 
       val leftValueSchema: StructType = new StructType()
-        .add("key", IntegerType, nullable = false)
-        .add("leftValue", IntegerType, nullable = false)
+        .add("key", IntegerType)
+        .add("leftValue", IntegerType)
         .add("matched", BooleanType)
       val rightValueSchema: StructType = new StructType()
-        .add("key", IntegerType, nullable = false)
-        .add("rightValue", IntegerType, nullable = false)
+        .add("key", IntegerType)
+        .add("rightValue", IntegerType)
         .add("matched", BooleanType)
 
       val dummyValueSchema =
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingStatefulOperatorNullabilityDriftSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingStatefulOperatorNullabilityDriftSuite.scala
new file mode 100644
index 000000000000..5278f68fbb0a
--- /dev/null
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingStatefulOperatorNullabilityDriftSuite.scala
@@ -0,0 +1,534 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.streaming
+
+import org.apache.hadoop.fs.Path
+
+import org.apache.spark.SparkUnsupportedOperationException
+import org.apache.spark.sql.{DataFrame, Encoders}
+import org.apache.spark.sql.catalyst.analysis.WidenStatefulOpNullability
+import 
org.apache.spark.sql.execution.streaming.checkpointing.CheckpointFileManager
+import org.apache.spark.sql.execution.streaming.runtime.MemoryStream
+import 
org.apache.spark.sql.execution.streaming.state.{RocksDBStateStoreProvider, 
StateSchemaCompatibilityChecker, StateStore}
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructType}
+
+/**
+ * Regression suite for stateful-operator nullability drift.
+ *
+ * Driver: `PropagateEmptyRelation` drops empty `Union` branches without a 
streaming
+ * guard, so the surviving branch's per-column nullability becomes the Union's
+ * nullability and propagates into a stateful operator above -- across 
microbatches or
+ * restarts.
+ *
+ * Coverage:
+ *   - New-query (default conf): originally-failing scenarios now complete 
cleanly.
+ *   - Existing-query (conf forced false): pre-fix behavior preserved verbatim.
+ *   - Helper invariant: `WidenStatefulOpNullability.deepWidenAttribute` 
recurses into
+ *     nested types.
+ */
+class StreamingStatefulOperatorNullabilityDriftSuite extends StreamTest {
+
+  import testImplicits._
+
+  private def buildTwoSources(): (MemoryStream[Int], MemoryStream[Int], 
DataFrame, DataFrame) = {
+    val inputA = MemoryStream[Int]
+    val inputB = MemoryStream[Int]
+
+    val dfA = inputA.toDF().select($"value".as("key"))
+    val dfB = inputB.toDF()
+      .select(when($"value" > Int.MinValue, $"value")
+        .otherwise(lit(null).cast("int"))
+        .as("key"))
+
+    (inputA, inputB, dfA, dfB)
+  }
+
+  private def buildTwoSourcesWithWatermark()
+      : (MemoryStream[Int], MemoryStream[Int], DataFrame, DataFrame) = {
+    val inputA = MemoryStream[Int]
+    val inputB = MemoryStream[Int]
+
+    val dfA = inputA.toDF()
+      .select($"value".as("key"),
+        timestamp_seconds($"value").as("ts"))
+      .withWatermark("ts", "1 minute")
+    val dfB = inputB.toDF()
+      .select(when($"value" > Int.MinValue, $"value")
+        .otherwise(lit(null).cast("int")).as("key"),
+        timestamp_seconds($"value").as("ts"))
+      .withWatermark("ts", "1 minute")
+
+    (inputA, inputB, dfA, dfB)
+  }
+
+  private def runUnionBranchDropRestart(
+      buildSources: () => (MemoryStream[Int], MemoryStream[Int], DataFrame, 
DataFrame),
+      buildQuery: (DataFrame, DataFrame) => DataFrame,
+      outputMode: OutputMode,
+      nullableToNonNullable: Boolean): Unit = {
+    withTempDir { checkpointDir =>
+      val checkpointPath = checkpointDir.getAbsolutePath
+
+      val (inputA, inputB, dfA, dfB) = buildSources()
+      val q = buildQuery(dfA, dfB)
+
+      if (nullableToNonNullable) {
+        testStream(q, outputMode)(
+          StartStream(checkpointLocation = checkpointPath),
+          MultiAddData(inputA, 1, 2, 3)(inputB, 4, 5),
+          ProcessAllAvailable(),
+          StopStream
+        )
+      } else {
+        testStream(q, outputMode)(
+          StartStream(checkpointLocation = checkpointPath),
+          AddData(inputA, 1, 2, 3),
+          ProcessAllAvailable(),
+          StopStream
+        )
+      }
+
+      assertJournaledStateSchemaAllNullable(checkpointPath)
+
+      if (nullableToNonNullable) {
+        testStream(q, outputMode)(
+          StartStream(checkpointLocation = checkpointPath),
+          AddData(inputA, 6),
+          ProcessAllAvailable()
+        )
+      } else {
+        testStream(q, outputMode)(
+          StartStream(checkpointLocation = checkpointPath),
+          MultiAddData(inputA, 6)(inputB, 7),
+          ProcessAllAvailable()
+        )
+      }
+    }
+  }
+
+  private def assertJournaledStateSchemaAllNullable(checkpointPath: String): 
Unit = {
+    val partId = StateStore.PARTITION_ID_TO_CHECK_SCHEMA
+    val operatorRoot = new Path(checkpointPath, "state/0")
+    val partitionRoot = new Path(operatorRoot, s"$partId")
+    val hadoopConf = spark.sessionState.newHadoopConf()
+    val fm = CheckpointFileManager.create(operatorRoot, hadoopConf)
+    val fs = operatorRoot.getFileSystem(hadoopConf)
+
+    def collectSchemaFiles(dir: Path): Seq[Path] = {
+      if (!fm.exists(dir)) return Seq.empty
+      if (fs.getFileStatus(dir).isDirectory) {
+        fs.listStatus(dir).filter(_.isFile).map(_.getPath).toSeq
+      } else {
+        Seq(dir)
+      }
+    }
+
+    val schemaFiles = scala.collection.mutable.ArrayBuffer.empty[Path]
+
+    val storeDirs = scala.collection.mutable.ArrayBuffer(partitionRoot)
+    if (fs.exists(partitionRoot)) {
+      fs.listStatus(partitionRoot)
+        .filter(_.isDirectory)
+        .filterNot(_.getPath.getName.startsWith("_"))
+        .foreach(d => storeDirs += d.getPath)
+    }
+    storeDirs.foreach { storeDir =>
+      schemaFiles ++= collectSchemaFiles(
+        new Path(storeDir, "_metadata/schema"))
+    }
+
+    val stateSchemaRoot = new Path(operatorRoot, "_stateSchema")
+    if (fs.exists(stateSchemaRoot)) {
+      fs.listStatus(stateSchemaRoot)
+        .filter(_.isDirectory)
+        .foreach { storeDir =>
+          schemaFiles ++= collectSchemaFiles(storeDir.getPath)
+        }
+    }
+
+    assert(schemaFiles.nonEmpty,
+      s"expected at least one schema file under $operatorRoot")
+    schemaFiles.foreach { schemaFile =>
+      val inStream = fm.open(schemaFile)
+      try {
+        val schemas = StateSchemaCompatibilityChecker.readSchemaFile(inStream)
+        schemas.foreach { s =>
+          assertSchemaAllNullable(s.keySchema,
+            s"$schemaFile: key schema for col family ${s.colFamilyName}")
+        }
+      } finally inStream.close()
+    }
+  }
+
+  private def assertSchemaAllNullable(schema: StructType, label: String): Unit 
= {
+    schema.fields.foreach { f =>
+      assert(f.nullable, s"$label: field ${f.name} should be nullable")
+      assertDataTypeAllNullable(f.dataType, s"$label.${f.name}")
+    }
+  }
+
+  private def assertDataTypeAllNullable(dataType: DataType, label: String): 
Unit = dataType match {
+    case s: StructType => assertSchemaAllNullable(s, label)
+    case ArrayType(elementType, containsNull) =>
+      assert(containsNull, s"$label: array element should be nullable")
+      assertDataTypeAllNullable(elementType, s"$label[]")
+    case MapType(keyType, valueType, valueContainsNull) =>
+      assert(valueContainsNull, s"$label: map value should be nullable")
+      assertDataTypeAllNullable(keyType, s"$label.key")
+      assertDataTypeAllNullable(valueType, s"$label.value")
+    case _ =>
+  }
+
+  test("streaming aggregate: non-nullable -> nullable widening remains 
restart-compatible") {
+    runUnionBranchDropRestart(
+      buildSources = () => buildTwoSources(),
+      buildQuery = (dfA, dfB) => dfA.union(dfB).groupBy($"key").count(),
+      outputMode = OutputMode.Update(),
+      nullableToNonNullable = false)
+  }
+
+  test("streaming aggregate: nullable -> non-nullable narrowing remains 
restart-compatible") {
+    runUnionBranchDropRestart(
+      buildSources = () => buildTwoSources(),
+      buildQuery = (dfA, dfB) => dfA.union(dfB).groupBy($"key").count(),
+      outputMode = OutputMode.Update(),
+      nullableToNonNullable = true)
+  }
+
+  test("streaming dropDuplicates: non-nullable -> nullable widening remains 
restart-compatible") {
+    runUnionBranchDropRestart(
+      buildSources = () => buildTwoSources(),
+      buildQuery = (dfA, dfB) => dfA.union(dfB).dropDuplicates(Seq("key")),
+      outputMode = OutputMode.Append(),
+      nullableToNonNullable = false)
+  }
+
+  test("streaming dropDuplicatesWithinWatermark: " +
+    "non-nullable -> nullable widening remains restart-compatible") {
+    runUnionBranchDropRestart(
+      buildSources = () => buildTwoSourcesWithWatermark(),
+      buildQuery = (dfA, dfB) => 
dfA.union(dfB).dropDuplicatesWithinWatermark(Seq("key")),
+      outputMode = OutputMode.Append(),
+      nullableToNonNullable = false)
+  }
+
+  test("streaming aggregate (Complete mode): no codegen NPE on state-restored 
null " +
+    "struct grouping key after fix") {
+    import org.apache.spark.sql.functions.struct
+
+    def mkQuery(inNullableK: MemoryStream[Int], inNonNullK: 
MemoryStream[Int]): DataFrame = {
+      val dfNullable = inNullableK.toDF()
+        .select(
+          when($"value" > 0, struct($"value".as("v")))
+            .otherwise(lit(null).cast("struct<v:int>"))
+            .as("key"),
+          lit(1).as("metric"))
+
+      val dfNonNull = inNonNullK.toDF()
+        .select(
+          struct($"value".as("v")).as("key"),
+          lit(1).as("metric"))
+
+      dfNullable.union(dfNonNull)
+        .groupBy($"key")
+        .agg(sum($"metric").as("c"))
+        .select($"key.v".as("v"), $"c")
+    }
+
+    withTempDir { checkpointDir =>
+      withSQLConf(
+        SQLConf.STATE_SCHEMA_CHECK_ENABLED.key -> "false",
+        SQLConf.STATE_STORE_FORMAT_VALIDATION_ENABLED.key -> "false",
+        SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
+        val inNullable = MemoryStream[Int]
+        val inNonNull = MemoryStream[Int]
+        val q = mkQuery(inNullable, inNonNull)
+        testStream(q, OutputMode.Complete())(
+          StartStream(checkpointLocation = checkpointDir.getAbsolutePath),
+          AddData(inNullable, 0),
+          ProcessAllAvailable(),
+          StopStream
+        )
+
+        testStream(q, OutputMode.Complete())(
+          StartStream(checkpointLocation = checkpointDir.getAbsolutePath),
+          AddData(inNonNull, 1),
+          ProcessAllAvailable()
+        )
+      }
+    }
+  }
+
+  test("streaming aggregate: with widening forced off (existing-query path), " 
+
+    "STATE_STORE_KEY_SCHEMA_NOT_COMPATIBLE still triggers on restart") {
+    withTempDir { checkpointDir =>
+      withSQLConf(
+        SQLConf.STATEFUL_OPERATOR_ALWAYS_NULLABLE_OUTPUT.key -> "false") {
+        val (inputA, inputB, dfA, dfB) = buildTwoSources()
+        val aggregated = dfA.union(dfB).groupBy($"key").count()
+        testStream(aggregated, OutputMode.Update())(
+          StartStream(checkpointLocation = checkpointDir.getAbsolutePath),
+          AddData(inputA, 1, 2, 3),
+          ProcessAllAvailable(),
+          StopStream
+        )
+
+        inputA.addData(4)
+        inputB.addData(5)
+
+        val ex = intercept[SparkUnsupportedOperationException] {
+          testStream(aggregated, OutputMode.Update())(
+            StartStream(checkpointLocation = checkpointDir.getAbsolutePath),
+            ProcessAllAvailable()
+          )
+        }
+
+        checkError(
+          ex,
+          condition = "STATE_STORE_KEY_SCHEMA_NOT_COMPATIBLE",
+          parameters = Map(
+            "storedKeySchema" -> ".*",
+            "newKeySchema" -> ".*"),
+          matchPVals = true
+        )
+      }
+    }
+  }
+
+  test("stream-stream join: non-nullable -> nullable widening remains 
restart-compatible") {
+    withTempDir { checkpointDir =>
+      val checkpointPath = checkpointDir.getAbsolutePath
+
+      def buildJoinQuery(): (MemoryStream[Int], MemoryStream[Int], DataFrame) 
= {
+        val leftInput = MemoryStream[Int]
+        val rightInput = MemoryStream[Int]
+
+        val left = leftInput.toDF()
+          .select($"value".as("key"),
+            timestamp_seconds($"value").as("leftTime"))
+          .withWatermark("leftTime", "10 seconds")
+        val right = rightInput.toDF()
+          .select(
+            when($"value" > Int.MinValue, $"value")
+              .otherwise(lit(null).cast("int")).as("key"),
+            timestamp_seconds($"value").as("rightTime"))
+          .withWatermark("rightTime", "10 seconds")
+
+        val joined = left.join(right,
+          left("key") === right("key") &&
+            left("leftTime") > right("rightTime") - expr("INTERVAL 10 
SECONDS") &&
+            left("leftTime") < right("rightTime") + expr("INTERVAL 10 
SECONDS"),
+          "inner")
+        (leftInput, rightInput, joined)
+      }
+
+      val (leftInput1, rightInput1, joined1) = buildJoinQuery()
+      testStream(joined1, OutputMode.Append())(
+        StartStream(checkpointLocation = checkpointPath),
+        MultiAddData(leftInput1, 1, 2, 3)(rightInput1, 1, 2),
+        ProcessAllAvailable(),
+        StopStream
+      )
+
+      assertJournaledStateSchemaAllNullable(checkpointPath)
+
+      val (leftInput2, rightInput2, joined2) = buildJoinQuery()
+      testStream(joined2, OutputMode.Append())(
+        StartStream(checkpointLocation = checkpointPath),
+        MultiAddData(leftInput2, 4)(rightInput2, 5),
+        ProcessAllAvailable()
+      )
+    }
+  }
+
+  test("streaming flatMapGroupsWithState: " +
+    "non-nullable -> nullable widening remains restart-compatible") {
+    val stateFunc = (key: Int, values: Iterator[Int], state: GroupState[Int]) 
=> {
+      val sum = values.sum + state.getOption.getOrElse(0)
+      state.update(sum)
+      Iterator((key, sum))
+    }
+
+    withTempDir { checkpointDir =>
+      val checkpointPath = checkpointDir.getAbsolutePath
+
+      def buildFmgwsQuery()
+          : (MemoryStream[Int], MemoryStream[Int], DataFrame) = {
+        val (inputA, inputB, dfA, dfB) = buildTwoSources()
+        val result = dfA.union(dfB)
+          .as[Int]
+          .groupByKey(identity)
+          .flatMapGroupsWithState(
+            OutputMode.Update(), GroupStateTimeout.NoTimeout())(stateFunc)
+          .toDF("key", "sum")
+        (inputA, inputB, result)
+      }
+
+      val (inputA1, inputB1, q1) = buildFmgwsQuery()
+      testStream(q1, OutputMode.Update())(
+        StartStream(checkpointLocation = checkpointPath),
+        AddData(inputA1, 1, 2, 3),
+        ProcessAllAvailable(),
+        StopStream
+      )
+
+      assertJournaledStateSchemaAllNullable(checkpointPath)
+
+      val (inputA2, inputB2, q2) = buildFmgwsQuery()
+      testStream(q2, OutputMode.Update())(
+        StartStream(checkpointLocation = checkpointPath),
+        MultiAddData(inputA2, 4)(inputB2, 5),
+        ProcessAllAvailable()
+      )
+    }
+  }
+
+  test("streaming transformWithState: " +
+    "non-nullable -> nullable widening remains restart-compatible") {
+    withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
+      classOf[RocksDBStateStoreProvider].getName) {
+    withTempDir { checkpointDir =>
+      val checkpointPath = checkpointDir.getAbsolutePath
+
+      def buildTwsQuery()
+          : (MemoryStream[Int], MemoryStream[Int], DataFrame) = {
+        val (inputA, inputB, dfA, dfB) = buildTwoSources()
+        val result = dfA.union(dfB)
+          .as[Int]
+          .groupByKey(identity)
+          .transformWithState(
+            new NullabilityDriftCountProcessor(),
+            TimeMode.None(),
+            OutputMode.Update())
+        (inputA, inputB, result.toDF())
+      }
+
+      val (inputA1, inputB1, q1) = buildTwsQuery()
+      testStream(q1, OutputMode.Update())(
+        StartStream(checkpointLocation = checkpointPath),
+        AddData(inputA1, 1, 2, 3),
+        ProcessAllAvailable(),
+        StopStream
+      )
+
+      assertJournaledStateSchemaAllNullable(checkpointPath)
+
+      val (inputA2, inputB2, q2) = buildTwsQuery()
+      testStream(q2, OutputMode.Update())(
+        StartStream(checkpointLocation = checkpointPath),
+        MultiAddData(inputA2, 4)(inputB2, 5),
+        ProcessAllAvailable()
+      )
+    }
+    }
+  }
+
+  test("rule skips non-stateful nodes whose subtree has no stateful operator") 
{
+    import 
org.apache.spark.sql.catalyst.analysis.WidenStatefulOperatorAttributeNullability
+    import org.apache.spark.sql.catalyst.expressions.{AttributeReference, 
NamedExpression}
+    import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, 
LocalRelation, Project}
+    import org.apache.spark.sql.types.IntegerType
+
+    val key = AttributeReference("key", IntegerType, nullable = false)()
+    val payload = AttributeReference("payload", IntegerType, nullable = 
false)()
+    val source = LocalRelation(Seq(key, payload), isStreaming = true)
+    val project = Project(Seq(key, payload), source)
+    val agg = Aggregate(
+      groupingExpressions = Seq(key),
+      aggregateExpressions = Seq(key.asInstanceOf[NamedExpression]),
+      child = project)
+
+    val widened = WidenStatefulOperatorAttributeNullability(agg)
+
+    val projectAfter = widened.collectFirst { case p: Project => p }.getOrElse(
+      fail(s"expected to find a Project node in the rewritten plan: $widened"))
+    assert(projectAfter.projectList.forall {
+      case ar: AttributeReference => !ar.nullable
+      case _ => true
+    }, s"Project.projectList below a stateful op should remain non-nullable: " 
+
+       s"${projectAfter.projectList}")
+
+    val aggAfter = widened.asInstanceOf[Aggregate]
+    assert(aggAfter.aggregateExpressions.forall {
+      case ar: AttributeReference => ar.nullable
+      case _ => true
+    }, s"Aggregate.aggregateExpressions should be widened to nullable: " +
+       s"${aggAfter.aggregateExpressions}")
+    assert(aggAfter.groupingExpressions.forall {
+      case ar: AttributeReference => ar.nullable
+      case _ => true
+    }, s"Aggregate.groupingExpressions should be widened to nullable: " +
+       s"${aggAfter.groupingExpressions}")
+  }
+
+  test("deepWidenAttribute recurses into struct fields, array elements, map 
values") {
+    import org.apache.spark.sql.catalyst.expressions.{AttributeReference, 
ExprId}
+    import org.apache.spark.sql.types._
+
+    val nestedStruct = StructType(Seq(
+      StructField("inner_nn", IntegerType, nullable = false),
+      StructField("inner_nl", StringType, nullable = true)))
+    val arrayOfNonNull = ArrayType(IntegerType, containsNull = false)
+    val mapWithNonNullValue = MapType(StringType, IntegerType, 
valueContainsNull = false)
+    val combined = StructType(Seq(
+      StructField("s", nestedStruct, nullable = false),
+      StructField("a", arrayOfNonNull, nullable = false),
+      StructField("m", mapWithNonNullValue, nullable = false)))
+
+    val attr = AttributeReference("complex", combined, nullable = 
false)(ExprId(42L))
+    val widened = WidenStatefulOpNullability.deepWidenAttribute(attr)
+
+    assert(widened.nullable, "outer attribute should be widened to nullable")
+    val widenedStruct = widened.dataType.asInstanceOf[StructType]
+    val widenedNested = widenedStruct("s").dataType.asInstanceOf[StructType]
+    assert(
+      widenedStruct("s").nullable && widenedStruct("a").nullable && 
widenedStruct("m").nullable,
+      "all top-level fields should be widened to nullable")
+    assert(widenedNested("inner_nn").nullable && 
widenedNested("inner_nl").nullable,
+      "nested struct fields should be widened to nullable")
+    val widenedArray = widenedStruct("a").dataType.asInstanceOf[ArrayType]
+    assert(widenedArray.containsNull, "array element nullability should be 
widened")
+    val widenedMap = widenedStruct("m").dataType.asInstanceOf[MapType]
+    assert(widenedMap.valueContainsNull, "map value nullability should be 
widened")
+
+    assert(widened.exprId == attr.exprId)
+    assert(widened.name == attr.name)
+    assert(widened.qualifier == attr.qualifier)
+  }
+}
+
+class NullabilityDriftCountProcessor
+    extends StatefulProcessor[Int, Int, (Int, Long)] {
+  @transient private var countState: ValueState[Long] = _
+
+  override def init(outputMode: OutputMode, timeMode: TimeMode): Unit = {
+    countState = getHandle.getValueState[Long](
+      "count", Encoders.scalaLong, TTLConfig.NONE)
+  }
+
+  override def handleInputRows(
+      key: Int,
+      rows: Iterator[Int],
+      timerValues: TimerValues): Iterator[(Int, Long)] = {
+    val count = (if (countState.exists()) countState.get() else 0L) + rows.size
+    countState.update(count)
+    Iterator((key, count))
+  }
+}
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala
index 0454c67f6a61..de1bc0d9c3d7 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala
@@ -1787,14 +1787,14 @@ abstract class TransformWithStateSuite extends 
StateStoreMetricsTest
             TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString,
           SQLConf.STREAMING_STATE_STORE_ENCODING_FORMAT.key -> encoding) {
           withTempDir { checkpointDir =>
-            // When Avro is used, we want to set the StructFields to nullable
-            val shouldBeNullable = encoding == "avro"
             val metadataPathPostfix = "state/0/_stateSchema/default"
             val stateSchemaPath = new Path(checkpointDir.toString,
               s"$metadataPathPostfix")
             val hadoopConf = spark.sessionState.newHadoopConf()
             val fm = CheckpointFileManager.create(stateSchemaPath, hadoopConf)
 
+            // When Avro is used, we want to set the StructFields to nullable
+            val shouldBeNullable = encoding == "avro"
             val keySchema = new StructType().add("value", StringType)
             val schema0 = StateStoreColFamilySchema(
               "countState", 0,


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to