This is an automated email from the ASF dual-hosted git repository.
comphead pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new d92ea64fa feat: Support reverse function with ArrayType input (#2481)
d92ea64fa is described below
commit d92ea64fa61d9e9c8358a9ee16245e1e023a774a
Author: Fu Chen <[email protected]>
AuthorDate: Sat Oct 4 00:19:51 2025 +0800
feat: Support reverse function with ArrayType input (#2481)
* Support reverse function with ArrayType input
* nit
* refactor ut
* assert
* fix ci
* Update
spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala
Co-authored-by: Oleks V <[email protected]>
* Update
spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala
Co-authored-by: Oleks V <[email protected]>
---------
Co-authored-by: Oleks V <[email protected]>
---
.../org/apache/comet/serde/QueryPlanSerde.scala | 2 +
.../main/scala/org/apache/comet/serde/arrays.scala | 18 +-
.../apache/comet/CometArrayExpressionSuite.scala | 833 ++++++++++++---------
3 files changed, 482 insertions(+), 371 deletions(-)
diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
index bb05015c2..4d1daacd6 100644
--- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
@@ -916,6 +916,8 @@ object QueryPlanSerde extends Logging with CometExprShim {
case l @ Length(child) if child.dataType == BinaryType =>
withInfo(l, "Length on BinaryType is not supported")
None
+ case r @ Reverse(child) if child.dataType.isInstanceOf[ArrayType] =>
+ convert(r, CometArrayReverse)
case expr =>
QueryPlanSerde.exprSerdeMap.get(expr.getClass) match {
case Some(handler) =>
diff --git a/spark/src/main/scala/org/apache/comet/serde/arrays.scala
b/spark/src/main/scala/org/apache/comet/serde/arrays.scala
index 5b1603aaf..09ea547cc 100644
--- a/spark/src/main/scala/org/apache/comet/serde/arrays.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/arrays.scala
@@ -21,7 +21,7 @@ package org.apache.comet.serde
import scala.annotation.tailrec
-import org.apache.spark.sql.catalyst.expressions.{ArrayAppend, ArrayContains,
ArrayDistinct, ArrayExcept, ArrayInsert, ArrayIntersect, ArrayJoin, ArrayMax,
ArrayMin, ArrayRemove, ArrayRepeat, ArraysOverlap, ArrayUnion, Attribute,
CreateArray, ElementAt, Expression, Flatten, GetArrayItem, Literal}
+import org.apache.spark.sql.catalyst.expressions.{ArrayAppend, ArrayContains,
ArrayDistinct, ArrayExcept, ArrayInsert, ArrayIntersect, ArrayJoin, ArrayMax,
ArrayMin, ArrayRemove, ArrayRepeat, ArraysOverlap, ArrayUnion, Attribute,
CreateArray, ElementAt, Expression, Flatten, GetArrayItem, Literal, Reverse}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
@@ -432,6 +432,22 @@ object CometGetArrayItem extends
CometExpressionSerde[GetArrayItem] {
}
}
+object CometArrayReverse extends CometExpressionSerde[Reverse] with ArraysBase
{
+ override def convert(
+ expr: Reverse,
+ inputs: Seq[Attribute],
+ binding: Boolean): Option[ExprOuterClass.Expr] = {
+ if (!isTypeSupported(expr.child.dataType)) {
+ withInfo(expr, s"child data type not supported: ${expr.child.dataType}")
+ return None
+ }
+ val reverseExprProto = exprToProto(expr.child, inputs, binding)
+ val reverseScalarExpr = scalarFunctionExprToProto("array_reverse",
reverseExprProto)
+ optExprWithInfo(reverseScalarExpr, expr, expr.children: _*)
+ }
+
+}
+
object CometElementAt extends CometExpressionSerde[ElementAt] {
override def convert(
diff --git
a/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala
b/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala
index 56d9b3b42..2adb7a9ed 100644
--- a/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala
@@ -29,88 +29,94 @@ import org.apache.spark.sql.functions._
import org.apache.comet.CometSparkSessionExtensions.{isSpark35Plus,
isSpark40Plus}
import org.apache.comet.DataTypeSupport.isComplexType
-import org.apache.comet.serde.{CometArrayExcept, CometArrayRemove,
CometFlatten}
+import org.apache.comet.serde.{CometArrayExcept, CometArrayRemove,
CometArrayReverse, CometFlatten}
import org.apache.comet.testing.{DataGenOptions, ParquetGenerator}
class CometArrayExpressionSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
test("array_remove - integer") {
Seq(true, false).foreach { dictionaryEnabled =>
- withTempDir { dir =>
- val path = new Path(dir.toURI.toString, "test.parquet")
- makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled, 10000)
- spark.read.parquet(path.toString).createOrReplaceTempView("t1")
- checkSparkAnswerAndOperator(
- sql("SELECT array_remove(array(_2, _3,_4), _2) from t1 where _2 is
null"))
- checkSparkAnswerAndOperator(
- sql("SELECT array_remove(array(_2, _3,_4), _3) from t1 where _3 is
not null"))
- checkSparkAnswerAndOperator(sql(
- "SELECT array_remove(case when _2 = _3 THEN array(_2, _3,_4) ELSE
null END, _3) from t1"))
+ withTempView("t1") {
+ withTempDir { dir =>
+ val path = new Path(dir.toURI.toString, "test.parquet")
+ makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled, 10000)
+ spark.read.parquet(path.toString).createOrReplaceTempView("t1")
+ checkSparkAnswerAndOperator(
+ sql("SELECT array_remove(array(_2, _3,_4), _2) from t1 where _2 is
null"))
+ checkSparkAnswerAndOperator(
+ sql("SELECT array_remove(array(_2, _3,_4), _3) from t1 where _3 is
not null"))
+ checkSparkAnswerAndOperator(sql(
+ "SELECT array_remove(case when _2 = _3 THEN array(_2, _3,_4) ELSE
null END, _3) from t1"))
+ }
}
}
}
test("array_remove - test all types (native Parquet reader)") {
withTempDir { dir =>
- val path = new Path(dir.toURI.toString, "test.parquet")
- val filename = path.toString
- val random = new Random(42)
- withSQLConf(CometConf.COMET_ENABLED.key -> "false") {
- ParquetGenerator.makeParquetFile(
- random,
- spark,
- filename,
- 100,
- DataGenOptions(
- allowNull = true,
- generateNegativeZero = true,
- generateArray = false,
- generateStruct = false,
- generateMap = false))
- }
- val table = spark.read.parquet(filename)
- table.createOrReplaceTempView("t1")
- // test with array of each column
- val fieldNames =
- table.schema.fields
- .filter(field => CometArrayRemove.isTypeSupported(field.dataType))
- .map(_.name)
- for (fieldName <- fieldNames) {
- sql(s"SELECT array($fieldName, $fieldName) as a, $fieldName as b FROM
t1")
- .createOrReplaceTempView("t2")
- val df = sql("SELECT array_remove(a, b) FROM t2")
- checkSparkAnswerAndOperator(df)
- }
- }
- }
-
- test("array_remove - test all types (convert from Parquet)") {
- withTempDir { dir =>
- val path = new Path(dir.toURI.toString, "test.parquet")
- val filename = path.toString
- val random = new Random(42)
- withSQLConf(CometConf.COMET_ENABLED.key -> "false") {
- val options = DataGenOptions(
- allowNull = true,
- generateNegativeZero = true,
- generateArray = true,
- generateStruct = true,
- generateMap = false)
- ParquetGenerator.makeParquetFile(random, spark, filename, 100, options)
- }
- withSQLConf(
- CometConf.COMET_NATIVE_SCAN_ENABLED.key -> "false",
- CometConf.COMET_SPARK_TO_ARROW_ENABLED.key -> "true",
- CometConf.COMET_CONVERT_FROM_PARQUET_ENABLED.key -> "true") {
+ withTempView("t1") {
+ val path = new Path(dir.toURI.toString, "test.parquet")
+ val filename = path.toString
+ val random = new Random(42)
+ withSQLConf(CometConf.COMET_ENABLED.key -> "false") {
+ ParquetGenerator.makeParquetFile(
+ random,
+ spark,
+ filename,
+ 100,
+ DataGenOptions(
+ allowNull = true,
+ generateNegativeZero = true,
+ generateArray = false,
+ generateStruct = false,
+ generateMap = false))
+ }
val table = spark.read.parquet(filename)
table.createOrReplaceTempView("t1")
// test with array of each column
- for (field <- table.schema.fields) {
- val fieldName = field.name
+ val fieldNames =
+ table.schema.fields
+ .filter(field => CometArrayRemove.isTypeSupported(field.dataType))
+ .map(_.name)
+ for (fieldName <- fieldNames) {
sql(s"SELECT array($fieldName, $fieldName) as a, $fieldName as b
FROM t1")
.createOrReplaceTempView("t2")
val df = sql("SELECT array_remove(a, b) FROM t2")
- checkSparkAnswer(df)
+ checkSparkAnswerAndOperator(df)
+ }
+ }
+ }
+ }
+
+ test("array_remove - test all types (convert from Parquet)") {
+ withTempDir { dir =>
+ withTempView("t1") {
+ val path = new Path(dir.toURI.toString, "test.parquet")
+ val filename = path.toString
+ val random = new Random(42)
+ withSQLConf(CometConf.COMET_ENABLED.key -> "false") {
+ val options = DataGenOptions(
+ allowNull = true,
+ generateNegativeZero = true,
+ generateArray = true,
+ generateStruct = true,
+ generateMap = false)
+ ParquetGenerator.makeParquetFile(random, spark, filename, 100,
options)
+ }
+ withSQLConf(
+ CometConf.COMET_NATIVE_SCAN_ENABLED.key -> "false",
+ CometConf.COMET_SPARK_TO_ARROW_ENABLED.key -> "true",
+ CometConf.COMET_CONVERT_FROM_PARQUET_ENABLED.key -> "true") {
+ val table = spark.read.parquet(filename)
+ table.createOrReplaceTempView("t1")
+ // test with array of each column
+ for (field <- table.schema.fields) {
+ val fieldName = field.name
+ sql(s"SELECT array($fieldName, $fieldName) as a, $fieldName as b
FROM t1")
+ .createOrReplaceTempView("t2")
+ val df = sql("SELECT array_remove(a, b) FROM t2")
+ checkSparkAnswer(df)
+ }
}
}
}
@@ -118,19 +124,21 @@ class CometArrayExpressionSuite extends CometTestBase
with AdaptiveSparkPlanHelp
test("array_remove - fallback for unsupported type struct") {
withTempDir { dir =>
- val path = new Path(dir.toURI.toString, "test.parquet")
- makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled = true, 100)
- spark.read.parquet(path.toString).createOrReplaceTempView("t1")
- sql("SELECT array(struct(_1, _2)) as a, struct(_1, _2) as b FROM t1")
- .createOrReplaceTempView("t2")
- val expectedFallbackReasons = HashSet(
- "data type not supported:
ArrayType(StructType(StructField(_1,BooleanType,true),StructField(_2,ByteType,true)),false)")
- // note that checkExtended is disabled here due to an unrelated issue
- // https://github.com/apache/datafusion-comet/issues/1313
- checkSparkAnswerAndCompareExplainPlan(
- sql("SELECT array_remove(a, b) FROM t2"),
- expectedFallbackReasons,
- checkExplainString = false)
+ withTempView("t1", "t2") {
+ val path = new Path(dir.toURI.toString, "test.parquet")
+ makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled = true, 100)
+ spark.read.parquet(path.toString).createOrReplaceTempView("t1")
+ sql("SELECT array(struct(_1, _2)) as a, struct(_1, _2) as b FROM t1")
+ .createOrReplaceTempView("t2")
+ val expectedFallbackReasons = HashSet(
+ "data type not supported:
ArrayType(StructType(StructField(_1,BooleanType,true),StructField(_2,ByteType,true)),false)")
+ // note that checkExtended is disabled here due to an unrelated issue
+ // https://github.com/apache/datafusion-comet/issues/1313
+ checkSparkAnswerAndCompareExplainPlan(
+ sql("SELECT array_remove(a, b) FROM t2"),
+ expectedFallbackReasons,
+ checkExplainString = false)
+ }
}
}
@@ -138,21 +146,25 @@ class CometArrayExpressionSuite extends CometTestBase
with AdaptiveSparkPlanHelp
withSQLConf(CometConf.COMET_EXPR_ALLOW_INCOMPATIBLE.key -> "true") {
Seq(true, false).foreach { dictionaryEnabled =>
withTempDir { dir =>
- val path = new Path(dir.toURI.toString, "test.parquet")
- makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled =
dictionaryEnabled, 10000)
- spark.read.parquet(path.toString).createOrReplaceTempView("t1");
- checkSparkAnswerAndOperator(spark.sql("Select
array_append(array(_1),false) from t1"))
- checkSparkAnswerAndOperator(
- spark.sql("SELECT array_append(array(_2, _3, _4), 4) FROM t1"))
- checkSparkAnswerAndOperator(
- spark.sql("SELECT array_append(array(_2, _3, _4), null) FROM t1"));
- checkSparkAnswerAndOperator(
- spark.sql("SELECT array_append(array(_6, _7), CAST(6.5 AS DOUBLE))
FROM t1"));
- checkSparkAnswerAndOperator(
- spark.sql("SELECT array_append(array(_8), 'test') FROM t1"));
- checkSparkAnswerAndOperator(spark.sql("SELECT
array_append(array(_19), _19) FROM t1"));
- checkSparkAnswerAndOperator(
- spark.sql("SELECT array_append((CASE WHEN _2 =_3 THEN array(_4)
END), _4) FROM t1"));
+ withTempView("t1") {
+ val path = new Path(dir.toURI.toString, "test.parquet")
+ makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled =
dictionaryEnabled, 10000)
+ spark.read.parquet(path.toString).createOrReplaceTempView("t1");
+ checkSparkAnswerAndOperator(spark.sql("Select
array_append(array(_1),false) from t1"))
+ checkSparkAnswerAndOperator(
+ spark.sql("SELECT array_append(array(_2, _3, _4), 4) FROM t1"))
+ checkSparkAnswerAndOperator(
+ spark.sql("SELECT array_append(array(_2, _3, _4), null) FROM
t1"));
+ checkSparkAnswerAndOperator(
+ spark.sql("SELECT array_append(array(_6, _7), CAST(6.5 AS
DOUBLE)) FROM t1"));
+ checkSparkAnswerAndOperator(
+ spark.sql("SELECT array_append(array(_8), 'test') FROM t1"));
+ checkSparkAnswerAndOperator(
+ spark.sql("SELECT array_append(array(_19), _19) FROM t1"));
+ checkSparkAnswerAndOperator(
+ spark.sql(
+ "SELECT array_append((CASE WHEN _2 =_3 THEN array(_4) END),
_4) FROM t1"));
+ }
}
}
}
@@ -163,21 +175,26 @@ class CometArrayExpressionSuite extends CometTestBase
with AdaptiveSparkPlanHelp
withSQLConf(CometConf.COMET_EXPR_ALLOW_INCOMPATIBLE.key -> "true") {
Seq(true, false).foreach { dictionaryEnabled =>
withTempDir { dir =>
- val path = new Path(dir.toURI.toString, "test.parquet")
- makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled =
dictionaryEnabled, 10000)
- spark.read.parquet(path.toString).createOrReplaceTempView("t1");
- checkSparkAnswerAndOperator(spark.sql("Select
array_prepend(array(_1),false) from t1"))
- checkSparkAnswerAndOperator(
- spark.sql("SELECT array_prepend(array(_2, _3, _4), 4) FROM t1"))
- checkSparkAnswerAndOperator(
- spark.sql("SELECT array_prepend(array(_2, _3, _4), null) FROM
t1"));
- checkSparkAnswerAndOperator(
- spark.sql("SELECT array_prepend(array(_6, _7), CAST(6.5 AS
DOUBLE)) FROM t1"));
- checkSparkAnswerAndOperator(
- spark.sql("SELECT array_prepend(array(_8), 'test') FROM t1"));
- checkSparkAnswerAndOperator(spark.sql("SELECT
array_prepend(array(_19), _19) FROM t1"));
- checkSparkAnswerAndOperator(
- spark.sql("SELECT array_prepend((CASE WHEN _2 =_3 THEN array(_4)
END), _4) FROM t1"));
+ withTempView("t1") {
+ val path = new Path(dir.toURI.toString, "test.parquet")
+ makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled =
dictionaryEnabled, 10000)
+ spark.read.parquet(path.toString).createOrReplaceTempView("t1");
+ checkSparkAnswerAndOperator(
+ spark.sql("Select array_prepend(array(_1),false) from t1"))
+ checkSparkAnswerAndOperator(
+ spark.sql("SELECT array_prepend(array(_2, _3, _4), 4) FROM t1"))
+ checkSparkAnswerAndOperator(
+ spark.sql("SELECT array_prepend(array(_2, _3, _4), null) FROM
t1"));
+ checkSparkAnswerAndOperator(
+ spark.sql("SELECT array_prepend(array(_6, _7), CAST(6.5 AS
DOUBLE)) FROM t1"));
+ checkSparkAnswerAndOperator(
+ spark.sql("SELECT array_prepend(array(_8), 'test') FROM t1"));
+ checkSparkAnswerAndOperator(
+ spark.sql("SELECT array_prepend(array(_19), _19) FROM t1"));
+ checkSparkAnswerAndOperator(
+ spark.sql(
+ "SELECT array_prepend((CASE WHEN _2 =_3 THEN array(_4) END),
_4) FROM t1"));
+ }
}
}
}
@@ -225,84 +242,90 @@ class CometArrayExpressionSuite extends CometTestBase
with AdaptiveSparkPlanHelp
test("array_contains - int values") {
withTempDir { dir =>
- val path = new Path(dir.toURI.toString, "test.parquet")
- makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled = false, n =
10000)
- spark.read.parquet(path.toString).createOrReplaceTempView("t1");
- checkSparkAnswerAndOperator(
- spark.sql("SELECT array_contains(array(_2, _3, _4), _2) FROM t1"))
- checkSparkAnswerAndOperator(
- spark.sql("SELECT array_contains((CASE WHEN _2 =_3 THEN array(_4)
END), _4) FROM t1"));
+ withTempView("t1") {
+ val path = new Path(dir.toURI.toString, "test.parquet")
+ makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled = false, n =
10000)
+ spark.read.parquet(path.toString).createOrReplaceTempView("t1");
+ checkSparkAnswerAndOperator(
+ spark.sql("SELECT array_contains(array(_2, _3, _4), _2) FROM t1"))
+ checkSparkAnswerAndOperator(
+ spark.sql("SELECT array_contains((CASE WHEN _2 =_3 THEN array(_4)
END), _4) FROM t1"));
+ }
}
}
test("array_contains - test all types (native Parquet reader)") {
withTempDir { dir =>
- val path = new Path(dir.toURI.toString, "test.parquet")
- val filename = path.toString
- val random = new Random(42)
- withSQLConf(CometConf.COMET_ENABLED.key -> "false") {
- ParquetGenerator.makeParquetFile(
- random,
- spark,
- filename,
- 100,
- DataGenOptions(
- allowNull = true,
- generateNegativeZero = true,
- generateArray = true,
- generateStruct = true,
- generateMap = false))
- }
- val table = spark.read.parquet(filename)
- table.createOrReplaceTempView("t1")
- val complexTypeFields =
- table.schema.fields.filter(field => isComplexType(field.dataType))
- val primitiveTypeFields =
- table.schema.fields.filterNot(field => isComplexType(field.dataType))
- for (field <- primitiveTypeFields) {
- val fieldName = field.name
- val typeName = field.dataType.typeName
- sql(s"SELECT array($fieldName, $fieldName) as a, $fieldName as b FROM
t1")
- .createOrReplaceTempView("t2")
- checkSparkAnswerAndOperator(sql("SELECT array_contains(a, b) FROM t2"))
- checkSparkAnswerAndOperator(
- sql(s"SELECT array_contains(a, cast(null as $typeName)) FROM t2"))
- }
- for (field <- complexTypeFields) {
- val fieldName = field.name
- sql(s"SELECT array($fieldName, $fieldName) as a, $fieldName as b FROM
t1")
- .createOrReplaceTempView("t3")
- checkSparkAnswer(sql("SELECT array_contains(a, b) FROM t3"))
+ withTempView("t1", "t2", "t3") {
+ val path = new Path(dir.toURI.toString, "test.parquet")
+ val filename = path.toString
+ val random = new Random(42)
+ withSQLConf(CometConf.COMET_ENABLED.key -> "false") {
+ ParquetGenerator.makeParquetFile(
+ random,
+ spark,
+ filename,
+ 100,
+ DataGenOptions(
+ allowNull = true,
+ generateNegativeZero = true,
+ generateArray = true,
+ generateStruct = true,
+ generateMap = false))
+ }
+ val table = spark.read.parquet(filename)
+ table.createOrReplaceTempView("t1")
+ val complexTypeFields =
+ table.schema.fields.filter(field => isComplexType(field.dataType))
+ val primitiveTypeFields =
+ table.schema.fields.filterNot(field => isComplexType(field.dataType))
+ for (field <- primitiveTypeFields) {
+ val fieldName = field.name
+ val typeName = field.dataType.typeName
+ sql(s"SELECT array($fieldName, $fieldName) as a, $fieldName as b
FROM t1")
+ .createOrReplaceTempView("t2")
+ checkSparkAnswerAndOperator(sql("SELECT array_contains(a, b) FROM
t2"))
+ checkSparkAnswerAndOperator(
+ sql(s"SELECT array_contains(a, cast(null as $typeName)) FROM t2"))
+ }
+ for (field <- complexTypeFields) {
+ val fieldName = field.name
+ sql(s"SELECT array($fieldName, $fieldName) as a, $fieldName as b
FROM t1")
+ .createOrReplaceTempView("t3")
+ checkSparkAnswer(sql("SELECT array_contains(a, b) FROM t3"))
+ }
}
}
}
test("array_contains - array literals") {
withTempDir { dir =>
- val path = new Path(dir.toURI.toString, "test.parquet")
- val filename = path.toString
- val random = new Random(42)
- withSQLConf(CometConf.COMET_ENABLED.key -> "false") {
- ParquetGenerator.makeParquetFile(
- random,
- spark,
- filename,
- 100,
- DataGenOptions(
- allowNull = true,
- generateNegativeZero = true,
- generateArray = false,
- generateStruct = false,
- generateMap = false))
- }
- val table = spark.read.parquet(filename)
- table.createOrReplaceTempView("t2")
- for (field <- table.schema.fields) {
- val typeName = field.dataType.typeName
- checkSparkAnswerAndOperator(sql(
- s"SELECT array_contains(cast(null as array<$typeName>), cast(null as
$typeName)) FROM t2"))
+ withTempView("t2") {
+ val path = new Path(dir.toURI.toString, "test.parquet")
+ val filename = path.toString
+ val random = new Random(42)
+ withSQLConf(CometConf.COMET_ENABLED.key -> "false") {
+ ParquetGenerator.makeParquetFile(
+ random,
+ spark,
+ filename,
+ 100,
+ DataGenOptions(
+ allowNull = true,
+ generateNegativeZero = true,
+ generateArray = false,
+ generateStruct = false,
+ generateMap = false))
+ }
+ val table = spark.read.parquet(filename)
+ table.createOrReplaceTempView("t2")
+ for (field <- table.schema.fields) {
+ val typeName = field.dataType.typeName
+ checkSparkAnswerAndOperator(sql(
+ s"SELECT array_contains(cast(null as array<$typeName>), cast(null
as $typeName)) FROM t2"))
+ }
+ checkSparkAnswerAndOperator(sql("SELECT array_contains(array(), 1)
FROM t2"))
}
- checkSparkAnswerAndOperator(sql("SELECT array_contains(array(), 1) FROM
t2"))
}
}
@@ -328,13 +351,15 @@ class CometArrayExpressionSuite extends CometTestBase
with AdaptiveSparkPlanHelp
CometConf.COMET_NATIVE_SCAN_ENABLED.key -> "false",
CometConf.COMET_SPARK_TO_ARROW_ENABLED.key -> "true",
CometConf.COMET_CONVERT_FROM_PARQUET_ENABLED.key -> "true") {
- val table = spark.read.parquet(filename)
- table.createOrReplaceTempView("t1")
- for (field <- table.schema.fields) {
- val fieldName = field.name
- sql(s"SELECT array($fieldName, $fieldName) as a, $fieldName as b
FROM t1")
- .createOrReplaceTempView("t2")
- checkSparkAnswer(sql("SELECT array_contains(a, b) FROM t2"))
+ withTempView("t1", "t2") {
+ val table = spark.read.parquet(filename)
+ table.createOrReplaceTempView("t1")
+ for (field <- table.schema.fields) {
+ val fieldName = field.name
+ sql(s"SELECT array($fieldName, $fieldName) as a, $fieldName as b
FROM t1")
+ .createOrReplaceTempView("t2")
+ checkSparkAnswer(sql("SELECT array_contains(a, b) FROM t2"))
+ }
}
}
}
@@ -344,24 +369,26 @@ class CometArrayExpressionSuite extends CometTestBase
with AdaptiveSparkPlanHelp
withSQLConf(CometConf.COMET_EXPR_ALLOW_INCOMPATIBLE.key -> "true") {
Seq(true, false).foreach { dictionaryEnabled =>
withTempDir { dir =>
- val path = new Path(dir.toURI.toString, "test.parquet")
- makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled, n = 10000)
- spark.read.parquet(path.toString).createOrReplaceTempView("t1")
- // The result needs to be in ascending order for
checkSparkAnswerAndOperator to pass
- // because datafusion array_distinct sorts the elements and then
removes the duplicates
- checkSparkAnswerAndOperator(
- spark.sql("SELECT array_distinct(array(_2, _2, _3, _4, _4)) FROM
t1"))
- checkSparkAnswerAndOperator(
- spark.sql("SELECT array_distinct((CASE WHEN _2 =_3 THEN array(_4)
END)) FROM t1"))
- checkSparkAnswerAndOperator(spark.sql(
- "SELECT array_distinct((CASE WHEN _2 =_3 THEN array(_2, _2, _4,
_4, _5) END)) FROM t1"))
- // NULL needs to be the first element for
checkSparkAnswerAndOperator to pass because
- // datafusion array_distinct sorts the elements and then removes the
duplicates
- checkSparkAnswerAndOperator(
- spark.sql(
- "SELECT array_distinct(array(CAST(NULL AS INT), _2, _2, _3, _4,
_4)) FROM t1"))
- checkSparkAnswerAndOperator(spark.sql(
- "SELECT array_distinct(array(CAST(NULL AS INT), CAST(NULL AS INT),
_2, _2, _3, _4, _4)) FROM t1"))
+ withTempView("t1") {
+ val path = new Path(dir.toURI.toString, "test.parquet")
+ makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled, n =
10000)
+ spark.read.parquet(path.toString).createOrReplaceTempView("t1")
+ // The result needs to be in ascending order for
checkSparkAnswerAndOperator to pass
+ // because datafusion array_distinct sorts the elements and then
removes the duplicates
+ checkSparkAnswerAndOperator(
+ spark.sql("SELECT array_distinct(array(_2, _2, _3, _4, _4)) FROM
t1"))
+ checkSparkAnswerAndOperator(
+ spark.sql("SELECT array_distinct((CASE WHEN _2 =_3 THEN
array(_4) END)) FROM t1"))
+ checkSparkAnswerAndOperator(spark.sql(
+ "SELECT array_distinct((CASE WHEN _2 =_3 THEN array(_2, _2, _4,
_4, _5) END)) FROM t1"))
+ // NULL needs to be the first element for
checkSparkAnswerAndOperator to pass because
+ // datafusion array_distinct sorts the elements and then removes
the duplicates
+ checkSparkAnswerAndOperator(
+ spark.sql(
+ "SELECT array_distinct(array(CAST(NULL AS INT), _2, _2, _3,
_4, _4)) FROM t1"))
+ checkSparkAnswerAndOperator(spark.sql(
+ "SELECT array_distinct(array(CAST(NULL AS INT), CAST(NULL AS
INT), _2, _2, _3, _4, _4)) FROM t1"))
+ }
}
}
}
@@ -371,16 +398,18 @@ class CometArrayExpressionSuite extends CometTestBase
with AdaptiveSparkPlanHelp
withSQLConf(CometConf.COMET_EXPR_ALLOW_INCOMPATIBLE.key -> "true") {
Seq(true, false).foreach { dictionaryEnabled =>
withTempDir { dir =>
- val path = new Path(dir.toURI.toString, "test.parquet")
- makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled, n = 10000)
- spark.read.parquet(path.toString).createOrReplaceTempView("t1")
- checkSparkAnswerAndOperator(
- spark.sql("SELECT array_union(array(_2, _3, _4), array(_3, _4))
FROM t1"))
- checkSparkAnswerAndOperator(sql("SELECT array_union(array(_18),
array(_19)) from t1"))
- checkSparkAnswerAndOperator(spark.sql(
- "SELECT array_union(array(CAST(NULL AS INT), _2, _3, _4),
array(CAST(NULL AS INT), _2, _3)) FROM t1"))
- checkSparkAnswerAndOperator(spark.sql(
- "SELECT array_union(array(CAST(NULL AS INT), CAST(NULL AS INT),
_2, _3, _4), array(CAST(NULL AS INT), CAST(NULL AS INT), _2, _3)) FROM t1"))
+ withTempView("t1") {
+ val path = new Path(dir.toURI.toString, "test.parquet")
+ makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled, n =
10000)
+ spark.read.parquet(path.toString).createOrReplaceTempView("t1")
+ checkSparkAnswerAndOperator(
+ spark.sql("SELECT array_union(array(_2, _3, _4), array(_3, _4))
FROM t1"))
+ checkSparkAnswerAndOperator(sql("SELECT array_union(array(_18),
array(_19)) from t1"))
+ checkSparkAnswerAndOperator(spark.sql(
+ "SELECT array_union(array(CAST(NULL AS INT), _2, _3, _4),
array(CAST(NULL AS INT), _2, _3)) FROM t1"))
+ checkSparkAnswerAndOperator(spark.sql(
+ "SELECT array_union(array(CAST(NULL AS INT), CAST(NULL AS INT),
_2, _3, _4), array(CAST(NULL AS INT), CAST(NULL AS INT), _2, _3)) FROM t1"))
+ }
}
}
}
@@ -389,22 +418,24 @@ class CometArrayExpressionSuite extends CometTestBase
with AdaptiveSparkPlanHelp
test("array_max") {
Seq(true, false).foreach { dictionaryEnabled =>
withTempDir { dir =>
- val path = new Path(dir.toURI.toString, "test.parquet")
- makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled, n = 10000)
- spark.read.parquet(path.toString).createOrReplaceTempView("t1");
- checkSparkAnswerAndOperator(spark.sql("SELECT array_max(array(_2, _3,
_4)) FROM t1"))
- checkSparkAnswerAndOperator(
- spark.sql("SELECT array_max((CASE WHEN _2 =_3 THEN array(_4) END))
FROM t1"))
- checkSparkAnswerAndOperator(
- spark.sql("SELECT array_max((CASE WHEN _2 =_3 THEN array(_2, _4)
END)) FROM t1"))
- checkSparkAnswerAndOperator(
- spark.sql("SELECT array_max(array(CAST(NULL AS INT), CAST(NULL AS
INT))) FROM t1"))
- checkSparkAnswerAndOperator(
- spark.sql("SELECT array_max(array(_2, CAST(NULL AS INT))) FROM t1"))
- checkSparkAnswerAndOperator(spark.sql("SELECT array_max(array()) FROM
t1"))
- checkSparkAnswerAndOperator(
- spark.sql(
- "SELECT array_max(array(double('-Infinity'), 0.0,
double('Infinity'))) FROM t1"))
+ withTempView("t1") {
+ val path = new Path(dir.toURI.toString, "test.parquet")
+ makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled, n = 10000)
+ spark.read.parquet(path.toString).createOrReplaceTempView("t1");
+ checkSparkAnswerAndOperator(spark.sql("SELECT array_max(array(_2,
_3, _4)) FROM t1"))
+ checkSparkAnswerAndOperator(
+ spark.sql("SELECT array_max((CASE WHEN _2 =_3 THEN array(_4) END))
FROM t1"))
+ checkSparkAnswerAndOperator(
+ spark.sql("SELECT array_max((CASE WHEN _2 =_3 THEN array(_2, _4)
END)) FROM t1"))
+ checkSparkAnswerAndOperator(
+ spark.sql("SELECT array_max(array(CAST(NULL AS INT), CAST(NULL AS
INT))) FROM t1"))
+ checkSparkAnswerAndOperator(
+ spark.sql("SELECT array_max(array(_2, CAST(NULL AS INT))) FROM
t1"))
+ checkSparkAnswerAndOperator(spark.sql("SELECT array_max(array())
FROM t1"))
+ checkSparkAnswerAndOperator(
+ spark.sql(
+ "SELECT array_max(array(double('-Infinity'), 0.0,
double('Infinity'))) FROM t1"))
+ }
}
}
}
@@ -412,40 +443,43 @@ class CometArrayExpressionSuite extends CometTestBase
with AdaptiveSparkPlanHelp
test("array_min") {
Seq(true, false).foreach { dictionaryEnabled =>
withTempDir { dir =>
- val path = new Path(dir.toURI.toString, "test.parquet")
- makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled, n = 10000)
- spark.read.parquet(path.toString).createOrReplaceTempView("t1");
- checkSparkAnswerAndOperator(spark.sql("SELECT array_min(array(_2, _3,
_4)) FROM t1"))
- checkSparkAnswerAndOperator(
- spark.sql("SELECT array_min((CASE WHEN _2 =_3 THEN array(_4) END))
FROM t1"))
- checkSparkAnswerAndOperator(
- spark.sql("SELECT array_min((CASE WHEN _2 =_3 THEN array(_2, _4)
END)) FROM t1"))
- checkSparkAnswerAndOperator(
- spark.sql("SELECT array_min(array(CAST(NULL AS INT), CAST(NULL AS
INT))) FROM t1"))
- checkSparkAnswerAndOperator(
- spark.sql("SELECT array_min(array(_2, CAST(NULL AS INT))) FROM t1"))
- checkSparkAnswerAndOperator(spark.sql("SELECT array_min(array()) FROM
t1"))
- checkSparkAnswerAndOperator(
- spark.sql(
- "SELECT array_min(array(double('-Infinity'), 0.0,
double('Infinity'))) FROM t1"))
+ withTempView("t1") {
+ val path = new Path(dir.toURI.toString, "test.parquet")
+ makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled, n = 10000)
+ spark.read.parquet(path.toString).createOrReplaceTempView("t1");
+ checkSparkAnswerAndOperator(spark.sql("SELECT array_min(array(_2,
_3, _4)) FROM t1"))
+ checkSparkAnswerAndOperator(
+ spark.sql("SELECT array_min((CASE WHEN _2 =_3 THEN array(_4) END))
FROM t1"))
+ checkSparkAnswerAndOperator(
+ spark.sql("SELECT array_min((CASE WHEN _2 =_3 THEN array(_2, _4)
END)) FROM t1"))
+ checkSparkAnswerAndOperator(
+ spark.sql("SELECT array_min(array(CAST(NULL AS INT), CAST(NULL AS
INT))) FROM t1"))
+ checkSparkAnswerAndOperator(
+ spark.sql("SELECT array_min(array(_2, CAST(NULL AS INT))) FROM
t1"))
+ checkSparkAnswerAndOperator(spark.sql("SELECT array_min(array())
FROM t1"))
+ checkSparkAnswerAndOperator(
+ spark.sql(
+ "SELECT array_min(array(double('-Infinity'), 0.0,
double('Infinity'))) FROM t1"))
+ }
}
}
}
test("array_intersect") {
withSQLConf(CometConf.COMET_EXPR_ALLOW_INCOMPATIBLE.key -> "true") {
-
Seq(true, false).foreach { dictionaryEnabled =>
withTempDir { dir =>
- val path = new Path(dir.toURI.toString, "test.parquet")
- makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled, 10000)
- spark.read.parquet(path.toString).createOrReplaceTempView("t1")
- checkSparkAnswerAndOperator(
- sql("SELECT array_intersect(array(_2, _3, _4), array(_3, _4)) from
t1"))
- checkSparkAnswerAndOperator(
- sql("SELECT array_intersect(array(_4 * -1), array(_5)) from t1"))
- checkSparkAnswerAndOperator(
- sql("SELECT array_intersect(array(_18), array(_19)) from t1"))
+ withTempView("t1") {
+ val path = new Path(dir.toURI.toString, "test.parquet")
+ makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled, 10000)
+ spark.read.parquet(path.toString).createOrReplaceTempView("t1")
+ checkSparkAnswerAndOperator(
+ sql("SELECT array_intersect(array(_2, _3, _4), array(_3, _4))
from t1"))
+ checkSparkAnswerAndOperator(
+ sql("SELECT array_intersect(array(_4 * -1), array(_5)) from t1"))
+ checkSparkAnswerAndOperator(
+ sql("SELECT array_intersect(array(_18), array(_19)) from t1"))
+ }
}
}
}
@@ -455,18 +489,19 @@ class CometArrayExpressionSuite extends CometTestBase
with AdaptiveSparkPlanHelp
withSQLConf(CometConf.COMET_EXPR_ALLOW_INCOMPATIBLE.key -> "true") {
Seq(true, false).foreach { dictionaryEnabled =>
withTempDir { dir =>
- val path = new Path(dir.toURI.toString, "test.parquet")
- makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled, 10000)
- spark.read.parquet(path.toString).createOrReplaceTempView("t1")
- checkSparkAnswerAndOperator(sql(
- "SELECT array_join(array(cast(_1 as string), cast(_2 as string),
cast(_6 as string)), ' @ ') from t1"))
- checkSparkAnswerAndOperator(sql(
- "SELECT array_join(array(cast(_1 as string), cast(_2 as string),
cast(_6 as string)), ' @ ', ' +++ ') from t1"))
- checkSparkAnswerAndOperator(sql(
- "SELECT array_join(array('hello', 'world', cast(_2 as string)), '
') from t1 where _2 is not null"))
- checkSparkAnswerAndOperator(
- sql(
+ withTempView("t1") {
+ val path = new Path(dir.toURI.toString, "test.parquet")
+ makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled, 10000)
+ spark.read.parquet(path.toString).createOrReplaceTempView("t1")
+ checkSparkAnswerAndOperator(sql(
+ "SELECT array_join(array(cast(_1 as string), cast(_2 as string),
cast(_6 as string)), ' @ ') from t1"))
+ checkSparkAnswerAndOperator(sql(
+ "SELECT array_join(array(cast(_1 as string), cast(_2 as string),
cast(_6 as string)), ' @ ', ' +++ ') from t1"))
+ checkSparkAnswerAndOperator(sql(
+ "SELECT array_join(array('hello', 'world', cast(_2 as string)),
' ') from t1 where _2 is not null"))
+ checkSparkAnswerAndOperator(sql(
"SELECT array_join(array('hello', '-', 'world', cast(_2 as
string)), ' ') from t1"))
+ }
}
}
}
@@ -476,17 +511,19 @@ class CometArrayExpressionSuite extends CometTestBase
with AdaptiveSparkPlanHelp
withSQLConf(CometConf.COMET_EXPR_ALLOW_INCOMPATIBLE.key -> "true") {
Seq(true, false).foreach { dictionaryEnabled =>
withTempDir { dir =>
- val path = new Path(dir.toURI.toString, "test.parquet")
- makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled, 10000)
- spark.read.parquet(path.toString).createOrReplaceTempView("t1")
- checkSparkAnswerAndOperator(sql(
- "SELECT arrays_overlap(array(_2, _3, _4), array(_3, _4)) from t1
where _2 is not null"))
- checkSparkAnswerAndOperator(sql(
- "SELECT arrays_overlap(array('a', null, cast(_1 as string)),
array('b', cast(_1 as string), cast(_2 as string))) from t1 where _1 is not
null"))
- checkSparkAnswerAndOperator(sql(
- "SELECT arrays_overlap(array('a', null), array('b', null)) from t1
where _1 is not null"))
- checkSparkAnswerAndOperator(spark.sql(
- "SELECT arrays_overlap((CASE WHEN _2 =_3 THEN array(_6, _7) END),
array(_6, _7)) FROM t1"));
+ withTempView("t1") {
+ val path = new Path(dir.toURI.toString, "test.parquet")
+ makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled, 10000)
+ spark.read.parquet(path.toString).createOrReplaceTempView("t1")
+ checkSparkAnswerAndOperator(sql(
+ "SELECT arrays_overlap(array(_2, _3, _4), array(_3, _4)) from t1
where _2 is not null"))
+ checkSparkAnswerAndOperator(sql(
+ "SELECT arrays_overlap(array('a', null, cast(_1 as string)),
array('b', cast(_1 as string), cast(_2 as string))) from t1 where _1 is not
null"))
+ checkSparkAnswerAndOperator(sql(
+ "SELECT arrays_overlap(array('a', null), array('b', null)) from
t1 where _1 is not null"))
+ checkSparkAnswerAndOperator(spark.sql(
+ "SELECT arrays_overlap((CASE WHEN _2 =_3 THEN array(_6, _7)
END), array(_6, _7)) FROM t1"));
+ }
}
}
}
@@ -498,16 +535,21 @@ class CometArrayExpressionSuite extends CometTestBase
with AdaptiveSparkPlanHelp
withSQLConf(CometConf.COMET_EXPR_ALLOW_INCOMPATIBLE.key -> "true") {
Seq(true, false).foreach { dictionaryEnabled =>
withTempDir { dir =>
- val path = new Path(dir.toURI.toString, "test.parquet")
- makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled =
dictionaryEnabled, n = 10000)
- spark.read.parquet(path.toString).createOrReplaceTempView("t1")
-
- checkSparkAnswerAndOperator(
- sql("SELECT array_compact(array(_2)) FROM t1 WHERE _2 IS NULL"))
- checkSparkAnswerAndOperator(
- sql("SELECT array_compact(array(_2)) FROM t1 WHERE _2 IS NOT
NULL"))
- checkSparkAnswerAndOperator(
- sql("SELECT array_compact(array(_2, _3, null)) FROM t1 WHERE _2 IS
NOT NULL"))
+ withTempView("t1") {
+ val path = new Path(dir.toURI.toString, "test.parquet")
+ makeParquetFileAllPrimitiveTypes(
+ path,
+ dictionaryEnabled = dictionaryEnabled,
+ n = 10000)
+ spark.read.parquet(path.toString).createOrReplaceTempView("t1")
+
+ checkSparkAnswerAndOperator(
+ sql("SELECT array_compact(array(_2)) FROM t1 WHERE _2 IS NULL"))
+ checkSparkAnswerAndOperator(
+ sql("SELECT array_compact(array(_2)) FROM t1 WHERE _2 IS NOT
NULL"))
+ checkSparkAnswerAndOperator(
+ sql("SELECT array_compact(array(_2, _3, null)) FROM t1 WHERE _2
IS NOT NULL"))
+ }
}
}
}
@@ -517,16 +559,19 @@ class CometArrayExpressionSuite extends CometTestBase
with AdaptiveSparkPlanHelp
withSQLConf(CometConf.COMET_EXPR_ALLOW_INCOMPATIBLE.key -> "true") {
Seq(true, false).foreach { dictionaryEnabled =>
withTempDir { dir =>
- val path = new Path(dir.toURI.toString, "test.parquet")
- makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled, 10000)
- spark.read.parquet(path.toString).createOrReplaceTempView("t1")
-
- checkSparkAnswerAndOperator(
- sql("SELECT array_except(array(_2, _3, _4), array(_3, _4)) from
t1"))
- checkSparkAnswerAndOperator(sql("SELECT array_except(array(_18),
array(_19)) from t1"))
- checkSparkAnswerAndOperator(
- spark.sql(
- "SELECT array_except(array(_2, _2, _4), array(_4)) FROM t1 WHERE
_2 IS NOT NULL"))
+ withTempView("t1") {
+ val path = new Path(dir.toURI.toString, "test.parquet")
+ makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled, 10000)
+ spark.read.parquet(path.toString).createOrReplaceTempView("t1")
+
+ checkSparkAnswerAndOperator(
+ sql("SELECT array_except(array(_2, _3, _4), array(_3, _4)) from
t1"))
+ checkSparkAnswerAndOperator(
+ sql("SELECT array_except(array(_18), array(_19)) from t1"))
+ checkSparkAnswerAndOperator(
+ spark.sql(
+ "SELECT array_except(array(_2, _2, _4), array(_4)) FROM t1
WHERE _2 IS NOT NULL"))
+ }
}
}
}
@@ -551,19 +596,21 @@ class CometArrayExpressionSuite extends CometTestBase
with AdaptiveSparkPlanHelp
generateMap = false))
}
withSQLConf(CometConf.COMET_EXPR_ALLOW_INCOMPATIBLE.key -> "true") {
- val table = spark.read.parquet(filename)
- table.createOrReplaceTempView("t1")
- // test with array of each column
- val fields =
- table.schema.fields.filter(field =>
CometArrayExcept.isTypeSupported(field.dataType))
- for (field <- fields) {
- val fieldName = field.name
- val typeName = field.dataType.typeName
- sql(
- s"SELECT cast(array($fieldName, $fieldName) as array<$typeName>)
as a, cast(array($fieldName) as array<$typeName>) as b FROM t1")
- .createOrReplaceTempView("t2")
- val df = sql("SELECT array_except(a, b) FROM t2")
- checkSparkAnswerAndOperator(df)
+ withTempView("t1", "t2") {
+ val table = spark.read.parquet(filename)
+ table.createOrReplaceTempView("t1")
+ // test with array of each column
+ val fields =
+ table.schema.fields.filter(field =>
CometArrayExcept.isTypeSupported(field.dataType))
+ for (field <- fields) {
+ val fieldName = field.name
+ val typeName = field.dataType.typeName
+ sql(
+ s"SELECT cast(array($fieldName, $fieldName) as array<$typeName>)
as a, cast(array($fieldName) as array<$typeName>) as b FROM t1")
+ .createOrReplaceTempView("t2")
+ val df = sql("SELECT array_except(a, b) FROM t2")
+ checkSparkAnswerAndOperator(df)
+ }
}
}
}
@@ -588,17 +635,19 @@ class CometArrayExpressionSuite extends CometTestBase
with AdaptiveSparkPlanHelp
CometConf.COMET_SPARK_TO_ARROW_ENABLED.key -> "true",
CometConf.COMET_CONVERT_FROM_PARQUET_ENABLED.key -> "true",
CometConf.COMET_EXPR_ALLOW_INCOMPATIBLE.key -> "true") {
- val table = spark.read.parquet(filename)
- table.createOrReplaceTempView("t1")
- // test with array of each column
- val fields =
- table.schema.fields.filter(field =>
CometArrayExcept.isTypeSupported(field.dataType))
- for (field <- fields) {
- val fieldName = field.name
- sql(s"SELECT array($fieldName, $fieldName) as a, array($fieldName)
as b FROM t1")
- .createOrReplaceTempView("t2")
- val df = sql("SELECT array_except(a, b) FROM t2")
- checkSparkAnswer(df)
+ withTempView("t1", "t2") {
+ val table = spark.read.parquet(filename)
+ table.createOrReplaceTempView("t1")
+ // test with array of each column
+ val fields =
+ table.schema.fields.filter(field =>
CometArrayExcept.isTypeSupported(field.dataType))
+ for (field <- fields) {
+ val fieldName = field.name
+ sql(s"SELECT array($fieldName, $fieldName) as a, array($fieldName)
as b FROM t1")
+ .createOrReplaceTempView("t2")
+ val df = sql("SELECT array_except(a, b) FROM t2")
+ checkSparkAnswer(df)
+ }
}
}
}
@@ -610,19 +659,22 @@ class CometArrayExpressionSuite extends CometTestBase
with AdaptiveSparkPlanHelp
CometConf.COMET_EXPLAIN_FALLBACK_ENABLED.key -> "true") {
Seq(true, false).foreach { dictionaryEnabled =>
withTempDir { dir =>
- val path = new Path(dir.toURI.toString, "test.parquet")
- makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled, 100)
- spark.read.parquet(path.toString).createOrReplaceTempView("t1")
-
- checkSparkAnswerAndOperator(sql("SELECT array_repeat(_4, null) from
t1"))
- checkSparkAnswerAndOperator(sql("SELECT array_repeat(_4, 0) from
t1"))
- checkSparkAnswerAndOperator(
- sql("SELECT array_repeat(_2, 5) from t1 where _2 is not null"))
- checkSparkAnswerAndOperator(sql("SELECT array_repeat(_2, 5) from t1
where _2 is null"))
- checkSparkAnswerAndOperator(
- sql("SELECT array_repeat(_3, _4) from t1 where _3 is not null"))
- checkSparkAnswerAndOperator(sql("SELECT array_repeat(cast(_3 as
string), 2) from t1"))
- checkSparkAnswerAndOperator(sql("SELECT array_repeat(array(_2, _3,
_4), 2) from t1"))
+ withTempView("t1") {
+ val path = new Path(dir.toURI.toString, "test.parquet")
+ makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled, 100)
+ spark.read.parquet(path.toString).createOrReplaceTempView("t1")
+
+ checkSparkAnswerAndOperator(sql("SELECT array_repeat(_4, null)
from t1"))
+ checkSparkAnswerAndOperator(sql("SELECT array_repeat(_4, 0) from
t1"))
+ checkSparkAnswerAndOperator(
+ sql("SELECT array_repeat(_2, 5) from t1 where _2 is not null"))
+ checkSparkAnswerAndOperator(
+ sql("SELECT array_repeat(_2, 5) from t1 where _2 is null"))
+ checkSparkAnswerAndOperator(
+ sql("SELECT array_repeat(_3, _4) from t1 where _3 is not null"))
+ checkSparkAnswerAndOperator(sql("SELECT array_repeat(cast(_3 as
string), 2) from t1"))
+ checkSparkAnswerAndOperator(sql("SELECT array_repeat(array(_2, _3,
_4), 2) from t1"))
+ }
}
}
}
@@ -630,32 +682,34 @@ class CometArrayExpressionSuite extends CometTestBase
with AdaptiveSparkPlanHelp
test("flatten - test all types (native Parquet reader)") {
withTempDir { dir =>
- val path = new Path(dir.toURI.toString, "test.parquet")
- val filename = path.toString
- val random = new Random(42)
- withSQLConf(CometConf.COMET_ENABLED.key -> "false") {
- ParquetGenerator.makeParquetFile(
- random,
- spark,
- filename,
- 100,
- DataGenOptions(
- allowNull = true,
- generateNegativeZero = true,
- generateArray = false,
- generateStruct = false,
- generateMap = false))
- }
- val table = spark.read.parquet(filename)
- table.createOrReplaceTempView("t1")
- val fieldNames =
- table.schema.fields
- .filter(field => CometFlatten.isTypeSupported(field.dataType))
- .map(_.name)
- for (fieldName <- fieldNames) {
- sql(s"SELECT array(array($fieldName, $fieldName), array($fieldName))
as a FROM t1")
- .createOrReplaceTempView("t2")
- checkSparkAnswerAndOperator(sql("SELECT flatten(a) FROM t2"))
+ withTempView("t1", "t2") {
+ val path = new Path(dir.toURI.toString, "test.parquet")
+ val filename = path.toString
+ val random = new Random(42)
+ withSQLConf(CometConf.COMET_ENABLED.key -> "false") {
+ ParquetGenerator.makeParquetFile(
+ random,
+ spark,
+ filename,
+ 100,
+ DataGenOptions(
+ allowNull = true,
+ generateNegativeZero = true,
+ generateArray = false,
+ generateStruct = false,
+ generateMap = false))
+ }
+ val table = spark.read.parquet(filename)
+ table.createOrReplaceTempView("t1")
+ val fieldNames =
+ table.schema.fields
+ .filter(field => CometFlatten.isTypeSupported(field.dataType))
+ .map(_.name)
+ for (fieldName <- fieldNames) {
+ sql(s"SELECT array(array($fieldName, $fieldName), array($fieldName))
as a FROM t1")
+ .createOrReplaceTempView("t2")
+ checkSparkAnswerAndOperator(sql("SELECT flatten(a) FROM t2"))
+ }
}
}
}
@@ -678,16 +732,18 @@ class CometArrayExpressionSuite extends CometTestBase
with AdaptiveSparkPlanHelp
CometConf.COMET_NATIVE_SCAN_ENABLED.key -> "false",
CometConf.COMET_SPARK_TO_ARROW_ENABLED.key -> "true",
CometConf.COMET_CONVERT_FROM_PARQUET_ENABLED.key -> "true") {
- val table = spark.read.parquet(filename)
- table.createOrReplaceTempView("t1")
- val fieldNames =
- table.schema.fields
- .filter(field => CometFlatten.isTypeSupported(field.dataType))
- .map(_.name)
- for (fieldName <- fieldNames) {
- sql(s"SELECT array(array($fieldName, $fieldName), array($fieldName))
as a FROM t1")
- .createOrReplaceTempView("t2")
- checkSparkAnswer(sql("SELECT flatten(a) FROM t2"))
+ withTempView("t1", "t2") {
+ val table = spark.read.parquet(filename)
+ table.createOrReplaceTempView("t1")
+ val fieldNames =
+ table.schema.fields
+ .filter(field => CometFlatten.isTypeSupported(field.dataType))
+ .map(_.name)
+ for (fieldName <- fieldNames) {
+ sql(s"SELECT array(array($fieldName, $fieldName),
array($fieldName)) as a FROM t1")
+ .createOrReplaceTempView("t2")
+ checkSparkAnswer(sql("SELECT flatten(a) FROM t2"))
+ }
}
}
}
@@ -699,11 +755,48 @@ class CometArrayExpressionSuite extends CometTestBase
with AdaptiveSparkPlanHelp
CometConf.COMET_EXPLAIN_FALLBACK_ENABLED.key -> "true") {
Seq(true, false).foreach { dictionaryEnabled =>
withTempDir { dir =>
- val path = new Path(dir.toURI.toString, "test.parquet")
- makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled, 100)
- spark.read.parquet(path.toString).createOrReplaceTempView("t1")
- checkSparkAnswerAndOperator(
- sql("SELECT array(array(1, 2, 3), null, array(), array(null),
array(1)) from t1"))
+ withTempView("t1") {
+ val path = new Path(dir.toURI.toString, "test.parquet")
+ makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled, 100)
+ spark.read.parquet(path.toString).createOrReplaceTempView("t1")
+ checkSparkAnswerAndOperator(
+ sql("SELECT array(array(1, 2, 3), null, array(), array(null),
array(1)) from t1"))
+ }
+ }
+ }
+ }
+ }
+
+ test("array_reverse") {
+ withTempDir { dir =>
+ val path = new Path(dir.toURI.toString, "test.parquet")
+ val filename = path.toString
+ val random = new Random(42)
+ withSQLConf(CometConf.COMET_ENABLED.key -> "false") {
+ val options = DataGenOptions(
+ allowNull = true,
+ generateNegativeZero = true,
+ generateArray = true,
+ generateStruct = true,
+ generateMap = false)
+ ParquetGenerator.makeParquetFile(random, spark, filename, 100, options)
+ }
+ withSQLConf(
+ CometConf.COMET_NATIVE_SCAN_ENABLED.key -> "false",
+ CometConf.COMET_SPARK_TO_ARROW_ENABLED.key -> "true",
+ CometConf.COMET_CONVERT_FROM_PARQUET_ENABLED.key -> "true") {
+ withTempView("t1", "t2") {
+ val table = spark.read.parquet(filename)
+ table.createOrReplaceTempView("t1")
+ val fieldNames =
+ table.schema.fields
+ .filter(field =>
CometArrayReverse.isTypeSupported(field.dataType))
+ .map(_.name)
+ for (fieldName <- fieldNames) {
+ sql(s"SELECT $fieldName as a FROM t1")
+ .createOrReplaceTempView("t2")
+ checkSparkAnswer(sql("SELECT reverse(a) FROM t2"))
+ }
}
}
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]