This is an automated email from the ASF dual-hosted git repository.
kabhwan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 2605b87990c [SPARK-45655][SQL][SS] Allow non-deterministic expressions
inside AggregateFunctions in CollectMetrics
2605b87990c is described below
commit 2605b87990c9826d05ad0943045e8dfa79af13e9
Author: Bhuwan Sahni <[email protected]>
AuthorDate: Sun Nov 12 16:50:01 2023 +0900
[SPARK-45655][SQL][SS] Allow non-deterministic expressions inside
AggregateFunctions in CollectMetrics
### What changes were proposed in this pull request?
This PR allows non-deterministic expressions wrapped inside an
`AggregateFunction` such as `count` inside `CollectMetrics` node.
`CollectMetrics` is used to collect arbitrary metrics from the query, in
certain scenarios user would like to collect metrics for filtering based on
non-deterministic expressions (see query example below).
Currently, Analyzer does not allow non-deterministic expressions inside a
`AggregateFunction` for `CollectMetrics`. This constraint is relaxed to allow
collection of such metrics. Note that the metrics are relevant for a completed
batch, and can change if the batch is replayed (because non-deterministic
expression can behave differently for different runs).
While working on this feature, I found a issue with `checkMetric` logic to
validate non-deterministic expressions inside an AggregateExpression. An
expression is determined as non-deterministic if any of its children is
non-deterministic, hence we need to match the case for `!e.deterministic &&
!seenAggregate` after we have matched if the current expression is a
AggregateExpression. If the current expression is a AggregateExpression, we
should validate further down in the tree recursi [...]
```
val inputData = MemoryStream[Timestamp]
inputData.toDF()
.filter("value < current_date()")
.observe("metrics", count(expr("value >=
current_date()")).alias("dropped"))
.writeStream
.queryName("ts_metrics_test")
.format("memory")
.outputMode("append")
.start()
```
### Why are the changes needed?
1. Added a testcase to calculate dropped rows (by `CurrentBatchTimestamp`)
and ensure the query is successful.
As an example, the query below fails (without this change) due to observe
call on the DataFrame.
```
val inputData = MemoryStream[Timestamp]
inputData.toDF()
.filter("value < current_date()")
.observe("metrics", count(expr("value >=
current_date()")).alias("dropped"))
.writeStream
.queryName("ts_metrics_test")
.format("memory")
.outputMode("append")
.start()
```
2. Added testing in AnalysisSuite for non-deterministic expressions inside
a AggregateFunction.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Unit test cases added.
```
[warn] 20 warnings found
WARNING: Using incubator modules: jdk.incubator.vector,
jdk.incubator.foreign
[info] StreamingQueryStatusAndProgressSuite:
09:14:39.684 WARN org.apache.hadoop.util.NativeCodeLoader: Unable to load
native-hadoop library for your platform... using builtin-java classes where
applicable
[info] Passed: Total 0, Failed 0, Errors 0, Passed 0
[info] No tests to run for hive / Test / testOnly
[info] - StreamingQueryProgress - prettyJson (436 milliseconds)
[info] - StreamingQueryProgress - json (3 milliseconds)
[info] - StreamingQueryProgress - toString (5 milliseconds)
[info] - StreamingQueryProgress - jsonString and fromJson (163 milliseconds)
[info] - StreamingQueryStatus - prettyJson (1 millisecond)
[info] - StreamingQueryStatus - json (1 millisecond)
[info] - StreamingQueryStatus - toString (2 milliseconds)
09:14:41.674 WARN
org.apache.spark.sql.execution.streaming.ResolveWriteToStream: Temporary
checkpoint location created which is deleted normally when the query didn't
fail: /Users/bhuwan.sahni/workspace/spark/target/tmp/temporary-34d2749f-f4d0
-46d8-bc51-29da6411e1c5. If it's required to delete it under any
circumstances, please set spark.sql.streaming.forceDeleteTempCheckpointLocation
to true. Important to know deleting temp checkpoint folder is best effort.
09:14:41.710 WARN
org.apache.spark.sql.execution.streaming.ResolveWriteToStream:
spark.sql.adaptive.enabled is not supported in streaming DataFrames/Datasets
and will be disabled.
[info] - progress classes should be Serializable (5 seconds, 552
milliseconds)
09:14:46.345 WARN
org.apache.spark.sql.execution.streaming.ResolveWriteToStream: Temporary
checkpoint location created which is deleted normally when the query didn't
fail: /Users/bhuwan.sahni/workspace/spark/target/tmp/temporary-3a41d397-c3c1
-490b-9cc7-d775b0c42208. If it's required to delete it under any
circumstances, please set spark.sql.streaming.forceDeleteTempCheckpointLocation
to true. Important to know deleting temp checkpoint folder is best effort.
09:14:46.345 WARN
org.apache.spark.sql.execution.streaming.ResolveWriteToStream:
spark.sql.adaptive.enabled is not supported in streaming DataFrames/Datasets
and will be disabled.
[info] - SPARK-19378: Continue reporting stateOp metrics even if there is
no active trigger (1 second, 337 milliseconds)
09:14:47.677 WARN
org.apache.spark.sql.execution.streaming.ResolveWriteToStream:
spark.sql.adaptive.enabled is not supported in streaming DataFrames/Datasets
and will be disabled.
[info] - SPARK-29973: Make `processedRowsPerSecond` calculated more
accurately and meaningfully (455 milliseconds)
09:14:48.174 WARN
org.apache.spark.sql.execution.streaming.ResolveWriteToStream: Temporary
checkpoint location created which is deleted normally when the query didn't
fail: /Users/bhuwan.sahni/workspace/spark/target/tmp/temporary-360fc3b9-a2c5
-430c-a892-c9869f1f8339. If it's required to delete it under any
circumstances, please set spark.sql.streaming.forceDeleteTempCheckpointLocation
to true. Important to know deleting temp checkpoint folder is best effort.
09:14:48.174 WARN
org.apache.spark.sql.execution.streaming.ResolveWriteToStream:
spark.sql.adaptive.enabled is not supported in streaming DataFrames/Datasets
and will be disabled.
[info] - SPARK-45655: Use current batch timestamp in observe API (587
milliseconds)
09:14:48.768 WARN
org.apache.spark.sql.streaming.StreamingQueryStatusAndProgressSuite:
```
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #43517 from sahnib/SPARK-45655.
Authored-by: Bhuwan Sahni <[email protected]>
Signed-off-by: Jungtaek Lim <[email protected]>
---
.../sql/catalyst/analysis/CheckAnalysis.scala | 25 +++++++-----
.../sql/catalyst/analysis/AnalysisSuite.scala | 17 +++++++--
.../StreamingQueryStatusAndProgressSuite.scala | 44 ++++++++++++++++++++++
3 files changed, 74 insertions(+), 12 deletions(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index 352b3124a86..d41345f38c2 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -23,7 +23,7 @@ import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.ExtendedAnalysisException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.SubExprUtils._
-import
org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression,
Median, PercentileCont, PercentileDisc}
+import
org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression,
AggregateFunction, Median, PercentileCont, PercentileDisc}
import org.apache.spark.sql.catalyst.optimizer.{BooleanSimplification,
DecorrelateInnerQuery, InlineCTE}
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
@@ -476,10 +476,6 @@ trait CheckAnalysis extends PredicateHelper with
LookupCatalog with QueryErrorsB
e.failAnalysis(
"INVALID_OBSERVED_METRICS.WINDOW_EXPRESSIONS_UNSUPPORTED",
Map("expr" -> toSQLExpr(s)))
- case _ if !e.deterministic && !seenAggregate =>
- e.failAnalysis(
-
"INVALID_OBSERVED_METRICS.NON_AGGREGATE_FUNC_ARG_IS_NON_DETERMINISTIC",
- Map("expr" -> toSQLExpr(s)))
case a: AggregateExpression if seenAggregate =>
e.failAnalysis(
"INVALID_OBSERVED_METRICS.NESTED_AGGREGATES_UNSUPPORTED",
@@ -492,12 +488,18 @@ trait CheckAnalysis extends PredicateHelper with
LookupCatalog with QueryErrorsB
e.failAnalysis(
"INVALID_OBSERVED_METRICS.AGGREGATE_EXPRESSION_WITH_FILTER_UNSUPPORTED",
Map("expr" -> toSQLExpr(s)))
+ case _: AggregateExpression | _: AggregateFunction =>
+ e.children.foreach(checkMetric (s, _, seenAggregate = true))
case _: Attribute if !seenAggregate =>
e.failAnalysis(
"INVALID_OBSERVED_METRICS.NON_AGGREGATE_FUNC_ARG_IS_ATTRIBUTE",
Map("expr" -> toSQLExpr(s)))
- case _: AggregateExpression =>
- e.children.foreach(checkMetric (s, _, seenAggregate = true))
+ case a: Alias =>
+ checkMetric(s, a.child, seenAggregate)
+ case a if !e.deterministic && !seenAggregate =>
+ e.failAnalysis(
+
"INVALID_OBSERVED_METRICS.NON_AGGREGATE_FUNC_ARG_IS_NON_DETERMINISTIC",
+ Map("expr" -> toSQLExpr(s)))
case _ =>
e.children.foreach(checkMetric (s, _, seenAggregate))
}
@@ -734,8 +736,13 @@ trait CheckAnalysis extends PredicateHelper with
LookupCatalog with QueryErrorsB
"dataType" -> toSQLType(mapCol.dataType)))
case o if o.expressions.exists(!_.deterministic) &&
- !o.isInstanceOf[Project] && !o.isInstanceOf[Filter] &&
- !o.isInstanceOf[Aggregate] && !o.isInstanceOf[Window] &&
+ !o.isInstanceOf[Project] &&
+ // non-deterministic expressions inside CollectMetrics have been
+ // already validated inside checkMetric function
+ !o.isInstanceOf[CollectMetrics] &&
+ !o.isInstanceOf[Filter] &&
+ !o.isInstanceOf[Aggregate] &&
+ !o.isInstanceOf[Window] &&
!o.isInstanceOf[Expand] &&
!o.isInstanceOf[Generate] &&
!o.isInstanceOf[CreateVariable] &&
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
index ca22c55b49e..8e514e245cb 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
@@ -794,9 +794,20 @@ class AnalysisSuite extends AnalysisTest with Matchers {
// No columns
assert(!CollectMetrics("evt", Nil, testRelation, 0).resolved)
- def checkAnalysisError(exprs: Seq[NamedExpression], errors: String*): Unit
= {
- assertAnalysisError(CollectMetrics("event", exprs, testRelation, 0),
errors)
- }
+ // non-deterministic expression inside an aggregate function is valid
+ val tsLiteral = Literal.create(java.sql.Timestamp.valueOf("2023-11-30
21:05:00.000000"),
+ TimestampType)
+
+ assertAnalysisSuccess(
+ CollectMetrics(
+ "invalid",
+ Count(
+ GreaterThan(tsLiteral, CurrentBatchTimestamp(1699485296000L,
TimestampType))
+ ).as("count") :: Nil,
+ testRelation,
+ 0
+ )
+ )
// Unwrapped attribute
assertAnalysisErrorClass(
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryStatusAndProgressSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryStatusAndProgressSuite.scala
index 8fe4ef39b25..8ff71473f27 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryStatusAndProgressSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryStatusAndProgressSuite.scala
@@ -17,6 +17,9 @@
package org.apache.spark.sql.streaming
+import java.sql.Timestamp
+import java.time.Instant
+import java.time.temporal.ChronoUnit
import java.util.UUID
import scala.jdk.CollectionConverters._
@@ -355,6 +358,47 @@ class StreamingQueryStatusAndProgressSuite extends
StreamTest with Eventually {
)
}
+ test("SPARK-45655: Use current batch timestamp in observe API") {
+ import testImplicits._
+
+ val inputData = MemoryStream[Timestamp]
+
+ // current_date() internally uses current batch timestamp on streaming
query
+ val query = inputData.toDF()
+ .filter("value < current_date()")
+ .observe("metrics", count(expr("value >=
current_date()")).alias("dropped"))
+ .writeStream
+ .queryName("ts_metrics_test")
+ .format("memory")
+ .outputMode("append")
+ .start()
+
+ val timeNow = Instant.now().truncatedTo(ChronoUnit.SECONDS)
+
+ // this value would be accepted by the filter and would not count towards
+ // dropped metrics.
+ val validValue = Timestamp.from(timeNow.minus(2, ChronoUnit.DAYS))
+ inputData.addData(validValue)
+
+ // would be dropped by the filter and count towards dropped metrics
+ inputData.addData(Timestamp.from(timeNow.plus(2, ChronoUnit.DAYS)))
+
+ query.processAllAvailable()
+ query.stop()
+
+ val dropped = query.recentProgress.map { p =>
+ val metricVal = Option(p.observedMetrics.get("metrics"))
+ metricVal.map(_.getLong(0)).getOrElse(0L)
+ }.sum
+ // ensure dropped metrics are correct
+ assert(dropped == 1)
+
+ val data = spark.read.table("ts_metrics_test").collect()
+
+ // ensure valid value ends up in output
+ assert(data(0).getAs[Timestamp](0).equals(validValue))
+ }
+
def waitUntilBatchProcessed: AssertOnQuery = Execute { q =>
eventually(Timeout(streamingTimeout)) {
if (q.exception.isEmpty) {
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]