xuzifu666 commented on code in PR #55982:
URL: https://github.com/apache/spark/pull/55982#discussion_r3278916487
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeRand.scala:
##########
@@ -17,47 +17,182 @@
package org.apache.spark.sql.catalyst.optimizer
-import org.apache.spark.sql.catalyst.expressions.{BinaryComparison,
DoubleLiteral, Expression, GreaterThan, GreaterThanOrEqual, LessThan,
LessThanOrEqual, Rand}
+import org.apache.spark.sql.catalyst.expressions.{Add, BinaryComparison,
Divide,
+ DoubleLiteral, EqualTo, Expression, GreaterThan, GreaterThanOrEqual,
LessThan,
+ LessThanOrEqual, Literal, Multiply, Rand, Subtract}
import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral,
TrueLiteral}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
-import org.apache.spark.sql.catalyst.trees.TreePattern.{BINARY_COMPARISON,
EXPRESSION_WITH_RANDOM_SEED, LITERAL}
+import
org.apache.spark.sql.catalyst.trees.TreePattern.EXPRESSION_WITH_RANDOM_SEED
-/**
- * Rand() generates a random column with i.i.d. uniformly distributed values
in [0, 1), so
- * compare double literal value with 1.0 or 0.0 could eliminate Rand() in
binary comparison.
- *
- * 1. Converts the binary comparison to true literal when the comparison value
must be true.
- * 2. Converts the binary comparison to false literal when the comparison
value must be false.
- */
object OptimizeRand extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan =
- plan.transformAllExpressionsWithPruning(_.containsAllPatterns(
- EXPRESSION_WITH_RANDOM_SEED, LITERAL, BINARY_COMPARISON), ruleId) {
- case op @ BinaryComparison(DoubleLiteral(_), _: Rand) =>
eliminateRand(swapComparison(op))
- case op @ BinaryComparison(_: Rand, DoubleLiteral(_)) =>
eliminateRand(op)
+ plan.transformAllExpressionsWithPruning(_.containsAnyPattern(
+ EXPRESSION_WITH_RANDOM_SEED), ruleId) {
+ case op @ BinaryComparison(DoubleLiteral(_), _: Rand) =>
+ eliminateRand(swapComparison(op))
+ case op @ BinaryComparison(_: Rand, DoubleLiteral(_)) =>
+ eliminateRand(op)
+ case op: BinaryComparison if hasRand(op.left) || hasRand(op.right) =>
+ optimizeArithmetic(op)
}
- /**
- * Swaps the left and right sides of some binary comparisons. e.g.,
transform "a < b" to "b > a"
- */
- private def swapComparison(comparison: BinaryComparison): BinaryComparison =
comparison match {
- case a LessThan b => GreaterThan(b, a)
- case a LessThanOrEqual b => GreaterThanOrEqual(b, a)
- case a GreaterThan b => LessThan(b, a)
- case a GreaterThanOrEqual b => LessThanOrEqual(b, a)
- case o => o
+ private def hasRand(expr: Expression): Boolean = expr match {
+ case _: Rand => true
+ case a: Add => hasRand(a.left) || hasRand(a.right)
+ case s: Subtract => hasRand(s.left) || hasRand(s.right)
+ case m: Multiply => hasRand(m.left) || hasRand(m.right)
+ case d: Divide => hasRand(d.left) || hasRand(d.right)
+ case _ => false
}
+ private def swapComparison(comparison: BinaryComparison): BinaryComparison =
+ comparison match {
+ case GreaterThan(l, r) => LessThan(r, l)
+ case GreaterThanOrEqual(l, r) => LessThanOrEqual(r, l)
+ case LessThan(l, r) => GreaterThan(r, l)
+ case LessThanOrEqual(l, r) => GreaterThanOrEqual(r, l)
+ case o => o
+ }
+
private def eliminateRand(op: BinaryComparison): Expression = op match {
- case GreaterThan(_: Rand, DoubleLiteral(value)) =>
- if (value < 0.0) TrueLiteral else if (value >= 1.0) FalseLiteral else op
- case GreaterThanOrEqual(_: Rand, DoubleLiteral(value)) =>
- if (value <= 0.0) TrueLiteral else if (value >= 1.0) FalseLiteral else op
- case LessThan(_: Rand, DoubleLiteral(value)) =>
- if (value >= 1.0) TrueLiteral else if (value <= 0.0) FalseLiteral else op
- case LessThanOrEqual(_: Rand, DoubleLiteral(value)) =>
- if (value >= 1.0) TrueLiteral else if (value < 0.0) FalseLiteral else op
+ case GreaterThan(_: Rand, DoubleLiteral(v)) =>
+ if (v < 0.0) TrueLiteral else if (v >= 1.0) FalseLiteral else op
+ case GreaterThanOrEqual(_: Rand, DoubleLiteral(v)) =>
+ if (v <= 0.0) TrueLiteral else if (v >= 1.0) FalseLiteral else op
+ case LessThan(_: Rand, DoubleLiteral(v)) =>
+ if (v >= 1.0) TrueLiteral else if (v <= 0.0) FalseLiteral else op
+ case LessThanOrEqual(_: Rand, DoubleLiteral(v)) =>
+ if (v >= 1.0) TrueLiteral else if (v < 0.0) FalseLiteral else op
case other => other
}
+
+ private def extractDouble(lit: Expression): Option[Double] = lit match {
+ case DoubleLiteral(v) => Some(v)
+ case Literal(v: Double, _) => Some(v)
+ case Literal(v: java.lang.Double, _) => Some(v.doubleValue())
+ case Literal(v: java.lang.Number, _) => Some(v.doubleValue())
+ case _ => None
+ }
+
+ case class RandExpr(coeff: Double, offset: Double)
+
+ private def extractRandCoeffOffset(expr: Expression): Option[RandExpr] = {
+ expr match {
+ case _: Rand => Some(RandExpr(1.0, 0.0))
+ case m: Multiply =>
Review Comment:
Nice question! I had narrow the check range to match the extractor's
capabilities by replacing hasRand() with a new isDirectRandChild() function
that only checks for direct children.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]