This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new 2ed096703 chore: extract comparison into separate tool (#2632)
2ed096703 is described below
commit 2ed096703b8082527f56befd1adc5195b8f4ab84
Author: Oleks V <[email protected]>
AuthorDate: Wed Oct 29 07:04:50 2025 -0700
chore: extract comparison into separate tool (#2632)
---
dev/benchmarks/tpcbench.py | 35 ++++-
fuzz-testing/README.md | 16 +++
.../org/apache/comet/fuzz/ComparisonTool.scala | 143 +++++++++++++++++++++
.../scala/org/apache/comet/fuzz/QueryRunner.scala | 112 +++++++++-------
4 files changed, 258 insertions(+), 48 deletions(-)
diff --git a/dev/benchmarks/tpcbench.py b/dev/benchmarks/tpcbench.py
index 39c34ca7c..0a91bf033 100644
--- a/dev/benchmarks/tpcbench.py
+++ b/dev/benchmarks/tpcbench.py
@@ -21,6 +21,22 @@ import json
from pyspark.sql import SparkSession
import time
+# rename same columns aliases
+# a, a, b, b -> a, a_1, b, b_1
+#
+# Important for writing data where column name uniqueness is required
+def dedup_columns(df):
+ counts = {}
+ new_cols = []
+ for c in df.columns:
+ if c not in counts:
+ counts[c] = 0
+ new_cols.append(c)
+ else:
+ counts[c] += 1
+ new_cols.append(f"{c}_{counts[c]}")
+ return df.toDF(*new_cols)
+
def main(benchmark: str, data_path: str, query_path: str, iterations: int,
output: str, name: str, query_num: int = None, write_path: str = None):
# Initialize a SparkSession
@@ -91,9 +107,19 @@ def main(benchmark: str, data_path: str, query_path: str,
iterations: int, outpu
df.explain()
if write_path is not None:
- output_path = f"{write_path}/q{query}"
-
df.coalesce(1).write.mode("overwrite").parquet(output_path)
- print(f"Query {query} results written to
{output_path}")
+ # skip results with empty schema
+ # coming across for running DDL stmt
+ if len(df.columns) > 0:
+ output_path = f"{write_path}/q{query}"
+ # rename same column names for output
+ # a, a, b, b => a, a_1, b, b_1
+ # output doesn't allow non unique column names
+ deduped = dedup_columns(df)
+ # sort by all columns to have predictable
output dataset for comparison
+
deduped.orderBy(*deduped.columns).coalesce(1).write.mode("overwrite").parquet(output_path)
+ print(f"Query {query} results written to
{output_path}")
+ else:
+ print(f"Skipping write: DataFrame has no
schema for {output_path}")
else:
rows = df.collect()
print(f"Query {query} returned {len(rows)} rows")
@@ -132,4 +158,5 @@ if __name__ == "__main__":
parser.add_argument("--write", required=False, help="Path to save query
results to, in Parquet format.")
args = parser.parse_args()
- main(args.benchmark, args.data, args.queries, int(args.iterations),
args.output, args.name, args.query, args.write)
\ No newline at end of file
+ main(args.benchmark, args.data, args.queries, int(args.iterations),
args.output, args.name, args.query, args.write)
+
diff --git a/fuzz-testing/README.md b/fuzz-testing/README.md
index c8cea5be8..74141fbf3 100644
--- a/fuzz-testing/README.md
+++ b/fuzz-testing/README.md
@@ -103,3 +103,19 @@ $SPARK_HOME/bin/spark-submit \
```
Note that the output filename is currently hard-coded as
`results-${System.currentTimeMillis()}.md`
+
+### Compare existing datasets
+
+To compare a pair of existing datasets you can use a comparison tool.
+The example below is for TPC-H queries results generated by pure Spark and
Comet
+
+
+```shell
+$SPARK_HOME/bin/spark-submit \
+ --master $SPARK_MASTER \
+ --class org.apache.comet.fuzz.ComparisonTool
+ target/comet-fuzz-spark3.5_2.12-0.12.0-SNAPSHOT-jar-with-dependencies.jar \
+ compareParquet --input-spark-folder=/tmp/tpch/spark
--input-comet-folder=/tmp/tpch/comet
+```
+
+The tool takes a pair of existing folders of the same layout and compares
subfolders treating them as parquet based datasets
\ No newline at end of file
diff --git
a/fuzz-testing/src/main/scala/org/apache/comet/fuzz/ComparisonTool.scala
b/fuzz-testing/src/main/scala/org/apache/comet/fuzz/ComparisonTool.scala
new file mode 100644
index 000000000..a4fd011fe
--- /dev/null
+++ b/fuzz-testing/src/main/scala/org/apache/comet/fuzz/ComparisonTool.scala
@@ -0,0 +1,143 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.comet.fuzz
+
+import java.io.File
+
+import org.rogach.scallop.{ScallopConf, ScallopOption, Subcommand}
+
+import org.apache.spark.sql.{functions, SparkSession}
+
+class ComparisonToolConf(arguments: Seq[String]) extends
ScallopConf(arguments) {
+ object compareParquet extends Subcommand("compareParquet") {
+ val inputSparkFolder: ScallopOption[String] =
+ opt[String](required = true, descr = "Folder with Spark produced results
in Parquet format")
+ val inputCometFolder: ScallopOption[String] =
+ opt[String](required = true, descr = "Folder with Comet produced results
in Parquet format")
+ }
+ addSubcommand(compareParquet)
+ verify()
+}
+
+object ComparisonTool {
+
+ lazy val spark: SparkSession = SparkSession
+ .builder()
+ .getOrCreate()
+
+ def main(args: Array[String]): Unit = {
+ val conf = new ComparisonToolConf(args.toIndexedSeq)
+ conf.subcommand match {
+ case Some(conf.compareParquet) =>
+ compareParquetFolders(
+ spark,
+ conf.compareParquet.inputSparkFolder(),
+ conf.compareParquet.inputCometFolder())
+
+ case _ =>
+ // scalastyle:off println
+ println("Invalid subcommand")
+ // scalastyle:on println
+ sys.exit(-1)
+ }
+ }
+
+ private def compareParquetFolders(
+ spark: SparkSession,
+ sparkFolderPath: String,
+ cometFolderPath: String): Unit = {
+
+ val output = QueryRunner.createOutputMdFile()
+
+ try {
+ val sparkFolder = new File(sparkFolderPath)
+ val cometFolder = new File(cometFolderPath)
+
+ if (!sparkFolder.exists() || !sparkFolder.isDirectory) {
+ throw new IllegalArgumentException(
+ s"Spark folder does not exist or is not a directory:
$sparkFolderPath")
+ }
+
+ if (!cometFolder.exists() || !cometFolder.isDirectory) {
+ throw new IllegalArgumentException(
+ s"Comet folder does not exist or is not a directory:
$cometFolderPath")
+ }
+
+ // Get all subdirectories from the Spark folder
+ val sparkSubfolders = sparkFolder
+ .listFiles()
+ .filter(_.isDirectory)
+ .map(_.getName)
+ .sorted
+
+ output.write("# Comparing Parquet Folders\n\n")
+ output.write(s"Spark folder: $sparkFolderPath\n")
+ output.write(s"Comet folder: $cometFolderPath\n")
+ output.write(s"Found ${sparkSubfolders.length} subfolders to
compare\n\n")
+
+ // Compare each subfolder
+ sparkSubfolders.foreach { subfolderName =>
+ val sparkSubfolderPath = new File(sparkFolder, subfolderName)
+ val cometSubfolderPath = new File(cometFolder, subfolderName)
+
+ if (!cometSubfolderPath.exists() || !cometSubfolderPath.isDirectory) {
+ output.write(s"## Subfolder: $subfolderName\n")
+ output.write(
+ s"[WARNING] Comet subfolder not found:
${cometSubfolderPath.getAbsolutePath}\n\n")
+ } else {
+ output.write(s"## Comparing subfolder: $subfolderName\n\n")
+
+ try {
+ // Read Spark parquet files
+ spark.conf.set("spark.comet.enabled", "false")
+ val sparkDf =
spark.read.parquet(sparkSubfolderPath.getAbsolutePath)
+ val sparkRows =
sparkDf.orderBy(sparkDf.columns.map(functions.col): _*).collect()
+
+ // Read Comet parquet files
+ val cometDf =
spark.read.parquet(cometSubfolderPath.getAbsolutePath)
+ val cometRows =
cometDf.orderBy(cometDf.columns.map(functions.col): _*).collect()
+
+ // Compare the results
+ if (QueryComparison.assertSameRows(sparkRows, cometRows, output)) {
+ output.write(s"Subfolder $subfolderName: ${sparkRows.length}
rows matched\n\n")
+ }
+ } catch {
+ case e: Exception =>
+ output.write(
+ s"[ERROR] Failed to compare subfolder $subfolderName:
${e.getMessage}\n")
+ val sw = new java.io.StringWriter()
+ val p = new java.io.PrintWriter(sw)
+ e.printStackTrace(p)
+ p.close()
+ output.write(s"```\n${sw.toString}\n```\n\n")
+ }
+ }
+
+ output.flush()
+ }
+
+ output.write("\n# Comparison Complete\n")
+ output.write(s"Compared ${sparkSubfolders.length} subfolders\n")
+
+ } finally {
+ output.close()
+ }
+ }
+}
diff --git
a/fuzz-testing/src/main/scala/org/apache/comet/fuzz/QueryRunner.scala
b/fuzz-testing/src/main/scala/org/apache/comet/fuzz/QueryRunner.scala
index bcc9f98d0..f4f345296 100644
--- a/fuzz-testing/src/main/scala/org/apache/comet/fuzz/QueryRunner.scala
+++ b/fuzz-testing/src/main/scala/org/apache/comet/fuzz/QueryRunner.scala
@@ -21,13 +21,24 @@ package org.apache.comet.fuzz
import java.io.{BufferedWriter, FileWriter, PrintWriter, StringWriter}
-import scala.collection.mutable.WrappedArray
+import scala.collection.mutable
import scala.io.Source
import org.apache.spark.sql.{Row, SparkSession}
+import org.apache.comet.fuzz.QueryComparison.showPlans
+
object QueryRunner {
+ def createOutputMdFile(): BufferedWriter = {
+ val outputFilename = s"results-${System.currentTimeMillis()}.md"
+ // scalastyle:off println
+ println(s"Writing results to $outputFilename")
+ // scalastyle:on println
+
+ new BufferedWriter(new FileWriter(outputFilename))
+ }
+
def runQueries(
spark: SparkSession,
numFiles: Int,
@@ -39,12 +50,7 @@ object QueryRunner {
var cometFailureCount = 0
var cometSuccessCount = 0
- val outputFilename = s"results-${System.currentTimeMillis()}.md"
- // scalastyle:off println
- println(s"Writing results to $outputFilename")
- // scalastyle:on println
-
- val w = new BufferedWriter(new FileWriter(outputFilename))
+ val w = createOutputMdFile()
// register input files
for (i <- 0 until numFiles) {
@@ -76,46 +82,21 @@ object QueryRunner {
val cometRows = df.collect()
val cometPlan = df.queryExecution.executedPlan.toString
- var success = true
- if (sparkRows.length == cometRows.length) {
- var i = 0
- while (i < sparkRows.length) {
- val l = sparkRows(i)
- val r = cometRows(i)
- assert(l.length == r.length)
- for (j <- 0 until l.length) {
- if (!same(l(j), r(j))) {
- success = false
- showSQL(w, sql)
- showPlans(w, sparkPlan, cometPlan)
- w.write(s"First difference at row $i:\n")
- w.write("Spark: `" + formatRow(l) + "`\n")
- w.write("Comet: `" + formatRow(r) + "`\n")
- i = sparkRows.length
- }
- }
- i += 1
- }
- } else {
- success = false
- showSQL(w, sql)
- showPlans(w, sparkPlan, cometPlan)
- w.write(
- s"[ERROR] Spark produced ${sparkRows.length} rows and " +
- s"Comet produced ${cometRows.length} rows.\n")
- }
+ var success = QueryComparison.assertSameRows(sparkRows,
cometRows, output = w)
// check that the plan contains Comet operators
if (!cometPlan.contains("Comet")) {
success = false
- showSQL(w, sql)
- showPlans(w, sparkPlan, cometPlan)
w.write("[ERROR] Comet did not accelerate any part of the
plan\n")
}
+ QueryComparison.showSQL(w, sql)
+
if (success) {
cometSuccessCount += 1
} else {
+ // show plans for failed queries
+ showPlans(w, sparkPlan, cometPlan)
cometFailureCount += 1
}
@@ -123,7 +104,7 @@ object QueryRunner {
case e: Exception =>
// the query worked in Spark but failed in Comet, so this is
likely a bug in Comet
cometFailureCount += 1
- showSQL(w, sql)
+ QueryComparison.showSQL(w, sql)
w.write("### Spark Plan\n")
w.write(s"```\n$sparkPlan\n```\n")
@@ -145,7 +126,7 @@ object QueryRunner {
// we expect many generated queries to be invalid
invalidQueryCount += 1
if (showFailedSparkQueries) {
- showSQL(w, sql)
+ QueryComparison.showSQL(w, sql)
w.write(s"Query failed in Spark: ${e.getMessage}\n")
}
}
@@ -161,6 +142,50 @@ object QueryRunner {
querySource.close()
}
}
+}
+
+object QueryComparison {
+ def assertSameRows(
+ sparkRows: Array[Row],
+ cometRows: Array[Row],
+ output: BufferedWriter): Boolean = {
+ if (sparkRows.length == cometRows.length) {
+ var i = 0
+ while (i < sparkRows.length) {
+ val l = sparkRows(i)
+ val r = cometRows(i)
+ // Check the schema is equal for first row only
+ if (i == 0 && l.schema != r.schema) {
+ output.write(
+ s"[ERROR] Spark produced schema ${l.schema} and " +
+ s"Comet produced schema ${r.schema} rows.\n")
+
+ return false
+ }
+
+ assert(l.length == r.length)
+ for (j <- 0 until l.length) {
+ if (!same(l(j), r(j))) {
+ output.write(s"First difference at row $i:\n")
+ output.write("Spark: `" + formatRow(l) + "`\n")
+ output.write("Comet: `" + formatRow(r) + "`\n")
+ i = sparkRows.length
+
+ return false
+ }
+ }
+ i += 1
+ }
+ } else {
+ output.write(
+ s"[ERROR] Spark produced ${sparkRows.length} rows and " +
+ s"Comet produced ${cometRows.length} rows.\n")
+
+ return false
+ }
+
+ true
+ }
private def same(l: Any, r: Any): Boolean = {
if (l == null || r == null) {
@@ -179,7 +204,7 @@ object QueryRunner {
case (a: Double, b: Double) => (a - b).abs <= 0.000001
case (a: Array[_], b: Array[_]) =>
a.length == b.length && a.zip(b).forall(x => same(x._1, x._2))
- case (a: WrappedArray[_], b: WrappedArray[_]) =>
+ case (a: mutable.WrappedArray[_], b: mutable.WrappedArray[_]) =>
a.length == b.length && a.zip(b).forall(x => same(x._1, x._2))
case (a: Row, b: Row) =>
val aa = a.toSeq
@@ -192,7 +217,7 @@ object QueryRunner {
private def format(value: Any): String = {
value match {
case null => "NULL"
- case v: WrappedArray[_] => s"[${v.map(format).mkString(",")}]"
+ case v: mutable.WrappedArray[_] => s"[${v.map(format).mkString(",")}]"
case v: Array[Byte] => s"[${v.mkString(",")}]"
case r: Row => formatRow(r)
case other => other.toString
@@ -203,7 +228,7 @@ object QueryRunner {
row.toSeq.map(format).mkString(",")
}
- private def showSQL(w: BufferedWriter, sql: String, maxLength: Int = 120):
Unit = {
+ def showSQL(w: BufferedWriter, sql: String, maxLength: Int = 120): Unit = {
w.write("## SQL\n")
w.write("```\n")
val words = sql.split(" ")
@@ -223,11 +248,10 @@ object QueryRunner {
w.write("```\n")
}
- private def showPlans(w: BufferedWriter, sparkPlan: String, cometPlan:
String): Unit = {
+ def showPlans(w: BufferedWriter, sparkPlan: String, cometPlan: String): Unit
= {
w.write("### Spark Plan\n")
w.write(s"```\n$sparkPlan\n```\n")
w.write("### Comet Plan\n")
w.write(s"```\n$cometPlan\n```\n")
}
-
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]