Github user tdas commented on a diff in the pull request:
https://github.com/apache/spark/pull/19271#discussion_r139835014
--- Diff:
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExecHelper.scala
---
@@ -0,0 +1,303 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import scala.util.control.NonFatal
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.catalyst.expressions.{Add, Attribute,
AttributeReference, BoundReference, Cast, CheckOverflow, Expression,
ExpressionSet, GreaterThan, GreaterThanOrEqual, LessThan, LessThanOrEqual,
Literal, Multiply, NamedExpression, PredicateHelper, Subtract, TimeAdd,
TimeSub, UnaryMinus}
+import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark._
+import
org.apache.spark.sql.execution.streaming.WatermarkSupport.watermarkExpression
+import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.CalendarInterval
+
+
+/**
+ * Helper object for [[StreamingSymmetricHashJoinExec]].
+ */
+object StreamingSymmetricHashJoinExecHelper extends PredicateHelper with
Logging {
+
+ sealed trait JoinSide
+ case object LeftSide extends JoinSide { override def toString(): String
= "left" }
+ case object RightSide extends JoinSide { override def toString(): String
= "right" }
+
+ sealed trait JoinStateWatermarkPredicate
+ case class JoinStateKeyWatermarkPredicate(expr: Expression) extends
JoinStateWatermarkPredicate
+ case class JoinStateValueWatermarkPredicate(expr: Expression) extends
JoinStateWatermarkPredicate
+
+ case class JoinStateWatermarkPredicates(
+ left: Option[JoinStateWatermarkPredicate] = None,
+ right: Option[JoinStateWatermarkPredicate] = None)
+
+ def getStateWatermarkPredicates(
+ leftAttributes: Seq[Attribute],
+ rightAttributes: Seq[Attribute],
+ leftKeys: Seq[Expression],
+ rightKeys: Seq[Expression],
+ condition: Option[Expression],
+ eventTimeWatermark: Option[Long]): JoinStateWatermarkPredicates = {
+ val joinKeyOrdinalForWatermark: Option[Int] = {
+ leftKeys.zipWithIndex.collectFirst {
+ case (ne: NamedExpression, index) if
ne.metadata.contains(delayKey) => index
+ } orElse {
+ rightKeys.zipWithIndex.collectFirst {
+ case (ne: NamedExpression, index) if
ne.metadata.contains(delayKey) => index
+ }
+ }
+ }
+
+ def getOneSideStateWatermarkPredicate(
+ oneSideInputAttributes: Seq[Attribute],
+ oneSideJoinKeys: Seq[Expression],
+ otherSideInputAttributes: Seq[Attribute]):
Option[JoinStateWatermarkPredicate] = {
+ val isWatermarkDefinedOnInput =
oneSideInputAttributes.exists(_.metadata.contains(delayKey))
+ val isWatermarkDefinedOnJoinKey =
joinKeyOrdinalForWatermark.isDefined
+
+ if (isWatermarkDefinedOnJoinKey) { // case 1 and 3 explained in the
class docs
+ val keyExprWithWatermark = BoundReference(
+ joinKeyOrdinalForWatermark.get,
+ oneSideJoinKeys(joinKeyOrdinalForWatermark.get).dataType,
+ oneSideJoinKeys(joinKeyOrdinalForWatermark.get).nullable)
+ val expr = watermarkExpression(Some(keyExprWithWatermark),
eventTimeWatermark)
+ expr.map(JoinStateKeyWatermarkPredicate)
+
+ } else if (isWatermarkDefinedOnInput) { // case 2 explained in the
class docs
+ val stateValueWatermark = getStateValueWatermark(
+ attributesToFindStateWatemarkFor = oneSideInputAttributes,
+ attributesWithEventWatermark = otherSideInputAttributes,
+ condition,
+ eventTimeWatermark)
+ val inputAttributeWithWatermark =
oneSideInputAttributes.find(_.metadata.contains(delayKey))
+ val expr = watermarkExpression(inputAttributeWithWatermark,
stateValueWatermark)
+ expr.map(JoinStateValueWatermarkPredicate)
+
+ } else {
+ None
+
+ }
+ }
+
+ val leftStateWatermarkPredicate =
+ getOneSideStateWatermarkPredicate(leftAttributes, leftKeys,
rightAttributes)
+ val rightStateWatermarkPredicate =
+ getOneSideStateWatermarkPredicate(rightAttributes, rightKeys,
leftAttributes)
+ JoinStateWatermarkPredicates(leftStateWatermarkPredicate,
rightStateWatermarkPredicate)
+ }
+
+ /**
+ * Get state value watermark (see [[StreamingSymmetricHashJoinExec]] for
context about it)
+ * given the join condition and the event time watermark. This is how it
works.
+ * - The condition is split into conjunctive predicates, and we find the
predicates of the
+ * form `leftTime + c1 < rightTime + c2` (or <=, >, >=).
+ * - We canoncalize the predicate and solve it with the event time
watermark value to find the
+ * value of the state watermark.
+ *
+ * @param attributesToFindStateWatemarkFor attributes of the side whose
state watermark
+ * is to be calculated
+ * @param attributesWithEventWatermark attributes of the other side
which has a watermark column
+ * @param joinCondition join condition
+ * @param eventWatermark watermark defined on the input
event data
+ * @return state value watermark in milliseconds
+ */
+ def getStateValueWatermark(
+ attributesToFindStateWatemarkFor: Seq[Attribute],
+ attributesWithEventWatermark: Seq[Attribute],
+ joinCondition: Option[Expression],
+ eventWatermark: Option[Long]): Option[Long] = {
+ if (joinCondition.isEmpty || eventWatermark.isEmpty) return None
+
+ def getStateWatermarkSafely(l: Expression, r: Expression):
Option[Long] = {
+ try {
+ getStateWatemarkFromLessThenPredicate(
+ l, r, attributesToFindStateWatemarkFor,
attributesWithEventWatermark, eventWatermark)
+ } catch {
+ case NonFatal(e) =>
+ logWarning(s"Error trying to extract state constraint from
condition $joinCondition", e)
+ None
+ }
+ }
+
+ val allStateWatermarks =
splitConjunctivePredicates(joinCondition.get).flatMap { predicate =>
+ val stateWatermark = predicate match {
+ case LessThan(l, r) => getStateWatermarkSafely(l, r)
+ case LessThanOrEqual(l, r) => getStateWatermarkSafely(l, r).map(_
- 1)
+ case GreaterThan(l, r) => getStateWatermarkSafely(r, l)
+ case GreaterThanOrEqual(l, r) => getStateWatermarkSafely(r,
l).map(_ - 1)
+ case _ => None
+ }
+ if (stateWatermark.nonEmpty) {
+ logInfo(s"Condition $joinCondition generated watermark constraint
= ${stateWatermark.get}")
+ }
+ stateWatermark
+ }
+ allStateWatermarks.reduceOption((x, y) => Math.min(x, y))
+ }
+
+ /**
+ * Extract constraint from conditions. For example: if we want to find
the constraint for
+ * leftTime using the watermark on the rightTime. Example:
+ *
+ * Input: rightTime-with-watermark + c1 < leftTime + c2
+ * Canonical form: rightTime-with-watermark + c1 + (-c2) +
(-leftTime) < 0
+ * Solving for rightTime: rightTime-with-watermark + c1 + (-c2) <
leftTime
+ * With watermark value: watermark-value + c1 + (-c2) < leftTime
+ */
+ private def getStateWatemarkFromLessThenPredicate(
+ leftExpr: Expression,
+ rightExpr: Expression,
+ attributesToFindStateWatermarkFor: Seq[Attribute],
+ attributesWithEventWatermark: Seq[Attribute],
+ eventWatermark: Option[Long]): Option[Long] = {
+
+ def containsAttributeToFindStateConstraintFor(e: Expression): Boolean
= {
+ e.collectLeaves().collectFirst {
+ case a@AttributeReference(_, TimestampType, _, _)
+ if attributesToFindStateWatermarkFor.contains(a) => a
+ }.nonEmpty
+ }
+
+ // Canonicalization step 1: convert to (rightTime-with-watermark + c1)
- (leftTime + c2) < 0
+ val allOnLeftExpr = Subtract(leftExpr, rightExpr)
+ logDebug(s"All on
Left:\n${allOnLeftExpr.treeString(true)}\n${allOnLeftExpr.asCode}")
+
+ // Canonicalization step 2: extract commutative terms
+ // rightTime-with-watermark, c1, -leftTime, -c2
+ val terms = ExpressionSet(collectTerms(allOnLeftExpr))
+ logDebug("Terms extracted from join condition:\n\t" +
terms.mkString("\n\t"))
+
+ // Find the term that has leftTime (i.e. the one present in
attributesToFindConstraintFor
+ val constraintTerms =
terms.filter(containsAttributeToFindStateConstraintFor)
+
+ // Verify there is only one correct constraint term and of the correct
type
+ if (constraintTerms.size > 1) {
+ logWarning("Failed to extract state constraint terms: multiple time
terms in condition\n\t" +
+ terms.mkString("\n\t"))
+ return None
+ }
+ if (constraintTerms.isEmpty) {
+ logDebug("Failed to extract state constraint terms: no time terms in
condition\n\t" +
+ terms.mkString("\n\t"))
+ return None
+ }
+ val constraintTerm = constraintTerms.head
+ if (constraintTerm.collectFirst { case u: UnaryMinus => u }.isEmpty) {
+ // Incorrect condition. We want the constraint term in canonical
form to be `-leftTime`
+ // so that resolve for it as `-leftTime + watermark + c < 0` ==>
`watermark + c < leftTime`.
+ // Now, if the original conditions is `rightTime-with-watermark >
leftTime` and watermark
+ // condition is `rightTime-with-watermark > watermarkValue`, then no
constraint about
+ // `leftTime` can be inferred. In this case, after canonicalization
and collection of terms,
+ // the constraintTerm would be `leftTime` and not `-leftTime`.
Hence, we return None.
+ return None
+ }
+
+ // Replace watermark attribute with watermark value, and generate the
resolved expression
+ // from the other terms. That is,
+ // rightTime-with-watermark, c1, -c2 => watermark, c1, -c2 =>
watermark + c1 + (-c2)
+ logDebug(s"Constraint term from join condition:\t$constraintTerm")
+ val exprWithWatermarkSubstituted = (terms - constraintTerm).map { term
=>
+ term.transform {
+ case a@AttributeReference(_, TimestampType, _, metadata)
+ if attributesWithEventWatermark.contains(a) &&
a.metadata.contains(delayKey) =>
+ Literal(eventWatermark.get)
+ }
+ }.reduceLeft(Add)
+
+ // Calculate the constraint value
+ logInfo(s"Final expression to evaluate
constraint:\t$exprWithWatermarkSubstituted")
+ val constraintValue =
exprWithWatermarkSubstituted.eval().asInstanceOf[java.lang.Double]
+ Some(Double2double(constraintValue).toLong)
+ }
+
+ /**
+ * Collect all the terms present in an expression after converting it
into the form
+ * a + b + c + d where each term be either an attribute or a literal
casted to long,
+ * optionally wrapped in a unary minus.
+ */
+ private def collectTerms(exprToCollectFrom: Expression): Seq[Expression]
= {
+ var invalid = false
+
+ /** Wrap a term with UnaryMinus if its needs to be negated. */
+ def negateIfNeeded(expr: Expression, minus: Boolean): Expression = {
+ if (minus) UnaryMinus(expr) else expr
+ }
+
+ /**
+ * Recursively split the expression into its leaf terms contains
attributes or literals.
+ * Returns terms only of the forms:
+ * Csat(AttributeReference), UnaryMinus(Cast(AttributeReference)),
+ * Cast(AttributeReference, Double),
UnaryMinus(Cast(AttributeReference, Double))
+ * Multiply(Literal), UnaryMinus(Multiply(Literal))
+ * Multiply(Cast(Literal)), UnaryMinus(Multiple(Cast(Literal)))
+ *
+ * Note:
+ * - If term needs to be negated for making it a commutative term,
+ * then it will be wrapped in UnaryMinus(...)
+ * - Each terms will be representing timestamp value or time interval
in milliseconds,
+ * typed as doubles.
+ */
+ def collect(expr: Expression, negate: Boolean): Seq[Expression] = {
+ expr match {
+ case Add(left, right) =>
+ collect(left, negate) ++ collect(right, negate)
+ case Subtract(left, right) =>
+ collect(left, negate) ++ collect(right, !negate)
+ case TimeAdd(left, right, _) =>
+ collect(left, negate) ++ collect(right, negate)
+ case TimeSub(left, right, _) =>
+ collect(left, negate) ++ collect(right, !negate)
+ case UnaryMinus(child) =>
+ collect(child, !negate)
+ case CheckOverflow(child, _) =>
+ collect(child, negate)
+ case Cast(child, dataType, _) =>
+ dataType match {
+ case _: NumericType | _: TimestampType => collect(child,
negate)
+ case _ =>
+ invalid = true
+ Seq.empty
+ }
+ case a: AttributeReference =>
+ val castedRef = if (a.dataType != DoubleType) Cast(a,
DoubleType) else a
+ Seq(negateIfNeeded(castedRef, negate))
+ case lit: Literal =>
+ // If literal of type calendar interval, then explicitly convert
to millis
+ // Convert other number like literal to doubles representing
millis (by x1000)
+ val castedLit = lit.dataType match {
+ case CalendarIntervalType =>
+ val calendarInterval =
lit.value.asInstanceOf[CalendarInterval]
+ val millisPerMonth = CalendarInterval.MICROS_PER_DAY / 1000
* 31
--- End diff --
I think there is a problem with months. Overestimation (similar to what we
do for watermark delay specific in months) is not correct in this case. It
should not throw analysis exception though as this is fundamentally a
best-effort extraction.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]