This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch branch-3.1
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.1 by this push:
new 0d4a2e6 [SPARK-35296][SQL] Allow Dataset.observe to work even if
CollectMetricsExec in a task handles multiple partitions
0d4a2e6 is described below
commit 0d4a2e67677cea4e144ab14480f321068fc00961
Author: Kousuke Saruta <[email protected]>
AuthorDate: Fri Jun 11 01:20:35 2021 +0800
[SPARK-35296][SQL] Allow Dataset.observe to work even if CollectMetricsExec
in a task handles multiple partitions
### What changes were proposed in this pull request?
This PR fixes an issue that `Dataset.observe` doesn't work if
`CollectMetricsExec` in a task handles multiple partitions.
If `coalesce` follows `observe` and the number of partitions shrinks after
`coalesce`, `CollectMetricsExec` can handle multiple partitions in a task.
### Why are the changes needed?
The current implementation of `CollectMetricsExec` doesn't consider the
case it can handle multiple partitions.
Because new `updater` is created for each partition even though those
partitions belong to the same task, `collector.setState(updater)` raise an
assertion error.
This is a simple reproducible example.
```
$ bin/spark-shell --master "local[1]"
scala> spark.range(1, 4, 1, 3).observe("my_event",
count($"id").as("count_val")).coalesce(2).collect
```
```
java.lang.AssertionError: assertion failed
at scala.Predef$.assert(Predef.scala:208)
at
org.apache.spark.sql.execution.AggregatingAccumulator.setState(AggregatingAccumulator.scala:204)
at
org.apache.spark.sql.execution.CollectMetricsExec.$anonfun$doExecute$2(CollectMetricsExec.scala:72)
at
org.apache.spark.sql.execution.CollectMetricsExec.$anonfun$doExecute$2$adapted(CollectMetricsExec.scala:71)
at
org.apache.spark.TaskContext$$anon$1.onTaskCompletion(TaskContext.scala:125)
at
org.apache.spark.TaskContextImpl.$anonfun$markTaskCompleted$1(TaskContextImpl.scala:124)
at
org.apache.spark.TaskContextImpl.$anonfun$markTaskCompleted$1$adapted(TaskContextImpl.scala:124)
at
org.apache.spark.TaskContextImpl.$anonfun$invokeListeners$1(TaskContextImpl.scala:137)
at
org.apache.spark.TaskContextImpl.$anonfun$invokeListeners$1$adapted(TaskContextImpl.scala:135)
```
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
New test.
Closes #32786 from sarutak/fix-collectmetricsexec.
Authored-by: Kousuke Saruta <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
(cherry picked from commit 44b695fbb06b0d89783b4838941c68543c5a5c8b)
Signed-off-by: Wenchen Fan <[email protected]>
---
.../sql/execution/AggregatingAccumulator.scala | 16 +++++-----
.../spark/sql/execution/CollectMetricsExec.scala | 6 +++-
.../spark/sql/util/DataFrameCallbackSuite.scala | 34 ++++++++++++++++++++++
3 files changed, 48 insertions(+), 8 deletions(-)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/AggregatingAccumulator.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/AggregatingAccumulator.scala
index 94e159c..0fa4e6c 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/AggregatingAccumulator.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/AggregatingAccumulator.scala
@@ -33,7 +33,7 @@ class AggregatingAccumulator private(
bufferSchema: Seq[DataType],
initialValues: Seq[Expression],
updateExpressions: Seq[Expression],
- @transient private val mergeExpressions: Seq[Expression],
+ mergeExpressions: Seq[Expression],
@transient private val resultExpressions: Seq[Expression],
imperatives: Array[ImperativeAggregate],
typedImperatives: Array[TypedImperativeAggregate[_]],
@@ -95,13 +95,14 @@ class AggregatingAccumulator private(
/**
* Driver side operations like `merge` and `value` are executed in the
DAGScheduler thread. This
- * thread does not have a SQL configuration so we attach our own here. Note
that we can't (and
- * shouldn't) call `merge` or `value` on an accumulator originating from an
executor so we just
- * return a default value here.
+ * thread does not have a SQL configuration so we attach our own here.
*/
- private[this] def withSQLConf[T](default: => T)(body: => T): T = {
+ private[this] def withSQLConf[T](canRunOnExecutor: Boolean, default: =>
T)(body: => T): T = {
if (conf != null) {
+ // When we can reach here, we are on the driver side.
SQLConf.withExistingConf(conf)(body)
+ } else if (canRunOnExecutor) {
+ body
} else {
default
}
@@ -147,7 +148,8 @@ class AggregatingAccumulator private(
}
}
- override def merge(other: AccumulatorV2[InternalRow, InternalRow]): Unit =
withSQLConf(()) {
+ override def merge(
+ other: AccumulatorV2[InternalRow, InternalRow]): Unit =
withSQLConf(true, ()) {
if (!other.isZero) {
other match {
case agg: AggregatingAccumulator =>
@@ -171,7 +173,7 @@ class AggregatingAccumulator private(
}
}
- override def value: InternalRow = withSQLConf(InternalRow.empty) {
+ override def value: InternalRow = withSQLConf(false, InternalRow.empty) {
// Either use the existing buffer or create a temporary one.
val input = if (!isZero) {
buffer
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/CollectMetricsExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/CollectMetricsExec.scala
index b0bbb52..2883bc0 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/CollectMetricsExec.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/CollectMetricsExec.scala
@@ -69,7 +69,11 @@ case class CollectMetricsExec(
// - Performance issues due to excessive serialization.
val updater = collector.copyAndReset()
TaskContext.get().addTaskCompletionListener[Unit] { _ =>
- collector.setState(updater)
+ if (collector.isZero) {
+ collector.setState(updater)
+ } else {
+ collector.merge(updater)
+ }
}
rows.map { r =>
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala
index b17c935..296c810 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala
@@ -281,6 +281,40 @@ class DataFrameCallbackSuite extends QueryTest
}
}
+ test("SPARK-35296: observe should work even if a task contains multiple
partitions") {
+ val metricMaps = ArrayBuffer.empty[Map[String, Row]]
+ val listener = new QueryExecutionListener {
+ override def onSuccess(funcName: String, qe: QueryExecution, duration:
Long): Unit = {
+ metricMaps += qe.observedMetrics
+ }
+
+ override def onFailure(funcName: String, qe: QueryExecution, exception:
Exception): Unit = {
+ // No-op
+ }
+ }
+ spark.listenerManager.register(listener)
+ try {
+ val df = spark.range(1, 4, 1, 3)
+ .observe(
+ name = "my_event",
+ count($"id").as("count_val"))
+ .coalesce(2)
+
+ def checkMetrics(metrics: Map[String, Row]): Unit = {
+ assert(metrics.size === 1)
+ assert(metrics("my_event") === Row(3L))
+ }
+
+ df.collect()
+ sparkContext.listenerBus.waitUntilEmpty()
+ assert(metricMaps.size === 1)
+ checkMetrics(metricMaps.head)
+ metricMaps.clear()
+ } finally {
+ spark.listenerManager.unregister(listener)
+ }
+ }
+
testQuietly("SPARK-31144: QueryExecutionListener should receive
`java.lang.Error`") {
var e: Exception = null
val listener = new QueryExecutionListener {
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]