kelvinjian-db commented on code in PR #45649:
URL: https://github.com/apache/spark/pull/45649#discussion_r1543403309
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/With.scala:
##########
@@ -35,12 +35,84 @@ case class With(child: Expression, defs:
Seq[CommonExpressionDef])
newChildren: IndexedSeq[Expression]): Expression = {
copy(child = newChildren.head, defs =
newChildren.tail.map(_.asInstanceOf[CommonExpressionDef]))
}
+
+ /**
+ * Builds a map of ids (originally assigned ids -> canonicalized ids) to be
re-assigned during
+ * canonicalization.
+ */
+ protected lazy val canonicalizationIdMap: Map[Long, Long] = {
+ // Start numbering after taking into account all nested With expression id
maps.
+ var currentId = child.map {
+ case w: With => w.canonicalizationIdMap.size
+ case _ => 0L
+ }.sum
+ defs.map { d =>
+ currentId += 1
+ d.id.id -> currentId
+ }.toMap
+ }
+
+ /**
+ * Canonicalize by re-assigning all ids in CommonExpressionRef's and
CommonExpressionDef's
+ * starting from 0. This uses [[canonicalizationIdMap]], which contains all
mappings for
+ * CommonExpressionDef's defined in this scope.
+ * Note that this takes into account nested With expressions by sharing a
numbering scope (see
+ * [[canonicalizationIdMap]].
+ */
+ override lazy val canonicalized: Expression = copy(
+ child = child.transformWithPruning(_.containsPattern(COMMON_EXPR_REF)) {
+ case r: CommonExpressionRef if !r.id.canonicalized =>
+ r.copy(id = r.id.canonicalize(canonicalizationIdMap))
+ }.canonicalized,
+ defs = defs.map {
+ case d: CommonExpressionDef if !d.id.canonicalized =>
+ d.copy(id = d.id.canonicalize(canonicalizationIdMap)).canonicalized
+ .asInstanceOf[CommonExpressionDef]
+ case d => d.canonicalized.asInstanceOf[CommonExpressionDef]
+ }
+ )
+}
+
+object With {
+ /**
+ * Helper function to create a [[With]] statement with an arbitrary number
of common expressions.
+ * Note that the number of arguments in `commonExprs` should be the same as
the number of
+ * arguments taken by `replaced`.
+ *
+ * @param commonExprs list of common expressions
+ * @param replaced closure that defines the common expressions in the
main expression
+ * @return the expression returned by replaced with its arguments replaced
by commonExprs in order
+ */
+ def apply(commonExprs: Expression*)(replaced: Seq[Expression] =>
Expression): With = {
+ val commonExprDefs = commonExprs.map(CommonExpressionDef(_))
+ val commonExprRefs = commonExprDefs.map(new CommonExpressionRef(_))
+ With(replaced(commonExprRefs), commonExprDefs)
+ }
+}
+
+case class CommonExpressionId(id: Long = CommonExpressionId.newId,
canonicalized: Boolean = false) {
Review Comment:
what replacement are you suggesting? the `canonicalized` parameter in
`CommonExpressionId` is used to distinguish between an ID that was assigned
initially by `curId.getAndIncrement()` or newly assigned by canonicalization.
it is more of a property of the ID itself than the
`With`/`CommonExpressionDef`/`CommonExpressionRef` operators? also, if we
haven't called `.canonicalized` on the outermost `With`, it is possible to have
a `With` expression that contains some canonicalized IDs but some
non-canonicalized IDs
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]