This is an automated email from the ASF dual-hosted git repository.
kabhwan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new e03319fd9219 [SPARK-49676][SS][PYTHON] Add Support for Chaining of
Operators in transformWithStateInPandas API
e03319fd9219 is described below
commit e03319fd9219da7162c12a15998d5718edc4c49e
Author: jingz-db <[email protected]>
AuthorDate: Wed Nov 27 15:27:59 2024 +0900
[SPARK-49676][SS][PYTHON] Add Support for Chaining of Operators in
transformWithStateInPandas API
### What changes were proposed in this pull request?
This PR adds support to define event time column in the output dataset of
`TransformWithStateInPandas` operator. The new event time column will be used
to evaluate watermark expressions in downstream operators.
### Why are the changes needed?
This change is to couple with the scala implementation of chaining of
operators. PR in Scala: https://github.com/apache/spark/pull/45376
### Does this PR introduce _any_ user-facing change?
Yes. User can now specify a event time column as:
```
df.groupBy("id")
.transformWithStateInPandas(
statefulProcessor=stateful_processor,
outputStructType=output_schema,
outputMode="Update",
timeMode=timeMode,
eventTimeColumnName="outputTimestamp"
)
```
### How was this patch tested?
Integration tests.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #48124 from jingz-db/python-chaining-op.
Lead-authored-by: jingz-db <[email protected]>
Co-authored-by: Jing Zhan <[email protected]>
Signed-off-by: Jungtaek Lim <[email protected]>
---
python/pyspark/sql/pandas/group_ops.py | 2 +
.../pandas/test_pandas_transform_with_state.py | 158 ++++++++++++++++++---
.../spark/sql/catalyst/analysis/Analyzer.scala | 1 -
.../analysis/UnsupportedOperationChecker.scala | 1 +
.../spark/sql/catalyst/optimizer/Optimizer.scala | 3 +
.../apache/spark/sql/KeyValueGroupedDataset.scala | 6 +-
.../spark/sql/RelationalGroupedDataset.scala | 30 +++-
.../spark/sql/execution/SparkStrategies.scala | 10 +-
.../python/TransformWithStateInPandasExec.scala | 45 +++++-
.../execution/streaming/IncrementalExecution.scala | 17 +++
.../streaming/TransformWithStateExec.scala | 4 +-
11 files changed, 245 insertions(+), 32 deletions(-)
diff --git a/python/pyspark/sql/pandas/group_ops.py
b/python/pyspark/sql/pandas/group_ops.py
index 56efe0676c08..d8f22e434374 100644
--- a/python/pyspark/sql/pandas/group_ops.py
+++ b/python/pyspark/sql/pandas/group_ops.py
@@ -374,6 +374,7 @@ class PandasGroupedOpsMixin:
outputMode: str,
timeMode: str,
initialState: Optional["GroupedData"] = None,
+ eventTimeColumnName: str = "",
) -> DataFrame:
"""
Invokes methods defined in the stateful processor used in arbitrary
state API v2. It
@@ -662,6 +663,7 @@ class PandasGroupedOpsMixin:
outputMode,
timeMode,
initial_state_java_obj,
+ eventTimeColumnName,
)
return DataFrame(jdf, self.session)
diff --git
a/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py
b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py
index 514339249818..f385d7cd1abc 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py
@@ -27,13 +27,7 @@ from typing import cast
from pyspark import SparkConf
from pyspark.errors import PySparkRuntimeError
from pyspark.sql.functions import split
-from pyspark.sql.types import (
- StringType,
- StructType,
- StructField,
- Row,
- IntegerType,
-)
+from pyspark.sql.types import StringType, StructType, StructField, Row,
IntegerType, TimestampType
from pyspark.testing import assertDataFrameEqual
from pyspark.testing.sqlutils import (
ReusedSQLTestCase,
@@ -247,11 +241,15 @@ class TransformWithStateInPandasTestsMixin:
# test list state with ttl has the same behavior as list state when state
doesn't expire.
def test_transform_with_state_in_pandas_list_state_large_ttl(self):
- def check_results(batch_df, _):
- assert set(batch_df.sort("id").collect()) == {
- Row(id="0", countAsString="2"),
- Row(id="1", countAsString="2"),
- }
+ def check_results(batch_df, batch_id):
+ if batch_id == 0:
+ assert set(batch_df.sort("id").collect()) == {
+ Row(id="0", countAsString="2"),
+ Row(id="1", countAsString="2"),
+ }
+ else:
+ for q in self.spark.streams.active:
+ q.stop()
self._test_transform_with_state_in_pandas_basic(
ListStateLargeTTLProcessor(), check_results, True, "processingTime"
@@ -268,11 +266,15 @@ class TransformWithStateInPandasTestsMixin:
# test map state with ttl has the same behavior as map state when state
doesn't expire.
def test_transform_with_state_in_pandas_map_state_large_ttl(self):
- def check_results(batch_df, _):
- assert set(batch_df.sort("id").collect()) == {
- Row(id="0", countAsString="2"),
- Row(id="1", countAsString="2"),
- }
+ def check_results(batch_df, batch_id):
+ if batch_id == 0:
+ assert set(batch_df.sort("id").collect()) == {
+ Row(id="0", countAsString="2"),
+ Row(id="1", countAsString="2"),
+ }
+ else:
+ for q in self.spark.streams.active:
+ q.stop()
self._test_transform_with_state_in_pandas_basic(
MapStateLargeTTLProcessor(), check_results, True, "processingTime"
@@ -287,11 +289,14 @@ class TransformWithStateInPandasTestsMixin:
Row(id="0", countAsString="2"),
Row(id="1", countAsString="2"),
}
- else:
+ elif batch_id == 1:
assert set(batch_df.sort("id").collect()) == {
Row(id="0", countAsString="3"),
Row(id="1", countAsString="2"),
}
+ else:
+ for q in self.spark.streams.active:
+ q.stop()
self._test_transform_with_state_in_pandas_basic(
SimpleTTLStatefulProcessor(), check_results, False,
"processingTime"
@@ -348,6 +353,9 @@ class TransformWithStateInPandasTestsMixin:
Row(id="ttl-map-state-count-1", count=3),
],
)
+ else:
+ for q in self.spark.streams.active:
+ q.stop()
if batch_id == 0 or batch_id == 1:
time.sleep(6)
@@ -466,7 +474,7 @@ class TransformWithStateInPandasTestsMixin:
).first()["timeValues"]
check_timestamp(batch_df)
- else:
+ elif batch_id == 2:
assert set(batch_df.sort("id").select("id",
"countAsString").collect()) == {
Row(id="0", countAsString="3"),
Row(id="0", countAsString="-1"),
@@ -480,6 +488,10 @@ class TransformWithStateInPandasTestsMixin:
).first()["timeValues"]
assert current_batch_expired_timestamp >
self.first_expired_timestamp
+ else:
+ for q in self.spark.streams.active:
+ q.stop()
+
self._test_transform_with_state_in_pandas_proc_timer(
ProcTimeStatefulProcessor(), check_results
)
@@ -552,12 +564,15 @@ class TransformWithStateInPandasTestsMixin:
Row(id="a", timestamp="20"),
Row(id="a-expired", timestamp="0"),
}
- else:
+ elif batch_id == 2:
# verify that rows and expired timer produce the expected
result
assert set(batch_df.sort("id").collect()) == {
Row(id="a", timestamp="15"),
Row(id="a-expired", timestamp="10000"),
}
+ else:
+ for q in self.spark.streams.active:
+ q.stop()
self._test_transform_with_state_in_pandas_event_time(
EventTimeStatefulProcessor(), check_results
@@ -679,6 +694,9 @@ class TransformWithStateInPandasTestsMixin:
Row(id1="0", id2="1", value=str(123 + 46)),
Row(id1="1", id2="2", value=str(146 + 346)),
}
+ else:
+ for q in self.spark.streams.active:
+ q.stop()
self._test_transform_with_state_non_contiguous_grouping_cols(
SimpleStatefulProcessorWithInitialState(), check_results
@@ -692,6 +710,9 @@ class TransformWithStateInPandasTestsMixin:
Row(id1="0", id2="1", value=str(789 + 123 + 46)),
Row(id1="1", id2="2", value=str(146 + 346)),
}
+ else:
+ for q in self.spark.streams.active:
+ q.stop()
# grouping key of initial state is also not starting from the
beginning of attributes
data = [(789, "0", "1"), (987, "3", "2")]
@@ -703,6 +724,88 @@ class TransformWithStateInPandasTestsMixin:
SimpleStatefulProcessorWithInitialState(), check_results,
initial_state
)
+ def _test_transform_with_state_in_pandas_chaining_ops(
+ self, stateful_processor, check_results, timeMode="None",
grouping_cols=["outputTimestamp"]
+ ):
+ import pyspark.sql.functions as f
+
+ input_path = tempfile.mkdtemp()
+ self._prepare_input_data(input_path + "/text-test3.txt", ["a", "b"],
[10, 15])
+ time.sleep(2)
+ self._prepare_input_data(input_path + "/text-test4.txt", ["a", "c"],
[11, 25])
+ time.sleep(2)
+ self._prepare_input_data(input_path + "/text-test1.txt", ["a"], [5])
+
+ df = self._build_test_df(input_path)
+ df = df.select(
+ "id",
f.from_unixtime(f.col("temperature")).alias("eventTime").cast("timestamp")
+ ).withWatermark("eventTime", "5 seconds")
+
+ for q in self.spark.streams.active:
+ q.stop()
+ self.assertTrue(df.isStreaming)
+
+ output_schema = StructType(
+ [
+ StructField("id", StringType(), True),
+ StructField("outputTimestamp", TimestampType(), True),
+ ]
+ )
+
+ q = (
+ df.groupBy("id")
+ .transformWithStateInPandas(
+ statefulProcessor=stateful_processor,
+ outputStructType=output_schema,
+ outputMode="Append",
+ timeMode=timeMode,
+ eventTimeColumnName="outputTimestamp",
+ )
+ .groupBy(grouping_cols)
+ .count()
+ .writeStream.queryName("chaining_ops_query")
+ .foreachBatch(check_results)
+ .outputMode("append")
+ .start()
+ )
+
+ self.assertEqual(q.name, "chaining_ops_query")
+ self.assertTrue(q.isActive)
+ q.processAllAvailable()
+ q.awaitTermination(10)
+
+ def test_transform_with_state_in_pandas_chaining_ops(self):
+ def check_results(batch_df, batch_id):
+ import datetime
+
+ if batch_id == 0:
+ assert batch_df.isEmpty()
+ elif batch_id == 1:
+ # eviction watermark = 15 - 5 = 10 (max event time from batch
0),
+ # late event watermark = 0 (eviction event time from batch 0)
+ assert set(
+ batch_df.sort("outputTimestamp").select("outputTimestamp",
"count").collect()
+ ) == {
+ Row(outputTimestamp=datetime.datetime(1970, 1, 1, 0, 0,
10), count=1),
+ }
+ elif batch_id == 2:
+ # eviction watermark = 25 - 5 = 20, late event watermark = 10;
+ # row with watermark=5<10 is dropped so it does not show up in
the results;
+ # row with eventTime<=20 are finalized and emitted
+ assert set(
+ batch_df.sort("outputTimestamp").select("outputTimestamp",
"count").collect()
+ ) == {
+ Row(outputTimestamp=datetime.datetime(1970, 1, 1, 0, 0,
11), count=1),
+ Row(outputTimestamp=datetime.datetime(1970, 1, 1, 0, 0,
15), count=1),
+ }
+
+ self._test_transform_with_state_in_pandas_chaining_ops(
+ StatefulProcessorChainingOps(), check_results, "eventTime"
+ )
+ self._test_transform_with_state_in_pandas_chaining_ops(
+ StatefulProcessorChainingOps(), check_results, "eventTime",
["outputTimestamp", "id"]
+ )
+
class SimpleStatefulProcessorWithInitialState(StatefulProcessor):
# this dict is the same as input initial state dataframe
@@ -888,6 +991,21 @@ class SimpleStatefulProcessor(StatefulProcessor,
unittest.TestCase):
pass
+class StatefulProcessorChainingOps(StatefulProcessor):
+ def init(self, handle: StatefulProcessorHandle) -> None:
+ pass
+
+ def handleInputRows(
+ self, key, rows, timer_values, expired_timer_info
+ ) -> Iterator[pd.DataFrame]:
+ for pdf in rows:
+ timestamp_list = pdf["eventTime"].tolist()
+ yield pd.DataFrame({"id": key, "outputTimestamp": timestamp_list[0]})
+
+ def close(self) -> None:
+ pass
+
+
# A stateful processor that inherit all behavior of SimpleStatefulProcessor
except that it use
# ttl state with a large timeout.
class SimpleTTLStatefulProcessor(SimpleStatefulProcessor, unittest.TestCase):
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index bed7bea61597..e05f3533ae3c 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -3653,7 +3653,6 @@ object CleanupAliases extends Rule[LogicalPlan] with
AliasHelper {
/**
* Ignore event time watermark in batch query, which is only supported in
Structured Streaming.
- * TODO: add this rule into analyzer rule list.
*/
object EliminateEventTimeWatermark extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan =
plan.resolveOperatorsWithPruning(
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
index 4f33c26d5c3c..5b7583c763c0 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
@@ -103,6 +103,7 @@ object UnsupportedOperationChecker extends Logging {
case d: Deduplicate if d.isStreaming && d.keys.exists(hasEventTimeCol) =>
true
case d: DeduplicateWithinWatermark if d.isStreaming => true
case t: TransformWithState if t.isStreaming => true
+ case t: TransformWithStateInPandas if t.isStreaming => true
case _ => false
}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 29216523fefc..0772c67ea27e 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -1031,6 +1031,9 @@ object ColumnPruning extends Rule[LogicalPlan] {
// Can't prune the columns on LeafNode
case p @ Project(_, _: LeafNode) => p
+ // Can't prune the columns on UpdateEventTimeWatermarkColumn
+ case p @ Project(_, _: UpdateEventTimeWatermarkColumn) => p
+
case NestedColumnAliasing(rewrittenPlan) => rewrittenPlan
// for all other logical plans that inherits the output from it's children
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
index 392c3edab989..6dcf01d3a9db 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
@@ -18,7 +18,7 @@
package org.apache.spark.sql
import org.apache.spark.api.java.function._
-import org.apache.spark.sql.catalyst.analysis.{EliminateEventTimeWatermark,
UnresolvedAttribute}
+import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import
org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{agnosticEncoderFor,
ProductEncoder}
import org.apache.spark.sql.catalyst.encoders.encoderFor
import org.apache.spark.sql.catalyst.expressions.Attribute
@@ -289,11 +289,11 @@ class KeyValueGroupedDataset[K, V] private[sql](
transformWithState
)
- Dataset[U](sparkSession, EliminateEventTimeWatermark(
+ Dataset[U](sparkSession,
UpdateEventTimeWatermarkColumn(
UnresolvedAttribute(eventTimeColumnName),
None,
- transformWithStateDataset.logicalPlan)))
+ transformWithStateDataset.logicalPlan))
}
/** @inheritdoc */
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
index 0974df55a6d8..6f0db42ec1f5 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
@@ -21,7 +21,7 @@ import org.apache.spark.SparkRuntimeException
import org.apache.spark.annotation.Stable
import org.apache.spark.api.python.PythonEvalType
import org.apache.spark.broadcast.Broadcast
-import org.apache.spark.sql.catalyst.analysis.UnresolvedAlias
+import org.apache.spark.sql.catalyst.analysis.{UnresolvedAlias,
UnresolvedAttribute}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.logical._
@@ -475,7 +475,8 @@ class RelationalGroupedDataset protected[sql](
outputStructType: StructType,
outputModeStr: String,
timeModeStr: String,
- initialState: RelationalGroupedDataset): DataFrame = {
+ initialState: RelationalGroupedDataset,
+ eventTimeColumnName: String): DataFrame = {
def exprToAttr(expr: Seq[Expression]): Seq[Attribute] = {
expr.map {
case ne: NamedExpression => ne
@@ -529,7 +530,30 @@ class RelationalGroupedDataset protected[sql](
initialStateSchema = initialState.df.schema
)
}
- Dataset.ofRows(df.sparkSession, plan)
+ if (eventTimeColumnName.isEmpty) {
+ Dataset.ofRows(df.sparkSession, plan)
+ } else {
+ updateEventTimeColumnAfterTransformWithState(plan, eventTimeColumnName)
+ }
+ }
+
+ /**
+ * Creates a new dataset with updated eventTimeColumn after the
transformWithState
+ * logical node.
+ */
+ private def updateEventTimeColumnAfterTransformWithState(
+ transformWithStateInPandas: LogicalPlan,
+ eventTimeColumnName: String): DataFrame = {
+ val transformWithStateDataset = Dataset.ofRows(
+ df.sparkSession,
+ transformWithStateInPandas
+ )
+
+ Dataset.ofRows(df.sparkSession,
+ UpdateEventTimeWatermarkColumn(
+ UnresolvedAttribute(eventTimeColumnName),
+ None,
+ transformWithStateDataset.logicalPlan))
}
override def toString: String = {
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 22082aca81a2..c621c151c0bd 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -22,7 +22,7 @@ import java.util.Locale
import org.apache.spark.{SparkException, SparkUnsupportedOperationException}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{execution, AnalysisException, Strategy}
-import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.{ExtendedAnalysisException, InternalRow}
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight,
BuildSide, JoinSelectionHelper, NormalizeFloatingNumbers}
@@ -966,6 +966,14 @@ abstract class SparkStrategies extends
QueryPlanner[SparkPlan] {
case _: FlatMapGroupsInPandasWithState =>
// TODO(SPARK-40443): support applyInPandasWithState in batch query
throw new SparkUnsupportedOperationException("_LEGACY_ERROR_TEMP_3176")
+ case t: TransformWithStateInPandas =>
+ // TODO(SPARK-50428): support TransformWithStateInPandas in batch query
+ throw new ExtendedAnalysisException(
+ new AnalysisException(
+ "_LEGACY_ERROR_TEMP_3102",
+ Map(
+ "msg" -> "TransformWithStateInPandas is not supported with batch
DataFrames/Datasets")
+ ), plan = t)
case logical.CoGroup(
f, key, lObj, rObj, lGroup, rGroup, lAttr, rAttr, lOrder, rOrder,
oAttr, left, right) =>
execution.CoGroupExec(
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasExec.scala
index 7dd4d4647eeb..617c20c3a782 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasExec.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasExec.scala
@@ -26,6 +26,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute,
Expression, PythonUDF, SortOrder}
+import org.apache.spark.sql.catalyst.plans.logical.ProcessingTime
import org.apache.spark.sql.catalyst.plans.physical.Distribution
import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.execution.{BinaryExecNode, CoGroupedIterator,
SparkPlan}
@@ -72,6 +73,8 @@ case class TransformWithStateInPandasExec(
initialStateSchema: StructType)
extends BinaryExecNode with StateStoreWriter with WatermarkSupport {
+ override def shortName: String = "transformWithStateInPandasExec"
+
private val pythonUDF = functionExpr.asInstanceOf[PythonUDF]
private val pythonFunction = pythonUDF.func
private val chainedFunc =
@@ -126,6 +129,37 @@ case class TransformWithStateInPandasExec(
List.empty
}
+ override def shouldRunAnotherBatch(newInputWatermark: Long): Boolean = {
+ if (timeMode == ProcessingTime) {
+ // TODO SPARK-50180: check if we can return true only if actual timers
are registered,
+ // or there is expired state
+ true
+ } else if (outputMode == OutputMode.Append || outputMode ==
OutputMode.Update) {
+ eventTimeWatermarkForEviction.isDefined &&
+ newInputWatermark > eventTimeWatermarkForEviction.get
+ } else {
+ false
+ }
+ }
+
+ /**
+ * Controls watermark propagation to downstream modes. If timeMode is
+ * ProcessingTime, the output rows cannot be interpreted in eventTime, hence
+ * this node will not propagate watermark in this timeMode.
+ *
+ * For timeMode EventTime, output watermark is same as input Watermark
because
+ * transformWithState does not allow users to set the event time column to be
+ * earlier than the watermark.
+ */
+ override def produceOutputWatermark(inputWatermarkMs: Long): Option[Long] = {
+ timeMode match {
+ case ProcessingTime =>
+ None
+ case _ =>
+ Some(inputWatermarkMs)
+ }
+ }
+
override def customStatefulOperatorMetrics:
Seq[StatefulOperatorCustomMetric] = {
Seq(
// metrics around state variables
@@ -214,8 +248,15 @@ case class TransformWithStateInPandasExec(
val updatesStartTimeNs = currentTimeNs
val (dedupAttributes, argOffsets) = resolveArgOffsets(child.output,
groupingAttributes)
- val data =
- groupAndProject(dataIterator, groupingAttributes, child.output,
dedupAttributes)
+ // If timeout is based on event time, then filter late data based on
watermark
+ val filteredIter = watermarkPredicateForDataForLateEvents match {
+ case Some(predicate) =>
+ applyRemovingRowsOlderThanWatermark(dataIterator, predicate)
+ case _ =>
+ dataIterator
+ }
+
+ val data = groupAndProject(filteredIter, groupingAttributes, child.output,
dedupAttributes)
val processorHandle = new StatefulProcessorHandleImpl(store,
getStateInfo.queryRunId,
groupingKeyExprEncoder, timeMode, isStreaming = true, batchTimestampMs,
metrics)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
index 2a7e9818aedd..719c4da14d72 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
@@ -439,6 +439,23 @@ class IncrementalExecution(
eventTimeWatermarkForEviction = iwEviction)
))
+ // UpdateEventTimeColumnExec is used to tag the eventTime column, and
validate
+ // emitted rows adhere to watermark in the output of
transformWithStateInp.
+ // Hence, this node shares the same watermark value as
TransformWithStateInPandasExec.
+ // This is the same as above in TransformWithStateExec.
+ // The only difference is TransformWithStateInPandasExec is analysed
slightly different
+ // with no SerializeFromObjectExec wrapper.
+ case UpdateEventTimeColumnExec(eventTime, delay, None, t:
TransformWithStateInPandasExec)
+ if t.stateInfo.isDefined =>
+ val stateInfo = t.stateInfo.get
+ val iwLateEvents = inputWatermarkForLateEvents(stateInfo)
+ val iwEviction = inputWatermarkForEviction(stateInfo)
+
+ UpdateEventTimeColumnExec(eventTime, delay, iwLateEvents,
+ t.copy(
+ eventTimeWatermarkForLateEvents = iwLateEvents,
+ eventTimeWatermarkForEviction = iwEviction)
+ )
case t: TransformWithStateExec if t.stateInfo.isDefined =>
t.copy(
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala
index f4705b89d5a8..9c31ff0a7443 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala
@@ -85,8 +85,8 @@ case class TransformWithStateExec(
override def shouldRunAnotherBatch(newInputWatermark: Long): Boolean = {
if (timeMode == ProcessingTime) {
- // TODO: check if we can return true only if actual timers are
registered, or there is
- // expired state
+ // TODO SPARK-50180: check if we can return true only if actual timers
are registered,
+ // or there is expired state
true
} else if (outputMode == OutputMode.Append || outputMode ==
OutputMode.Update) {
eventTimeWatermarkForEviction.isDefined &&
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]