This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new 5fb88b8c7 feat: Define function signatures in CometFuzz (#2614)
5fb88b8c7 is described below
commit 5fb88b8c7112adacc9295dc175ff8a3a6c73a613
Author: Andy Grove <[email protected]>
AuthorDate: Fri Oct 24 09:01:14 2025 -0600
feat: Define function signatures in CometFuzz (#2614)
---
fuzz-testing/README.md | 11 +-
.../main/scala/org/apache/comet/fuzz/Main.scala | 6 +-
.../main/scala/org/apache/comet/fuzz/Meta.scala | 382 ++++++++++++++++-----
.../scala/org/apache/comet/fuzz/QueryGen.scala | 234 ++++++++++---
.../scala/org/apache/comet/fuzz/QueryRunner.scala | 50 ++-
5 files changed, 548 insertions(+), 135 deletions(-)
diff --git a/fuzz-testing/README.md b/fuzz-testing/README.md
index 17b2c151a..c8cea5be8 100644
--- a/fuzz-testing/README.md
+++ b/fuzz-testing/README.md
@@ -61,7 +61,7 @@ Set appropriate values for `SPARK_HOME`, `SPARK_MASTER`, and
`COMET_JAR` environ
$SPARK_HOME/bin/spark-submit \
--master $SPARK_MASTER \
--class org.apache.comet.fuzz.Main \
- target/comet-fuzz-spark3.4_2.12-0.7.0-SNAPSHOT-jar-with-dependencies.jar \
+ target/comet-fuzz-spark3.5_2.12-0.12.0-SNAPSHOT-jar-with-dependencies.jar \
data --num-files=2 --num-rows=200 --exclude-negative-zero
--generate-arrays --generate-structs --generate-maps
```
@@ -77,7 +77,7 @@ Generate random queries that are based on the available test
files.
$SPARK_HOME/bin/spark-submit \
--master $SPARK_MASTER \
--class org.apache.comet.fuzz.Main \
- target/comet-fuzz-spark3.4_2.12-0.7.0-SNAPSHOT-jar-with-dependencies.jar \
+ target/comet-fuzz-spark3.5_2.12-0.12.0-SNAPSHOT-jar-with-dependencies.jar \
queries --num-files=2 --num-queries=500
```
@@ -88,18 +88,17 @@ Note that the output filename is currently hard-coded as
`queries.sql`
```shell
$SPARK_HOME/bin/spark-submit \
--master $SPARK_MASTER \
+ --conf spark.memory.offHeap.enabled=true \
+ --conf spark.memory.offHeap.size=16G \
--conf spark.plugins=org.apache.spark.CometPlugin \
--conf spark.comet.enabled=true \
- --conf spark.comet.exec.enabled=true \
- --conf spark.comet.exec.all.enabled=true \
--conf
spark.shuffle.manager=org.apache.spark.sql.comet.execution.shuffle.CometShuffleManager
\
--conf spark.comet.exec.shuffle.enabled=true \
- --conf spark.comet.exec.shuffle.mode=auto \
--jars $COMET_JAR \
--conf spark.driver.extraClassPath=$COMET_JAR \
--conf spark.executor.extraClassPath=$COMET_JAR \
--class org.apache.comet.fuzz.Main \
- target/comet-fuzz-spark3.4_2.12-0.7.0-SNAPSHOT-jar-with-dependencies.jar \
+ target/comet-fuzz-spark3.5_2.12-0.12.0-SNAPSHOT-jar-with-dependencies.jar \
run --num-files=2 --filename=queries.sql
```
diff --git a/fuzz-testing/src/main/scala/org/apache/comet/fuzz/Main.scala
b/fuzz-testing/src/main/scala/org/apache/comet/fuzz/Main.scala
index 1f81dc779..b9e63c76a 100644
--- a/fuzz-testing/src/main/scala/org/apache/comet/fuzz/Main.scala
+++ b/fuzz-testing/src/main/scala/org/apache/comet/fuzz/Main.scala
@@ -87,7 +87,11 @@ object Main {
SchemaGenOptions(
generateArray = conf.generateData.generateArrays(),
generateStruct = conf.generateData.generateStructs(),
- generateMap = conf.generateData.generateMaps()),
+ generateMap = conf.generateData.generateMaps(),
+ // create two columns of each primitive type so that they can be
used in binary
+ // expressions such as `a + b` and `a < b`
+ primitiveTypes = SchemaGenOptions.defaultPrimitiveTypes ++
+ SchemaGenOptions.defaultPrimitiveTypes),
DataGenOptions(
allowNull = true,
generateNegativeZero = !conf.generateData.excludeNegativeZero()))
diff --git a/fuzz-testing/src/main/scala/org/apache/comet/fuzz/Meta.scala
b/fuzz-testing/src/main/scala/org/apache/comet/fuzz/Meta.scala
index 246216840..74d13f85e 100644
--- a/fuzz-testing/src/main/scala/org/apache/comet/fuzz/Meta.scala
+++ b/fuzz-testing/src/main/scala/org/apache/comet/fuzz/Meta.scala
@@ -22,6 +22,32 @@ package org.apache.comet.fuzz
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.types.DataTypes
+sealed trait SparkType
+case class SparkTypeOneOf(dataTypes: Seq[SparkType]) extends SparkType
+case object SparkBooleanType extends SparkType
+case object SparkBinaryType extends SparkType
+case object SparkStringType extends SparkType
+case object SparkIntegralType extends SparkType
+case object SparkByteType extends SparkType
+case object SparkShortType extends SparkType
+case object SparkIntType extends SparkType
+case object SparkLongType extends SparkType
+case object SparkFloatType extends SparkType
+case object SparkDoubleType extends SparkType
+case class SparkDecimalType(p: Int, s: Int) extends SparkType
+case object SparkNumericType extends SparkType
+case object SparkDateType extends SparkType
+case object SparkTimestampType extends SparkType
+case object SparkDateOrTimestampType extends SparkType
+case class SparkArrayType(elementType: SparkType) extends SparkType
+case class SparkMapType(keyType: SparkType, valueType: SparkType) extends
SparkType
+case class SparkStructType(fields: Seq[SparkType]) extends SparkType
+case object SparkAnyType extends SparkType
+
+case class FunctionSignature(inputTypes: Seq[SparkType])
+
+case class Function(name: String, signatures: Seq[FunctionSignature])
+
object Meta {
val dataTypes: Seq[(DataType, Double)] = Seq(
@@ -35,100 +61,283 @@ object Meta {
(DataTypes.createDecimalType(10, 2), 0.2),
(DataTypes.DateType, 0.2),
(DataTypes.TimestampType, 0.2),
- // TimestampNTZType only in Spark 3.4+
- // (DataTypes.TimestampNTZType, 0.2),
+ (DataTypes.TimestampNTZType, 0.2),
(DataTypes.StringType, 0.2),
(DataTypes.BinaryType, 0.1))
- val stringScalarFunc: Seq[Function] = Seq(
- Function("substring", 3),
- Function("coalesce", 1),
- Function("starts_with", 2),
- Function("ends_with", 2),
- Function("contains", 2),
- Function("ascii", 1),
- Function("bit_length", 1),
- Function("octet_length", 1),
- Function("upper", 1),
- Function("lower", 1),
- Function("chr", 1),
- Function("init_cap", 1),
- Function("trim", 1),
- Function("ltrim", 1),
- Function("rtrim", 1),
- Function("string_space", 1),
- Function("rpad", 2),
- Function("rpad", 3), // rpad can have 2 or 3 arguments
- Function("hex", 1),
- Function("unhex", 1),
- Function("xxhash64", 1),
- Function("sha1", 1),
- // Function("sha2", 1), -- needs a second argument for number of bits
- Function("substring", 3),
- Function("btrim", 1),
- Function("concat_ws", 2),
- Function("repeat", 2),
- Function("length", 1),
- Function("reverse", 1),
- Function("instr", 2),
- Function("replace", 2),
- Function("translate", 2))
-
- val dateScalarFunc: Seq[Function] =
- Seq(Function("year", 1), Function("hour", 1), Function("minute", 1),
Function("second", 1))
+ private def createFunctionWithInputTypes(name: String, inputs:
Seq[SparkType]): Function = {
+ Function(name, Seq(FunctionSignature(inputs)))
+ }
+
+ private def createFunctions(name: String, signatures:
Seq[FunctionSignature]): Function = {
+ Function(name, signatures)
+ }
+ private def createUnaryStringFunction(name: String): Function = {
+ createFunctionWithInputTypes(name, Seq(SparkStringType))
+ }
+
+ private def createUnaryNumericFunction(name: String): Function = {
+ createFunctionWithInputTypes(name, Seq(SparkNumericType))
+ }
+
+ // Math expressions (corresponds to mathExpressions in QueryPlanSerde)
val mathScalarFunc: Seq[Function] = Seq(
- Function("abs", 1),
- Function("acos", 1),
- Function("asin", 1),
- Function("atan", 1),
- Function("Atan2", 1),
- Function("Cos", 1),
- Function("Exp", 2),
- Function("Ln", 1),
- Function("Log10", 1),
- Function("Log2", 1),
- Function("Pow", 2),
- Function("Round", 1),
- Function("Signum", 1),
- Function("Sin", 1),
- Function("Sqrt", 1),
- Function("Tan", 1),
- Function("Ceil", 1),
- Function("Floor", 1),
- Function("bool_and", 1),
- Function("bool_or", 1),
- Function("bitwise_not", 1))
+ createUnaryNumericFunction("abs"),
+ createUnaryNumericFunction("acos"),
+ createUnaryNumericFunction("asin"),
+ createUnaryNumericFunction("atan"),
+ createFunctionWithInputTypes("atan2", Seq(SparkNumericType,
SparkNumericType)),
+ createUnaryNumericFunction("cos"),
+ createUnaryNumericFunction("exp"),
+ createUnaryNumericFunction("expm1"),
+ createFunctionWithInputTypes("log", Seq(SparkNumericType,
SparkNumericType)),
+ createUnaryNumericFunction("log10"),
+ createUnaryNumericFunction("log2"),
+ createFunctionWithInputTypes("pow", Seq(SparkNumericType,
SparkNumericType)),
+ createFunctionWithInputTypes("remainder", Seq(SparkNumericType,
SparkNumericType)),
+ createFunctions(
+ "round",
+ Seq(
+ FunctionSignature(Seq(SparkNumericType)),
+ FunctionSignature(Seq(SparkNumericType, SparkIntType)))),
+ createUnaryNumericFunction("signum"),
+ createUnaryNumericFunction("sin"),
+ createUnaryNumericFunction("sqrt"),
+ createUnaryNumericFunction("tan"),
+ createUnaryNumericFunction("ceil"),
+ createUnaryNumericFunction("floor"),
+ createFunctionWithInputTypes("unary_minus", Seq(SparkNumericType)))
+ // Hash expressions (corresponds to hashExpressions in QueryPlanSerde)
+ val hashScalarFunc: Seq[Function] = Seq(
+ createFunctionWithInputTypes("md5", Seq(SparkAnyType)),
+ createFunctionWithInputTypes("murmur3_hash", Seq(SparkAnyType)), // TODO
variadic
+ createFunctionWithInputTypes("sha2", Seq(SparkAnyType, SparkIntType)))
+
+ // String expressions (corresponds to stringExpressions in QueryPlanSerde)
+ val stringScalarFunc: Seq[Function] = Seq(
+ createUnaryStringFunction("ascii"),
+ createUnaryStringFunction("bit_length"),
+ createUnaryStringFunction("chr"),
+ createFunctionWithInputTypes(
+ "concat",
+ Seq(
+ SparkTypeOneOf(
+ Seq(
+ SparkStringType,
+ SparkNumericType,
+ SparkBinaryType,
+ SparkArrayType(
+ SparkTypeOneOf(Seq(SparkStringType, SparkNumericType,
SparkBinaryType))))),
+ SparkTypeOneOf(
+ Seq(
+ SparkStringType,
+ SparkNumericType,
+ SparkBinaryType,
+ SparkArrayType(
+ SparkTypeOneOf(Seq(SparkStringType, SparkNumericType,
SparkBinaryType))))))),
+ createFunctionWithInputTypes("concat_ws", Seq(SparkStringType,
SparkStringType)),
+ createFunctionWithInputTypes("contains", Seq(SparkStringType,
SparkStringType)),
+ createFunctionWithInputTypes("ends_with", Seq(SparkStringType,
SparkStringType)),
+ createFunctionWithInputTypes(
+ "hex",
+ Seq(SparkTypeOneOf(Seq(SparkStringType, SparkBinaryType, SparkIntType,
SparkLongType)))),
+ createUnaryStringFunction("init_cap"),
+ createFunctionWithInputTypes("instr", Seq(SparkStringType,
SparkStringType)),
+ createFunctionWithInputTypes(
+ "length",
+ Seq(SparkTypeOneOf(Seq(SparkStringType, SparkBinaryType)))),
+ createFunctionWithInputTypes("like", Seq(SparkStringType,
SparkStringType)),
+ createUnaryStringFunction("lower"),
+ createFunctions(
+ "lpad",
+ Seq(
+ FunctionSignature(Seq(SparkStringType, SparkIntegralType)),
+ FunctionSignature(Seq(SparkStringType, SparkIntegralType,
SparkStringType)))),
+ createUnaryStringFunction("ltrim"),
+ createUnaryStringFunction("octet_length"),
+ createFunctions(
+ "regexp_replace",
+ Seq(
+ FunctionSignature(Seq(SparkStringType, SparkStringType,
SparkStringType)),
+ FunctionSignature(Seq(SparkStringType, SparkStringType,
SparkStringType, SparkIntType)))),
+ createFunctionWithInputTypes("repeat", Seq(SparkStringType, SparkIntType)),
+ createFunctions(
+ "replace",
+ Seq(
+ FunctionSignature(Seq(SparkStringType, SparkStringType)),
+ FunctionSignature(Seq(SparkStringType, SparkStringType,
SparkStringType)))),
+ createFunctions(
+ "reverse",
+ Seq(
+ FunctionSignature(Seq(SparkStringType)),
+ FunctionSignature(Seq(SparkArrayType(SparkAnyType))))),
+ createFunctionWithInputTypes("rlike", Seq(SparkStringType,
SparkStringType)),
+ createFunctions(
+ "rpad",
+ Seq(
+ FunctionSignature(Seq(SparkStringType, SparkIntegralType)),
+ FunctionSignature(Seq(SparkStringType, SparkIntegralType,
SparkStringType)))),
+ createUnaryStringFunction("rtrim"),
+ createFunctionWithInputTypes("starts_with", Seq(SparkStringType,
SparkStringType)),
+ createFunctionWithInputTypes("string_space", Seq(SparkIntType)),
+ createFunctionWithInputTypes("substring", Seq(SparkStringType,
SparkIntType, SparkIntType)),
+ createFunctionWithInputTypes("translate", Seq(SparkStringType,
SparkStringType)),
+ createUnaryStringFunction("trim"),
+ createUnaryStringFunction("btrim"),
+ createUnaryStringFunction("unhex"),
+ createUnaryStringFunction("upper"),
+ createFunctionWithInputTypes("xxhash64", Seq(SparkAnyType)), // TODO
variadic
+ createFunctionWithInputTypes("sha1", Seq(SparkAnyType)))
+
+ // Conditional expressions (corresponds to conditionalExpressions in
QueryPlanSerde)
+ val conditionalScalarFunc: Seq[Function] = Seq(
+ createFunctionWithInputTypes("if", Seq(SparkBooleanType, SparkAnyType,
SparkAnyType)))
+
+ // Map expressions (corresponds to mapExpressions in QueryPlanSerde)
+ val mapScalarFunc: Seq[Function] = Seq(
+ createFunctionWithInputTypes(
+ "map_extract",
+ Seq(SparkMapType(SparkAnyType, SparkAnyType), SparkAnyType)),
+ createFunctionWithInputTypes("map_keys", Seq(SparkMapType(SparkAnyType,
SparkAnyType))),
+ createFunctionWithInputTypes("map_entries", Seq(SparkMapType(SparkAnyType,
SparkAnyType))),
+ createFunctionWithInputTypes("map_values", Seq(SparkMapType(SparkAnyType,
SparkAnyType))),
+ createFunctionWithInputTypes(
+ "map_from_arrays",
+ Seq(SparkArrayType(SparkAnyType), SparkArrayType(SparkAnyType))))
+
+ // Predicate expressions (corresponds to predicateExpressions in
QueryPlanSerde)
+ val predicateScalarFunc: Seq[Function] = Seq(
+ createFunctionWithInputTypes("and", Seq(SparkBooleanType,
SparkBooleanType)),
+ createFunctionWithInputTypes("or", Seq(SparkBooleanType,
SparkBooleanType)),
+ createFunctionWithInputTypes("not", Seq(SparkBooleanType)),
+ createFunctionWithInputTypes("in", Seq(SparkAnyType, SparkAnyType))
+ ) // TODO: variadic
+
+ // Struct expressions (corresponds to structExpressions in QueryPlanSerde)
+ val structScalarFunc: Seq[Function] = Seq(
+ createFunctionWithInputTypes(
+ "create_named_struct",
+ Seq(SparkStringType, SparkAnyType)
+ ), // TODO: variadic name/value pairs
+ createFunctionWithInputTypes(
+ "get_struct_field",
+ Seq(SparkStructType(Seq(SparkAnyType)), SparkStringType)))
+
+ // Bitwise expressions (corresponds to bitwiseExpressions in QueryPlanSerde)
+ val bitwiseScalarFunc: Seq[Function] = Seq(
+ createFunctionWithInputTypes("bitwise_and", Seq(SparkIntegralType,
SparkIntegralType)),
+ createFunctionWithInputTypes("bitwise_count", Seq(SparkIntegralType)),
+ createFunctionWithInputTypes("bitwise_get", Seq(SparkIntegralType,
SparkIntType)),
+ createFunctionWithInputTypes("bitwise_or", Seq(SparkIntegralType,
SparkIntegralType)),
+ createFunctionWithInputTypes("bitwise_not", Seq(SparkIntegralType)),
+ createFunctionWithInputTypes("bitwise_xor", Seq(SparkIntegralType,
SparkIntegralType)),
+ createFunctionWithInputTypes("shift_left", Seq(SparkIntegralType,
SparkIntType)),
+ createFunctionWithInputTypes("shift_right", Seq(SparkIntegralType,
SparkIntType)))
+
+ // Misc expressions (corresponds to miscExpressions in QueryPlanSerde)
val miscScalarFunc: Seq[Function] =
- Seq(Function("isnan", 1), Function("isnull", 1), Function("isnotnull", 1))
+ Seq(
+ createFunctionWithInputTypes("isnan", Seq(SparkNumericType)),
+ createFunctionWithInputTypes("isnull", Seq(SparkAnyType)),
+ createFunctionWithInputTypes("isnotnull", Seq(SparkAnyType)),
+ createFunctionWithInputTypes("coalesce", Seq(SparkAnyType, SparkAnyType))
+ ) // TODO: variadic
+ // Array expressions (corresponds to arrayExpressions in QueryPlanSerde)
val arrayScalarFunc: Seq[Function] = Seq(
- Function("array", 2),
- Function("array_remove", 2),
- Function("array_insert", 2),
- Function("array_contains", 2),
- Function("array_intersect", 2),
- Function("array_append", 2))
+ createFunctionWithInputTypes("array_append",
Seq(SparkArrayType(SparkAnyType), SparkAnyType)),
+ createFunctionWithInputTypes("array_compact",
Seq(SparkArrayType(SparkAnyType))),
+ createFunctionWithInputTypes(
+ "array_contains",
+ Seq(SparkArrayType(SparkAnyType), SparkAnyType)),
+ createFunctionWithInputTypes("array_distinct",
Seq(SparkArrayType(SparkAnyType))),
+ createFunctionWithInputTypes(
+ "array_except",
+ Seq(SparkArrayType(SparkAnyType), SparkArrayType(SparkAnyType))),
+ createFunctionWithInputTypes(
+ "array_insert",
+ Seq(SparkArrayType(SparkAnyType), SparkIntType, SparkAnyType)),
+ createFunctionWithInputTypes(
+ "array_intersect",
+ Seq(SparkArrayType(SparkAnyType), SparkArrayType(SparkAnyType))),
+ createFunctions(
+ "array_join",
+ Seq(
+ FunctionSignature(Seq(SparkArrayType(SparkAnyType), SparkStringType)),
+ FunctionSignature(Seq(SparkArrayType(SparkAnyType), SparkStringType,
SparkStringType)))),
+ createFunctionWithInputTypes("array_max",
Seq(SparkArrayType(SparkAnyType))),
+ createFunctionWithInputTypes("array_min",
Seq(SparkArrayType(SparkAnyType))),
+ createFunctionWithInputTypes("array_remove",
Seq(SparkArrayType(SparkAnyType), SparkAnyType)),
+ createFunctionWithInputTypes("array_repeat", Seq(SparkAnyType,
SparkIntType)),
+ createFunctionWithInputTypes(
+ "arrays_overlap",
+ Seq(SparkArrayType(SparkAnyType), SparkArrayType(SparkAnyType))),
+ createFunctionWithInputTypes(
+ "array_union",
+ Seq(SparkArrayType(SparkAnyType), SparkArrayType(SparkAnyType))),
+ createFunctionWithInputTypes("array", Seq(SparkAnyType, SparkAnyType)), //
TODO: variadic
+ createFunctionWithInputTypes(
+ "element_at",
+ Seq(
+ SparkTypeOneOf(
+ Seq(SparkArrayType(SparkAnyType), SparkMapType(SparkAnyType,
SparkAnyType))),
+ SparkAnyType)),
+ createFunctionWithInputTypes("flatten",
Seq(SparkArrayType(SparkArrayType(SparkAnyType)))),
+ createFunctionWithInputTypes(
+ "get_array_item",
+ Seq(SparkArrayType(SparkAnyType), SparkIntType)))
- val scalarFunc: Seq[Function] = stringScalarFunc ++ dateScalarFunc ++
- mathScalarFunc ++ miscScalarFunc ++ arrayScalarFunc
+ // Temporal expressions (corresponds to temporalExpressions in
QueryPlanSerde)
+ val temporalScalarFunc: Seq[Function] =
+ Seq(
+ createFunctionWithInputTypes("date_add", Seq(SparkDateType,
SparkIntType)),
+ createFunctionWithInputTypes("date_sub", Seq(SparkDateType,
SparkIntType)),
+ createFunctions(
+ "from_unixtime",
+ Seq(
+ FunctionSignature(Seq(SparkLongType)),
+ FunctionSignature(Seq(SparkLongType, SparkStringType)))),
+ createFunctionWithInputTypes("hour", Seq(SparkDateOrTimestampType)),
+ createFunctionWithInputTypes("minute", Seq(SparkDateOrTimestampType)),
+ createFunctionWithInputTypes("second", Seq(SparkDateOrTimestampType)),
+ createFunctionWithInputTypes("trunc", Seq(SparkDateOrTimestampType,
SparkStringType)),
+ createFunctionWithInputTypes("year", Seq(SparkDateOrTimestampType)),
+ createFunctionWithInputTypes("month", Seq(SparkDateOrTimestampType)),
+ createFunctionWithInputTypes("day", Seq(SparkDateOrTimestampType)),
+ createFunctionWithInputTypes("dayofmonth",
Seq(SparkDateOrTimestampType)),
+ createFunctionWithInputTypes("dayofweek", Seq(SparkDateOrTimestampType)),
+ createFunctionWithInputTypes("weekday", Seq(SparkDateOrTimestampType)),
+ createFunctionWithInputTypes("dayofyear", Seq(SparkDateOrTimestampType)),
+ createFunctionWithInputTypes("weekofyear",
Seq(SparkDateOrTimestampType)),
+ createFunctionWithInputTypes("quarter", Seq(SparkDateOrTimestampType)))
+
+ // Combined in same order as exprSerdeMap in QueryPlanSerde
+ val scalarFunc: Seq[Function] = mathScalarFunc ++ hashScalarFunc ++
stringScalarFunc ++
+ conditionalScalarFunc ++ mapScalarFunc ++ predicateScalarFunc ++
+ structScalarFunc ++ bitwiseScalarFunc ++ miscScalarFunc ++ arrayScalarFunc
++
+ temporalScalarFunc
val aggFunc: Seq[Function] = Seq(
- Function("min", 1),
- Function("max", 1),
- Function("count", 1),
- Function("avg", 1),
- Function("sum", 1),
- Function("first", 1),
- Function("last", 1),
- Function("var_pop", 1),
- Function("var_samp", 1),
- Function("covar_pop", 1),
- Function("covar_samp", 1),
- Function("stddev_pop", 1),
- Function("stddev_samp", 1),
- Function("corr", 2))
+ createFunctionWithInputTypes("min", Seq(SparkAnyType)),
+ createFunctionWithInputTypes("max", Seq(SparkAnyType)),
+ createFunctionWithInputTypes("count", Seq(SparkAnyType)),
+ createUnaryNumericFunction("avg"),
+ createUnaryNumericFunction("sum"),
+ // first/last are non-deterministic and known to be incompatible with Spark
+// createFunctionWithInputTypes("first", Seq(SparkAnyType)),
+// createFunctionWithInputTypes("last", Seq(SparkAnyType)),
+ createUnaryNumericFunction("var_pop"),
+ createUnaryNumericFunction("var_samp"),
+ createFunctionWithInputTypes("covar_pop", Seq(SparkNumericType,
SparkNumericType)),
+ createFunctionWithInputTypes("covar_samp", Seq(SparkNumericType,
SparkNumericType)),
+ createUnaryNumericFunction("stddev_pop"),
+ createUnaryNumericFunction("stddev_samp"),
+ createFunctionWithInputTypes("corr", Seq(SparkNumericType,
SparkNumericType)),
+ createFunctionWithInputTypes("bit_and", Seq(SparkIntegralType)),
+ createFunctionWithInputTypes("bit_or", Seq(SparkIntegralType)),
+ createFunctionWithInputTypes("bit_xor", Seq(SparkIntegralType)))
val unaryArithmeticOps: Seq[String] = Seq("+", "-")
@@ -137,4 +346,13 @@ object Meta {
val comparisonOps: Seq[String] = Seq("=", "<=>", ">", ">=", "<", "<=")
+ // TODO make this more comprehensive
+ val comparisonTypes: Seq[SparkType] = Seq(
+ SparkStringType,
+ SparkBinaryType,
+ SparkNumericType,
+ SparkDateType,
+ SparkTimestampType,
+ SparkArrayType(SparkTypeOneOf(Seq(SparkStringType, SparkNumericType,
SparkDateType))))
+
}
diff --git a/fuzz-testing/src/main/scala/org/apache/comet/fuzz/QueryGen.scala
b/fuzz-testing/src/main/scala/org/apache/comet/fuzz/QueryGen.scala
index de1117837..d9e3c147d 100644
--- a/fuzz-testing/src/main/scala/org/apache/comet/fuzz/QueryGen.scala
+++ b/fuzz-testing/src/main/scala/org/apache/comet/fuzz/QueryGen.scala
@@ -24,7 +24,8 @@ import java.io.{BufferedWriter, FileWriter}
import scala.collection.mutable
import scala.util.Random
-import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.{DataFrame, SparkSession}
+import org.apache.spark.sql.types._
object QueryGen {
@@ -42,19 +43,25 @@ object QueryGen {
val uniqueQueries = mutable.HashSet[String]()
for (_ <- 0 until numQueries) {
- val sql = r.nextInt().abs % 8 match {
- case 0 => generateJoin(r, spark, numFiles)
- case 1 => generateAggregate(r, spark, numFiles)
- case 2 => generateScalar(r, spark, numFiles)
- case 3 => generateCast(r, spark, numFiles)
- case 4 => generateUnaryArithmetic(r, spark, numFiles)
- case 5 => generateBinaryArithmetic(r, spark, numFiles)
- case 6 => generateBinaryComparison(r, spark, numFiles)
- case _ => generateConditional(r, spark, numFiles)
- }
- if (!uniqueQueries.contains(sql)) {
- uniqueQueries += sql
- w.write(sql + "\n")
+ try {
+ val sql = r.nextInt().abs % 8 match {
+ case 0 => generateJoin(r, spark, numFiles)
+ case 1 => generateAggregate(r, spark, numFiles)
+ case 2 => generateScalar(r, spark, numFiles)
+ case 3 => generateCast(r, spark, numFiles)
+ case 4 => generateUnaryArithmetic(r, spark, numFiles)
+ case 5 => generateBinaryArithmetic(r, spark, numFiles)
+ case 6 => generateBinaryComparison(r, spark, numFiles)
+ case _ => generateConditional(r, spark, numFiles)
+ }
+ if (!uniqueQueries.contains(sql)) {
+ uniqueQueries += sql
+ w.write(sql + "\n")
+ }
+ } catch {
+ case e: Exception =>
+ // scalastyle:off
+ println(s"Failed to generate query: ${e.getMessage}")
}
}
w.close()
@@ -65,35 +72,177 @@ object QueryGen {
val table = spark.table(tableName)
val func = Utils.randomChoice(Meta.aggFunc, r)
- val args = Range(0, func.num_args)
- .map(_ => Utils.randomChoice(table.columns, r))
+ try {
+ val signature = Utils.randomChoice(func.signatures, r)
+ val args = signature.inputTypes.map(x => pickRandomColumn(r, table, x))
+
+ val groupingCols = Range(0, 2).map(_ =>
Utils.randomChoice(table.columns, r))
+
+ if (groupingCols.isEmpty) {
+ s"SELECT ${args.mkString(", ")}, ${func.name}(${args.mkString(", ")})
AS x " +
+ s"FROM $tableName " +
+ s"ORDER BY ${args.mkString(", ")};"
+ } else {
+ s"SELECT ${groupingCols.mkString(", ")},
${func.name}(${args.mkString(", ")}) " +
+ s"FROM $tableName " +
+ s"GROUP BY ${groupingCols.mkString(",")} " +
+ s"ORDER BY ${groupingCols.mkString(", ")};"
+ }
+ } catch {
+ case e: Exception =>
+ throw new IllegalStateException(
+ s"Failed to generate SQL for aggregate function ${func.name}",
+ e)
+ }
+ }
+
+ private def generateScalar(r: Random, spark: SparkSession, numFiles: Int):
String = {
+ val tableName = s"test${r.nextInt(numFiles)}"
+ val table = spark.table(tableName)
- val groupingCols = Range(0, 2).map(_ => Utils.randomChoice(table.columns,
r))
+ val func = Utils.randomChoice(Meta.scalarFunc, r)
+ try {
+ val signature = Utils.randomChoice(func.signatures, r)
+ val args = signature.inputTypes.map(x => pickRandomColumn(r, table, x))
- if (groupingCols.isEmpty) {
+ // Example SELECT c0, log(c0) as x FROM test0
s"SELECT ${args.mkString(", ")}, ${func.name}(${args.mkString(", ")}) AS
x " +
s"FROM $tableName " +
s"ORDER BY ${args.mkString(", ")};"
- } else {
- s"SELECT ${groupingCols.mkString(", ")}, ${func.name}(${args.mkString(",
")}) " +
- s"FROM $tableName " +
- s"GROUP BY ${groupingCols.mkString(",")} " +
- s"ORDER BY ${groupingCols.mkString(", ")};"
+ } catch {
+ case e: Exception =>
+ throw new IllegalStateException(
+ s"Failed to generate SQL for scalar function ${func.name}",
+ e)
}
}
- private def generateScalar(r: Random, spark: SparkSession, numFiles: Int):
String = {
- val tableName = s"test${r.nextInt(numFiles)}"
- val table = spark.table(tableName)
+ private def pickRandomColumn(r: Random, df: DataFrame, targetType:
SparkType): String = {
+ targetType match {
+ case SparkAnyType =>
+ Utils.randomChoice(df.schema.fields, r).name
+ case SparkBooleanType =>
+ select(r, df, _.dataType == BooleanType)
+ case SparkByteType =>
+ select(r, df, _.dataType == ByteType)
+ case SparkShortType =>
+ select(r, df, _.dataType == ShortType)
+ case SparkIntType =>
+ select(r, df, _.dataType == IntegerType)
+ case SparkLongType =>
+ select(r, df, _.dataType == LongType)
+ case SparkFloatType =>
+ select(r, df, _.dataType == FloatType)
+ case SparkDoubleType =>
+ select(r, df, _.dataType == DoubleType)
+ case SparkDecimalType(_, _) =>
+ select(r, df, _.dataType.isInstanceOf[DecimalType])
+ case SparkIntegralType =>
+ select(
+ r,
+ df,
+ f =>
+ f.dataType == ByteType || f.dataType == ShortType ||
+ f.dataType == IntegerType || f.dataType == LongType)
+ case SparkNumericType =>
+ select(r, df, f => isNumeric(f.dataType))
+ case SparkStringType =>
+ select(r, df, _.dataType == StringType)
+ case SparkBinaryType =>
+ select(r, df, _.dataType == BinaryType)
+ case SparkDateType =>
+ select(r, df, _.dataType == DateType)
+ case SparkTimestampType =>
+ select(r, df, _.dataType == TimestampType)
+ case SparkDateOrTimestampType =>
+ select(r, df, f => f.dataType == DateType || f.dataType ==
TimestampType)
+ case SparkTypeOneOf(choices) =>
+ pickRandomColumn(r, df, Utils.randomChoice(choices, r))
+ case SparkArrayType(elementType) =>
+ select(
+ r,
+ df,
+ _.dataType match {
+ case ArrayType(x, _) if typeMatch(elementType, x) => true
+ case _ => false
+ })
+ case SparkMapType(keyType, valueType) =>
+ select(
+ r,
+ df,
+ _.dataType match {
+ case MapType(k, v, _) if typeMatch(keyType, k) &&
typeMatch(valueType, v) => true
+ case _ => false
+ })
+ case SparkStructType(fields) =>
+ select(
+ r,
+ df,
+ _.dataType match {
+ case StructType(structFields) if structFields.length ==
fields.length => true
+ case _ => false
+ })
+ case _ =>
+ throw new IllegalStateException(targetType.toString)
+ }
+ }
- val func = Utils.randomChoice(Meta.scalarFunc, r)
- val args = Range(0, func.num_args)
- .map(_ => Utils.randomChoice(table.columns, r))
+ def pickTwoRandomColumns(r: Random, df: DataFrame, targetType: SparkType):
(String, String) = {
+ val a = pickRandomColumn(r, df, targetType)
+ val df2 = df.drop(a)
+ val b = pickRandomColumn(r, df2, targetType)
+ (a, b)
+ }
- // Example SELECT c0, log(c0) as x FROM test0
- s"SELECT ${args.mkString(", ")}, ${func.name}(${args.mkString(", ")}) AS x
" +
- s"FROM $tableName " +
- s"ORDER BY ${args.mkString(", ")};"
+ /** Select a random field that matches a predicate */
+ private def select(r: Random, df: DataFrame, predicate: StructField =>
Boolean): String = {
+ val candidates = df.schema.fields.filter(predicate)
+ if (candidates.isEmpty) {
+ throw new IllegalStateException("Failed to find suitable column")
+ }
+ Utils.randomChoice(candidates, r).name
+ }
+
+ private def isNumeric(d: DataType): Boolean = {
+ d match {
+ case _: ByteType | _: ShortType | _: IntegerType | _: LongType | _:
FloatType |
+ _: DoubleType | _: DecimalType =>
+ true
+ case _ => false
+ }
+ }
+
+ private def typeMatch(s: SparkType, d: DataType): Boolean = {
+ (s, d) match {
+ case (SparkAnyType, _) => true
+ case (SparkBooleanType, BooleanType) => true
+ case (SparkByteType, ByteType) => true
+ case (SparkShortType, ShortType) => true
+ case (SparkIntType, IntegerType) => true
+ case (SparkLongType, LongType) => true
+ case (SparkFloatType, FloatType) => true
+ case (SparkDoubleType, DoubleType) => true
+ case (SparkDecimalType(_, _), _: DecimalType) => true
+ case (SparkIntegralType, ByteType | ShortType | IntegerType | LongType)
=> true
+ case (SparkNumericType, _) if isNumeric(d) => true
+ case (SparkStringType, StringType) => true
+ case (SparkBinaryType, BinaryType) => true
+ case (SparkDateType, DateType) => true
+ case (SparkTimestampType, TimestampType | TimestampNTZType) => true
+ case (SparkDateOrTimestampType, DateType | TimestampType |
TimestampNTZType) => true
+ case (SparkArrayType(elementType), ArrayType(elementDataType, _)) =>
+ typeMatch(elementType, elementDataType)
+ case (SparkMapType(keyType, valueType), MapType(keyDataType,
valueDataType, _)) =>
+ typeMatch(keyType, keyDataType) && typeMatch(valueType, valueDataType)
+ case (SparkStructType(fields), StructType(structFields)) =>
+ fields.length == structFields.length &&
+ fields.zip(structFields.map(_.dataType)).forall { case (sparkType,
dataType) =>
+ typeMatch(sparkType, dataType)
+ }
+ case (SparkTypeOneOf(choices), _) =>
+ choices.exists(choice => typeMatch(choice, d))
+ case _ => false
+ }
}
private def generateUnaryArithmetic(r: Random, spark: SparkSession,
numFiles: Int): String = {
@@ -101,7 +250,7 @@ object QueryGen {
val table = spark.table(tableName)
val op = Utils.randomChoice(Meta.unaryArithmeticOps, r)
- val a = Utils.randomChoice(table.columns, r)
+ val a = pickRandomColumn(r, table, SparkNumericType)
// Example SELECT a, -a FROM test0
s"SELECT $a, $op$a " +
@@ -114,8 +263,7 @@ object QueryGen {
val table = spark.table(tableName)
val op = Utils.randomChoice(Meta.binaryArithmeticOps, r)
- val a = Utils.randomChoice(table.columns, r)
- val b = Utils.randomChoice(table.columns, r)
+ val (a, b) = pickTwoRandomColumns(r, table, SparkNumericType)
// Example SELECT a, b, a+b FROM test0
s"SELECT $a, $b, $a $op $b " +
@@ -128,8 +276,10 @@ object QueryGen {
val table = spark.table(tableName)
val op = Utils.randomChoice(Meta.comparisonOps, r)
- val a = Utils.randomChoice(table.columns, r)
- val b = Utils.randomChoice(table.columns, r)
+
+ // pick two columns with the same type
+ val opType = Utils.randomChoice(Meta.comparisonTypes, r)
+ val (a, b) = pickTwoRandomColumns(r, table, opType)
// Example SELECT a, b, a <=> b FROM test0
s"SELECT $a, $b, $a $op $b " +
@@ -142,8 +292,10 @@ object QueryGen {
val table = spark.table(tableName)
val op = Utils.randomChoice(Meta.comparisonOps, r)
- val a = Utils.randomChoice(table.columns, r)
- val b = Utils.randomChoice(table.columns, r)
+
+ // pick two columns with the same type
+ val opType = Utils.randomChoice(Meta.comparisonTypes, r)
+ val (a, b) = pickTwoRandomColumns(r, table, opType)
// Example SELECT a, b, IF(a <=> b, 1, 2), CASE WHEN a <=> b THEN 1 ELSE 2
END FROM test0
s"SELECT $a, $b, $a $op $b, IF($a $op $b, 1, 2), CASE WHEN $a $op $b THEN
1 ELSE 2 END " +
@@ -192,5 +344,3 @@ object QueryGen {
}
}
-
-case class Function(name: String, num_args: Int)
diff --git
a/fuzz-testing/src/main/scala/org/apache/comet/fuzz/QueryRunner.scala
b/fuzz-testing/src/main/scala/org/apache/comet/fuzz/QueryRunner.scala
index 8852f4bc1..bcc9f98d0 100644
--- a/fuzz-testing/src/main/scala/org/apache/comet/fuzz/QueryRunner.scala
+++ b/fuzz-testing/src/main/scala/org/apache/comet/fuzz/QueryRunner.scala
@@ -34,6 +34,11 @@ object QueryRunner {
filename: String,
showFailedSparkQueries: Boolean = false): Unit = {
+ var queryCount = 0
+ var invalidQueryCount = 0
+ var cometFailureCount = 0
+ var cometSuccessCount = 0
+
val outputFilename = s"results-${System.currentTimeMillis()}.md"
// scalastyle:off println
println(s"Writing results to $outputFilename")
@@ -56,7 +61,7 @@ object QueryRunner {
querySource
.getLines()
.foreach(sql => {
-
+ queryCount += 1
try {
// execute with Spark
spark.conf.set("spark.comet.enabled", "false")
@@ -67,13 +72,11 @@ object QueryRunner {
// execute with Comet
try {
spark.conf.set("spark.comet.enabled", "true")
- // complex type support until we support it natively
- spark.conf.set("spark.comet.sparkToColumnar.enabled", "true")
- spark.conf.set("spark.comet.convert.parquet.enabled", "true")
val df = spark.sql(sql)
val cometRows = df.collect()
val cometPlan = df.queryExecution.executedPlan.toString
+ var success = true
if (sparkRows.length == cometRows.length) {
var i = 0
while (i < sparkRows.length) {
@@ -82,6 +85,7 @@ object QueryRunner {
assert(l.length == r.length)
for (j <- 0 until l.length) {
if (!same(l(j), r(j))) {
+ success = false
showSQL(w, sql)
showPlans(w, sparkPlan, cometPlan)
w.write(s"First difference at row $i:\n")
@@ -93,16 +97,36 @@ object QueryRunner {
i += 1
}
} else {
+ success = false
showSQL(w, sql)
showPlans(w, sparkPlan, cometPlan)
w.write(
s"[ERROR] Spark produced ${sparkRows.length} rows and " +
s"Comet produced ${cometRows.length} rows.\n")
}
+
+ // check that the plan contains Comet operators
+ if (!cometPlan.contains("Comet")) {
+ success = false
+ showSQL(w, sql)
+ showPlans(w, sparkPlan, cometPlan)
+ w.write("[ERROR] Comet did not accelerate any part of the
plan\n")
+ }
+
+ if (success) {
+ cometSuccessCount += 1
+ } else {
+ cometFailureCount += 1
+ }
+
} catch {
case e: Exception =>
// the query worked in Spark but failed in Comet, so this is
likely a bug in Comet
+ cometFailureCount += 1
showSQL(w, sql)
+ w.write("### Spark Plan\n")
+ w.write(s"```\n$sparkPlan\n```\n")
+
w.write(s"[ERROR] Query failed in Comet: ${e.getMessage}:\n")
w.write("```\n")
val sw = new StringWriter()
@@ -119,6 +143,7 @@ object QueryRunner {
} catch {
case e: Exception =>
// we expect many generated queries to be invalid
+ invalidQueryCount += 1
if (showFailedSparkQueries) {
showSQL(w, sql)
w.write(s"Query failed in Spark: ${e.getMessage}\n")
@@ -126,6 +151,11 @@ object QueryRunner {
}
})
+ w.write("# Summary\n")
+ w.write(
+ s"Total queries: $queryCount; Invalid queries: $invalidQueryCount; " +
+ s"Comet failed: $cometFailureCount; Comet succeeded:
$cometSuccessCount\n")
+
} finally {
w.close()
querySource.close()
@@ -133,10 +163,17 @@ object QueryRunner {
}
private def same(l: Any, r: Any): Boolean = {
+ if (l == null || r == null) {
+ return l == null && r == null
+ }
(l, r) match {
+ case (a: Float, b: Float) if a.isPosInfinity => b.isPosInfinity
+ case (a: Float, b: Float) if a.isNegInfinity => b.isNegInfinity
case (a: Float, b: Float) if a.isInfinity => b.isInfinity
case (a: Float, b: Float) if a.isNaN => b.isNaN
case (a: Float, b: Float) => (a - b).abs <= 0.000001f
+ case (a: Double, b: Double) if a.isPosInfinity => b.isPosInfinity
+ case (a: Double, b: Double) if a.isNegInfinity => b.isNegInfinity
case (a: Double, b: Double) if a.isInfinity => b.isInfinity
case (a: Double, b: Double) if a.isNaN => b.isNaN
case (a: Double, b: Double) => (a - b).abs <= 0.000001
@@ -144,6 +181,10 @@ object QueryRunner {
a.length == b.length && a.zip(b).forall(x => same(x._1, x._2))
case (a: WrappedArray[_], b: WrappedArray[_]) =>
a.length == b.length && a.zip(b).forall(x => same(x._1, x._2))
+ case (a: Row, b: Row) =>
+ val aa = a.toSeq
+ val bb = b.toSeq
+ aa.length == bb.length && aa.zip(bb).forall(x => same(x._1, x._2))
case (a, b) => a == b
}
}
@@ -153,6 +194,7 @@ object QueryRunner {
case null => "NULL"
case v: WrappedArray[_] => s"[${v.map(format).mkString(",")}]"
case v: Array[Byte] => s"[${v.mkString(",")}]"
+ case r: Row => formatRow(r)
case other => other.toString
}
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]