This is an automated email from the ASF dual-hosted git repository.
hongze pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new 7020ed376 [VL] CI: Reformat gluten-it code with Spark331's scalafmt
configuration (#5615)
7020ed376 is described below
commit 7020ed3768ae1315481d6091a6aec33e3f93b66f
Author: Hongze Zhang <[email protected]>
AuthorDate: Mon May 6 17:06:45 2024 +0800
[VL] CI: Reformat gluten-it code with Spark331's scalafmt configuration
(#5615)
---
.../apache/gluten/integration/tpc/Constants.scala | 28 ++--
.../apache/gluten/integration/tpc/DataGen.scala | 43 ++++---
.../apache/gluten/integration/tpc/ShimUtils.scala | 12 +-
.../apache/gluten/integration/tpc/TpcRunner.scala | 25 ++--
.../apache/gluten/integration/tpc/TpcSuite.scala | 15 ++-
.../integration/tpc/action/Parameterized.scala | 142 +++++++++------------
.../gluten/integration/tpc/action/Queries.scala | 72 +++++------
.../integration/tpc/action/QueriesCompare.scala | 105 +++++++--------
.../gluten/integration/tpc/ds/TpcdsDataGen.scala | 112 +++++++---------
.../gluten/integration/tpc/ds/TpcdsSuite.scala | 44 ++++---
.../gluten/integration/tpc/h/TpchDataGen.scala | 141 ++++++++------------
.../gluten/integration/tpc/h/TpchSuite.scala | 35 +++--
.../history/GlutenItHistoryServerPlugin.scala | 85 ++++++------
.../spark/deploy/history/HistoryServerHelper.scala | 16 +--
.../scala/org/apache/spark/sql/ConfUtils.scala | 10 +-
.../scala/org/apache/spark/sql/QueryRunner.scala | 25 ++--
.../scala/org/apache/spark/sql/TestUtils.scala | 29 ++---
17 files changed, 430 insertions(+), 509 deletions(-)
diff --git
a/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/tpc/Constants.scala
b/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/tpc/Constants.scala
index 7564f6dce..d39a16c32 100644
---
a/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/tpc/Constants.scala
+++
b/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/tpc/Constants.scala
@@ -18,7 +18,14 @@ package org.apache.gluten.integration.tpc
import org.apache.spark.SparkConf
import org.apache.spark.sql.TypeUtils
-import org.apache.spark.sql.types.{DateType, DecimalType, DoubleType,
IntegerType, LongType, StringType}
+import org.apache.spark.sql.types.{
+ DateType,
+ DecimalType,
+ DoubleType,
+ IntegerType,
+ LongType,
+ StringType
+}
import java.sql.Date
@@ -33,16 +40,15 @@ object Constants {
.set("spark.shuffle.manager",
"org.apache.spark.shuffle.sort.ColumnarShuffleManager")
.set("spark.sql.optimizer.runtime.bloomFilter.enabled", "true")
.set("spark.sql.optimizer.runtime.bloomFilter.applicationSideScanSizeThreshold",
"0")
- .set(
- "spark.gluten.sql.columnar.physicalJoinOptimizeEnable",
- "false"
- ) // q72 slow if false, q64 fails if true
+ .set("spark.gluten.sql.columnar.physicalJoinOptimizeEnable", "false") //
q72 slow if false, q64 fails if true
val VELOX_WITH_CELEBORN_CONF: SparkConf = new SparkConf(false)
.set("spark.gluten.sql.columnar.forceShuffledHashJoin", "true")
.set("spark.sql.parquet.enableVectorizedReader", "true")
.set("spark.plugins", "org.apache.gluten.GlutenPlugin")
- .set("spark.shuffle.manager",
"org.apache.spark.shuffle.gluten.celeborn.CelebornShuffleManager")
+ .set(
+ "spark.shuffle.manager",
+ "org.apache.spark.shuffle.gluten.celeborn.CelebornShuffleManager")
.set("spark.celeborn.shuffle.writer", "hash")
.set("spark.celeborn.push.replicate.enabled", "false")
.set("spark.celeborn.client.shuffle.compression.codec", "none")
@@ -51,10 +57,7 @@ object Constants {
.set("spark.dynamicAllocation.enabled", "false")
.set("spark.sql.optimizer.runtime.bloomFilter.enabled", "true")
.set("spark.sql.optimizer.runtime.bloomFilter.applicationSideScanSizeThreshold",
"0")
- .set(
- "spark.gluten.sql.columnar.physicalJoinOptimizeEnable",
- "false"
- ) // q72 slow if false, q64 fails if true
+ .set("spark.gluten.sql.columnar.physicalJoinOptimizeEnable", "false") //
q72 slow if false, q64 fails if true
.set("spark.celeborn.push.data.timeout", "600s")
.set("spark.celeborn.push.limit.inFlight.timeout", "1200s")
@@ -72,10 +75,7 @@ object Constants {
.set("spark.dynamicAllocation.enabled", "false")
.set("spark.sql.optimizer.runtime.bloomFilter.enabled", "true")
.set("spark.sql.optimizer.runtime.bloomFilter.applicationSideScanSizeThreshold",
"0")
- .set(
- "spark.gluten.sql.columnar.physicalJoinOptimizeEnable",
- "false"
- )
+ .set("spark.gluten.sql.columnar.physicalJoinOptimizeEnable", "false")
@deprecated
val TYPE_MODIFIER_DATE_AS_DOUBLE: TypeModifier =
diff --git
a/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/tpc/DataGen.scala
b/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/tpc/DataGen.scala
index 5c092089d..e810a4dc2 100644
---
a/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/tpc/DataGen.scala
+++
b/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/tpc/DataGen.scala
@@ -23,7 +23,7 @@ trait DataGen {
}
abstract class TypeModifier(val predicate: DataType => Boolean, val to:
DataType)
- extends Serializable {
+ extends Serializable {
def modValue(value: Any): Any
}
@@ -32,29 +32,30 @@ class NoopModifier(t: DataType) extends TypeModifier(_ =>
true, t) {
}
object DataGen {
- def getRowModifier(schema: StructType, typeModifiers: List[TypeModifier]):
Int => TypeModifier = {
- val modifiers = schema.fields.map {
- f =>
- val matchedModifiers = typeModifiers.flatMap {
- m =>
- if (m.predicate.apply(f.dataType)) {
- Some(m)
- } else {
- None
- }
- }
- if (matchedModifiers.isEmpty) {
- new NoopModifier(f.dataType)
+ def getRowModifier(
+ schema: StructType,
+ typeModifiers: List[TypeModifier]): Int => TypeModifier = {
+ val modifiers = schema.fields.map { f =>
+ val matchedModifiers = typeModifiers.flatMap { m =>
+ if (m.predicate.apply(f.dataType)) {
+ Some(m)
} else {
- if (matchedModifiers.size > 1) {
- println(
- s"More than one type modifiers specified for type ${f.dataType},
" +
- s"use first one in the list")
- }
- matchedModifiers.head // use the first one that matches
+ None
+ }
+ }
+ if (matchedModifiers.isEmpty) {
+ new NoopModifier(f.dataType)
+ } else {
+ if (matchedModifiers.size > 1) {
+ println(
+ s"More than one type modifiers specified for type ${f.dataType}, "
+
+ s"use first one in the list")
}
+ matchedModifiers.head // use the first one that matches
+ }
}
- i => modifiers(i)
+ i =>
+ modifiers(i)
}
def modifySchema(schema: StructType, rowModifier: Int => TypeModifier):
StructType = {
diff --git
a/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/tpc/ShimUtils.scala
b/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/tpc/ShimUtils.scala
index c64fa160f..19e15df5c 100644
---
a/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/tpc/ShimUtils.scala
+++
b/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/tpc/ShimUtils.scala
@@ -25,13 +25,17 @@ object ShimUtils {
def getExpressionEncoder(schema: StructType): ExpressionEncoder[Row] = {
try {
- RowEncoder.getClass.getMethod("apply", classOf[StructType])
- .invoke(RowEncoder, schema).asInstanceOf[ExpressionEncoder[Row]]
+ RowEncoder.getClass
+ .getMethod("apply", classOf[StructType])
+ .invoke(RowEncoder, schema)
+ .asInstanceOf[ExpressionEncoder[Row]]
} catch {
case _: Exception =>
// to be compatible with Spark 3.5 and later
- ExpressionEncoder.getClass.getMethod("apply", classOf[StructType])
- .invoke(ExpressionEncoder,
schema).asInstanceOf[ExpressionEncoder[Row]]
+ ExpressionEncoder.getClass
+ .getMethod("apply", classOf[StructType])
+ .invoke(ExpressionEncoder, schema)
+ .asInstanceOf[ExpressionEncoder[Row]]
}
}
}
diff --git
a/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/tpc/TpcRunner.scala
b/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/tpc/TpcRunner.scala
index ab76dc68c..908b8206e 100644
---
a/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/tpc/TpcRunner.scala
+++
b/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/tpc/TpcRunner.scala
@@ -48,20 +48,19 @@ class TpcRunner(val queryResourceFolder: String, val
dataPath: String) {
object TpcRunner {
def createTables(spark: SparkSession, dataPath: String): Unit = {
val files = new File(dataPath).listFiles()
- files.foreach(
- file => {
- if (spark.catalog.tableExists(file.getName)) {
- println("Table exists: " + file.getName)
- } else {
- println("Creating catalog table: " + file.getName)
- spark.catalog.createTable(file.getName, file.getAbsolutePath,
"parquet")
- try {
- spark.catalog.recoverPartitions(file.getName)
- } catch {
- case _: AnalysisException =>
- }
+ files.foreach(file => {
+ if (spark.catalog.tableExists(file.getName)) {
+ println("Table exists: " + file.getName)
+ } else {
+ println("Creating catalog table: " + file.getName)
+ spark.catalog.createTable(file.getName, file.getAbsolutePath,
"parquet")
+ try {
+ spark.catalog.recoverPartitions(file.getName)
+ } catch {
+ case _: AnalysisException =>
}
- })
+ }
+ })
}
private def delete(path: String): Unit = {
diff --git
a/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/tpc/TpcSuite.scala
b/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/tpc/TpcSuite.scala
index 058657976..f7605e273 100644
---
a/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/tpc/TpcSuite.scala
+++
b/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/tpc/TpcSuite.scala
@@ -68,7 +68,9 @@ abstract class TpcSuite(
.setWarningOnOverriding("spark.executor.metrics.pollingInterval", "0")
sessionSwitcher.defaultConf().setWarningOnOverriding("spark.network.timeout",
"3601s")
sessionSwitcher.defaultConf().setWarningOnOverriding("spark.sql.broadcastTimeout",
"1800")
-
sessionSwitcher.defaultConf().setWarningOnOverriding("spark.network.io.preferDirectBufs",
"false")
+ sessionSwitcher
+ .defaultConf()
+ .setWarningOnOverriding("spark.network.io.preferDirectBufs", "false")
sessionSwitcher
.defaultConf()
.setWarningOnOverriding("spark.unsafe.exceptionOnMemoryLeak",
s"$errorOnMemLeak")
@@ -113,8 +115,8 @@ abstract class TpcSuite(
sessionSwitcher.defaultConf().setWarningOnOverriding("spark.default.parallelism",
"1")
}
- extraSparkConf.toStream.foreach {
- kv => sessionSwitcher.defaultConf().setWarningOnOverriding(kv._1, kv._2)
+ extraSparkConf.toStream.foreach { kv =>
+ sessionSwitcher.defaultConf().setWarningOnOverriding(kv._1, kv._2)
}
// register sessions
@@ -134,10 +136,9 @@ abstract class TpcSuite(
}
def run(): Boolean = {
- val succeed = actions.forall {
- action =>
- resetLogLevel() // to prevent log level from being set by unknown
external codes
- action.execute(this)
+ val succeed = actions.forall { action =>
+ resetLogLevel() // to prevent log level from being set by unknown
external codes
+ action.execute(this)
}
succeed
}
diff --git
a/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/tpc/action/Parameterized.scala
b/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/tpc/action/Parameterized.scala
index f066659ef..b4f7a5394 100644
---
a/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/tpc/action/Parameterized.scala
+++
b/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/tpc/action/Parameterized.scala
@@ -36,26 +36,22 @@ class Parameterized(
configDimensions: Seq[Dim],
excludedCombinations: Seq[Set[DimKv]],
metrics: Array[String])
- extends Action {
+ extends Action {
private def validateDims(configDimensions: Seq[Dim]): Unit = {
- if (
- configDimensions
- .map(
- dim => {
+ if (configDimensions
+ .map(dim => {
dim.name
})
- .toSet
- .size != configDimensions.size
- ) {
+ .toSet
+ .size != configDimensions.size) {
throw new IllegalStateException("Duplicated dimension name found")
}
- configDimensions.foreach {
- dim =>
- if (dim.dimValues.map(dimValue => dimValue.name).toSet.size !=
dim.dimValues.size) {
- throw new IllegalStateException("Duplicated dimension value found")
- }
+ configDimensions.foreach { dim =>
+ if (dim.dimValues.map(dimValue => dimValue.name).toSet.size !=
dim.dimValues.size) {
+ throw new IllegalStateException("Duplicated dimension value found")
+ }
}
}
@@ -70,26 +66,23 @@ class Parameterized(
intermediateConf: Seq[(String, String)]): Unit = {
if (dimOffset == dimCount) {
// we got one coordinate
- excludedCombinations.foreach {
- ec: Set[DimKv] =>
- if (ec.forall {
- kv =>
+ excludedCombinations.foreach { ec: Set[DimKv] =>
+ if (ec.forall { kv =>
intermediateCoordinates.contains(kv.k) &&
intermediateCoordinates(kv.k) == kv.v
- }) {
- println(s"Coordinate ${Coordinate(intermediateCoordinates)}
excluded by $ec.")
- return
- }
+ }) {
+ println(s"Coordinate ${Coordinate(intermediateCoordinates)}
excluded by $ec.")
+ return
+ }
}
coordinateMap(Coordinate(intermediateCoordinates)) = intermediateConf
return
}
val dim = configDimensions(dimOffset)
- dim.dimValues.foreach {
- dimValue =>
- fillCoordinates(
- dimOffset + 1,
- intermediateCoordinates + (dim.name -> dimValue.name),
- intermediateConf ++ dimValue.conf)
+ dim.dimValues.foreach { dimValue =>
+ fillCoordinates(
+ dimOffset + 1,
+ intermediateCoordinates + (dim.name -> dimValue.name),
+ intermediateConf ++ dimValue.conf)
}
}
@@ -110,45 +103,40 @@ class Parameterized(
case (c, idx) =>
println(s" $idx: $c")
}
- coordinates.foreach {
- entry =>
- // register one session per coordinate
- val coordinate = entry._1
- val coordinateConf = entry._2
- val conf = testConf.clone()
- conf.setAllWarningOnOverriding(coordinateConf)
- sessionSwitcher.registerSession(coordinate.toString, conf)
+ coordinates.foreach { entry =>
+ // register one session per coordinate
+ val coordinate = entry._1
+ val coordinateConf = entry._2
+ val conf = testConf.clone()
+ conf.setAllWarningOnOverriding(coordinateConf)
+ sessionSwitcher.registerSession(coordinate.toString, conf)
}
val runQueryIds = queries.select(tpcSuite)
// warm up
- (0 until warmupIterations).foreach {
- _ =>
- runQueryIds.foreach {
- queryId => Parameterized.warmUp(queryId, tpcSuite.desc(),
sessionSwitcher, runner)
- }
+ (0 until warmupIterations).foreach { _ =>
+ runQueryIds.foreach { queryId =>
+ Parameterized.warmUp(queryId, tpcSuite.desc(), sessionSwitcher, runner)
+ }
}
- val results = coordinates.flatMap {
- entry =>
- val coordinate = entry._1
- val coordinateResults = (0 until iterations).flatMap {
- iteration =>
- println(s"Running tests (iteration $iteration) with coordinate
$coordinate...")
- runQueryIds.map {
- queryId =>
- Parameterized.runTpcQuery(
- runner,
- sessionSwitcher,
- queryId,
- coordinate,
- tpcSuite.desc(),
- explain,
- metrics)
- }
- }.toList
- coordinateResults
+ val results = coordinates.flatMap { entry =>
+ val coordinate = entry._1
+ val coordinateResults = (0 until iterations).flatMap { iteration =>
+ println(s"Running tests (iteration $iteration) with coordinate
$coordinate...")
+ runQueryIds.map { queryId =>
+ Parameterized.runTpcQuery(
+ runner,
+ sessionSwitcher,
+ queryId,
+ coordinate,
+ tpcSuite.desc(),
+ explain,
+ metrics)
+ }
+ }.toList
+ coordinateResults
}
val dimNames = configDimensions.map(dim => dim.name)
@@ -164,8 +152,7 @@ class Parameterized(
"RAM statistics: JVM Heap size: %d KiB (total %d KiB), Process RSS: %d
KiB\n",
RamStat.getJvmHeapUsed(),
RamStat.getJvmHeapTotal(),
- RamStat.getProcessRamUsed()
- )
+ RamStat.getProcessRamUsed())
println("")
println("Test report: ")
@@ -225,25 +212,22 @@ case class TestResultLines(
fields.append("Row Count")
fields.append("Query Time (Millis)")
printf(fmt, fields: _*)
- lines.foreach {
- line =>
- val values = ArrayBuffer[Any](line.queryId, line.succeed)
- dimNames.foreach {
- dimName =>
- val coordinate = line.coordinate.coordinate
- if (!coordinate.contains(dimName)) {
- throw new IllegalStateException("Dimension name not found" +
dimName)
- }
- values.append(coordinate(dimName))
- }
- metricNames.foreach {
- metricName =>
- val metrics = line.metrics
- values.append(metrics.getOrElse(metricName, "N/A"))
+ lines.foreach { line =>
+ val values = ArrayBuffer[Any](line.queryId, line.succeed)
+ dimNames.foreach { dimName =>
+ val coordinate = line.coordinate.coordinate
+ if (!coordinate.contains(dimName)) {
+ throw new IllegalStateException("Dimension name not found" + dimName)
}
- values.append(line.rowCount.getOrElse("N/A"))
- values.append(line.executionTimeMillis.getOrElse("N/A"))
- printf(fmt, values: _*)
+ values.append(coordinate(dimName))
+ }
+ metricNames.foreach { metricName =>
+ val metrics = line.metrics
+ values.append(metrics.getOrElse(metricName, "N/A"))
+ }
+ values.append(line.rowCount.getOrElse("N/A"))
+ values.append(line.executionTimeMillis.getOrElse("N/A"))
+ printf(fmt, values: _*)
}
}
}
diff --git
a/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/tpc/action/Queries.scala
b/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/tpc/action/Queries.scala
index dc4ffe622..c5f883189 100644
---
a/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/tpc/action/Queries.scala
+++
b/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/tpc/action/Queries.scala
@@ -27,24 +27,22 @@ case class Queries(
explain: Boolean,
iterations: Int,
randomKillTasks: Boolean)
- extends Action {
+ extends Action {
override def execute(tpcSuite: TpcSuite): Boolean = {
val runQueryIds = queries.select(tpcSuite)
val runner: TpcRunner = new TpcRunner(tpcSuite.queryResource(),
tpcSuite.dataWritePath(scale))
- val results = (0 until iterations).flatMap {
- iteration =>
- println(s"Running tests (iteration $iteration)...")
- runQueryIds.map {
- queryId =>
- Queries.runTpcQuery(
- runner,
- tpcSuite.sessionSwitcher,
- queryId,
- tpcSuite.desc(),
- explain,
- randomKillTasks)
- }
+ val results = (0 until iterations).flatMap { iteration =>
+ println(s"Running tests (iteration $iteration)...")
+ runQueryIds.map { queryId =>
+ Queries.runTpcQuery(
+ runner,
+ tpcSuite.sessionSwitcher,
+ queryId,
+ tpcSuite.desc(),
+ explain,
+ randomKillTasks)
+ }
}.toList
val passedCount = results.count(l => l.testPassed)
@@ -58,8 +56,7 @@ case class Queries(
"RAM statistics: JVM Heap size: %d KiB (total %d KiB), Process RSS: %d
KiB\n",
RamStat.getJvmHeapUsed(),
RamStat.getJvmHeapTotal(),
- RamStat.getProcessRamUsed()
- )
+ RamStat.getProcessRamUsed())
println("")
println("Test report: ")
@@ -112,17 +109,14 @@ object Queries {
"Query ID",
"Was Passed",
"Row Count",
- "Query Time (Millis)"
- )
- results.foreach {
- line =>
- printf(
- "|%15s|%15s|%30s|%30s|\n",
- line.queryId,
- line.testPassed,
- line.rowCount.getOrElse("N/A"),
- line.executionTimeMillis.getOrElse("N/A")
- )
+ "Query Time (Millis)")
+ results.foreach { line =>
+ printf(
+ "|%15s|%15s|%30s|%30s|\n",
+ line.queryId,
+ line.testPassed,
+ line.rowCount.getOrElse("N/A"),
+ line.executionTimeMillis.getOrElse("N/A"))
}
}
@@ -131,19 +125,17 @@ object Queries {
return Nil
}
List(
- succeed.reduce(
- (r1, r2) =>
- TestResultLine(
- name,
- testPassed = true,
- if (r1.rowCount.nonEmpty && r2.rowCount.nonEmpty)
- Some(r1.rowCount.get + r2.rowCount.get)
- else None,
- if (r1.executionTimeMillis.nonEmpty &&
r2.executionTimeMillis.nonEmpty)
- Some(r1.executionTimeMillis.get + r2.executionTimeMillis.get)
- else None,
- None
- )))
+ succeed.reduce((r1, r2) =>
+ TestResultLine(
+ name,
+ testPassed = true,
+ if (r1.rowCount.nonEmpty && r2.rowCount.nonEmpty)
+ Some(r1.rowCount.get + r2.rowCount.get)
+ else None,
+ if (r1.executionTimeMillis.nonEmpty &&
r2.executionTimeMillis.nonEmpty)
+ Some(r1.executionTimeMillis.get + r2.executionTimeMillis.get)
+ else None,
+ None)))
}
private def runTpcQuery(
diff --git
a/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/tpc/action/QueriesCompare.scala
b/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/tpc/action/QueriesCompare.scala
index f841b5827..5e8e2d613 100644
---
a/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/tpc/action/QueriesCompare.scala
+++
b/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/tpc/action/QueriesCompare.scala
@@ -27,23 +27,21 @@ case class QueriesCompare(
queries: QuerySelector,
explain: Boolean,
iterations: Int)
- extends Action {
+ extends Action {
override def execute(tpcSuite: TpcSuite): Boolean = {
val runner: TpcRunner = new TpcRunner(tpcSuite.queryResource(),
tpcSuite.dataWritePath(scale))
val runQueryIds = queries.select(tpcSuite)
- val results = (0 until iterations).flatMap {
- iteration =>
- println(s"Running tests (iteration $iteration)...")
- runQueryIds.map {
- queryId =>
- QueriesCompare.runTpcQuery(
- queryId,
- explain,
- tpcSuite.desc(),
- tpcSuite.sessionSwitcher,
- runner)
- }
+ val results = (0 until iterations).flatMap { iteration =>
+ println(s"Running tests (iteration $iteration)...")
+ runQueryIds.map { queryId =>
+ QueriesCompare.runTpcQuery(
+ queryId,
+ explain,
+ tpcSuite.desc(),
+ tpcSuite.sessionSwitcher,
+ runner)
+ }
}.toList
val passedCount = results.count(l => l.testPassed)
@@ -57,8 +55,7 @@ case class QueriesCompare(
"RAM statistics: JVM Heap size: %d KiB (total %d KiB), Process RSS: %d
KiB\n",
RamStat.getJvmHeapUsed(),
RamStat.getJvmHeapTotal(),
- RamStat.getProcessRamUsed()
- )
+ RamStat.getProcessRamUsed())
println("")
println("Test report: ")
@@ -73,7 +70,8 @@ case class QueriesCompare(
println("No failed queries. ")
println("")
} else {
- println("Failed queries (a failed query with correct row count indicates
value mismatches): ")
+ println(
+ "Failed queries (a failed query with correct row count indicates value
mismatches): ")
println("")
QueriesCompare.printResults(results.filter(!_.testPassed))
println("")
@@ -116,28 +114,23 @@ object QueriesCompare {
"Actual Row Count",
"Baseline Query Time (Millis)",
"Query Time (Millis)",
- "Query Time Variation"
- )
- results.foreach {
- line =>
- val timeVariation =
- if (
- line.expectedExecutionTimeMillis.nonEmpty &&
line.actualExecutionTimeMillis.nonEmpty
- ) {
- Some(
- ((line.expectedExecutionTimeMillis.get -
line.actualExecutionTimeMillis.get).toDouble
- / line.actualExecutionTimeMillis.get.toDouble) * 100)
- } else None
- printf(
- "|%15s|%15s|%30s|%30s|%30s|%30s|%30s|\n",
- line.queryId,
- line.testPassed,
- line.expectedRowCount.getOrElse("N/A"),
- line.actualRowCount.getOrElse("N/A"),
- line.expectedExecutionTimeMillis.getOrElse("N/A"),
- line.actualExecutionTimeMillis.getOrElse("N/A"),
- timeVariation.map("%15.2f%%".format(_)).getOrElse("N/A")
- )
+ "Query Time Variation")
+ results.foreach { line =>
+ val timeVariation =
+ if (line.expectedExecutionTimeMillis.nonEmpty &&
line.actualExecutionTimeMillis.nonEmpty) {
+ Some(
+ ((line.expectedExecutionTimeMillis.get -
line.actualExecutionTimeMillis.get).toDouble
+ / line.actualExecutionTimeMillis.get.toDouble) * 100)
+ } else None
+ printf(
+ "|%15s|%15s|%30s|%30s|%30s|%30s|%30s|\n",
+ line.queryId,
+ line.testPassed,
+ line.expectedRowCount.getOrElse("N/A"),
+ line.actualRowCount.getOrElse("N/A"),
+ line.expectedExecutionTimeMillis.getOrElse("N/A"),
+ line.actualExecutionTimeMillis.getOrElse("N/A"),
+ timeVariation.map("%15.2f%%".format(_)).getOrElse("N/A"))
}
}
@@ -146,25 +139,23 @@ object QueriesCompare {
return Nil
}
List(
- succeed.reduce(
- (r1, r2) =>
- TestResultLine(
- name,
- testPassed = true,
- if (r1.expectedRowCount.nonEmpty && r2.expectedRowCount.nonEmpty)
- Some(r1.expectedRowCount.get + r2.expectedRowCount.get)
- else None,
- if (r1.actualRowCount.nonEmpty && r2.actualRowCount.nonEmpty)
- Some(r1.actualRowCount.get + r2.actualRowCount.get)
- else None,
- if (r1.expectedExecutionTimeMillis.nonEmpty &&
r2.expectedExecutionTimeMillis.nonEmpty)
- Some(r1.expectedExecutionTimeMillis.get +
r2.expectedExecutionTimeMillis.get)
- else None,
- if (r1.actualExecutionTimeMillis.nonEmpty &&
r2.actualExecutionTimeMillis.nonEmpty)
- Some(r1.actualExecutionTimeMillis.get +
r2.actualExecutionTimeMillis.get)
- else None,
- None
- )))
+ succeed.reduce((r1, r2) =>
+ TestResultLine(
+ name,
+ testPassed = true,
+ if (r1.expectedRowCount.nonEmpty && r2.expectedRowCount.nonEmpty)
+ Some(r1.expectedRowCount.get + r2.expectedRowCount.get)
+ else None,
+ if (r1.actualRowCount.nonEmpty && r2.actualRowCount.nonEmpty)
+ Some(r1.actualRowCount.get + r2.actualRowCount.get)
+ else None,
+ if (r1.expectedExecutionTimeMillis.nonEmpty &&
r2.expectedExecutionTimeMillis.nonEmpty)
+ Some(r1.expectedExecutionTimeMillis.get +
r2.expectedExecutionTimeMillis.get)
+ else None,
+ if (r1.actualExecutionTimeMillis.nonEmpty &&
r2.actualExecutionTimeMillis.nonEmpty)
+ Some(r1.actualExecutionTimeMillis.get +
r2.actualExecutionTimeMillis.get)
+ else None,
+ None)))
}
private[tpc] def runTpcQuery(
diff --git
a/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/tpc/ds/TpcdsDataGen.scala
b/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/tpc/ds/TpcdsDataGen.scala
index 081e54747..82d16dd90 100644
---
a/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/tpc/ds/TpcdsDataGen.scala
+++
b/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/tpc/ds/TpcdsDataGen.scala
@@ -33,8 +33,8 @@ class TpcdsDataGen(
dir: String,
typeModifiers: List[TypeModifier] = List(),
val genPartitionedData: Boolean)
- extends Serializable
- with DataGen {
+ extends Serializable
+ with DataGen {
def writeParquetTable(t: Table): Unit = {
val name = t.getName
@@ -97,25 +97,23 @@ class TpcdsDataGen(
val tablePath = dir + File.separator + tableName
spark
.range(0, partitions, 1L, partitions)
- .mapPartitions {
- itr =>
- val id = itr.toArray
- if (id.length != 1) {
- throw new IllegalStateException()
- }
- val options = new Options()
- options.scale = scale
- options.parallelism = partitions
- val session = options.toSession
- val chunkSession = session.withChunkNumber(id(0).toInt + 1)
- val results = Results.constructResults(t,
chunkSession).asScala.toIterator
- results.map {
- parentAndChildRow =>
- // Skip child table when generating parent table,
- // we generate every table individually no matter it is parent
or child.
- val array: Array[String] =
parentAndChildRow.get(0).asScala.toArray
- Row(array: _*)
- }
+ .mapPartitions { itr =>
+ val id = itr.toArray
+ if (id.length != 1) {
+ throw new IllegalStateException()
+ }
+ val options = new Options()
+ options.scale = scale
+ options.parallelism = partitions
+ val session = options.toSession
+ val chunkSession = session.withChunkNumber(id(0).toInt + 1)
+ val results = Results.constructResults(t,
chunkSession).asScala.toIterator
+ results.map { parentAndChildRow =>
+ // Skip child table when generating parent table,
+ // we generate every table individually no matter it is parent or
child.
+ val array: Array[String] = parentAndChildRow.get(0).asScala.toArray
+ Row(array: _*)
+ }
}(ShimUtils.getExpressionEncoder(stringSchema))
.select(columns: _*)
.write
@@ -168,8 +166,7 @@ object TpcdsDataGen {
StructField("cs_net_paid_inc_tax", DecimalType(7, 2)),
StructField("cs_net_paid_inc_ship", DecimalType(7, 2)),
StructField("cs_net_paid_inc_ship_tax", DecimalType(7, 2)),
- StructField("cs_net_profit", DecimalType(7, 2))
- ))
+ StructField("cs_net_profit", DecimalType(7, 2))))
}
private def catalogReturnsSchema = {
@@ -201,8 +198,7 @@ object TpcdsDataGen {
StructField("cr_refunded_cash", DecimalType(7, 2)),
StructField("cr_reversed_charge", DecimalType(7, 2)),
StructField("cr_store_credit", DecimalType(7, 2)),
- StructField("cr_net_loss", DecimalType(7, 2))
- ))
+ StructField("cr_net_loss", DecimalType(7, 2))))
}
private def inventorySchema = {
@@ -211,8 +207,7 @@ object TpcdsDataGen {
StructField("inv_date_sk", LongType),
StructField("inv_item_sk", LongType),
StructField("inv_warehouse_sk", LongType),
- StructField("inv_quantity_on_hand", LongType)
- ))
+ StructField("inv_quantity_on_hand", LongType)))
}
private def storeSalesSchema = {
@@ -240,8 +235,7 @@ object TpcdsDataGen {
StructField("ss_coupon_amt", DecimalType(7, 2)),
StructField("ss_net_paid", DecimalType(7, 2)),
StructField("ss_net_paid_inc_tax", DecimalType(7, 2)),
- StructField("ss_net_profit", DecimalType(7, 2))
- ))
+ StructField("ss_net_profit", DecimalType(7, 2))))
}
private def storeReturnsSchema = {
@@ -266,8 +260,7 @@ object TpcdsDataGen {
StructField("sr_refunded_cash", DecimalType(7, 2)),
StructField("sr_reversed_charge", DecimalType(7, 2)),
StructField("sr_store_credit", DecimalType(7, 2)),
- StructField("sr_net_loss", DecimalType(7, 2))
- ))
+ StructField("sr_net_loss", DecimalType(7, 2))))
}
private def webSalesSchema = {
@@ -306,8 +299,7 @@ object TpcdsDataGen {
StructField("ws_net_paid_inc_tax", DecimalType(7, 2)),
StructField("ws_net_paid_inc_ship", DecimalType(7, 2)),
StructField("ws_net_paid_inc_ship_tax", DecimalType(7, 2)),
- StructField("ws_net_profit", DecimalType(7, 2))
- ))
+ StructField("ws_net_profit", DecimalType(7, 2))))
}
private def webReturnsSchema = {
@@ -336,8 +328,7 @@ object TpcdsDataGen {
StructField("wr_refunded_cash", DecimalType(7, 2)),
StructField("wr_reversed_charge", DecimalType(7, 2)),
StructField("wr_account_credit", DecimalType(7, 2)),
- StructField("wr_net_loss", DecimalType(7, 2))
- ))
+ StructField("wr_net_loss", DecimalType(7, 2))))
}
private def callCenterSchema = {
@@ -373,8 +364,7 @@ object TpcdsDataGen {
StructField("cc_zip", StringType),
StructField("cc_country", StringType),
StructField("cc_gmt_offset", DecimalType(5, 2)),
- StructField("cc_tax_percentage", DecimalType(5, 2))
- ))
+ StructField("cc_tax_percentage", DecimalType(5, 2))))
}
private def catalogPageSchema = {
@@ -388,8 +378,7 @@ object TpcdsDataGen {
StructField("cp_catalog_number", LongType),
StructField("cp_catalog_page_number", LongType),
StructField("cp_description", StringType),
- StructField("cp_type", StringType)
- ))
+ StructField("cp_type", StringType)))
}
private def customerSchema = {
@@ -412,8 +401,7 @@ object TpcdsDataGen {
StructField("c_birth_country", StringType),
StructField("c_login", StringType),
StructField("c_email_address", StringType),
- StructField("c_last_review_date", StringType)
- ))
+ StructField("c_last_review_date", StringType)))
}
private def customerAddressSchema = {
@@ -431,8 +419,7 @@ object TpcdsDataGen {
StructField("ca_zip", StringType),
StructField("ca_country", StringType),
StructField("ca_gmt_offset", DecimalType(5, 2)),
- StructField("ca_location_type", StringType)
- ))
+ StructField("ca_location_type", StringType)))
}
private def customerDemographicsSchema = {
@@ -446,8 +433,7 @@ object TpcdsDataGen {
StructField("cd_credit_rating", StringType),
StructField("cd_dep_count", LongType),
StructField("cd_dep_employed_count", LongType),
- StructField("cd_dep_college_count", LongType)
- ))
+ StructField("cd_dep_college_count", LongType)))
}
private def dateDimSchema = {
@@ -480,8 +466,7 @@ object TpcdsDataGen {
StructField("d_current_week", StringType),
StructField("d_current_month", StringType),
StructField("d_current_quarter", StringType),
- StructField("d_current_year", StringType)
- ))
+ StructField("d_current_year", StringType)))
}
private def householdDemographicsSchema = {
@@ -491,8 +476,7 @@ object TpcdsDataGen {
StructField("hd_income_band_sk", LongType),
StructField("hd_buy_potential", StringType),
StructField("hd_dep_count", LongType),
- StructField("hd_vehicle_count", LongType)
- ))
+ StructField("hd_vehicle_count", LongType)))
}
private def incomeBandSchema = {
@@ -500,8 +484,7 @@ object TpcdsDataGen {
Seq(
StructField("ib_income_band_sk", LongType),
StructField("ib_lower_bound", LongType),
- StructField("ib_upper_bound", LongType)
- ))
+ StructField("ib_upper_bound", LongType)))
}
private def itemSchema = {
@@ -528,8 +511,7 @@ object TpcdsDataGen {
StructField("i_units", StringType),
StructField("i_container", StringType),
StructField("i_manager_id", LongType),
- StructField("i_product_name", StringType)
- ))
+ StructField("i_product_name", StringType)))
}
private def promotionSchema = {
@@ -553,8 +535,7 @@ object TpcdsDataGen {
StructField("p_channel_demo", StringType),
StructField("p_channel_details", StringType),
StructField("p_purpose", StringType),
- StructField("p_discount_active", StringType)
- ))
+ StructField("p_discount_active", StringType)))
}
private def reasonSchema = {
@@ -562,8 +543,7 @@ object TpcdsDataGen {
Seq(
StructField("r_reason_sk", LongType),
StructField("r_reason_id", StringType),
- StructField("r_reason_desc", StringType)
- ))
+ StructField("r_reason_desc", StringType)))
}
private def shipModeSchema = {
@@ -574,8 +554,7 @@ object TpcdsDataGen {
StructField("sm_type", StringType),
StructField("sm_code", StringType),
StructField("sm_carrier", StringType),
- StructField("sm_contract", StringType)
- ))
+ StructField("sm_contract", StringType)))
}
private def storeSchema = {
@@ -609,8 +588,7 @@ object TpcdsDataGen {
StructField("s_zip", StringType),
StructField("s_country", StringType),
StructField("s_gmt_offset", DecimalType(5, 2)),
- StructField("s_tax_precentage", DecimalType(5, 2))
- ))
+ StructField("s_tax_precentage", DecimalType(5, 2))))
}
private def timeDimSchema = {
@@ -625,8 +603,7 @@ object TpcdsDataGen {
StructField("t_am_pm", StringType),
StructField("t_shift", StringType),
StructField("t_sub_shift", StringType),
- StructField("t_meal_time", StringType)
- ))
+ StructField("t_meal_time", StringType)))
}
private def warehouseSchema = {
@@ -645,8 +622,7 @@ object TpcdsDataGen {
StructField("w_state", StringType),
StructField("w_zip", StringType),
StructField("w_country", StringType),
- StructField("w_gmt_offset", DecimalType(5, 2))
- ))
+ StructField("w_gmt_offset", DecimalType(5, 2))))
}
private def webPageSchema = {
@@ -665,8 +641,7 @@ object TpcdsDataGen {
StructField("wp_char_count", LongType),
StructField("wp_link_count", LongType),
StructField("wp_image_count", LongType),
- StructField("wp_max_ad_count", LongType)
- ))
+ StructField("wp_max_ad_count", LongType)))
}
private def webSiteSchema = {
@@ -697,7 +672,6 @@ object TpcdsDataGen {
StructField("web_zip", StringType),
StructField("web_country", StringType),
StructField("web_gmt_offset", StringType),
- StructField("web_tax_percentage", DecimalType(5, 2))
- ))
+ StructField("web_tax_percentage", DecimalType(5, 2))))
}
}
diff --git
a/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/tpc/ds/TpcdsSuite.scala
b/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/tpc/ds/TpcdsSuite.scala
index 37a88d446..c703821c1 100644
---
a/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/tpc/ds/TpcdsSuite.scala
+++
b/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/tpc/ds/TpcdsSuite.scala
@@ -18,7 +18,11 @@ package org.apache.gluten.integration.tpc.ds
import org.apache.gluten.integration.tpc.{Constants, DataGen, TpcSuite,
TypeModifier}
import org.apache.gluten.integration.tpc.action.Action
-import org.apache.gluten.integration.tpc.ds.TpcdsSuite.{ALL_QUERY_IDS,
HISTORY_WRITE_PATH, TPCDS_WRITE_PATH}
+import org.apache.gluten.integration.tpc.ds.TpcdsSuite.{
+ ALL_QUERY_IDS,
+ HISTORY_WRITE_PATH,
+ TPCDS_WRITE_PATH
+}
import org.apache.spark.SparkConf
@@ -41,24 +45,23 @@ class TpcdsSuite(
val disableWscg: Boolean,
val shufflePartitions: Int,
val minimumScanPartitions: Boolean)
- extends TpcSuite(
- masterUrl,
- actions,
- testConf,
- baselineConf,
- extraSparkConf,
- logLevel,
- errorOnMemLeak,
- enableUi,
- enableHsUi,
- hsUiPort,
- offHeapSize,
- disableAqe,
- disableBhj,
- disableWscg,
- shufflePartitions,
- minimumScanPartitions
- ) {
+ extends TpcSuite(
+ masterUrl,
+ actions,
+ testConf,
+ baselineConf,
+ extraSparkConf,
+ logLevel,
+ errorOnMemLeak,
+ enableUi,
+ enableHsUi,
+ hsUiPort,
+ offHeapSize,
+ disableAqe,
+ disableBhj,
+ disableWscg,
+ shufflePartitions,
+ minimumScanPartitions) {
override protected def historyWritePath(): String = HISTORY_WRITE_PATH
@@ -191,7 +194,6 @@ object TpcdsSuite {
"q96",
"q97",
"q98",
- "q99"
- )
+ "q99")
private val HISTORY_WRITE_PATH = "/tmp/tpcds-history"
}
diff --git
a/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/tpc/h/TpchDataGen.scala
b/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/tpc/h/TpchDataGen.scala
index 18c557045..fa574f59c 100644
---
a/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/tpc/h/TpchDataGen.scala
+++
b/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/tpc/h/TpchDataGen.scala
@@ -34,8 +34,8 @@ class TpchDataGen(
partitions: Int,
path: String,
typeModifiers: List[TypeModifier] = List())
- extends Serializable
- with DataGen {
+ extends Serializable
+ with DataGen {
override def gen(): Unit = {
generate(path, "lineitem", lineItemSchema, partitions, lineItemGenerator,
lineItemParser)
@@ -55,8 +55,8 @@ class TpchDataGen(
}
// lineitem
- private def lineItemGenerator = {
- (part: Int, partCount: Int) => new LineItemGenerator(scale, part,
partCount)
+ private def lineItemGenerator = { (part: Int, partCount: Int) =>
+ new LineItemGenerator(scale, part, partCount)
}
private def lineItemSchema = {
@@ -77,8 +77,7 @@ class TpchDataGen(
StructField("l_shipinstruct", StringType),
StructField("l_shipmode", StringType),
StructField("l_comment", StringType),
- StructField("l_shipdate", DateType)
- ))
+ StructField("l_shipdate", DateType)))
}
private def lineItemParser: LineItem => Row =
@@ -99,12 +98,11 @@ class TpchDataGen(
lineItem.getShipInstructions,
lineItem.getShipMode,
lineItem.getComment,
- Date.valueOf(GenerateUtils.formatDate(lineItem.getShipDate))
- )
+ Date.valueOf(GenerateUtils.formatDate(lineItem.getShipDate)))
// customer
- private def customerGenerator = {
- (part: Int, partCount: Int) => new CustomerGenerator(scale, part,
partCount)
+ private def customerGenerator = { (part: Int, partCount: Int) =>
+ new CustomerGenerator(scale, part, partCount)
}
private def customerSchema = {
@@ -117,8 +115,7 @@ class TpchDataGen(
StructField("c_phone", StringType),
StructField("c_acctbal", DecimalType(12, 2)),
StructField("c_comment", StringType),
- StructField("c_mktsegment", StringType)
- ))
+ StructField("c_mktsegment", StringType)))
}
private def customerParser: Customer => Row =
@@ -131,12 +128,11 @@ class TpchDataGen(
customer.getPhone,
BigDecimal.valueOf(customer.getAccountBalance),
customer.getComment,
- customer.getMarketSegment
- )
+ customer.getMarketSegment)
// orders
- private def orderGenerator = {
- (part: Int, partCount: Int) => new OrderGenerator(scale, part, partCount)
+ private def orderGenerator = { (part: Int, partCount: Int) =>
+ new OrderGenerator(scale, part, partCount)
}
private def orderSchema = {
@@ -150,8 +146,7 @@ class TpchDataGen(
StructField("o_clerk", StringType),
StructField("o_shippriority", IntegerType),
StructField("o_comment", StringType),
- StructField("o_orderdate", DateType)
- ))
+ StructField("o_orderdate", DateType)))
}
private def orderParser: Order => Row =
@@ -165,12 +160,11 @@ class TpchDataGen(
order.getClerk,
order.getShipPriority,
order.getComment,
- Date.valueOf(GenerateUtils.formatDate(order.getOrderDate))
- )
+ Date.valueOf(GenerateUtils.formatDate(order.getOrderDate)))
// partsupp
- private def partSupplierGenerator = {
- (part: Int, partCount: Int) => new PartSupplierGenerator(scale, part,
partCount)
+ private def partSupplierGenerator = { (part: Int, partCount: Int) =>
+ new PartSupplierGenerator(scale, part, partCount)
}
private def partSupplierSchema = {
@@ -180,8 +174,7 @@ class TpchDataGen(
StructField("ps_suppkey", LongType),
StructField("ps_availqty", IntegerType),
StructField("ps_supplycost", DecimalType(12, 2)),
- StructField("ps_comment", StringType)
- ))
+ StructField("ps_comment", StringType)))
}
private def partSupplierParser: PartSupplier => Row =
@@ -191,12 +184,11 @@ class TpchDataGen(
ps.getSupplierKey,
ps.getAvailableQuantity,
BigDecimal.valueOf(ps.getSupplyCost),
- ps.getComment
- )
+ ps.getComment)
// supplier
- private def supplierGenerator = {
- (part: Int, partCount: Int) => new SupplierGenerator(scale, part,
partCount)
+ private def supplierGenerator = { (part: Int, partCount: Int) =>
+ new SupplierGenerator(scale, part, partCount)
}
private def supplierSchema = {
@@ -208,8 +200,7 @@ class TpchDataGen(
StructField("s_nationkey", LongType),
StructField("s_phone", StringType),
StructField("s_acctbal", DecimalType(12, 2)),
- StructField("s_comment", StringType)
- ))
+ StructField("s_comment", StringType)))
}
private def supplierParser: Supplier => Row =
@@ -221,11 +212,12 @@ class TpchDataGen(
s.getNationKey,
s.getPhone,
BigDecimal.valueOf(s.getAccountBalance),
- s.getComment
- )
+ s.getComment)
// nation
- private def nationGenerator = { () => new NationGenerator() }
+ private def nationGenerator = { () =>
+ new NationGenerator()
+ }
private def nationSchema = {
StructType(
@@ -233,22 +225,15 @@ class TpchDataGen(
StructField("n_nationkey", LongType),
StructField("n_name", StringType),
StructField("n_regionkey", LongType),
- StructField("n_comment", StringType)
- ))
+ StructField("n_comment", StringType)))
}
private def nationParser: Nation => Row =
- nation =>
- Row(
- nation.getNationKey,
- nation.getName,
- nation.getRegionKey,
- nation.getComment
- )
+ nation => Row(nation.getNationKey, nation.getName, nation.getRegionKey,
nation.getComment)
// part
- private def partGenerator = {
- (part: Int, partCount: Int) => new PartGenerator(scale, part, partCount)
+ private def partGenerator = { (part: Int, partCount: Int) =>
+ new PartGenerator(scale, part, partCount)
}
private def partSchema = {
@@ -262,8 +247,7 @@ class TpchDataGen(
StructField("p_container", StringType),
StructField("p_retailprice", DecimalType(12, 2)),
StructField("p_comment", StringType),
- StructField("p_brand", StringType)
- ))
+ StructField("p_brand", StringType)))
}
private def partParser: Part => Row =
@@ -277,28 +261,23 @@ class TpchDataGen(
part.getContainer,
BigDecimal.valueOf(part.getRetailPrice),
part.getComment,
- part.getBrand
- )
+ part.getBrand)
// region
- private def regionGenerator = { () => new RegionGenerator() }
+ private def regionGenerator = { () =>
+ new RegionGenerator()
+ }
private def regionSchema = {
StructType(
Seq(
StructField("r_regionkey", LongType),
StructField("r_name", StringType),
- StructField("r_comment", StringType)
- ))
+ StructField("r_comment", StringType)))
}
private def regionParser: Region => Row =
- region =>
- Row(
- region.getRegionKey,
- region.getName,
- region.getComment
- )
+ region => Row(region.getRegionKey, region.getName, region.getComment)
// gen tpc-h data
private def generate[U](
@@ -307,15 +286,9 @@ class TpchDataGen(
schema: StructType,
gen: () => java.lang.Iterable[U],
parser: U => Row): Unit = {
- generate(
- dir,
- tableName,
- schema,
- 1,
- (_: Int, _: Int) => {
- gen.apply()
- },
- parser)
+ generate(dir, tableName, schema, 1, (_: Int, _: Int) => {
+ gen.apply()
+ }, parser)
}
private def generate[U](
@@ -330,25 +303,23 @@ class TpchDataGen(
val modifiedSchema = DataGen.modifySchema(schema, rowModifier)
spark
.range(0, partitions, 1L, partitions)
- .mapPartitions {
- itr =>
- val id = itr.toArray
- if (id.length != 1) {
- throw new IllegalStateException()
- }
- val data = gen.apply(id(0).toInt + 1, partitions)
- val dataItr = data.iterator()
- val rows = dataItr.asScala.map {
- item =>
- val row = parser(item)
- val modifiedRow = Row(row.toSeq.zipWithIndex.map {
- case (v, i) =>
- val modifier = rowModifier.apply(i)
- modifier.modValue(v)
- }.toArray: _*)
- modifiedRow
- }
- rows
+ .mapPartitions { itr =>
+ val id = itr.toArray
+ if (id.length != 1) {
+ throw new IllegalStateException()
+ }
+ val data = gen.apply(id(0).toInt + 1, partitions)
+ val dataItr = data.iterator()
+ val rows = dataItr.asScala.map { item =>
+ val row = parser(item)
+ val modifiedRow = Row(row.toSeq.zipWithIndex.map {
+ case (v, i) =>
+ val modifier = rowModifier.apply(i)
+ modifier.modValue(v)
+ }.toArray: _*)
+ modifiedRow
+ }
+ rows
}(ShimUtils.getExpressionEncoder(modifiedSchema))
.write
.mode(SaveMode.Overwrite)
diff --git
a/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/tpc/h/TpchSuite.scala
b/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/tpc/h/TpchSuite.scala
index 418c7ca6a..9fbd83dc2 100644
---
a/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/tpc/h/TpchSuite.scala
+++
b/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/tpc/h/TpchSuite.scala
@@ -41,24 +41,23 @@ class TpchSuite(
val disableWscg: Boolean,
val shufflePartitions: Int,
val minimumScanPartitions: Boolean)
- extends TpcSuite(
- masterUrl,
- actions,
- testConf,
- baselineConf,
- extraSparkConf,
- logLevel,
- errorOnMemLeak,
- enableUi,
- enableHsUi,
- hsUiPort,
- offHeapSize,
- disableAqe,
- disableBhj,
- disableWscg,
- shufflePartitions,
- minimumScanPartitions
- ) {
+ extends TpcSuite(
+ masterUrl,
+ actions,
+ testConf,
+ baselineConf,
+ extraSparkConf,
+ logLevel,
+ errorOnMemLeak,
+ enableUi,
+ enableHsUi,
+ hsUiPort,
+ offHeapSize,
+ disableAqe,
+ disableBhj,
+ disableWscg,
+ shufflePartitions,
+ minimumScanPartitions) {
override protected def historyWritePath(): String = HISTORY_WRITE_PATH
diff --git
a/tools/gluten-it/common/src/main/scala/org/apache/spark/deploy/history/GlutenItHistoryServerPlugin.scala
b/tools/gluten-it/common/src/main/scala/org/apache/spark/deploy/history/GlutenItHistoryServerPlugin.scala
index 33500c3e1..4720d3e4a 100644
---
a/tools/gluten-it/common/src/main/scala/org/apache/spark/deploy/history/GlutenItHistoryServerPlugin.scala
+++
b/tools/gluten-it/common/src/main/scala/org/apache/spark/deploy/history/GlutenItHistoryServerPlugin.scala
@@ -20,7 +20,11 @@ import org.apache.spark.SparkConf
import org.apache.spark.deploy.history.HistoryServerHelper.LogServerRpcEnvs
import org.apache.spark.scheduler.SparkListener
import org.apache.spark.sql.ConfUtils.ConfImplicits.SparkConfWrapper
-import org.apache.spark.status.{AppHistoryServerPlugin, ElementTrackingStore,
ExecutorSummaryWrapper}
+import org.apache.spark.status.{
+ AppHistoryServerPlugin,
+ ElementTrackingStore,
+ ExecutorSummaryWrapper
+}
import org.apache.spark.status.api.v1
import com.google.common.base.Preconditions
@@ -60,52 +64,53 @@ class GlutenItHistoryServerPlugin extends
AppHistoryServerPlugin {
}
}
- override def createListeners(conf: SparkConf, store: ElementTrackingStore):
Seq[SparkListener] = {
+ override def createListeners(
+ conf: SparkConf,
+ store: ElementTrackingStore): Seq[SparkListener] = {
store.onFlush {
val wrappers = org.apache.spark.util.Utils
.tryWithResource(store.view(classOf[ExecutorSummaryWrapper]).closeableIterator())
{
- iter => iter.asScala.toList
+ iter =>
+ iter.asScala.toList
}
// create new executor summaries
wrappers
- .map {
- wrapper =>
- Preconditions.checkArgument(wrapper.info.attributes.isEmpty)
- new ExecutorSummaryWrapper(
- new v1.ExecutorSummary(
- id = wrapper.info.id,
- hostPort = wrapper.info.hostPort,
- isActive = wrapper.info.isActive,
- rddBlocks = wrapper.info.rddBlocks,
- memoryUsed = wrapper.info.memoryUsed,
- diskUsed = wrapper.info.diskUsed,
- totalCores = wrapper.info.totalCores,
- maxTasks = wrapper.info.maxTasks,
- activeTasks = wrapper.info.activeTasks,
- failedTasks = wrapper.info.failedTasks,
- completedTasks = wrapper.info.completedTasks,
- totalTasks = wrapper.info.totalTasks,
- totalDuration = wrapper.info.totalDuration,
- totalGCTime = wrapper.info.totalGCTime,
- totalInputBytes = wrapper.info.totalInputBytes,
- totalShuffleRead = wrapper.info.totalShuffleRead,
- totalShuffleWrite = wrapper.info.totalShuffleWrite,
- isBlacklisted = wrapper.info.isBlacklisted,
- maxMemory = wrapper.info.maxMemory,
- addTime = wrapper.info.addTime,
- removeTime = wrapper.info.removeTime,
- removeReason = wrapper.info.removeReason,
- executorLogs = rewriteLogs(wrapper.info.executorLogs,
logServerRpcEnvs),
- memoryMetrics = wrapper.info.memoryMetrics,
- blacklistedInStages = wrapper.info.blacklistedInStages,
- peakMemoryMetrics = wrapper.info.peakMemoryMetrics,
- attributes = wrapper.info.attributes,
- resources = wrapper.info.resources,
- resourceProfileId = wrapper.info.resourceProfileId,
- isExcluded = wrapper.info.isExcluded,
- excludedInStages = wrapper.info.excludedInStages
- ))
+ .map { wrapper =>
+ Preconditions.checkArgument(wrapper.info.attributes.isEmpty)
+ new ExecutorSummaryWrapper(
+ new v1.ExecutorSummary(
+ id = wrapper.info.id,
+ hostPort = wrapper.info.hostPort,
+ isActive = wrapper.info.isActive,
+ rddBlocks = wrapper.info.rddBlocks,
+ memoryUsed = wrapper.info.memoryUsed,
+ diskUsed = wrapper.info.diskUsed,
+ totalCores = wrapper.info.totalCores,
+ maxTasks = wrapper.info.maxTasks,
+ activeTasks = wrapper.info.activeTasks,
+ failedTasks = wrapper.info.failedTasks,
+ completedTasks = wrapper.info.completedTasks,
+ totalTasks = wrapper.info.totalTasks,
+ totalDuration = wrapper.info.totalDuration,
+ totalGCTime = wrapper.info.totalGCTime,
+ totalInputBytes = wrapper.info.totalInputBytes,
+ totalShuffleRead = wrapper.info.totalShuffleRead,
+ totalShuffleWrite = wrapper.info.totalShuffleWrite,
+ isBlacklisted = wrapper.info.isBlacklisted,
+ maxMemory = wrapper.info.maxMemory,
+ addTime = wrapper.info.addTime,
+ removeTime = wrapper.info.removeTime,
+ removeReason = wrapper.info.removeReason,
+ executorLogs = rewriteLogs(wrapper.info.executorLogs,
logServerRpcEnvs),
+ memoryMetrics = wrapper.info.memoryMetrics,
+ blacklistedInStages = wrapper.info.blacklistedInStages,
+ peakMemoryMetrics = wrapper.info.peakMemoryMetrics,
+ attributes = wrapper.info.attributes,
+ resources = wrapper.info.resources,
+ resourceProfileId = wrapper.info.resourceProfileId,
+ isExcluded = wrapper.info.isExcluded,
+ excludedInStages = wrapper.info.excludedInStages))
}
.foreach(store.write(_))
}
diff --git
a/tools/gluten-it/common/src/main/scala/org/apache/spark/deploy/history/HistoryServerHelper.scala
b/tools/gluten-it/common/src/main/scala/org/apache/spark/deploy/history/HistoryServerHelper.scala
index fb991cb47..649ef130d 100644
---
a/tools/gluten-it/common/src/main/scala/org/apache/spark/deploy/history/HistoryServerHelper.scala
+++
b/tools/gluten-it/common/src/main/scala/org/apache/spark/deploy/history/HistoryServerHelper.scala
@@ -46,10 +46,9 @@ object HistoryServerHelper {
private def findFreePort(): Int = {
val port = org.apache.spark.util.Utils
- .tryWithResource(new ServerSocket(0)) {
- socket =>
- socket.setReuseAddress(true)
- socket.getLocalPort
+ .tryWithResource(new ServerSocket(0)) { socket =>
+ socket.setReuseAddress(true)
+ socket.getLocalPort
}
if (port > 0) {
return port
@@ -80,11 +79,10 @@ object HistoryServerHelper {
conf,
conf.get(org.apache.spark.internal.config.Worker.SPARK_WORKER_RESOURCE_FILE))
- ShutdownHookManager.addShutdownHook(
- () => {
- workerRpcEnv.shutdown()
- rpcEnv.shutdown()
- })
+ ShutdownHookManager.addShutdownHook(() => {
+ workerRpcEnv.shutdown()
+ rpcEnv.shutdown()
+ })
LogServerRpcEnvs(rpcEnv, workerRpcEnv, webUiPort, workerWebUiPort)
}
diff --git
a/tools/gluten-it/common/src/main/scala/org/apache/spark/sql/ConfUtils.scala
b/tools/gluten-it/common/src/main/scala/org/apache/spark/sql/ConfUtils.scala
index 966d46247..66eb4a48f 100644
--- a/tools/gluten-it/common/src/main/scala/org/apache/spark/sql/ConfUtils.scala
+++ b/tools/gluten-it/common/src/main/scala/org/apache/spark/sql/ConfUtils.scala
@@ -30,17 +30,15 @@ object ConfUtils {
onOverriding => {
Console.err.println(
s"Overriding SparkConf key ${onOverriding.key}, old value:
${onOverriding.value}, new value: ${onOverriding.newValue}. ")
- }
- )
+ })
}
def setAllWarningOnOverriding(others: Iterable[(String, String)]):
SparkConf = {
var tmp: SparkConf = conf
- others.foreach(
- c => {
- tmp = new SparkConfWrapper(tmp).setWarningOnOverriding(c._1, c._2)
- })
+ others.foreach(c => {
+ tmp = new SparkConfWrapper(tmp).setWarningOnOverriding(c._1, c._2)
+ })
tmp
}
diff --git
a/tools/gluten-it/common/src/main/scala/org/apache/spark/sql/QueryRunner.scala
b/tools/gluten-it/common/src/main/scala/org/apache/spark/sql/QueryRunner.scala
index a4044c925..332e56043 100644
---
a/tools/gluten-it/common/src/main/scala/org/apache/spark/sql/QueryRunner.scala
+++
b/tools/gluten-it/common/src/main/scala/org/apache/spark/sql/QueryRunner.scala
@@ -18,7 +18,12 @@ package org.apache.spark.sql
import org.apache.spark.{SparkContext, Success, TaskKilled}
import org.apache.spark.executor.ExecutorMetrics
-import org.apache.spark.scheduler.{SparkListener,
SparkListenerExecutorMetricsUpdate, SparkListenerTaskEnd,
SparkListenerTaskStart}
+import org.apache.spark.scheduler.{
+ SparkListener,
+ SparkListenerExecutorMetricsUpdate,
+ SparkListenerTaskEnd,
+ SparkListenerTaskStart
+}
import org.apache.spark.sql.KillTaskListener.INIT_WAIT_TIME_MS
import com.google.common.base.Preconditions
@@ -45,8 +50,7 @@ object QueryRunner {
"ProcessTreePythonVMemory",
"ProcessTreePythonRSSMemory",
"ProcessTreeOtherVMemory",
- "ProcessTreeOtherRSSMemory"
- )
+ "ProcessTreeOtherRSSMemory")
def runTpcQuery(
spark: SparkSession,
@@ -90,12 +94,11 @@ object QueryRunner {
RunResult(rows, millis, collectedMetrics)
} finally {
sc.removeSparkListener(metricsListener)
- killTaskListener.foreach(
- l => {
- sc.removeSparkListener(l)
- println(s"Successful kill rate ${"%.2f%%".format(
- 100 * l.successfulKillRate())} during execution of app:
${sc.applicationId}")
- })
+ killTaskListener.foreach(l => {
+ sc.removeSparkListener(l)
+ println(s"Successful kill rate ${"%.2f%%"
+ .format(100 * l.successfulKillRate())} during execution of app:
${sc.applicationId}")
+ })
sc.setJobDescription(null)
}
}
@@ -156,8 +159,8 @@ class KillTaskListener(val sc: SparkContext) extends
SparkListener {
sync.synchronized {
val total = Math.min(
stageKillMaxWaitTimeLookup.computeIfAbsent(taskStart.stageId,
_ => Long.MaxValue),
- stageKillWaitTimeLookup.computeIfAbsent(taskStart.stageId, _
=> INIT_WAIT_TIME_MS)
- )
+ stageKillWaitTimeLookup
+ .computeIfAbsent(taskStart.stageId, _ => INIT_WAIT_TIME_MS))
val elapsed = System.currentTimeMillis() - startMs
val remaining = total - elapsed
if (remaining <= 0L) {
diff --git
a/tools/gluten-it/common/src/main/scala/org/apache/spark/sql/TestUtils.scala
b/tools/gluten-it/common/src/main/scala/org/apache/spark/sql/TestUtils.scala
index 03e70fbcc..c5af0e9b4 100644
--- a/tools/gluten-it/common/src/main/scala/org/apache/spark/sql/TestUtils.scala
+++ b/tools/gluten-it/common/src/main/scala/org/apache/spark/sql/TestUtils.scala
@@ -46,7 +46,8 @@ object TestUtils {
override def toString: String = java.lang.Float.toString(value)
// unsupported
- override def compareTo(anotherFloat: FuzzyFloat): Int = throw new
UnsupportedOperationException
+ override def compareTo(anotherFloat: FuzzyFloat): Int =
+ throw new UnsupportedOperationException
override def hashCode(): Int = throw new UnsupportedOperationException
}
@@ -62,15 +63,14 @@ object TestUtils {
// For binary arrays, we convert it to Seq to avoid of calling
java.util.Arrays.equals for
// equality test.
// This function is copied from Catalyst's QueryTest
- val converted: Seq[Row] = answer.map {
- s =>
- Row.fromSeq(s.toSeq.map {
- case d: java.math.BigDecimal => BigDecimal(d)
- case b: Array[Byte] => b.toSeq
- case f: Float => new FuzzyFloat(f)
- case db: Double => new FuzzyDouble(db)
- case o => o
- })
+ val converted: Seq[Row] = answer.map { s =>
+ Row.fromSeq(s.toSeq.map {
+ case d: java.math.BigDecimal => BigDecimal(d)
+ case b: Array[Byte] => b.toSeq
+ case f: Float => new FuzzyFloat(f)
+ case db: Double => new FuzzyDouble(db)
+ case o => o
+ })
}
if (sort) {
converted.sortBy(_.toString())
@@ -83,11 +83,10 @@ object TestUtils {
s"""
| == Results ==
| ${sideBySide(
- s"== Expected Answer - ${expectedAnswer.size} ==" +:
- prepareAnswer(expectedAnswer).map(_.toString()),
- s"== Actual Answer - ${sparkAnswer.size} ==" +:
- prepareAnswer(sparkAnswer).map(_.toString())
- ).mkString("\n")}
+ s"== Expected Answer - ${expectedAnswer.size} ==" +:
+ prepareAnswer(expectedAnswer).map(_.toString()),
+ s"== Actual Answer - ${sparkAnswer.size} ==" +:
+ prepareAnswer(sparkAnswer).map(_.toString())).mkString("\n")}
""".stripMargin
Some(errorMessage)
} else {
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]