This is an automated email from the ASF dual-hosted git repository.

chengpan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-celeborn.git


The following commit(s) were added to refs/heads/main by this push:
     new df3fc194f [CELEBORN-744] Add Benchmark framework and 
ComputeIfAbsentBenchmark
df3fc194f is described below

commit df3fc194fba3bfcc339a0f31d6a0b4f90553e898
Author: Cheng Pan <[email protected]>
AuthorDate: Thu Jun 29 20:19:30 2023 +0800

    [CELEBORN-744] Add Benchmark framework and ComputeIfAbsentBenchmark
    
    ### What changes were proposed in this pull request?
    
    The benchmark shows that `computeIfAbsent` still has better performance on 
simple case
    
    ```
    
================================================================================================
    HashMap
    
================================================================================================
    
    OpenJDK 64-Bit Server VM 1.8.0_332-b09 on Mac OS X 13.4.1
    Apple M1 Pro
    HashMap:                                  Best Time(ms)   Avg Time(ms)   
Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
    
------------------------------------------------------------------------------------------------------------------------
    putIfAbsent                                         701            702      
     0         95.7          10.4       1.0X
    computeIfAbsent                                     534            535      
     1        125.6           8.0       1.3X
    
    
================================================================================================
    ConcurrentHashMap
    
================================================================================================
    
    OpenJDK 64-Bit Server VM 1.8.0_332-b09 on Mac OS X 13.4.1
    Apple M1 Pro
    ConcurrentHashMap:                        Best Time(ms)   Avg Time(ms)   
Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
    
------------------------------------------------------------------------------------------------------------------------
    putIfAbsent                                         712            716      
     3         94.2          10.6       1.0X
    computeIfAbsent                                     702            705      
     2         95.6          10.5       1.0X
    ```
    
    ### Why are the changes needed?
    
    Introduce a Benchmark framework for future performance sensitive case 
measurement.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Pass GA.
    
    Closes #1657 from pan3793/CELEBORN-744.
    
    Authored-by: Cheng Pan <[email protected]>
    Signed-off-by: Cheng Pan <[email protected]>
---
 .../ComputeIfAbsentBenchmark-results.txt           |  24 ++
 .../org/apache/celeborn/common/util/Utils.scala    |  72 +++++-
 .../org/apache/celeborn/benchmark/Benchmark.scala  | 248 +++++++++++++++++++++
 .../apache/celeborn/benchmark/BenchmarkBase.scala  |  84 +++++++
 .../org/apache/celeborn/benchmark/Benchmarks.scala |  97 ++++++++
 .../celeborn/common/ComputeIfAbsentBenchmark.scala |  71 ++++++
 6 files changed, 593 insertions(+), 3 deletions(-)

diff --git a/common/benchmarks/ComputeIfAbsentBenchmark-results.txt 
b/common/benchmarks/ComputeIfAbsentBenchmark-results.txt
new file mode 100644
index 000000000..5344d2710
--- /dev/null
+++ b/common/benchmarks/ComputeIfAbsentBenchmark-results.txt
@@ -0,0 +1,24 @@
+================================================================================================
+HashMap
+================================================================================================
+
+OpenJDK 64-Bit Server VM 1.8.0_332-b09 on Mac OS X 13.4.1
+Apple M1 Pro
+HashMap:                                  Best Time(ms)   Avg Time(ms)   
Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
+------------------------------------------------------------------------------------------------------------------------
+putIfAbsent                                         701            702         
  0         95.7          10.4       1.0X
+computeIfAbsent                                     534            535         
  1        125.6           8.0       1.3X
+
+
+================================================================================================
+ConcurrentHashMap
+================================================================================================
+
+OpenJDK 64-Bit Server VM 1.8.0_332-b09 on Mac OS X 13.4.1
+Apple M1 Pro
+ConcurrentHashMap:                        Best Time(ms)   Avg Time(ms)   
Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
+------------------------------------------------------------------------------------------------------------------------
+putIfAbsent                                         712            716         
  3         94.2          10.6       1.0X
+computeIfAbsent                                     702            705         
  2         95.6          10.5       1.0X
+
+
diff --git a/common/src/main/scala/org/apache/celeborn/common/util/Utils.scala 
b/common/src/main/scala/org/apache/celeborn/common/util/Utils.scala
index 1fd7b9c61..6338e35e8 100644
--- a/common/src/main/scala/org/apache/celeborn/common/util/Utils.scala
+++ b/common/src/main/scala/org/apache/celeborn/common/util/Utils.scala
@@ -17,25 +17,24 @@
 
 package org.apache.celeborn.common.util
 
-import java.io.{File, FileInputStream, InputStreamReader, IOException}
+import java.io.{File, FileInputStream, InputStream, InputStreamReader, 
IOException}
 import java.lang.management.ManagementFactory
 import java.math.{MathContext, RoundingMode}
 import java.net._
 import java.nio.ByteBuffer
 import java.nio.channels.FileChannel
 import java.nio.charset.StandardCharsets
-import java.text.SimpleDateFormat
 import java.util
 import java.util.{Locale, Properties, Random, UUID}
 import java.util.concurrent.{Callable, ThreadPoolExecutor, TimeoutException, 
TimeUnit}
 
 import scala.annotation.tailrec
 import scala.collection.JavaConverters._
+import scala.io.Source
 import scala.reflect.ClassTag
 import scala.util.Try
 import scala.util.control.{ControlThrowable, NonFatal}
 
-import com.google.common.net.InetAddresses
 import com.google.protobuf.{ByteString, GeneratedMessageV3}
 import io.netty.channel.unix.Errors.NativeIoException
 import org.apache.commons.lang3.SystemUtils
@@ -649,6 +648,73 @@ object Utils extends Logging {
     readProcessStdout(process)
   }
 
+  /**
+   * Execute a command and return the process running the command.
+   */
+  def executeCommand(
+      command: Seq[String],
+      workingDir: File = new File("."),
+      extraEnvironment: Map[String, String] = Map.empty,
+      redirectStderr: Boolean = true): Process = {
+    val builder = new ProcessBuilder(command: _*).directory(workingDir)
+    val environment = builder.environment()
+    for ((key, value) <- extraEnvironment) {
+      environment.put(key, value)
+    }
+    val process = builder.start()
+    if (redirectStderr) {
+      val threadName = "redirect stderr for command " + command.head
+
+      def log(s: String): Unit = logInfo(s)
+
+      processStreamByLine(threadName, process.getErrorStream, log)
+    }
+    process
+  }
+
+  /**
+   * Execute a command and get its output, throwing an exception if it yields 
a code other than 0.
+   */
+  def executeAndGetOutput(
+      command: Seq[String],
+      workingDir: File = new File("."),
+      extraEnvironment: Map[String, String] = Map.empty,
+      redirectStderr: Boolean = true): String = {
+    val process = executeCommand(command, workingDir, extraEnvironment, 
redirectStderr)
+    val output = new StringBuilder
+    val threadName = "read stdout for " + command.head
+
+    def appendToOutput(s: String): Unit = output.append(s).append("\n")
+
+    val stdoutThread = processStreamByLine(threadName, process.getInputStream, 
appendToOutput)
+    val exitCode = process.waitFor()
+    stdoutThread.join() // Wait for it to finish reading output
+    if (exitCode != 0) {
+      logError(s"Process $command exited with code $exitCode: $output")
+      throw new CelebornException(s"Process $command exited with code 
$exitCode")
+    }
+    output.toString
+  }
+
+  /**
+   * Return and start a daemon thread that processes the content of the input 
stream line by line.
+   */
+  def processStreamByLine(
+      threadName: String,
+      inputStream: InputStream,
+      processLine: String => Unit): Thread = {
+    val t = new Thread(threadName) {
+      override def run(): Unit = {
+        for (line <- Source.fromInputStream(inputStream).getLines()) {
+          processLine(line)
+        }
+      }
+    }
+    t.setDaemon(true)
+    t.start()
+    t
+  }
+
   /**
    * Create a directory inside the given parent directory. The directory is 
guaranteed to be
    * newly created, and is not marked for automatic deletion.
diff --git 
a/common/src/test/scala/org/apache/celeborn/benchmark/Benchmark.scala 
b/common/src/test/scala/org/apache/celeborn/benchmark/Benchmark.scala
new file mode 100644
index 000000000..27882539a
--- /dev/null
+++ b/common/src/test/scala/org/apache/celeborn/benchmark/Benchmark.scala
@@ -0,0 +1,248 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.benchmark
+
+import java.io.{OutputStream, PrintStream}
+
+import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+import scala.concurrent.duration._
+import scala.util.Try
+
+import org.apache.commons.io.output.TeeOutputStream
+import org.apache.commons.lang3.SystemUtils
+
+import org.apache.celeborn.common.util.Utils
+
+/**
+ * Utility class to benchmark components. An example of how to use this is:
+ *  val benchmark = new Benchmark("My Benchmark", valuesPerIteration)
+ *   benchmark.addCase("V1")(<function>)
+ *   benchmark.addCase("V2")(<function>)
+ *   benchmark.run
+ * This will output the average time to run each function and the rate of each 
function.
+ *
+ * The benchmark function takes one argument that is the iteration that's 
being run.
+ *
+ * @param name name of this benchmark.
+ * @param valuesPerIteration number of values used in the test case, used to 
compute rows/s.
+ * @param minNumIters the min number of iterations that will be run per case, 
not counting warm-up.
+ * @param warmupTime amount of time to spend running dummy case iterations for 
JIT warm-up.
+ * @param minTime further iterations will be run for each case until this time 
is used up.
+ * @param outputPerIteration if true, the timing for each run will be printed 
to stdout.
+ * @param output optional output stream to write benchmark results to
+ */
+private[celeborn] class Benchmark(
+    name: String,
+    valuesPerIteration: Long,
+    minNumIters: Int = 2,
+    warmupTime: FiniteDuration = 2.seconds,
+    minTime: FiniteDuration = 2.seconds,
+    outputPerIteration: Boolean = false,
+    output: Option[OutputStream] = None) {
+  import Benchmark._
+  val benchmarks = mutable.ArrayBuffer.empty[Benchmark.Case]
+
+  val out =
+    if (output.isDefined) {
+      new PrintStream(new TeeOutputStream(System.out, output.get))
+    } else {
+      System.out
+    }
+
+  /**
+   * Adds a case to run when run() is called. The given function will be run 
for several
+   * iterations to collect timing statistics.
+   *
+   * @param name of the benchmark case
+   * @param numIters if non-zero, forces exactly this many iterations to be run
+   */
+  def addCase(name: String, numIters: Int = 0)(f: Int => Unit): Unit = {
+    addTimerCase(name, numIters) { timer =>
+      timer.startTiming()
+      f(timer.iteration)
+      timer.stopTiming()
+    }
+  }
+
+  /**
+   * Adds a case with manual timing control. When the function is run, timing 
does not start
+   * until timer.startTiming() is called within the given function. The 
corresponding
+   * timer.stopTiming() method must be called before the function returns.
+   *
+   * @param name of the benchmark case
+   * @param numIters if non-zero, forces exactly this many iterations to be run
+   */
+  def addTimerCase(name: String, numIters: Int = 0)(f: Benchmark.Timer => 
Unit): Unit = {
+    benchmarks += Benchmark.Case(name, f, numIters)
+  }
+
+  /**
+   * Runs the benchmark and outputs the results to stdout. This should be 
copied and added as
+   * a comment with the benchmark. Although the results vary from machine to 
machine, it should
+   * provide some baseline.
+   */
+  def run(): Unit = {
+    require(benchmarks.nonEmpty)
+    // scalastyle:off
+    println("Running benchmark: " + name)
+
+    val results = benchmarks.map { c =>
+      println("  Running case: " + c.name)
+      measure(valuesPerIteration, c.numIters)(c.fn)
+    }
+    println
+
+    val firstBest = results.head.bestMs
+    // The results are going to be processor specific so it is useful to 
include that.
+    out.println(Benchmark.getJVMOSInfo())
+    out.println(Benchmark.getProcessorName())
+    val nameLen = Math.max(40, Math.max(name.length, 
benchmarks.map(_.name.length).max))
+    out.printf(
+      s"%-${nameLen}s %14s %14s %11s %12s %13s %10s\n",
+      name + ":",
+      "Best Time(ms)",
+      "Avg Time(ms)",
+      "Stdev(ms)",
+      "Rate(M/s)",
+      "Per Row(ns)",
+      "Relative")
+    out.println("-" * (nameLen + 80))
+    results.zip(benchmarks).foreach { case (result, benchmark) =>
+      out.printf(
+        s"%-${nameLen}s %14s %14s %11s %12s %13s %10s\n",
+        benchmark.name,
+        "%5.0f" format result.bestMs,
+        "%4.0f" format result.avgMs,
+        "%5.0f" format result.stdevMs,
+        "%10.1f" format result.bestRate,
+        "%6.1f" format (1000 / result.bestRate),
+        "%3.1fX" format (firstBest / result.bestMs))
+    }
+    out.println()
+    // scalastyle:on
+  }
+
+  /**
+   * Runs a single function `f` for iters, returning the average time the 
function took and
+   * the rate of the function.
+   */
+  def measure(num: Long, overrideNumIters: Int)(f: Timer => Unit): Result = {
+    System.gc() // ensures garbage from previous cases don't impact this one
+    val warmupDeadline = warmupTime.fromNow
+    while (!warmupDeadline.isOverdue) {
+      f(new Benchmark.Timer(-1))
+    }
+    val minIters = if (overrideNumIters != 0) overrideNumIters else minNumIters
+    val minDuration = if (overrideNumIters != 0) 0 else minTime.toNanos
+    val runTimes = ArrayBuffer[Long]()
+    var totalTime = 0L
+    var i = 0
+    while (i < minIters || totalTime < minDuration) {
+      val timer = new Benchmark.Timer(i)
+      f(timer)
+      val runTime = timer.totalTime()
+      runTimes += runTime
+      totalTime += runTime
+
+      if (outputPerIteration) {
+        // scalastyle:off
+        println(s"Iteration $i took ${NANOSECONDS.toMicros(runTime)} 
microseconds")
+        // scalastyle:on
+      }
+      i += 1
+    }
+    // scalastyle:off
+    println(s"  Stopped after $i iterations, 
${NANOSECONDS.toMillis(runTimes.sum)} ms")
+    // scalastyle:on
+    assert(runTimes.nonEmpty)
+    val best = runTimes.min
+    val avg = runTimes.sum / runTimes.size
+    val stdev =
+      if (runTimes.size > 1) {
+        math.sqrt(runTimes.map(time => (time - avg) * (time - avg)).sum / 
(runTimes.size - 1))
+      } else 0
+    Result(avg / 1000000.0, num / (best / 1000.0), best / 1000000.0, stdev / 
1000000.0)
+  }
+}
+
+private[celeborn] object Benchmark {
+
+  /**
+   * Object available to benchmark code to control timing e.g. to exclude 
set-up time.
+   *
+   * @param iteration specifies this is the nth iteration of running the 
benchmark case
+   */
+  class Timer(val iteration: Int) {
+    private var accumulatedTime: Long = 0L
+    private var timeStart: Long = 0L
+
+    def startTiming(): Unit = {
+      assert(timeStart == 0L, "Already started timing.")
+      timeStart = System.nanoTime
+    }
+
+    def stopTiming(): Unit = {
+      assert(timeStart != 0L, "Have not started timing.")
+      accumulatedTime += System.nanoTime - timeStart
+      timeStart = 0L
+    }
+
+    def totalTime(): Long = {
+      assert(timeStart == 0L, "Have not stopped timing.")
+      accumulatedTime
+    }
+  }
+
+  case class Case(name: String, fn: Timer => Unit, numIters: Int)
+  case class Result(avgMs: Double, bestRate: Double, bestMs: Double, stdevMs: 
Double)
+
+  /**
+   * This should return a user helpful processor information. Getting at this 
depends on the OS.
+   * This should return something like "Intel(R) Core(TM) i7-4870HQ CPU @ 
2.50GHz"
+   */
+  def getProcessorName(): String = {
+    val cpu =
+      if (SystemUtils.IS_OS_MAC_OSX) {
+        Utils.executeAndGetOutput(Seq("/usr/sbin/sysctl", "-n", 
"machdep.cpu.brand_string"))
+          .stripLineEnd
+      } else if (SystemUtils.IS_OS_LINUX) {
+        Try {
+          val grepPath = Utils.executeAndGetOutput(Seq("which", 
"grep")).stripLineEnd
+          Utils.executeAndGetOutput(Seq(grepPath, "-m", "1", "model name", 
"/proc/cpuinfo"))
+            .stripLineEnd.replaceFirst("model name[\\s*]:[\\s*]", "")
+        }.getOrElse("Unknown processor")
+      } else {
+        System.getenv("PROCESSOR_IDENTIFIER")
+      }
+    cpu
+  }
+
+  /**
+   * This should return a user helpful JVM & OS information.
+   * This should return something like
+   * "OpenJDK 64-Bit Server VM 1.8.0_65-b17 on Linux 4.1.13-100.fc21.x86_64"
+   */
+  def getJVMOSInfo(): String = {
+    val vmName = System.getProperty("java.vm.name")
+    val runtimeVersion = System.getProperty("java.runtime.version")
+    val osName = System.getProperty("os.name")
+    val osVersion = System.getProperty("os.version")
+    s"${vmName} ${runtimeVersion} on ${osName} ${osVersion}"
+  }
+}
diff --git 
a/common/src/test/scala/org/apache/celeborn/benchmark/BenchmarkBase.scala 
b/common/src/test/scala/org/apache/celeborn/benchmark/BenchmarkBase.scala
new file mode 100644
index 000000000..3641cf677
--- /dev/null
+++ b/common/src/test/scala/org/apache/celeborn/benchmark/BenchmarkBase.scala
@@ -0,0 +1,84 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.benchmark
+
+import java.io.{File, FileOutputStream, OutputStream}
+
+/**
+ * A base class for generate benchmark results to a file.
+ * For JDK9+, JDK major version number is added to the file names to 
distinguish the results.
+ */
+abstract class BenchmarkBase {
+  var output: Option[OutputStream] = None
+
+  /**
+   * Main process of the whole benchmark.
+   * Implementations of this method are supposed to use the wrapper method 
`runBenchmark`
+   * for each benchmark scenario.
+   */
+  def runBenchmarkSuite(mainArgs: Array[String]): Unit
+
+  final def runBenchmark(benchmarkName: String)(func: => Any): Unit = {
+    val separator = "=" * 96
+    val testHeader = (separator + '\n' + benchmarkName + '\n' + separator + 
'\n' + '\n').getBytes
+    output.foreach(_.write(testHeader))
+    func
+    output.foreach(_.write('\n'))
+  }
+
+  def main(args: Array[String]): Unit = {
+    val regenerateBenchmarkFiles: Boolean =
+      System.getenv("CELEBORN_GENERATE_BENCHMARK_FILES") == "1"
+    if (regenerateBenchmarkFiles) {
+      val version = System.getProperty("java.version").split("\\D+")(0).toInt
+      val jdkString = if (version > 8) s"-jdk$version" else ""
+      val resultFileName =
+        s"${this.getClass.getSimpleName.replace("$", 
"")}$jdkString$suffix-results.txt"
+      val prefix = Benchmarks.currentProjectRoot.map(_ + "/").getOrElse("")
+      val dir = new File(s"${prefix}benchmarks/")
+      if (!dir.exists()) {
+        // scalastyle:off println
+        println(s"Creating ${dir.getAbsolutePath} for benchmark results.")
+        // scalastyle:on println
+        dir.mkdirs()
+      }
+      val file = new File(dir, resultFileName)
+      if (!file.exists()) {
+        file.createNewFile()
+      }
+      output = Some(new FileOutputStream(file))
+    }
+
+    runBenchmarkSuite(args)
+
+    output.foreach { o =>
+      if (o != null) {
+        o.close()
+      }
+    }
+
+    afterAll()
+  }
+
+  def suffix: String = ""
+
+  /**
+   * Any shutdown code to ensure a clean shutdown
+   */
+  def afterAll(): Unit = {}
+}
diff --git 
a/common/src/test/scala/org/apache/celeborn/benchmark/Benchmarks.scala 
b/common/src/test/scala/org/apache/celeborn/benchmark/Benchmarks.scala
new file mode 100644
index 000000000..84cac0d5a
--- /dev/null
+++ b/common/src/test/scala/org/apache/celeborn/benchmark/Benchmarks.scala
@@ -0,0 +1,97 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.celeborn.benchmark
+
+import java.io.File
+import java.lang.reflect.Modifier
+import java.nio.file.{FileSystems, Paths}
+import java.util.Locale
+
+import scala.collection.JavaConverters._
+import scala.util.Try
+
+import com.google.common.reflect.ClassPath
+
+object Benchmarks {
+  var currentProjectRoot: Option[String] = None
+
+  def main(args: Array[String]): Unit = {
+    val isFailFast = sys.env.get(
+      
"CELEBORN_BENCHMARK_FAILFAST").forall(_.toLowerCase(Locale.ROOT).trim.toBoolean)
+    val numOfSplits = sys.env.get(
+      
"CELEBORN_BENCHMARK_NUM_SPLITS").map(_.toLowerCase(Locale.ROOT).trim.toInt).getOrElse(1)
+    val currentSplit = sys.env.get(
+      
"CELEBORN_BENCHMARK_CUR_SPLIT").map(_.toLowerCase(Locale.ROOT).trim.toInt - 
1).getOrElse(0)
+    var numBenchmark = 0
+
+    var isBenchmarkFound = false
+    val benchmarkClasses = ClassPath.from(
+      Thread.currentThread.getContextClassLoader).getTopLevelClassesRecursive(
+      "org.apache.celeborn").asScala.toArray
+    val matcher = FileSystems.getDefault.getPathMatcher(s"glob:${args.head}")
+
+    benchmarkClasses.foreach { info =>
+      lazy val clazz = info.load
+      lazy val runBenchmark = clazz.getMethod("main", classOf[Array[String]])
+      // isAssignableFrom seems not working with the reflected class from 
Guava's
+      // getTopLevelClassesRecursive.
+      require(args.length > 0, "Benchmark class to run should be specified.")
+      if (info.getName.endsWith("Benchmark") &&
+        matcher.matches(Paths.get(info.getName)) &&
+        Try(runBenchmark).isSuccess && // Does this has a main method?
+        !Modifier.isAbstract(clazz.getModifiers) // Is this a regular class?
+      ) {
+        numBenchmark += 1
+        if (numBenchmark % numOfSplits == currentSplit) {
+          isBenchmarkFound = true
+
+          val targetDirOrProjDir =
+            new File(clazz.getProtectionDomain.getCodeSource.getLocation.toURI)
+              .getParentFile.getParentFile
+
+          // The root path to be referred in each benchmark.
+          currentProjectRoot = Some {
+            if (targetDirOrProjDir.getName == "target") {
+              // SBT build
+              targetDirOrProjDir.getParentFile.getCanonicalPath
+            } else {
+              // Maven build
+              targetDirOrProjDir.getCanonicalPath
+            }
+          }
+
+          // scalastyle:off println
+          println(s"Running ${clazz.getName}:")
+          // scalastyle:on println
+          // Force GC to minimize the side effect.
+          System.gc()
+          try {
+            runBenchmark.invoke(null, args.tail)
+          } catch {
+            case e: Throwable if !isFailFast =>
+              // scalastyle:off println
+              println(s"${clazz.getName} failed with the exception below:")
+              // scalastyle:on println
+              e.printStackTrace()
+          }
+        }
+      }
+    }
+
+    if (!isBenchmarkFound) throw new RuntimeException("No benchmark found to 
run.")
+  }
+}
diff --git 
a/common/src/test/scala/org/apache/celeborn/common/ComputeIfAbsentBenchmark.scala
 
b/common/src/test/scala/org/apache/celeborn/common/ComputeIfAbsentBenchmark.scala
new file mode 100644
index 000000000..eeee4d78a
--- /dev/null
+++ 
b/common/src/test/scala/org/apache/celeborn/common/ComputeIfAbsentBenchmark.scala
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.common
+
+import java.util.{HashMap => JHashMap, Map => JMap}
+import java.util.concurrent.ConcurrentHashMap
+import java.util.concurrent.atomic.AtomicInteger
+import java.util.function.{Function => JFunction}
+
+import scala.util.Random
+
+import org.apache.celeborn.benchmark.{Benchmark, BenchmarkBase}
+
+/**
+ * ComputeIfAbsent benchmark.
+ * To run this benchmark:
+ * {{{
+ *   1. build/sbt "common/test:runMain <this class>"
+ *   2. generate result:
+ *      CELEBORN_GENERATE_BENCHMARK_FILES=1 build/sbt "common/test:runMain 
<this class>"
+ *      Results will be written to 
"benchmarks/ComputeIfAbsentBenchmark-results.txt".
+ * }}}
+ */
+object ComputeIfAbsentBenchmark extends BenchmarkBase {
+
+  def test(name: String, map: JMap[Int, AtomicInteger], iters: Int): Unit = {
+    runBenchmark(name) {
+      val benchmark = new Benchmark(name, iters, output = output)
+      benchmark.addCase("putIfAbsent") { _: Int =>
+        var i = 0
+        while (i < iters) {
+          map.putIfAbsent(Random.nextInt(32), new AtomicInteger(0))
+          i += 1
+        }
+      }
+
+      benchmark.addCase("computeIfAbsent") { _: Int =>
+        var i = 0
+        while (i < iters) {
+          map.computeIfAbsent(
+            Random.nextInt(32),
+            new JFunction[Int, AtomicInteger] {
+              override def apply(v1: Int): AtomicInteger = new AtomicInteger(0)
+            })
+          i += 1
+        }
+      }
+      benchmark.run()
+    }
+  }
+
+  override def runBenchmarkSuite(mainArgs: Array[String]): Unit = {
+    test("HashMap", new JHashMap[Int, AtomicInteger], 1 << 26)
+    test("ConcurrentHashMap", new ConcurrentHashMap[Int, AtomicInteger], 1 << 
26)
+  }
+}

Reply via email to