comphead commented on code in PR #2632:
URL: https://github.com/apache/datafusion-comet/pull/2632#discussion_r2453464967


##########
fuzz-testing/src/main/scala/org/apache/comet/fuzz/QueryRunner.scala:
##########
@@ -55,104 +133,100 @@ object QueryRunner {
     try {
       querySource
         .getLines()
-        .foreach(sql => {
-
-          try {
-            // execute with Spark
-            spark.conf.set("spark.comet.enabled", "false")
-            val df = spark.sql(sql)
-            val sparkRows = df.collect()
-            val sparkPlan = df.queryExecution.executedPlan.toString
-
-            // execute with Comet
-            try {
-              spark.conf.set("spark.comet.enabled", "true")
-              // complex type support until we support it natively
-              spark.conf.set("spark.comet.sparkToColumnar.enabled", "true")
-              spark.conf.set("spark.comet.convert.parquet.enabled", "true")
-              val df = spark.sql(sql)
-              val cometRows = df.collect()
-              val cometPlan = df.queryExecution.executedPlan.toString
-
-              if (sparkRows.length == cometRows.length) {
-                var i = 0
-                while (i < sparkRows.length) {
-                  val l = sparkRows(i)
-                  val r = cometRows(i)
-                  assert(l.length == r.length)
-                  for (j <- 0 until l.length) {
-                    if (!same(l(j), r(j))) {
-                      showSQL(w, sql)
-                      showPlans(w, sparkPlan, cometPlan)
-                      w.write(s"First difference at row $i:\n")
-                      w.write("Spark: `" + formatRow(l) + "`\n")
-                      w.write("Comet: `" + formatRow(r) + "`\n")
-                      i = sparkRows.length
-                    }
-                  }
-                  i += 1
-                }
-              } else {
-                showSQL(w, sql)
-                showPlans(w, sparkPlan, cometPlan)
-                w.write(
-                  s"[ERROR] Spark produced ${sparkRows.length} rows and " +
-                    s"Comet produced ${cometRows.length} rows.\n")
-              }
-            } catch {
-              case e: Exception =>
-                // the query worked in Spark but failed in Comet, so this is 
likely a bug in Comet
-                showSQL(w, sql)
-                w.write(s"[ERROR] Query failed in Comet: ${e.getMessage}:\n")
-                w.write("```\n")
-                val sw = new StringWriter()
-                val p = new PrintWriter(sw)
-                e.printStackTrace(p)
-                p.close()
-                w.write(s"${sw.toString}\n")
-                w.write("```\n")
-            }
-
-            // flush after every query so that results are saved in the event 
of the driver crashing
-            w.flush()
-
-          } catch {
-            case e: Exception =>
-              // we expect many generated queries to be invalid
-              if (showFailedSparkQueries) {
-                showSQL(w, sql)
-                w.write(s"Query failed in Spark: ${e.getMessage}\n")
-              }
-          }
-        })
+        .foreach(sql => assertCorrectness(spark, sql, showFailedSparkQueries, 
output = w))
 
     } finally {
       w.close()
       querySource.close()
     }
   }
 
+  def runTPCQueries(
+      spark: SparkSession,
+      dataFolderName: String,
+      queriesFolderName: String): Unit = {
+    val output = QueryRunner.createOutputMdFile()
+
+    // Load data tables from dataFolder
+    val dataFolder = new File(dataFolderName)
+    if (!dataFolder.exists() || !dataFolder.isDirectory) {
+      // scalastyle:off println
+      println(s"Error: Data folder $dataFolder does not exist or is not a 
directory")
+      // scalastyle:on println
+      sys.exit(-1)
+    }
+
+    // Traverse data folder and create temp views
+    dataFolder.listFiles().filter(_.isDirectory).foreach { tableDir =>
+      val tableName = tableDir.getName
+      val parquetPath = s"${tableDir.getAbsolutePath}/*.parquet"
+      spark.read.parquet(parquetPath).createOrReplaceTempView(tableName)
+      // scalastyle:off println
+      println(s"Created temp view: $tableName from $parquetPath")
+    // scalastyle:on println
+    }
+
+    // Load and run queries from queriesFolder
+    val queriesFolder = new File(queriesFolderName)
+    if (!queriesFolder.exists() || !queriesFolder.isDirectory) {
+      // scalastyle:off println
+      println(s"Error: Queries folder $queriesFolder does not exist or is not 
a directory")
+      // scalastyle:on println
+      sys.exit(-1)
+    }
+
+    // Traverse queries folder and run each .sql file
+    queriesFolder.listFiles().filter(f => f.isFile && 
f.getName.endsWith(".sql")).foreach {
+      sqlFile =>
+        // scalastyle:off println
+        println(s"Running query from: ${sqlFile.getName}")
+        // scalastyle:on println
+
+        val querySource = Source.fromFile(sqlFile)
+        try {
+          val sql = querySource.mkString
+          QueryRunner.assertCorrectness(spark, sql, showFailedSparkQueries = 
false, output)
+        } finally {
+          querySource.close()
+        }
+    }
+
+    output.close()
+  }
+
   private def same(l: Any, r: Any): Boolean = {
+    if (l == null || r == null) {
+      return l == null && r == null
+    }
     (l, r) match {
+      case (a: Float, b: Float) if a.isPosInfinity => b.isPosInfinity
+      case (a: Float, b: Float) if a.isNegInfinity => b.isNegInfinity
       case (a: Float, b: Float) if a.isInfinity => b.isInfinity
       case (a: Float, b: Float) if a.isNaN => b.isNaN
       case (a: Float, b: Float) => (a - b).abs <= 0.000001f
+      case (a: Double, b: Double) if a.isPosInfinity => b.isPosInfinity
+      case (a: Double, b: Double) if a.isNegInfinity => b.isNegInfinity
       case (a: Double, b: Double) if a.isInfinity => b.isInfinity
       case (a: Double, b: Double) if a.isNaN => b.isNaN
       case (a: Double, b: Double) => (a - b).abs <= 0.000001
       case (a: Array[_], b: Array[_]) =>
         a.length == b.length && a.zip(b).forall(x => same(x._1, x._2))
-      case (a: WrappedArray[_], b: WrappedArray[_]) =>
+      case (a: mutable.WrappedArray[_], b: mutable.WrappedArray[_]) =>

Review Comment:
   moved it from #2614 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to