This is an automated email from the ASF dual-hosted git repository. kabhwan pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 2605b87990c [SPARK-45655][SQL][SS] Allow non-deterministic expressions inside AggregateFunctions in CollectMetrics 2605b87990c is described below commit 2605b87990c9826d05ad0943045e8dfa79af13e9 Author: Bhuwan Sahni <bhuwan.sa...@databricks.com> AuthorDate: Sun Nov 12 16:50:01 2023 +0900 [SPARK-45655][SQL][SS] Allow non-deterministic expressions inside AggregateFunctions in CollectMetrics ### What changes were proposed in this pull request? This PR allows non-deterministic expressions wrapped inside an `AggregateFunction` such as `count` inside `CollectMetrics` node. `CollectMetrics` is used to collect arbitrary metrics from the query, in certain scenarios user would like to collect metrics for filtering based on non-deterministic expressions (see query example below). Currently, Analyzer does not allow non-deterministic expressions inside a `AggregateFunction` for `CollectMetrics`. This constraint is relaxed to allow collection of such metrics. Note that the metrics are relevant for a completed batch, and can change if the batch is replayed (because non-deterministic expression can behave differently for different runs). While working on this feature, I found a issue with `checkMetric` logic to validate non-deterministic expressions inside an AggregateExpression. An expression is determined as non-deterministic if any of its children is non-deterministic, hence we need to match the case for `!e.deterministic && !seenAggregate` after we have matched if the current expression is a AggregateExpression. If the current expression is a AggregateExpression, we should validate further down in the tree recursi [...] ``` val inputData = MemoryStream[Timestamp] inputData.toDF() .filter("value < current_date()") .observe("metrics", count(expr("value >= current_date()")).alias("dropped")) .writeStream .queryName("ts_metrics_test") .format("memory") .outputMode("append") .start() ``` ### Why are the changes needed? 1. Added a testcase to calculate dropped rows (by `CurrentBatchTimestamp`) and ensure the query is successful. As an example, the query below fails (without this change) due to observe call on the DataFrame. ``` val inputData = MemoryStream[Timestamp] inputData.toDF() .filter("value < current_date()") .observe("metrics", count(expr("value >= current_date()")).alias("dropped")) .writeStream .queryName("ts_metrics_test") .format("memory") .outputMode("append") .start() ``` 2. Added testing in AnalysisSuite for non-deterministic expressions inside a AggregateFunction. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Unit test cases added. ``` [warn] 20 warnings found WARNING: Using incubator modules: jdk.incubator.vector, jdk.incubator.foreign [info] StreamingQueryStatusAndProgressSuite: 09:14:39.684 WARN org.apache.hadoop.util.NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable [info] Passed: Total 0, Failed 0, Errors 0, Passed 0 [info] No tests to run for hive / Test / testOnly [info] - StreamingQueryProgress - prettyJson (436 milliseconds) [info] - StreamingQueryProgress - json (3 milliseconds) [info] - StreamingQueryProgress - toString (5 milliseconds) [info] - StreamingQueryProgress - jsonString and fromJson (163 milliseconds) [info] - StreamingQueryStatus - prettyJson (1 millisecond) [info] - StreamingQueryStatus - json (1 millisecond) [info] - StreamingQueryStatus - toString (2 milliseconds) 09:14:41.674 WARN org.apache.spark.sql.execution.streaming.ResolveWriteToStream: Temporary checkpoint location created which is deleted normally when the query didn't fail: /Users/bhuwan.sahni/workspace/spark/target/tmp/temporary-34d2749f-f4d0 -46d8-bc51-29da6411e1c5. If it's required to delete it under any circumstances, please set spark.sql.streaming.forceDeleteTempCheckpointLocation to true. Important to know deleting temp checkpoint folder is best effort. 09:14:41.710 WARN org.apache.spark.sql.execution.streaming.ResolveWriteToStream: spark.sql.adaptive.enabled is not supported in streaming DataFrames/Datasets and will be disabled. [info] - progress classes should be Serializable (5 seconds, 552 milliseconds) 09:14:46.345 WARN org.apache.spark.sql.execution.streaming.ResolveWriteToStream: Temporary checkpoint location created which is deleted normally when the query didn't fail: /Users/bhuwan.sahni/workspace/spark/target/tmp/temporary-3a41d397-c3c1 -490b-9cc7-d775b0c42208. If it's required to delete it under any circumstances, please set spark.sql.streaming.forceDeleteTempCheckpointLocation to true. Important to know deleting temp checkpoint folder is best effort. 09:14:46.345 WARN org.apache.spark.sql.execution.streaming.ResolveWriteToStream: spark.sql.adaptive.enabled is not supported in streaming DataFrames/Datasets and will be disabled. [info] - SPARK-19378: Continue reporting stateOp metrics even if there is no active trigger (1 second, 337 milliseconds) 09:14:47.677 WARN org.apache.spark.sql.execution.streaming.ResolveWriteToStream: spark.sql.adaptive.enabled is not supported in streaming DataFrames/Datasets and will be disabled. [info] - SPARK-29973: Make `processedRowsPerSecond` calculated more accurately and meaningfully (455 milliseconds) 09:14:48.174 WARN org.apache.spark.sql.execution.streaming.ResolveWriteToStream: Temporary checkpoint location created which is deleted normally when the query didn't fail: /Users/bhuwan.sahni/workspace/spark/target/tmp/temporary-360fc3b9-a2c5 -430c-a892-c9869f1f8339. If it's required to delete it under any circumstances, please set spark.sql.streaming.forceDeleteTempCheckpointLocation to true. Important to know deleting temp checkpoint folder is best effort. 09:14:48.174 WARN org.apache.spark.sql.execution.streaming.ResolveWriteToStream: spark.sql.adaptive.enabled is not supported in streaming DataFrames/Datasets and will be disabled. [info] - SPARK-45655: Use current batch timestamp in observe API (587 milliseconds) 09:14:48.768 WARN org.apache.spark.sql.streaming.StreamingQueryStatusAndProgressSuite: ``` ### Was this patch authored or co-authored using generative AI tooling? No Closes #43517 from sahnib/SPARK-45655. Authored-by: Bhuwan Sahni <bhuwan.sa...@databricks.com> Signed-off-by: Jungtaek Lim <kabhwan.opensou...@gmail.com> --- .../sql/catalyst/analysis/CheckAnalysis.scala | 25 +++++++----- .../sql/catalyst/analysis/AnalysisSuite.scala | 17 +++++++-- .../StreamingQueryStatusAndProgressSuite.scala | 44 ++++++++++++++++++++++ 3 files changed, 74 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 352b3124a86..d41345f38c2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.ExtendedAnalysisException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.SubExprUtils._ -import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Median, PercentileCont, PercentileDisc} +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction, Median, PercentileCont, PercentileDisc} import org.apache.spark.sql.catalyst.optimizer.{BooleanSimplification, DecorrelateInnerQuery, InlineCTE} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ @@ -476,10 +476,6 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB e.failAnalysis( "INVALID_OBSERVED_METRICS.WINDOW_EXPRESSIONS_UNSUPPORTED", Map("expr" -> toSQLExpr(s))) - case _ if !e.deterministic && !seenAggregate => - e.failAnalysis( - "INVALID_OBSERVED_METRICS.NON_AGGREGATE_FUNC_ARG_IS_NON_DETERMINISTIC", - Map("expr" -> toSQLExpr(s))) case a: AggregateExpression if seenAggregate => e.failAnalysis( "INVALID_OBSERVED_METRICS.NESTED_AGGREGATES_UNSUPPORTED", @@ -492,12 +488,18 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB e.failAnalysis( "INVALID_OBSERVED_METRICS.AGGREGATE_EXPRESSION_WITH_FILTER_UNSUPPORTED", Map("expr" -> toSQLExpr(s))) + case _: AggregateExpression | _: AggregateFunction => + e.children.foreach(checkMetric (s, _, seenAggregate = true)) case _: Attribute if !seenAggregate => e.failAnalysis( "INVALID_OBSERVED_METRICS.NON_AGGREGATE_FUNC_ARG_IS_ATTRIBUTE", Map("expr" -> toSQLExpr(s))) - case _: AggregateExpression => - e.children.foreach(checkMetric (s, _, seenAggregate = true)) + case a: Alias => + checkMetric(s, a.child, seenAggregate) + case a if !e.deterministic && !seenAggregate => + e.failAnalysis( + "INVALID_OBSERVED_METRICS.NON_AGGREGATE_FUNC_ARG_IS_NON_DETERMINISTIC", + Map("expr" -> toSQLExpr(s))) case _ => e.children.foreach(checkMetric (s, _, seenAggregate)) } @@ -734,8 +736,13 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB "dataType" -> toSQLType(mapCol.dataType))) case o if o.expressions.exists(!_.deterministic) && - !o.isInstanceOf[Project] && !o.isInstanceOf[Filter] && - !o.isInstanceOf[Aggregate] && !o.isInstanceOf[Window] && + !o.isInstanceOf[Project] && + // non-deterministic expressions inside CollectMetrics have been + // already validated inside checkMetric function + !o.isInstanceOf[CollectMetrics] && + !o.isInstanceOf[Filter] && + !o.isInstanceOf[Aggregate] && + !o.isInstanceOf[Window] && !o.isInstanceOf[Expand] && !o.isInstanceOf[Generate] && !o.isInstanceOf[CreateVariable] && diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index ca22c55b49e..8e514e245cb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -794,9 +794,20 @@ class AnalysisSuite extends AnalysisTest with Matchers { // No columns assert(!CollectMetrics("evt", Nil, testRelation, 0).resolved) - def checkAnalysisError(exprs: Seq[NamedExpression], errors: String*): Unit = { - assertAnalysisError(CollectMetrics("event", exprs, testRelation, 0), errors) - } + // non-deterministic expression inside an aggregate function is valid + val tsLiteral = Literal.create(java.sql.Timestamp.valueOf("2023-11-30 21:05:00.000000"), + TimestampType) + + assertAnalysisSuccess( + CollectMetrics( + "invalid", + Count( + GreaterThan(tsLiteral, CurrentBatchTimestamp(1699485296000L, TimestampType)) + ).as("count") :: Nil, + testRelation, + 0 + ) + ) // Unwrapped attribute assertAnalysisErrorClass( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryStatusAndProgressSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryStatusAndProgressSuite.scala index 8fe4ef39b25..8ff71473f27 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryStatusAndProgressSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryStatusAndProgressSuite.scala @@ -17,6 +17,9 @@ package org.apache.spark.sql.streaming +import java.sql.Timestamp +import java.time.Instant +import java.time.temporal.ChronoUnit import java.util.UUID import scala.jdk.CollectionConverters._ @@ -355,6 +358,47 @@ class StreamingQueryStatusAndProgressSuite extends StreamTest with Eventually { ) } + test("SPARK-45655: Use current batch timestamp in observe API") { + import testImplicits._ + + val inputData = MemoryStream[Timestamp] + + // current_date() internally uses current batch timestamp on streaming query + val query = inputData.toDF() + .filter("value < current_date()") + .observe("metrics", count(expr("value >= current_date()")).alias("dropped")) + .writeStream + .queryName("ts_metrics_test") + .format("memory") + .outputMode("append") + .start() + + val timeNow = Instant.now().truncatedTo(ChronoUnit.SECONDS) + + // this value would be accepted by the filter and would not count towards + // dropped metrics. + val validValue = Timestamp.from(timeNow.minus(2, ChronoUnit.DAYS)) + inputData.addData(validValue) + + // would be dropped by the filter and count towards dropped metrics + inputData.addData(Timestamp.from(timeNow.plus(2, ChronoUnit.DAYS))) + + query.processAllAvailable() + query.stop() + + val dropped = query.recentProgress.map { p => + val metricVal = Option(p.observedMetrics.get("metrics")) + metricVal.map(_.getLong(0)).getOrElse(0L) + }.sum + // ensure dropped metrics are correct + assert(dropped == 1) + + val data = spark.read.table("ts_metrics_test").collect() + + // ensure valid value ends up in output + assert(data(0).getAs[Timestamp](0).equals(validValue)) + } + def waitUntilBatchProcessed: AssertOnQuery = Execute { q => eventually(Timeout(streamingTimeout)) { if (q.exception.isEmpty) { --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org