This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new a8b919f924db [SPARK-47736][SQL] Add support for AbstractArrayType
a8b919f924db is described below
commit a8b919f924db1e2818b2b0de49762292ae20c17c
Author: Mihailo Milosevic <[email protected]>
AuthorDate: Thu Apr 11 15:39:42 2024 +0800
[SPARK-47736][SQL] Add support for AbstractArrayType
### What changes were proposed in this pull request?
Addition of abstract arraytype which accepts StringTypeCollated as
elementType. Changes in this PR https://github.com/apache/spark/pull/45693 work
for ArrayJoin, but will not work in general for other functions. This PR
introduces a change to give an interface for all functions. Merge only after
#45693.
### Why are the changes needed?
This is needed in order to enable functions to use collated arrays.
### Does this PR introduce _any_ user-facing change?
Yes, collation functions will work.
### How was this patch tested?
Test for array_join added to `CollationSuite`
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #45891 from mihailom-db/SPARK-47736.
Authored-by: Mihailo Milosevic <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
.../sql/internal/types/AbstractArrayType.scala | 37 ++++++++++++++++++++++
.../sql/internal/types/AbstractStringType.scala} | 10 +++---
.../sql/catalyst/analysis/AnsiTypeCoercion.scala | 3 +-
.../spark/sql/catalyst/analysis/TypeCoercion.scala | 6 +++-
.../expressions/collationExpressions.scala | 1 +
.../expressions/collectionOperations.scala | 5 +--
.../catalyst/expressions/stringExpressions.scala | 1 +
.../org/apache/spark/sql/CollationSuite.scala | 3 ++
8 files changed, 57 insertions(+), 9 deletions(-)
diff --git
a/sql/api/src/main/scala/org/apache/spark/sql/internal/types/AbstractArrayType.scala
b/sql/api/src/main/scala/org/apache/spark/sql/internal/types/AbstractArrayType.scala
new file mode 100644
index 000000000000..406449a33727
--- /dev/null
+++
b/sql/api/src/main/scala/org/apache/spark/sql/internal/types/AbstractArrayType.scala
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.internal.types
+
+import org.apache.spark.sql.types.{AbstractDataType, ArrayType, DataType}
+
+
+/**
+ * Use AbstractArrayType(AbstractDataType) for defining expected types for
expression parameters.
+ */
+case class AbstractArrayType(elementType: AbstractDataType) extends
AbstractDataType {
+
+ override private[sql] def defaultConcreteType: DataType =
+ ArrayType(elementType.defaultConcreteType, containsNull = true)
+
+ override private[sql] def acceptsType(other: DataType): Boolean = {
+ other.isInstanceOf[ArrayType] &&
+ elementType.acceptsType(other.asInstanceOf[ArrayType].elementType)
+ }
+
+ override private[spark] def simpleString: String =
s"array<${elementType.simpleString}>"
+}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/StringTypeCollated.scala
b/sql/api/src/main/scala/org/apache/spark/sql/internal/types/AbstractStringType.scala
similarity index 86%
rename from
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/StringTypeCollated.scala
rename to
sql/api/src/main/scala/org/apache/spark/sql/internal/types/AbstractStringType.scala
index 67b65859e6bb..6403295fe20c 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/StringTypeCollated.scala
+++
b/sql/api/src/main/scala/org/apache/spark/sql/internal/types/AbstractStringType.scala
@@ -15,14 +15,14 @@
* limitations under the License.
*/
-package org.apache.spark.sql.catalyst.expressions
+package org.apache.spark.sql.internal.types
import org.apache.spark.sql.types.{AbstractDataType, DataType, StringType}
/**
* StringTypeCollated is an abstract class for StringType with collation
support.
*/
-abstract class StringTypeCollated extends AbstractDataType {
+abstract class AbstractStringType extends AbstractDataType {
override private[sql] def defaultConcreteType: DataType = StringType
override private[sql] def simpleString: String = "string"
}
@@ -30,7 +30,7 @@ abstract class StringTypeCollated extends AbstractDataType {
/**
* Use StringTypeBinary for expressions supporting only binary collation.
*/
-case object StringTypeBinary extends StringTypeCollated {
+case object StringTypeBinary extends AbstractStringType {
override private[sql] def acceptsType(other: DataType): Boolean =
other.isInstanceOf[StringType] &&
other.asInstanceOf[StringType].supportsBinaryEquality
}
@@ -38,7 +38,7 @@ case object StringTypeBinary extends StringTypeCollated {
/**
* Use StringTypeBinaryLcase for expressions supporting only binary and
lowercase collation.
*/
-case object StringTypeBinaryLcase extends StringTypeCollated {
+case object StringTypeBinaryLcase extends AbstractStringType {
override private[sql] def acceptsType(other: DataType): Boolean =
other.isInstanceOf[StringType] &&
(other.asInstanceOf[StringType].supportsBinaryEquality ||
other.asInstanceOf[StringType].isUTF8BinaryLcaseCollation)
@@ -47,6 +47,6 @@ case object StringTypeBinaryLcase extends StringTypeCollated {
/**
* Use StringTypeAnyCollation for expressions supporting all possible
collation types.
*/
-case object StringTypeAnyCollation extends StringTypeCollated {
+case object StringTypeAnyCollation extends AbstractStringType {
override private[sql] def acceptsType(other: DataType): Boolean =
other.isInstanceOf[StringType]
}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala
index dd904be9a760..52c1136b1ee3 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.analysis
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.internal.types.AbstractStringType
import org.apache.spark.sql.types._
import org.apache.spark.sql.types.UpCastRule.numericPrecedence
@@ -185,7 +186,7 @@ object AnsiTypeCoercion extends TypeCoercionBase {
// If a function expects a StringType, no StringType instance should be
implicitly cast to
// StringType with a collation that's not accepted (aka. lockdown
unsupported collations).
case (_: StringType, _: StringType) => None
- case (_: StringType, _: StringTypeCollated) => None
+ case (_: StringType, _: AbstractStringType) => None
// If a function expects integral type, fractional input is not allowed.
case (_: FractionalType, IntegralType) => None
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
index 615d21f67695..5d8bcd4f7b1f 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
@@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.trees.AlwaysProcess
import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.internal.types.{AbstractArrayType,
AbstractStringType}
import org.apache.spark.sql.types._
import org.apache.spark.sql.types.UpCastRule.numericPrecedence
@@ -997,7 +998,7 @@ object TypeCoercion extends TypeCoercionBase {
case (_: StringType, BinaryType) => BinaryType
// Cast any atomic type to string.
case (any: AtomicType, _: StringType) if !any.isInstanceOf[StringType]
=> StringType
- case (any: AtomicType, st: StringTypeCollated)
+ case (any: AtomicType, st: AbstractStringType)
if !any.isInstanceOf[StringType] => st.defaultConcreteType
// When we reach here, input type is not acceptable for any types in
this type collection,
@@ -1017,6 +1018,9 @@ object TypeCoercion extends TypeCoercionBase {
case (ArrayType(fromType, fn), ArrayType(toType: DataType, true)) =>
implicitCast(fromType, toType).map(ArrayType(_, true)).orNull
+ case (ArrayType(fromType, fn), AbstractArrayType(toType)) =>
+ implicitCast(fromType, toType).map(ArrayType(_, true)).orNull
+
case (ArrayType(fromType, true), ArrayType(toType: DataType, false)) =>
null
case (ArrayType(fromType, false), ArrayType(toType: DataType, false))
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collationExpressions.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collationExpressions.scala
index e8b738921b73..6af00e193d94 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collationExpressions.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collationExpressions.scala
@@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.util.CollationFactory
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.internal.types.StringTypeAnyCollation
import org.apache.spark.sql.types._
// scalastyle:off line.contains.tab
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index f331a489968a..39bf6734eb27 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -38,6 +38,7 @@ import org.apache.spark.sql.catalyst.util.DateTimeConstants._
import org.apache.spark.sql.catalyst.util.DateTimeUtils._
import org.apache.spark.sql.errors.{QueryErrorsBase, QueryExecutionErrors}
import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.internal.types.{AbstractArrayType,
StringTypeAnyCollation}
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.SQLOpenHashSet
import org.apache.spark.unsafe.UTF8StringBuilder
@@ -2003,9 +2004,9 @@ case class ArrayJoin(
this(array, delimiter, Some(nullReplacement))
override def inputTypes: Seq[AbstractDataType] = if
(nullReplacement.isDefined) {
- Seq(ArrayType, StringTypeAnyCollation, StringTypeAnyCollation)
+ Seq(AbstractArrayType(StringTypeAnyCollation), StringTypeAnyCollation,
StringTypeAnyCollation)
} else {
- Seq(ArrayType, StringTypeAnyCollation)
+ Seq(AbstractArrayType(StringTypeAnyCollation), StringTypeAnyCollation)
}
override def children: Seq[Expression] = if (nullReplacement.isDefined) {
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
index ad15509dc6fd..cf6c9d4f1d94 100755
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
@@ -37,6 +37,7 @@ import
org.apache.spark.sql.catalyst.trees.TreePattern.{TreePattern, UPPER_OR_LO
import org.apache.spark.sql.catalyst.util.{ArrayData, CollationFactory,
GenericArrayData, TypeUtils}
import org.apache.spark.sql.errors.{QueryCompilationErrors,
QueryExecutionErrors}
import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.internal.types.StringTypeAnyCollation
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.UTF8StringBuilder
import org.apache.spark.unsafe.array.ByteArrayMethods
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala
index 62150eaeac54..c0322387c804 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala
@@ -645,6 +645,9 @@ class CollationSuite extends DatasourceV2SQLBase with
AdaptiveSparkPlanHelper {
},
errorClass = "COLLATION_MISMATCH.IMPLICIT"
)
+
+ checkAnswer(sql("SELECT array_join(array('a', 'b' collate UNICODE), 'c'
collate UNICODE_CI)"),
+ Seq(Row("acb")))
}
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]