This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new c25060e6 feat: add support for array_remove expression (#1179)
c25060e6 is described below
commit c25060e61a5b5fa3d62d8782aa3f1a06467792dd
Author: Jagdish Parihar <[email protected]>
AuthorDate: Sun Jan 12 12:30:40 2025 -0700
feat: add support for array_remove expression (#1179)
* wip: array remove
* added comet expression test
* updated test cases
* fixed array_remove function for null values
* removed commented code
* remove unnecessary code
* updated the test for 'array_remove'
* added test for array_remove in case the input array is null
* wip: case array is empty
* removed test case for empty array
---
native/core/src/execution/planner.rs | 30 ++++++++++++++++++++++
native/proto/src/proto/expr.proto | 1 +
.../org/apache/comet/serde/QueryPlanSerde.scala | 6 +++++
.../org/apache/comet/CometExpressionSuite.scala | 16 ++++++++++++
4 files changed, 53 insertions(+)
diff --git a/native/core/src/execution/planner.rs
b/native/core/src/execution/planner.rs
index 294922f2..c43230f4 100644
--- a/native/core/src/execution/planner.rs
+++ b/native/core/src/execution/planner.rs
@@ -66,6 +66,7 @@ use datafusion::{
};
use datafusion_comet_spark_expr::{create_comet_physical_fun,
create_negate_expr};
use datafusion_functions_nested::concat::ArrayAppend;
+use datafusion_functions_nested::remove::array_remove_all_udf;
use datafusion_physical_expr::aggregate::{AggregateExprBuilder,
AggregateFunctionExpr};
use crate::execution::shuffle::CompressionCodec;
@@ -735,6 +736,35 @@ impl PhysicalPlanner {
));
Ok(array_has_expr)
}
+ ExprStruct::ArrayRemove(expr) => {
+ let src_array_expr =
+ self.create_expr(expr.left.as_ref().unwrap(),
Arc::clone(&input_schema))?;
+ let key_expr =
+ self.create_expr(expr.right.as_ref().unwrap(),
Arc::clone(&input_schema))?;
+ let args = vec![Arc::clone(&src_array_expr),
Arc::clone(&key_expr)];
+ let return_type = src_array_expr.data_type(&input_schema)?;
+
+ let datafusion_array_remove = array_remove_all_udf();
+
+ let array_remove_expr: Arc<dyn PhysicalExpr> =
Arc::new(ScalarFunctionExpr::new(
+ "array_remove",
+ datafusion_array_remove,
+ args,
+ return_type,
+ ));
+ let is_null_expr: Arc<dyn PhysicalExpr> =
Arc::new(IsNullExpr::new(key_expr));
+
+ let null_literal_expr: Arc<dyn PhysicalExpr> =
+ Arc::new(Literal::new(ScalarValue::Null));
+
+ let case_expr = CaseExpr::try_new(
+ None,
+ vec![(is_null_expr, null_literal_expr)],
+ Some(array_remove_expr),
+ )?;
+
+ Ok(Arc::new(case_expr))
+ }
expr => Err(ExecutionError::GeneralError(format!(
"Not implemented: {:?}",
expr
diff --git a/native/proto/src/proto/expr.proto
b/native/proto/src/proto/expr.proto
index e76ecdcc..8e3bc60b 100644
--- a/native/proto/src/proto/expr.proto
+++ b/native/proto/src/proto/expr.proto
@@ -85,6 +85,7 @@ message Expr {
BinaryExpr array_append = 58;
ArrayInsert array_insert = 59;
BinaryExpr array_contains = 60;
+ BinaryExpr array_remove = 61;
}
}
diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
index 7a69e630..81864865 100644
--- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
@@ -2266,6 +2266,12 @@ object QueryPlanSerde extends Logging with
ShimQueryPlanSerde with CometExprShim
withInfo(expr, "unsupported arguments for GetArrayStructFields",
child)
None
}
+ case expr if expr.prettyName == "array_remove" =>
+ createBinaryExpr(
+ expr.children(0),
+ expr.children(1),
+ inputs,
+ (builder, binaryExpr) => builder.setArrayRemove(binaryExpr))
case expr if expr.prettyName == "array_contains" =>
createBinaryExpr(
expr.children(0),
diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
index afdf8601..8c2759a3 100644
--- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
@@ -2529,4 +2529,20 @@ class CometExpressionSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
spark.sql("SELECT array_contains((CASE WHEN _2 =_3 THEN array(_4)
END), _4) FROM t1"));
}
}
+
+ test("array_remove") {
+ Seq(true, false).foreach { dictionaryEnabled =>
+ withTempDir { dir =>
+ val path = new Path(dir.toURI.toString, "test.parquet")
+ makeParquetFileAllTypes(path, dictionaryEnabled, 10000)
+ spark.read.parquet(path.toString).createOrReplaceTempView("t1")
+ checkSparkAnswerAndOperator(
+ sql("SELECT array_remove(array(_2, _3,_4), _2) from t1 where _2 is
null"))
+ checkSparkAnswerAndOperator(
+ sql("SELECT array_remove(array(_2, _3,_4), _3) from t1 where _3 is
not null"))
+ checkSparkAnswerAndOperator(sql(
+ "SELECT array_remove(case when _2 = _3 THEN array(_2, _3,_4) ELSE
null END, _3) from t1"))
+ }
+ }
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]