This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new 9a4e5b5bd Feat: Support array_join function (#1290)
9a4e5b5bd is described below

commit 9a4e5b5bd799112451580aae6b76eb7582a08d29
Author: Eren Avsarogullari <erenavsarogull...@gmail.com>
AuthorDate: Thu Jan 23 15:23:32 2025 -0800

    Feat: Support array_join function (#1290)
---
 native/core/src/execution/planner.rs               | 27 ++++++++++++++++++
 native/proto/src/proto/expr.proto                  |  7 +++++
 .../org/apache/comet/serde/QueryPlanSerde.scala    | 32 ++++++++++++++++++++++
 .../org/apache/comet/CometExpressionSuite.scala    | 17 ++++++++++++
 4 files changed, 83 insertions(+)

diff --git a/native/core/src/execution/planner.rs 
b/native/core/src/execution/planner.rs
index c7df503a7..95926bfee 100644
--- a/native/core/src/execution/planner.rs
+++ b/native/core/src/execution/planner.rs
@@ -68,6 +68,7 @@ use datafusion_comet_spark_expr::{create_comet_physical_fun, 
create_negate_expr}
 use datafusion_functions_nested::concat::ArrayAppend;
 use datafusion_functions_nested::remove::array_remove_all_udf;
 use datafusion_functions_nested::set_ops::array_intersect_udf;
+use datafusion_functions_nested::string::array_to_string_udf;
 use datafusion_physical_expr::aggregate::{AggregateExprBuilder, 
AggregateFunctionExpr};
 
 use crate::execution::shuffle::CompressionCodec;
@@ -791,6 +792,32 @@ impl PhysicalPlanner {
                 ));
                 Ok(array_intersect_expr)
             }
+            ExprStruct::ArrayJoin(expr) => {
+                let array_expr =
+                    self.create_expr(expr.array_expr.as_ref().unwrap(), 
Arc::clone(&input_schema))?;
+                let delimiter_expr = self.create_expr(
+                    expr.delimiter_expr.as_ref().unwrap(),
+                    Arc::clone(&input_schema),
+                )?;
+
+                let mut args = vec![Arc::clone(&array_expr), delimiter_expr];
+                if expr.null_replacement_expr.is_some() {
+                    let null_replacement_expr = self.create_expr(
+                        expr.null_replacement_expr.as_ref().unwrap(),
+                        Arc::clone(&input_schema),
+                    )?;
+                    args.push(null_replacement_expr)
+                }
+
+                let datafusion_array_to_string = array_to_string_udf();
+                let array_join_expr = Arc::new(ScalarFunctionExpr::new(
+                    "array_join",
+                    datafusion_array_to_string,
+                    args,
+                    DataType::Utf8,
+                ));
+                Ok(array_join_expr)
+            }
             expr => Err(ExecutionError::GeneralError(format!(
                 "Not implemented: {:?}",
                 expr
diff --git a/native/proto/src/proto/expr.proto 
b/native/proto/src/proto/expr.proto
index 0b7d24d9f..83d6da7cb 100644
--- a/native/proto/src/proto/expr.proto
+++ b/native/proto/src/proto/expr.proto
@@ -87,6 +87,7 @@ message Expr {
     BinaryExpr array_contains = 60;
     BinaryExpr array_remove = 61;
     BinaryExpr array_intersect = 62;
+    ArrayJoin array_join = 63;
   }
 }
 
@@ -415,6 +416,12 @@ message ArrayInsert {
   bool legacy_negative_index = 4;
 }
 
+message ArrayJoin {
+  Expr array_expr = 1;
+  Expr delimiter_expr = 2;
+  Expr null_replacement_expr = 3;
+}
+
 message DataType {
   enum DataTypeId {
     BOOL = 0;
diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala 
b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
index 70ed49464..22bb7dc82 100644
--- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
@@ -2312,6 +2312,38 @@ object QueryPlanSerde extends Logging with 
ShimQueryPlanSerde with CometExprShim
             expr.children(1),
             inputs,
             (builder, binaryExpr) => builder.setArrayIntersect(binaryExpr))
+        case ArrayJoin(arrayExpr, delimiterExpr, nullReplacementExpr) =>
+          val arrayExprProto = exprToProto(arrayExpr, inputs, binding)
+          val delimiterExprProto = exprToProto(delimiterExpr, inputs, binding)
+
+          if (arrayExprProto.isDefined && delimiterExprProto.isDefined) {
+            val arrayJoinBuilder = nullReplacementExpr match {
+              case Some(nrExpr) =>
+                val nullReplacementExprProto = exprToProto(nrExpr, inputs, 
binding)
+                ExprOuterClass.ArrayJoin
+                  .newBuilder()
+                  .setArrayExpr(arrayExprProto.get)
+                  .setDelimiterExpr(delimiterExprProto.get)
+                  .setNullReplacementExpr(nullReplacementExprProto.get)
+              case None =>
+                ExprOuterClass.ArrayJoin
+                  .newBuilder()
+                  .setArrayExpr(arrayExprProto.get)
+                  .setDelimiterExpr(delimiterExprProto.get)
+            }
+            Some(
+              ExprOuterClass.Expr
+                .newBuilder()
+                .setArrayJoin(arrayJoinBuilder)
+                .build())
+          } else {
+            val exprs: List[Expression] = nullReplacementExpr match {
+              case Some(nrExpr) => List(arrayExpr, delimiterExpr, nrExpr)
+              case None => List(arrayExpr, delimiterExpr)
+            }
+            withInfo(expr, "unsupported arguments for ArrayJoin", exprs: _*)
+            None
+          }
         case _ =>
           withInfo(expr, s"${expr.prettyName} is not supported", 
expr.children: _*)
           None
diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala 
b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
index b43fa7728..99cf4bad4 100644
--- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
@@ -2684,4 +2684,21 @@ class CometExpressionSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
     }
   }
 
+  test("array_join") {
+    Seq(true, false).foreach { dictionaryEnabled =>
+      withTempDir { dir =>
+        val path = new Path(dir.toURI.toString, "test.parquet")
+        makeParquetFileAllTypes(path, dictionaryEnabled, 10000)
+        spark.read.parquet(path.toString).createOrReplaceTempView("t1")
+        checkSparkAnswerAndOperator(sql(
+          "SELECT array_join(array(cast(_1 as string), cast(_2 as string), 
cast(_6 as string)), ' @ ') from t1"))
+        checkSparkAnswerAndOperator(sql(
+          "SELECT array_join(array(cast(_1 as string), cast(_2 as string), 
cast(_6 as string)), ' @ ', ' +++ ') from t1"))
+        checkSparkAnswerAndOperator(sql(
+          "SELECT array_join(array('hello', 'world', cast(_2 as string)), ' ') 
from t1 where _2 is not null"))
+        checkSparkAnswerAndOperator(
+          sql("SELECT array_join(array('hello', '-', 'world', cast(_2 as 
string)), ' ') from t1"))
+      }
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@datafusion.apache.org
For additional commands, e-mail: commits-h...@datafusion.apache.org

Reply via email to