cloud-fan commented on code in PR #56061:
URL: https://github.com/apache/spark/pull/56061#discussion_r3301377768


##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/WidenStatefulOperatorAttributeNullability.scala:
##########
@@ -0,0 +1,135 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.analysis
+
+import org.apache.spark.sql.catalyst.expressions.{Attribute, 
AttributeReference, ExprId}
+import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan}
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types.StructType
+
+/**
+ * Shared helpers for the stateful-operator nullability fix. The fix has three
+ * independent components, all gated by
+ * [[SQLConf.STATEFUL_OPERATOR_ALWAYS_NULLABLE_OUTPUT]] (pinned per-query via 
the
+ * offset log so existing queries keep their pre-fix behavior on restart):
+ *
+ *   - (a) `widenStateSchema`: explicit `asNullable` at every state-schema 
construction
+ *         site in each stateful physical exec.
+ *   - (b) `widenOutputForStatefulOp`: a per-op `output` override on every 
stateful logical
+ *         and physical operator, used by the operator's `output` definition.
+ *   - (c) [[WidenStatefulOperatorAttributeNullability]] (defined below in 
this file): a
+ *         custom optimizer rule that widens `AttributeReference`s inside 
stateful ops'
+ *         internal expressions and propagates upward to ancestor expressions.
+ */
+object WidenStatefulOpNullability {
+
+  def isEnabled: Boolean =
+    SQLConf.get.getConf(SQLConf.STATEFUL_OPERATOR_ALWAYS_NULLABLE_OUTPUT)
+
+  /**
+   * Recursively widens an attribute to be fully nullable: outer `nullable = 
true` plus
+   * every nested `StructField.nullable`, `ArrayType.containsNull`, and
+   * `MapType.valueContainsNull` flipped to `true` via
+   * [[org.apache.spark.sql.types.DataType#asNullable]].
+   */
+  def deepWidenAttribute(a: Attribute): Attribute = a match {
+    case ref: AttributeReference =>
+      AttributeReference(
+        ref.name, ref.dataType.asNullable, nullable = true, ref.metadata)(
+        ref.exprId, ref.qualifier)
+    case other => other.withNullability(true)
+  }
+
+  /**
+   * Component (a): widens a state schema to fully nullable. Stateful physical 
execs apply
+   * this at every `validateAndMaybeEvolveStateSchema(...)` call site and every
+   * `mapPartitionsWith*StateStore(...)` call site. When the conf is off, 
returns the
+   * schema unchanged.
+   */
+  def widenStateSchema(schema: StructType): StructType =
+    if (isEnabled) schema.asNullable else schema
+
+  /**
+   * Component (b): wraps a stateful operator's `output` to be fully nullable. 
The caller
+   * is responsible for only calling this from within an `output` definition 
on a stateful
+   * operator; gating is handled here via [[isEnabled]].
+   */
+  def widenOutputForStatefulOp(base: Seq[Attribute]): Seq[Attribute] =
+    if (isEnabled) base.map(deepWidenAttribute) else base
+}
+
+/**
+ * Component (c) of the stateful-operator nullability fix: a custom optimizer 
rule that
+ * widens `AttributeReference`s inside streaming-stateful operators' internal 
expressions
+ * and propagates the widening upward to ancestor operators' expressions.
+ *
+ * The rule does NOT introduce any new logical or physical node. It is purely 
an
+ * attribute-rewrite pass: for every node whose subtree contains a stateful 
operator,
+ * collect `exprId`s from `p.output` plus children's output, then deep-widen 
every
+ * `AttributeReference` in the node's expressions whose `exprId` is in that 
set via
+ * [[WidenStatefulOpNullability#deepWidenAttribute]].
+ *
+ * At a stateful operator itself, all children's output attributes are 
included because
+ * the operator's internal expressions (e.g. grouping keys) reference them 
directly.
+ * At non-stateful ancestor operators, only children whose subtrees contain a 
stateful
+ * operator are included, to avoid unnecessary widening of non-stateful 
siblings.
+ *
+ * '''Scope.''' The walk only fires on nodes whose subtree contains a stateful 
operator.
+ *
+ * '''Ordering constraint.''' This rule must run AFTER every 
`UpdateAttributeNullability`
+ * invocation in both the main optimizer and AQE.
+ *
+ * '''Idempotence.''' [[WidenStatefulOpNullability#deepWidenAttribute]] is 
idempotent.
+ */
+object WidenStatefulOperatorAttributeNullability extends Rule[LogicalPlan] {
+
+  override def apply(plan: LogicalPlan): LogicalPlan = {
+    if (!conf.getConf(SQLConf.STATEFUL_OPERATOR_ALWAYS_NULLABLE_OUTPUT) ||
+        !plan.containsStatefulOperator) {
+      return plan
+    }
+    plan.resolveOperatorsUp {
+      case p if !p.resolved => p
+      case p: LeafNode => p
+      case p if !p.containsStatefulOperator => p
+      case p =>
+        val childOutputs = if (p.isStateful) {
+          p.children.flatMap(_.output)
+        } else {
+          p.children.filter(_.containsStatefulOperator).flatMap(_.output)
+        }
+        val widenableExprIds: Set[ExprId] =
+          (p.output ++ childOutputs)

Review Comment:
   Filtering `childOutputs` to stateful subtrees only tightens half of the 
union — `p.output` is still pulled in unconditionally. For an Inner / Outer / 
Full `Join`, `Join.output = left.output ++ right.output` carries the 
non-stateful side's `exprId`s, so for a non-stream-stream `Join` above 
`[stateful, batch]`, references to the batch side in the join condition still 
end up in `widenableExprIds` and still get widened. (For `LeftSemi` / 
`LeftAnti` where `Join.output = left.output`, the filter does fully help.)
   
   Not a blocker for this PR, but it means the comment at lines 90-91 ("only 
children whose subtrees contain a stateful operator are included, to avoid 
unnecessary widening of non-stateful siblings") slightly overstates the scope 
for mixed-stateful `Join`s. Two ways to reconcile in a follow-up:
   - Tighten the code: compute `statefulExprIds` from stateful children's 
outputs and use `(p.output.filter(ar => statefulExprIds.contains(ar.exprId)) ++ 
statefulExprIds-attrs)` so the Inner/Outer-join case is also handled.
   - Or keep the partial fix and weaken the comment to call out the 
mixed-stateful `Join` caveat.
   
   Happy with either path.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to