This is an automated email from the ASF dual-hosted git repository. agrove pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push: new ddab35211 Chore: Improve array contains test coverage (#2030) ddab35211 is described below commit ddab35211f45c42085e74be0f5b8589f9c351089 Author: Kazantsev Maksim <kazantsev....@yandex.ru> AuthorDate: Thu Jul 31 19:33:26 2025 +0400 Chore: Improve array contains test coverage (#2030) --- .../main/scala/org/apache/comet/serde/arrays.scala | 2 +- .../apache/comet/CometArrayExpressionSuite.scala | 123 +++++++++++++++++++-- 2 files changed, 115 insertions(+), 10 deletions(-) diff --git a/spark/src/main/scala/org/apache/comet/serde/arrays.scala b/spark/src/main/scala/org/apache/comet/serde/arrays.scala index 4dfc59045..15d86bea1 100644 --- a/spark/src/main/scala/org/apache/comet/serde/arrays.scala +++ b/spark/src/main/scala/org/apache/comet/serde/arrays.scala @@ -136,7 +136,7 @@ object CometArrayAppend extends CometExpressionSerde with IncompatExpr { } } -object CometArrayContains extends CometExpressionSerde with IncompatExpr { +object CometArrayContains extends CometExpressionSerde { override def convert( expr: Expression, inputs: Seq[Attribute], diff --git a/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala index 0be89c512..9951f4f9d 100644 --- a/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala @@ -25,7 +25,7 @@ import scala.util.Random import org.apache.hadoop.fs.Path import org.apache.spark.sql.CometTestBase import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper -import org.apache.spark.sql.functions.{array, col, expr, lit, udf} +import org.apache.spark.sql.functions._ import org.apache.comet.CometSparkSessionExtensions.{isSpark35Plus, isSpark40Plus} import org.apache.comet.serde.CometArrayExcept @@ -218,16 +218,121 @@ class CometArrayExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelp } } - test("array_contains") { - withSQLConf(CometConf.COMET_EXPR_ALLOW_INCOMPATIBLE.key -> "true") { - withTempDir { dir => - val path = new Path(dir.toURI.toString, "test.parquet") - makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled = false, n = 10000) - spark.read.parquet(path.toString).createOrReplaceTempView("t1"); + test("array_contains - int values") { + withTempDir { dir => + val path = new Path(dir.toURI.toString, "test.parquet") + makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled = false, n = 10000) + spark.read.parquet(path.toString).createOrReplaceTempView("t1"); + checkSparkAnswerAndOperator( + spark.sql("SELECT array_contains(array(_2, _3, _4), _2) FROM t1")) + checkSparkAnswerAndOperator( + spark.sql("SELECT array_contains((CASE WHEN _2 =_3 THEN array(_4) END), _4) FROM t1")); + } + } + + test("array_contains - test all types (native Parquet reader)") { + withTempDir { dir => + val path = new Path(dir.toURI.toString, "test.parquet") + val filename = path.toString + val random = new Random(42) + withSQLConf(CometConf.COMET_ENABLED.key -> "false") { + ParquetGenerator.makeParquetFile( + random, + spark, + filename, + 100, + DataGenOptions( + allowNull = true, + generateNegativeZero = true, + generateArray = true, + generateStruct = true, + generateMap = false)) + } + val table = spark.read.parquet(filename) + table.createOrReplaceTempView("t1") + val complexTypeFields = + table.schema.fields.filter(field => isComplexType(field.dataType)) + val primitiveTypeFields = + table.schema.fields.filterNot(field => isComplexType(field.dataType)) + for (field <- primitiveTypeFields) { + val fieldName = field.name + val typeName = field.dataType.typeName + sql(s"SELECT array($fieldName, $fieldName) as a, $fieldName as b FROM t1") + .createOrReplaceTempView("t2") + checkSparkAnswerAndOperator(sql("SELECT array_contains(a, b) FROM t2")) checkSparkAnswerAndOperator( - spark.sql("SELECT array_contains(array(_2, _3, _4), _2) FROM t1")) + sql(s"SELECT array_contains(a, cast(null as $typeName)) FROM t2")) + } + for (field <- complexTypeFields) { + val fieldName = field.name + sql(s"SELECT array($fieldName, $fieldName) as a, $fieldName as b FROM t1") + .createOrReplaceTempView("t3") + checkSparkAnswer(sql("SELECT array_contains(a, b) FROM t3")) + } + } + } + + // https://github.com/apache/datafusion-comet/issues/1929 + ignore("array_contains - array literals") { + withTempDir { dir => + val path = new Path(dir.toURI.toString, "test.parquet") + val filename = path.toString + val random = new Random(42) + withSQLConf(CometConf.COMET_ENABLED.key -> "false") { + ParquetGenerator.makeParquetFile( + random, + spark, + filename, + 100, + DataGenOptions( + allowNull = true, + generateNegativeZero = true, + generateArray = false, + generateStruct = false, + generateMap = false)) + } + val table = spark.read.parquet(filename) + for (field <- table.schema.fields) { + val typeName = field.dataType.typeName checkSparkAnswerAndOperator( - spark.sql("SELECT array_contains((CASE WHEN _2 =_3 THEN array(_4) END), _4) FROM t1")); + sql(s"SELECT array_contains(cast(null as array<$typeName>), b) FROM t2")) + checkSparkAnswerAndOperator(sql( + s"SELECT array_contains(cast(array() as array<$typeName>), cast(null as $typeName)) FROM t2")) + checkSparkAnswerAndOperator(sql("SELECT array_contains(array(), 1) FROM t2")) + } + } + } + + test("array_contains - test all types (convert from Parquet)") { + withTempDir { dir => + val path = new Path(dir.toURI.toString, "test.parquet") + val filename = path.toString + val random = new Random(42) + withSQLConf(CometConf.COMET_ENABLED.key -> "false") { + ParquetGenerator.makeParquetFile( + random, + spark, + filename, + 100, + DataGenOptions( + allowNull = true, + generateNegativeZero = true, + generateArray = true, + generateStruct = true, + generateMap = false)) + } + withSQLConf( + CometConf.COMET_NATIVE_SCAN_ENABLED.key -> "false", + CometConf.COMET_SPARK_TO_ARROW_ENABLED.key -> "true", + CometConf.COMET_CONVERT_FROM_PARQUET_ENABLED.key -> "true") { + val table = spark.read.parquet(filename) + table.createOrReplaceTempView("t1") + for (field <- table.schema.fields) { + val fieldName = field.name + sql(s"SELECT array($fieldName, $fieldName) as a, $fieldName as b FROM t1") + .createOrReplaceTempView("t2") + checkSparkAnswer(sql("SELECT array_contains(a, b) FROM t2")) + } } } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@datafusion.apache.org For additional commands, e-mail: commits-h...@datafusion.apache.org