This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new 5fb88b8c7 feat: Define function signatures in CometFuzz (#2614)
5fb88b8c7 is described below

commit 5fb88b8c7112adacc9295dc175ff8a3a6c73a613
Author: Andy Grove <[email protected]>
AuthorDate: Fri Oct 24 09:01:14 2025 -0600

    feat: Define function signatures in CometFuzz (#2614)
---
 fuzz-testing/README.md                             |  11 +-
 .../main/scala/org/apache/comet/fuzz/Main.scala    |   6 +-
 .../main/scala/org/apache/comet/fuzz/Meta.scala    | 382 ++++++++++++++++-----
 .../scala/org/apache/comet/fuzz/QueryGen.scala     | 234 ++++++++++---
 .../scala/org/apache/comet/fuzz/QueryRunner.scala  |  50 ++-
 5 files changed, 548 insertions(+), 135 deletions(-)

diff --git a/fuzz-testing/README.md b/fuzz-testing/README.md
index 17b2c151a..c8cea5be8 100644
--- a/fuzz-testing/README.md
+++ b/fuzz-testing/README.md
@@ -61,7 +61,7 @@ Set appropriate values for `SPARK_HOME`, `SPARK_MASTER`, and 
`COMET_JAR` environ
 $SPARK_HOME/bin/spark-submit \
     --master $SPARK_MASTER \
     --class org.apache.comet.fuzz.Main \
-    target/comet-fuzz-spark3.4_2.12-0.7.0-SNAPSHOT-jar-with-dependencies.jar \
+    target/comet-fuzz-spark3.5_2.12-0.12.0-SNAPSHOT-jar-with-dependencies.jar \
     data --num-files=2 --num-rows=200 --exclude-negative-zero 
--generate-arrays --generate-structs --generate-maps
 ```
 
@@ -77,7 +77,7 @@ Generate random queries that are based on the available test 
files.
 $SPARK_HOME/bin/spark-submit \
     --master $SPARK_MASTER \
     --class org.apache.comet.fuzz.Main \
-    target/comet-fuzz-spark3.4_2.12-0.7.0-SNAPSHOT-jar-with-dependencies.jar \
+    target/comet-fuzz-spark3.5_2.12-0.12.0-SNAPSHOT-jar-with-dependencies.jar \
     queries --num-files=2 --num-queries=500
 ```
 
@@ -88,18 +88,17 @@ Note that the output filename is currently hard-coded as 
`queries.sql`
 ```shell
 $SPARK_HOME/bin/spark-submit \
     --master $SPARK_MASTER \
+    --conf spark.memory.offHeap.enabled=true \
+    --conf spark.memory.offHeap.size=16G \
     --conf spark.plugins=org.apache.spark.CometPlugin \
     --conf spark.comet.enabled=true \
-    --conf spark.comet.exec.enabled=true \
-    --conf spark.comet.exec.all.enabled=true \
     --conf 
spark.shuffle.manager=org.apache.spark.sql.comet.execution.shuffle.CometShuffleManager
 \
     --conf spark.comet.exec.shuffle.enabled=true \
-    --conf spark.comet.exec.shuffle.mode=auto \
     --jars $COMET_JAR \
     --conf spark.driver.extraClassPath=$COMET_JAR \
     --conf spark.executor.extraClassPath=$COMET_JAR \
     --class org.apache.comet.fuzz.Main \
-    target/comet-fuzz-spark3.4_2.12-0.7.0-SNAPSHOT-jar-with-dependencies.jar \
+    target/comet-fuzz-spark3.5_2.12-0.12.0-SNAPSHOT-jar-with-dependencies.jar \
     run --num-files=2 --filename=queries.sql
 ```
 
diff --git a/fuzz-testing/src/main/scala/org/apache/comet/fuzz/Main.scala 
b/fuzz-testing/src/main/scala/org/apache/comet/fuzz/Main.scala
index 1f81dc779..b9e63c76a 100644
--- a/fuzz-testing/src/main/scala/org/apache/comet/fuzz/Main.scala
+++ b/fuzz-testing/src/main/scala/org/apache/comet/fuzz/Main.scala
@@ -87,7 +87,11 @@ object Main {
             SchemaGenOptions(
               generateArray = conf.generateData.generateArrays(),
               generateStruct = conf.generateData.generateStructs(),
-              generateMap = conf.generateData.generateMaps()),
+              generateMap = conf.generateData.generateMaps(),
+              // create two columns of each primitive type so that they can be 
used in binary
+              // expressions such as `a + b` and `a <  b`
+              primitiveTypes = SchemaGenOptions.defaultPrimitiveTypes ++
+                SchemaGenOptions.defaultPrimitiveTypes),
             DataGenOptions(
               allowNull = true,
               generateNegativeZero = !conf.generateData.excludeNegativeZero()))
diff --git a/fuzz-testing/src/main/scala/org/apache/comet/fuzz/Meta.scala 
b/fuzz-testing/src/main/scala/org/apache/comet/fuzz/Meta.scala
index 246216840..74d13f85e 100644
--- a/fuzz-testing/src/main/scala/org/apache/comet/fuzz/Meta.scala
+++ b/fuzz-testing/src/main/scala/org/apache/comet/fuzz/Meta.scala
@@ -22,6 +22,32 @@ package org.apache.comet.fuzz
 import org.apache.spark.sql.types.DataType
 import org.apache.spark.sql.types.DataTypes
 
+sealed trait SparkType
+case class SparkTypeOneOf(dataTypes: Seq[SparkType]) extends SparkType
+case object SparkBooleanType extends SparkType
+case object SparkBinaryType extends SparkType
+case object SparkStringType extends SparkType
+case object SparkIntegralType extends SparkType
+case object SparkByteType extends SparkType
+case object SparkShortType extends SparkType
+case object SparkIntType extends SparkType
+case object SparkLongType extends SparkType
+case object SparkFloatType extends SparkType
+case object SparkDoubleType extends SparkType
+case class SparkDecimalType(p: Int, s: Int) extends SparkType
+case object SparkNumericType extends SparkType
+case object SparkDateType extends SparkType
+case object SparkTimestampType extends SparkType
+case object SparkDateOrTimestampType extends SparkType
+case class SparkArrayType(elementType: SparkType) extends SparkType
+case class SparkMapType(keyType: SparkType, valueType: SparkType) extends 
SparkType
+case class SparkStructType(fields: Seq[SparkType]) extends SparkType
+case object SparkAnyType extends SparkType
+
+case class FunctionSignature(inputTypes: Seq[SparkType])
+
+case class Function(name: String, signatures: Seq[FunctionSignature])
+
 object Meta {
 
   val dataTypes: Seq[(DataType, Double)] = Seq(
@@ -35,100 +61,283 @@ object Meta {
     (DataTypes.createDecimalType(10, 2), 0.2),
     (DataTypes.DateType, 0.2),
     (DataTypes.TimestampType, 0.2),
-    // TimestampNTZType only in Spark 3.4+
-    // (DataTypes.TimestampNTZType, 0.2),
+    (DataTypes.TimestampNTZType, 0.2),
     (DataTypes.StringType, 0.2),
     (DataTypes.BinaryType, 0.1))
 
-  val stringScalarFunc: Seq[Function] = Seq(
-    Function("substring", 3),
-    Function("coalesce", 1),
-    Function("starts_with", 2),
-    Function("ends_with", 2),
-    Function("contains", 2),
-    Function("ascii", 1),
-    Function("bit_length", 1),
-    Function("octet_length", 1),
-    Function("upper", 1),
-    Function("lower", 1),
-    Function("chr", 1),
-    Function("init_cap", 1),
-    Function("trim", 1),
-    Function("ltrim", 1),
-    Function("rtrim", 1),
-    Function("string_space", 1),
-    Function("rpad", 2),
-    Function("rpad", 3), // rpad can have 2 or 3 arguments
-    Function("hex", 1),
-    Function("unhex", 1),
-    Function("xxhash64", 1),
-    Function("sha1", 1),
-    // Function("sha2", 1), -- needs a second argument for number of bits
-    Function("substring", 3),
-    Function("btrim", 1),
-    Function("concat_ws", 2),
-    Function("repeat", 2),
-    Function("length", 1),
-    Function("reverse", 1),
-    Function("instr", 2),
-    Function("replace", 2),
-    Function("translate", 2))
-
-  val dateScalarFunc: Seq[Function] =
-    Seq(Function("year", 1), Function("hour", 1), Function("minute", 1), 
Function("second", 1))
+  private def createFunctionWithInputTypes(name: String, inputs: 
Seq[SparkType]): Function = {
+    Function(name, Seq(FunctionSignature(inputs)))
+  }
+
+  private def createFunctions(name: String, signatures: 
Seq[FunctionSignature]): Function = {
+    Function(name, signatures)
+  }
 
+  private def createUnaryStringFunction(name: String): Function = {
+    createFunctionWithInputTypes(name, Seq(SparkStringType))
+  }
+
+  private def createUnaryNumericFunction(name: String): Function = {
+    createFunctionWithInputTypes(name, Seq(SparkNumericType))
+  }
+
+  // Math expressions (corresponds to mathExpressions in QueryPlanSerde)
   val mathScalarFunc: Seq[Function] = Seq(
-    Function("abs", 1),
-    Function("acos", 1),
-    Function("asin", 1),
-    Function("atan", 1),
-    Function("Atan2", 1),
-    Function("Cos", 1),
-    Function("Exp", 2),
-    Function("Ln", 1),
-    Function("Log10", 1),
-    Function("Log2", 1),
-    Function("Pow", 2),
-    Function("Round", 1),
-    Function("Signum", 1),
-    Function("Sin", 1),
-    Function("Sqrt", 1),
-    Function("Tan", 1),
-    Function("Ceil", 1),
-    Function("Floor", 1),
-    Function("bool_and", 1),
-    Function("bool_or", 1),
-    Function("bitwise_not", 1))
+    createUnaryNumericFunction("abs"),
+    createUnaryNumericFunction("acos"),
+    createUnaryNumericFunction("asin"),
+    createUnaryNumericFunction("atan"),
+    createFunctionWithInputTypes("atan2", Seq(SparkNumericType, 
SparkNumericType)),
+    createUnaryNumericFunction("cos"),
+    createUnaryNumericFunction("exp"),
+    createUnaryNumericFunction("expm1"),
+    createFunctionWithInputTypes("log", Seq(SparkNumericType, 
SparkNumericType)),
+    createUnaryNumericFunction("log10"),
+    createUnaryNumericFunction("log2"),
+    createFunctionWithInputTypes("pow", Seq(SparkNumericType, 
SparkNumericType)),
+    createFunctionWithInputTypes("remainder", Seq(SparkNumericType, 
SparkNumericType)),
+    createFunctions(
+      "round",
+      Seq(
+        FunctionSignature(Seq(SparkNumericType)),
+        FunctionSignature(Seq(SparkNumericType, SparkIntType)))),
+    createUnaryNumericFunction("signum"),
+    createUnaryNumericFunction("sin"),
+    createUnaryNumericFunction("sqrt"),
+    createUnaryNumericFunction("tan"),
+    createUnaryNumericFunction("ceil"),
+    createUnaryNumericFunction("floor"),
+    createFunctionWithInputTypes("unary_minus", Seq(SparkNumericType)))
 
+  // Hash expressions (corresponds to hashExpressions in QueryPlanSerde)
+  val hashScalarFunc: Seq[Function] = Seq(
+    createFunctionWithInputTypes("md5", Seq(SparkAnyType)),
+    createFunctionWithInputTypes("murmur3_hash", Seq(SparkAnyType)), // TODO 
variadic
+    createFunctionWithInputTypes("sha2", Seq(SparkAnyType, SparkIntType)))
+
+  // String expressions (corresponds to stringExpressions in QueryPlanSerde)
+  val stringScalarFunc: Seq[Function] = Seq(
+    createUnaryStringFunction("ascii"),
+    createUnaryStringFunction("bit_length"),
+    createUnaryStringFunction("chr"),
+    createFunctionWithInputTypes(
+      "concat",
+      Seq(
+        SparkTypeOneOf(
+          Seq(
+            SparkStringType,
+            SparkNumericType,
+            SparkBinaryType,
+            SparkArrayType(
+              SparkTypeOneOf(Seq(SparkStringType, SparkNumericType, 
SparkBinaryType))))),
+        SparkTypeOneOf(
+          Seq(
+            SparkStringType,
+            SparkNumericType,
+            SparkBinaryType,
+            SparkArrayType(
+              SparkTypeOneOf(Seq(SparkStringType, SparkNumericType, 
SparkBinaryType))))))),
+    createFunctionWithInputTypes("concat_ws", Seq(SparkStringType, 
SparkStringType)),
+    createFunctionWithInputTypes("contains", Seq(SparkStringType, 
SparkStringType)),
+    createFunctionWithInputTypes("ends_with", Seq(SparkStringType, 
SparkStringType)),
+    createFunctionWithInputTypes(
+      "hex",
+      Seq(SparkTypeOneOf(Seq(SparkStringType, SparkBinaryType, SparkIntType, 
SparkLongType)))),
+    createUnaryStringFunction("init_cap"),
+    createFunctionWithInputTypes("instr", Seq(SparkStringType, 
SparkStringType)),
+    createFunctionWithInputTypes(
+      "length",
+      Seq(SparkTypeOneOf(Seq(SparkStringType, SparkBinaryType)))),
+    createFunctionWithInputTypes("like", Seq(SparkStringType, 
SparkStringType)),
+    createUnaryStringFunction("lower"),
+    createFunctions(
+      "lpad",
+      Seq(
+        FunctionSignature(Seq(SparkStringType, SparkIntegralType)),
+        FunctionSignature(Seq(SparkStringType, SparkIntegralType, 
SparkStringType)))),
+    createUnaryStringFunction("ltrim"),
+    createUnaryStringFunction("octet_length"),
+    createFunctions(
+      "regexp_replace",
+      Seq(
+        FunctionSignature(Seq(SparkStringType, SparkStringType, 
SparkStringType)),
+        FunctionSignature(Seq(SparkStringType, SparkStringType, 
SparkStringType, SparkIntType)))),
+    createFunctionWithInputTypes("repeat", Seq(SparkStringType, SparkIntType)),
+    createFunctions(
+      "replace",
+      Seq(
+        FunctionSignature(Seq(SparkStringType, SparkStringType)),
+        FunctionSignature(Seq(SparkStringType, SparkStringType, 
SparkStringType)))),
+    createFunctions(
+      "reverse",
+      Seq(
+        FunctionSignature(Seq(SparkStringType)),
+        FunctionSignature(Seq(SparkArrayType(SparkAnyType))))),
+    createFunctionWithInputTypes("rlike", Seq(SparkStringType, 
SparkStringType)),
+    createFunctions(
+      "rpad",
+      Seq(
+        FunctionSignature(Seq(SparkStringType, SparkIntegralType)),
+        FunctionSignature(Seq(SparkStringType, SparkIntegralType, 
SparkStringType)))),
+    createUnaryStringFunction("rtrim"),
+    createFunctionWithInputTypes("starts_with", Seq(SparkStringType, 
SparkStringType)),
+    createFunctionWithInputTypes("string_space", Seq(SparkIntType)),
+    createFunctionWithInputTypes("substring", Seq(SparkStringType, 
SparkIntType, SparkIntType)),
+    createFunctionWithInputTypes("translate", Seq(SparkStringType, 
SparkStringType)),
+    createUnaryStringFunction("trim"),
+    createUnaryStringFunction("btrim"),
+    createUnaryStringFunction("unhex"),
+    createUnaryStringFunction("upper"),
+    createFunctionWithInputTypes("xxhash64", Seq(SparkAnyType)), // TODO 
variadic
+    createFunctionWithInputTypes("sha1", Seq(SparkAnyType)))
+
+  // Conditional expressions (corresponds to conditionalExpressions in 
QueryPlanSerde)
+  val conditionalScalarFunc: Seq[Function] = Seq(
+    createFunctionWithInputTypes("if", Seq(SparkBooleanType, SparkAnyType, 
SparkAnyType)))
+
+  // Map expressions (corresponds to mapExpressions in QueryPlanSerde)
+  val mapScalarFunc: Seq[Function] = Seq(
+    createFunctionWithInputTypes(
+      "map_extract",
+      Seq(SparkMapType(SparkAnyType, SparkAnyType), SparkAnyType)),
+    createFunctionWithInputTypes("map_keys", Seq(SparkMapType(SparkAnyType, 
SparkAnyType))),
+    createFunctionWithInputTypes("map_entries", Seq(SparkMapType(SparkAnyType, 
SparkAnyType))),
+    createFunctionWithInputTypes("map_values", Seq(SparkMapType(SparkAnyType, 
SparkAnyType))),
+    createFunctionWithInputTypes(
+      "map_from_arrays",
+      Seq(SparkArrayType(SparkAnyType), SparkArrayType(SparkAnyType))))
+
+  // Predicate expressions (corresponds to predicateExpressions in 
QueryPlanSerde)
+  val predicateScalarFunc: Seq[Function] = Seq(
+    createFunctionWithInputTypes("and", Seq(SparkBooleanType, 
SparkBooleanType)),
+    createFunctionWithInputTypes("or", Seq(SparkBooleanType, 
SparkBooleanType)),
+    createFunctionWithInputTypes("not", Seq(SparkBooleanType)),
+    createFunctionWithInputTypes("in", Seq(SparkAnyType, SparkAnyType))
+  ) // TODO: variadic
+
+  // Struct expressions (corresponds to structExpressions in QueryPlanSerde)
+  val structScalarFunc: Seq[Function] = Seq(
+    createFunctionWithInputTypes(
+      "create_named_struct",
+      Seq(SparkStringType, SparkAnyType)
+    ), // TODO: variadic name/value pairs
+    createFunctionWithInputTypes(
+      "get_struct_field",
+      Seq(SparkStructType(Seq(SparkAnyType)), SparkStringType)))
+
+  // Bitwise expressions (corresponds to bitwiseExpressions in QueryPlanSerde)
+  val bitwiseScalarFunc: Seq[Function] = Seq(
+    createFunctionWithInputTypes("bitwise_and", Seq(SparkIntegralType, 
SparkIntegralType)),
+    createFunctionWithInputTypes("bitwise_count", Seq(SparkIntegralType)),
+    createFunctionWithInputTypes("bitwise_get", Seq(SparkIntegralType, 
SparkIntType)),
+    createFunctionWithInputTypes("bitwise_or", Seq(SparkIntegralType, 
SparkIntegralType)),
+    createFunctionWithInputTypes("bitwise_not", Seq(SparkIntegralType)),
+    createFunctionWithInputTypes("bitwise_xor", Seq(SparkIntegralType, 
SparkIntegralType)),
+    createFunctionWithInputTypes("shift_left", Seq(SparkIntegralType, 
SparkIntType)),
+    createFunctionWithInputTypes("shift_right", Seq(SparkIntegralType, 
SparkIntType)))
+
+  // Misc expressions (corresponds to miscExpressions in QueryPlanSerde)
   val miscScalarFunc: Seq[Function] =
-    Seq(Function("isnan", 1), Function("isnull", 1), Function("isnotnull", 1))
+    Seq(
+      createFunctionWithInputTypes("isnan", Seq(SparkNumericType)),
+      createFunctionWithInputTypes("isnull", Seq(SparkAnyType)),
+      createFunctionWithInputTypes("isnotnull", Seq(SparkAnyType)),
+      createFunctionWithInputTypes("coalesce", Seq(SparkAnyType, SparkAnyType))
+    ) // TODO: variadic
 
+  // Array expressions (corresponds to arrayExpressions in QueryPlanSerde)
   val arrayScalarFunc: Seq[Function] = Seq(
-    Function("array", 2),
-    Function("array_remove", 2),
-    Function("array_insert", 2),
-    Function("array_contains", 2),
-    Function("array_intersect", 2),
-    Function("array_append", 2))
+    createFunctionWithInputTypes("array_append", 
Seq(SparkArrayType(SparkAnyType), SparkAnyType)),
+    createFunctionWithInputTypes("array_compact", 
Seq(SparkArrayType(SparkAnyType))),
+    createFunctionWithInputTypes(
+      "array_contains",
+      Seq(SparkArrayType(SparkAnyType), SparkAnyType)),
+    createFunctionWithInputTypes("array_distinct", 
Seq(SparkArrayType(SparkAnyType))),
+    createFunctionWithInputTypes(
+      "array_except",
+      Seq(SparkArrayType(SparkAnyType), SparkArrayType(SparkAnyType))),
+    createFunctionWithInputTypes(
+      "array_insert",
+      Seq(SparkArrayType(SparkAnyType), SparkIntType, SparkAnyType)),
+    createFunctionWithInputTypes(
+      "array_intersect",
+      Seq(SparkArrayType(SparkAnyType), SparkArrayType(SparkAnyType))),
+    createFunctions(
+      "array_join",
+      Seq(
+        FunctionSignature(Seq(SparkArrayType(SparkAnyType), SparkStringType)),
+        FunctionSignature(Seq(SparkArrayType(SparkAnyType), SparkStringType, 
SparkStringType)))),
+    createFunctionWithInputTypes("array_max", 
Seq(SparkArrayType(SparkAnyType))),
+    createFunctionWithInputTypes("array_min", 
Seq(SparkArrayType(SparkAnyType))),
+    createFunctionWithInputTypes("array_remove", 
Seq(SparkArrayType(SparkAnyType), SparkAnyType)),
+    createFunctionWithInputTypes("array_repeat", Seq(SparkAnyType, 
SparkIntType)),
+    createFunctionWithInputTypes(
+      "arrays_overlap",
+      Seq(SparkArrayType(SparkAnyType), SparkArrayType(SparkAnyType))),
+    createFunctionWithInputTypes(
+      "array_union",
+      Seq(SparkArrayType(SparkAnyType), SparkArrayType(SparkAnyType))),
+    createFunctionWithInputTypes("array", Seq(SparkAnyType, SparkAnyType)), // 
TODO: variadic
+    createFunctionWithInputTypes(
+      "element_at",
+      Seq(
+        SparkTypeOneOf(
+          Seq(SparkArrayType(SparkAnyType), SparkMapType(SparkAnyType, 
SparkAnyType))),
+        SparkAnyType)),
+    createFunctionWithInputTypes("flatten", 
Seq(SparkArrayType(SparkArrayType(SparkAnyType)))),
+    createFunctionWithInputTypes(
+      "get_array_item",
+      Seq(SparkArrayType(SparkAnyType), SparkIntType)))
 
-  val scalarFunc: Seq[Function] = stringScalarFunc ++ dateScalarFunc ++
-    mathScalarFunc ++ miscScalarFunc ++ arrayScalarFunc
+  // Temporal expressions (corresponds to temporalExpressions in 
QueryPlanSerde)
+  val temporalScalarFunc: Seq[Function] =
+    Seq(
+      createFunctionWithInputTypes("date_add", Seq(SparkDateType, 
SparkIntType)),
+      createFunctionWithInputTypes("date_sub", Seq(SparkDateType, 
SparkIntType)),
+      createFunctions(
+        "from_unixtime",
+        Seq(
+          FunctionSignature(Seq(SparkLongType)),
+          FunctionSignature(Seq(SparkLongType, SparkStringType)))),
+      createFunctionWithInputTypes("hour", Seq(SparkDateOrTimestampType)),
+      createFunctionWithInputTypes("minute", Seq(SparkDateOrTimestampType)),
+      createFunctionWithInputTypes("second", Seq(SparkDateOrTimestampType)),
+      createFunctionWithInputTypes("trunc", Seq(SparkDateOrTimestampType, 
SparkStringType)),
+      createFunctionWithInputTypes("year", Seq(SparkDateOrTimestampType)),
+      createFunctionWithInputTypes("month", Seq(SparkDateOrTimestampType)),
+      createFunctionWithInputTypes("day", Seq(SparkDateOrTimestampType)),
+      createFunctionWithInputTypes("dayofmonth", 
Seq(SparkDateOrTimestampType)),
+      createFunctionWithInputTypes("dayofweek", Seq(SparkDateOrTimestampType)),
+      createFunctionWithInputTypes("weekday", Seq(SparkDateOrTimestampType)),
+      createFunctionWithInputTypes("dayofyear", Seq(SparkDateOrTimestampType)),
+      createFunctionWithInputTypes("weekofyear", 
Seq(SparkDateOrTimestampType)),
+      createFunctionWithInputTypes("quarter", Seq(SparkDateOrTimestampType)))
+
+  // Combined in same order as exprSerdeMap in QueryPlanSerde
+  val scalarFunc: Seq[Function] = mathScalarFunc ++ hashScalarFunc ++ 
stringScalarFunc ++
+    conditionalScalarFunc ++ mapScalarFunc ++ predicateScalarFunc ++
+    structScalarFunc ++ bitwiseScalarFunc ++ miscScalarFunc ++ arrayScalarFunc 
++
+    temporalScalarFunc
 
   val aggFunc: Seq[Function] = Seq(
-    Function("min", 1),
-    Function("max", 1),
-    Function("count", 1),
-    Function("avg", 1),
-    Function("sum", 1),
-    Function("first", 1),
-    Function("last", 1),
-    Function("var_pop", 1),
-    Function("var_samp", 1),
-    Function("covar_pop", 1),
-    Function("covar_samp", 1),
-    Function("stddev_pop", 1),
-    Function("stddev_samp", 1),
-    Function("corr", 2))
+    createFunctionWithInputTypes("min", Seq(SparkAnyType)),
+    createFunctionWithInputTypes("max", Seq(SparkAnyType)),
+    createFunctionWithInputTypes("count", Seq(SparkAnyType)),
+    createUnaryNumericFunction("avg"),
+    createUnaryNumericFunction("sum"),
+    // first/last are non-deterministic and known to be incompatible with Spark
+//    createFunctionWithInputTypes("first", Seq(SparkAnyType)),
+//    createFunctionWithInputTypes("last", Seq(SparkAnyType)),
+    createUnaryNumericFunction("var_pop"),
+    createUnaryNumericFunction("var_samp"),
+    createFunctionWithInputTypes("covar_pop", Seq(SparkNumericType, 
SparkNumericType)),
+    createFunctionWithInputTypes("covar_samp", Seq(SparkNumericType, 
SparkNumericType)),
+    createUnaryNumericFunction("stddev_pop"),
+    createUnaryNumericFunction("stddev_samp"),
+    createFunctionWithInputTypes("corr", Seq(SparkNumericType, 
SparkNumericType)),
+    createFunctionWithInputTypes("bit_and", Seq(SparkIntegralType)),
+    createFunctionWithInputTypes("bit_or", Seq(SparkIntegralType)),
+    createFunctionWithInputTypes("bit_xor", Seq(SparkIntegralType)))
 
   val unaryArithmeticOps: Seq[String] = Seq("+", "-")
 
@@ -137,4 +346,13 @@ object Meta {
 
   val comparisonOps: Seq[String] = Seq("=", "<=>", ">", ">=", "<", "<=")
 
+  // TODO make this more comprehensive
+  val comparisonTypes: Seq[SparkType] = Seq(
+    SparkStringType,
+    SparkBinaryType,
+    SparkNumericType,
+    SparkDateType,
+    SparkTimestampType,
+    SparkArrayType(SparkTypeOneOf(Seq(SparkStringType, SparkNumericType, 
SparkDateType))))
+
 }
diff --git a/fuzz-testing/src/main/scala/org/apache/comet/fuzz/QueryGen.scala 
b/fuzz-testing/src/main/scala/org/apache/comet/fuzz/QueryGen.scala
index de1117837..d9e3c147d 100644
--- a/fuzz-testing/src/main/scala/org/apache/comet/fuzz/QueryGen.scala
+++ b/fuzz-testing/src/main/scala/org/apache/comet/fuzz/QueryGen.scala
@@ -24,7 +24,8 @@ import java.io.{BufferedWriter, FileWriter}
 import scala.collection.mutable
 import scala.util.Random
 
-import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.{DataFrame, SparkSession}
+import org.apache.spark.sql.types._
 
 object QueryGen {
 
@@ -42,19 +43,25 @@ object QueryGen {
     val uniqueQueries = mutable.HashSet[String]()
 
     for (_ <- 0 until numQueries) {
-      val sql = r.nextInt().abs % 8 match {
-        case 0 => generateJoin(r, spark, numFiles)
-        case 1 => generateAggregate(r, spark, numFiles)
-        case 2 => generateScalar(r, spark, numFiles)
-        case 3 => generateCast(r, spark, numFiles)
-        case 4 => generateUnaryArithmetic(r, spark, numFiles)
-        case 5 => generateBinaryArithmetic(r, spark, numFiles)
-        case 6 => generateBinaryComparison(r, spark, numFiles)
-        case _ => generateConditional(r, spark, numFiles)
-      }
-      if (!uniqueQueries.contains(sql)) {
-        uniqueQueries += sql
-        w.write(sql + "\n")
+      try {
+        val sql = r.nextInt().abs % 8 match {
+          case 0 => generateJoin(r, spark, numFiles)
+          case 1 => generateAggregate(r, spark, numFiles)
+          case 2 => generateScalar(r, spark, numFiles)
+          case 3 => generateCast(r, spark, numFiles)
+          case 4 => generateUnaryArithmetic(r, spark, numFiles)
+          case 5 => generateBinaryArithmetic(r, spark, numFiles)
+          case 6 => generateBinaryComparison(r, spark, numFiles)
+          case _ => generateConditional(r, spark, numFiles)
+        }
+        if (!uniqueQueries.contains(sql)) {
+          uniqueQueries += sql
+          w.write(sql + "\n")
+        }
+      } catch {
+        case e: Exception =>
+          // scalastyle:off
+          println(s"Failed to generate query: ${e.getMessage}")
       }
     }
     w.close()
@@ -65,35 +72,177 @@ object QueryGen {
     val table = spark.table(tableName)
 
     val func = Utils.randomChoice(Meta.aggFunc, r)
-    val args = Range(0, func.num_args)
-      .map(_ => Utils.randomChoice(table.columns, r))
+    try {
+      val signature = Utils.randomChoice(func.signatures, r)
+      val args = signature.inputTypes.map(x => pickRandomColumn(r, table, x))
+
+      val groupingCols = Range(0, 2).map(_ => 
Utils.randomChoice(table.columns, r))
+
+      if (groupingCols.isEmpty) {
+        s"SELECT ${args.mkString(", ")}, ${func.name}(${args.mkString(", ")}) 
AS x " +
+          s"FROM $tableName " +
+          s"ORDER BY ${args.mkString(", ")};"
+      } else {
+        s"SELECT ${groupingCols.mkString(", ")}, 
${func.name}(${args.mkString(", ")}) " +
+          s"FROM $tableName " +
+          s"GROUP BY ${groupingCols.mkString(",")} " +
+          s"ORDER BY ${groupingCols.mkString(", ")};"
+      }
+    } catch {
+      case e: Exception =>
+        throw new IllegalStateException(
+          s"Failed to generate SQL for aggregate function ${func.name}",
+          e)
+    }
+  }
+
+  private def generateScalar(r: Random, spark: SparkSession, numFiles: Int): 
String = {
+    val tableName = s"test${r.nextInt(numFiles)}"
+    val table = spark.table(tableName)
 
-    val groupingCols = Range(0, 2).map(_ => Utils.randomChoice(table.columns, 
r))
+    val func = Utils.randomChoice(Meta.scalarFunc, r)
+    try {
+      val signature = Utils.randomChoice(func.signatures, r)
+      val args = signature.inputTypes.map(x => pickRandomColumn(r, table, x))
 
-    if (groupingCols.isEmpty) {
+      // Example SELECT c0, log(c0) as x FROM test0
       s"SELECT ${args.mkString(", ")}, ${func.name}(${args.mkString(", ")}) AS 
x " +
         s"FROM $tableName " +
         s"ORDER BY ${args.mkString(", ")};"
-    } else {
-      s"SELECT ${groupingCols.mkString(", ")}, ${func.name}(${args.mkString(", 
")}) " +
-        s"FROM $tableName " +
-        s"GROUP BY ${groupingCols.mkString(",")} " +
-        s"ORDER BY ${groupingCols.mkString(", ")};"
+    } catch {
+      case e: Exception =>
+        throw new IllegalStateException(
+          s"Failed to generate SQL for scalar function ${func.name}",
+          e)
     }
   }
 
-  private def generateScalar(r: Random, spark: SparkSession, numFiles: Int): 
String = {
-    val tableName = s"test${r.nextInt(numFiles)}"
-    val table = spark.table(tableName)
+  private def pickRandomColumn(r: Random, df: DataFrame, targetType: 
SparkType): String = {
+    targetType match {
+      case SparkAnyType =>
+        Utils.randomChoice(df.schema.fields, r).name
+      case SparkBooleanType =>
+        select(r, df, _.dataType == BooleanType)
+      case SparkByteType =>
+        select(r, df, _.dataType == ByteType)
+      case SparkShortType =>
+        select(r, df, _.dataType == ShortType)
+      case SparkIntType =>
+        select(r, df, _.dataType == IntegerType)
+      case SparkLongType =>
+        select(r, df, _.dataType == LongType)
+      case SparkFloatType =>
+        select(r, df, _.dataType == FloatType)
+      case SparkDoubleType =>
+        select(r, df, _.dataType == DoubleType)
+      case SparkDecimalType(_, _) =>
+        select(r, df, _.dataType.isInstanceOf[DecimalType])
+      case SparkIntegralType =>
+        select(
+          r,
+          df,
+          f =>
+            f.dataType == ByteType || f.dataType == ShortType ||
+              f.dataType == IntegerType || f.dataType == LongType)
+      case SparkNumericType =>
+        select(r, df, f => isNumeric(f.dataType))
+      case SparkStringType =>
+        select(r, df, _.dataType == StringType)
+      case SparkBinaryType =>
+        select(r, df, _.dataType == BinaryType)
+      case SparkDateType =>
+        select(r, df, _.dataType == DateType)
+      case SparkTimestampType =>
+        select(r, df, _.dataType == TimestampType)
+      case SparkDateOrTimestampType =>
+        select(r, df, f => f.dataType == DateType || f.dataType == 
TimestampType)
+      case SparkTypeOneOf(choices) =>
+        pickRandomColumn(r, df, Utils.randomChoice(choices, r))
+      case SparkArrayType(elementType) =>
+        select(
+          r,
+          df,
+          _.dataType match {
+            case ArrayType(x, _) if typeMatch(elementType, x) => true
+            case _ => false
+          })
+      case SparkMapType(keyType, valueType) =>
+        select(
+          r,
+          df,
+          _.dataType match {
+            case MapType(k, v, _) if typeMatch(keyType, k) && 
typeMatch(valueType, v) => true
+            case _ => false
+          })
+      case SparkStructType(fields) =>
+        select(
+          r,
+          df,
+          _.dataType match {
+            case StructType(structFields) if structFields.length == 
fields.length => true
+            case _ => false
+          })
+      case _ =>
+        throw new IllegalStateException(targetType.toString)
+    }
+  }
 
-    val func = Utils.randomChoice(Meta.scalarFunc, r)
-    val args = Range(0, func.num_args)
-      .map(_ => Utils.randomChoice(table.columns, r))
+  def pickTwoRandomColumns(r: Random, df: DataFrame, targetType: SparkType): 
(String, String) = {
+    val a = pickRandomColumn(r, df, targetType)
+    val df2 = df.drop(a)
+    val b = pickRandomColumn(r, df2, targetType)
+    (a, b)
+  }
 
-    // Example SELECT c0, log(c0) as x FROM test0
-    s"SELECT ${args.mkString(", ")}, ${func.name}(${args.mkString(", ")}) AS x 
" +
-      s"FROM $tableName " +
-      s"ORDER BY ${args.mkString(", ")};"
+  /** Select a random field that matches a predicate */
+  private def select(r: Random, df: DataFrame, predicate: StructField => 
Boolean): String = {
+    val candidates = df.schema.fields.filter(predicate)
+    if (candidates.isEmpty) {
+      throw new IllegalStateException("Failed to find suitable column")
+    }
+    Utils.randomChoice(candidates, r).name
+  }
+
+  private def isNumeric(d: DataType): Boolean = {
+    d match {
+      case _: ByteType | _: ShortType | _: IntegerType | _: LongType | _: 
FloatType |
+          _: DoubleType | _: DecimalType =>
+        true
+      case _ => false
+    }
+  }
+
+  private def typeMatch(s: SparkType, d: DataType): Boolean = {
+    (s, d) match {
+      case (SparkAnyType, _) => true
+      case (SparkBooleanType, BooleanType) => true
+      case (SparkByteType, ByteType) => true
+      case (SparkShortType, ShortType) => true
+      case (SparkIntType, IntegerType) => true
+      case (SparkLongType, LongType) => true
+      case (SparkFloatType, FloatType) => true
+      case (SparkDoubleType, DoubleType) => true
+      case (SparkDecimalType(_, _), _: DecimalType) => true
+      case (SparkIntegralType, ByteType | ShortType | IntegerType | LongType) 
=> true
+      case (SparkNumericType, _) if isNumeric(d) => true
+      case (SparkStringType, StringType) => true
+      case (SparkBinaryType, BinaryType) => true
+      case (SparkDateType, DateType) => true
+      case (SparkTimestampType, TimestampType | TimestampNTZType) => true
+      case (SparkDateOrTimestampType, DateType | TimestampType | 
TimestampNTZType) => true
+      case (SparkArrayType(elementType), ArrayType(elementDataType, _)) =>
+        typeMatch(elementType, elementDataType)
+      case (SparkMapType(keyType, valueType), MapType(keyDataType, 
valueDataType, _)) =>
+        typeMatch(keyType, keyDataType) && typeMatch(valueType, valueDataType)
+      case (SparkStructType(fields), StructType(structFields)) =>
+        fields.length == structFields.length &&
+        fields.zip(structFields.map(_.dataType)).forall { case (sparkType, 
dataType) =>
+          typeMatch(sparkType, dataType)
+        }
+      case (SparkTypeOneOf(choices), _) =>
+        choices.exists(choice => typeMatch(choice, d))
+      case _ => false
+    }
   }
 
   private def generateUnaryArithmetic(r: Random, spark: SparkSession, 
numFiles: Int): String = {
@@ -101,7 +250,7 @@ object QueryGen {
     val table = spark.table(tableName)
 
     val op = Utils.randomChoice(Meta.unaryArithmeticOps, r)
-    val a = Utils.randomChoice(table.columns, r)
+    val a = pickRandomColumn(r, table, SparkNumericType)
 
     // Example SELECT a, -a FROM test0
     s"SELECT $a, $op$a " +
@@ -114,8 +263,7 @@ object QueryGen {
     val table = spark.table(tableName)
 
     val op = Utils.randomChoice(Meta.binaryArithmeticOps, r)
-    val a = Utils.randomChoice(table.columns, r)
-    val b = Utils.randomChoice(table.columns, r)
+    val (a, b) = pickTwoRandomColumns(r, table, SparkNumericType)
 
     // Example SELECT a, b, a+b FROM test0
     s"SELECT $a, $b, $a $op $b " +
@@ -128,8 +276,10 @@ object QueryGen {
     val table = spark.table(tableName)
 
     val op = Utils.randomChoice(Meta.comparisonOps, r)
-    val a = Utils.randomChoice(table.columns, r)
-    val b = Utils.randomChoice(table.columns, r)
+
+    // pick two columns with the same type
+    val opType = Utils.randomChoice(Meta.comparisonTypes, r)
+    val (a, b) = pickTwoRandomColumns(r, table, opType)
 
     // Example SELECT a, b, a <=> b FROM test0
     s"SELECT $a, $b, $a $op $b " +
@@ -142,8 +292,10 @@ object QueryGen {
     val table = spark.table(tableName)
 
     val op = Utils.randomChoice(Meta.comparisonOps, r)
-    val a = Utils.randomChoice(table.columns, r)
-    val b = Utils.randomChoice(table.columns, r)
+
+    // pick two columns with the same type
+    val opType = Utils.randomChoice(Meta.comparisonTypes, r)
+    val (a, b) = pickTwoRandomColumns(r, table, opType)
 
     // Example SELECT a, b, IF(a <=> b, 1, 2), CASE WHEN a <=> b THEN 1 ELSE 2 
END FROM test0
     s"SELECT $a, $b, $a $op $b, IF($a $op $b, 1, 2), CASE WHEN $a $op $b THEN 
1 ELSE 2 END " +
@@ -192,5 +344,3 @@ object QueryGen {
   }
 
 }
-
-case class Function(name: String, num_args: Int)
diff --git 
a/fuzz-testing/src/main/scala/org/apache/comet/fuzz/QueryRunner.scala 
b/fuzz-testing/src/main/scala/org/apache/comet/fuzz/QueryRunner.scala
index 8852f4bc1..bcc9f98d0 100644
--- a/fuzz-testing/src/main/scala/org/apache/comet/fuzz/QueryRunner.scala
+++ b/fuzz-testing/src/main/scala/org/apache/comet/fuzz/QueryRunner.scala
@@ -34,6 +34,11 @@ object QueryRunner {
       filename: String,
       showFailedSparkQueries: Boolean = false): Unit = {
 
+    var queryCount = 0
+    var invalidQueryCount = 0
+    var cometFailureCount = 0
+    var cometSuccessCount = 0
+
     val outputFilename = s"results-${System.currentTimeMillis()}.md"
     // scalastyle:off println
     println(s"Writing results to $outputFilename")
@@ -56,7 +61,7 @@ object QueryRunner {
       querySource
         .getLines()
         .foreach(sql => {
-
+          queryCount += 1
           try {
             // execute with Spark
             spark.conf.set("spark.comet.enabled", "false")
@@ -67,13 +72,11 @@ object QueryRunner {
             // execute with Comet
             try {
               spark.conf.set("spark.comet.enabled", "true")
-              // complex type support until we support it natively
-              spark.conf.set("spark.comet.sparkToColumnar.enabled", "true")
-              spark.conf.set("spark.comet.convert.parquet.enabled", "true")
               val df = spark.sql(sql)
               val cometRows = df.collect()
               val cometPlan = df.queryExecution.executedPlan.toString
 
+              var success = true
               if (sparkRows.length == cometRows.length) {
                 var i = 0
                 while (i < sparkRows.length) {
@@ -82,6 +85,7 @@ object QueryRunner {
                   assert(l.length == r.length)
                   for (j <- 0 until l.length) {
                     if (!same(l(j), r(j))) {
+                      success = false
                       showSQL(w, sql)
                       showPlans(w, sparkPlan, cometPlan)
                       w.write(s"First difference at row $i:\n")
@@ -93,16 +97,36 @@ object QueryRunner {
                   i += 1
                 }
               } else {
+                success = false
                 showSQL(w, sql)
                 showPlans(w, sparkPlan, cometPlan)
                 w.write(
                   s"[ERROR] Spark produced ${sparkRows.length} rows and " +
                     s"Comet produced ${cometRows.length} rows.\n")
               }
+
+              // check that the plan contains Comet operators
+              if (!cometPlan.contains("Comet")) {
+                success = false
+                showSQL(w, sql)
+                showPlans(w, sparkPlan, cometPlan)
+                w.write("[ERROR] Comet did not accelerate any part of the 
plan\n")
+              }
+
+              if (success) {
+                cometSuccessCount += 1
+              } else {
+                cometFailureCount += 1
+              }
+
             } catch {
               case e: Exception =>
                 // the query worked in Spark but failed in Comet, so this is 
likely a bug in Comet
+                cometFailureCount += 1
                 showSQL(w, sql)
+                w.write("### Spark Plan\n")
+                w.write(s"```\n$sparkPlan\n```\n")
+
                 w.write(s"[ERROR] Query failed in Comet: ${e.getMessage}:\n")
                 w.write("```\n")
                 val sw = new StringWriter()
@@ -119,6 +143,7 @@ object QueryRunner {
           } catch {
             case e: Exception =>
               // we expect many generated queries to be invalid
+              invalidQueryCount += 1
               if (showFailedSparkQueries) {
                 showSQL(w, sql)
                 w.write(s"Query failed in Spark: ${e.getMessage}\n")
@@ -126,6 +151,11 @@ object QueryRunner {
           }
         })
 
+      w.write("# Summary\n")
+      w.write(
+        s"Total queries: $queryCount; Invalid queries: $invalidQueryCount; " +
+          s"Comet failed: $cometFailureCount; Comet succeeded: 
$cometSuccessCount\n")
+
     } finally {
       w.close()
       querySource.close()
@@ -133,10 +163,17 @@ object QueryRunner {
   }
 
   private def same(l: Any, r: Any): Boolean = {
+    if (l == null || r == null) {
+      return l == null && r == null
+    }
     (l, r) match {
+      case (a: Float, b: Float) if a.isPosInfinity => b.isPosInfinity
+      case (a: Float, b: Float) if a.isNegInfinity => b.isNegInfinity
       case (a: Float, b: Float) if a.isInfinity => b.isInfinity
       case (a: Float, b: Float) if a.isNaN => b.isNaN
       case (a: Float, b: Float) => (a - b).abs <= 0.000001f
+      case (a: Double, b: Double) if a.isPosInfinity => b.isPosInfinity
+      case (a: Double, b: Double) if a.isNegInfinity => b.isNegInfinity
       case (a: Double, b: Double) if a.isInfinity => b.isInfinity
       case (a: Double, b: Double) if a.isNaN => b.isNaN
       case (a: Double, b: Double) => (a - b).abs <= 0.000001
@@ -144,6 +181,10 @@ object QueryRunner {
         a.length == b.length && a.zip(b).forall(x => same(x._1, x._2))
       case (a: WrappedArray[_], b: WrappedArray[_]) =>
         a.length == b.length && a.zip(b).forall(x => same(x._1, x._2))
+      case (a: Row, b: Row) =>
+        val aa = a.toSeq
+        val bb = b.toSeq
+        aa.length == bb.length && aa.zip(bb).forall(x => same(x._1, x._2))
       case (a, b) => a == b
     }
   }
@@ -153,6 +194,7 @@ object QueryRunner {
       case null => "NULL"
       case v: WrappedArray[_] => s"[${v.map(format).mkString(",")}]"
       case v: Array[Byte] => s"[${v.mkString(",")}]"
+      case r: Row => formatRow(r)
       case other => other.toString
     }
   }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to