cloud-fan commented on code in PR #39722:
URL: https://github.com/apache/spark/pull/39722#discussion_r1100956662
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala:
##########
@@ -1335,3 +1335,72 @@ trait CommutativeExpression extends Expression {
f: PartialFunction[CommutativeExpression, Seq[Expression]]):
Seq[Expression] =
gatherCommutative(this, f).sortBy(_.hashCode())
}
+
+/**
+ * A helper class used by the Commutative expressions during canonicalization.
During
+ * canonicalization, when we have a long tree of commutative operations, we
use the MultiCommutative
+ * expression to represent that tree instead of creating new commutative
objects.
+ * This class is added as a memory optimization for processing large
commutative operation trees
+ * without creating a large number of new intermediate objects.
+ * The MultiCommutativeOp memory optimization is applied to the following
commutative
+ * expressions:
+ * Add, Multiply, And, Or, BitwiseAnd, BitwiseOr, BitwiseXor.
+ * @param operands A sequence of operands that produces a commutative
expression tree.
+ * @param opCls The class of the root operator of the expression tree.
+ * @param evalMode The optional expression evaluation mode.
+ * @param originalRoot Root operator of the commutative expression tree before
canonicalization.
+ * This object reference is used to deduce the return
dataType of Add and
+ * Multiply operations when the input datatype is decimal.
+ */
+case class MultiCommutativeOp(
+ operands: Seq[Expression],
+ opCls: Class[_],
+ evalMode: Option[EvalMode.Value])(originalRoot: Expression) extends
Unevaluable {
+ // Helper method to deduce the data type of a single operation.
+ private def singleOpDataType(lType: DataType, rType: DataType): DataType = {
+ originalRoot match {
+ case add: Add =>
+ (lType, rType) match {
+ case (DecimalType.Fixed(p1, s1), DecimalType.Fixed(p2, s2)) =>
+ add.resultDecimalType(p1, s1, p2, s2)
+ case _ => lType
+ }
+ case multiply: Multiply =>
+ (lType, rType) match {
+ case (DecimalType.Fixed(p1, s1), DecimalType.Fixed(p2, s2)) =>
+ multiply.resultDecimalType(p1, s1, p2, s2)
+ case _ => lType
+ }
+ }
+ }
+
+ /**
+ * Returns the [[DataType]] of the result of evaluating this expression. It
is
+ * invalid to query the dataType of an unresolved expression (i.e., when
`resolved` == false).
+ */
+ override def dataType: DataType = {
+ originalRoot match {
+ case _: Add | _: Multiply =>
+ operands.map(_.dataType).reduce((l, r) => singleOpDataType(l, r))
+ case other => other.dataType
+ }
+ }
+
+ /**
+ * Returns whether this node is nullable. This node is nullable if any of
its children is
+ * nullable.
+ */
+ override def nullable: Boolean = operands.exists(_.nullable)
+
+ /**
+ * Returns a Seq of the children of this node.
+ * Children should not change. Immutability required for containsChild
optimization
+ */
+ override def children: Seq[Expression] = operands
+
Review Comment:
how about providing a util function in this trait so that sub-classes is
easier to do canonicalization?
```
def buildCanonicalizedPlan(collectOperands: PartialFunction[Expression,
Seq[Expression]], buildBinaryOp: (Expression, Expression) => Expression) = ...
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]