cloud-fan commented on code in PR #39722:
URL: https://github.com/apache/spark/pull/39722#discussion_r1100954235
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala:
##########
@@ -1335,3 +1335,72 @@ trait CommutativeExpression extends Expression {
f: PartialFunction[CommutativeExpression, Seq[Expression]]):
Seq[Expression] =
gatherCommutative(this, f).sortBy(_.hashCode())
}
+
+/**
+ * A helper class used by the Commutative expressions during canonicalization.
During
+ * canonicalization, when we have a long tree of commutative operations, we
use the MultiCommutative
+ * expression to represent that tree instead of creating new commutative
objects.
+ * This class is added as a memory optimization for processing large
commutative operation trees
+ * without creating a large number of new intermediate objects.
+ * The MultiCommutativeOp memory optimization is applied to the following
commutative
+ * expressions:
+ * Add, Multiply, And, Or, BitwiseAnd, BitwiseOr, BitwiseXor.
+ * @param operands A sequence of operands that produces a commutative
expression tree.
+ * @param opCls The class of the root operator of the expression tree.
+ * @param evalMode The optional expression evaluation mode.
+ * @param originalRoot Root operator of the commutative expression tree before
canonicalization.
+ * This object reference is used to deduce the return
dataType of Add and
+ * Multiply operations when the input datatype is decimal.
+ */
+case class MultiCommutativeOp(
+ operands: Seq[Expression],
+ opCls: Class[_],
+ evalMode: Option[EvalMode.Value])(originalRoot: Expression) extends
Unevaluable {
+ // Helper method to deduce the data type of a single operation.
+ private def singleOpDataType(lType: DataType, rType: DataType): DataType = {
+ originalRoot match {
+ case add: Add =>
+ (lType, rType) match {
+ case (DecimalType.Fixed(p1, s1), DecimalType.Fixed(p2, s2)) =>
+ add.resultDecimalType(p1, s1, p2, s2)
+ case _ => lType
+ }
+ case multiply: Multiply =>
+ (lType, rType) match {
+ case (DecimalType.Fixed(p1, s1), DecimalType.Fixed(p2, s2)) =>
+ multiply.resultDecimalType(p1, s1, p2, s2)
+ case _ => lType
+ }
+ }
+ }
+
+ /**
+ * Returns the [[DataType]] of the result of evaluating this expression. It
is
+ * invalid to query the dataType of an unresolved expression (i.e., when
`resolved` == false).
Review Comment:
nit: if you look at other sub-classes of `Expression`, we don't repeat the
api doc in the override methods.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]