This is an automated email from the ASF dual-hosted git repository.
hongze pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new 7ac9983cd [VL] Gluten-it: Reuse Spark sessions that share same
configuration (#6117)
7ac9983cd is described below
commit 7ac9983cde6ed2f942eaf05c628055bc715b6990
Author: Hongze Zhang <[email protected]>
AuthorDate: Wed Jun 19 13:07:22 2024 +0800
[VL] Gluten-it: Reuse Spark sessions that share same configuration (#6117)
---
.github/workflows/velox_docker.yml | 2 +-
.../gluten/integration/command/Parameterized.java | 19 +-
.../apache/gluten/integration/command/Queries.java | 2 +-
.../gluten/integration/command/QueriesCompare.java | 2 +-
.../gluten/integration/command/QueriesMixin.java | 7 +
.../apache/gluten/integration/QueryRunner.scala | 53 +++-
.../gluten/integration/action/Parameterized.scala | 350 +++++++++++----------
.../apache/gluten/integration/action/Queries.scala | 148 ++++-----
.../gluten/integration/action/QueriesCompare.scala | 222 +++++++------
.../gluten/integration/action/SparkShell.scala | 2 +-
.../apache/gluten/integration/action/package.scala | 28 +-
.../integration/clickbench/ClickBenchDataGen.scala | 5 +-
.../apache/spark/sql/SparkSessionSwitcher.scala | 10 +
13 files changed, 440 insertions(+), 410 deletions(-)
diff --git a/.github/workflows/velox_docker.yml
b/.github/workflows/velox_docker.yml
index 6c1be4344..b1d5cfdcf 100644
--- a/.github/workflows/velox_docker.yml
+++ b/.github/workflows/velox_docker.yml
@@ -367,7 +367,7 @@ jobs:
cd tools/gluten-it \
&& GLUTEN_IT_JVM_ARGS=-Xmx6G sbin/gluten-it.sh queries \
--local --preset=velox --benchmark-type=ds --error-on-memleak
-s=30.0 --off-heap-size=8g --threads=12 --shuffle-partitions=72 --iterations=1
\
- --skip-data-gen --random-kill-tasks
+ --skip-data-gen --random-kill-tasks --no-session-reuse
# run-tpc-test-ubuntu-sf30:
# needs: build-native-lib-centos-7
diff --git
a/tools/gluten-it/common/src/main/java/org/apache/gluten/integration/command/Parameterized.java
b/tools/gluten-it/common/src/main/java/org/apache/gluten/integration/command/Parameterized.java
index cadff0a2d..225b492ef 100644
---
a/tools/gluten-it/common/src/main/java/org/apache/gluten/integration/command/Parameterized.java
+++
b/tools/gluten-it/common/src/main/java/org/apache/gluten/integration/command/Parameterized.java
@@ -18,9 +18,6 @@ package org.apache.gluten.integration.command;
import com.google.common.base.Preconditions;
import org.apache.gluten.integration.BaseMixin;
-import org.apache.gluten.integration.action.Dim;
-import org.apache.gluten.integration.action.DimKv;
-import org.apache.gluten.integration.action.DimValue;
import org.apache.commons.lang3.ArrayUtils;
import picocli.CommandLine;
import scala.Tuple2;
@@ -67,17 +64,17 @@ public class Parameterized implements Callable<Integer> {
public Integer call() throws Exception {
final Map<String, Map<String, List<Map.Entry<String, String>>>> parsed =
new LinkedHashMap<>();
- final Seq<scala.collection.immutable.Set<DimKv>> excludedCombinations =
JavaConverters.asScalaBufferConverter(Arrays.stream(excludedDims).map(d -> {
+ final
Seq<scala.collection.immutable.Set<org.apache.gluten.integration.action.Parameterized.DimKv>>
excludedCombinations =
JavaConverters.asScalaBufferConverter(Arrays.stream(excludedDims).map(d -> {
final Matcher m = excludedDimsPattern.matcher(d);
Preconditions.checkArgument(m.matches(), "Unrecognizable excluded dims:
" + d);
- Set<DimKv> out = new HashSet<>();
+ Set<org.apache.gluten.integration.action.Parameterized.DimKv> out = new
HashSet<>();
final String[] dims = d.split(",");
for (String dim : dims) {
final String[] kv = dim.split(":");
Preconditions.checkArgument(kv.length == 2, "Unrecognizable excluded
dims: " + d);
- out.add(new DimKv(kv[0], kv[1]));
+ out.add(new
org.apache.gluten.integration.action.Parameterized.DimKv(kv[0], kv[1]));
}
- return JavaConverters.asScalaSetConverter(out).asScala().<DimKv>toSet();
+ return
JavaConverters.asScalaSetConverter(out).asScala().<org.apache.gluten.integration.action.Parameterized.DimKv>toSet();
}).collect(Collectors.toList())).asScala();
// parse dims
@@ -121,11 +118,11 @@ public class Parameterized implements Callable<Integer> {
}
// Convert Map<String, Map<String, List<Map.Entry<String, String>>>> to
List<Dim>
- Seq<Dim> parsedDims = JavaConverters.asScalaBufferConverter(
+ Seq<org.apache.gluten.integration.action.Parameterized.Dim> parsedDims =
JavaConverters.asScalaBufferConverter(
parsed.entrySet().stream().map(e ->
- new Dim(e.getKey(), JavaConverters.asScalaBufferConverter(
+ new
org.apache.gluten.integration.action.Parameterized.Dim(e.getKey(),
JavaConverters.asScalaBufferConverter(
e.getValue().entrySet().stream().map(e2 ->
- new DimValue(e2.getKey(),
JavaConverters.asScalaBufferConverter(
+ new
org.apache.gluten.integration.action.Parameterized.DimValue(e2.getKey(),
JavaConverters.asScalaBufferConverter(
e2.getValue().stream().map(e3 -> new
Tuple2<>(e3.getKey(), e3.getValue()))
.collect(Collectors.toList())).asScala())).collect(Collectors.toList())).asScala()
)).collect(Collectors.toList())).asScala();
@@ -133,7 +130,7 @@ public class Parameterized implements Callable<Integer> {
org.apache.gluten.integration.action.Parameterized parameterized =
new
org.apache.gluten.integration.action.Parameterized(dataGenMixin.getScale(),
dataGenMixin.genPartitionedData(), queriesMixin.queries(),
- queriesMixin.explain(), queriesMixin.iterations(),
warmupIterations, parsedDims,
+ queriesMixin.explain(), queriesMixin.iterations(),
warmupIterations, queriesMixin.noSessionReuse(), parsedDims,
excludedCombinations, metrics);
return mixin.runActions(ArrayUtils.addAll(dataGenMixin.makeActions(),
parameterized));
}
diff --git
a/tools/gluten-it/common/src/main/java/org/apache/gluten/integration/command/Queries.java
b/tools/gluten-it/common/src/main/java/org/apache/gluten/integration/command/Queries.java
index f0c07b415..c19d66bda 100644
---
a/tools/gluten-it/common/src/main/java/org/apache/gluten/integration/command/Queries.java
+++
b/tools/gluten-it/common/src/main/java/org/apache/gluten/integration/command/Queries.java
@@ -42,7 +42,7 @@ public class Queries implements Callable<Integer> {
public Integer call() throws Exception {
org.apache.gluten.integration.action.Queries queries =
new
org.apache.gluten.integration.action.Queries(dataGenMixin.getScale(),
dataGenMixin.genPartitionedData(), queriesMixin.queries(),
- queriesMixin.explain(), queriesMixin.iterations(),
randomKillTasks);
+ queriesMixin.explain(), queriesMixin.iterations(),
randomKillTasks, queriesMixin.noSessionReuse());
return mixin.runActions(ArrayUtils.addAll(dataGenMixin.makeActions(),
queries));
}
}
diff --git
a/tools/gluten-it/common/src/main/java/org/apache/gluten/integration/command/QueriesCompare.java
b/tools/gluten-it/common/src/main/java/org/apache/gluten/integration/command/QueriesCompare.java
index 42b00f94c..d194aad18 100644
---
a/tools/gluten-it/common/src/main/java/org/apache/gluten/integration/command/QueriesCompare.java
+++
b/tools/gluten-it/common/src/main/java/org/apache/gluten/integration/command/QueriesCompare.java
@@ -40,7 +40,7 @@ public class QueriesCompare implements Callable<Integer> {
org.apache.gluten.integration.action.QueriesCompare queriesCompare =
new
org.apache.gluten.integration.action.QueriesCompare(dataGenMixin.getScale(),
dataGenMixin.genPartitionedData(), queriesMixin.queries(),
- queriesMixin.explain(), queriesMixin.iterations());
+ queriesMixin.explain(), queriesMixin.iterations(),
queriesMixin.noSessionReuse());
return mixin.runActions(ArrayUtils.addAll(dataGenMixin.makeActions(),
queriesCompare));
}
}
diff --git
a/tools/gluten-it/common/src/main/java/org/apache/gluten/integration/command/QueriesMixin.java
b/tools/gluten-it/common/src/main/java/org/apache/gluten/integration/command/QueriesMixin.java
index fc93f968c..64e4b32ec 100644
---
a/tools/gluten-it/common/src/main/java/org/apache/gluten/integration/command/QueriesMixin.java
+++
b/tools/gluten-it/common/src/main/java/org/apache/gluten/integration/command/QueriesMixin.java
@@ -42,6 +42,9 @@ public class QueriesMixin {
@CommandLine.Option(names = {"--iterations"}, description = "How many
iterations to run", defaultValue = "1")
private int iterations;
+ @CommandLine.Option(names = {"--no-session-reuse"}, description = "Recreate
new Spark session each time a query is about to run", defaultValue = "false")
+ private boolean noSessionReuse;
+
public boolean explain() {
return explain;
}
@@ -50,6 +53,10 @@ public class QueriesMixin {
return iterations;
}
+ public boolean noSessionReuse() {
+ return noSessionReuse;
+ }
+
public Actions.QuerySelector queries() {
return new Actions.QuerySelector() {
@Override
diff --git
a/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/QueryRunner.scala
b/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/QueryRunner.scala
index 88e8e2250..9791242f1 100644
---
a/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/QueryRunner.scala
+++
b/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/QueryRunner.scala
@@ -17,11 +17,14 @@
package org.apache.gluten.integration
import com.google.common.base.Preconditions
+import org.apache.commons.lang3.exception.ExceptionUtils
import org.apache.spark.sql.{RunResult, SparkQueryRunner, SparkSession}
import java.io.File
class QueryRunner(val queryResourceFolder: String, val dataPath: String) {
+ import QueryRunner._
+
Preconditions.checkState(
new File(dataPath).exists(),
s"Data not found at $dataPath, try using command `<gluten-it>
data-gen-only <options>` to generate it first.",
@@ -37,10 +40,54 @@ class QueryRunner(val queryResourceFolder: String, val
dataPath: String) {
caseId: String,
explain: Boolean = false,
metrics: Array[String] = Array(),
- randomKillTasks: Boolean = false): RunResult = {
+ randomKillTasks: Boolean = false): QueryResult = {
val path = "%s/%s.sql".format(queryResourceFolder, caseId)
- SparkQueryRunner.runQuery(spark, desc, path, explain, metrics,
randomKillTasks)
+ try {
+ val r = SparkQueryRunner.runQuery(spark, desc, path, explain, metrics,
randomKillTasks)
+ println(s"Successfully ran query $caseId. Returned row count:
${r.rows.length}")
+ Success(caseId, r)
+ } catch {
+ case e: Exception =>
+ println(s"Error running query $caseId. Error:
${ExceptionUtils.getStackTrace(e)}")
+ Failure(caseId, e)
+ }
}
}
-object QueryRunner {}
+object QueryRunner {
+ sealed trait QueryResult {
+ def caseId(): String
+ def succeeded(): Boolean
+ }
+
+ implicit class QueryResultOps(r: QueryResult) {
+ def asSuccessOption(): Option[Success] = {
+ r match {
+ case s: Success => Some(s)
+ case _: Failure => None
+ }
+ }
+
+ def asFailureOption(): Option[Failure] = {
+ r match {
+ case _: Success => None
+ case f: Failure => Some(f)
+ }
+ }
+
+ def asSuccess(): Success = {
+ asSuccessOption().get
+ }
+
+ def asFailure(): Failure = {
+ asFailureOption().get
+ }
+ }
+
+ case class Success(override val caseId: String, runResult: RunResult)
extends QueryResult {
+ override def succeeded(): Boolean = true
+ }
+ case class Failure(override val caseId: String, error: Exception) extends
QueryResult {
+ override def succeeded(): Boolean = false
+ }
+}
diff --git
a/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/action/Parameterized.scala
b/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/action/Parameterized.scala
index 74f22a05f..c9ebb9754 100644
---
a/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/action/Parameterized.scala
+++
b/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/action/Parameterized.scala
@@ -17,15 +17,16 @@
package org.apache.gluten.integration.action
import org.apache.commons.lang3.exception.ExceptionUtils
+import org.apache.gluten.integration.QueryRunner.QueryResult
import org.apache.gluten.integration.action.Actions.QuerySelector
import org.apache.gluten.integration.action.TableRender.Field
import
org.apache.gluten.integration.action.TableRender.RowParser.FieldAppender.RowAppender
import org.apache.gluten.integration.stat.RamStat
-import org.apache.gluten.integration.{QueryRunner, Suite, TableCreator}
+import org.apache.gluten.integration.{QueryRunner, Suite}
import org.apache.spark.sql.ConfUtils.ConfImplicits._
-import org.apache.spark.sql.SparkSessionSwitcher
+import org.apache.spark.sql.SparkSession
-import java.util.concurrent.atomic.AtomicInteger
+import java.util.concurrent.atomic.{AtomicInteger, AtomicLong}
import scala.collection.mutable
import scala.collection.mutable.ListBuffer
@@ -36,20 +37,22 @@ class Parameterized(
explain: Boolean,
iterations: Int,
warmupIterations: Int,
- configDimensions: Seq[Dim],
- excludedCombinations: Seq[Set[DimKv]],
+ noSessionReuse: Boolean,
+ configDimensions: Seq[Parameterized.Dim],
+ excludedCombinations: Seq[Set[Parameterized.DimKv]],
metrics: Array[String])
extends Action {
+ import Parameterized._
validateDims(configDimensions)
private def validateDims(configDimensions: Seq[Dim]): Unit = {
if (configDimensions
- .map(dim => {
- dim.name
- })
- .toSet
- .size != configDimensions.size) {
+ .map(dim => {
+ dim.name
+ })
+ .toSet
+ .size != configDimensions.size) {
throw new IllegalStateException("Duplicated dimension name found")
}
@@ -73,9 +76,9 @@ class Parameterized(
// we got one coordinate
excludedCombinations.foreach { ec: Set[DimKv] =>
if (ec.forall { kv =>
- intermediateCoordinate.contains(kv.k) &&
intermediateCoordinate(kv.k) == kv.v
- }) {
- println(s"Coordinate ${intermediateCoordinate} excluded by $ec.")
+ intermediateCoordinate.contains(kv.k) &&
intermediateCoordinate(kv.k) == kv.v
+ }) {
+ println(s"Coordinate $intermediateCoordinate excluded by $ec.")
return
}
}
@@ -105,9 +108,8 @@ class Parameterized(
val testConf = suite.getTestConf()
println("Prepared coordinates: ")
- coordinates.toList.map(_._1).zipWithIndex.foreach {
- case (c, idx) =>
- println(s" $idx: $c")
+ coordinates.keys.foreach { c =>
+ println(s" ${c.id}: $c")
}
coordinates.foreach { entry =>
// register one session per coordinate
@@ -118,39 +120,67 @@ class Parameterized(
sessionSwitcher.registerSession(coordinate.toString, conf)
}
- val runQueryIds = queries.select(suite)
+ val runQueryIds = queries.select(suite).map(TestResultLine.QueryId(_))
- val results = (0 until iterations).flatMap { iteration =>
- runQueryIds.map { queryId =>
- val queryResult =
- TestResultLine(
- queryId,
- coordinates.map { entry =>
- val coordinate = entry._1
- println(s"Running tests (iteration $iteration) with coordinate
$coordinate...")
- // warm up
- (0 until warmupIterations).foreach { _ =>
- Parameterized.warmUp(
- runner,
- suite.tableCreator(),
- sessionSwitcher,
- queryId,
- suite.desc())
- }
- // run
+ val marks: Seq[TestResultLine.CoordMark] = coordinates.flatMap { entry =>
+ val coordinate = entry._1
+ sessionSwitcher.useSession(coordinate.toString, "Parameterized
%s".format(coordinate))
+ runner.createTables(suite.tableCreator(), sessionSwitcher.spark())
+
+ runQueryIds.flatMap { queryId =>
+ // warm up
+ (0 until warmupIterations).foreach { iteration =>
+ println(s"Warming up: Running query $queryId (iteration
$iteration)...")
+ try {
+ Parameterized.warmUp(
+ runner,
+ sessionSwitcher.spark(),
+ queryId.id,
+ coordinate,
+ suite.desc())
+ } finally {
+ if (noSessionReuse) {
+ sessionSwitcher.renewSession()
+ runner.createTables(suite.tableCreator(),
sessionSwitcher.spark())
+ }
+ }
+ }
+
+ // run
+ (0 until iterations).map { iteration =>
+ println(s"Running query $queryId with coordinate $coordinate
(iteration $iteration)...")
+ val r =
+ try {
Parameterized.runQuery(
runner,
- suite.tableCreator(),
- sessionSwitcher,
- queryId,
+ sessionSwitcher.spark(),
+ queryId.id,
coordinate,
suite.desc(),
explain,
metrics)
- }.toList)
- queryResult
+ } finally {
+ if (noSessionReuse) {
+ sessionSwitcher.renewSession()
+ runner.createTables(suite.tableCreator(),
sessionSwitcher.spark())
+ }
+ }
+ TestResultLine.CoordMark(iteration, queryId, r)
+ }
+ }
+ }.toSeq
+
+ val results: Seq[TestResultLine] = marks
+ .groupBy(m => (m.iteration, m.queryId))
+ .toSeq
+ .sortBy(_._1)
+ .map { e =>
+ val iteration = e._1._1
+ val queryId = e._1._2
+ val marks = e._2
+ val line = TestResultLine(queryId, marks.map(_.coord).toList)
+ line
}
- }
val succeededCount = results.count(l => l.succeeded())
val totalCount = results.count(_ => true)
@@ -174,16 +204,27 @@ class Parameterized(
totalCount)
println("")
println("Configurations:")
- coordinates.foreach { coord =>
- println(s"${coord._1.id}. ${coord._1}")
- }
+ coordinates.foreach(coord => println(s"${coord._1.id}. ${coord._1}"))
println("")
val succeeded = results.filter(_.succeeded())
- TestResultLines(
- coordinates.size,
- configDimensions,
- metrics,
- succeeded ++ TestResultLine.aggregate("all", succeeded))
+ val all = succeeded match {
+ case Nil => None
+ case several =>
+ Some(
+ TestResultLine(
+ TestResultLine.QueryId("all"),
+ coordinates.keys.map { c =>
+ TestResultLine.Coord(
+ c,
+ several
+ .map(_.coord(c.id))
+ .map(_.queryResult)
+ .asSuccesses()
+ .agg(s"coordinate $c")
+ .get)
+ }.toSeq))
+ }
+ TestResultLines(coordinates.map(_._1.id).toSeq, configDimensions, metrics,
succeeded ++ all)
.print()
println("")
@@ -193,7 +234,11 @@ class Parameterized(
} else {
println("Failed queries: ")
println("")
- TestResultLines(coordinates.size, configDimensions, metrics,
results.filter(!_.succeeded()))
+ TestResultLines(
+ coordinates.map(_._1.id).toSeq,
+ configDimensions,
+ metrics,
+ results.filter(!_.succeeded()))
.print()
println("")
}
@@ -205,157 +250,114 @@ class Parameterized(
}
}
-case class DimKv(k: String, v: String)
-case class Dim(name: String, dimValues: Seq[DimValue])
-case class DimValue(name: String, conf: Seq[(String, String)])
-// coordinate: [dim, dim value]
-case class Coordinate(id: Int, coordinate: Map[String, String]) {
- override def toString: String = coordinate.mkString(", ")
-}
+object Parameterized {
+ case class DimKv(k: String, v: String)
+
+ case class Dim(name: String, dimValues: Seq[DimValue])
+
+ case class DimValue(name: String, conf: Seq[(String, String)])
-case class TestResultLine(queryId: String, coordinates:
Seq[TestResultLine.Coord]) {
- def succeeded(): Boolean = {
- coordinates.forall(_.succeeded)
+ // coordinate: [dim, dim value]
+ case class Coordinate(id: Int, coordinate: Map[String, String]) {
+ override def toString: String = coordinate.mkString(", ")
}
-}
-object TestResultLine {
- case class Coord(
- coordinate: Coordinate,
- succeeded: Boolean,
- rowCount: Option[Long],
- planningTimeMillis: Option[Long],
- executionTimeMillis: Option[Long],
- metrics: Map[String, Long],
- errorMessage: Option[String])
-
- class Parser(metricNames: Seq[String]) extends
TableRender.RowParser[TestResultLine] {
- override def parse(rowAppender: RowAppender, line: TestResultLine): Unit =
{
- val inc = rowAppender.incremental()
- inc.next().write(line.queryId)
- val coords = line.coordinates
- coords.foreach(coord => inc.next().write(coord.succeeded))
- coords.foreach(coord => inc.next().write(coord.rowCount))
- metricNames.foreach(metricName =>
- coords.foreach(coord => inc.next().write(coord.metrics(metricName))))
- coords.foreach(coord => inc.next().write(coord.planningTimeMillis))
- coords.foreach(coord => inc.next().write(coord.executionTimeMillis))
+ case class TestResultLine(
+ queryId: TestResultLine.QueryId,
+ coordinates: Seq[TestResultLine.Coord]) {
+ private val coordMap = coordinates.map(c => c.coordinate.id -> c).toMap
+ def succeeded(): Boolean = {
+ coordinates.forall(_.queryResult.succeeded())
}
+
+ def coord(id: Int): TestResultLine.Coord = coordMap(id)
}
- def aggregate(name: String, lines: Iterable[TestResultLine]):
Iterable[TestResultLine] = {
- if (lines.isEmpty) {
- return Nil
+ object TestResultLine {
+ case class QueryId(id: String) {
+ import QueryId._
+ private val uid = nextUid.getAndIncrement()
+ override def toString: String = id
}
- if (lines.size == 1) {
- return Nil
+ object QueryId {
+ private val nextUid = new AtomicLong(0L)
+ implicit val o: Ordering[QueryId] = Ordering.by(_.uid)
}
- List(lines.reduce { (left, right) =>
- TestResultLine(name, left.coordinates.zip(right.coordinates).map {
- case (leftCoord, rightCoord) =>
- assert(leftCoord.coordinate == rightCoord.coordinate)
- Coord(
- leftCoord.coordinate,
- leftCoord.succeeded && rightCoord.succeeded,
- (leftCoord.rowCount, rightCoord.rowCount).onBothProvided(_ + _),
- (leftCoord.planningTimeMillis,
rightCoord.planningTimeMillis).onBothProvided(_ + _),
- (leftCoord.executionTimeMillis,
rightCoord.executionTimeMillis).onBothProvided(_ + _),
- (leftCoord.metrics, rightCoord.metrics).sumUp,
- (leftCoord.errorMessage ++ rightCoord.errorMessage).reduceOption(_
+ ", " + _))
- })
- })
+ case class Coord(coordinate: Coordinate, queryResult: QueryResult)
+ case class CoordMark(iteration: Int, queryId: QueryId, coord: Coord)
+
+ class Parser(coordIds: Seq[Int], metricNames: Seq[String])
+ extends TableRender.RowParser[TestResultLine] {
+ override def parse(rowAppender: RowAppender, line: TestResultLine): Unit
= {
+ val inc = rowAppender.incremental()
+ inc.next().write(line.queryId)
+ val coords = coordIds.map(id => line.coord(id))
+ coords.foreach(coord =>
inc.next().write(coord.queryResult.succeeded()))
+ coords.foreach(coord =>
+
inc.next().write(coord.queryResult.asSuccessOption().map(_.runResult.rows.size)))
+ metricNames.foreach(metricName =>
+ coords.foreach(coord =>
+ inc
+ .next()
+
.write(coord.queryResult.asSuccessOption().map(_.runResult.metrics(metricName)))))
+ coords.foreach(coord =>
+ inc
+ .next()
+
.write(coord.queryResult.asSuccessOption().map(_.runResult.planningTimeMillis)))
+ coords.foreach(coord =>
+ inc
+ .next()
+
.write(coord.queryResult.asSuccessOption().map(_.runResult.executionTimeMillis)))
+ }
+ }
}
-}
-case class TestResultLines(
- coordCount: Int,
- configDimensions: Seq[Dim],
- metricNames: Seq[String],
- lines: Iterable[TestResultLine]) {
- def print(): Unit = {
- val fields = ListBuffer[Field](Field.Leaf("Query ID"))
- val coordFields = (1 to coordCount).map(id => Field.Leaf(id.toString))
-
- fields.append(Field.Branch("Succeeded", coordFields))
- fields.append(Field.Branch("Row Count", coordFields))
- metricNames.foreach(metricName => fields.append(Field.Branch(metricName,
coordFields)))
- fields.append(Field.Branch("Planning Time (Millis)", coordFields))
- fields.append(Field.Branch("Query Time (Millis)", coordFields))
-
- val render =
- TableRender.create[TestResultLine](fields: _*)(new
TestResultLine.Parser(metricNames))
-
- lines.foreach { line =>
- render.appendRow(line)
- }
+ case class TestResultLines(
+ coordIds: Seq[Int],
+ configDimensions: Seq[Dim],
+ metricNames: Seq[String],
+ lines: Iterable[TestResultLine]) {
+ def print(): Unit = {
+ val fields = ListBuffer[Field](Field.Leaf("Query ID"))
+ val coordFields = coordIds.map(id => Field.Leaf(id.toString))
+
+ fields.append(Field.Branch("Succeeded", coordFields))
+ fields.append(Field.Branch("Row Count", coordFields))
+ metricNames.foreach(metricName => fields.append(Field.Branch(metricName,
coordFields)))
+ fields.append(Field.Branch("Planning Time (Millis)", coordFields))
+ fields.append(Field.Branch("Query Time (Millis)", coordFields))
- render.print(System.out)
+ val render =
+ TableRender.create[TestResultLine](fields: _*)(
+ new TestResultLine.Parser(coordIds, metricNames))
+
+ lines.foreach(line => render.appendRow(line))
+
+ render.print(System.out)
+ }
}
-}
-object Parameterized {
private def runQuery(
runner: QueryRunner,
- creator: TableCreator,
- sessionSwitcher: SparkSessionSwitcher,
+ spark: SparkSession,
id: String,
coordinate: Coordinate,
desc: String,
explain: Boolean,
metrics: Array[String]): TestResultLine.Coord = {
- println(s"Running query: $id...")
- try {
- val testDesc = "Gluten Spark %s [%s] %s".format(desc, id, coordinate)
- sessionSwitcher.useSession(coordinate.toString, testDesc)
- runner.createTables(creator, sessionSwitcher.spark())
- val result =
- runner.runQuery(sessionSwitcher.spark(), testDesc, id, explain,
metrics)
- val resultRows = result.rows
- println(
- s"Successfully ran query $id. " +
- s"Returned row count: ${resultRows.length}")
- TestResultLine.Coord(
- coordinate,
- succeeded = true,
- Some(resultRows.length),
- Some(result.planningTimeMillis),
- Some(result.executionTimeMillis),
- result.metrics,
- None)
- } catch {
- case e: Exception =>
- val error = Some(s"FATAL: ${ExceptionUtils.getStackTrace(e)}")
- println(
- s"Error running query $id. " +
- s" Error: ${error.get}")
- TestResultLine.Coord(coordinate, succeeded = false, None, None, None,
Map.empty, error)
- }
+ val testDesc = "Query %s [%s] %s".format(desc, id, coordinate)
+ val result = runner.runQuery(spark, testDesc, id, explain, metrics)
+ TestResultLine.Coord(coordinate, result)
}
private def warmUp(
runner: QueryRunner,
- creator: TableCreator,
- sessionSwitcher: SparkSessionSwitcher,
+ session: SparkSession,
id: String,
+ coordinate: Coordinate,
desc: String): Unit = {
- println(s"Warming up: Running query: $id...")
- try {
- val testDesc = "Gluten Spark %s [%s] Warm Up".format(desc, id)
- sessionSwitcher.useSession("test", testDesc)
- runner.createTables(creator, sessionSwitcher.spark())
- val result = runner.runQuery(sessionSwitcher.spark(), testDesc, id,
explain = false)
- val resultRows = result.rows
- println(
- s"Warming up: Successfully ran query $id. " +
- s"Returned row count: ${resultRows.length}")
- } catch {
- case e: Exception =>
- val error = Some(s"FATAL: ${ExceptionUtils.getStackTrace(e)}")
- println(
- s"Warming up: Error running query $id. " +
- s" Error: ${error.get}")
- }
+ runQuery(runner, session, id, coordinate, desc, explain = false,
Array.empty)
}
}
diff --git
a/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/action/Queries.scala
b/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/action/Queries.scala
index de09d925e..b8a42f393 100644
---
a/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/action/Queries.scala
+++
b/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/action/Queries.scala
@@ -16,11 +16,12 @@
*/
package org.apache.gluten.integration.action
-import org.apache.commons.lang3.exception.ExceptionUtils
+import org.apache.gluten.integration.QueryRunner.QueryResult
import org.apache.gluten.integration.action.Actions.QuerySelector
import
org.apache.gluten.integration.action.TableRender.RowParser.FieldAppender.RowAppender
import org.apache.gluten.integration.stat.RamStat
-import org.apache.gluten.integration.{QueryRunner, Suite}
+import org.apache.gluten.integration.{QueryRunner, Suite, TableCreator}
+import org.apache.spark.sql.{SparkSession}
case class Queries(
scale: Double,
@@ -28,28 +29,40 @@ case class Queries(
queries: QuerySelector,
explain: Boolean,
iterations: Int,
- randomKillTasks: Boolean)
+ randomKillTasks: Boolean,
+ noSessionReuse: Boolean)
extends Action {
+ import Queries._
override def execute(suite: Suite): Boolean = {
val runQueryIds = queries.select(suite)
val runner: QueryRunner =
new QueryRunner(suite.queryResource(), suite.dataWritePath(scale,
genPartitionedData))
+ val sessionSwitcher = suite.sessionSwitcher
+ sessionSwitcher.useSession("test", "Run Queries")
+ runner.createTables(suite.tableCreator(), sessionSwitcher.spark())
val results = (0 until iterations).flatMap { iteration =>
println(s"Running tests (iteration $iteration)...")
runQueryIds.map { queryId =>
- Queries.runQuery(
- runner,
- suite.tableCreator(),
- suite.sessionSwitcher,
- queryId,
- suite.desc(),
- explain,
- randomKillTasks)
+ try {
+ Queries.runQuery(
+ runner,
+ suite.tableCreator(),
+ sessionSwitcher.spark(),
+ queryId,
+ suite.desc(),
+ explain,
+ randomKillTasks)
+ } finally {
+ if (noSessionReuse) {
+ sessionSwitcher.renewSession()
+ runner.createTables(suite.tableCreator(), sessionSwitcher.spark())
+ }
+ }
}
}.toList
- val passedCount = results.count(l => l.testPassed)
+ val passedCount = results.count(l => l.queryResult.succeeded())
val count = results.count(_ => true)
// RAM stats
@@ -67,8 +80,9 @@ case class Queries(
println("")
printf("Summary: %d out of %d queries passed. \n", passedCount, count)
println("")
- val succeed = results.filter(_.testPassed)
- Queries.printResults(succeed)
+ val succeeded = results.filter(_.queryResult.succeeded())
+ val all = succeeded.map(_.queryResult).asSuccesses().agg("all").map(s =>
TestResultLine(s))
+ Queries.printResults(succeeded ++ all)
println("")
if (passedCount == count) {
@@ -77,21 +91,10 @@ case class Queries(
} else {
println("Failed queries: ")
println("")
- Queries.printResults(results.filter(!_.testPassed))
+ Queries.printResults(results.filter(!_.queryResult.succeeded()))
println("")
}
- var all = Queries.aggregate(results, "all")
-
- if (passedCount != count) {
- all = Queries.aggregate(succeed, "succeeded") ::: all
- }
-
- println("Overall: ")
- println("")
- Queries.printResults(all)
- println("")
-
if (passedCount != count) {
return false
}
@@ -100,28 +103,29 @@ case class Queries(
}
object Queries {
- case class TestResultLine(
- queryId: String,
- testPassed: Boolean,
- rowCount: Option[Long],
- planningTimeMillis: Option[Long],
- executionTimeMillis: Option[Long],
- errorMessage: Option[String])
+ case class TestResultLine(queryResult: QueryResult)
object TestResultLine {
implicit object Parser extends TableRender.RowParser[TestResultLine] {
override def parse(rowAppender: RowAppender, line: TestResultLine): Unit
= {
val inc = rowAppender.incremental()
- inc.next().write(line.queryId)
- inc.next().write(line.testPassed)
- inc.next().write(line.rowCount)
- inc.next().write(line.planningTimeMillis)
- inc.next().write(line.executionTimeMillis)
+ inc.next().write(line.queryResult.caseId())
+ inc.next().write(line.queryResult.succeeded())
+ line.queryResult match {
+ case QueryRunner.Success(_, runResult) =>
+ inc.next().write(runResult.rows.size)
+ inc.next().write(runResult.planningTimeMillis)
+ inc.next().write(runResult.executionTimeMillis)
+ case QueryRunner.Failure(_, error) =>
+ inc.next().write(None)
+ inc.next().write(None)
+ inc.next().write(None)
+ }
}
}
}
- private def printResults(results: List[TestResultLine]): Unit = {
+ private def printResults(results: Seq[TestResultLine]): Unit = {
val render = TableRender.plain[TestResultLine](
"Query ID",
"Was Passed",
@@ -136,64 +140,18 @@ object Queries {
render.print(System.out)
}
- private def aggregate(succeed: List[TestResultLine], name: String):
List[TestResultLine] = {
- if (succeed.isEmpty) {
- return Nil
- }
- List(
- succeed.reduce((r1, r2) =>
- TestResultLine(
- name,
- testPassed = true,
- if (r1.rowCount.nonEmpty && r2.rowCount.nonEmpty)
- Some(r1.rowCount.get + r2.rowCount.get)
- else None,
- if (r1.planningTimeMillis.nonEmpty && r2.planningTimeMillis.nonEmpty)
- Some(r1.planningTimeMillis.get + r2.planningTimeMillis.get)
- else None,
- if (r1.executionTimeMillis.nonEmpty &&
r2.executionTimeMillis.nonEmpty)
- Some(r1.executionTimeMillis.get + r2.executionTimeMillis.get)
- else None,
- None)))
- }
-
private def runQuery(
- runner: _root_.org.apache.gluten.integration.QueryRunner,
- creator: _root_.org.apache.gluten.integration.TableCreator,
- sessionSwitcher: _root_.org.apache.spark.sql.SparkSessionSwitcher,
- id: _root_.java.lang.String,
- desc: _root_.java.lang.String,
+ runner: QueryRunner,
+ creator: TableCreator,
+ session: SparkSession,
+ id: String,
+ desc: String,
explain: Boolean,
- randomKillTasks: Boolean) = {
+ randomKillTasks: Boolean): TestResultLine = {
println(s"Running query: $id...")
- try {
- val testDesc = "Gluten Spark %s %s".format(desc, id)
- sessionSwitcher.useSession("test", testDesc)
- runner.createTables(creator, sessionSwitcher.spark())
- val result = runner.runQuery(
- sessionSwitcher.spark(),
- testDesc,
- id,
- explain = explain,
- randomKillTasks = randomKillTasks)
- val resultRows = result.rows
- println(
- s"Successfully ran query $id. " +
- s"Returned row count: ${resultRows.length}")
- TestResultLine(
- id,
- testPassed = true,
- Some(resultRows.length),
- Some(result.planningTimeMillis),
- Some(result.executionTimeMillis),
- None)
- } catch {
- case e: Exception =>
- val error = Some(s"FATAL: ${ExceptionUtils.getStackTrace(e)}")
- println(
- s"Error running query $id. " +
- s" Error: ${error.get}")
- TestResultLine(id, testPassed = false, None, None, None, error)
- }
+ val testDesc = "Query %s [%s]".format(desc, id)
+ val result =
+ runner.runQuery(session, testDesc, id, explain = explain,
randomKillTasks = randomKillTasks)
+ TestResultLine(result)
}
}
diff --git
a/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/action/QueriesCompare.scala
b/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/action/QueriesCompare.scala
index d7b6ffff8..804f1fbd7 100644
---
a/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/action/QueriesCompare.scala
+++
b/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/action/QueriesCompare.scala
@@ -17,37 +17,78 @@
package org.apache.gluten.integration.action
import org.apache.commons.lang3.exception.ExceptionUtils
+import org.apache.gluten.integration.QueryRunner.QueryResult
import org.apache.gluten.integration.action.Actions.QuerySelector
+import org.apache.gluten.integration.action.QueriesCompare.TestResultLine
import
org.apache.gluten.integration.action.TableRender.RowParser.FieldAppender.RowAppender
import org.apache.gluten.integration.stat.RamStat
import org.apache.gluten.integration.{QueryRunner, Suite, TableCreator}
-import org.apache.spark.sql.{SparkSessionSwitcher, TestUtils}
+import org.apache.spark.sql.{RunResult, SparkSession, SparkSessionSwitcher,
TestUtils}
case class QueriesCompare(
scale: Double,
genPartitionedData: Boolean,
queries: QuerySelector,
explain: Boolean,
- iterations: Int)
+ iterations: Int,
+ noSessionReuse: Boolean)
extends Action {
override def execute(suite: Suite): Boolean = {
val runner: QueryRunner =
new QueryRunner(suite.queryResource(), suite.dataWritePath(scale,
genPartitionedData))
val runQueryIds = queries.select(suite)
- val results = (0 until iterations).flatMap { iteration =>
- println(s"Running tests (iteration $iteration)...")
+ val sessionSwitcher = suite.sessionSwitcher
+
+ sessionSwitcher.useSession("baseline", "Run Baseline Queries")
+ runner.createTables(suite.tableCreator(), sessionSwitcher.spark())
+ val baselineResults = (0 until iterations).flatMap { iteration =>
+ runQueryIds.map { queryId =>
+ println(s"Running baseline query $queryId (iteration $iteration)...")
+ try {
+ QueriesCompare.runBaselineQuery(
+ runner,
+ sessionSwitcher.spark(),
+ suite.desc(),
+ queryId,
+ explain)
+ } finally {
+ if (noSessionReuse) {
+ sessionSwitcher.renewSession()
+ runner.createTables(suite.tableCreator(), sessionSwitcher.spark())
+ }
+ }
+ }
+ }.toList
+
+ sessionSwitcher.useSession("test", "Run Test Queries")
+ runner.createTables(suite.tableCreator(), sessionSwitcher.spark())
+ val testResults = (0 until iterations).flatMap { iteration =>
runQueryIds.map { queryId =>
- QueriesCompare.runQuery(
- suite.tableCreator(),
- queryId,
- explain,
- suite.desc(),
- suite.sessionSwitcher,
- runner)
+ println(s"Running test query $queryId (iteration $iteration)...")
+ try {
+ QueriesCompare.runTestQuery(
+ runner,
+ sessionSwitcher.spark(),
+ suite.desc(),
+ queryId,
+ explain)
+ } finally {
+ if (noSessionReuse) {
+ sessionSwitcher.renewSession()
+ runner.createTables(suite.tableCreator(), sessionSwitcher.spark())
+ }
+ }
}
}.toList
+ assert(baselineResults.size == testResults.size)
+
+ val results: Seq[TestResultLine] = baselineResults.zip(testResults).map {
case (b, t) =>
+ assert(b.caseId() == t.caseId())
+ TestResultLine(b.caseId(), b, t)
+ }
+
val passedCount = results.count(l => l.testPassed)
val count = results.count(_ => true)
@@ -66,8 +107,15 @@ case class QueriesCompare(
println("")
printf("Summary: %d out of %d queries passed. \n", passedCount, count)
println("")
- val succeed = results.filter(_.testPassed)
- QueriesCompare.printResults(succeed)
+ val succeeded = results.filter(_.testPassed)
+ val all = succeeded match {
+ case Nil => None
+ case several =>
+ val allExpected = several.map(_.expected).asSuccesses().agg("all
expected").get
+ val allActual = several.map(_.actual).asSuccesses().agg("all
actual").get
+ Some(TestResultLine("all", allExpected, allActual))
+ }
+ QueriesCompare.printResults(succeeded ++ all)
println("")
if (passedCount == count) {
@@ -81,17 +129,6 @@ case class QueriesCompare(
println("")
}
- var all = QueriesCompare.aggregate("all", results)
-
- if (passedCount != count) {
- all = QueriesCompare.aggregate("succeeded", succeed) ::: all
- }
-
- println("Overall: ")
- println("")
- QueriesCompare.printResults(all)
- println("")
-
if (passedCount != count) {
return false
}
@@ -100,41 +137,46 @@ case class QueriesCompare(
}
object QueriesCompare {
- case class TestResultLine(
- queryId: String,
- testPassed: Boolean,
- expectedRowCount: Option[Long],
- actualRowCount: Option[Long],
- expectedPlanningTimeMillis: Option[Long],
- actualPlanningTimeMillis: Option[Long],
- expectedExecutionTimeMillis: Option[Long],
- actualExecutionTimeMillis: Option[Long],
- errorMessage: Option[String])
+ case class TestResultLine(queryId: String, expected: QueryResult, actual:
QueryResult) {
+ val testPassed: Boolean = {
+ expected.succeeded() && actual.succeeded() &&
+ TestUtils
+ .compareAnswers(
+ expected.asSuccess().runResult.rows,
+ actual.asSuccess().runResult.rows,
+ sort = true)
+ .isEmpty
+ }
+ }
object TestResultLine {
implicit object Parser extends TableRender.RowParser[TestResultLine] {
override def parse(rowAppender: RowAppender, line: TestResultLine): Unit
= {
val inc = rowAppender.incremental()
+ inc.next().write(line.queryId)
+ inc.next().write(line.testPassed)
+
inc.next().write(line.expected.asSuccessOption().map(_.runResult.rows.size))
+
inc.next().write(line.actual.asSuccessOption().map(_.runResult.rows.size))
+
inc.next().write(line.expected.asSuccessOption().map(_.runResult.planningTimeMillis))
+
inc.next().write(line.actual.asSuccessOption().map(_.runResult.planningTimeMillis))
+
inc.next().write(line.expected.asSuccessOption().map(_.runResult.executionTimeMillis))
+
inc.next().write(line.actual.asSuccessOption().map(_.runResult.executionTimeMillis))
+
val speedUp =
- if (line.expectedExecutionTimeMillis.nonEmpty &&
line.actualExecutionTimeMillis.nonEmpty) {
+ if (line.expected.succeeded() && line.actual.succeeded()) {
Some(
- ((line.expectedExecutionTimeMillis.get -
line.actualExecutionTimeMillis.get).toDouble
- / line.actualExecutionTimeMillis.get.toDouble) * 100)
+ ((line.expected.asSuccess().runResult.executionTimeMillis -
line.actual
+ .asSuccess()
+ .runResult
+ .executionTimeMillis).toDouble
+ / line.actual.asSuccess().runResult.executionTimeMillis) * 100)
} else None
- inc.next().write(line.queryId)
- inc.next().write(line.testPassed)
- inc.next().write(line.expectedRowCount)
- inc.next().write(line.actualRowCount)
- inc.next().write(line.expectedPlanningTimeMillis)
- inc.next().write(line.actualPlanningTimeMillis)
- inc.next().write(line.expectedExecutionTimeMillis)
- inc.next().write(line.actualExecutionTimeMillis)
inc.next().write(speedUp.map("%.2f%%".format(_)))
}
}
}
- private def printResults(results: List[TestResultLine]): Unit = {
+ private def printResults(results: Seq[TestResultLine]): Unit = {
import org.apache.gluten.integration.action.TableRender.Field._
val render = TableRender.create[TestResultLine](
@@ -152,79 +194,25 @@ object QueriesCompare {
render.print(System.out)
}
- private def aggregate(name: String, succeed: List[TestResultLine]):
List[TestResultLine] = {
- if (succeed.isEmpty) {
- return Nil
- }
- List(
- succeed.reduce((r1, r2) =>
- TestResultLine(
- name,
- r1.testPassed && r2.testPassed,
- (r1.expectedRowCount, r2.expectedRowCount).onBothProvided(_ + _),
- (r1.actualRowCount, r2.actualRowCount).onBothProvided(_ + _),
- (r1.expectedPlanningTimeMillis,
r2.expectedPlanningTimeMillis).onBothProvided(_ + _),
- (r1.actualPlanningTimeMillis,
r2.actualPlanningTimeMillis).onBothProvided(_ + _),
- (r1.expectedExecutionTimeMillis,
r2.expectedExecutionTimeMillis).onBothProvided(_ + _),
- (r1.actualExecutionTimeMillis,
r2.actualExecutionTimeMillis).onBothProvided(_ + _),
- None)))
+ private def runBaselineQuery(
+ runner: QueryRunner,
+ session: SparkSession,
+ desc: String,
+ id: String,
+ explain: Boolean): QueryResult = {
+ val testDesc = "Baseline %s [%s]".format(desc, id)
+ val result = runner.runQuery(session, testDesc, id, explain = explain)
+ result
}
- private[integration] def runQuery(
- creator: TableCreator,
- id: String,
- explain: Boolean,
+ private def runTestQuery(
+ runner: QueryRunner,
+ session: SparkSession,
desc: String,
- sessionSwitcher: SparkSessionSwitcher,
- runner: QueryRunner): TestResultLine = {
- println(s"Running query: $id...")
- try {
- val baseLineDesc = "Vanilla Spark %s %s".format(desc, id)
- sessionSwitcher.useSession("baseline", baseLineDesc)
- runner.createTables(creator, sessionSwitcher.spark())
- val expected =
- runner.runQuery(sessionSwitcher.spark(), baseLineDesc, id, explain =
explain)
- val expectedRows = expected.rows
- val testDesc = "Gluten Spark %s %s".format(desc, id)
- sessionSwitcher.useSession("test", testDesc)
- runner.createTables(creator, sessionSwitcher.spark())
- val result = runner.runQuery(sessionSwitcher.spark(), testDesc, id,
explain = explain)
- val resultRows = result.rows
- val error = TestUtils.compareAnswers(resultRows, expectedRows, sort =
true)
- if (error.isEmpty) {
- println(
- s"Successfully ran query $id, result check was passed. " +
- s"Returned row count: ${resultRows.length}, expected:
${expectedRows.length}")
- return TestResultLine(
- id,
- testPassed = true,
- Some(expectedRows.length),
- Some(resultRows.length),
- Some(expected.planningTimeMillis),
- Some(result.planningTimeMillis),
- Some(expected.executionTimeMillis),
- Some(result.executionTimeMillis),
- None)
- }
- println(s"Error running query $id, result check was not passed. " +
- s"Returned row count: ${resultRows.length}, expected:
${expectedRows.length}, error: ${error.get}")
- TestResultLine(
- id,
- testPassed = false,
- Some(expectedRows.length),
- Some(resultRows.length),
- Some(expected.planningTimeMillis),
- Some(result.planningTimeMillis),
- Some(expected.executionTimeMillis),
- Some(result.executionTimeMillis),
- error)
- } catch {
- case e: Exception =>
- val error = Some(s"FATAL: ${ExceptionUtils.getStackTrace(e)}")
- println(
- s"Error running query $id. " +
- s" Error: ${error.get}")
- TestResultLine(id, testPassed = false, None, None, None, None, None,
None, error)
- }
+ id: String,
+ explain: Boolean): QueryResult = {
+ val testDesc = "Query %s [%s]".format(desc, id)
+ val result = runner.runQuery(session, testDesc, id, explain = explain)
+ result
}
}
diff --git
a/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/action/SparkShell.scala
b/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/action/SparkShell.scala
index 76f43cb71..1742b99c2 100644
---
a/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/action/SparkShell.scala
+++
b/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/action/SparkShell.scala
@@ -21,7 +21,7 @@ import org.apache.spark.repl.Main
case class SparkShell(scale: Double, genPartitionedData: Boolean) extends
Action {
override def execute(suite: Suite): Boolean = {
- suite.sessionSwitcher.useSession("test", "Gluten Spark CLI")
+ suite.sessionSwitcher.useSession("test", "Spark CLI")
val runner: QueryRunner =
new QueryRunner(suite.queryResource(), suite.dataWritePath(scale,
genPartitionedData))
runner.createTables(suite.tableCreator(), suite.sessionSwitcher.spark())
diff --git
a/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/action/package.scala
b/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/action/package.scala
index 6046ae4aa..a84915ebe 100644
---
a/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/action/package.scala
+++
b/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/action/package.scala
@@ -17,13 +17,31 @@
package org.apache.gluten.integration
+import org.apache.spark.sql.RunResult
+
package object action {
- implicit class DualOptionsOps[T](value: (Option[T], Option[T])) {
- def onBothProvided[R](func: (T, T) => R): Option[R] = {
- if (value._1.isEmpty || value._2.isEmpty) {
- return None
+
+ implicit class QueryResultsOps(results: Iterable[QueryRunner.QueryResult]) {
+ def asSuccesses(): Iterable[QueryRunner.Success] = {
+ results.map(_.asSuccess())
+ }
+
+ def asFailures(): Iterable[QueryRunner.Failure] = {
+ results.map(_.asFailure())
+ }
+ }
+
+ implicit class CompletedOps(completed: Iterable[QueryRunner.Success]) {
+ def agg(name: String): Option[QueryRunner.Success] = {
+ completed.reduceOption { (c1, c2) =>
+ QueryRunner.Success(
+ name,
+ RunResult(
+ c1.runResult.rows ++ c2.runResult.rows,
+ c1.runResult.planningTimeMillis + c2.runResult.planningTimeMillis,
+ c1.runResult.executionTimeMillis +
c2.runResult.executionTimeMillis,
+ (c1.runResult.metrics, c2.runResult.metrics).sumUp))
}
- Some(func(value._1.get, value._2.get))
}
}
diff --git
a/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/clickbench/ClickBenchDataGen.scala
b/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/clickbench/ClickBenchDataGen.scala
index ba772f165..add7b01fe 100644
---
a/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/clickbench/ClickBenchDataGen.scala
+++
b/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/clickbench/ClickBenchDataGen.scala
@@ -31,7 +31,10 @@ class ClickBenchDataGen(val spark: SparkSession, dir:
String) extends DataGen {
// Directly download from official URL.
val target = new File(dir + File.separator + FILE_NAME)
FileUtils.forceMkdirParent(target)
- val code = Process(s"wget -P $dir $DATA_URL") !;
+ val cmd =
+ s"wget --no-verbose --show-progress --progress=bar:force:noscroll -O
$target $DATA_URL"
+ println(s"Executing command: $cmd")
+ val code = Process(cmd) !;
if (code != 0) {
throw new RuntimeException("Download failed")
}
diff --git
a/tools/gluten-it/common/src/main/scala/org/apache/spark/sql/SparkSessionSwitcher.scala
b/tools/gluten-it/common/src/main/scala/org/apache/spark/sql/SparkSessionSwitcher.scala
index 17a50fd29..0a1a25351 100644
---
a/tools/gluten-it/common/src/main/scala/org/apache/spark/sql/SparkSessionSwitcher.scala
+++
b/tools/gluten-it/common/src/main/scala/org/apache/spark/sql/SparkSessionSwitcher.scala
@@ -65,6 +65,16 @@ class SparkSessionSwitcher(val masterUrl: String, val
logLevel: String) extends
useSession(SessionDesc(SessionToken(token), appName))
}
+ def renewSession(): Unit = synchronized {
+ if (!hasActiveSession()) {
+ return
+ }
+ val sd = _activeSessionDesc
+ println(s"Renewing $sd session... ")
+ stopActiveSession()
+ useSession(sd)
+ }
+
private def useSession(desc: SessionDesc): Unit = synchronized {
if (desc == _activeSessionDesc) {
return
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]