This is an automated email from the ASF dual-hosted git repository.
maxgekk pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 74d82665bc2e [SPARK-50219][SQL] Refactor `ApplyCharTypePadding` so
that helper methods can be used in single-pass resolver
74d82665bc2e is described below
commit 74d82665bc2e30daf86452de529c4688214d225d
Author: Mihailo Timotic <[email protected]>
AuthorDate: Tue Nov 5 16:25:41 2024 +0100
[SPARK-50219][SQL] Refactor `ApplyCharTypePadding` so that helper methods
can be used in single-pass resolver
### What changes were proposed in this pull request?
Refactor `ApplyCharTypePadding` so that helper methods can be used in
single-pass resolver. This means refactoring these helpers to a separate object
and moving it to catalyst package.
### Why are the changes needed?
Necessary in order to unblock Analyzer++ features.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Existing tests
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #48753 from mihailotim-db/mihailotim-db/refactor_rpad.
Authored-by: Mihailo Timotic <[email protected]>
Signed-off-by: Max Gekk <[email protected]>
---
.../analysis/ApplyCharTypePaddingHelper.scala | 206 +++++++++++++++++++++
.../datasources/ApplyCharTypePadding.scala | 161 +---------------
2 files changed, 213 insertions(+), 154 deletions(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ApplyCharTypePaddingHelper.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ApplyCharTypePaddingHelper.scala
new file mode 100644
index 000000000000..54f9abe0b9f1
--- /dev/null
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ApplyCharTypePaddingHelper.scala
@@ -0,0 +1,206 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.analysis
+
+import org.apache.spark.sql.catalyst.expressions.{
+ Alias,
+ Attribute,
+ BinaryComparison,
+ Expression,
+ In,
+ Literal,
+ NamedExpression,
+ OuterReference,
+ StringRPad
+}
+import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project}
+import org.apache.spark.sql.catalyst.trees.TreePattern.{BINARY_COMPARISON, IN}
+import org.apache.spark.sql.catalyst.util.CharVarcharUtils
+import org.apache.spark.sql.types.{CharType, Metadata, StringType}
+import org.apache.spark.unsafe.types.UTF8String
+
+/**
+ * Helper object used by the [[ApplyCharTypePadding]] rule. This object is
under catalyst
+ * package in order to make the methods accessible to single-pass [[Resolver]].
+ */
+object ApplyCharTypePaddingHelper {
+
+ object AttrOrOuterRef {
+ def unapply(e: Expression): Option[Attribute] = e match {
+ case a: Attribute => Some(a)
+ case OuterReference(a: Attribute) => Some(a)
+ case _ => None
+ }
+ }
+
+ private[sql] def readSidePadding(
+ relation: LogicalPlan,
+ cleanedRelation: () => LogicalPlan): (LogicalPlan, Seq[(Attribute,
Attribute)]) = {
+ val projectList = relation.output.map { attr =>
+ CharVarcharUtils.addPaddingForScan(attr) match {
+ case ne: NamedExpression => ne
+ case other => Alias(other, attr.name)(explicitMetadata =
Some(attr.metadata))
+ }
+ }
+ if (projectList == relation.output) {
+ relation -> Nil
+ } else {
+ val newPlan = Project(projectList, cleanedRelation())
+ newPlan -> relation.output.zip(newPlan.output)
+ }
+ }
+
+ private[sql] def paddingForStringComparison(
+ plan: LogicalPlan,
+ padCharCol: Boolean): LogicalPlan = {
+ plan.resolveOperatorsUpWithPruning(_.containsAnyPattern(BINARY_COMPARISON,
IN)) {
+ case operator =>
+
operator.transformExpressionsUpWithPruning(_.containsAnyPattern(BINARY_COMPARISON,
IN)) {
+ case e if !e.childrenResolved => e
+ case withChildrenResolved =>
+ singleNodePaddingForStringComparison(withChildrenResolved,
padCharCol)
+ }
+ }
+ }
+
+ private[sql] def singleNodePaddingForStringComparison(
+ expression: Expression,
+ padCharCol: Boolean): Expression =
+ expression match {
+ // String literal is treated as char type when it's compared to a char
type column.
+ // We should pad the shorter one to the longer length.
+ case b @ BinaryComparison(e @ AttrOrOuterRef(attr), lit) if lit.foldable
=>
+ padAttrLitCmp(e, attr.metadata, padCharCol, lit)
+ .map { newChildren =>
+ b.withNewChildren(newChildren)
+ }
+ .getOrElse(b)
+
+ case b @ BinaryComparison(lit, e @ AttrOrOuterRef(attr)) if lit.foldable
=>
+ padAttrLitCmp(e, attr.metadata, padCharCol, lit)
+ .map { newChildren =>
+ b.withNewChildren(newChildren.reverse)
+ }
+ .getOrElse(b)
+
+ case i @ In(e @ AttrOrOuterRef(attr), list)
+ if attr.dataType == StringType && list.forall(_.foldable) =>
+ CharVarcharUtils
+ .getRawType(attr.metadata)
+ .flatMap {
+ case CharType(length) =>
+ val (nulls, literalChars) =
+ list.map(_.eval().asInstanceOf[UTF8String]).partition(_ ==
null)
+ val literalCharLengths = literalChars.map(_.numChars())
+ val targetLen = (length +: literalCharLengths).max
+ Some(
+ i.copy(
+ value = addPadding(e, length, targetLen, alwaysPad =
padCharCol),
+ list = list.zip(literalCharLengths).map {
+ case (lit, charLength) =>
+ addPadding(lit, charLength, targetLen, alwaysPad =
false)
+ } ++ nulls.map(Literal.create(_, StringType))
+ )
+ )
+ case _ => None
+ }
+ .getOrElse(i)
+
+ // For char type column or inner field comparison, pad the shorter one
to the longer length.
+ case b @ BinaryComparison(e1 @ AttrOrOuterRef(left), e2 @
AttrOrOuterRef(right))
+ // For the same attribute, they must be the same length and no
padding is needed.
+ if !left.semanticEquals(right) =>
+ val outerRefs = (e1, e2) match {
+ case (_: OuterReference, _: OuterReference) => Seq(left, right)
+ case (_: OuterReference, _) => Seq(left)
+ case (_, _: OuterReference) => Seq(right)
+ case _ => Nil
+ }
+ val newChildren =
+ CharVarcharUtils.addPaddingInStringComparison(Seq(left, right),
padCharCol)
+ if (outerRefs.nonEmpty) {
+ b.withNewChildren(newChildren.map(_.transform {
+ case a: Attribute if outerRefs.exists(_.semanticEquals(a)) =>
OuterReference(a)
+ }))
+ } else {
+ b.withNewChildren(newChildren)
+ }
+
+ case i @ In(e @ AttrOrOuterRef(attr), list) if
list.forall(_.isInstanceOf[Attribute]) =>
+ val newChildren = CharVarcharUtils.addPaddingInStringComparison(
+ attr +: list.map(_.asInstanceOf[Attribute]),
+ padCharCol
+ )
+ if (e.isInstanceOf[OuterReference]) {
+ i.copy(value = newChildren.head.transform {
+ case a: Attribute if a.semanticEquals(attr) => OuterReference(a)
+ }, list = newChildren.tail)
+ } else {
+ i.copy(value = newChildren.head, list = newChildren.tail)
+ }
+
+ case other => other
+ }
+
+ private def padAttrLitCmp(
+ expr: Expression,
+ metadata: Metadata,
+ padCharCol: Boolean,
+ lit: Expression): Option[Seq[Expression]] = {
+ if (expr.dataType == StringType) {
+ CharVarcharUtils.getRawType(metadata).flatMap {
+ case CharType(length) =>
+ val str = lit.eval().asInstanceOf[UTF8String]
+ if (str == null) {
+ None
+ } else {
+ val stringLitLen = str.numChars()
+ if (length < stringLitLen) {
+ Some(Seq(StringRPad(expr, Literal(stringLitLen)), lit))
+ } else if (length > stringLitLen) {
+ val paddedExpr = if (padCharCol) {
+ StringRPad(expr, Literal(length))
+ } else {
+ expr
+ }
+ Some(Seq(paddedExpr, StringRPad(lit, Literal(length))))
+ } else if (padCharCol) {
+ Some(Seq(StringRPad(expr, Literal(length)), lit))
+ } else {
+ None
+ }
+ }
+ case _ => None
+ }
+ } else {
+ None
+ }
+ }
+
+ private def addPadding(
+ expr: Expression,
+ charLength: Int,
+ targetLength: Int,
+ alwaysPad: Boolean): Expression = {
+ if (targetLength > charLength) {
+ StringRPad(expr, Literal(targetLength))
+ } else if (alwaysPad) {
+ StringRPad(expr, Literal(charLength))
+ } else expr
+ }
+}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ApplyCharTypePadding.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ApplyCharTypePadding.scala
index 141767135a50..d952927f9d30 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ApplyCharTypePadding.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ApplyCharTypePadding.scala
@@ -17,16 +17,13 @@
package org.apache.spark.sql.execution.datasources
+import org.apache.spark.sql.catalyst.analysis.ApplyCharTypePaddingHelper
import org.apache.spark.sql.catalyst.catalog.HiveTableRelation
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project}
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
-import org.apache.spark.sql.catalyst.trees.TreePattern.{BINARY_COMPARISON, IN}
import org.apache.spark.sql.catalyst.util.CharVarcharUtils
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.sql.types.{CharType, Metadata, StringType}
-import org.apache.spark.unsafe.types.UTF8String
/**
* This rule performs string padding for char type.
@@ -39,14 +36,6 @@ import org.apache.spark.unsafe.types.UTF8String
*/
object ApplyCharTypePadding extends Rule[LogicalPlan] {
- object AttrOrOuterRef {
- def unapply(e: Expression): Option[Attribute] = e match {
- case a: Attribute => Some(a)
- case OuterReference(a: Attribute) => Some(a)
- case _ => None
- }
- }
-
override def apply(plan: LogicalPlan): LogicalPlan = {
if (conf.charVarcharAsString) {
return plan
@@ -55,158 +44,22 @@ object ApplyCharTypePadding extends Rule[LogicalPlan] {
if (conf.readSideCharPadding) {
val newPlan = plan.resolveOperatorsUpWithNewOutput {
case r: LogicalRelation =>
- readSidePadding(r, () =>
+ ApplyCharTypePaddingHelper.readSidePadding(r, () =>
r.copy(output = r.output.map(CharVarcharUtils.cleanAttrMetadata)))
case r: DataSourceV2Relation =>
- readSidePadding(r, () =>
+ ApplyCharTypePaddingHelper.readSidePadding(r, () =>
r.copy(output = r.output.map(CharVarcharUtils.cleanAttrMetadata)))
case r: HiveTableRelation =>
- readSidePadding(r, () => {
+ ApplyCharTypePaddingHelper.readSidePadding(r, () => {
val cleanedDataCols =
r.dataCols.map(CharVarcharUtils.cleanAttrMetadata)
val cleanedPartCols =
r.partitionCols.map(CharVarcharUtils.cleanAttrMetadata)
r.copy(dataCols = cleanedDataCols, partitionCols = cleanedPartCols)
})
}
- paddingForStringComparison(newPlan, padCharCol = false)
+ ApplyCharTypePaddingHelper.paddingForStringComparison(newPlan,
padCharCol = false)
} else {
- paddingForStringComparison(
+ ApplyCharTypePaddingHelper.paddingForStringComparison(
plan, padCharCol =
!conf.getConf(SQLConf.LEGACY_NO_CHAR_PADDING_IN_PREDICATE))
}
}
-
- private def readSidePadding(
- relation: LogicalPlan,
- cleanedRelation: () => LogicalPlan)
- : (LogicalPlan, Seq[(Attribute, Attribute)]) = {
- val projectList = relation.output.map { attr =>
- CharVarcharUtils.addPaddingForScan(attr) match {
- case ne: NamedExpression => ne
- case other => Alias(other, attr.name)(explicitMetadata =
Some(attr.metadata))
- }
- }
- if (projectList == relation.output) {
- relation -> Nil
- } else {
- val newPlan = Project(projectList, cleanedRelation())
- newPlan -> relation.output.zip(newPlan.output)
- }
- }
-
- private def paddingForStringComparison(plan: LogicalPlan, padCharCol:
Boolean): LogicalPlan = {
- plan.resolveOperatorsUpWithPruning(_.containsAnyPattern(BINARY_COMPARISON,
IN)) {
- case operator => operator.transformExpressionsUpWithPruning(
- _.containsAnyPattern(BINARY_COMPARISON, IN)) {
- case e if !e.childrenResolved => e
-
- // String literal is treated as char type when it's compared to a char
type column.
- // We should pad the shorter one to the longer length.
- case b @ BinaryComparison(e @ AttrOrOuterRef(attr), lit) if
lit.foldable =>
- padAttrLitCmp(e, attr.metadata, padCharCol, lit).map { newChildren =>
- b.withNewChildren(newChildren)
- }.getOrElse(b)
-
- case b @ BinaryComparison(lit, e @ AttrOrOuterRef(attr)) if
lit.foldable =>
- padAttrLitCmp(e, attr.metadata, padCharCol, lit).map { newChildren =>
- b.withNewChildren(newChildren.reverse)
- }.getOrElse(b)
-
- case i @ In(e @ AttrOrOuterRef(attr), list)
- if attr.dataType == StringType && list.forall(_.foldable) =>
- CharVarcharUtils.getRawType(attr.metadata).flatMap {
- case CharType(length) =>
- val (nulls, literalChars) =
- list.map(_.eval().asInstanceOf[UTF8String]).partition(_ ==
null)
- val literalCharLengths = literalChars.map(_.numChars())
- val targetLen = (length +: literalCharLengths).max
- Some(i.copy(
- value = addPadding(e, length, targetLen, alwaysPad =
padCharCol),
- list = list.zip(literalCharLengths).map {
- case (lit, charLength) =>
- addPadding(lit, charLength, targetLen, alwaysPad = false)
- } ++ nulls.map(Literal.create(_, StringType))))
- case _ => None
- }.getOrElse(i)
-
- // For char type column or inner field comparison, pad the shorter one
to the longer length.
- case b @ BinaryComparison(e1 @ AttrOrOuterRef(left), e2 @
AttrOrOuterRef(right))
- // For the same attribute, they must be the same length and no
padding is needed.
- if !left.semanticEquals(right) =>
- val outerRefs = (e1, e2) match {
- case (_: OuterReference, _: OuterReference) => Seq(left, right)
- case (_: OuterReference, _) => Seq(left)
- case (_, _: OuterReference) => Seq(right)
- case _ => Nil
- }
- val newChildren = CharVarcharUtils.addPaddingInStringComparison(
- Seq(left, right), padCharCol)
- if (outerRefs.nonEmpty) {
- b.withNewChildren(newChildren.map(_.transform {
- case a: Attribute if outerRefs.exists(_.semanticEquals(a)) =>
OuterReference(a)
- }))
- } else {
- b.withNewChildren(newChildren)
- }
-
- case i @ In(e @ AttrOrOuterRef(attr), list) if
list.forall(_.isInstanceOf[Attribute]) =>
- val newChildren = CharVarcharUtils.addPaddingInStringComparison(
- attr +: list.map(_.asInstanceOf[Attribute]), padCharCol)
- if (e.isInstanceOf[OuterReference]) {
- i.copy(
- value = newChildren.head.transform {
- case a: Attribute if a.semanticEquals(attr) =>
OuterReference(a)
- },
- list = newChildren.tail)
- } else {
- i.copy(value = newChildren.head, list = newChildren.tail)
- }
- }
- }
- }
-
- private def padAttrLitCmp(
- expr: Expression,
- metadata: Metadata,
- padCharCol: Boolean,
- lit: Expression): Option[Seq[Expression]] = {
- if (expr.dataType == StringType) {
- CharVarcharUtils.getRawType(metadata).flatMap {
- case CharType(length) =>
- val str = lit.eval().asInstanceOf[UTF8String]
- if (str == null) {
- None
- } else {
- val stringLitLen = str.numChars()
- if (length < stringLitLen) {
- Some(Seq(StringRPad(expr, Literal(stringLitLen)), lit))
- } else if (length > stringLitLen) {
- val paddedExpr = if (padCharCol) {
- StringRPad(expr, Literal(length))
- } else {
- expr
- }
- Some(Seq(paddedExpr, StringRPad(lit, Literal(length))))
- } else if (padCharCol) {
- Some(Seq(StringRPad(expr, Literal(length)), lit))
- } else {
- None
- }
- }
- case _ => None
- }
- } else {
- None
- }
- }
-
- private def addPadding(
- expr: Expression,
- charLength: Int,
- targetLength: Int,
- alwaysPad: Boolean): Expression = {
- if (targetLength > charLength) {
- StringRPad(expr, Literal(targetLength))
- } else if (alwaysPad) {
- StringRPad(expr, Literal(charLength))
- } else expr
- }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]