cloud-fan commented on code in PR #45649:
URL: https://github.com/apache/spark/pull/45649#discussion_r1543996991
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/With.scala:
##########
@@ -35,12 +35,84 @@ case class With(child: Expression, defs:
Seq[CommonExpressionDef])
newChildren: IndexedSeq[Expression]): Expression = {
copy(child = newChildren.head, defs =
newChildren.tail.map(_.asInstanceOf[CommonExpressionDef]))
}
+
+ /**
+ * Builds a map of ids (originally assigned ids -> canonicalized ids) to be
re-assigned during
+ * canonicalization.
+ */
+ protected lazy val canonicalizationIdMap: Map[Long, Long] = {
+ // Start numbering after taking into account all nested With expression id
maps.
+ var currentId = child.map {
+ case w: With => w.canonicalizationIdMap.size
+ case _ => 0L
+ }.sum
+ defs.map { d =>
+ currentId += 1
+ d.id.id -> currentId
+ }.toMap
+ }
+
+ /**
+ * Canonicalize by re-assigning all ids in CommonExpressionRef's and
CommonExpressionDef's
+ * starting from 0. This uses [[canonicalizationIdMap]], which contains all
mappings for
+ * CommonExpressionDef's defined in this scope.
+ * Note that this takes into account nested With expressions by sharing a
numbering scope (see
+ * [[canonicalizationIdMap]].
+ */
+ override lazy val canonicalized: Expression = copy(
+ child = child.transformWithPruning(_.containsPattern(COMMON_EXPR_REF)) {
+ case r: CommonExpressionRef if !r.id.canonicalized =>
+ r.copy(id = r.id.canonicalize(canonicalizationIdMap))
+ }.canonicalized,
+ defs = defs.map {
+ case d: CommonExpressionDef if !d.id.canonicalized =>
+ d.copy(id = d.id.canonicalize(canonicalizationIdMap)).canonicalized
+ .asInstanceOf[CommonExpressionDef]
+ case d => d.canonicalized.asInstanceOf[CommonExpressionDef]
+ }
+ )
+}
+
+object With {
+ /**
+ * Helper function to create a [[With]] statement with an arbitrary number
of common expressions.
+ * Note that the number of arguments in `commonExprs` should be the same as
the number of
+ * arguments taken by `replaced`.
+ *
+ * @param commonExprs list of common expressions
+ * @param replaced closure that defines the common expressions in the
main expression
+ * @return the expression returned by replaced with its arguments replaced
by commonExprs in order
+ */
+ def apply(commonExprs: Expression*)(replaced: Seq[Expression] =>
Expression): With = {
+ val commonExprDefs = commonExprs.map(CommonExpressionDef(_))
+ val commonExprRefs = commonExprDefs.map(new CommonExpressionRef(_))
+ With(replaced(commonExprRefs), commonExprDefs)
+ }
+}
+
+case class CommonExpressionId(id: Long = CommonExpressionId.newId,
canonicalized: Boolean = false) {
Review Comment:
After normalization, the `With` is still a valid expression and I'm not sure
why we need a per-id flag to indicate if it has been canonicalized or not. A
per-With flag seems sufficient for testing.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]