Github user cloud-fan commented on a diff in the pull request: https://github.com/apache/spark/pull/21898#discussion_r207704241 --- Diff: core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala --- @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import scala.util.Random + +import org.apache.spark._ + +class BarrierTaskContextSuite extends SparkFunSuite with LocalSparkContext { + + test("global sync by barrier() call") { + val conf = new SparkConf() + // Init local cluster here so each barrier task runs in a separated process, thus `barrier()` + // call is actually useful. + .setMaster("local-cluster[4, 1, 1024]") + .setAppName("test-cluster") + sc = new SparkContext(conf) + val rdd = sc.makeRDD(1 to 10, 4) + val rdd2 = rdd.barrier().mapPartitions { (it, context) => + // Sleep for a random time before global sync. + Thread.sleep(Random.nextInt(1000)) + context.barrier() + Seq(System.currentTimeMillis()).iterator + } + + val times = rdd2.collect() + // All the tasks shall finish global sync within a short time slot. + assert(times.max - times.min <= 1000) + } + + test("support multiple barrier() call within a single task") { + val conf = new SparkConf() + .setMaster("local-cluster[4, 1, 1024]") + .setAppName("test-cluster") + sc = new SparkContext(conf) + val rdd = sc.makeRDD(1 to 10, 4) + val rdd2 = rdd.barrier().mapPartitions { (it, context) => + // Sleep for a random time before global sync. + Thread.sleep(Random.nextInt(1000)) + context.barrier() + val time1 = System.currentTimeMillis() + // Sleep for a random time between two global syncs. + Thread.sleep(Random.nextInt(1000)) + context.barrier() + val time2 = System.currentTimeMillis() + Seq((time1, time2)).iterator + } + + val times = rdd2.collect() + // All the tasks shall finish the first round of global sync within a short time slot. + val times1 = times.map(_._1) + assert(times1.max - times1.min <= 1000) + + // All the tasks shall finish the second round of global sync within a short time slot. + val times2 = times.map(_._2) + assert(times2.max - times2.min <= 1000) + } + + test("throw exception on barrier() call timeout") { + val conf = new SparkConf() + .set("spark.barrier.sync.timeout", "1") + .set("spark.test.noStageRetry", "true") + .setMaster("local-cluster[4, 1, 1024]") + .setAppName("test-cluster") + sc = new SparkContext(conf) + val rdd = sc.makeRDD(1 to 10, 4) + val rdd2 = rdd.barrier().mapPartitions { (it, context) => + // Task 3 shall sleep 2000ms to ensure barrier() call timeout + if (context.taskAttemptId == 3) { + Thread.sleep(2000) + } + context.barrier() + it + } + + val error = intercept[SparkException] { + rdd2.collect() + }.getMessage + assert(error.contains("The coordinator didn't get all barrier sync requests")) + assert(error.contains("within 1s")) + } + + test("throw exception if barrier() call doesn't happen on every task") { + val conf = new SparkConf() + .set("spark.barrier.sync.timeout", "1") + .set("spark.test.noStageRetry", "true") + .setMaster("local-cluster[4, 1, 1024]") + .setAppName("test-cluster") + sc = new SparkContext(conf) + val rdd = sc.makeRDD(1 to 10, 4) + val rdd2 = rdd.barrier().mapPartitions { (it, context) => + if (context.taskAttemptId != 0) { + context.barrier() + } + it + } + + val error = intercept[SparkException] { + rdd2.collect() + }.getMessage + assert(error.contains("The coordinator didn't get all barrier sync requests")) + assert(error.contains("within 1s")) + } + + test("throw exception if the number of barrier() calls are not the same on every task") { + val conf = new SparkConf() + .set("spark.barrier.sync.timeout", "1") + .set("spark.test.noStageRetry", "true") + .setMaster("local-cluster[4, 1, 1024]") + .setAppName("test-cluster") + sc = new SparkContext(conf) + val rdd = sc.makeRDD(1 to 10, 4) + val rdd2 = rdd.barrier().mapPartitions { (it, context) => + try { --- End diff -- seems can be simpler ``` if (context.taskAttemptId == 0) { context.barrier() context.barrier() } else { context.barrier() } ```
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org