cloud-fan commented on code in PR #46722:
URL: https://github.com/apache/spark/pull/46722#discussion_r1624959243
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteCollationJoin.scala:
##########
@@ -17,29 +17,83 @@
package org.apache.spark.sql.catalyst.analysis
-import org.apache.spark.sql.catalyst.expressions.{AttributeReference,
CollationKey, Equality}
+import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.{Join, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.Rule
-import org.apache.spark.sql.catalyst.util.CollationFactory
+import org.apache.spark.sql.catalyst.util.UnsafeRowUtils
+import org.apache.spark.sql.types._
import org.apache.spark.sql.types.StringType
+import org.apache.spark.util.ArrayImplicits.SparkArrayOps
+/**
+ * This rule rewrites Join conditions to ensure that all types containing
non-binary collated
+ * strings are compared correctly. This is necessary because join conditions
are evaluated using
+ * binary equality, which does not work correctly for non-binary collated
strings. However, by
+ * injecting CollationKey expressions into the join condition, we can ensure
that the comparison
+ * is done correctly, which then allows HashJoin to work properly on this type
of data.
+ */
object RewriteCollationJoin extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case j @ Join(_, _, _, Some(condition), _) =>
val newCondition = condition transform {
case e @ Equality(l: AttributeReference, r: AttributeReference) =>
- (l.dataType, r.dataType) match {
- case (st: StringType, _: StringType)
- if
!CollationFactory.fetchCollation(st.collationId).supportsBinaryEquality =>
- e.withNewChildren(Seq(CollationKey(l), CollationKey(r)))
- case _ =>
- e
- }
+ e.withNewChildren(Seq(processExpression(l, l.dataType),
processExpression(r, r.dataType)))
}
if (!newCondition.fastEquals(condition)) {
j.copy(condition = Some(newCondition))
} else {
j
}
}
+
+ /**
+ * Recursively process the expression in order to replace non-binary
collated strings with their
+ * associated collation keys. This is necessary to ensure that the join
condition is evaluated
+ * correctly for all types containing non-binary collated strings, including
structs and arrays.
+ */
+ private def processExpression(expr: Expression, dt: DataType): Expression = {
+ dt match {
+ // For binary stable expressions, no special handling is needed.
+ case _ if UnsafeRowUtils.isBinaryStable(dt) =>
+ expr
+
+ // Inject CollationKey for non-binary collated strings.
+ case _: StringType =>
+ CollationKey(expr)
+
+ // Recursively process struct fields for non-binary structs.
+ case StructType(fields) =>
+ processStruct(expr, fields)
+
+ // Recursively process array elements for non-binary arrays.
+ case ArrayType(et, containsNull) =>
+ processArray(expr, et, containsNull)
+
+ // Joins are not supported on maps, so there's no special handling for
MapType.
+ case _ =>
+ expr
+ }
+ }
+
+ private def processStruct(str: Expression, fields: Array[StructField]):
Expression = {
+ val struct = CreateNamedStruct(fields.zipWithIndex.flatMap { case (f, i) =>
+ Seq(Literal(f.name), processExpression(GetStructField(str, i,
Some(f.name)), f.dataType))
+ }.toImmutableArraySeq)
Review Comment:
shall we return the original `str` expression if no field is updated with
CollationKey?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]