This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new a8b919f924db [SPARK-47736][SQL] Add support for AbstractArrayType
a8b919f924db is described below

commit a8b919f924db1e2818b2b0de49762292ae20c17c
Author: Mihailo Milosevic <mihailo.milose...@databricks.com>
AuthorDate: Thu Apr 11 15:39:42 2024 +0800

    [SPARK-47736][SQL] Add support for AbstractArrayType
    
    ### What changes were proposed in this pull request?
    Addition of abstract arraytype which accepts StringTypeCollated as 
elementType. Changes in this PR https://github.com/apache/spark/pull/45693 work 
for ArrayJoin, but will not work in general for other functions. This PR 
introduces a change to give an interface for all functions. Merge only after 
#45693.
    
    ### Why are the changes needed?
    This is needed in order to enable functions to use collated arrays.
    
    ### Does this PR introduce _any_ user-facing change?
    Yes, collation functions will work.
    
    ### How was this patch tested?
    Test for array_join added to `CollationSuite`
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No.
    
    Closes #45891 from mihailom-db/SPARK-47736.
    
    Authored-by: Mihailo Milosevic <mihailo.milose...@databricks.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../sql/internal/types/AbstractArrayType.scala     | 37 ++++++++++++++++++++++
 .../sql/internal/types/AbstractStringType.scala}   | 10 +++---
 .../sql/catalyst/analysis/AnsiTypeCoercion.scala   |  3 +-
 .../spark/sql/catalyst/analysis/TypeCoercion.scala |  6 +++-
 .../expressions/collationExpressions.scala         |  1 +
 .../expressions/collectionOperations.scala         |  5 +--
 .../catalyst/expressions/stringExpressions.scala   |  1 +
 .../org/apache/spark/sql/CollationSuite.scala      |  3 ++
 8 files changed, 57 insertions(+), 9 deletions(-)

diff --git 
a/sql/api/src/main/scala/org/apache/spark/sql/internal/types/AbstractArrayType.scala
 
b/sql/api/src/main/scala/org/apache/spark/sql/internal/types/AbstractArrayType.scala
new file mode 100644
index 000000000000..406449a33727
--- /dev/null
+++ 
b/sql/api/src/main/scala/org/apache/spark/sql/internal/types/AbstractArrayType.scala
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.internal.types
+
+import org.apache.spark.sql.types.{AbstractDataType, ArrayType, DataType}
+
+
+/**
+ * Use AbstractArrayType(AbstractDataType) for defining expected types for 
expression parameters.
+ */
+case class AbstractArrayType(elementType: AbstractDataType) extends 
AbstractDataType {
+
+  override private[sql] def defaultConcreteType: DataType =
+    ArrayType(elementType.defaultConcreteType, containsNull = true)
+
+  override private[sql] def acceptsType(other: DataType): Boolean = {
+    other.isInstanceOf[ArrayType] &&
+      elementType.acceptsType(other.asInstanceOf[ArrayType].elementType)
+  }
+
+  override private[spark] def simpleString: String = 
s"array<${elementType.simpleString}>"
+}
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/StringTypeCollated.scala
 
b/sql/api/src/main/scala/org/apache/spark/sql/internal/types/AbstractStringType.scala
similarity index 86%
rename from 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/StringTypeCollated.scala
rename to 
sql/api/src/main/scala/org/apache/spark/sql/internal/types/AbstractStringType.scala
index 67b65859e6bb..6403295fe20c 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/StringTypeCollated.scala
+++ 
b/sql/api/src/main/scala/org/apache/spark/sql/internal/types/AbstractStringType.scala
@@ -15,14 +15,14 @@
  * limitations under the License.
  */
 
-package org.apache.spark.sql.catalyst.expressions
+package org.apache.spark.sql.internal.types
 
 import org.apache.spark.sql.types.{AbstractDataType, DataType, StringType}
 
 /**
  * StringTypeCollated is an abstract class for StringType with collation 
support.
  */
-abstract class StringTypeCollated extends AbstractDataType {
+abstract class AbstractStringType extends AbstractDataType {
   override private[sql] def defaultConcreteType: DataType = StringType
   override private[sql] def simpleString: String = "string"
 }
@@ -30,7 +30,7 @@ abstract class StringTypeCollated extends AbstractDataType {
 /**
  * Use StringTypeBinary for expressions supporting only binary collation.
  */
-case object StringTypeBinary extends StringTypeCollated {
+case object StringTypeBinary extends AbstractStringType {
   override private[sql] def acceptsType(other: DataType): Boolean =
     other.isInstanceOf[StringType] && 
other.asInstanceOf[StringType].supportsBinaryEquality
 }
@@ -38,7 +38,7 @@ case object StringTypeBinary extends StringTypeCollated {
 /**
  * Use StringTypeBinaryLcase for expressions supporting only binary and 
lowercase collation.
  */
-case object StringTypeBinaryLcase extends StringTypeCollated {
+case object StringTypeBinaryLcase extends AbstractStringType {
   override private[sql] def acceptsType(other: DataType): Boolean =
     other.isInstanceOf[StringType] && 
(other.asInstanceOf[StringType].supportsBinaryEquality ||
       other.asInstanceOf[StringType].isUTF8BinaryLcaseCollation)
@@ -47,6 +47,6 @@ case object StringTypeBinaryLcase extends StringTypeCollated {
 /**
  * Use StringTypeAnyCollation for expressions supporting all possible 
collation types.
  */
-case object StringTypeAnyCollation extends StringTypeCollated {
+case object StringTypeAnyCollation extends AbstractStringType {
   override private[sql] def acceptsType(other: DataType): Boolean = 
other.isInstanceOf[StringType]
 }
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala
index dd904be9a760..52c1136b1ee3 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.analysis
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
 import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.internal.types.AbstractStringType
 import org.apache.spark.sql.types._
 import org.apache.spark.sql.types.UpCastRule.numericPrecedence
 
@@ -185,7 +186,7 @@ object AnsiTypeCoercion extends TypeCoercionBase {
       // If a function expects a StringType, no StringType instance should be 
implicitly cast to
       // StringType with a collation that's not accepted (aka. lockdown 
unsupported collations).
       case (_: StringType, _: StringType) => None
-      case (_: StringType, _: StringTypeCollated) => None
+      case (_: StringType, _: AbstractStringType) => None
 
       // If a function expects integral type, fractional input is not allowed.
       case (_: FractionalType, IntegralType) => None
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
index 615d21f67695..5d8bcd4f7b1f 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
@@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.trees.AlwaysProcess
 import org.apache.spark.sql.catalyst.types.DataTypeUtils
 import org.apache.spark.sql.errors.QueryCompilationErrors
 import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.internal.types.{AbstractArrayType, 
AbstractStringType}
 import org.apache.spark.sql.types._
 import org.apache.spark.sql.types.UpCastRule.numericPrecedence
 
@@ -997,7 +998,7 @@ object TypeCoercion extends TypeCoercionBase {
       case (_: StringType, BinaryType) => BinaryType
       // Cast any atomic type to string.
       case (any: AtomicType, _: StringType) if !any.isInstanceOf[StringType] 
=> StringType
-      case (any: AtomicType, st: StringTypeCollated)
+      case (any: AtomicType, st: AbstractStringType)
         if !any.isInstanceOf[StringType] => st.defaultConcreteType
 
       // When we reach here, input type is not acceptable for any types in 
this type collection,
@@ -1017,6 +1018,9 @@ object TypeCoercion extends TypeCoercionBase {
       case (ArrayType(fromType, fn), ArrayType(toType: DataType, true)) =>
         implicitCast(fromType, toType).map(ArrayType(_, true)).orNull
 
+      case (ArrayType(fromType, fn), AbstractArrayType(toType)) =>
+        implicitCast(fromType, toType).map(ArrayType(_, true)).orNull
+
       case (ArrayType(fromType, true), ArrayType(toType: DataType, false)) => 
null
 
       case (ArrayType(fromType, false), ArrayType(toType: DataType, false))
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collationExpressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collationExpressions.scala
index e8b738921b73..6af00e193d94 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collationExpressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collationExpressions.scala
@@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._
 import org.apache.spark.sql.catalyst.util.CollationFactory
 import org.apache.spark.sql.errors.QueryCompilationErrors
 import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.internal.types.StringTypeAnyCollation
 import org.apache.spark.sql.types._
 
 // scalastyle:off line.contains.tab
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index f331a489968a..39bf6734eb27 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -38,6 +38,7 @@ import org.apache.spark.sql.catalyst.util.DateTimeConstants._
 import org.apache.spark.sql.catalyst.util.DateTimeUtils._
 import org.apache.spark.sql.errors.{QueryErrorsBase, QueryExecutionErrors}
 import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.internal.types.{AbstractArrayType, 
StringTypeAnyCollation}
 import org.apache.spark.sql.types._
 import org.apache.spark.sql.util.SQLOpenHashSet
 import org.apache.spark.unsafe.UTF8StringBuilder
@@ -2003,9 +2004,9 @@ case class ArrayJoin(
     this(array, delimiter, Some(nullReplacement))
 
   override def inputTypes: Seq[AbstractDataType] = if 
(nullReplacement.isDefined) {
-    Seq(ArrayType, StringTypeAnyCollation, StringTypeAnyCollation)
+    Seq(AbstractArrayType(StringTypeAnyCollation), StringTypeAnyCollation, 
StringTypeAnyCollation)
   } else {
-    Seq(ArrayType, StringTypeAnyCollation)
+    Seq(AbstractArrayType(StringTypeAnyCollation), StringTypeAnyCollation)
   }
 
   override def children: Seq[Expression] = if (nullReplacement.isDefined) {
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
index ad15509dc6fd..cf6c9d4f1d94 100755
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
@@ -37,6 +37,7 @@ import 
org.apache.spark.sql.catalyst.trees.TreePattern.{TreePattern, UPPER_OR_LO
 import org.apache.spark.sql.catalyst.util.{ArrayData, CollationFactory, 
GenericArrayData, TypeUtils}
 import org.apache.spark.sql.errors.{QueryCompilationErrors, 
QueryExecutionErrors}
 import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.internal.types.StringTypeAnyCollation
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.UTF8StringBuilder
 import org.apache.spark.unsafe.array.ByteArrayMethods
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala
index 62150eaeac54..c0322387c804 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala
@@ -645,6 +645,9 @@ class CollationSuite extends DatasourceV2SQLBase with 
AdaptiveSparkPlanHelper {
         },
         errorClass = "COLLATION_MISMATCH.IMPLICIT"
       )
+
+      checkAnswer(sql("SELECT array_join(array('a', 'b' collate UNICODE), 'c' 
collate UNICODE_CI)"),
+        Seq(Row("acb")))
     }
   }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to