This is an automated email from the ASF dual-hosted git repository.

kazuyukitanimura pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new a31ece991 Chore: simplify array related functions impl (#1490)
a31ece991 is described below

commit a31ece991163d21542ca05ce1c0fb52f3d5c2808
Author: Kazantsev Maksim <[email protected]>
AuthorDate: Tue Mar 25 23:39:58 2025 +0400

    Chore: simplify array related functions impl (#1490)
    
    ## Which issue does this PR close?
    
    Related to issue: https://github.com/apache/datafusion-comet/issues/1459
    
    ## Rationale for this change
    
    Defined under Issue: https://github.com/apache/datafusion-comet/issues/1459
    
    ## What changes are included in this PR?
    
    In functions related to arrays, scalarExprToProtoWithReturnType or 
scalarExprToProto is used instead of creating a separate proto for each 
function.
    
    ## How are these changes tested?
    
    Regression with available unit tests
---
 native/core/src/execution/planner.rs               | 132 +--------------
 native/proto/src/proto/expr.proto                  |  10 +-
 .../org/apache/comet/serde/QueryPlanSerde.scala    |  31 +---
 .../main/scala/org/apache/comet/serde/arrays.scala | 187 +++++++++++++++------
 .../apache/comet/CometArrayExpressionSuite.scala   |  93 +++++-----
 5 files changed, 186 insertions(+), 267 deletions(-)

diff --git a/native/core/src/execution/planner.rs 
b/native/core/src/execution/planner.rs
index 4c3a0fde9..60803dfeb 100644
--- a/native/core/src/execution/planner.rs
+++ b/native/core/src/execution/planner.rs
@@ -38,11 +38,6 @@ use 
datafusion::functions_aggregate::bit_and_or_xor::{bit_and_udaf, bit_or_udaf,
 use datafusion::functions_aggregate::min_max::max_udaf;
 use datafusion::functions_aggregate::min_max::min_udaf;
 use datafusion::functions_aggregate::sum::sum_udaf;
-use datafusion::functions_nested::array_has::array_has_any_udf;
-use datafusion::functions_nested::concat::ArrayAppend;
-use datafusion::functions_nested::remove::array_remove_all_udf;
-use datafusion::functions_nested::set_ops::array_intersect_udf;
-use datafusion::functions_nested::string::array_to_string_udf;
 use datafusion::physical_expr::aggregate::{AggregateExprBuilder, 
AggregateFunctionExpr};
 use datafusion::physical_plan::windows::BoundedWindowAggExec;
 use datafusion::physical_plan::InputOrderMode;
@@ -83,10 +78,9 @@ use datafusion::common::{
     JoinType as DFJoinType, ScalarValue,
 };
 use datafusion::datasource::listing::PartitionedFile;
-use datafusion::functions_nested::array_has::ArrayHas;
 use 
datafusion::logical_expr::type_coercion::other::get_coerce_type_for_case_expression;
 use datafusion::logical_expr::{
-    AggregateUDF, ReturnTypeArgs, ScalarUDF, WindowFrame, WindowFrameBound, 
WindowFrameUnits,
+    AggregateUDF, ReturnTypeArgs, WindowFrame, WindowFrameBound, 
WindowFrameUnits,
     WindowFunctionDefinition,
 };
 use datafusion::physical_expr::expressions::{Literal, StatsType};
@@ -759,32 +753,6 @@ impl PhysicalPlanner {
                     expr.ordinal as usize,
                 )))
             }
-            ExprStruct::ArrayAppend(expr) => {
-                let left =
-                    self.create_expr(expr.left.as_ref().unwrap(), 
Arc::clone(&input_schema))?;
-                let right =
-                    self.create_expr(expr.right.as_ref().unwrap(), 
Arc::clone(&input_schema))?;
-                let return_type = left.data_type(&input_schema)?;
-                let args = vec![Arc::clone(&left), right];
-                let datafusion_array_append =
-                    Arc::new(ScalarUDF::new_from_impl(ArrayAppend::new()));
-                let array_append_expr: Arc<dyn PhysicalExpr> = 
Arc::new(ScalarFunctionExpr::new(
-                    "array_append",
-                    datafusion_array_append,
-                    args,
-                    return_type,
-                ));
-
-                let is_null_expr: Arc<dyn PhysicalExpr> = 
Arc::new(IsNullExpr::new(left));
-                let null_literal_expr: Arc<dyn PhysicalExpr> =
-                    Arc::new(Literal::new(ScalarValue::Null));
-
-                create_case_expr(
-                    vec![(is_null_expr, null_literal_expr)],
-                    Some(array_append_expr),
-                    &input_schema,
-                )
-            }
             ExprStruct::ArrayInsert(expr) => {
                 let src_array_expr = self.create_expr(
                     expr.src_array_expr.as_ref().unwrap(),
@@ -801,104 +769,6 @@ impl PhysicalPlanner {
                     expr.legacy_negative_index,
                 )))
             }
-            ExprStruct::ArrayContains(expr) => {
-                let src_array_expr =
-                    self.create_expr(expr.left.as_ref().unwrap(), 
Arc::clone(&input_schema))?;
-                let key_expr =
-                    self.create_expr(expr.right.as_ref().unwrap(), 
Arc::clone(&input_schema))?;
-                let args = vec![Arc::clone(&src_array_expr), key_expr];
-                let array_has_expr = Arc::new(ScalarFunctionExpr::new(
-                    "array_has",
-                    Arc::new(ScalarUDF::new_from_impl(ArrayHas::new())),
-                    args,
-                    DataType::Boolean,
-                ));
-                Ok(array_has_expr)
-            }
-            ExprStruct::ArrayRemove(expr) => {
-                let src_array_expr =
-                    self.create_expr(expr.left.as_ref().unwrap(), 
Arc::clone(&input_schema))?;
-                let key_expr =
-                    self.create_expr(expr.right.as_ref().unwrap(), 
Arc::clone(&input_schema))?;
-                let args = vec![Arc::clone(&src_array_expr), 
Arc::clone(&key_expr)];
-                let return_type = src_array_expr.data_type(&input_schema)?;
-
-                let datafusion_array_remove = array_remove_all_udf();
-
-                let array_remove_expr: Arc<dyn PhysicalExpr> = 
Arc::new(ScalarFunctionExpr::new(
-                    "array_remove",
-                    datafusion_array_remove,
-                    args,
-                    return_type,
-                ));
-                let is_null_expr: Arc<dyn PhysicalExpr> = 
Arc::new(IsNullExpr::new(key_expr));
-
-                let null_literal_expr: Arc<dyn PhysicalExpr> =
-                    Arc::new(Literal::new(ScalarValue::Null));
-
-                create_case_expr(
-                    vec![(is_null_expr, null_literal_expr)],
-                    Some(array_remove_expr),
-                    &input_schema,
-                )
-            }
-            ExprStruct::ArrayIntersect(expr) => {
-                let left_expr =
-                    self.create_expr(expr.left.as_ref().unwrap(), 
Arc::clone(&input_schema))?;
-                let right_expr =
-                    self.create_expr(expr.right.as_ref().unwrap(), 
Arc::clone(&input_schema))?;
-                let args = vec![Arc::clone(&left_expr), right_expr];
-                let datafusion_array_intersect = array_intersect_udf();
-                let return_type = left_expr.data_type(&input_schema)?;
-                let array_intersect_expr = Arc::new(ScalarFunctionExpr::new(
-                    "array_intersect",
-                    datafusion_array_intersect,
-                    args,
-                    return_type,
-                ));
-                Ok(array_intersect_expr)
-            }
-            ExprStruct::ArrayJoin(expr) => {
-                let array_expr =
-                    self.create_expr(expr.array_expr.as_ref().unwrap(), 
Arc::clone(&input_schema))?;
-                let delimiter_expr = self.create_expr(
-                    expr.delimiter_expr.as_ref().unwrap(),
-                    Arc::clone(&input_schema),
-                )?;
-
-                let mut args = vec![Arc::clone(&array_expr), delimiter_expr];
-                if expr.null_replacement_expr.is_some() {
-                    let null_replacement_expr = self.create_expr(
-                        expr.null_replacement_expr.as_ref().unwrap(),
-                        Arc::clone(&input_schema),
-                    )?;
-                    args.push(null_replacement_expr)
-                }
-
-                let datafusion_array_to_string = array_to_string_udf();
-                let array_join_expr = Arc::new(ScalarFunctionExpr::new(
-                    "array_join",
-                    datafusion_array_to_string,
-                    args,
-                    DataType::Utf8,
-                ));
-                Ok(array_join_expr)
-            }
-            ExprStruct::ArraysOverlap(expr) => {
-                let left_array_expr =
-                    self.create_expr(expr.left.as_ref().unwrap(), 
Arc::clone(&input_schema))?;
-                let right_array_expr =
-                    self.create_expr(expr.right.as_ref().unwrap(), 
Arc::clone(&input_schema))?;
-                let args = vec![Arc::clone(&left_array_expr), 
right_array_expr];
-                let datafusion_array_has_any = array_has_any_udf();
-                let array_has_any_expr = Arc::new(ScalarFunctionExpr::new(
-                    "array_has_any",
-                    datafusion_array_has_any,
-                    args,
-                    DataType::Boolean,
-                ));
-                Ok(array_has_any_expr)
-            }
             expr => Err(ExecutionError::GeneralError(format!(
                 "Not implemented: {:?}",
                 expr
diff --git a/native/proto/src/proto/expr.proto 
b/native/proto/src/proto/expr.proto
index 71ad8cf3f..90fd08948 100644
--- a/native/proto/src/proto/expr.proto
+++ b/native/proto/src/proto/expr.proto
@@ -82,14 +82,8 @@ message Expr {
     ToJson to_json = 55;
     ListExtract list_extract = 56;
     GetArrayStructFields get_array_struct_fields = 57;
-    BinaryExpr array_append = 58;
-    ArrayInsert array_insert = 59;
-    BinaryExpr array_contains = 60;
-    BinaryExpr array_remove = 61;
-    BinaryExpr array_intersect = 62;
-    ArrayJoin array_join = 63;
-    BinaryExpr arrays_overlap = 64;
-    MathExpr integral_divide = 65;
+    ArrayInsert array_insert = 58;
+    MathExpr integral_divide = 59;
   }
 }
 
diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala 
b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
index b3e65a8a0..a8a3df0c1 100644
--- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
@@ -1925,34 +1925,7 @@ object QueryPlanSerde extends Logging with 
ShimQueryPlanSerde with CometExprShim
           None
         }
 
-      case expr if expr.prettyName == "array_insert" =>
-        val srcExprProto = exprToProtoInternal(expr.children(0), inputs, 
binding)
-        val posExprProto = exprToProtoInternal(expr.children(1), inputs, 
binding)
-        val itemExprProto = exprToProtoInternal(expr.children(2), inputs, 
binding)
-        val legacyNegativeIndex =
-          
SQLConf.get.getConfString("spark.sql.legacy.negativeIndexInArrayInsert").toBoolean
-        if (srcExprProto.isDefined && posExprProto.isDefined && 
itemExprProto.isDefined) {
-          val arrayInsertBuilder = ExprOuterClass.ArrayInsert
-            .newBuilder()
-            .setSrcArrayExpr(srcExprProto.get)
-            .setPosExpr(posExprProto.get)
-            .setItemExpr(itemExprProto.get)
-            .setLegacyNegativeIndex(legacyNegativeIndex)
-
-          Some(
-            ExprOuterClass.Expr
-              .newBuilder()
-              .setArrayInsert(arrayInsertBuilder)
-              .build())
-        } else {
-          withInfo(
-            expr,
-            "unsupported arguments for ArrayInsert",
-            expr.children(0),
-            expr.children(1),
-            expr.children(2))
-          None
-        }
+      case expr if expr.prettyName == "array_insert" => 
convert(CometArrayInsert)
 
       case ElementAt(child, ordinal, defaultValue, failOnError)
           if child.dataType.isInstanceOf[ArrayType] =>
@@ -2899,7 +2872,7 @@ object QueryPlanSerde extends Logging with 
ShimQueryPlanSerde with CometExprShim
   }
 
   // Utility method. Adds explain info if the result of calling exprToProto is 
None
-  private def optExprWithInfo(
+  def optExprWithInfo(
       optExpr: Option[Expr],
       expr: Expression,
       childExpr: Expression*): Option[Expr] = {
diff --git a/spark/src/main/scala/org/apache/comet/serde/arrays.scala 
b/spark/src/main/scala/org/apache/comet/serde/arrays.scala
index 60df51b08..8550d5201 100644
--- a/spark/src/main/scala/org/apache/comet/serde/arrays.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/arrays.scala
@@ -20,10 +20,11 @@
 package org.apache.comet.serde
 
 import org.apache.spark.sql.catalyst.expressions.{ArrayJoin, ArrayRemove, 
Attribute, Expression, Literal}
+import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types._
 
 import org.apache.comet.CometSparkSessionExtensions.withInfo
-import org.apache.comet.serde.QueryPlanSerde.{createBinaryExpr, exprToProto, 
scalarExprToProtoWithReturnType}
+import org.apache.comet.serde.QueryPlanSerde._
 import org.apache.comet.shims.CometExprShim
 
 object CometArrayRemove extends CometExpressionSerde with CometExprShim {
@@ -56,13 +57,37 @@ object CometArrayRemove extends CometExpressionSerde with 
CometExprShim {
         return None
       }
     }
-    createBinaryExpr(
+    val arrayExprProto = exprToProto(ar.left, inputs, binding)
+    val keyExprProto = exprToProto(ar.right, inputs, binding)
+
+    val arrayRemoveScalarExpr =
+      scalarExprToProto("array_remove_all", arrayExprProto, keyExprProto)
+
+    val isNotNullExpr = createUnaryExpr(
       expr,
-      expr.children(0),
-      expr.children(1),
+      ar.right,
       inputs,
       binding,
-      (builder, binaryExpr) => builder.setArrayRemove(binaryExpr))
+      (builder, unaryExpr) => builder.setIsNotNull(unaryExpr))
+
+    val nullLiteralProto = exprToProto(Literal(null, ar.right.dataType), 
Seq.empty)
+
+    if (arrayRemoveScalarExpr.isDefined && isNotNullExpr.isDefined && 
nullLiteralProto.isDefined) {
+      val caseWhenExpr = ExprOuterClass.CaseWhen
+        .newBuilder()
+        .addWhen(isNotNullExpr.get)
+        .addThen(arrayRemoveScalarExpr.get)
+        .setElseExpr(nullLiteralProto.get)
+        .build()
+      Some(
+        ExprOuterClass.Expr
+          .newBuilder()
+          .setCaseWhen(caseWhenExpr)
+          .build())
+    } else {
+      withInfo(expr, expr.children: _*)
+      None
+    }
   }
 }
 
@@ -71,13 +96,39 @@ object CometArrayAppend extends CometExpressionSerde with 
IncompatExpr {
       expr: Expression,
       inputs: Seq[Attribute],
       binding: Boolean): Option[ExprOuterClass.Expr] = {
-    createBinaryExpr(
+    val child = expr.children.head
+    val elementType = child.dataType.asInstanceOf[ArrayType].elementType
+
+    val arrayExprProto = exprToProto(expr.children(0), inputs, binding)
+    val keyExprProto = exprToProto(expr.children(1), inputs, binding)
+
+    val arrayAppendScalarExpr = scalarExprToProto("array_append", 
arrayExprProto, keyExprProto)
+
+    val isNotNullExpr = createUnaryExpr(
       expr,
       expr.children(0),
-      expr.children(1),
       inputs,
       binding,
-      (builder, binaryExpr) => builder.setArrayAppend(binaryExpr))
+      (builder, unaryExpr) => builder.setIsNotNull(unaryExpr))
+
+    val nullLiteralProto = exprToProto(Literal(null, elementType), Seq.empty)
+
+    if (arrayAppendScalarExpr.isDefined && isNotNullExpr.isDefined && 
nullLiteralProto.isDefined) {
+      val caseWhenExpr = ExprOuterClass.CaseWhen
+        .newBuilder()
+        .addWhen(isNotNullExpr.get)
+        .addThen(arrayAppendScalarExpr.get)
+        .setElseExpr(nullLiteralProto.get)
+        .build()
+      Some(
+        ExprOuterClass.Expr
+          .newBuilder()
+          .setCaseWhen(caseWhenExpr)
+          .build())
+    } else {
+      withInfo(expr, expr.children: _*)
+      None
+    }
   }
 }
 
@@ -86,13 +137,12 @@ object CometArrayContains extends CometExpressionSerde 
with IncompatExpr {
       expr: Expression,
       inputs: Seq[Attribute],
       binding: Boolean): Option[ExprOuterClass.Expr] = {
-    createBinaryExpr(
-      expr,
-      expr.children(0),
-      expr.children(1),
-      inputs,
-      binding,
-      (builder, binaryExpr) => builder.setArrayContains(binaryExpr))
+    val arrayExprProto = exprToProto(expr.children(0), inputs, binding)
+    val keyExprProto = exprToProto(expr.children(1), inputs, binding)
+
+    val arrayContainsScalarExpr =
+      scalarExprToProto("array_has", arrayExprProto, keyExprProto)
+    optExprWithInfo(arrayContainsScalarExpr, expr, expr.children: _*)
   }
 }
 
@@ -101,13 +151,12 @@ object CometArrayIntersect extends CometExpressionSerde 
with IncompatExpr {
       expr: Expression,
       inputs: Seq[Attribute],
       binding: Boolean): Option[ExprOuterClass.Expr] = {
-    createBinaryExpr(
-      expr,
-      expr.children(0),
-      expr.children(1),
-      inputs,
-      binding,
-      (builder, binaryExpr) => builder.setArrayIntersect(binaryExpr))
+    val leftArrayExprProto = exprToProto(expr.children(0), inputs, binding)
+    val rightArrayExprProto = exprToProto(expr.children(1), inputs, binding)
+
+    val arraysIntersectScalarExpr =
+      scalarExprToProto("array_intersect", leftArrayExprProto, 
rightArrayExprProto)
+    optExprWithInfo(arraysIntersectScalarExpr, expr, expr.children: _*)
   }
 }
 
@@ -116,13 +165,15 @@ object CometArraysOverlap extends CometExpressionSerde 
with IncompatExpr {
       expr: Expression,
       inputs: Seq[Attribute],
       binding: Boolean): Option[ExprOuterClass.Expr] = {
-    createBinaryExpr(
-      expr,
-      expr.children(0),
-      expr.children(1),
-      inputs,
-      binding,
-      (builder, binaryExpr) => builder.setArraysOverlap(binaryExpr))
+    val leftArrayExprProto = exprToProto(expr.children(0), inputs, binding)
+    val rightArrayExprProto = exprToProto(expr.children(1), inputs, binding)
+
+    val arraysOverlapScalarExpr = scalarExprToProtoWithReturnType(
+      "array_has_any",
+      BooleanType,
+      leftArrayExprProto,
+      rightArrayExprProto)
+    optExprWithInfo(arraysOverlapScalarExpr, expr, expr.children: _*)
   }
 }
 
@@ -142,12 +193,7 @@ object CometArrayCompact extends CometExpressionSerde with 
IncompatExpr {
       ArrayType(elementType = elementType),
       arrayExprProto,
       nullLiteralProto)
-    arrayCompactScalarExpr match {
-      case None =>
-        withInfo(expr, "unsupported arguments for ArrayCompact", 
expr.children: _*)
-        None
-      case expr => expr
-    }
+    optExprWithInfo(arrayCompactScalarExpr, expr, expr.children: _*)
   }
 }
 
@@ -160,32 +206,61 @@ object CometArrayJoin extends CometExpressionSerde with 
IncompatExpr {
     val arrayExprProto = exprToProto(arrayExpr.array, inputs, binding)
     val delimiterExprProto = exprToProto(arrayExpr.delimiter, inputs, binding)
 
-    if (arrayExprProto.isDefined && delimiterExprProto.isDefined) {
-      val arrayJoinBuilder = arrayExpr.nullReplacement match {
-        case Some(nrExpr) =>
-          val nullReplacementExprProto = exprToProto(nrExpr, inputs, binding)
-          ExprOuterClass.ArrayJoin
-            .newBuilder()
-            .setArrayExpr(arrayExprProto.get)
-            .setDelimiterExpr(delimiterExprProto.get)
-            .setNullReplacementExpr(nullReplacementExprProto.get)
-        case None =>
-          ExprOuterClass.ArrayJoin
-            .newBuilder()
-            .setArrayExpr(arrayExprProto.get)
-            .setDelimiterExpr(delimiterExprProto.get)
-      }
+    arrayExpr.nullReplacement match {
+      case Some(nullReplacementExpr) =>
+        val nullReplacementExprProto = exprToProto(nullReplacementExpr, 
inputs, binding)
+
+        val arrayJoinScalarExpr = scalarExprToProto(
+          "array_to_string",
+          arrayExprProto,
+          delimiterExprProto,
+          nullReplacementExprProto)
+
+        optExprWithInfo(
+          arrayJoinScalarExpr,
+          expr,
+          arrayExpr,
+          arrayExpr.delimiter,
+          nullReplacementExpr)
+      case None =>
+        val arrayJoinScalarExpr =
+          scalarExprToProto("array_to_string", arrayExprProto, 
delimiterExprProto)
+
+        optExprWithInfo(arrayJoinScalarExpr, expr, arrayExpr, 
arrayExpr.delimiter)
+    }
+  }
+}
+
+object CometArrayInsert extends CometExpressionSerde with IncompatExpr {
+  override def convert(
+      expr: Expression,
+      inputs: Seq[Attribute],
+      binding: Boolean): Option[ExprOuterClass.Expr] = {
+    val srcExprProto = exprToProtoInternal(expr.children(0), inputs, binding)
+    val posExprProto = exprToProtoInternal(expr.children(1), inputs, binding)
+    val itemExprProto = exprToProtoInternal(expr.children(2), inputs, binding)
+    val legacyNegativeIndex =
+      
SQLConf.get.getConfString("spark.sql.legacy.negativeIndexInArrayInsert").toBoolean
+    if (srcExprProto.isDefined && posExprProto.isDefined && 
itemExprProto.isDefined) {
+      val arrayInsertBuilder = ExprOuterClass.ArrayInsert
+        .newBuilder()
+        .setSrcArrayExpr(srcExprProto.get)
+        .setPosExpr(posExprProto.get)
+        .setItemExpr(itemExprProto.get)
+        .setLegacyNegativeIndex(legacyNegativeIndex)
+
       Some(
         ExprOuterClass.Expr
           .newBuilder()
-          .setArrayJoin(arrayJoinBuilder)
+          .setArrayInsert(arrayInsertBuilder)
           .build())
     } else {
-      val exprs: List[Expression] = arrayExpr.nullReplacement match {
-        case Some(nrExpr) => List(arrayExpr, arrayExpr.delimiter, nrExpr)
-        case None => List(arrayExpr, arrayExpr.delimiter)
-      }
-      withInfo(expr, "unsupported arguments for ArrayJoin", exprs: _*)
+      withInfo(
+        expr,
+        "unsupported arguments for ArrayInsert",
+        expr.children(0),
+        expr.children(1),
+        expr.children(2))
       None
     }
   }
diff --git 
a/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala 
b/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala
index 8b3299dc9..cef48c50c 100644
--- a/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala
@@ -160,59 +160,66 @@ class CometArrayExpressionSuite extends CometTestBase 
with AdaptiveSparkPlanHelp
 
   test("array_prepend") {
     assume(isSpark35Plus) // in Spark 3.5 array_prepend is implemented via 
array_insert
-    Seq(true, false).foreach { dictionaryEnabled =>
-      withTempDir { dir =>
-        val path = new Path(dir.toURI.toString, "test.parquet")
-        makeParquetFileAllTypes(path, dictionaryEnabled = dictionaryEnabled, 
10000)
-        spark.read.parquet(path.toString).createOrReplaceTempView("t1");
-        checkSparkAnswerAndOperator(spark.sql("Select 
array_prepend(array(_1),false) from t1"))
-        checkSparkAnswerAndOperator(
-          spark.sql("SELECT array_prepend(array(_2, _3, _4), 4) FROM t1"))
-        checkSparkAnswerAndOperator(
-          spark.sql("SELECT array_prepend(array(_2, _3, _4), null) FROM t1"));
-        checkSparkAnswerAndOperator(
-          spark.sql("SELECT array_prepend(array(_6, _7), CAST(6.5 AS DOUBLE)) 
FROM t1"));
-        checkSparkAnswerAndOperator(spark.sql("SELECT array_prepend(array(_8), 
'test') FROM t1"));
-        checkSparkAnswerAndOperator(spark.sql("SELECT 
array_prepend(array(_19), _19) FROM t1"));
-        checkSparkAnswerAndOperator(
-          spark.sql("SELECT array_prepend((CASE WHEN _2 =_3 THEN array(_4) 
END), _4) FROM t1"));
+    withSQLConf(CometConf.COMET_EXPR_ALLOW_INCOMPATIBLE.key -> "true") {
+      Seq(true, false).foreach { dictionaryEnabled =>
+        withTempDir { dir =>
+          val path = new Path(dir.toURI.toString, "test.parquet")
+          makeParquetFileAllTypes(path, dictionaryEnabled = dictionaryEnabled, 
10000)
+          spark.read.parquet(path.toString).createOrReplaceTempView("t1");
+          checkSparkAnswerAndOperator(spark.sql("Select 
array_prepend(array(_1),false) from t1"))
+          checkSparkAnswerAndOperator(
+            spark.sql("SELECT array_prepend(array(_2, _3, _4), 4) FROM t1"))
+          checkSparkAnswerAndOperator(
+            spark.sql("SELECT array_prepend(array(_2, _3, _4), null) FROM 
t1"));
+          checkSparkAnswerAndOperator(
+            spark.sql("SELECT array_prepend(array(_6, _7), CAST(6.5 AS 
DOUBLE)) FROM t1"));
+          checkSparkAnswerAndOperator(
+            spark.sql("SELECT array_prepend(array(_8), 'test') FROM t1"));
+          checkSparkAnswerAndOperator(spark.sql("SELECT 
array_prepend(array(_19), _19) FROM t1"));
+          checkSparkAnswerAndOperator(
+            spark.sql("SELECT array_prepend((CASE WHEN _2 =_3 THEN array(_4) 
END), _4) FROM t1"));
+        }
       }
     }
   }
 
   test("ArrayInsert") {
-    Seq(true, false).foreach(dictionaryEnabled =>
-      withTempDir { dir =>
-        val path = new Path(dir.toURI.toString, "test.parquet")
-        makeParquetFileAllTypes(path, dictionaryEnabled, 10000)
-        val df = spark.read
-          .parquet(path.toString)
-          .withColumn("arr", array(col("_4"), lit(null), col("_4")))
-          .withColumn("arrInsertResult", expr("array_insert(arr, 1, 1)"))
-          .withColumn("arrInsertNegativeIndexResult", expr("array_insert(arr, 
-1, 1)"))
-          .withColumn("arrPosGreaterThanSize", expr("array_insert(arr, 8, 1)"))
-          .withColumn("arrNegPosGreaterThanSize", expr("array_insert(arr, -8, 
1)"))
-          .withColumn("arrInsertNone", expr("array_insert(arr, 1, null)"))
-        checkSparkAnswerAndOperator(df.select("arrInsertResult"))
-        checkSparkAnswerAndOperator(df.select("arrInsertNegativeIndexResult"))
-        checkSparkAnswerAndOperator(df.select("arrPosGreaterThanSize"))
-        checkSparkAnswerAndOperator(df.select("arrNegPosGreaterThanSize"))
-        checkSparkAnswerAndOperator(df.select("arrInsertNone"))
-      })
+    withSQLConf(CometConf.COMET_EXPR_ALLOW_INCOMPATIBLE.key -> "true") {
+      Seq(true, false).foreach(dictionaryEnabled =>
+        withTempDir { dir =>
+          val path = new Path(dir.toURI.toString, "test.parquet")
+          makeParquetFileAllTypes(path, dictionaryEnabled, 10000)
+          val df = spark.read
+            .parquet(path.toString)
+            .withColumn("arr", array(col("_4"), lit(null), col("_4")))
+            .withColumn("arrInsertResult", expr("array_insert(arr, 1, 1)"))
+            .withColumn("arrInsertNegativeIndexResult", 
expr("array_insert(arr, -1, 1)"))
+            .withColumn("arrPosGreaterThanSize", expr("array_insert(arr, 8, 
1)"))
+            .withColumn("arrNegPosGreaterThanSize", expr("array_insert(arr, 
-8, 1)"))
+            .withColumn("arrInsertNone", expr("array_insert(arr, 1, null)"))
+          checkSparkAnswerAndOperator(df.select("arrInsertResult"))
+          
checkSparkAnswerAndOperator(df.select("arrInsertNegativeIndexResult"))
+          checkSparkAnswerAndOperator(df.select("arrPosGreaterThanSize"))
+          checkSparkAnswerAndOperator(df.select("arrNegPosGreaterThanSize"))
+          checkSparkAnswerAndOperator(df.select("arrInsertNone"))
+        })
+    }
   }
 
   test("ArrayInsertUnsupportedArgs") {
     // This test checks that the else branch in ArrayInsert
     // mapping to the comet is valid and fallback to spark is working fine.
-    withTempDir { dir =>
-      val path = new Path(dir.toURI.toString, "test.parquet")
-      makeParquetFileAllTypes(path, dictionaryEnabled = false, 10000)
-      val df = spark.read
-        .parquet(path.toString)
-        .withColumn("arr", array(col("_4"), lit(null), col("_4")))
-        .withColumn("idx", udf((_: Int) => 1).apply(col("_4")))
-        .withColumn("arrUnsupportedArgs", expr("array_insert(arr, idx, 1)"))
-      checkSparkAnswer(df.select("arrUnsupportedArgs"))
+    withSQLConf(CometConf.COMET_EXPR_ALLOW_INCOMPATIBLE.key -> "true") {
+      withTempDir { dir =>
+        val path = new Path(dir.toURI.toString, "test.parquet")
+        makeParquetFileAllTypes(path, dictionaryEnabled = false, 10000)
+        val df = spark.read
+          .parquet(path.toString)
+          .withColumn("arr", array(col("_4"), lit(null), col("_4")))
+          .withColumn("idx", udf((_: Int) => 1).apply(col("_4")))
+          .withColumn("arrUnsupportedArgs", expr("array_insert(arr, idx, 1)"))
+        checkSparkAnswer(df.select("arrUnsupportedArgs"))
+      }
     }
   }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to