stefankandic commented on code in PR #48663:
URL: https://github.com/apache/spark/pull/48663#discussion_r1848606677
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCoercion.scala:
##########
@@ -193,88 +144,208 @@ object CollationTypeCoercion {
case other => other
}
+ /**
+ * If childType is collated and target is UTF8_BINARY, the collation of the
output
+ * should be that of the childType.
+ */
+ private def shouldRemoveCast(cast: Cast): Boolean = {
+ val isUserDefined = cast.getTagValue(Cast.USER_SPECIFIED_CAST).isDefined
+ val isChildTypeCollatedString = cast.child.dataType match {
+ case st: StringType => !st.isUTF8BinaryCollation
+ case _ => false
+ }
+ val targetType = cast.dataType
+
+ isUserDefined && isChildTypeCollatedString && targetType == StringType
+ }
+
/**
* Extracts StringTypes from filtered hasStringType
*/
@tailrec
- private def extractStringType(dt: DataType): StringType = dt match {
- case st: StringType => st
+ private def extractStringType(dt: DataType): Option[StringType] = dt match {
+ case st: StringType => Some(st)
case ArrayType(et, _) => extractStringType(et)
+ case _ => None
}
/**
* Casts given expression to collated StringType with id equal to
collationId only
* if expression has StringType in the first place.
- * @param expr
- * @param collationId
- * @return
*/
- def castStringType(expr: Expression, st: StringType): Option[Expression] =
- castStringType(expr.dataType, st).map { dt => Cast(expr, dt)}
+ def castStringType(expr: Expression, st: StringType): Expression = {
+ castStringType(expr.dataType, st)
+ .map(dt => Cast(expr, dt))
+ .getOrElse(expr)
+ }
private def castStringType(inType: DataType, castType: StringType):
Option[DataType] = {
- @Nullable val ret: DataType = inType match {
- case st: StringType if st.collationId != castType.collationId => castType
+ inType match {
+ case st: StringType if st.collationId != castType.collationId =>
+ Some(castType)
case ArrayType(arrType, nullable) =>
- castStringType(arrType, castType).map(ArrayType(_, nullable)).orNull
- case _ => null
+ castStringType(arrType, castType).map(ArrayType(_, nullable))
+ case _ => None
}
- Option(ret)
}
/**
* Collates input expressions to a single collation.
*/
- def collateToSingleType(exprs: Seq[Expression]): Seq[Expression] = {
- val st = getOutputCollation(exprs)
+ def collateToSingleType(expressions: Seq[Expression]): Seq[Expression] = {
+ val lctOpt = findLeastCommonStringType(expressions)
- exprs.map(e => castStringType(e, st).getOrElse(e))
+ lctOpt match {
+ case Some(lct) =>
+ expressions.map(e => castStringType(e, lct))
+ case _ =>
+ expressions
+ }
}
/**
- * Based on the data types of the input expressions this method determines
- * a collation type which the output will have. This function accepts Seq of
- * any expressions, but will only be affected by collated StringTypes or
- * complex DataTypes with collated StringTypes (e.g. ArrayType)
+ * Tries to find the least common StringType among the given expressions.
*/
- def getOutputCollation(expr: Seq[Expression]): StringType = {
- val explicitTypes = expr.filter {
- case _: Collate => true
- case _ => false
- }
- .map(_.dataType.asInstanceOf[StringType].collationId)
- .distinct
-
- explicitTypes.size match {
- // We have 1 explicit collation
- case 1 => StringType(explicitTypes.head)
- // Multiple explicit collations occurred
- case size if size > 1 =>
- throw QueryCompilationErrors
- .explicitCollationMismatchError(
- explicitTypes.map(t => StringType(t))
- )
- // Only implicit or default collations present
- case 0 =>
- val implicitTypes = expr.filter {
- case Literal(_, _: StringType) => false
- case cast: Cast if
cast.getTagValue(Cast.USER_SPECIFIED_CAST).isEmpty =>
- cast.child.dataType.isInstanceOf[StringType]
- case _ => true
- }
- .map(_.dataType)
- .filter(hasStringType)
- .map(extractStringType(_).collationId)
- .distinct
-
- if (implicitTypes.length > 1) {
- throw QueryCompilationErrors.implicitCollationMismatchError(
- implicitTypes.map(t => StringType(t))
- )
+ private def findLeastCommonStringType(expressions: Seq[Expression]):
Option[StringType] = {
+ if (!expressions.exists(e =>
SchemaUtils.hasNonUTF8BinaryCollation(e.dataType))) {
+ return None
+ }
+
+ val collationContextWinner =
expressions.foldLeft(findCollationContext(expressions.head)) {
+ case (Some(left), right) =>
+ findCollationContext(right).flatMap { ctx =>
+ collationPrecedenceWinner(left, ctx)
}
- else {
-
implicitTypes.headOption.map(StringType(_)).getOrElse(SQLConf.get.defaultStringType)
+ case (None, _) => return None
+ }
+
+ collationContextWinner.flatMap { cc =>
+ extractStringType(cc.dataType)
+ }
+ }
+
+ /**
+ * Tries to find the collation context for the given expression.
+ * If found, it will also set the [[COLLATION_CONTEXT_TAG]] on the
expression,
+ * so that the context can be reused later.
+ */
+ private def findCollationContext(expr: Expression): Option[CollationContext]
= {
+
+ def getChildren: Seq[Expression] = expr match {
+ // we don't need to consider the struct name expressions
+ case struct: CreateNamedStruct => struct.valExprs
+ case _ => expr.children
+ }
+
+ val contextOpt = expr match {
+ case _ if hasCollationContextTag(expr) =>
+ Some(expr.getTagValue(COLLATION_CONTEXT_TAG).get)
+
+ // if `expr` doesn't have a string in its dataType then it doesn't
+ // have the collation context either
+ case _ if !expr.dataType.existsRecursively(_.isInstanceOf[StringType]) =>
+ None
+
+ case collate: Collate =>
+ Some(CollationContext(collate.dataType, Explicit))
+
+ case _: Alias | _: SubqueryExpression | _: AttributeReference | _:
VariableReference =>
+ Some(CollationContext(expr.dataType, Implicit))
+
+ case _: Literal =>
+ Some(CollationContext(expr.dataType, Default))
+
+ // if it does have a string type but none of its children do
+ // then the collation context strength is default
+ case _ if
!expr.children.exists(_.dataType.existsRecursively(_.isInstanceOf[StringType]))
=>
+ Some(CollationContext(expr.dataType, Default))
+
+ case extract: ExtractValue =>
Review Comment:
No, let's say we have a map column and we do `mapCol['key' COLLATE
UNICODE]`, since we are just retrieving the value, it shouldn't have explicit
collation just because the key is explicit. That's why should only look at the
child and not at all children.
There is a test for this called `access collated map via literal`.
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCoercion.scala:
##########
@@ -193,88 +144,208 @@ object CollationTypeCoercion {
case other => other
}
+ /**
+ * If childType is collated and target is UTF8_BINARY, the collation of the
output
+ * should be that of the childType.
+ */
+ private def shouldRemoveCast(cast: Cast): Boolean = {
+ val isUserDefined = cast.getTagValue(Cast.USER_SPECIFIED_CAST).isDefined
+ val isChildTypeCollatedString = cast.child.dataType match {
+ case st: StringType => !st.isUTF8BinaryCollation
+ case _ => false
+ }
+ val targetType = cast.dataType
+
+ isUserDefined && isChildTypeCollatedString && targetType == StringType
+ }
+
/**
* Extracts StringTypes from filtered hasStringType
*/
@tailrec
- private def extractStringType(dt: DataType): StringType = dt match {
- case st: StringType => st
+ private def extractStringType(dt: DataType): Option[StringType] = dt match {
+ case st: StringType => Some(st)
case ArrayType(et, _) => extractStringType(et)
+ case _ => None
}
/**
* Casts given expression to collated StringType with id equal to
collationId only
* if expression has StringType in the first place.
- * @param expr
- * @param collationId
- * @return
*/
- def castStringType(expr: Expression, st: StringType): Option[Expression] =
- castStringType(expr.dataType, st).map { dt => Cast(expr, dt)}
+ def castStringType(expr: Expression, st: StringType): Expression = {
+ castStringType(expr.dataType, st)
+ .map(dt => Cast(expr, dt))
+ .getOrElse(expr)
+ }
private def castStringType(inType: DataType, castType: StringType):
Option[DataType] = {
- @Nullable val ret: DataType = inType match {
- case st: StringType if st.collationId != castType.collationId => castType
+ inType match {
+ case st: StringType if st.collationId != castType.collationId =>
+ Some(castType)
case ArrayType(arrType, nullable) =>
- castStringType(arrType, castType).map(ArrayType(_, nullable)).orNull
- case _ => null
+ castStringType(arrType, castType).map(ArrayType(_, nullable))
+ case _ => None
}
- Option(ret)
}
/**
* Collates input expressions to a single collation.
*/
- def collateToSingleType(exprs: Seq[Expression]): Seq[Expression] = {
- val st = getOutputCollation(exprs)
+ def collateToSingleType(expressions: Seq[Expression]): Seq[Expression] = {
+ val lctOpt = findLeastCommonStringType(expressions)
- exprs.map(e => castStringType(e, st).getOrElse(e))
+ lctOpt match {
+ case Some(lct) =>
+ expressions.map(e => castStringType(e, lct))
+ case _ =>
+ expressions
+ }
}
/**
- * Based on the data types of the input expressions this method determines
- * a collation type which the output will have. This function accepts Seq of
- * any expressions, but will only be affected by collated StringTypes or
- * complex DataTypes with collated StringTypes (e.g. ArrayType)
+ * Tries to find the least common StringType among the given expressions.
*/
- def getOutputCollation(expr: Seq[Expression]): StringType = {
- val explicitTypes = expr.filter {
- case _: Collate => true
- case _ => false
- }
- .map(_.dataType.asInstanceOf[StringType].collationId)
- .distinct
-
- explicitTypes.size match {
- // We have 1 explicit collation
- case 1 => StringType(explicitTypes.head)
- // Multiple explicit collations occurred
- case size if size > 1 =>
- throw QueryCompilationErrors
- .explicitCollationMismatchError(
- explicitTypes.map(t => StringType(t))
- )
- // Only implicit or default collations present
- case 0 =>
- val implicitTypes = expr.filter {
- case Literal(_, _: StringType) => false
- case cast: Cast if
cast.getTagValue(Cast.USER_SPECIFIED_CAST).isEmpty =>
- cast.child.dataType.isInstanceOf[StringType]
- case _ => true
- }
- .map(_.dataType)
- .filter(hasStringType)
- .map(extractStringType(_).collationId)
- .distinct
-
- if (implicitTypes.length > 1) {
- throw QueryCompilationErrors.implicitCollationMismatchError(
- implicitTypes.map(t => StringType(t))
- )
+ private def findLeastCommonStringType(expressions: Seq[Expression]):
Option[StringType] = {
+ if (!expressions.exists(e =>
SchemaUtils.hasNonUTF8BinaryCollation(e.dataType))) {
+ return None
+ }
+
+ val collationContextWinner =
expressions.foldLeft(findCollationContext(expressions.head)) {
+ case (Some(left), right) =>
+ findCollationContext(right).flatMap { ctx =>
+ collationPrecedenceWinner(left, ctx)
}
- else {
-
implicitTypes.headOption.map(StringType(_)).getOrElse(SQLConf.get.defaultStringType)
+ case (None, _) => return None
+ }
+
+ collationContextWinner.flatMap { cc =>
+ extractStringType(cc.dataType)
+ }
+ }
+
+ /**
+ * Tries to find the collation context for the given expression.
+ * If found, it will also set the [[COLLATION_CONTEXT_TAG]] on the
expression,
+ * so that the context can be reused later.
+ */
+ private def findCollationContext(expr: Expression): Option[CollationContext]
= {
+
+ def getChildren: Seq[Expression] = expr match {
+ // we don't need to consider the struct name expressions
+ case struct: CreateNamedStruct => struct.valExprs
+ case _ => expr.children
+ }
+
+ val contextOpt = expr match {
+ case _ if hasCollationContextTag(expr) =>
+ Some(expr.getTagValue(COLLATION_CONTEXT_TAG).get)
+
+ // if `expr` doesn't have a string in its dataType then it doesn't
+ // have the collation context either
+ case _ if !expr.dataType.existsRecursively(_.isInstanceOf[StringType]) =>
+ None
+
+ case collate: Collate =>
+ Some(CollationContext(collate.dataType, Explicit))
+
+ case _: Alias | _: SubqueryExpression | _: AttributeReference | _:
VariableReference =>
+ Some(CollationContext(expr.dataType, Implicit))
+
+ case _: Literal =>
+ Some(CollationContext(expr.dataType, Default))
+
+ // if it does have a string type but none of its children do
+ // then the collation context strength is default
+ case _ if
!expr.children.exists(_.dataType.existsRecursively(_.isInstanceOf[StringType]))
=>
+ Some(CollationContext(expr.dataType, Default))
+
+ case extract: ExtractValue =>
+ findCollationContext(extract.child)
+ .map(cc => CollationContext(extract.dataType, cc.strength))
+
+ case _ =>
+ val contextWinnerOpt = getChildren
+ .flatMap(findCollationContext)
+ .foldLeft(Option.empty[CollationContext]) {
+ case (Some(left), right) =>
+ collationPrecedenceWinner(left, right)
+ case (None, right) =>
+ Some(right)
+ }
+
+ contextWinnerOpt.map { context =>
+ if (hasStringType(expr.dataType)) {
+ CollationContext(expr.dataType, context.strength)
+ } else {
+ context
+ }
}
}
+
+ contextOpt.foreach(expr.setTagValue(COLLATION_CONTEXT_TAG, _))
+ contextOpt
}
+
+ /**
+ * Returns the collation context that wins in precedence between left and
right.
+ */
+ private def collationPrecedenceWinner(
+ left: CollationContext,
+ right: CollationContext): Option[CollationContext] = {
+
+ val (leftStringType, rightStringType) =
+ (extractStringType(left.dataType), extractStringType(right.dataType))
match {
+ case (Some(l), Some(r)) =>
+ (l, r)
+ case (None, None) =>
+ return None
+ case (Some(_), None) =>
Review Comment:
This can happen if the data type of the context is a map or a struct since
the `extractStringType` only works for strings and arrays like you mentioned
below.
However I'm not 100% sure we can hit this case as usually there is something
above the struct or a map that only extracts the primitive field (like
`ExtractValue`).
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]