This is an automated email from the ASF dual-hosted git repository. agrove pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push: new c25060e6 feat: add support for array_remove expression (#1179) c25060e6 is described below commit c25060e61a5b5fa3d62d8782aa3f1a06467792dd Author: Jagdish Parihar <jatin6...@gmail.com> AuthorDate: Sun Jan 12 12:30:40 2025 -0700 feat: add support for array_remove expression (#1179) * wip: array remove * added comet expression test * updated test cases * fixed array_remove function for null values * removed commented code * remove unnecessary code * updated the test for 'array_remove' * added test for array_remove in case the input array is null * wip: case array is empty * removed test case for empty array --- native/core/src/execution/planner.rs | 30 ++++++++++++++++++++++ native/proto/src/proto/expr.proto | 1 + .../org/apache/comet/serde/QueryPlanSerde.scala | 6 +++++ .../org/apache/comet/CometExpressionSuite.scala | 16 ++++++++++++ 4 files changed, 53 insertions(+) diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 294922f2..c43230f4 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -66,6 +66,7 @@ use datafusion::{ }; use datafusion_comet_spark_expr::{create_comet_physical_fun, create_negate_expr}; use datafusion_functions_nested::concat::ArrayAppend; +use datafusion_functions_nested::remove::array_remove_all_udf; use datafusion_physical_expr::aggregate::{AggregateExprBuilder, AggregateFunctionExpr}; use crate::execution::shuffle::CompressionCodec; @@ -735,6 +736,35 @@ impl PhysicalPlanner { )); Ok(array_has_expr) } + ExprStruct::ArrayRemove(expr) => { + let src_array_expr = + self.create_expr(expr.left.as_ref().unwrap(), Arc::clone(&input_schema))?; + let key_expr = + self.create_expr(expr.right.as_ref().unwrap(), Arc::clone(&input_schema))?; + let args = vec![Arc::clone(&src_array_expr), Arc::clone(&key_expr)]; + let return_type = src_array_expr.data_type(&input_schema)?; + + let datafusion_array_remove = array_remove_all_udf(); + + let array_remove_expr: Arc<dyn PhysicalExpr> = Arc::new(ScalarFunctionExpr::new( + "array_remove", + datafusion_array_remove, + args, + return_type, + )); + let is_null_expr: Arc<dyn PhysicalExpr> = Arc::new(IsNullExpr::new(key_expr)); + + let null_literal_expr: Arc<dyn PhysicalExpr> = + Arc::new(Literal::new(ScalarValue::Null)); + + let case_expr = CaseExpr::try_new( + None, + vec![(is_null_expr, null_literal_expr)], + Some(array_remove_expr), + )?; + + Ok(Arc::new(case_expr)) + } expr => Err(ExecutionError::GeneralError(format!( "Not implemented: {:?}", expr diff --git a/native/proto/src/proto/expr.proto b/native/proto/src/proto/expr.proto index e76ecdcc..8e3bc60b 100644 --- a/native/proto/src/proto/expr.proto +++ b/native/proto/src/proto/expr.proto @@ -85,6 +85,7 @@ message Expr { BinaryExpr array_append = 58; ArrayInsert array_insert = 59; BinaryExpr array_contains = 60; + BinaryExpr array_remove = 61; } } diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 7a69e630..81864865 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -2266,6 +2266,12 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim withInfo(expr, "unsupported arguments for GetArrayStructFields", child) None } + case expr if expr.prettyName == "array_remove" => + createBinaryExpr( + expr.children(0), + expr.children(1), + inputs, + (builder, binaryExpr) => builder.setArrayRemove(binaryExpr)) case expr if expr.prettyName == "array_contains" => createBinaryExpr( expr.children(0), diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index afdf8601..8c2759a3 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -2529,4 +2529,20 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { spark.sql("SELECT array_contains((CASE WHEN _2 =_3 THEN array(_4) END), _4) FROM t1")); } } + + test("array_remove") { + Seq(true, false).foreach { dictionaryEnabled => + withTempDir { dir => + val path = new Path(dir.toURI.toString, "test.parquet") + makeParquetFileAllTypes(path, dictionaryEnabled, 10000) + spark.read.parquet(path.toString).createOrReplaceTempView("t1") + checkSparkAnswerAndOperator( + sql("SELECT array_remove(array(_2, _3,_4), _2) from t1 where _2 is null")) + checkSparkAnswerAndOperator( + sql("SELECT array_remove(array(_2, _3,_4), _3) from t1 where _3 is not null")) + checkSparkAnswerAndOperator(sql( + "SELECT array_remove(case when _2 = _3 THEN array(_2, _3,_4) ELSE null END, _3) from t1")) + } + } + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@datafusion.apache.org For additional commands, e-mail: commits-h...@datafusion.apache.org