This is an automated email from the ASF dual-hosted git repository.

granthenke pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/kudu.git

commit 252b6965dd816587ddccdffa5fe5f83e28504bac
Author: Grant Henke <[email protected]>
AuthorDate: Tue May 7 09:29:55 2019 -0500

    KUDU-2775: Deflake DefaultSourceTest repartition tests
    
    This deflakes testRepartition and testRepartitionAndSort
    by ensuring the Spark job is complete giving the Spark
    listener a chance to report all of the tasks.
    
    Change-Id: I2302170df3bf3ebac6cc06381d764419c2d48303
    Reviewed-on: http://gerrit.cloudera.org:8080/13263
    Tested-by: Kudu Jenkins
    Reviewed-by: Adar Dembo <[email protected]>
---
 .../spark/tools/DistributedDataGeneratorTest.scala | 32 ++++---------
 .../apache/kudu/spark/kudu/DefaultSourceTest.scala | 24 ++++------
 .../apache/kudu/spark/kudu/SparkListenerUtil.scala | 55 ++++++++++++++++++++++
 3 files changed, 74 insertions(+), 37 deletions(-)

diff --git 
a/java/kudu-spark-tools/src/test/scala/org/apache/kudu/spark/tools/DistributedDataGeneratorTest.scala
 
b/java/kudu-spark-tools/src/test/scala/org/apache/kudu/spark/tools/DistributedDataGeneratorTest.scala
index 37e2624..902abab 100644
--- 
a/java/kudu-spark-tools/src/test/scala/org/apache/kudu/spark/tools/DistributedDataGeneratorTest.scala
+++ 
b/java/kudu-spark-tools/src/test/scala/org/apache/kudu/spark/tools/DistributedDataGeneratorTest.scala
@@ -22,9 +22,8 @@ import org.apache.kudu.spark.kudu.KuduTestSuite
 import org.apache.kudu.test.RandomUtils
 import org.apache.kudu.util.DecimalUtil
 import org.apache.kudu.util.SchemaGenerator
+import org.apache.kudu.spark.kudu.SparkListenerUtil.withJobTaskCounter
 import org.apache.spark.rdd.RDD
-import org.apache.spark.scheduler.SparkListener
-import org.apache.spark.scheduler.SparkListenerTaskEnd
 import org.apache.spark.sql.Row
 import org.junit.Test
 import org.junit.Assert.assertEquals
@@ -88,15 +87,6 @@ class DistributedDataGeneratorTest extends KuduTestSuite {
 
   @Test
   def testNumTasks() {
-    // Add a SparkListener to count the number of tasks that end.
-    var actualNumTasks = 0
-    val listener = new SparkListener {
-      override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = {
-        actualNumTasks += 1
-      }
-    }
-    ss.sparkContext.addSparkListener(listener)
-
     val numTasks = 8
     val numRows = 100
     val args = Array(
@@ -104,22 +94,16 @@ class DistributedDataGeneratorTest extends KuduTestSuite {
       s"--num-tasks=$numTasks",
       randomTableName,
       harness.getMasterAddressesAsString)
-    runGeneratorTest(args)
 
+    // count the number of tasks that end.
+    val actualNumTasks = withJobTaskCounter(ss.sparkContext) { _ =>
+      runGeneratorTest(args)
+    }
     assertEquals(numTasks, actualNumTasks)
   }
 
   @Test
   def testNumTasksRepartition(): Unit = {
-    // Add a SparkListener to count the number of tasks that end.
-    var actualNumTasks = 0
-    val listener = new SparkListener {
-      override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = {
-        actualNumTasks += 1
-      }
-    }
-    ss.sparkContext.addSparkListener(listener)
-
     val numTasks = 8
     val numRows = 100
     val args = Array(
@@ -128,7 +112,11 @@ class DistributedDataGeneratorTest extends KuduTestSuite {
       "--repartition=true",
       randomTableName,
       harness.getMasterAddressesAsString)
-    runGeneratorTest(args)
+
+    // count the number of tasks that end.
+    val actualNumTasks = withJobTaskCounter(ss.sparkContext) { _ =>
+      runGeneratorTest(args)
+    }
 
     val table = kuduContext.syncClient.openTable(randomTableName)
     val numPartitions = new 
KuduPartitioner.KuduPartitionerBuilder(table).build().numPartitions()
diff --git 
a/java/kudu-spark/src/test/scala/org/apache/kudu/spark/kudu/DefaultSourceTest.scala
 
b/java/kudu-spark/src/test/scala/org/apache/kudu/spark/kudu/DefaultSourceTest.scala
index 6afd32c..101546f 100644
--- 
a/java/kudu-spark/src/test/scala/org/apache/kudu/spark/kudu/DefaultSourceTest.scala
+++ 
b/java/kudu-spark/src/test/scala/org/apache/kudu/spark/kudu/DefaultSourceTest.scala
@@ -33,8 +33,7 @@ import org.apache.kudu.client.CreateTableOptions
 import org.apache.kudu.Schema
 import org.apache.kudu.Type
 import org.apache.kudu.test.RandomUtils
-import org.apache.spark.scheduler.SparkListener
-import org.apache.spark.scheduler.SparkListenerTaskEnd
+import org.apache.kudu.spark.kudu.SparkListenerUtil.withJobTaskCounter
 import org.apache.spark.sql.execution.datasources.LogicalRelation
 import org.junit.Before
 import org.junit.Test
@@ -428,15 +427,6 @@ class DefaultSourceTest extends KuduTestSuite with 
Matchers {
     options.addSplitRow(split)
     val table = kuduClient.createTable(tableName, simpleSchema, options)
 
-    // Add a SparkListener to count the number of tasks that end.
-    var actualNumTasks = 0
-    val listener = new SparkListener {
-      override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = {
-        actualNumTasks += 1
-      }
-    }
-    ss.sparkContext.addSparkListener(listener)
-
     val random = Random.javaRandomToRandom(RandomUtils.getRandom)
     val data = random.shuffle(
       Seq(
@@ -459,10 +449,14 @@ class DefaultSourceTest extends KuduTestSuite with 
Matchers {
     // Capture the rows so we can validate the insert order.
     kuduContext.captureRows = true
 
-    kuduContext.insertRows(
-      dataDF,
-      tableName,
-      new KuduWriteOptions(repartition = true, repartitionSort = 
repartitionSort))
+    // Count the number of tasks that end.
+    val actualNumTasks = withJobTaskCounter(ss.sparkContext) { _ =>
+      kuduContext.insertRows(
+        dataDF,
+        tableName,
+        new KuduWriteOptions(repartition = true, repartitionSort = 
repartitionSort))
+    }
+
     // 2 tasks from the parallelize call, and 2 from the repartitioning.
     assertEquals(4, actualNumTasks)
     val rows = kuduContext.rowsAccumulator.value.asScala
diff --git 
a/java/kudu-spark/src/test/scala/org/apache/kudu/spark/kudu/SparkListenerUtil.scala
 
b/java/kudu-spark/src/test/scala/org/apache/kudu/spark/kudu/SparkListenerUtil.scala
new file mode 100644
index 0000000..b9bb885
--- /dev/null
+++ 
b/java/kudu-spark/src/test/scala/org/apache/kudu/spark/kudu/SparkListenerUtil.scala
@@ -0,0 +1,55 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kudu.spark.kudu
+
+import org.apache.kudu.test.junit.AssertHelpers
+import org.apache.kudu.test.junit.AssertHelpers.BooleanExpression
+import org.apache.spark.SparkContext
+import org.apache.spark.scheduler.SparkListener
+import org.apache.spark.scheduler.SparkListenerJobEnd
+import org.apache.spark.scheduler.SparkListenerTaskEnd
+
+object SparkListenerUtil {
+
+  // TODO: Use org.apache.spark.TestUtils.withListener if it becomes public 
test API
+  def withJobTaskCounter(sc: SparkContext)(body: Any => Unit): Int = {
+    // Add a SparkListener to count the number of tasks that end.
+    var numTasks = 0
+    var jobDone = false
+    val listener: SparkListener = new SparkListener {
+      override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = {
+        numTasks += 1
+      }
+      override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = {
+        jobDone = true
+      }
+    }
+    sc.addSparkListener(listener)
+    try {
+      body()
+    } finally {
+      // Because the SparkListener events are processed on an async queue 
which is behind
+      // private API, we use the jobEnd event to know that all of the taskEnd 
events
+      // must have been processed.
+      AssertHelpers.assertEventuallyTrue("Spark job did not complete", new 
BooleanExpression {
+        override def get(): Boolean = jobDone
+      }, 5000)
+      sc.removeSparkListener(listener)
+    }
+    numTasks
+  }
+}

Reply via email to