This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new a8b919f924db [SPARK-47736][SQL] Add support for AbstractArrayType a8b919f924db is described below commit a8b919f924db1e2818b2b0de49762292ae20c17c Author: Mihailo Milosevic <mihailo.milose...@databricks.com> AuthorDate: Thu Apr 11 15:39:42 2024 +0800 [SPARK-47736][SQL] Add support for AbstractArrayType ### What changes were proposed in this pull request? Addition of abstract arraytype which accepts StringTypeCollated as elementType. Changes in this PR https://github.com/apache/spark/pull/45693 work for ArrayJoin, but will not work in general for other functions. This PR introduces a change to give an interface for all functions. Merge only after #45693. ### Why are the changes needed? This is needed in order to enable functions to use collated arrays. ### Does this PR introduce _any_ user-facing change? Yes, collation functions will work. ### How was this patch tested? Test for array_join added to `CollationSuite` ### Was this patch authored or co-authored using generative AI tooling? No. Closes #45891 from mihailom-db/SPARK-47736. Authored-by: Mihailo Milosevic <mihailo.milose...@databricks.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../sql/internal/types/AbstractArrayType.scala | 37 ++++++++++++++++++++++ .../sql/internal/types/AbstractStringType.scala} | 10 +++--- .../sql/catalyst/analysis/AnsiTypeCoercion.scala | 3 +- .../spark/sql/catalyst/analysis/TypeCoercion.scala | 6 +++- .../expressions/collationExpressions.scala | 1 + .../expressions/collectionOperations.scala | 5 +-- .../catalyst/expressions/stringExpressions.scala | 1 + .../org/apache/spark/sql/CollationSuite.scala | 3 ++ 8 files changed, 57 insertions(+), 9 deletions(-) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/internal/types/AbstractArrayType.scala b/sql/api/src/main/scala/org/apache/spark/sql/internal/types/AbstractArrayType.scala new file mode 100644 index 000000000000..406449a33727 --- /dev/null +++ b/sql/api/src/main/scala/org/apache/spark/sql/internal/types/AbstractArrayType.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.internal.types + +import org.apache.spark.sql.types.{AbstractDataType, ArrayType, DataType} + + +/** + * Use AbstractArrayType(AbstractDataType) for defining expected types for expression parameters. + */ +case class AbstractArrayType(elementType: AbstractDataType) extends AbstractDataType { + + override private[sql] def defaultConcreteType: DataType = + ArrayType(elementType.defaultConcreteType, containsNull = true) + + override private[sql] def acceptsType(other: DataType): Boolean = { + other.isInstanceOf[ArrayType] && + elementType.acceptsType(other.asInstanceOf[ArrayType].elementType) + } + + override private[spark] def simpleString: String = s"array<${elementType.simpleString}>" +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/StringTypeCollated.scala b/sql/api/src/main/scala/org/apache/spark/sql/internal/types/AbstractStringType.scala similarity index 86% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/StringTypeCollated.scala rename to sql/api/src/main/scala/org/apache/spark/sql/internal/types/AbstractStringType.scala index 67b65859e6bb..6403295fe20c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/StringTypeCollated.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/internal/types/AbstractStringType.scala @@ -15,14 +15,14 @@ * limitations under the License. */ -package org.apache.spark.sql.catalyst.expressions +package org.apache.spark.sql.internal.types import org.apache.spark.sql.types.{AbstractDataType, DataType, StringType} /** * StringTypeCollated is an abstract class for StringType with collation support. */ -abstract class StringTypeCollated extends AbstractDataType { +abstract class AbstractStringType extends AbstractDataType { override private[sql] def defaultConcreteType: DataType = StringType override private[sql] def simpleString: String = "string" } @@ -30,7 +30,7 @@ abstract class StringTypeCollated extends AbstractDataType { /** * Use StringTypeBinary for expressions supporting only binary collation. */ -case object StringTypeBinary extends StringTypeCollated { +case object StringTypeBinary extends AbstractStringType { override private[sql] def acceptsType(other: DataType): Boolean = other.isInstanceOf[StringType] && other.asInstanceOf[StringType].supportsBinaryEquality } @@ -38,7 +38,7 @@ case object StringTypeBinary extends StringTypeCollated { /** * Use StringTypeBinaryLcase for expressions supporting only binary and lowercase collation. */ -case object StringTypeBinaryLcase extends StringTypeCollated { +case object StringTypeBinaryLcase extends AbstractStringType { override private[sql] def acceptsType(other: DataType): Boolean = other.isInstanceOf[StringType] && (other.asInstanceOf[StringType].supportsBinaryEquality || other.asInstanceOf[StringType].isUTF8BinaryLcaseCollation) @@ -47,6 +47,6 @@ case object StringTypeBinaryLcase extends StringTypeCollated { /** * Use StringTypeAnyCollation for expressions supporting all possible collation types. */ -case object StringTypeAnyCollation extends StringTypeCollated { +case object StringTypeAnyCollation extends AbstractStringType { override private[sql] def acceptsType(other: DataType): Boolean = other.isInstanceOf[StringType] } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala index dd904be9a760..52c1136b1ee3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.internal.types.AbstractStringType import org.apache.spark.sql.types._ import org.apache.spark.sql.types.UpCastRule.numericPrecedence @@ -185,7 +186,7 @@ object AnsiTypeCoercion extends TypeCoercionBase { // If a function expects a StringType, no StringType instance should be implicitly cast to // StringType with a collation that's not accepted (aka. lockdown unsupported collations). case (_: StringType, _: StringType) => None - case (_: StringType, _: StringTypeCollated) => None + case (_: StringType, _: AbstractStringType) => None // If a function expects integral type, fractional input is not allowed. case (_: FractionalType, IntegralType) => None diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 615d21f67695..5d8bcd4f7b1f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.trees.AlwaysProcess import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.types.{AbstractArrayType, AbstractStringType} import org.apache.spark.sql.types._ import org.apache.spark.sql.types.UpCastRule.numericPrecedence @@ -997,7 +998,7 @@ object TypeCoercion extends TypeCoercionBase { case (_: StringType, BinaryType) => BinaryType // Cast any atomic type to string. case (any: AtomicType, _: StringType) if !any.isInstanceOf[StringType] => StringType - case (any: AtomicType, st: StringTypeCollated) + case (any: AtomicType, st: AbstractStringType) if !any.isInstanceOf[StringType] => st.defaultConcreteType // When we reach here, input type is not acceptable for any types in this type collection, @@ -1017,6 +1018,9 @@ object TypeCoercion extends TypeCoercionBase { case (ArrayType(fromType, fn), ArrayType(toType: DataType, true)) => implicitCast(fromType, toType).map(ArrayType(_, true)).orNull + case (ArrayType(fromType, fn), AbstractArrayType(toType)) => + implicitCast(fromType, toType).map(ArrayType(_, true)).orNull + case (ArrayType(fromType, true), ArrayType(toType: DataType, false)) => null case (ArrayType(fromType, false), ArrayType(toType: DataType, false)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collationExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collationExpressions.scala index e8b738921b73..6af00e193d94 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collationExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collationExpressions.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.CollationFactory import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.types.StringTypeAnyCollation import org.apache.spark.sql.types._ // scalastyle:off line.contains.tab diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index f331a489968a..39bf6734eb27 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -38,6 +38,7 @@ import org.apache.spark.sql.catalyst.util.DateTimeConstants._ import org.apache.spark.sql.catalyst.util.DateTimeUtils._ import org.apache.spark.sql.errors.{QueryErrorsBase, QueryExecutionErrors} import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.types.{AbstractArrayType, StringTypeAnyCollation} import org.apache.spark.sql.types._ import org.apache.spark.sql.util.SQLOpenHashSet import org.apache.spark.unsafe.UTF8StringBuilder @@ -2003,9 +2004,9 @@ case class ArrayJoin( this(array, delimiter, Some(nullReplacement)) override def inputTypes: Seq[AbstractDataType] = if (nullReplacement.isDefined) { - Seq(ArrayType, StringTypeAnyCollation, StringTypeAnyCollation) + Seq(AbstractArrayType(StringTypeAnyCollation), StringTypeAnyCollation, StringTypeAnyCollation) } else { - Seq(ArrayType, StringTypeAnyCollation) + Seq(AbstractArrayType(StringTypeAnyCollation), StringTypeAnyCollation) } override def children: Seq[Expression] = if (nullReplacement.isDefined) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index ad15509dc6fd..cf6c9d4f1d94 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -37,6 +37,7 @@ import org.apache.spark.sql.catalyst.trees.TreePattern.{TreePattern, UPPER_OR_LO import org.apache.spark.sql.catalyst.util.{ArrayData, CollationFactory, GenericArrayData, TypeUtils} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.types.StringTypeAnyCollation import org.apache.spark.sql.types._ import org.apache.spark.unsafe.UTF8StringBuilder import org.apache.spark.unsafe.array.ByteArrayMethods diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala index 62150eaeac54..c0322387c804 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala @@ -645,6 +645,9 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { }, errorClass = "COLLATION_MISMATCH.IMPLICIT" ) + + checkAnswer(sql("SELECT array_join(array('a', 'b' collate UNICODE), 'c' collate UNICODE_CI)"), + Seq(Row("acb"))) } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org