Github user maropu commented on a diff in the pull request:

    https://github.com/apache/spark/pull/19188#discussion_r137999669
  
    --- Diff: 
sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/TPCDSQueryBenchmark.scala
 ---
    @@ -32,76 +33,76 @@ import org.apache.spark.util.Benchmark
      * Benchmark to measure TPCDS query performance.
      * To run this:
      *  spark-submit --class <this class> --jars <spark sql test jar>
    + *
    + * By default, this class runs all the TPC-DS queries. If you want to run 
some of them,
    + * you can use an option to filter the queries that it runs, e.g.,
    + * to run q2, q4, and q6 only:
    + *  spark-submit --class <this class> --conf 
spark.sql.tpcds.queryFilter="q2,q4,q6"
    + *    --jars <spark sql test jar>
      */
    -object TPCDSQueryBenchmark {
    -  val conf =
    -    new SparkConf()
    -      .setMaster("local[1]")
    -      .setAppName("test-sql-context")
    -      .set("spark.sql.parquet.compression.codec", "snappy")
    -      .set("spark.sql.shuffle.partitions", "4")
    -      .set("spark.driver.memory", "3g")
    -      .set("spark.executor.memory", "3g")
    -      .set("spark.sql.autoBroadcastJoinThreshold", (20 * 1024 * 
1024).toString)
    -      .set("spark.sql.crossJoin.enabled", "true")
    -
    -  val spark = SparkSession.builder.config(conf).getOrCreate()
    +object TPCDSQueryBenchmark extends Logging {
     
    -  val tables = Seq("catalog_page", "catalog_returns", "customer", 
"customer_address",
    -    "customer_demographics", "date_dim", "household_demographics", 
"inventory", "item",
    -    "promotion", "store", "store_returns", "catalog_sales", "web_sales", 
"store_sales",
    -    "web_returns", "web_site", "reason", "call_center", "warehouse", 
"ship_mode", "income_band",
    -    "time_dim", "web_page")
    -
    -  def setupTables(dataLocation: String): Map[String, Long] = {
    -    tables.map { tableName =>
    -      
spark.read.parquet(s"$dataLocation/$tableName").createOrReplaceTempView(tableName)
    -      tableName -> spark.table(tableName).count()
    -    }.toMap
    -  }
    +  case class TpcdsQueries(spark: SparkSession, queries: Seq[String], 
dataLocation: String) {
     
    -  def tpcdsAll(dataLocation: String, queries: Seq[String]): Unit = {
         require(dataLocation.nonEmpty,
           "please modify the value of dataLocation to point to your local 
TPCDS data")
    -    val tableSizes = setupTables(dataLocation)
    -    queries.foreach { name =>
    -      val queryString = fileToString(new 
File(Thread.currentThread().getContextClassLoader
    -        .getResource(s"tpcds/$name.sql").getFile))
    -
    -      // This is an indirect hack to estimate the size of each query's 
input by traversing the
    -      // logical plan and adding up the sizes of all tables that appear in 
the plan. Note that this
    -      // currently doesn't take WITH subqueries into account which might 
lead to fairly inaccurate
    -      // per-row processing time for those cases.
    -      val queryRelations = scala.collection.mutable.HashSet[String]()
    -      spark.sql(queryString).queryExecution.logical.map {
    -        case UnresolvedRelation(t: TableIdentifier) =>
    -          queryRelations.add(t.table)
    -        case lp: LogicalPlan =>
    -          lp.expressions.foreach { _ foreach {
    -            case subquery: SubqueryExpression =>
    -              subquery.plan.foreach {
    -                case UnresolvedRelation(t: TableIdentifier) =>
    -                  queryRelations.add(t.table)
    -                case _ =>
    -              }
    -            case _ =>
    +
    +    private val tables = Seq("catalog_page", "catalog_returns", 
"customer", "customer_address",
    +      "customer_demographics", "date_dim", "household_demographics", 
"inventory", "item",
    +      "promotion", "store", "store_returns", "catalog_sales", "web_sales", 
"store_sales",
    +      "web_returns", "web_site", "reason", "call_center", "warehouse", 
"ship_mode", "income_band",
    +      "time_dim", "web_page")
    +
    +    private def setupTables(dataLocation: String): Map[String, Long] = {
    +      tables.map { tableName =>
    +        
spark.read.parquet(s"$dataLocation/$tableName").createOrReplaceTempView(tableName)
    +        tableName -> spark.table(tableName).count()
    +      }.toMap
    +    }
    +
    +    def run(): Unit = {
    +      val tableSizes = setupTables(dataLocation)
    +      queries.foreach { name =>
    +        val queryString = fileToString(new 
File(Thread.currentThread().getContextClassLoader
    +          .getResource(s"tpcds/$name.sql").getFile))
    +
    +        // This is an indirect hack to estimate the size of each query's 
input by traversing the
    +        // logical plan and adding up the sizes of all tables that appear 
in the plan. Note that this
    +        // currently doesn't take WITH subqueries into account which might 
lead to fairly inaccurate
    +        // per-row processing time for those cases.
    +        val queryRelations = scala.collection.mutable.HashSet[String]()
    +        spark.sql(queryString).queryExecution.logical.map {
    +          case UnresolvedRelation(t: TableIdentifier) =>
    +            queryRelations.add(t.table)
    +          case lp: LogicalPlan =>
    +            lp.expressions.foreach { _ foreach {
    +              case subquery: SubqueryExpression =>
    +                subquery.plan.foreach {
    +                  case UnresolvedRelation(t: TableIdentifier) =>
    +                    queryRelations.add(t.table)
    +                  case _ =>
    +                }
    +              case _ =>
    +            }
               }
    +          case _ =>
             }
    -        case _ =>
    -      }
    -      val numRows = queryRelations.map(tableSizes.getOrElse(_, 0L)).sum
    -      val benchmark = new Benchmark(s"TPCDS Snappy", numRows, 5)
    -      benchmark.addCase(name) { i =>
    -        spark.sql(queryString).collect()
    +        val numRows = queryRelations.map(tableSizes.getOrElse(_, 0L)).sum
    +        val benchmark = new Benchmark(s"TPCDS Snappy", numRows, 5)
    +        benchmark.addCase(name) { i =>
    +          spark.sql(queryString).collect()
    +        }
    +        logInfo(s"\n\n===== TPCDS QUERY BENCHMARK OUTPUT FOR $name 
=====\n")
    --- End diff --
    
    This change is not directly related to this pr though, I added this log 
here cuz this change is much trivial. I think this log helps to check which 
query fails.


---

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to