This is an automated email from the ASF dual-hosted git repository.
hongze pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new be3eeea8c3 [VL] Gluten-it: Several enhancements (#11600)
be3eeea8c3 is described below
commit be3eeea8c33ddfb5352a37ad7d169e326c4dc1ba
Author: Hongze Zhang <[email protected]>
AuthorDate: Fri Feb 13 22:47:03 2026 +0000
[VL] Gluten-it: Several enhancements (#11600)
---
.../org/apache/gluten/integration/BaseMixin.java | 38 ++++--
.../java/org/apache/gluten/integration/Cli.java | 8 ++
.../gluten/integration/command/DataGenMixin.java | 2 +-
.../gluten/integration/command/Parameterized.java | 3 +
.../apache/gluten/integration/command/Queries.java | 3 +-
.../gluten/integration/command/QueriesMixin.java | 2 +-
.../org/apache/gluten/integration/Constants.scala | 39 +++++-
.../org/apache/gluten/integration/DataGen.scala | 4 +-
.../apache/gluten/integration/QueryRunner.scala | 20 ++-
.../org/apache/gluten/integration/Suite.scala | 90 +++++++++++--
.../apache/gluten/integration/TableAnalyzer.scala | 55 ++++++++
.../apache/gluten/integration/TableCreator.scala | 17 +--
.../gluten/integration/action/DataGenOnly.scala | 51 ++++++--
.../gluten/integration/action/Parameterized.scala | 58 +++++----
.../apache/gluten/integration/action/Queries.scala | 70 ++++++-----
.../gluten/integration/action/QueriesCompare.scala | 75 ++++++-----
.../gluten/integration/action/SparkShell.scala | 2 +-
.../integration/clickbench/ClickBenchDataGen.scala | 4 +-
.../integration/clickbench/ClickBenchSuite.scala | 17 ++-
.../gluten/integration/ds/TpcdsDataGen.scala | 15 +--
.../apache/gluten/integration/ds/TpcdsSuite.scala | 29 ++---
.../apache/gluten/integration/h/TpchDataGen.scala | 26 ++--
.../apache/gluten/integration/h/TpchSuite.scala | 28 +++--
.../gluten/integration/metrics/MetricMapper.scala | 11 +-
.../gluten/integration/metrics/MetricTag.scala | 22 +---
.../gluten/integration/metrics/PlanMetric.scala | 99 +++++++++++++--
.../gluten/integration/report/TestReporter.scala | 140 +++++++++++++++++++++
.../{ShimUtils.scala => shim/Shim.scala} | 4 +-
.../org/apache/spark/sql/SparkQueryRunner.scala | 10 +-
29 files changed, 701 insertions(+), 241 deletions(-)
diff --git
a/tools/gluten-it/common/src/main/java/org/apache/gluten/integration/BaseMixin.java
b/tools/gluten-it/common/src/main/java/org/apache/gluten/integration/BaseMixin.java
index 88ccf9c512..b7ccf01f04 100644
---
a/tools/gluten-it/common/src/main/java/org/apache/gluten/integration/BaseMixin.java
+++
b/tools/gluten-it/common/src/main/java/org/apache/gluten/integration/BaseMixin.java
@@ -29,10 +29,8 @@ import org.apache.log4j.LogManager;
import org.apache.spark.SparkConf;
import picocli.CommandLine;
-import java.util.Arrays;
-import java.util.Collections;
-import java.util.HashMap;
-import java.util.Map;
+import java.io.*;
+import java.util.*;
public class BaseMixin {
@@ -118,7 +116,7 @@ public class BaseMixin {
private int hsUiPort;
@CommandLine.ArgGroup(exclusive = true, multiplicity = "1")
- SparkRunModes.Mode.Enumeration runModeEnumeration;
+ private SparkRunModes.Mode.Enumeration runModeEnumeration;
@CommandLine.Option(
names = {"--disable-aqe"},
@@ -138,6 +136,12 @@ public class BaseMixin {
defaultValue = "false")
private boolean disableWscg;
+ @CommandLine.Option(
+ names = {"--enable-cbo"},
+ description = "Enable Spark CBO and analyze all tables before running
queries",
+ defaultValue = "false")
+ private boolean enableCbo;
+
@CommandLine.Option(
names = {"--shuffle-partitions"},
description = "Shuffle partition number",
@@ -163,6 +167,13 @@ public class BaseMixin {
"Extra Spark config entries applying to generated Spark session.
E.g. --extra-conf=k1=v1 --extra-conf=k2=v2")
private Map<String, String> extraSparkConf = Collections.emptyMap();
+ @CommandLine.Option(
+ names = {"--report"},
+ description =
+ "The file path where the test report will be written. If not
specified, the report will be printed to stdout only.",
+ defaultValue = "")
+ private String reportPath;
+
private SparkConf pickSparkConf(String preset) {
return Preset.get(preset).getConf();
}
@@ -222,11 +233,13 @@ public class BaseMixin {
disableAqe,
disableBhj,
disableWscg,
+ enableCbo,
shufflePartitions,
scanPartitions,
decimalAsDouble,
baselineMetricMapper,
- testMetricMapper);
+ testMetricMapper,
+ reportPath);
break;
case "ds":
suite =
@@ -249,11 +262,13 @@ public class BaseMixin {
disableAqe,
disableBhj,
disableWscg,
+ enableCbo,
shufflePartitions,
scanPartitions,
decimalAsDouble,
baselineMetricMapper,
- testMetricMapper);
+ testMetricMapper,
+ reportPath);
break;
case "clickbench":
suite =
@@ -275,21 +290,22 @@ public class BaseMixin {
disableAqe,
disableBhj,
disableWscg,
+ enableCbo,
shufflePartitions,
scanPartitions,
decimalAsDouble,
baselineMetricMapper,
- testMetricMapper);
+ testMetricMapper,
+ reportPath);
break;
default:
throw new IllegalArgumentException("TPC benchmark type not found: " +
benchmarkType);
}
+
+ // Execute the suite.
final boolean succeed;
try {
succeed = suite.run();
- } catch (Throwable t) {
- t.printStackTrace();
- throw t;
} finally {
suite.close();
}
diff --git
a/tools/gluten-it/common/src/main/java/org/apache/gluten/integration/Cli.java
b/tools/gluten-it/common/src/main/java/org/apache/gluten/integration/Cli.java
index c6adb21f66..b9cbdb470a 100644
---
a/tools/gluten-it/common/src/main/java/org/apache/gluten/integration/Cli.java
+++
b/tools/gluten-it/common/src/main/java/org/apache/gluten/integration/Cli.java
@@ -24,6 +24,8 @@ import org.apache.gluten.integration.command.SparkShell;
import picocli.CommandLine;
+import java.util.Arrays;
+
@CommandLine.Command(
name = "gluten-it",
mixinStandardHelpOptions = true,
@@ -37,10 +39,16 @@ import picocli.CommandLine;
},
description = "Gluten integration test using various of benchmark's data
and queries.")
public class Cli {
+ private static String[] COMMANDLINE_ARGS = new String[0];
private Cli() {}
+ public static String[] args() {
+ return Arrays.copyOf(COMMANDLINE_ARGS, COMMANDLINE_ARGS.length);
+ }
+
public static void main(String... args) {
+ COMMANDLINE_ARGS = args;
final CommandLine cmd = new CommandLine(new Cli());
final int exitCode = cmd.execute(args);
System.exit(exitCode);
diff --git
a/tools/gluten-it/common/src/main/java/org/apache/gluten/integration/command/DataGenMixin.java
b/tools/gluten-it/common/src/main/java/org/apache/gluten/integration/command/DataGenMixin.java
index 427b0c25c2..8f27b3db83 100644
---
a/tools/gluten-it/common/src/main/java/org/apache/gluten/integration/command/DataGenMixin.java
+++
b/tools/gluten-it/common/src/main/java/org/apache/gluten/integration/command/DataGenMixin.java
@@ -25,7 +25,7 @@ public class DataGenMixin {
@CommandLine.Option(
names = {"--data-gen"},
description = "The strategy of data generation, accepted values: skip,
once, always",
- defaultValue = "always")
+ defaultValue = "once")
private String dataGenStrategy;
public Action[] makeActions() {
diff --git
a/tools/gluten-it/common/src/main/java/org/apache/gluten/integration/command/Parameterized.java
b/tools/gluten-it/common/src/main/java/org/apache/gluten/integration/command/Parameterized.java
index e6bb2237e1..80a5cb8ce1 100644
---
a/tools/gluten-it/common/src/main/java/org/apache/gluten/integration/command/Parameterized.java
+++
b/tools/gluten-it/common/src/main/java/org/apache/gluten/integration/command/Parameterized.java
@@ -76,6 +76,9 @@ public class Parameterized implements Callable<Integer> {
@Override
public Integer call() throws Exception {
+ if (dims.length == 0) {
+ throw new IllegalArgumentException("At least one dimension must be
specified by -d / --dim");
+ }
final Map<String, Map<String, List<Map.Entry<String, String>>>> parsed =
new LinkedHashMap<>();
final scala.collection.immutable.Seq<
diff --git
a/tools/gluten-it/common/src/main/java/org/apache/gluten/integration/command/Queries.java
b/tools/gluten-it/common/src/main/java/org/apache/gluten/integration/command/Queries.java
index 6632ab04b5..bd87008f27 100644
---
a/tools/gluten-it/common/src/main/java/org/apache/gluten/integration/command/Queries.java
+++
b/tools/gluten-it/common/src/main/java/org/apache/gluten/integration/command/Queries.java
@@ -50,7 +50,8 @@ public class Queries implements Callable<Integer> {
@CommandLine.Option(
names = {"--sql-metrics"},
description =
- "Collect SQL metrics from run queries and generate a simple report
based on them. Available types: execution-time")
+ "Collect SQL metrics from run queries and generate a simple report
based on them. Available types: execution-time, join-selectivity",
+ split = ",")
private Set<String> collectSqlMetrics = Collections.emptySet();
@Override
diff --git
a/tools/gluten-it/common/src/main/java/org/apache/gluten/integration/command/QueriesMixin.java
b/tools/gluten-it/common/src/main/java/org/apache/gluten/integration/command/QueriesMixin.java
index 8c47bc8bf0..c5bf31f704 100644
---
a/tools/gluten-it/common/src/main/java/org/apache/gluten/integration/command/QueriesMixin.java
+++
b/tools/gluten-it/common/src/main/java/org/apache/gluten/integration/command/QueriesMixin.java
@@ -108,7 +108,7 @@ public class QueriesMixin {
}
final Division div = Division.parse(shard);
querySet = querySet.getShard(div.shard - 1, div.shardCount);
- System.out.println("About to run queries: " + querySet.queryIds() +
"... ");
+ System.out.println("About to run queries: " +
querySet.queryIds().mkString(",") + "... ");
return querySet;
}
};
diff --git
a/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/Constants.scala
b/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/Constants.scala
index 1a10334a86..a37cf71617 100644
---
a/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/Constants.scala
+++
b/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/Constants.scala
@@ -16,8 +16,8 @@
*/
package org.apache.gluten.integration
-import org.apache.gluten.integration.metrics.MetricMapper
-import org.apache.gluten.integration.metrics.MetricMapper.SelfTimeMapper
+import org.apache.gluten.integration.metrics.{MetricMapper, MetricTag}
+import org.apache.gluten.integration.metrics.MetricMapper.SimpleMetricMapper
import org.apache.spark.SparkConf
import org.apache.spark.sql.TypeUtils
@@ -78,7 +78,8 @@ object Constants {
.set("spark.sql.optimizer.runtime.bloomFilter.applicationSideScanSizeThreshold",
"0")
.set("spark.gluten.sql.columnar.physicalJoinOptimizeEnable", "false")
- val VANILLA_METRIC_MAPPER: MetricMapper = SelfTimeMapper(
+ val VANILLA_METRIC_MAPPER: MetricMapper = SimpleMetricMapper(
+ Seq(MetricTag.IsSelfTime),
Map(
"FileSourceScanExec" -> Set("metadataTime", "scanTime"),
"HashAggregateExec" -> Set("aggTime"),
@@ -91,10 +92,12 @@ object Constants {
"ShuffleExchangeExec" -> Set("fetchWaitTime", "shuffleWriteTime"),
"ShuffledHashJoinExec" -> Set("buildTime"),
"WindowGroupLimitExec" -> Set() // No available metrics provided by
vanilla Spark.
- ))
+ )
+ )
- val VELOX_METRIC_MAPPER: MetricMapper = VANILLA_METRIC_MAPPER.and(
- SelfTimeMapper(
+ val VELOX_METRIC_MAPPER: MetricMapper = VANILLA_METRIC_MAPPER
+ .and(SimpleMetricMapper(
+ Seq(MetricTag.IsSelfTime),
Map(
"FileSourceScanExecTransformer" -> Set("scanTime", "pruningTime",
"remainingFilterTime"),
"ProjectExecTransformer" -> Set("wallNanos"),
@@ -121,6 +124,30 @@ object Constants {
"ExpandExecTransformer" -> Set("wallNanos")
)
))
+ .and(SimpleMetricMapper(
+ Seq(MetricTag.IsJoinProbeInputNumRows),
+ Map(
+ "BroadcastHashJoinExecTransformer" -> Set("hashProbeInputRows"),
+ "ShuffledHashJoinExecTransformer" -> Set("hashProbeInputRows"),
+ "BroadcastNestedLoopJoinExecTransformer" ->
Set("nestedLoopJoinProbeInputRows")
+ )
+ ))
+ .and(SimpleMetricMapper(
+ Seq(MetricTag.IsJoinProbeOutputNumRows),
+ Map(
+ "BroadcastHashJoinExecTransformer" -> Set("hashProbeOutputRows"),
+ "ShuffledHashJoinExecTransformer" -> Set("hashProbeOutputRows"),
+ "BroadcastNestedLoopJoinExecTransformer" ->
Set("nestedLoopJoinProbeOutputRows")
+ )
+ ))
+ .and(SimpleMetricMapper(
+ Seq(MetricTag.IsJoinOutputNumRows),
+ Map(
+ "BroadcastHashJoinExecTransformer" -> Set("numOutputRows"),
+ "ShuffledHashJoinExecTransformer" -> Set("numOutputRows"),
+ "BroadcastNestedLoopJoinExecTransformer" -> Set("numOutputRows")
+ )
+ ))
@deprecated
val TYPE_MODIFIER_DATE_AS_DOUBLE: TypeModifier =
diff --git
a/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/DataGen.scala
b/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/DataGen.scala
index 9ada805c6e..b7bdc122b0 100644
---
a/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/DataGen.scala
+++
b/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/DataGen.scala
@@ -22,7 +22,7 @@ import org.apache.spark.sql.types.{DataType, StructField,
StructType}
import scala.collection.mutable
trait DataGen {
- def gen(): Unit
+ def gen(spark: SparkSession): Unit
}
abstract class TypeModifier(val predicate: DataType => Boolean, val to:
DataType)
@@ -71,7 +71,7 @@ object DataGen {
trait Feature extends Serializable {
def name(): String
- def run(spark: SparkSession, source: String)
+ def run(spark: SparkSession, source: String): Unit
}
object Feature {
diff --git
a/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/QueryRunner.scala
b/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/QueryRunner.scala
index 8f817f418a..efbd6dbd7c 100644
---
a/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/QueryRunner.scala
+++
b/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/QueryRunner.scala
@@ -30,13 +30,10 @@ import java.net.URI
class QueryRunner(val source: String, val dataPath: String) {
import QueryRunner._
- Preconditions.checkState(
- fileExists(dataPath),
- s"Data not found at $dataPath, try using command `<gluten-it>
data-gen-only <options>` to generate it first.",
- Array(): _*)
- def createTables(creator: TableCreator, spark: SparkSession): Unit = {
+ def createTables(creator: TableCreator, analyzer: TableAnalyzer, spark:
SparkSession): Unit = {
creator.create(spark, source, dataPath)
+ analyzer.analyze(spark)
}
def runQuery(
@@ -48,6 +45,13 @@ class QueryRunner(val source: String, val dataPath: String) {
executorMetrics: Seq[String] = Nil,
randomKillTasks: Boolean = false): QueryResult = {
try {
+ val path = new Path(dataPath)
+ val fs = path.getFileSystem(spark.sessionState.newHadoopConf())
+ Preconditions.checkState(
+ fs.exists(path),
+ s"Data not found at $dataPath, try using command `<gluten-it>
data-gen-only <options>` to generate it first.",
+ Array(): _*)
+
val r =
SparkQueryRunner.runQuery(
spark,
@@ -66,12 +70,6 @@ class QueryRunner(val source: String, val dataPath: String) {
}
}
- private def fileExists(datapath: String): Boolean = {
- if (datapath.startsWith("hdfs:") || datapath.startsWith("s3a:")) {
- val uri = URI.create(datapath)
- FileSystem.get(uri, new Configuration()).exists(new Path(uri.getPath))
- } else new File(datapath).exists()
- }
}
object QueryRunner {
diff --git
a/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/Suite.scala
b/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/Suite.scala
index 2ea814df27..5503889d19 100644
---
a/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/Suite.scala
+++
b/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/Suite.scala
@@ -19,6 +19,7 @@ package org.apache.gluten.integration
import org.apache.gluten.integration.Constants.TYPE_MODIFIER_DECIMAL_AS_DOUBLE
import org.apache.gluten.integration.action.Action
import org.apache.gluten.integration.metrics.MetricMapper
+import org.apache.gluten.integration.report.TestReporter
import org.apache.spark.SparkConf
import org.apache.spark.deploy.history.HistoryServerHelper
@@ -26,9 +27,13 @@ import org.apache.spark.network.util.ByteUnit
import org.apache.spark.sql.ConfUtils.ConfImplicits._
import org.apache.spark.sql.SparkSessionSwitcher
+import org.apache.commons.io.output.{NullOutputStream, TeeOutputStream}
+import org.apache.commons.lang3.StringUtils
import org.apache.log4j.{Level, LogManager}
-import java.io.File
+import java.io.{BufferedOutputStream, File, FileNotFoundException,
FileOutputStream, OutputStream, PrintStream}
+import java.time.{Instant, ZoneId}
+import java.time.format.DateTimeFormatter
import java.util.Scanner
abstract class Suite(
@@ -45,14 +50,17 @@ abstract class Suite(
private val disableAqe: Boolean,
private val disableBhj: Boolean,
private val disableWscg: Boolean,
+ private val enableCbo: Boolean,
private val shufflePartitions: Int,
private val scanPartitions: Int,
private val decimalAsDouble: Boolean,
private val baselineMetricMapper: MetricMapper,
- private val testMetricMapper: MetricMapper) {
+ private val testMetricMapper: MetricMapper,
+ private val reportPath: String) {
resetLogLevel()
+ private val reporter: TestReporter = TestReporter.create()
private var hsUiBoundPort: Int = -1
private[integration] val sessionSwitcher: SparkSessionSwitcher =
@@ -105,6 +113,19 @@ abstract class Suite(
sessionSwitcher.addDefaultConf("spark.sql.codegen.wholeStage", "false")
}
+ if (enableCbo) {
+ sessionSwitcher.addDefaultConf("spark.sql.cbo.enabled", "true")
+ sessionSwitcher.addDefaultConf("spark.sql.cbo.planStats.enabled", "true")
+ sessionSwitcher.addDefaultConf("spark.sql.cbo.joinReorder.enabled", "true")
+ sessionSwitcher.addDefaultConf("spark.sql.cbo.joinReorder.dp.threshold",
"12")
+ sessionSwitcher.addDefaultConf("spark.sql.cbo.joinReorder.card.weight",
"0.7")
+ sessionSwitcher.addDefaultConf("spark.sql.cbo.joinReorder.dp.star.filter",
"true")
+ sessionSwitcher.addDefaultConf("spark.sql.cbo.starSchemaDetection", "true")
+ sessionSwitcher.addDefaultConf("spark.sql.cbo.starJoinFTRatio", "0.9")
+ sessionSwitcher.addDefaultConf("spark.sql.statistics.histogram.enabled",
"true")
+ sessionSwitcher.addDefaultConf("spark.sql.statistics.histogram.numBins",
"254")
+ }
+
if (scanPartitions != -1) {
// Scan partition number.
sessionSwitcher.addDefaultConf(
@@ -135,12 +156,54 @@ abstract class Suite(
}
def run(): Boolean = {
- val succeed = actions.forall {
+ // Report metadata.
+ val formatter =
+ DateTimeFormatter
+ .ofPattern("yyyy-MM-dd HH:mm:ss")
+ .withZone(ZoneId.systemDefault())
+ val formattedTime =
formatter.format(Instant.ofEpochMilli(System.currentTimeMillis()))
+ reporter.addMetadata("Timestamp", formattedTime)
+ reporter.addMetadata("Arguments", Cli.args().mkString(" "))
+
+ // Construct the output streams for writing test reports.
+ var fileOut: OutputStream = null
+ if (!StringUtils.isBlank(reportPath)) try {
+ val file = new File(reportPath)
+ if (file.isDirectory) throw new FileNotFoundException("Is a directory: "
+ reportPath)
+ println("Test report will be written to " + file.getAbsolutePath)
+ fileOut = new BufferedOutputStream(new FileOutputStream(file))
+ } catch {
+ case e: FileNotFoundException =>
+ throw new RuntimeException(e)
+ }
+ else fileOut = NullOutputStream.NULL_OUTPUT_STREAM
+ val combinedOut = new PrintStream(new TeeOutputStream(System.out,
fileOut), true)
+ val combinedErr = new PrintStream(new TeeOutputStream(System.err,
fileOut), true)
+
+ // Execute the suite.
+ val succeeded =
+ try {
+ runActions()
+ } catch {
+ case t: Exception =>
+ t.printStackTrace(reporter.rootAppender.err)
+ false
+ }
+ if (succeeded) {
+ reporter.write(combinedOut)
+ } else {
+ reporter.write(combinedErr)
+ }
+ succeeded
+ }
+
+ private def runActions(): Boolean = {
+ val succeeded = actions.forall {
action =>
resetLogLevel() // to prevent log level from being set by unknown
external codes
action.execute(this)
}
- succeed
+ succeeded
}
def close(): Unit = {
@@ -155,9 +218,17 @@ abstract class Suite(
def tableCreator(): TableCreator
- private def resetLogLevel(): Unit = {
- System.setProperty(org.slf4j.impl.SimpleLogger.DEFAULT_LOG_LEVEL_KEY,
logLevel.toString)
- LogManager.getRootLogger.setLevel(logLevel)
+ final def tableAnalyzer(): TableAnalyzer = {
+ if (enableCbo) {
+ return tableAnalyzer0()
+ }
+ TableAnalyzer.noop()
+ }
+
+ protected def tableAnalyzer0(): TableAnalyzer
+
+ def getReporter(): TestReporter = {
+ reporter
}
private[integration] def getBaselineConf(): SparkConf = {
@@ -195,6 +266,11 @@ abstract class Suite(
private[integration] def allQueries(): QuerySet
private[integration] def desc(): String
+
+ private def resetLogLevel(): Unit = {
+ System.setProperty(org.slf4j.impl.SimpleLogger.DEFAULT_LOG_LEVEL_KEY,
logLevel.toString)
+ LogManager.getRootLogger.setLevel(logLevel)
+ }
}
object Suite {}
diff --git
a/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/TableAnalyzer.scala
b/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/TableAnalyzer.scala
new file mode 100644
index 0000000000..16707aa138
--- /dev/null
+++
b/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/TableAnalyzer.scala
@@ -0,0 +1,55 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.gluten.integration
+
+import org.apache.spark.sql.SparkSession
+
+trait TableAnalyzer {
+ def analyze(spark: SparkSession): Unit
+}
+
+object TableAnalyzer {
+ def noop(): TableAnalyzer = {
+ Noop
+ }
+
+ def analyzeAll(): TableAnalyzer = {
+ AnalyzeAll
+ }
+
+ private object Noop extends TableAnalyzer {
+ override def analyze(spark: SparkSession): Unit = {
+ // Do nothing.
+ }
+ }
+
+ private object AnalyzeAll extends TableAnalyzer {
+ override def analyze(spark: SparkSession): Unit = {
+ val tables = spark.catalog.listTables().collect()
+ tables.foreach {
+ tab =>
+ val tableName = tab.name
+ val tableColumnNames =
spark.catalog.listColumns(tableName).collect().map(c => c.name)
+ println(s"Analyzing catalog table: $tableName
[${tableColumnNames.mkString(", ")}]...")
+ spark.sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS")
+ spark.sql(
+ s"ANALYZE TABLE $tableName COMPUTE STATISTICS FOR COLUMNS
${tableColumnNames.mkString(", ")}")
+ println(s"Catalog table analyzed: $tableName.")
+ }
+ }
+ }
+}
diff --git
a/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/TableCreator.scala
b/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/TableCreator.scala
index f382f9aad7..2819645410 100644
---
a/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/TableCreator.scala
+++
b/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/TableCreator.scala
@@ -31,14 +31,15 @@ trait TableCreator {
}
object TableCreator {
- def discoverSchema(): TableCreator = {
- DiscoverSchema
+ def discoverFromFiles(): TableCreator = {
+ DiscoverFromFiles
}
- private object DiscoverSchema extends TableCreator {
+ /** Discover tables automatically from a given file system path. */
+ private object DiscoverFromFiles extends TableCreator {
override def create(spark: SparkSession, source: String, dataPath:
String): Unit = {
val uri = URI.create(dataPath)
- val fs = FileSystem.get(uri, new Configuration())
+ val fs = FileSystem.get(uri, spark.sessionState.newHadoopConf())
val basePath = new Path(dataPath)
val statuses = fs.listStatus(basePath)
@@ -57,7 +58,7 @@ object TableCreator {
tableNames += tableName
}
- println("Creating catalog tables: " + tableNames.mkString(", "))
+ println("Creating catalog tables: " + tableNames.mkString(", ") + "...")
tableDirs.foreach {
tablePath =>
@@ -80,13 +81,13 @@ object TableCreator {
return
}
if (existedTableNames.nonEmpty) {
- println("Tables already exists: " + existedTableNames.mkString(", "))
+ println("Tables already exists: " + existedTableNames.mkString(", ") +
".")
}
if (createdTableNames.nonEmpty) {
- println("Tables created: " + createdTableNames.mkString(", "))
+ println("Tables created: " + createdTableNames.mkString(", ") + ".")
}
if (recoveredPartitionTableNames.nonEmpty) {
- println("Recovered partition tables: " +
recoveredPartitionTableNames.mkString(", "))
+ println("Recovered partition tables: " +
recoveredPartitionTableNames.mkString(", ") + ".")
}
}
}
diff --git
a/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/action/DataGenOnly.scala
b/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/action/DataGenOnly.scala
index fbac7f284b..deee1d008e 100644
---
a/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/action/DataGenOnly.scala
+++
b/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/action/DataGenOnly.scala
@@ -18,31 +18,66 @@ package org.apache.gluten.integration.action
import org.apache.gluten.integration.Suite
-import java.io.File
+import org.apache.hadoop.fs.{FileSystem, Path}
case class DataGenOnly(strategy: DataGenOnly.Strategy) extends Action {
+
override def execute(suite: Suite): Boolean = {
+ suite.sessionSwitcher.useSession("baseline", "Data Gen")
+ val fs = this.fs(suite)
+ val markerPath = this.markerPath(suite)
+
strategy match {
case DataGenOnly.Skip =>
- // Do nothing
+ ()
+
case DataGenOnly.Once =>
- val dataPath = suite.dataWritePath()
- val alreadyExists = new File(dataPath).exists()
- if (alreadyExists) {
- println(s"Data already exists at $dataPath, skipping generating it.")
+ val dataPath = this.dataPath(suite)
+ if (fs.exists(dataPath) && fs.exists(markerPath)) {
+ println(s"Test data already generated at $dataPath. Skipping.")
} else {
+ if (fs.exists(dataPath)) {
+ println(
+ s"Test data exists at $dataPath but no completion marker found.
Regenerating."
+ )
+ fs.delete(dataPath, true)
+ }
+ if (fs.exists(markerPath)) {
+ fs.delete(markerPath, true)
+ }
gen(suite)
+ // Create marker after successful generation.
+ fs.create(markerPath, false).close()
}
+
case DataGenOnly.Always =>
gen(suite)
+ // Create marker after successful generation.
+ fs.create(markerPath, false).close()
}
true
}
+ private def fs(suite: Suite): FileSystem = {
+ val configuration =
suite.sessionSwitcher.spark().sessionState.newHadoopConf()
+ dataPath(suite).getFileSystem(configuration)
+ }
+
+ private def markerPath(suite: Suite): Path =
+ new Path(suite.dataWritePath() + ".completed")
+
+ private def dataPath(suite: Suite): Path =
+ new Path(suite.dataWritePath())
+
private def gen(suite: Suite): Unit = {
- suite.sessionSwitcher.useSession("baseline", "Data Gen")
+ val dataPath = suite.dataWritePath()
+
+ println(s"Generating test data to $dataPath...")
+
val dataGen = suite.createDataGen()
- dataGen.gen()
+ dataGen.gen(suite.sessionSwitcher.spark())
+
+ println(s"All test data successfully generated at $dataPath.")
}
}
diff --git
a/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/action/Parameterized.scala
b/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/action/Parameterized.scala
index 91a7391695..3368135122 100644
---
a/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/action/Parameterized.scala
+++
b/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/action/Parameterized.scala
@@ -26,6 +26,7 @@ import org.apache.gluten.integration.stat.RamStat
import org.apache.spark.sql.ConfUtils.ConfImplicits._
import org.apache.spark.sql.SparkSession
+import java.io.PrintStream
import java.util.concurrent.atomic.{AtomicInteger, AtomicLong}
import scala.collection.mutable
@@ -132,7 +133,7 @@ class Parameterized(
entry =>
val coordinate = entry._1
sessionSwitcher.useSession(coordinate.toString, "Parameterized
%s".format(coordinate))
- runner.createTables(suite.tableCreator(), sessionSwitcher.spark())
+ runner.createTables(suite.tableCreator(), suite.tableAnalyzer(),
sessionSwitcher.spark())
runQueryIds.flatMap {
queryId =>
@@ -150,7 +151,10 @@ class Parameterized(
} finally {
if (noSessionReuse) {
sessionSwitcher.renewSession()
- runner.createTables(suite.tableCreator(),
sessionSwitcher.spark())
+ runner.createTables(
+ suite.tableCreator(),
+ suite.tableAnalyzer(),
+ sessionSwitcher.spark())
}
}
}
@@ -173,7 +177,10 @@ class Parameterized(
} finally {
if (noSessionReuse) {
sessionSwitcher.renewSession()
- runner.createTables(suite.tableCreator(),
sessionSwitcher.spark())
+ runner.createTables(
+ suite.tableCreator(),
+ suite.tableAnalyzer(),
+ sessionSwitcher.spark())
}
}
TestResultLine.CoordMark(iteration, queryId, r)
@@ -207,18 +214,19 @@ class Parameterized(
RamStat.getJvmHeapTotal(),
RamStat.getProcessRamUsed()
)
-
- println("")
- println("Test report: ")
- println("")
- printf(
- "Summary: %d out of %d queries successfully run on all config
combinations. \n",
- succeededCount,
- totalCount)
- println("")
- println("Configurations:")
- coordinates.foreach(coord => println(s"${coord._1.id}. ${coord._1}"))
- println("")
+ println()
+
+ // Write out test report.
+ val reportAppender =
suite.getReporter().actionAppender(getClass.getSimpleName)
+ reportAppender.out.println("Test report: ")
+ reportAppender.out.println()
+ reportAppender.out.println(
+ "Summary: %d out of %d queries successfully run on all config
combinations."
+ .format(succeededCount, totalCount))
+ reportAppender.out.println()
+ reportAppender.out.println("Configurations:")
+ coordinates.foreach(coord => reportAppender.out.println(s"${coord._1.id}.
${coord._1}"))
+ reportAppender.out.println()
val succeeded = results.filter(_.succeeded())
val all = succeeded match {
case Nil => None
@@ -240,22 +248,22 @@ class Parameterized(
))
}
TestResultLines(coordinates.map(_._1.id).toSeq, configDimensions, metrics,
succeeded ++ all)
- .print()
- println("")
+ .print(reportAppender.out)
+ reportAppender.out.println()
if (succeededCount == totalCount) {
- println("No failed queries. ")
- println("")
+ reportAppender.out.println("No failed queries. ")
+ reportAppender.out.println()
} else {
- println("Failed queries: ")
- println("")
+ reportAppender.err.println("Failed queries: ")
+ reportAppender.err.println()
TestResultLines(
coordinates.map(_._1.id).toSeq,
configDimensions,
metrics,
results.filter(!_.succeeded()))
- .print()
- println("")
+ .print(reportAppender.err)
+ reportAppender.err.println()
}
if (succeededCount != totalCount) {
@@ -341,7 +349,7 @@ object Parameterized {
configDimensions: Seq[Dim],
metricNames: Seq[String],
lines: Iterable[TestResultLine]) {
- def print(): Unit = {
+ def print(out: PrintStream): Unit = {
val coordFields: Seq[Field] = coordIds.map(id => Field.Leaf(id.toString))
val fields: Seq[Field] =
Seq(Field.Leaf("Query ID")) ++
@@ -357,7 +365,7 @@ object Parameterized {
)
lines.foreach(line => render.appendRow(line))
- render.print(System.out)
+ render.print(out)
}
}
diff --git
a/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/action/Queries.scala
b/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/action/Queries.scala
index 4cf308835e..c88cdbb32a 100644
---
a/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/action/Queries.scala
+++
b/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/action/Queries.scala
@@ -25,6 +25,8 @@ import org.apache.gluten.integration.stat.RamStat
import org.apache.spark.sql.SparkSession
+import java.io.PrintStream
+
case class Queries(
queries: QuerySelector,
explain: Boolean,
@@ -42,7 +44,7 @@ case class Queries(
new QueryRunner(suite.dataSource(), suite.dataWritePath())
val sessionSwitcher = suite.sessionSwitcher
sessionSwitcher.useSession("test", "Run Queries")
- runner.createTables(suite.tableCreator(), sessionSwitcher.spark())
+ runner.createTables(suite.tableCreator(), suite.tableAnalyzer(),
sessionSwitcher.spark())
val results = (0 until iterations).flatMap {
iteration =>
println(s"Running tests (iteration $iteration)...")
@@ -61,7 +63,10 @@ case class Queries(
} finally {
if (noSessionReuse) {
sessionSwitcher.renewSession()
- runner.createTables(suite.tableCreator(),
sessionSwitcher.spark())
+ runner.createTables(
+ suite.tableCreator(),
+ suite.tableAnalyzer(),
+ sessionSwitcher.spark())
}
}
}
@@ -73,19 +78,6 @@ case class Queries(
val failedQueries = results.filter(!_.queryResult.succeeded())
println()
-
- if (failedQueries.nonEmpty) {
- println(s"There are failed queries.")
- if (!suppressFailureMessages) {
- println()
- failedQueries.foreach {
- failedQuery =>
- println(
- s"Query ${failedQuery.queryResult.caseId()} failed by error:
${failedQuery.queryResult.asFailure().error}")
- }
- }
- }
-
// RAM stats
println("Performing GC to collect RAM statistics... ")
System.gc()
@@ -96,33 +88,47 @@ case class Queries(
RamStat.getJvmHeapTotal(),
RamStat.getProcessRamUsed()
)
- println("")
+ println()
+
+ // Write out test report.
+ val reportAppender =
suite.getReporter().actionAppender(getClass.getSimpleName)
+ if (failedQueries.nonEmpty) {
+ reportAppender.err.println(s"There are failed queries.")
+ if (!suppressFailureMessages) {
+ reportAppender.err.println()
+ failedQueries.foreach {
+ failedQuery =>
+ println(
+ s"Query ${failedQuery.queryResult.caseId()} failed by error:
${failedQuery.queryResult.asFailure().error}")
+ }
+ }
+ }
val sqlMetrics =
succeededQueries.flatMap(_.queryResult.asSuccess().runResult.sqlMetrics)
metricsReporters.foreach {
r =>
val report = r.toString(sqlMetrics)
- println(report)
- println("")
+ reportAppender.out.println(report)
+ reportAppender.out.println()
}
- println("Test report: ")
- println("")
- printf("Summary: %d out of %d queries passed. \n", passedCount, count)
- println("")
+ reportAppender.out.println("Test report: ")
+ reportAppender.out.println()
+ reportAppender.out.println("Summary: %d out of %d queries
passed.".format(passedCount, count))
+ reportAppender.out.println()
val all =
succeededQueries.map(_.queryResult).asSuccesses().agg("all").map(s =>
TestResultLine(s))
- Queries.printResults(succeededQueries ++ all)
- println("")
+ Queries.printResults(reportAppender.out, succeededQueries ++ all)
+ reportAppender.out.println()
if (failedQueries.isEmpty) {
- println("No failed queries. ")
- println("")
+ reportAppender.out.println("No failed queries. ")
+ reportAppender.out.println()
} else {
- println("Failed queries: ")
- println("")
- Queries.printResults(failedQueries)
- println("")
+ reportAppender.err.println("Failed queries: ")
+ reportAppender.err.println()
+ Queries.printResults(reportAppender.err, failedQueries)
+ reportAppender.err.println()
}
if (passedCount != count) {
@@ -155,7 +161,7 @@ object Queries {
}
}
- private def printResults(results: Seq[TestResultLine]): Unit = {
+ private def printResults(out: PrintStream, results: Seq[TestResultLine]):
Unit = {
val render = TableRender.plain[TestResultLine](
"Query ID",
"Was Passed",
@@ -165,7 +171,7 @@ object Queries {
results.foreach(line => render.appendRow(line))
- render.print(System.out)
+ render.print(out)
}
private def runQuery(
diff --git
a/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/action/QueriesCompare.scala
b/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/action/QueriesCompare.scala
index 412150ce80..e5f0c381e9 100644
---
a/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/action/QueriesCompare.scala
+++
b/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/action/QueriesCompare.scala
@@ -25,6 +25,8 @@ import org.apache.gluten.integration.stat.RamStat
import org.apache.spark.sql.{SparkSession, TestUtils}
+import java.io.PrintStream
+
case class QueriesCompare(
queries: QuerySelector,
explain: Boolean,
@@ -40,7 +42,7 @@ case class QueriesCompare(
val sessionSwitcher = suite.sessionSwitcher
sessionSwitcher.useSession("baseline", "Run Baseline Queries")
- runner.createTables(suite.tableCreator(), sessionSwitcher.spark())
+ runner.createTables(suite.tableCreator(), suite.tableAnalyzer(),
sessionSwitcher.spark())
val baselineResults = (0 until iterations).flatMap {
iteration =>
querySet.queries.map {
@@ -56,14 +58,17 @@ case class QueriesCompare(
} finally {
if (noSessionReuse) {
sessionSwitcher.renewSession()
- runner.createTables(suite.tableCreator(),
sessionSwitcher.spark())
+ runner.createTables(
+ suite.tableCreator(),
+ suite.tableAnalyzer(),
+ sessionSwitcher.spark())
}
}
}
}.toList
sessionSwitcher.useSession("test", "Run Test Queries")
- runner.createTables(suite.tableCreator(), sessionSwitcher.spark())
+ runner.createTables(suite.tableCreator(), suite.tableAnalyzer(),
sessionSwitcher.spark())
val testResults = (0 until iterations).flatMap {
iteration =>
querySet.queries.map {
@@ -79,7 +84,10 @@ case class QueriesCompare(
} finally {
if (noSessionReuse) {
sessionSwitcher.renewSession()
- runner.createTables(suite.tableCreator(),
sessionSwitcher.spark())
+ runner.createTables(
+ suite.tableCreator(),
+ suite.tableAnalyzer(),
+ sessionSwitcher.spark())
}
}
}
@@ -98,19 +106,6 @@ case class QueriesCompare(
val succeededQueries = results.filter(_.testPassed())
val failedQueries = results.filter(!_.testPassed)
- println()
-
- if (failedQueries.nonEmpty) {
- println(s"There are failed queries.")
- if (!suppressFailureMessages) {
- println()
- failedQueries.foreach {
- failedQuery =>
- println(s"Query ${failedQuery.queryId} failed by error:
${failedQuery.error()}")
- }
- }
- }
-
// RAM stats
println("Performing GC to collect RAM statistics... ")
System.gc()
@@ -121,12 +116,27 @@ case class QueriesCompare(
RamStat.getJvmHeapTotal(),
RamStat.getProcessRamUsed()
)
+ println()
+
+ // Write out test report.
+ val reportAppender =
suite.getReporter().actionAppender(getClass.getSimpleName)
+ if (failedQueries.nonEmpty) {
+ reportAppender.err.println(s"There are failed queries.")
+ if (!suppressFailureMessages) {
+ reportAppender.err.println()
+ failedQueries.foreach {
+ failedQuery =>
+ reportAppender.err.println(
+ s"Query ${failedQuery.queryId} failed by error:
${failedQuery.error()}")
+ }
+ }
+ }
- println("")
- println("Test report: ")
- println("")
- printf("Summary: %d out of %d queries passed. \n", passedCount, count)
- println("")
+ reportAppender.out.println()
+ reportAppender.out.println("Test report: ")
+ reportAppender.out.println()
+ reportAppender.out.println("Summary: %d out of %d queries
passed.".format(passedCount, count))
+ reportAppender.out.println()
val all = succeededQueries match {
case Nil => None
case several =>
@@ -134,17 +144,18 @@ case class QueriesCompare(
val allActual = several.map(_.actual).asSuccesses().agg("all
actual").get
Some(TestResultLine("all", allExpected, allActual))
}
- QueriesCompare.printResults(succeededQueries ++ all)
- println("")
+ QueriesCompare.printResults(reportAppender.out, succeededQueries ++ all)
+ reportAppender.out.println()
if (failedQueries.isEmpty) {
- println("No failed queries. ")
- println("")
+ reportAppender.out.println("No failed queries. ")
+ reportAppender.out.println()
} else {
- println("Failed queries (a failed query with correct row count indicates
value mismatches): ")
- println("")
- QueriesCompare.printResults(failedQueries)
- println("")
+ reportAppender.err.println(
+ "Failed queries (a failed query with correct row count indicates value
mismatches): ")
+ reportAppender.err.println()
+ QueriesCompare.printResults(reportAppender.err, failedQueries)
+ reportAppender.err.println()
}
if (passedCount != count) {
@@ -204,7 +215,7 @@ object QueriesCompare {
}
}
- private def printResults(results: Seq[TestResultLine]): Unit = {
+ private def printResults(out: PrintStream, results: Seq[TestResultLine]):
Unit = {
import org.apache.gluten.integration.action.TableRender.Field._
val render = TableRender.create[TestResultLine](
@@ -218,7 +229,7 @@ object QueriesCompare {
results.foreach(line => render.appendRow(line))
- render.print(System.out)
+ render.print(out)
}
private def runBaselineQuery(
diff --git
a/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/action/SparkShell.scala
b/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/action/SparkShell.scala
index f920977eea..24977c2540 100644
---
a/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/action/SparkShell.scala
+++
b/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/action/SparkShell.scala
@@ -25,7 +25,7 @@ case class SparkShell() extends Action {
suite.sessionSwitcher.useSession("test", "Spark CLI")
val runner: QueryRunner =
new QueryRunner(suite.dataSource(), suite.dataWritePath())
- runner.createTables(suite.tableCreator(), suite.sessionSwitcher.spark())
+ runner.createTables(suite.tableCreator(), suite.tableAnalyzer(),
suite.sessionSwitcher.spark())
Main.sparkSession = suite.sessionSwitcher.spark()
Main.sparkContext = suite.sessionSwitcher.spark().sparkContext
Main.main(Array("-usejavacp"))
diff --git
a/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/clickbench/ClickBenchDataGen.scala
b/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/clickbench/ClickBenchDataGen.scala
index 7e75b153ee..83ffb270f1 100644
---
a/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/clickbench/ClickBenchDataGen.scala
+++
b/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/clickbench/ClickBenchDataGen.scala
@@ -27,9 +27,9 @@ import java.io.File
import scala.language.postfixOps
import scala.sys.process._
-class ClickBenchDataGen(spark: SparkSession, dir: String) extends DataGen {
+class ClickBenchDataGen(dir: String) extends DataGen {
import ClickBenchDataGen._
- override def gen(): Unit = {
+ override def gen(spark: SparkSession): Unit = {
println(s"Start to download ClickBench Parquet dataset from URL:
$DATA_URL... ")
// Directly download from official URL.
val tempFile = new File(dir + File.separator + TMP_FILE_NAME)
diff --git
a/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/clickbench/ClickBenchSuite.scala
b/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/clickbench/ClickBenchSuite.scala
index 5e07211af3..200bf01e06 100644
---
a/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/clickbench/ClickBenchSuite.scala
+++
b/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/clickbench/ClickBenchSuite.scala
@@ -16,9 +16,10 @@
*/
package org.apache.gluten.integration.clickbench
-import org.apache.gluten.integration.{DataGen, QuerySet, Suite, TableCreator}
+import org.apache.gluten.integration.{DataGen, QuerySet, Suite, TableAnalyzer,
TableCreator}
import org.apache.gluten.integration.action.Action
import org.apache.gluten.integration.metrics.MetricMapper
+import org.apache.gluten.integration.report.TestReporter
import org.apache.spark.SparkConf
@@ -49,11 +50,13 @@ class ClickBenchSuite(
val disableAqe: Boolean,
val disableBhj: Boolean,
val disableWscg: Boolean,
+ val enableCbo: Boolean,
val shufflePartitions: Int,
val scanPartitions: Int,
val decimalAsDouble: Boolean,
val baselineMetricMapper: MetricMapper,
- val testMetricMapper: MetricMapper)
+ val testMetricMapper: MetricMapper,
+ val reportPath: String)
extends Suite(
masterUrl,
actions,
@@ -68,11 +71,13 @@ class ClickBenchSuite(
disableAqe,
disableBhj,
disableWscg,
+ enableCbo,
shufflePartitions,
scanPartitions,
decimalAsDouble,
baselineMetricMapper,
- testMetricMapper
+ testMetricMapper,
+ reportPath
) {
import ClickBenchSuite._
@@ -84,7 +89,7 @@ class ClickBenchSuite(
override private[integration] def createDataGen(): DataGen = {
checkDataGenArgs(dataSource, dataScale, genPartitionedData)
- new ClickBenchDataGen(sessionSwitcher.spark(), dataWritePath())
+ new ClickBenchDataGen(dataWritePath())
}
override private[integration] def allQueries(): QuerySet = {
@@ -94,6 +99,10 @@ class ClickBenchSuite(
override private[integration] def desc(): String = "ClickBench"
override def tableCreator(): TableCreator = ClickBenchTableCreator
+
+ override def tableAnalyzer0(): TableAnalyzer = {
+ TableAnalyzer.analyzeAll()
+ }
}
private object ClickBenchSuite {
diff --git
a/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/ds/TpcdsDataGen.scala
b/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/ds/TpcdsDataGen.scala
index 0c4bf94c71..b80709edb7 100644
---
a/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/ds/TpcdsDataGen.scala
+++
b/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/ds/TpcdsDataGen.scala
@@ -16,7 +16,8 @@
*/
package org.apache.gluten.integration.ds
-import org.apache.gluten.integration.{DataGen, ShimUtils, TypeModifier}
+import org.apache.gluten.integration.{DataGen, TypeModifier}
+import org.apache.gluten.integration.shim.Shim
import org.apache.spark.sql.{Column, Row, SaveMode, SparkSession}
import org.apache.spark.sql.types._
@@ -28,7 +29,6 @@ import java.io.File
import scala.collection.JavaConverters._
class TpcdsDataGen(
- spark: SparkSession,
scale: Double,
partitions: Int,
source: String,
@@ -46,7 +46,7 @@ class TpcdsDataGen(
private val features = featureNames.map(featureRegistry.getFeature)
- def writeParquetTable(t: Table): Unit = {
+ def writeParquetTable(spark: SparkSession, t: Table): Unit = {
val name = t.getName
if (name.equals("dbgen_version")) {
return
@@ -88,10 +88,11 @@ class TpcdsDataGen(
}
}
- writeParquetTable(name, t, schema, partitionBy)
+ writeParquetTable(spark, name, t, schema, partitionBy)
}
private def writeParquetTable(
+ spark: SparkSession,
tableName: String,
t: Table,
schema: StructType,
@@ -124,7 +125,7 @@ class TpcdsDataGen(
val array: Array[String] =
parentAndChildRow.get(0).asScala.toArray
Row(array: _*)
}
- }(ShimUtils.getExpressionEncoder(stringSchema))
+ }(Shim.getExpressionEncoder(stringSchema))
.select(columns: _*)
.write
.format(source)
@@ -134,8 +135,8 @@ class TpcdsDataGen(
.saveAsTable(tableName)
}
- override def gen(): Unit = {
- Table.getBaseTables.forEach(t => writeParquetTable(t))
+ override def gen(spark: SparkSession): Unit = {
+ Table.getBaseTables.forEach(t => writeParquetTable(spark, t))
features.foreach(feature => DataGen.Feature.run(spark, source, feature))
}
diff --git
a/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/ds/TpcdsSuite.scala
b/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/ds/TpcdsSuite.scala
index 9293e1a09d..0de9ad1c36 100644
---
a/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/ds/TpcdsSuite.scala
+++
b/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/ds/TpcdsSuite.scala
@@ -16,16 +16,16 @@
*/
package org.apache.gluten.integration.ds
-import org.apache.gluten.integration.{DataGen, QuerySet, Suite, TableCreator}
+import org.apache.gluten.integration.{DataGen, QuerySet, Suite, TableAnalyzer,
TableCreator}
import org.apache.gluten.integration.action.Action
import org.apache.gluten.integration.metrics.MetricMapper
+import org.apache.gluten.integration.report.TestReporter
import org.apache.spark.SparkConf
+import org.apache.hadoop.fs.Path
import org.apache.log4j.Level
-import java.io.File
-
class TpcdsSuite(
val masterUrl: String,
val actions: Array[Action],
@@ -45,11 +45,13 @@ class TpcdsSuite(
val disableAqe: Boolean,
val disableBhj: Boolean,
val disableWscg: Boolean,
+ val enableCbo: Boolean,
val shufflePartitions: Int,
val scanPartitions: Int,
val decimalAsDouble: Boolean,
val baselineMetricMapper: MetricMapper,
- val testMetricMapper: MetricMapper)
+ val testMetricMapper: MetricMapper,
+ val reportPath: String)
extends Suite(
masterUrl,
actions,
@@ -64,11 +66,13 @@ class TpcdsSuite(
disableAqe,
disableBhj,
disableWscg,
+ enableCbo,
shufflePartitions,
scanPartitions,
decimalAsDouble,
baselineMetricMapper,
- testMetricMapper
+ testMetricMapper,
+ reportPath
) {
import TpcdsSuite._
@@ -81,19 +85,14 @@ class TpcdsSuite(
"non_partitioned"
}
val featureFlags = dataGenFeatures.map(feature =>
s"-$feature").mkString("")
- if (dataDir.startsWith("hdfs://") || dataDir.startsWith("s3a://")) {
- return
s"$dataDir/$TPCDS_WRITE_RELATIVE_PATH-$dataScale-$dataSource-$partitionedFlag$featureFlags"
- }
- new File(dataDir).toPath
-
.resolve(s"$TPCDS_WRITE_RELATIVE_PATH-$dataScale-$dataSource-$partitionedFlag$featureFlags")
- .toFile
- .getAbsolutePath
+ val relative =
+
s"$TPCDS_WRITE_RELATIVE_PATH-$dataScale-$dataSource-$partitionedFlag$featureFlags"
+ new Path(dataDir, relative).toString
}
override private[integration] def createDataGen(): DataGen = {
checkDataGenArgs(dataSource, dataScale, genPartitionedData)
new TpcdsDataGen(
- sessionSwitcher.spark(),
dataScale,
shufflePartitions,
dataSource,
@@ -109,7 +108,9 @@ class TpcdsSuite(
override private[integration] def desc(): String = "TPC-DS"
- override def tableCreator(): TableCreator = TableCreator.discoverSchema()
+ override def tableCreator(): TableCreator = TableCreator.discoverFromFiles()
+
+ override def tableAnalyzer0(): TableAnalyzer = TableAnalyzer.analyzeAll()
}
object TpcdsSuite {
diff --git
a/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/h/TpchDataGen.scala
b/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/h/TpchDataGen.scala
index aed8653f62..54c549f835 100644
---
a/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/h/TpchDataGen.scala
+++
b/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/h/TpchDataGen.scala
@@ -16,7 +16,8 @@
*/
package org.apache.gluten.integration.h
-import org.apache.gluten.integration.{DataGen, ShimUtils, TypeModifier}
+import org.apache.gluten.integration.{DataGen, TypeModifier}
+import org.apache.gluten.integration.shim.Shim
import org.apache.spark.sql.{Row, SaveMode, SparkSession}
import org.apache.spark.sql.types._
@@ -29,7 +30,6 @@ import java.sql.Date
import scala.collection.JavaConverters._
class TpchDataGen(
- spark: SparkSession,
scale: Double,
partitions: Int,
source: String,
@@ -42,21 +42,22 @@ class TpchDataGen(
private val featureRegistry = new DataGen.FeatureRegistry
private val features = featureNames.map(featureRegistry.getFeature)
- override def gen(): Unit = {
- generate(dir, "lineitem", lineItemSchema, partitions, lineItemGenerator,
lineItemParser)
- generate(dir, "customer", customerSchema, partitions, customerGenerator,
customerParser)
- generate(dir, "orders", orderSchema, partitions, orderGenerator,
orderParser)
+ override def gen(spark: SparkSession): Unit = {
+ generate(spark, dir, "lineitem", lineItemSchema, partitions,
lineItemGenerator, lineItemParser)
+ generate(spark, dir, "customer", customerSchema, partitions,
customerGenerator, customerParser)
+ generate(spark, dir, "orders", orderSchema, partitions, orderGenerator,
orderParser)
generate(
+ spark,
dir,
"partsupp",
partSupplierSchema,
partitions,
partSupplierGenerator,
partSupplierParser)
- generate(dir, "supplier", supplierSchema, partitions, supplierGenerator,
supplierParser)
- generate(dir, "nation", nationSchema, nationGenerator, nationParser)
- generate(dir, "part", partSchema, partitions, partGenerator, partParser)
- generate(dir, "region", regionSchema, regionGenerator, regionParser)
+ generate(spark, dir, "supplier", supplierSchema, partitions,
supplierGenerator, supplierParser)
+ generate(spark, dir, "nation", nationSchema, nationGenerator, nationParser)
+ generate(spark, dir, "part", partSchema, partitions, partGenerator,
partParser)
+ generate(spark, dir, "region", regionSchema, regionGenerator, regionParser)
features.foreach(feature => DataGen.Feature.run(spark, source, feature))
}
@@ -294,12 +295,14 @@ class TpchDataGen(
// gen tpc-h data
private def generate[U](
+ spark: SparkSession,
dir: String,
tableName: String,
schema: StructType,
gen: () => java.lang.Iterable[U],
parser: U => Row): Unit = {
generate(
+ spark,
dir,
tableName,
schema,
@@ -311,6 +314,7 @@ class TpchDataGen(
}
private def generate[U](
+ spark: SparkSession,
dir: String,
tableName: String,
schema: StructType,
@@ -341,7 +345,7 @@ class TpchDataGen(
modifiedRow
}
rows
- }(ShimUtils.getExpressionEncoder(modifiedSchema))
+ }(Shim.getExpressionEncoder(modifiedSchema))
.write
.format(source)
.mode(SaveMode.Overwrite)
diff --git
a/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/h/TpchSuite.scala
b/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/h/TpchSuite.scala
index af36cc4946..6437b301ba 100644
---
a/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/h/TpchSuite.scala
+++
b/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/h/TpchSuite.scala
@@ -16,12 +16,15 @@
*/
package org.apache.gluten.integration.h
-import org.apache.gluten.integration.{DataGen, QuerySet, Suite, TableCreator}
+import org.apache.gluten.integration.{DataGen, QuerySet, Suite, TableAnalyzer,
TableCreator}
import org.apache.gluten.integration.action.Action
+import org.apache.gluten.integration.ds.TpcdsSuite.TPCDS_WRITE_RELATIVE_PATH
import org.apache.gluten.integration.metrics.MetricMapper
+import org.apache.gluten.integration.report.TestReporter
import org.apache.spark.SparkConf
+import org.apache.hadoop.fs.Path
import org.apache.log4j.Level
import java.io.File
@@ -45,11 +48,13 @@ class TpchSuite(
val disableAqe: Boolean,
val disableBhj: Boolean,
val disableWscg: Boolean,
+ val enableCbo: Boolean,
val shufflePartitions: Int,
val scanPartitions: Int,
val decimalAsDouble: Boolean,
val baselineMetricMapper: MetricMapper,
- val testMetricMapper: MetricMapper)
+ val testMetricMapper: MetricMapper,
+ val reportPath: String)
extends Suite(
masterUrl,
actions,
@@ -64,11 +69,13 @@ class TpchSuite(
disableAqe,
disableBhj,
disableWscg,
+ enableCbo,
shufflePartitions,
scanPartitions,
decimalAsDouble,
baselineMetricMapper,
- testMetricMapper
+ testMetricMapper,
+ reportPath
) {
import TpchSuite._
@@ -76,19 +83,14 @@ class TpchSuite(
override private[integration] def dataWritePath(): String = {
val featureFlags = dataGenFeatures.map(feature =>
s"-$feature").mkString("")
- if (dataDir.startsWith("hdfs://") || dataDir.startsWith("s3a://")) {
- return
s"$dataDir/$TPCH_WRITE_RELATIVE_PATH-$dataScale-$dataSource$featureFlags"
- }
- new File(dataDir).toPath
-
.resolve(s"$TPCH_WRITE_RELATIVE_PATH-$dataScale-$dataSource$featureFlags")
- .toFile
- .getAbsolutePath
+ val relative =
+ s"$TPCH_WRITE_RELATIVE_PATH-$dataScale-$dataSource$featureFlags"
+ new Path(dataDir, relative).toString
}
override private[integration] def createDataGen(): DataGen = {
checkDataGenArgs(dataSource, dataScale, genPartitionedData)
new TpchDataGen(
- sessionSwitcher.spark(),
dataScale,
shufflePartitions,
dataSource,
@@ -103,7 +105,9 @@ class TpchSuite(
override private[integration] def desc(): String = "TPC-H"
- override def tableCreator(): TableCreator = TableCreator.discoverSchema()
+ override def tableCreator(): TableCreator = TableCreator.discoverFromFiles()
+
+ override def tableAnalyzer0(): TableAnalyzer = TableAnalyzer.analyzeAll()
}
object TpchSuite {
diff --git
a/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/metrics/MetricMapper.scala
b/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/metrics/MetricMapper.scala
index 27189275db..9133bb7e26 100644
---
a/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/metrics/MetricMapper.scala
+++
b/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/metrics/MetricMapper.scala
@@ -20,18 +20,19 @@ import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.metric.SQLMetric
trait MetricMapper {
- def map(node: SparkPlan, key: String, metric: SQLMetric): Seq[MetricTag[_]]
+ def map(node: SparkPlan, key: String, metric: SQLMetric): Seq[MetricTag]
}
object MetricMapper {
val dummy: MetricMapper = (node: SparkPlan, key: String, metric: SQLMetric)
=> Nil
- case class SelfTimeMapper(selfTimeKeys: Map[String, Set[String]]) extends
MetricMapper {
- override def map(node: SparkPlan, key: String, metric: SQLMetric):
Seq[MetricTag[_]] = {
+ case class SimpleMetricMapper(tags: Seq[MetricTag], selfTimeKeys:
Map[String, Set[String]])
+ extends MetricMapper {
+ override def map(node: SparkPlan, key: String, metric: SQLMetric):
Seq[MetricTag] = {
val className = node.getClass.getSimpleName
if (selfTimeKeys.contains(className)) {
if (selfTimeKeys(className).contains(key)) {
- return Seq(MetricTag.IsSelfTime())
+ return tags
}
}
Nil
@@ -52,7 +53,7 @@ object MetricMapper {
private class ChainedTypeMetricMapper(val mappers: Seq[MetricMapper])
extends MetricMapper {
assert(!mappers.exists(_.isInstanceOf[ChainedTypeMetricMapper]))
- override def map(node: SparkPlan, key: String, metric: SQLMetric):
Seq[MetricTag[_]] = {
+ override def map(node: SparkPlan, key: String, metric: SQLMetric):
Seq[MetricTag] = {
mappers.flatMap(m => m.map(node, key, metric))
}
}
diff --git
a/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/metrics/MetricTag.scala
b/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/metrics/MetricTag.scala
index 5cc2de04e8..b5c489fac4 100644
---
a/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/metrics/MetricTag.scala
+++
b/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/metrics/MetricTag.scala
@@ -16,23 +16,13 @@
*/
package org.apache.gluten.integration.metrics
-import scala.reflect.{classTag, ClassTag}
-
-trait MetricTag[T] {
- import MetricTag._
- final def name(): String = nameOf(ClassTag(this.getClass))
- def value(): T
+trait MetricTag {
+ final def name(): String =
s"${this.getClass}-${System.identityHashCode(this)}"
}
object MetricTag {
- def nameOf[T <: MetricTag[_]: ClassTag]: String = {
- val clazz = classTag[T].runtimeClass
- assert(classOf[MetricTag[_]].isAssignableFrom(clazz))
- clazz.getSimpleName
- }
- case class IsSelfTime() extends MetricTag[Nothing] {
- override def value(): Nothing = {
- throw new UnsupportedOperationException()
- }
- }
+ object IsSelfTime extends MetricTag
+ object IsJoinProbeInputNumRows extends MetricTag
+ object IsJoinProbeOutputNumRows extends MetricTag
+ object IsJoinOutputNumRows extends MetricTag
}
diff --git
a/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/metrics/PlanMetric.scala
b/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/metrics/PlanMetric.scala
index 611f5c4067..239bf08afe 100644
---
a/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/metrics/PlanMetric.scala
+++
b/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/metrics/PlanMetric.scala
@@ -34,16 +34,10 @@ case class PlanMetric(
plan: SparkPlan,
key: String,
metric: SQLMetric,
- tags: Map[String, Seq[MetricTag[_]]]) {
+ tags: Set[MetricTag]) {
- def containsTags[T <: MetricTag[_]: ClassTag]: Boolean = {
- val name = MetricTag.nameOf[T]
- tags.contains(name)
- }
- def getTags[T <: MetricTag[_]: ClassTag]: Seq[T] = {
- require(containsTags[T])
- val name = MetricTag.nameOf[T]
- tags(name).asInstanceOf[Seq[T]]
+ def containsTags(tag: MetricTag): Boolean = {
+ tags.contains(tag)
}
}
@@ -55,6 +49,8 @@ object PlanMetric {
new NodeTimeReporter(10),
new StepTimeReporter(30)
))
+ case "join-selectivity" =>
+ new ChainedReporter(Seq(new JoinSelectivityReporter(30)))
case other => throw new IllegalArgumentException(s"Metric reporter type
$other not defined")
}
@@ -83,7 +79,7 @@ object PlanMetric {
override def toString(metrics: Seq[PlanMetric]): String = {
val sb = new StringBuilder()
val selfTimes = metrics
- .filter(_.containsTags[MetricTag.IsSelfTime])
+ .filter(_.containsTags(MetricTag.IsSelfTime))
val sorted = selfTimes.sortBy(m =>
toNanoTime(m.metric))(Ordering.Long.reverse)
sb.append(s"Top $topN computation steps that took longest time to
execute: ")
sb.append(System.lineSeparator())
@@ -96,10 +92,10 @@ object PlanMetric {
Leaf("Step Time (ns)"))
for (i <- 0 until (topN.min(sorted.size))) {
val m = sorted(i)
- val f = new File(m.queryPath).toPath.getFileName.toString
+ val queryPath = new File(m.queryPath).toPath.getFileName.toString
tr.appendRow(
Seq(
- f,
+ queryPath,
m.plan.id.toString,
m.plan.nodeName,
s"[${m.metric.name.getOrElse("")}]
${toNanoTime(m.metric).toString}"))
@@ -121,7 +117,7 @@ object PlanMetric {
override def toString(metrics: Seq[PlanMetric]): String = {
val sb = new StringBuilder()
val selfTimes = metrics
- .filter(_.containsTags[MetricTag.IsSelfTime])
+ .filter(_.containsTags(MetricTag.IsSelfTime))
val rows: Seq[TableRow] = selfTimes
.groupBy(m => m.plan.id)
.toSeq
@@ -165,4 +161,81 @@ object PlanMetric {
selfTimeNs: Long,
metrics: Seq[(String, SQLMetric)])
}
+
+ private class JoinSelectivityReporter(topN: Int) extends Reporter {
+ import JoinSelectivityReporter._
+ private def toNumRows(m: SQLMetric): Long = m.metricType match {
+ case "sum" => m.value
+ }
+
+ override def toString(metrics: Seq[PlanMetric]): String = {
+ val sb = new StringBuilder()
+ sb.append(s"Top $topN join operations that has lowest selectivity: ")
+ sb.append(System.lineSeparator())
+ sb.append(System.lineSeparator())
+
+ val tr: TableRender[Seq[String]] =
+ TableRender.create(
+ Leaf("Query"),
+ Leaf("Node ID"),
+ Leaf("Node Name"),
+ Leaf("Input Row Count"),
+ Leaf("Output Row Count"),
+ Leaf("Selectivity"))
+ val probeInputNumRows = metrics
+ .filter(_.containsTags(MetricTag.IsJoinProbeInputNumRows))
+ .groupBy(m => m.plan.id)
+ .toSeq
+ .sortBy(_._1)
+ val probeOutputNumRows = metrics
+ .filter(_.containsTags(MetricTag.IsJoinProbeOutputNumRows))
+ .groupBy(m => m.plan.id)
+ .toSeq
+ .sortBy(_._1)
+ assert(probeInputNumRows.size == probeOutputNumRows.size)
+ val rows = probeInputNumRows
+ .zip(probeOutputNumRows)
+ .map {
+ case ((id1, inputMetrics), (id2, outputMetrics)) =>
+ assert(id1 == id2)
+ val queryPath = new
File(inputMetrics.head.queryPath).toPath.getFileName.toString
+ val inputNumRows = inputMetrics.map(m => toNumRows(m.metric)).sum
+ val outputNumRows = outputMetrics.map(m => toNumRows(m.metric)).sum
+ val selectivity = outputNumRows.toDouble / inputNumRows.toDouble
+ TableRow(
+ queryPath,
+ id1,
+ inputMetrics.head.plan.nodeName,
+ inputNumRows,
+ outputNumRows,
+ selectivity)
+ }
+ .sortBy(_.selectivity)
+ for (i <- 0 until (topN.min(rows.size))) {
+ val row = rows(i)
+ tr.appendRow(
+ Seq(
+ row.queryPath,
+ row.planId.toString,
+ row.planNodeName,
+ row.probeInputNumRows.toString,
+ row.probeOutputNumRows.toString,
+ "%.3f".format(row.selectivity)))
+ }
+ val out = new ByteArrayOutputStream()
+ tr.print(out)
+ sb.append(out.toString(Charset.defaultCharset))
+ sb.toString()
+ }
+ }
+
+ private object JoinSelectivityReporter {
+ private case class TableRow(
+ queryPath: String,
+ planId: Long,
+ planNodeName: String,
+ probeInputNumRows: Long,
+ probeOutputNumRows: Long,
+ selectivity: Double)
+ }
}
diff --git
a/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/report/TestReporter.scala
b/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/report/TestReporter.scala
new file mode 100644
index 0000000000..5c5822d1dd
--- /dev/null
+++
b/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/report/TestReporter.scala
@@ -0,0 +1,140 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.gluten.integration.report
+
+import java.io.{ByteArrayOutputStream, OutputStream, PrintStream, PrintWriter}
+import java.time.{Instant, ZoneId}
+import java.time.format.DateTimeFormatter
+
+import scala.collection.mutable
+
+trait TestReporter {
+ def addMetadata(key: String, value: String): Unit
+ def rootAppender(): TestReporter.Appender
+ def actionAppender(actionName: String): TestReporter.Appender
+ def write(out: OutputStream): Unit
+}
+
+object TestReporter {
+ trait Appender {
+ def out: PrintStream
+ def err: PrintStream
+ }
+
+ private class AppenderImpl extends Appender {
+ val outStream = new ByteArrayOutputStream()
+ val errStream = new ByteArrayOutputStream()
+
+ override val out: PrintStream = new PrintStream(outStream)
+ override val err: PrintStream = new PrintStream(errStream)
+ }
+
+ def create(): TestReporter = {
+ new Impl()
+ }
+
+ private class Impl() extends TestReporter {
+ private val rootAppenderName = "__ROOT__"
+ private val metadataMap = mutable.LinkedHashMap[String, String]()
+ private val appenderMap = mutable.LinkedHashMap[String, AppenderImpl]()
+
+ override def addMetadata(key: String, value: String): Unit = {
+ metadataMap += key -> value
+ }
+
+ override def rootAppender(): Appender = {
+ appenderMap.getOrElseUpdate(rootAppenderName, new AppenderImpl)
+ }
+
+ override def actionAppender(actionName: String): Appender = {
+ require(actionName != rootAppenderName)
+ appenderMap.getOrElseUpdate(actionName, new AppenderImpl)
+ }
+
+ override def write(out: OutputStream): Unit = {
+ val writer = new PrintWriter(out)
+
+ def line(): Unit =
+ writer.println("========================================")
+
+ def subLine(): Unit =
+ writer.println("----------------------------------------")
+
+ def printStreamBlock(label: String, content: String, indent: String):
Unit = {
+ if (content.nonEmpty) {
+ writer.println(s"$indent$label:")
+ subLine()
+ content.linesIterator.foreach(l => writer.println(s"$indent $l"))
+ writer.println()
+ }
+ }
+
+ line()
+ writer.println(" TEST REPORT ")
+ line()
+ metadataMap.foreach {
+ case (k, v) =>
+ writer.println(s"$k : $v")
+ }
+ writer.println()
+
+ // ---- ROOT (suite-level) ----
+ appenderMap.get(rootAppenderName).foreach {
+ root =>
+ val stdout = root.outStream.toString("UTF-8").trim
+ val stderr = root.errStream.toString("UTF-8").trim
+
+ writer.println("SUITE OUTPUT")
+ subLine()
+ writer.println()
+
+ printStreamBlock("STDOUT", stdout, "")
+ printStreamBlock("STDERR", stderr, "")
+ }
+
+ // ---- ACTIONS ----
+ val actions =
+ appenderMap.iterator.filterNot(_._1 == rootAppenderName)
+
+ if (actions.nonEmpty) {
+ writer.println("ACTIONS")
+ subLine()
+ writer.println()
+
+ actions.foreach {
+ case (name, appender) =>
+ val stdout = appender.outStream.toString("UTF-8").trim
+ val stderr = appender.errStream.toString("UTF-8").trim
+
+ writer.println(s"[ $name ]")
+ writer.println()
+
+ if (stdout.isEmpty && stderr.isEmpty) {
+ writer.println(" (no output)")
+ writer.println()
+ } else {
+ printStreamBlock("STDOUT", stdout, " ")
+ printStreamBlock("STDERR", stderr, " ")
+ }
+ }
+ }
+
+ line()
+ writer.flush()
+ }
+ }
+}
diff --git
a/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/ShimUtils.scala
b/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/shim/Shim.scala
similarity index 96%
rename from
tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/ShimUtils.scala
rename to
tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/shim/Shim.scala
index 3507af81f6..30987c81f0 100644
---
a/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/ShimUtils.scala
+++
b/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/shim/Shim.scala
@@ -14,14 +14,14 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-package org.apache.gluten.integration
+package org.apache.gluten.integration.shim
import org.apache.spark.VersionUtils
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder}
import org.apache.spark.sql.types.StructType
-object ShimUtils {
+object Shim {
def getExpressionEncoder(schema: StructType): ExpressionEncoder[Row] = {
val sparkVersion = VersionUtils.majorMinorVersion()
if (VersionUtils.compareMajorMinorVersion(sparkVersion, (3, 5)) < 0) {
diff --git
a/tools/gluten-it/common/src/main/scala/org/apache/spark/sql/SparkQueryRunner.scala
b/tools/gluten-it/common/src/main/scala/org/apache/spark/sql/SparkQueryRunner.scala
index ed84746cb7..6a5117be83 100644
---
a/tools/gluten-it/common/src/main/scala/org/apache/spark/sql/SparkQueryRunner.scala
+++
b/tools/gluten-it/common/src/main/scala/org/apache/spark/sql/SparkQueryRunner.scala
@@ -153,15 +153,7 @@ object SparkQueryRunner {
p.metrics.map {
case keyValue @ (k, m) =>
val tags = mapper.map(p, k, m)
- val tagMapMutable = mutable.Map[String,
mutable.Buffer[MetricTag[_]]]()
- tags.foreach {
- tag: MetricTag[_] =>
- val buffer =
- tagMapMutable.getOrElseUpdate(tag.name(),
mutable.ListBuffer[MetricTag[_]]())
- buffer += tag
- }
- val tagMap = tagMapMutable.map { case (k, v) => (k, v.toSeq)
}.toMap
- PlanMetric(queryPath, p, k, m, tagMap)
+ PlanMetric(queryPath, p, k, m, tags.toSet)
}
}
all.toSeq
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]