This is an automated email from the ASF dual-hosted git repository.
chengpan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-celeborn.git
The following commit(s) were added to refs/heads/main by this push:
new df3fc194f [CELEBORN-744] Add Benchmark framework and
ComputeIfAbsentBenchmark
df3fc194f is described below
commit df3fc194fba3bfcc339a0f31d6a0b4f90553e898
Author: Cheng Pan <[email protected]>
AuthorDate: Thu Jun 29 20:19:30 2023 +0800
[CELEBORN-744] Add Benchmark framework and ComputeIfAbsentBenchmark
### What changes were proposed in this pull request?
The benchmark shows that `computeIfAbsent` still has better performance on
simple case
```
================================================================================================
HashMap
================================================================================================
OpenJDK 64-Bit Server VM 1.8.0_332-b09 on Mac OS X 13.4.1
Apple M1 Pro
HashMap: Best Time(ms) Avg Time(ms)
Stdev(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------------------------------
putIfAbsent 701 702
0 95.7 10.4 1.0X
computeIfAbsent 534 535
1 125.6 8.0 1.3X
================================================================================================
ConcurrentHashMap
================================================================================================
OpenJDK 64-Bit Server VM 1.8.0_332-b09 on Mac OS X 13.4.1
Apple M1 Pro
ConcurrentHashMap: Best Time(ms) Avg Time(ms)
Stdev(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------------------------------
putIfAbsent 712 716
3 94.2 10.6 1.0X
computeIfAbsent 702 705
2 95.6 10.5 1.0X
```
### Why are the changes needed?
Introduce a Benchmark framework for future performance sensitive case
measurement.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Pass GA.
Closes #1657 from pan3793/CELEBORN-744.
Authored-by: Cheng Pan <[email protected]>
Signed-off-by: Cheng Pan <[email protected]>
---
.../ComputeIfAbsentBenchmark-results.txt | 24 ++
.../org/apache/celeborn/common/util/Utils.scala | 72 +++++-
.../org/apache/celeborn/benchmark/Benchmark.scala | 248 +++++++++++++++++++++
.../apache/celeborn/benchmark/BenchmarkBase.scala | 84 +++++++
.../org/apache/celeborn/benchmark/Benchmarks.scala | 97 ++++++++
.../celeborn/common/ComputeIfAbsentBenchmark.scala | 71 ++++++
6 files changed, 593 insertions(+), 3 deletions(-)
diff --git a/common/benchmarks/ComputeIfAbsentBenchmark-results.txt
b/common/benchmarks/ComputeIfAbsentBenchmark-results.txt
new file mode 100644
index 000000000..5344d2710
--- /dev/null
+++ b/common/benchmarks/ComputeIfAbsentBenchmark-results.txt
@@ -0,0 +1,24 @@
+================================================================================================
+HashMap
+================================================================================================
+
+OpenJDK 64-Bit Server VM 1.8.0_332-b09 on Mac OS X 13.4.1
+Apple M1 Pro
+HashMap: Best Time(ms) Avg Time(ms)
Stdev(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------------------------------
+putIfAbsent 701 702
0 95.7 10.4 1.0X
+computeIfAbsent 534 535
1 125.6 8.0 1.3X
+
+
+================================================================================================
+ConcurrentHashMap
+================================================================================================
+
+OpenJDK 64-Bit Server VM 1.8.0_332-b09 on Mac OS X 13.4.1
+Apple M1 Pro
+ConcurrentHashMap: Best Time(ms) Avg Time(ms)
Stdev(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------------------------------
+putIfAbsent 712 716
3 94.2 10.6 1.0X
+computeIfAbsent 702 705
2 95.6 10.5 1.0X
+
+
diff --git a/common/src/main/scala/org/apache/celeborn/common/util/Utils.scala
b/common/src/main/scala/org/apache/celeborn/common/util/Utils.scala
index 1fd7b9c61..6338e35e8 100644
--- a/common/src/main/scala/org/apache/celeborn/common/util/Utils.scala
+++ b/common/src/main/scala/org/apache/celeborn/common/util/Utils.scala
@@ -17,25 +17,24 @@
package org.apache.celeborn.common.util
-import java.io.{File, FileInputStream, InputStreamReader, IOException}
+import java.io.{File, FileInputStream, InputStream, InputStreamReader,
IOException}
import java.lang.management.ManagementFactory
import java.math.{MathContext, RoundingMode}
import java.net._
import java.nio.ByteBuffer
import java.nio.channels.FileChannel
import java.nio.charset.StandardCharsets
-import java.text.SimpleDateFormat
import java.util
import java.util.{Locale, Properties, Random, UUID}
import java.util.concurrent.{Callable, ThreadPoolExecutor, TimeoutException,
TimeUnit}
import scala.annotation.tailrec
import scala.collection.JavaConverters._
+import scala.io.Source
import scala.reflect.ClassTag
import scala.util.Try
import scala.util.control.{ControlThrowable, NonFatal}
-import com.google.common.net.InetAddresses
import com.google.protobuf.{ByteString, GeneratedMessageV3}
import io.netty.channel.unix.Errors.NativeIoException
import org.apache.commons.lang3.SystemUtils
@@ -649,6 +648,73 @@ object Utils extends Logging {
readProcessStdout(process)
}
+ /**
+ * Execute a command and return the process running the command.
+ */
+ def executeCommand(
+ command: Seq[String],
+ workingDir: File = new File("."),
+ extraEnvironment: Map[String, String] = Map.empty,
+ redirectStderr: Boolean = true): Process = {
+ val builder = new ProcessBuilder(command: _*).directory(workingDir)
+ val environment = builder.environment()
+ for ((key, value) <- extraEnvironment) {
+ environment.put(key, value)
+ }
+ val process = builder.start()
+ if (redirectStderr) {
+ val threadName = "redirect stderr for command " + command.head
+
+ def log(s: String): Unit = logInfo(s)
+
+ processStreamByLine(threadName, process.getErrorStream, log)
+ }
+ process
+ }
+
+ /**
+ * Execute a command and get its output, throwing an exception if it yields
a code other than 0.
+ */
+ def executeAndGetOutput(
+ command: Seq[String],
+ workingDir: File = new File("."),
+ extraEnvironment: Map[String, String] = Map.empty,
+ redirectStderr: Boolean = true): String = {
+ val process = executeCommand(command, workingDir, extraEnvironment,
redirectStderr)
+ val output = new StringBuilder
+ val threadName = "read stdout for " + command.head
+
+ def appendToOutput(s: String): Unit = output.append(s).append("\n")
+
+ val stdoutThread = processStreamByLine(threadName, process.getInputStream,
appendToOutput)
+ val exitCode = process.waitFor()
+ stdoutThread.join() // Wait for it to finish reading output
+ if (exitCode != 0) {
+ logError(s"Process $command exited with code $exitCode: $output")
+ throw new CelebornException(s"Process $command exited with code
$exitCode")
+ }
+ output.toString
+ }
+
+ /**
+ * Return and start a daemon thread that processes the content of the input
stream line by line.
+ */
+ def processStreamByLine(
+ threadName: String,
+ inputStream: InputStream,
+ processLine: String => Unit): Thread = {
+ val t = new Thread(threadName) {
+ override def run(): Unit = {
+ for (line <- Source.fromInputStream(inputStream).getLines()) {
+ processLine(line)
+ }
+ }
+ }
+ t.setDaemon(true)
+ t.start()
+ t
+ }
+
/**
* Create a directory inside the given parent directory. The directory is
guaranteed to be
* newly created, and is not marked for automatic deletion.
diff --git
a/common/src/test/scala/org/apache/celeborn/benchmark/Benchmark.scala
b/common/src/test/scala/org/apache/celeborn/benchmark/Benchmark.scala
new file mode 100644
index 000000000..27882539a
--- /dev/null
+++ b/common/src/test/scala/org/apache/celeborn/benchmark/Benchmark.scala
@@ -0,0 +1,248 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.benchmark
+
+import java.io.{OutputStream, PrintStream}
+
+import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+import scala.concurrent.duration._
+import scala.util.Try
+
+import org.apache.commons.io.output.TeeOutputStream
+import org.apache.commons.lang3.SystemUtils
+
+import org.apache.celeborn.common.util.Utils
+
+/**
+ * Utility class to benchmark components. An example of how to use this is:
+ * val benchmark = new Benchmark("My Benchmark", valuesPerIteration)
+ * benchmark.addCase("V1")(<function>)
+ * benchmark.addCase("V2")(<function>)
+ * benchmark.run
+ * This will output the average time to run each function and the rate of each
function.
+ *
+ * The benchmark function takes one argument that is the iteration that's
being run.
+ *
+ * @param name name of this benchmark.
+ * @param valuesPerIteration number of values used in the test case, used to
compute rows/s.
+ * @param minNumIters the min number of iterations that will be run per case,
not counting warm-up.
+ * @param warmupTime amount of time to spend running dummy case iterations for
JIT warm-up.
+ * @param minTime further iterations will be run for each case until this time
is used up.
+ * @param outputPerIteration if true, the timing for each run will be printed
to stdout.
+ * @param output optional output stream to write benchmark results to
+ */
+private[celeborn] class Benchmark(
+ name: String,
+ valuesPerIteration: Long,
+ minNumIters: Int = 2,
+ warmupTime: FiniteDuration = 2.seconds,
+ minTime: FiniteDuration = 2.seconds,
+ outputPerIteration: Boolean = false,
+ output: Option[OutputStream] = None) {
+ import Benchmark._
+ val benchmarks = mutable.ArrayBuffer.empty[Benchmark.Case]
+
+ val out =
+ if (output.isDefined) {
+ new PrintStream(new TeeOutputStream(System.out, output.get))
+ } else {
+ System.out
+ }
+
+ /**
+ * Adds a case to run when run() is called. The given function will be run
for several
+ * iterations to collect timing statistics.
+ *
+ * @param name of the benchmark case
+ * @param numIters if non-zero, forces exactly this many iterations to be run
+ */
+ def addCase(name: String, numIters: Int = 0)(f: Int => Unit): Unit = {
+ addTimerCase(name, numIters) { timer =>
+ timer.startTiming()
+ f(timer.iteration)
+ timer.stopTiming()
+ }
+ }
+
+ /**
+ * Adds a case with manual timing control. When the function is run, timing
does not start
+ * until timer.startTiming() is called within the given function. The
corresponding
+ * timer.stopTiming() method must be called before the function returns.
+ *
+ * @param name of the benchmark case
+ * @param numIters if non-zero, forces exactly this many iterations to be run
+ */
+ def addTimerCase(name: String, numIters: Int = 0)(f: Benchmark.Timer =>
Unit): Unit = {
+ benchmarks += Benchmark.Case(name, f, numIters)
+ }
+
+ /**
+ * Runs the benchmark and outputs the results to stdout. This should be
copied and added as
+ * a comment with the benchmark. Although the results vary from machine to
machine, it should
+ * provide some baseline.
+ */
+ def run(): Unit = {
+ require(benchmarks.nonEmpty)
+ // scalastyle:off
+ println("Running benchmark: " + name)
+
+ val results = benchmarks.map { c =>
+ println(" Running case: " + c.name)
+ measure(valuesPerIteration, c.numIters)(c.fn)
+ }
+ println
+
+ val firstBest = results.head.bestMs
+ // The results are going to be processor specific so it is useful to
include that.
+ out.println(Benchmark.getJVMOSInfo())
+ out.println(Benchmark.getProcessorName())
+ val nameLen = Math.max(40, Math.max(name.length,
benchmarks.map(_.name.length).max))
+ out.printf(
+ s"%-${nameLen}s %14s %14s %11s %12s %13s %10s\n",
+ name + ":",
+ "Best Time(ms)",
+ "Avg Time(ms)",
+ "Stdev(ms)",
+ "Rate(M/s)",
+ "Per Row(ns)",
+ "Relative")
+ out.println("-" * (nameLen + 80))
+ results.zip(benchmarks).foreach { case (result, benchmark) =>
+ out.printf(
+ s"%-${nameLen}s %14s %14s %11s %12s %13s %10s\n",
+ benchmark.name,
+ "%5.0f" format result.bestMs,
+ "%4.0f" format result.avgMs,
+ "%5.0f" format result.stdevMs,
+ "%10.1f" format result.bestRate,
+ "%6.1f" format (1000 / result.bestRate),
+ "%3.1fX" format (firstBest / result.bestMs))
+ }
+ out.println()
+ // scalastyle:on
+ }
+
+ /**
+ * Runs a single function `f` for iters, returning the average time the
function took and
+ * the rate of the function.
+ */
+ def measure(num: Long, overrideNumIters: Int)(f: Timer => Unit): Result = {
+ System.gc() // ensures garbage from previous cases don't impact this one
+ val warmupDeadline = warmupTime.fromNow
+ while (!warmupDeadline.isOverdue) {
+ f(new Benchmark.Timer(-1))
+ }
+ val minIters = if (overrideNumIters != 0) overrideNumIters else minNumIters
+ val minDuration = if (overrideNumIters != 0) 0 else minTime.toNanos
+ val runTimes = ArrayBuffer[Long]()
+ var totalTime = 0L
+ var i = 0
+ while (i < minIters || totalTime < minDuration) {
+ val timer = new Benchmark.Timer(i)
+ f(timer)
+ val runTime = timer.totalTime()
+ runTimes += runTime
+ totalTime += runTime
+
+ if (outputPerIteration) {
+ // scalastyle:off
+ println(s"Iteration $i took ${NANOSECONDS.toMicros(runTime)}
microseconds")
+ // scalastyle:on
+ }
+ i += 1
+ }
+ // scalastyle:off
+ println(s" Stopped after $i iterations,
${NANOSECONDS.toMillis(runTimes.sum)} ms")
+ // scalastyle:on
+ assert(runTimes.nonEmpty)
+ val best = runTimes.min
+ val avg = runTimes.sum / runTimes.size
+ val stdev =
+ if (runTimes.size > 1) {
+ math.sqrt(runTimes.map(time => (time - avg) * (time - avg)).sum /
(runTimes.size - 1))
+ } else 0
+ Result(avg / 1000000.0, num / (best / 1000.0), best / 1000000.0, stdev /
1000000.0)
+ }
+}
+
+private[celeborn] object Benchmark {
+
+ /**
+ * Object available to benchmark code to control timing e.g. to exclude
set-up time.
+ *
+ * @param iteration specifies this is the nth iteration of running the
benchmark case
+ */
+ class Timer(val iteration: Int) {
+ private var accumulatedTime: Long = 0L
+ private var timeStart: Long = 0L
+
+ def startTiming(): Unit = {
+ assert(timeStart == 0L, "Already started timing.")
+ timeStart = System.nanoTime
+ }
+
+ def stopTiming(): Unit = {
+ assert(timeStart != 0L, "Have not started timing.")
+ accumulatedTime += System.nanoTime - timeStart
+ timeStart = 0L
+ }
+
+ def totalTime(): Long = {
+ assert(timeStart == 0L, "Have not stopped timing.")
+ accumulatedTime
+ }
+ }
+
+ case class Case(name: String, fn: Timer => Unit, numIters: Int)
+ case class Result(avgMs: Double, bestRate: Double, bestMs: Double, stdevMs:
Double)
+
+ /**
+ * This should return a user helpful processor information. Getting at this
depends on the OS.
+ * This should return something like "Intel(R) Core(TM) i7-4870HQ CPU @
2.50GHz"
+ */
+ def getProcessorName(): String = {
+ val cpu =
+ if (SystemUtils.IS_OS_MAC_OSX) {
+ Utils.executeAndGetOutput(Seq("/usr/sbin/sysctl", "-n",
"machdep.cpu.brand_string"))
+ .stripLineEnd
+ } else if (SystemUtils.IS_OS_LINUX) {
+ Try {
+ val grepPath = Utils.executeAndGetOutput(Seq("which",
"grep")).stripLineEnd
+ Utils.executeAndGetOutput(Seq(grepPath, "-m", "1", "model name",
"/proc/cpuinfo"))
+ .stripLineEnd.replaceFirst("model name[\\s*]:[\\s*]", "")
+ }.getOrElse("Unknown processor")
+ } else {
+ System.getenv("PROCESSOR_IDENTIFIER")
+ }
+ cpu
+ }
+
+ /**
+ * This should return a user helpful JVM & OS information.
+ * This should return something like
+ * "OpenJDK 64-Bit Server VM 1.8.0_65-b17 on Linux 4.1.13-100.fc21.x86_64"
+ */
+ def getJVMOSInfo(): String = {
+ val vmName = System.getProperty("java.vm.name")
+ val runtimeVersion = System.getProperty("java.runtime.version")
+ val osName = System.getProperty("os.name")
+ val osVersion = System.getProperty("os.version")
+ s"${vmName} ${runtimeVersion} on ${osName} ${osVersion}"
+ }
+}
diff --git
a/common/src/test/scala/org/apache/celeborn/benchmark/BenchmarkBase.scala
b/common/src/test/scala/org/apache/celeborn/benchmark/BenchmarkBase.scala
new file mode 100644
index 000000000..3641cf677
--- /dev/null
+++ b/common/src/test/scala/org/apache/celeborn/benchmark/BenchmarkBase.scala
@@ -0,0 +1,84 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.benchmark
+
+import java.io.{File, FileOutputStream, OutputStream}
+
+/**
+ * A base class for generate benchmark results to a file.
+ * For JDK9+, JDK major version number is added to the file names to
distinguish the results.
+ */
+abstract class BenchmarkBase {
+ var output: Option[OutputStream] = None
+
+ /**
+ * Main process of the whole benchmark.
+ * Implementations of this method are supposed to use the wrapper method
`runBenchmark`
+ * for each benchmark scenario.
+ */
+ def runBenchmarkSuite(mainArgs: Array[String]): Unit
+
+ final def runBenchmark(benchmarkName: String)(func: => Any): Unit = {
+ val separator = "=" * 96
+ val testHeader = (separator + '\n' + benchmarkName + '\n' + separator +
'\n' + '\n').getBytes
+ output.foreach(_.write(testHeader))
+ func
+ output.foreach(_.write('\n'))
+ }
+
+ def main(args: Array[String]): Unit = {
+ val regenerateBenchmarkFiles: Boolean =
+ System.getenv("CELEBORN_GENERATE_BENCHMARK_FILES") == "1"
+ if (regenerateBenchmarkFiles) {
+ val version = System.getProperty("java.version").split("\\D+")(0).toInt
+ val jdkString = if (version > 8) s"-jdk$version" else ""
+ val resultFileName =
+ s"${this.getClass.getSimpleName.replace("$",
"")}$jdkString$suffix-results.txt"
+ val prefix = Benchmarks.currentProjectRoot.map(_ + "/").getOrElse("")
+ val dir = new File(s"${prefix}benchmarks/")
+ if (!dir.exists()) {
+ // scalastyle:off println
+ println(s"Creating ${dir.getAbsolutePath} for benchmark results.")
+ // scalastyle:on println
+ dir.mkdirs()
+ }
+ val file = new File(dir, resultFileName)
+ if (!file.exists()) {
+ file.createNewFile()
+ }
+ output = Some(new FileOutputStream(file))
+ }
+
+ runBenchmarkSuite(args)
+
+ output.foreach { o =>
+ if (o != null) {
+ o.close()
+ }
+ }
+
+ afterAll()
+ }
+
+ def suffix: String = ""
+
+ /**
+ * Any shutdown code to ensure a clean shutdown
+ */
+ def afterAll(): Unit = {}
+}
diff --git
a/common/src/test/scala/org/apache/celeborn/benchmark/Benchmarks.scala
b/common/src/test/scala/org/apache/celeborn/benchmark/Benchmarks.scala
new file mode 100644
index 000000000..84cac0d5a
--- /dev/null
+++ b/common/src/test/scala/org/apache/celeborn/benchmark/Benchmarks.scala
@@ -0,0 +1,97 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.celeborn.benchmark
+
+import java.io.File
+import java.lang.reflect.Modifier
+import java.nio.file.{FileSystems, Paths}
+import java.util.Locale
+
+import scala.collection.JavaConverters._
+import scala.util.Try
+
+import com.google.common.reflect.ClassPath
+
+object Benchmarks {
+ var currentProjectRoot: Option[String] = None
+
+ def main(args: Array[String]): Unit = {
+ val isFailFast = sys.env.get(
+
"CELEBORN_BENCHMARK_FAILFAST").forall(_.toLowerCase(Locale.ROOT).trim.toBoolean)
+ val numOfSplits = sys.env.get(
+
"CELEBORN_BENCHMARK_NUM_SPLITS").map(_.toLowerCase(Locale.ROOT).trim.toInt).getOrElse(1)
+ val currentSplit = sys.env.get(
+
"CELEBORN_BENCHMARK_CUR_SPLIT").map(_.toLowerCase(Locale.ROOT).trim.toInt -
1).getOrElse(0)
+ var numBenchmark = 0
+
+ var isBenchmarkFound = false
+ val benchmarkClasses = ClassPath.from(
+ Thread.currentThread.getContextClassLoader).getTopLevelClassesRecursive(
+ "org.apache.celeborn").asScala.toArray
+ val matcher = FileSystems.getDefault.getPathMatcher(s"glob:${args.head}")
+
+ benchmarkClasses.foreach { info =>
+ lazy val clazz = info.load
+ lazy val runBenchmark = clazz.getMethod("main", classOf[Array[String]])
+ // isAssignableFrom seems not working with the reflected class from
Guava's
+ // getTopLevelClassesRecursive.
+ require(args.length > 0, "Benchmark class to run should be specified.")
+ if (info.getName.endsWith("Benchmark") &&
+ matcher.matches(Paths.get(info.getName)) &&
+ Try(runBenchmark).isSuccess && // Does this has a main method?
+ !Modifier.isAbstract(clazz.getModifiers) // Is this a regular class?
+ ) {
+ numBenchmark += 1
+ if (numBenchmark % numOfSplits == currentSplit) {
+ isBenchmarkFound = true
+
+ val targetDirOrProjDir =
+ new File(clazz.getProtectionDomain.getCodeSource.getLocation.toURI)
+ .getParentFile.getParentFile
+
+ // The root path to be referred in each benchmark.
+ currentProjectRoot = Some {
+ if (targetDirOrProjDir.getName == "target") {
+ // SBT build
+ targetDirOrProjDir.getParentFile.getCanonicalPath
+ } else {
+ // Maven build
+ targetDirOrProjDir.getCanonicalPath
+ }
+ }
+
+ // scalastyle:off println
+ println(s"Running ${clazz.getName}:")
+ // scalastyle:on println
+ // Force GC to minimize the side effect.
+ System.gc()
+ try {
+ runBenchmark.invoke(null, args.tail)
+ } catch {
+ case e: Throwable if !isFailFast =>
+ // scalastyle:off println
+ println(s"${clazz.getName} failed with the exception below:")
+ // scalastyle:on println
+ e.printStackTrace()
+ }
+ }
+ }
+ }
+
+ if (!isBenchmarkFound) throw new RuntimeException("No benchmark found to
run.")
+ }
+}
diff --git
a/common/src/test/scala/org/apache/celeborn/common/ComputeIfAbsentBenchmark.scala
b/common/src/test/scala/org/apache/celeborn/common/ComputeIfAbsentBenchmark.scala
new file mode 100644
index 000000000..eeee4d78a
--- /dev/null
+++
b/common/src/test/scala/org/apache/celeborn/common/ComputeIfAbsentBenchmark.scala
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.common
+
+import java.util.{HashMap => JHashMap, Map => JMap}
+import java.util.concurrent.ConcurrentHashMap
+import java.util.concurrent.atomic.AtomicInteger
+import java.util.function.{Function => JFunction}
+
+import scala.util.Random
+
+import org.apache.celeborn.benchmark.{Benchmark, BenchmarkBase}
+
+/**
+ * ComputeIfAbsent benchmark.
+ * To run this benchmark:
+ * {{{
+ * 1. build/sbt "common/test:runMain <this class>"
+ * 2. generate result:
+ * CELEBORN_GENERATE_BENCHMARK_FILES=1 build/sbt "common/test:runMain
<this class>"
+ * Results will be written to
"benchmarks/ComputeIfAbsentBenchmark-results.txt".
+ * }}}
+ */
+object ComputeIfAbsentBenchmark extends BenchmarkBase {
+
+ def test(name: String, map: JMap[Int, AtomicInteger], iters: Int): Unit = {
+ runBenchmark(name) {
+ val benchmark = new Benchmark(name, iters, output = output)
+ benchmark.addCase("putIfAbsent") { _: Int =>
+ var i = 0
+ while (i < iters) {
+ map.putIfAbsent(Random.nextInt(32), new AtomicInteger(0))
+ i += 1
+ }
+ }
+
+ benchmark.addCase("computeIfAbsent") { _: Int =>
+ var i = 0
+ while (i < iters) {
+ map.computeIfAbsent(
+ Random.nextInt(32),
+ new JFunction[Int, AtomicInteger] {
+ override def apply(v1: Int): AtomicInteger = new AtomicInteger(0)
+ })
+ i += 1
+ }
+ }
+ benchmark.run()
+ }
+ }
+
+ override def runBenchmarkSuite(mainArgs: Array[String]): Unit = {
+ test("HashMap", new JHashMap[Int, AtomicInteger], 1 << 26)
+ test("ConcurrentHashMap", new ConcurrentHashMap[Int, AtomicInteger], 1 <<
26)
+ }
+}