This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 4dad2170b05c [SPARK-47356][SQL] Add support for ConcatWs & Elt (all collations) 4dad2170b05c is described below commit 4dad2170b05c04faf1da550ab3fb8c52a61b8be7 Author: Mihailo Milosevic <mihailo.milose...@databricks.com> AuthorDate: Tue Apr 16 21:21:24 2024 +0800 [SPARK-47356][SQL] Add support for ConcatWs & Elt (all collations) ### What changes were proposed in this pull request? Addition of support for ConcatWs and Elt expressions. ### Why are the changes needed? We need to enable these functions to support collations in order to scope all functions. ### Does this PR introduce _any_ user-facing change? Yes, both expressions now will not return error when called with collated strings. ### How was this patch tested? Addition of tests to `CollationStringExpressionsSuite` ### Was this patch authored or co-authored using generative AI tooling? No. Closes #46061 from mihailom-db/SPARK-47356. Authored-by: Mihailo Milosevic <mihailo.milose...@databricks.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../sql/catalyst/analysis/CollationTypeCasts.scala | 5 ++- .../catalyst/expressions/stringExpressions.scala | 25 ++++++------ .../sql/CollationStringExpressionsSuite.scala | 46 ++++++++++++++++------ 3 files changed, 51 insertions(+), 25 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala index 1a14b4227de8..795e8a696b01 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala @@ -22,7 +22,7 @@ import javax.annotation.Nullable import scala.annotation.tailrec import org.apache.spark.sql.catalyst.analysis.TypeCoercion.{hasStringType, haveSameType} -import org.apache.spark.sql.catalyst.expressions.{ArrayJoin, BinaryExpression, CaseWhen, Cast, Coalesce, Collate, Concat, ConcatWs, CreateArray, Expression, Greatest, If, In, InSubquery, Least} +import org.apache.spark.sql.catalyst.expressions.{ArrayJoin, BinaryExpression, CaseWhen, Cast, Coalesce, Collate, Concat, ConcatWs, CreateArray, Elt, Expression, Greatest, If, In, InSubquery, Least} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{ArrayType, DataType, StringType} @@ -45,6 +45,9 @@ object CollationTypeCasts extends TypeCoercionRule { caseWhenExpr.elseValue.map(e => castStringType(e, outputStringType).getOrElse(e)) CaseWhen(newBranches, newElseValue) + case eltExpr: Elt => + eltExpr.withNewChildren(eltExpr.children.head +: collateToSingleType(eltExpr.children.tail)) + case otherExpr @ ( _: In | _: InSubquery | _: CreateArray | _: ArrayJoin | _: Concat | _: Greatest | _: Least | _: Coalesce | _: BinaryExpression | _: ConcatWs) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 34e8f3f40859..4fe57b4f8f02 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -37,7 +37,7 @@ import org.apache.spark.sql.catalyst.trees.TreePattern.{TreePattern, UPPER_OR_LO import org.apache.spark.sql.catalyst.util.{ArrayData, CollationSupport, GenericArrayData, TypeUtils} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.types.StringTypeAnyCollation +import org.apache.spark.sql.internal.types.{AbstractArrayType, StringTypeAnyCollation} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.UTF8StringBuilder import org.apache.spark.unsafe.array.ByteArrayMethods @@ -79,11 +79,12 @@ case class ConcatWs(children: Seq[Expression]) /** The 1st child (separator) is str, and rest are either str or array of str. */ override def inputTypes: Seq[AbstractDataType] = { - val arrayOrStr = TypeCollection(ArrayType(StringType), StringType) - StringType +: Seq.fill(children.size - 1)(arrayOrStr) + val arrayOrStr = + TypeCollection(AbstractArrayType(StringTypeAnyCollation), StringTypeAnyCollation) + StringTypeAnyCollation +: Seq.fill(children.size - 1)(arrayOrStr) } - override def dataType: DataType = StringType + override def dataType: DataType = children.head.dataType override def nullable: Boolean = children.head.nullable override def foldable: Boolean = children.forall(_.foldable) @@ -102,7 +103,8 @@ case class ConcatWs(children: Seq[Expression]) val flatInputs = children.flatMap { child => child.eval(input) match { case s: UTF8String => Iterator(s) - case arr: ArrayData => arr.toArray[UTF8String](StringType) + case arr: ArrayData => + arr.toArray[UTF8String](child.dataType.asInstanceOf[ArrayType].elementType) case null => Iterator(null.asInstanceOf[UTF8String]) } } @@ -110,7 +112,7 @@ case class ConcatWs(children: Seq[Expression]) } override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - if (children.forall(_.dataType == StringType)) { + if (children.forall(_.dataType.isInstanceOf[StringType])) { // All children are strings. In that case we can construct a fixed size array. val evals = children.map(_.genCode(ctx)) val separator = evals.head @@ -163,7 +165,7 @@ case class ConcatWs(children: Seq[Expression]) """ val (varCount, varBuild) = child.dataType match { - case StringType => + case _: StringType => val reprForValueCast = s"((UTF8String) $reprForValue)" ("", // we count all the StringType arguments num at once below. if (eval.isNull == TrueLiteral) { @@ -171,7 +173,7 @@ case class ConcatWs(children: Seq[Expression]) } else { s"$array[$idxVararg ++] = $reprForIsNull ? (UTF8String) null : $reprForValueCast;" }) - case _: ArrayType => + case arr: ArrayType => val reprForValueCast = s"((ArrayData) $reprForValue)" val size = ctx.freshName("n") if (eval.isNull == TrueLiteral) { @@ -187,7 +189,7 @@ case class ConcatWs(children: Seq[Expression]) if (!$reprForIsNull) { final int $size = $reprForValueCast.numElements(); for (int j = 0; j < $size; j ++) { - $array[$idxVararg ++] = ${CodeGenerator.getValue(reprForValueCast, StringType, "j")}; + $array[$idxVararg ++] = ${CodeGenerator.getValue(reprForValueCast, arr.elementType, "j")}; } } """) @@ -235,7 +237,7 @@ case class ConcatWs(children: Seq[Expression]) boolean[] $isNullArgs = new boolean[${children.length - 1}]; Object[] $valueArgs = new Object[${children.length - 1}]; $argBuilds - int $varargNum = ${children.count(_.dataType == StringType) - 1}; + int $varargNum = ${children.count(_.dataType.isInstanceOf[StringType]) - 1}; int $idxVararg = 0; $varargCounts UTF8String[] $array = new UTF8String[$varargNum]; @@ -287,7 +289,8 @@ case class Elt( /** This expression is always nullable because it returns null if index is out of range. */ override def nullable: Boolean = true - override def dataType: DataType = inputExprs.map(_.dataType).headOption.getOrElse(StringType) + override def dataType: DataType = + inputExprs.map(_.dataType).headOption.getOrElse(SQLConf.get.defaultStringType) override def checkInputDataTypes(): TypeCheckResult = { if (children.size < 2) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala index 0dbd4c0ba713..9237c8a25a5d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala @@ -32,7 +32,10 @@ class CollationStringExpressionsSuite // Supported collations case class ConcatWsTestCase[R](s: String, a: Array[String], c: String, result: R) val testCases = Seq( - ConcatWsTestCase(" ", Array("Spark", "SQL"), "UTF8_BINARY", "Spark SQL") + ConcatWsTestCase(" ", Array("Spark", "SQL"), "UTF8_BINARY", "Spark SQL"), + ConcatWsTestCase(" ", Array("Spark", "SQL"), "UTF8_BINARY_LCASE", "Spark SQL"), + ConcatWsTestCase(" ", Array("Spark", "SQL"), "UNICODE", "Spark SQL"), + ConcatWsTestCase(" ", Array("Spark", "SQL"), "UNICODE_CI", "Spark SQL") ) testCases.foreach(t => { val arrCollated = t.a.map(s => s"collate('$s', '${t.c}')").mkString(", ") @@ -49,22 +52,39 @@ class CollationStringExpressionsSuite checkAnswer(sql(query), Row(t.result)) assert(sql(query).schema.fields.head.dataType.sameType(StringType(t.c))) }) - // Unsupported collations - case class ConcatWsTestFail(s: String, a: Array[String], c: String) - val failCases = Seq( - ConcatWsTestFail(" ", Array("ABC", "%b%"), "UTF8_BINARY_LCASE"), - ConcatWsTestFail(" ", Array("ABC", "%B%"), "UNICODE"), - ConcatWsTestFail(" ", Array("ABC", "%b%"), "UNICODE_CI") + // Collation mismatch + val collationMismatch = intercept[AnalysisException] { + sql("SELECT concat_ws(' ',collate('Spark', 'UTF8_BINARY_LCASE'),collate('SQL', 'UNICODE'))") + } + assert(collationMismatch.getErrorClass === "COLLATION_MISMATCH.EXPLICIT") + } + + test("Support Elt string expression with collation") { + // Supported collations + case class EltTestCase[R](index: Int, inputs: Array[String], c: String, result: R) + val testCases = Seq( + EltTestCase(1, Array("Spark", "SQL"), "UTF8_BINARY", "Spark"), + EltTestCase(1, Array("Spark", "SQL"), "UTF8_BINARY_LCASE", "Spark"), + EltTestCase(2, Array("Spark", "SQL"), "UNICODE", "SQL"), + EltTestCase(2, Array("Spark", "SQL"), "UNICODE_CI", "SQL") ) - failCases.foreach(t => { - val arrCollated = t.a.map(s => s"collate('$s', '${t.c}')").mkString(", ") - val query = s"SELECT concat_ws(collate('${t.s}', '${t.c}'), $arrCollated)" - val unsupportedCollation = intercept[AnalysisException] { sql(query) } - assert(unsupportedCollation.getErrorClass === "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE") + testCases.foreach(t => { + var query = s"SELECT elt(${t.index}, collate('${t.inputs(0)}', '${t.c}')," + + s" collate('${t.inputs(1)}', '${t.c}'))" + // Result & data type + checkAnswer(sql(query), Row(t.result)) + assert(sql(query).schema.fields.head.dataType.sameType(StringType(t.c))) + // Implicit casting + query = s"SELECT elt(${t.index}, collate('${t.inputs(0)}', '${t.c}'), '${t.inputs(1)}')" + checkAnswer(sql(query), Row(t.result)) + assert(sql(query).schema.fields.head.dataType.sameType(StringType(t.c))) + query = s"SELECT elt(${t.index}, '${t.inputs(0)}', collate('${t.inputs(1)}', '${t.c}'))" + checkAnswer(sql(query), Row(t.result)) + assert(sql(query).schema.fields.head.dataType.sameType(StringType(t.c))) }) // Collation mismatch val collationMismatch = intercept[AnalysisException] { - sql("SELECT concat_ws(' ',collate('Spark', 'UTF8_BINARY_LCASE'),collate('SQL', 'UNICODE'))") + sql("SELECT elt(0 ,collate('Spark', 'UTF8_BINARY_LCASE'), collate('SQL', 'UNICODE'))") } assert(collationMismatch.getErrorClass === "COLLATION_MISMATCH.EXPLICIT") } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org