This is an automated email from the ASF dual-hosted git repository.
kazuyukitanimura pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new a31ece991 Chore: simplify array related functions impl (#1490)
a31ece991 is described below
commit a31ece991163d21542ca05ce1c0fb52f3d5c2808
Author: Kazantsev Maksim <[email protected]>
AuthorDate: Tue Mar 25 23:39:58 2025 +0400
Chore: simplify array related functions impl (#1490)
## Which issue does this PR close?
Related to issue: https://github.com/apache/datafusion-comet/issues/1459
## Rationale for this change
Defined under Issue: https://github.com/apache/datafusion-comet/issues/1459
## What changes are included in this PR?
In functions related to arrays, scalarExprToProtoWithReturnType or
scalarExprToProto is used instead of creating a separate proto for each
function.
## How are these changes tested?
Regression with available unit tests
---
native/core/src/execution/planner.rs | 132 +--------------
native/proto/src/proto/expr.proto | 10 +-
.../org/apache/comet/serde/QueryPlanSerde.scala | 31 +---
.../main/scala/org/apache/comet/serde/arrays.scala | 187 +++++++++++++++------
.../apache/comet/CometArrayExpressionSuite.scala | 93 +++++-----
5 files changed, 186 insertions(+), 267 deletions(-)
diff --git a/native/core/src/execution/planner.rs
b/native/core/src/execution/planner.rs
index 4c3a0fde9..60803dfeb 100644
--- a/native/core/src/execution/planner.rs
+++ b/native/core/src/execution/planner.rs
@@ -38,11 +38,6 @@ use
datafusion::functions_aggregate::bit_and_or_xor::{bit_and_udaf, bit_or_udaf,
use datafusion::functions_aggregate::min_max::max_udaf;
use datafusion::functions_aggregate::min_max::min_udaf;
use datafusion::functions_aggregate::sum::sum_udaf;
-use datafusion::functions_nested::array_has::array_has_any_udf;
-use datafusion::functions_nested::concat::ArrayAppend;
-use datafusion::functions_nested::remove::array_remove_all_udf;
-use datafusion::functions_nested::set_ops::array_intersect_udf;
-use datafusion::functions_nested::string::array_to_string_udf;
use datafusion::physical_expr::aggregate::{AggregateExprBuilder,
AggregateFunctionExpr};
use datafusion::physical_plan::windows::BoundedWindowAggExec;
use datafusion::physical_plan::InputOrderMode;
@@ -83,10 +78,9 @@ use datafusion::common::{
JoinType as DFJoinType, ScalarValue,
};
use datafusion::datasource::listing::PartitionedFile;
-use datafusion::functions_nested::array_has::ArrayHas;
use
datafusion::logical_expr::type_coercion::other::get_coerce_type_for_case_expression;
use datafusion::logical_expr::{
- AggregateUDF, ReturnTypeArgs, ScalarUDF, WindowFrame, WindowFrameBound,
WindowFrameUnits,
+ AggregateUDF, ReturnTypeArgs, WindowFrame, WindowFrameBound,
WindowFrameUnits,
WindowFunctionDefinition,
};
use datafusion::physical_expr::expressions::{Literal, StatsType};
@@ -759,32 +753,6 @@ impl PhysicalPlanner {
expr.ordinal as usize,
)))
}
- ExprStruct::ArrayAppend(expr) => {
- let left =
- self.create_expr(expr.left.as_ref().unwrap(),
Arc::clone(&input_schema))?;
- let right =
- self.create_expr(expr.right.as_ref().unwrap(),
Arc::clone(&input_schema))?;
- let return_type = left.data_type(&input_schema)?;
- let args = vec![Arc::clone(&left), right];
- let datafusion_array_append =
- Arc::new(ScalarUDF::new_from_impl(ArrayAppend::new()));
- let array_append_expr: Arc<dyn PhysicalExpr> =
Arc::new(ScalarFunctionExpr::new(
- "array_append",
- datafusion_array_append,
- args,
- return_type,
- ));
-
- let is_null_expr: Arc<dyn PhysicalExpr> =
Arc::new(IsNullExpr::new(left));
- let null_literal_expr: Arc<dyn PhysicalExpr> =
- Arc::new(Literal::new(ScalarValue::Null));
-
- create_case_expr(
- vec![(is_null_expr, null_literal_expr)],
- Some(array_append_expr),
- &input_schema,
- )
- }
ExprStruct::ArrayInsert(expr) => {
let src_array_expr = self.create_expr(
expr.src_array_expr.as_ref().unwrap(),
@@ -801,104 +769,6 @@ impl PhysicalPlanner {
expr.legacy_negative_index,
)))
}
- ExprStruct::ArrayContains(expr) => {
- let src_array_expr =
- self.create_expr(expr.left.as_ref().unwrap(),
Arc::clone(&input_schema))?;
- let key_expr =
- self.create_expr(expr.right.as_ref().unwrap(),
Arc::clone(&input_schema))?;
- let args = vec![Arc::clone(&src_array_expr), key_expr];
- let array_has_expr = Arc::new(ScalarFunctionExpr::new(
- "array_has",
- Arc::new(ScalarUDF::new_from_impl(ArrayHas::new())),
- args,
- DataType::Boolean,
- ));
- Ok(array_has_expr)
- }
- ExprStruct::ArrayRemove(expr) => {
- let src_array_expr =
- self.create_expr(expr.left.as_ref().unwrap(),
Arc::clone(&input_schema))?;
- let key_expr =
- self.create_expr(expr.right.as_ref().unwrap(),
Arc::clone(&input_schema))?;
- let args = vec![Arc::clone(&src_array_expr),
Arc::clone(&key_expr)];
- let return_type = src_array_expr.data_type(&input_schema)?;
-
- let datafusion_array_remove = array_remove_all_udf();
-
- let array_remove_expr: Arc<dyn PhysicalExpr> =
Arc::new(ScalarFunctionExpr::new(
- "array_remove",
- datafusion_array_remove,
- args,
- return_type,
- ));
- let is_null_expr: Arc<dyn PhysicalExpr> =
Arc::new(IsNullExpr::new(key_expr));
-
- let null_literal_expr: Arc<dyn PhysicalExpr> =
- Arc::new(Literal::new(ScalarValue::Null));
-
- create_case_expr(
- vec![(is_null_expr, null_literal_expr)],
- Some(array_remove_expr),
- &input_schema,
- )
- }
- ExprStruct::ArrayIntersect(expr) => {
- let left_expr =
- self.create_expr(expr.left.as_ref().unwrap(),
Arc::clone(&input_schema))?;
- let right_expr =
- self.create_expr(expr.right.as_ref().unwrap(),
Arc::clone(&input_schema))?;
- let args = vec![Arc::clone(&left_expr), right_expr];
- let datafusion_array_intersect = array_intersect_udf();
- let return_type = left_expr.data_type(&input_schema)?;
- let array_intersect_expr = Arc::new(ScalarFunctionExpr::new(
- "array_intersect",
- datafusion_array_intersect,
- args,
- return_type,
- ));
- Ok(array_intersect_expr)
- }
- ExprStruct::ArrayJoin(expr) => {
- let array_expr =
- self.create_expr(expr.array_expr.as_ref().unwrap(),
Arc::clone(&input_schema))?;
- let delimiter_expr = self.create_expr(
- expr.delimiter_expr.as_ref().unwrap(),
- Arc::clone(&input_schema),
- )?;
-
- let mut args = vec![Arc::clone(&array_expr), delimiter_expr];
- if expr.null_replacement_expr.is_some() {
- let null_replacement_expr = self.create_expr(
- expr.null_replacement_expr.as_ref().unwrap(),
- Arc::clone(&input_schema),
- )?;
- args.push(null_replacement_expr)
- }
-
- let datafusion_array_to_string = array_to_string_udf();
- let array_join_expr = Arc::new(ScalarFunctionExpr::new(
- "array_join",
- datafusion_array_to_string,
- args,
- DataType::Utf8,
- ));
- Ok(array_join_expr)
- }
- ExprStruct::ArraysOverlap(expr) => {
- let left_array_expr =
- self.create_expr(expr.left.as_ref().unwrap(),
Arc::clone(&input_schema))?;
- let right_array_expr =
- self.create_expr(expr.right.as_ref().unwrap(),
Arc::clone(&input_schema))?;
- let args = vec![Arc::clone(&left_array_expr),
right_array_expr];
- let datafusion_array_has_any = array_has_any_udf();
- let array_has_any_expr = Arc::new(ScalarFunctionExpr::new(
- "array_has_any",
- datafusion_array_has_any,
- args,
- DataType::Boolean,
- ));
- Ok(array_has_any_expr)
- }
expr => Err(ExecutionError::GeneralError(format!(
"Not implemented: {:?}",
expr
diff --git a/native/proto/src/proto/expr.proto
b/native/proto/src/proto/expr.proto
index 71ad8cf3f..90fd08948 100644
--- a/native/proto/src/proto/expr.proto
+++ b/native/proto/src/proto/expr.proto
@@ -82,14 +82,8 @@ message Expr {
ToJson to_json = 55;
ListExtract list_extract = 56;
GetArrayStructFields get_array_struct_fields = 57;
- BinaryExpr array_append = 58;
- ArrayInsert array_insert = 59;
- BinaryExpr array_contains = 60;
- BinaryExpr array_remove = 61;
- BinaryExpr array_intersect = 62;
- ArrayJoin array_join = 63;
- BinaryExpr arrays_overlap = 64;
- MathExpr integral_divide = 65;
+ ArrayInsert array_insert = 58;
+ MathExpr integral_divide = 59;
}
}
diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
index b3e65a8a0..a8a3df0c1 100644
--- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
@@ -1925,34 +1925,7 @@ object QueryPlanSerde extends Logging with
ShimQueryPlanSerde with CometExprShim
None
}
- case expr if expr.prettyName == "array_insert" =>
- val srcExprProto = exprToProtoInternal(expr.children(0), inputs,
binding)
- val posExprProto = exprToProtoInternal(expr.children(1), inputs,
binding)
- val itemExprProto = exprToProtoInternal(expr.children(2), inputs,
binding)
- val legacyNegativeIndex =
-
SQLConf.get.getConfString("spark.sql.legacy.negativeIndexInArrayInsert").toBoolean
- if (srcExprProto.isDefined && posExprProto.isDefined &&
itemExprProto.isDefined) {
- val arrayInsertBuilder = ExprOuterClass.ArrayInsert
- .newBuilder()
- .setSrcArrayExpr(srcExprProto.get)
- .setPosExpr(posExprProto.get)
- .setItemExpr(itemExprProto.get)
- .setLegacyNegativeIndex(legacyNegativeIndex)
-
- Some(
- ExprOuterClass.Expr
- .newBuilder()
- .setArrayInsert(arrayInsertBuilder)
- .build())
- } else {
- withInfo(
- expr,
- "unsupported arguments for ArrayInsert",
- expr.children(0),
- expr.children(1),
- expr.children(2))
- None
- }
+ case expr if expr.prettyName == "array_insert" =>
convert(CometArrayInsert)
case ElementAt(child, ordinal, defaultValue, failOnError)
if child.dataType.isInstanceOf[ArrayType] =>
@@ -2899,7 +2872,7 @@ object QueryPlanSerde extends Logging with
ShimQueryPlanSerde with CometExprShim
}
// Utility method. Adds explain info if the result of calling exprToProto is
None
- private def optExprWithInfo(
+ def optExprWithInfo(
optExpr: Option[Expr],
expr: Expression,
childExpr: Expression*): Option[Expr] = {
diff --git a/spark/src/main/scala/org/apache/comet/serde/arrays.scala
b/spark/src/main/scala/org/apache/comet/serde/arrays.scala
index 60df51b08..8550d5201 100644
--- a/spark/src/main/scala/org/apache/comet/serde/arrays.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/arrays.scala
@@ -20,10 +20,11 @@
package org.apache.comet.serde
import org.apache.spark.sql.catalyst.expressions.{ArrayJoin, ArrayRemove,
Attribute, Expression, Literal}
+import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.comet.CometSparkSessionExtensions.withInfo
-import org.apache.comet.serde.QueryPlanSerde.{createBinaryExpr, exprToProto,
scalarExprToProtoWithReturnType}
+import org.apache.comet.serde.QueryPlanSerde._
import org.apache.comet.shims.CometExprShim
object CometArrayRemove extends CometExpressionSerde with CometExprShim {
@@ -56,13 +57,37 @@ object CometArrayRemove extends CometExpressionSerde with
CometExprShim {
return None
}
}
- createBinaryExpr(
+ val arrayExprProto = exprToProto(ar.left, inputs, binding)
+ val keyExprProto = exprToProto(ar.right, inputs, binding)
+
+ val arrayRemoveScalarExpr =
+ scalarExprToProto("array_remove_all", arrayExprProto, keyExprProto)
+
+ val isNotNullExpr = createUnaryExpr(
expr,
- expr.children(0),
- expr.children(1),
+ ar.right,
inputs,
binding,
- (builder, binaryExpr) => builder.setArrayRemove(binaryExpr))
+ (builder, unaryExpr) => builder.setIsNotNull(unaryExpr))
+
+ val nullLiteralProto = exprToProto(Literal(null, ar.right.dataType),
Seq.empty)
+
+ if (arrayRemoveScalarExpr.isDefined && isNotNullExpr.isDefined &&
nullLiteralProto.isDefined) {
+ val caseWhenExpr = ExprOuterClass.CaseWhen
+ .newBuilder()
+ .addWhen(isNotNullExpr.get)
+ .addThen(arrayRemoveScalarExpr.get)
+ .setElseExpr(nullLiteralProto.get)
+ .build()
+ Some(
+ ExprOuterClass.Expr
+ .newBuilder()
+ .setCaseWhen(caseWhenExpr)
+ .build())
+ } else {
+ withInfo(expr, expr.children: _*)
+ None
+ }
}
}
@@ -71,13 +96,39 @@ object CometArrayAppend extends CometExpressionSerde with
IncompatExpr {
expr: Expression,
inputs: Seq[Attribute],
binding: Boolean): Option[ExprOuterClass.Expr] = {
- createBinaryExpr(
+ val child = expr.children.head
+ val elementType = child.dataType.asInstanceOf[ArrayType].elementType
+
+ val arrayExprProto = exprToProto(expr.children(0), inputs, binding)
+ val keyExprProto = exprToProto(expr.children(1), inputs, binding)
+
+ val arrayAppendScalarExpr = scalarExprToProto("array_append",
arrayExprProto, keyExprProto)
+
+ val isNotNullExpr = createUnaryExpr(
expr,
expr.children(0),
- expr.children(1),
inputs,
binding,
- (builder, binaryExpr) => builder.setArrayAppend(binaryExpr))
+ (builder, unaryExpr) => builder.setIsNotNull(unaryExpr))
+
+ val nullLiteralProto = exprToProto(Literal(null, elementType), Seq.empty)
+
+ if (arrayAppendScalarExpr.isDefined && isNotNullExpr.isDefined &&
nullLiteralProto.isDefined) {
+ val caseWhenExpr = ExprOuterClass.CaseWhen
+ .newBuilder()
+ .addWhen(isNotNullExpr.get)
+ .addThen(arrayAppendScalarExpr.get)
+ .setElseExpr(nullLiteralProto.get)
+ .build()
+ Some(
+ ExprOuterClass.Expr
+ .newBuilder()
+ .setCaseWhen(caseWhenExpr)
+ .build())
+ } else {
+ withInfo(expr, expr.children: _*)
+ None
+ }
}
}
@@ -86,13 +137,12 @@ object CometArrayContains extends CometExpressionSerde
with IncompatExpr {
expr: Expression,
inputs: Seq[Attribute],
binding: Boolean): Option[ExprOuterClass.Expr] = {
- createBinaryExpr(
- expr,
- expr.children(0),
- expr.children(1),
- inputs,
- binding,
- (builder, binaryExpr) => builder.setArrayContains(binaryExpr))
+ val arrayExprProto = exprToProto(expr.children(0), inputs, binding)
+ val keyExprProto = exprToProto(expr.children(1), inputs, binding)
+
+ val arrayContainsScalarExpr =
+ scalarExprToProto("array_has", arrayExprProto, keyExprProto)
+ optExprWithInfo(arrayContainsScalarExpr, expr, expr.children: _*)
}
}
@@ -101,13 +151,12 @@ object CometArrayIntersect extends CometExpressionSerde
with IncompatExpr {
expr: Expression,
inputs: Seq[Attribute],
binding: Boolean): Option[ExprOuterClass.Expr] = {
- createBinaryExpr(
- expr,
- expr.children(0),
- expr.children(1),
- inputs,
- binding,
- (builder, binaryExpr) => builder.setArrayIntersect(binaryExpr))
+ val leftArrayExprProto = exprToProto(expr.children(0), inputs, binding)
+ val rightArrayExprProto = exprToProto(expr.children(1), inputs, binding)
+
+ val arraysIntersectScalarExpr =
+ scalarExprToProto("array_intersect", leftArrayExprProto,
rightArrayExprProto)
+ optExprWithInfo(arraysIntersectScalarExpr, expr, expr.children: _*)
}
}
@@ -116,13 +165,15 @@ object CometArraysOverlap extends CometExpressionSerde
with IncompatExpr {
expr: Expression,
inputs: Seq[Attribute],
binding: Boolean): Option[ExprOuterClass.Expr] = {
- createBinaryExpr(
- expr,
- expr.children(0),
- expr.children(1),
- inputs,
- binding,
- (builder, binaryExpr) => builder.setArraysOverlap(binaryExpr))
+ val leftArrayExprProto = exprToProto(expr.children(0), inputs, binding)
+ val rightArrayExprProto = exprToProto(expr.children(1), inputs, binding)
+
+ val arraysOverlapScalarExpr = scalarExprToProtoWithReturnType(
+ "array_has_any",
+ BooleanType,
+ leftArrayExprProto,
+ rightArrayExprProto)
+ optExprWithInfo(arraysOverlapScalarExpr, expr, expr.children: _*)
}
}
@@ -142,12 +193,7 @@ object CometArrayCompact extends CometExpressionSerde with
IncompatExpr {
ArrayType(elementType = elementType),
arrayExprProto,
nullLiteralProto)
- arrayCompactScalarExpr match {
- case None =>
- withInfo(expr, "unsupported arguments for ArrayCompact",
expr.children: _*)
- None
- case expr => expr
- }
+ optExprWithInfo(arrayCompactScalarExpr, expr, expr.children: _*)
}
}
@@ -160,32 +206,61 @@ object CometArrayJoin extends CometExpressionSerde with
IncompatExpr {
val arrayExprProto = exprToProto(arrayExpr.array, inputs, binding)
val delimiterExprProto = exprToProto(arrayExpr.delimiter, inputs, binding)
- if (arrayExprProto.isDefined && delimiterExprProto.isDefined) {
- val arrayJoinBuilder = arrayExpr.nullReplacement match {
- case Some(nrExpr) =>
- val nullReplacementExprProto = exprToProto(nrExpr, inputs, binding)
- ExprOuterClass.ArrayJoin
- .newBuilder()
- .setArrayExpr(arrayExprProto.get)
- .setDelimiterExpr(delimiterExprProto.get)
- .setNullReplacementExpr(nullReplacementExprProto.get)
- case None =>
- ExprOuterClass.ArrayJoin
- .newBuilder()
- .setArrayExpr(arrayExprProto.get)
- .setDelimiterExpr(delimiterExprProto.get)
- }
+ arrayExpr.nullReplacement match {
+ case Some(nullReplacementExpr) =>
+ val nullReplacementExprProto = exprToProto(nullReplacementExpr,
inputs, binding)
+
+ val arrayJoinScalarExpr = scalarExprToProto(
+ "array_to_string",
+ arrayExprProto,
+ delimiterExprProto,
+ nullReplacementExprProto)
+
+ optExprWithInfo(
+ arrayJoinScalarExpr,
+ expr,
+ arrayExpr,
+ arrayExpr.delimiter,
+ nullReplacementExpr)
+ case None =>
+ val arrayJoinScalarExpr =
+ scalarExprToProto("array_to_string", arrayExprProto,
delimiterExprProto)
+
+ optExprWithInfo(arrayJoinScalarExpr, expr, arrayExpr,
arrayExpr.delimiter)
+ }
+ }
+}
+
+object CometArrayInsert extends CometExpressionSerde with IncompatExpr {
+ override def convert(
+ expr: Expression,
+ inputs: Seq[Attribute],
+ binding: Boolean): Option[ExprOuterClass.Expr] = {
+ val srcExprProto = exprToProtoInternal(expr.children(0), inputs, binding)
+ val posExprProto = exprToProtoInternal(expr.children(1), inputs, binding)
+ val itemExprProto = exprToProtoInternal(expr.children(2), inputs, binding)
+ val legacyNegativeIndex =
+
SQLConf.get.getConfString("spark.sql.legacy.negativeIndexInArrayInsert").toBoolean
+ if (srcExprProto.isDefined && posExprProto.isDefined &&
itemExprProto.isDefined) {
+ val arrayInsertBuilder = ExprOuterClass.ArrayInsert
+ .newBuilder()
+ .setSrcArrayExpr(srcExprProto.get)
+ .setPosExpr(posExprProto.get)
+ .setItemExpr(itemExprProto.get)
+ .setLegacyNegativeIndex(legacyNegativeIndex)
+
Some(
ExprOuterClass.Expr
.newBuilder()
- .setArrayJoin(arrayJoinBuilder)
+ .setArrayInsert(arrayInsertBuilder)
.build())
} else {
- val exprs: List[Expression] = arrayExpr.nullReplacement match {
- case Some(nrExpr) => List(arrayExpr, arrayExpr.delimiter, nrExpr)
- case None => List(arrayExpr, arrayExpr.delimiter)
- }
- withInfo(expr, "unsupported arguments for ArrayJoin", exprs: _*)
+ withInfo(
+ expr,
+ "unsupported arguments for ArrayInsert",
+ expr.children(0),
+ expr.children(1),
+ expr.children(2))
None
}
}
diff --git
a/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala
b/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala
index 8b3299dc9..cef48c50c 100644
--- a/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala
@@ -160,59 +160,66 @@ class CometArrayExpressionSuite extends CometTestBase
with AdaptiveSparkPlanHelp
test("array_prepend") {
assume(isSpark35Plus) // in Spark 3.5 array_prepend is implemented via
array_insert
- Seq(true, false).foreach { dictionaryEnabled =>
- withTempDir { dir =>
- val path = new Path(dir.toURI.toString, "test.parquet")
- makeParquetFileAllTypes(path, dictionaryEnabled = dictionaryEnabled,
10000)
- spark.read.parquet(path.toString).createOrReplaceTempView("t1");
- checkSparkAnswerAndOperator(spark.sql("Select
array_prepend(array(_1),false) from t1"))
- checkSparkAnswerAndOperator(
- spark.sql("SELECT array_prepend(array(_2, _3, _4), 4) FROM t1"))
- checkSparkAnswerAndOperator(
- spark.sql("SELECT array_prepend(array(_2, _3, _4), null) FROM t1"));
- checkSparkAnswerAndOperator(
- spark.sql("SELECT array_prepend(array(_6, _7), CAST(6.5 AS DOUBLE))
FROM t1"));
- checkSparkAnswerAndOperator(spark.sql("SELECT array_prepend(array(_8),
'test') FROM t1"));
- checkSparkAnswerAndOperator(spark.sql("SELECT
array_prepend(array(_19), _19) FROM t1"));
- checkSparkAnswerAndOperator(
- spark.sql("SELECT array_prepend((CASE WHEN _2 =_3 THEN array(_4)
END), _4) FROM t1"));
+ withSQLConf(CometConf.COMET_EXPR_ALLOW_INCOMPATIBLE.key -> "true") {
+ Seq(true, false).foreach { dictionaryEnabled =>
+ withTempDir { dir =>
+ val path = new Path(dir.toURI.toString, "test.parquet")
+ makeParquetFileAllTypes(path, dictionaryEnabled = dictionaryEnabled,
10000)
+ spark.read.parquet(path.toString).createOrReplaceTempView("t1");
+ checkSparkAnswerAndOperator(spark.sql("Select
array_prepend(array(_1),false) from t1"))
+ checkSparkAnswerAndOperator(
+ spark.sql("SELECT array_prepend(array(_2, _3, _4), 4) FROM t1"))
+ checkSparkAnswerAndOperator(
+ spark.sql("SELECT array_prepend(array(_2, _3, _4), null) FROM
t1"));
+ checkSparkAnswerAndOperator(
+ spark.sql("SELECT array_prepend(array(_6, _7), CAST(6.5 AS
DOUBLE)) FROM t1"));
+ checkSparkAnswerAndOperator(
+ spark.sql("SELECT array_prepend(array(_8), 'test') FROM t1"));
+ checkSparkAnswerAndOperator(spark.sql("SELECT
array_prepend(array(_19), _19) FROM t1"));
+ checkSparkAnswerAndOperator(
+ spark.sql("SELECT array_prepend((CASE WHEN _2 =_3 THEN array(_4)
END), _4) FROM t1"));
+ }
}
}
}
test("ArrayInsert") {
- Seq(true, false).foreach(dictionaryEnabled =>
- withTempDir { dir =>
- val path = new Path(dir.toURI.toString, "test.parquet")
- makeParquetFileAllTypes(path, dictionaryEnabled, 10000)
- val df = spark.read
- .parquet(path.toString)
- .withColumn("arr", array(col("_4"), lit(null), col("_4")))
- .withColumn("arrInsertResult", expr("array_insert(arr, 1, 1)"))
- .withColumn("arrInsertNegativeIndexResult", expr("array_insert(arr,
-1, 1)"))
- .withColumn("arrPosGreaterThanSize", expr("array_insert(arr, 8, 1)"))
- .withColumn("arrNegPosGreaterThanSize", expr("array_insert(arr, -8,
1)"))
- .withColumn("arrInsertNone", expr("array_insert(arr, 1, null)"))
- checkSparkAnswerAndOperator(df.select("arrInsertResult"))
- checkSparkAnswerAndOperator(df.select("arrInsertNegativeIndexResult"))
- checkSparkAnswerAndOperator(df.select("arrPosGreaterThanSize"))
- checkSparkAnswerAndOperator(df.select("arrNegPosGreaterThanSize"))
- checkSparkAnswerAndOperator(df.select("arrInsertNone"))
- })
+ withSQLConf(CometConf.COMET_EXPR_ALLOW_INCOMPATIBLE.key -> "true") {
+ Seq(true, false).foreach(dictionaryEnabled =>
+ withTempDir { dir =>
+ val path = new Path(dir.toURI.toString, "test.parquet")
+ makeParquetFileAllTypes(path, dictionaryEnabled, 10000)
+ val df = spark.read
+ .parquet(path.toString)
+ .withColumn("arr", array(col("_4"), lit(null), col("_4")))
+ .withColumn("arrInsertResult", expr("array_insert(arr, 1, 1)"))
+ .withColumn("arrInsertNegativeIndexResult",
expr("array_insert(arr, -1, 1)"))
+ .withColumn("arrPosGreaterThanSize", expr("array_insert(arr, 8,
1)"))
+ .withColumn("arrNegPosGreaterThanSize", expr("array_insert(arr,
-8, 1)"))
+ .withColumn("arrInsertNone", expr("array_insert(arr, 1, null)"))
+ checkSparkAnswerAndOperator(df.select("arrInsertResult"))
+
checkSparkAnswerAndOperator(df.select("arrInsertNegativeIndexResult"))
+ checkSparkAnswerAndOperator(df.select("arrPosGreaterThanSize"))
+ checkSparkAnswerAndOperator(df.select("arrNegPosGreaterThanSize"))
+ checkSparkAnswerAndOperator(df.select("arrInsertNone"))
+ })
+ }
}
test("ArrayInsertUnsupportedArgs") {
// This test checks that the else branch in ArrayInsert
// mapping to the comet is valid and fallback to spark is working fine.
- withTempDir { dir =>
- val path = new Path(dir.toURI.toString, "test.parquet")
- makeParquetFileAllTypes(path, dictionaryEnabled = false, 10000)
- val df = spark.read
- .parquet(path.toString)
- .withColumn("arr", array(col("_4"), lit(null), col("_4")))
- .withColumn("idx", udf((_: Int) => 1).apply(col("_4")))
- .withColumn("arrUnsupportedArgs", expr("array_insert(arr, idx, 1)"))
- checkSparkAnswer(df.select("arrUnsupportedArgs"))
+ withSQLConf(CometConf.COMET_EXPR_ALLOW_INCOMPATIBLE.key -> "true") {
+ withTempDir { dir =>
+ val path = new Path(dir.toURI.toString, "test.parquet")
+ makeParquetFileAllTypes(path, dictionaryEnabled = false, 10000)
+ val df = spark.read
+ .parquet(path.toString)
+ .withColumn("arr", array(col("_4"), lit(null), col("_4")))
+ .withColumn("idx", udf((_: Int) => 1).apply(col("_4")))
+ .withColumn("arrUnsupportedArgs", expr("array_insert(arr, idx, 1)"))
+ checkSparkAnswer(df.select("arrUnsupportedArgs"))
+ }
}
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]