Github user tdas commented on a diff in the pull request:

    https://github.com/apache/spark/pull/19327#discussion_r141698177
  
    --- Diff: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/StreamingJoinHelper.scala
 ---
    @@ -0,0 +1,268 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.sql.catalyst.analysis
    +
    +import scala.util.control.NonFatal
    +
    +import org.apache.spark.internal.Logging
    +import org.apache.spark.sql.catalyst.expressions.{Add, AttributeReference, 
AttributeSet, Cast, CheckOverflow, Expression, ExpressionSet, GreaterThan, 
GreaterThanOrEqual, LessThan, LessThanOrEqual, Literal, Multiply, 
PreciseTimestampConversion, PredicateHelper, Subtract, TimeAdd, TimeSub, 
UnaryMinus}
    +import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark._
    +import org.apache.spark.sql.types._
    +import org.apache.spark.unsafe.types.CalendarInterval
    +
    +
    +/**
    + * Helper object for stream joins. See [[StreamingSymmetricHashJoinExec]] 
in SQL for more details.
    + */
    +object StreamingJoinHelper extends PredicateHelper with Logging {
    +  /**
    +   * Get state value watermark (see [[StreamingSymmetricHashJoinExec]] for 
context about it)
    +   * given the join condition and the event time watermark. This is how it 
works.
    +   * - The condition is split into conjunctive predicates, and we find the 
predicates of the
    +   *   form `leftTime + c1 < rightTime + c2`   (or <=, >, >=).
    +   * - We canoncalize the predicate and solve it with the event time 
watermark value to find the
    +   *  value of the state watermark.
    +   * This function is supposed to make best-effort attempt to get the 
state watermark. If there is
    +   * any error, it will return None.
    +   *
    +   * @param attributesToFindStateWatermarkFor attributes of the side whose 
state watermark
    +   *                                         is to be calculated
    +   * @param attributesWithEventWatermark  attributes of the other side 
which has a watermark column
    +   * @param joinCondition                 join condition
    +   * @param eventWatermark                watermark defined on the input 
event data
    +   * @return state value watermark in milliseconds, is possible.
    +   */
    +  def getStateValueWatermark(
    +      attributesToFindStateWatermarkFor: AttributeSet,
    +      attributesWithEventWatermark: AttributeSet,
    +      joinCondition: Option[Expression],
    +      eventWatermark: Option[Long]): Option[Long] = {
    +
    +    // If condition or event time watermark is not provided, then cannot 
calculate state watermark
    +    if (joinCondition.isEmpty || eventWatermark.isEmpty) return None
    +
    +    // If there is not watermark attribute, then cannot define state 
watermark
    +    if 
(!attributesWithEventWatermark.exists(_.metadata.contains(delayKey))) return 
None
    +
    +    def getStateWatermarkSafely(l: Expression, r: Expression): 
Option[Long] = {
    +      try {
    +        getStateWatermarkFromLessThenPredicate(
    +          l, r, attributesToFindStateWatermarkFor, 
attributesWithEventWatermark, eventWatermark)
    +      } catch {
    +        case NonFatal(e) =>
    +          logWarning(s"Error trying to extract state constraint from 
condition $joinCondition", e)
    +          None
    +      }
    +    }
    +
    +    val allStateWatermarks = 
splitConjunctivePredicates(joinCondition.get).flatMap { predicate =>
    +
    +      // The generated the state watermark cleanup expression is inclusive 
of the state watermark.
    +      // If state watermark is W, all state where timestamp <= W will be 
cleaned up.
    +      // Now when the canonicalized join condition solves to leftTime >= 
W, we dont want to clean
    +      // up leftTime <= W. Rather we should clean up leftTime <= W - 1. 
Hence the -1 below.
    +      val stateWatermark = predicate match {
    +        case LessThan(l, r) => getStateWatermarkSafely(l, r)
    +        case LessThanOrEqual(l, r) => getStateWatermarkSafely(l, r).map(_ 
- 1)
    +        case GreaterThan(l, r) => getStateWatermarkSafely(r, l)
    +        case GreaterThanOrEqual(l, r) => getStateWatermarkSafely(r, 
l).map(_ - 1)
    +        case _ => None
    +      }
    +      if (stateWatermark.nonEmpty) {
    +        logInfo(s"Condition $joinCondition generated watermark constraint 
= ${stateWatermark.get}")
    +      }
    +      stateWatermark
    +    }
    +    allStateWatermarks.reduceOption((x, y) => Math.min(x, y))
    +  }
    +
    +  /**
    +   * Extract the state value watermark (milliseconds) from the condition
    +   * `LessThan(leftExpr, rightExpr)` where . For example: if we want to 
find the constraint for
    +   * leftTime using the watermark on the rightTime. Example:
    +   *
    +   * Input:                 rightTime-with-watermark + c1 < leftTime + c2
    +   * Canonical form:        rightTime-with-watermark + c1 + (-c2) + 
(-leftTime) < 0
    +   * Solving for rightTime: rightTime-with-watermark + c1 + (-c2) < 
leftTime
    +   * With watermark value:  watermark-value + c1 + (-c2) < leftTime
    +   */
    +  private def getStateWatermarkFromLessThenPredicate(
    +      leftExpr: Expression,
    +      rightExpr: Expression,
    +      attributesToFindStateWatermarkFor: AttributeSet,
    +      attributesWithEventWatermark: AttributeSet,
    +      eventWatermark: Option[Long]): Option[Long] = {
    +
    +    val attributesInCondition = AttributeSet(
    +      leftExpr.collect { case a: AttributeReference => a } ++
    +      rightExpr.collect { case a: AttributeReference => a }
    +    )
    +    if (attributesInCondition.filter { 
attributesToFindStateWatermarkFor.contains(_) }.size > 1 ||
    +        attributesInCondition.filter { 
attributesWithEventWatermark.contains(_) }.size > 1) {
    +      // If more than attributes present in condition from one side, then 
it cannot be solved
    +      return None
    +    }
    +
    +    def containsAttributeToFindStateConstraintFor(e: Expression): Boolean 
= {
    +      e.collectLeaves().collectFirst {
    +        case a @ AttributeReference(_, _, _, _)
    +          if attributesToFindStateWatermarkFor.contains(a) => a
    +      }.nonEmpty
    +    }
    +
    +    // Canonicalization step 1: convert to (rightTime-with-watermark + c1) 
- (leftTime + c2) < 0
    +    val allOnLeftExpr = Subtract(leftExpr, rightExpr)
    +    logDebug(s"All on 
Left:\n${allOnLeftExpr.treeString(true)}\n${allOnLeftExpr.asCode}")
    +
    +    // Canonicalization step 2: extract commutative terms
    +    //    rightTime-with-watermark, c1, -leftTime, -c2
    +    val terms = ExpressionSet(collectTerms(allOnLeftExpr))
    +    logDebug("Terms extracted from join condition:\n\t" + 
terms.mkString("\n\t"))
    +
    +
    +
    --- End diff --
    
    remove 2 extra lines


---

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to