Github user brkyvz commented on a diff in the pull request:

    https://github.com/apache/spark/pull/19271#discussion_r139814268
  
    --- Diff: 
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExecHelper.scala
 ---
    @@ -0,0 +1,303 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.sql.execution.streaming
    +
    +import scala.util.control.NonFatal
    +
    +import org.apache.spark.internal.Logging
    +import org.apache.spark.sql.catalyst.expressions.{Add, Attribute, 
AttributeReference, BoundReference, Cast, CheckOverflow, Expression, 
ExpressionSet, GreaterThan, GreaterThanOrEqual, LessThan, LessThanOrEqual, 
Literal, Multiply, NamedExpression, PredicateHelper, Subtract, TimeAdd, 
TimeSub, UnaryMinus}
    +import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark._
    +import 
org.apache.spark.sql.execution.streaming.WatermarkSupport.watermarkExpression
    +import org.apache.spark.sql.types._
    +import org.apache.spark.unsafe.types.CalendarInterval
    +
    +
    +/**
    + * Helper object for [[StreamingSymmetricHashJoinExec]].
    + */
    +object StreamingSymmetricHashJoinExecHelper extends PredicateHelper with 
Logging {
    +
    +  sealed trait JoinSide
    +  case object LeftSide extends JoinSide { override def toString(): String 
= "left" }
    +  case object RightSide extends JoinSide { override def toString(): String 
= "right" }
    +
    +  sealed trait JoinStateWatermarkPredicate
    +  case class JoinStateKeyWatermarkPredicate(expr: Expression) extends 
JoinStateWatermarkPredicate
    +  case class JoinStateValueWatermarkPredicate(expr: Expression) extends 
JoinStateWatermarkPredicate
    +
    +  case class JoinStateWatermarkPredicates(
    +    left: Option[JoinStateWatermarkPredicate] = None,
    +    right: Option[JoinStateWatermarkPredicate] = None)
    +
    +  def getStateWatermarkPredicates(
    +      leftAttributes: Seq[Attribute],
    +      rightAttributes: Seq[Attribute],
    +      leftKeys: Seq[Expression],
    +      rightKeys: Seq[Expression],
    +      condition: Option[Expression],
    +      eventTimeWatermark: Option[Long]): JoinStateWatermarkPredicates = {
    +    val joinKeyOrdinalForWatermark: Option[Int] = {
    +      leftKeys.zipWithIndex.collectFirst {
    +        case (ne: NamedExpression, index) if 
ne.metadata.contains(delayKey) => index
    +      } orElse {
    +        rightKeys.zipWithIndex.collectFirst {
    +          case (ne: NamedExpression, index) if 
ne.metadata.contains(delayKey) => index
    +        }
    +      }
    +    }
    +
    +    def getOneSideStateWatermarkPredicate(
    +        oneSideInputAttributes: Seq[Attribute],
    +        oneSideJoinKeys: Seq[Expression],
    +        otherSideInputAttributes: Seq[Attribute]): 
Option[JoinStateWatermarkPredicate] = {
    +      val isWatermarkDefinedOnInput = 
oneSideInputAttributes.exists(_.metadata.contains(delayKey))
    +      val isWatermarkDefinedOnJoinKey = 
joinKeyOrdinalForWatermark.isDefined
    +
    +      if (isWatermarkDefinedOnJoinKey) { // case 1 and 3 explained in the 
class docs
    +        val keyExprWithWatermark = BoundReference(
    +          joinKeyOrdinalForWatermark.get,
    +          oneSideJoinKeys(joinKeyOrdinalForWatermark.get).dataType,
    +          oneSideJoinKeys(joinKeyOrdinalForWatermark.get).nullable)
    +        val expr = watermarkExpression(Some(keyExprWithWatermark), 
eventTimeWatermark)
    +        expr.map(JoinStateKeyWatermarkPredicate)
    +
    +      } else if (isWatermarkDefinedOnInput) { // case 2 explained in the 
class docs
    +        val stateValueWatermark = getStateValueWatermark(
    +          attributesToFindStateWatemarkFor = oneSideInputAttributes,
    +          attributesWithEventWatermark = otherSideInputAttributes,
    +          condition,
    +          eventTimeWatermark)
    +        val inputAttributeWithWatermark = 
oneSideInputAttributes.find(_.metadata.contains(delayKey))
    +        val expr = watermarkExpression(inputAttributeWithWatermark, 
stateValueWatermark)
    +        expr.map(JoinStateValueWatermarkPredicate)
    +
    +      } else {
    +        None
    +
    +      }
    +    }
    +
    +    val leftStateWatermarkPredicate =
    +      getOneSideStateWatermarkPredicate(leftAttributes, leftKeys, 
rightAttributes)
    +    val rightStateWatermarkPredicate =
    +      getOneSideStateWatermarkPredicate(rightAttributes, rightKeys, 
leftAttributes)
    +    JoinStateWatermarkPredicates(leftStateWatermarkPredicate, 
rightStateWatermarkPredicate)
    +  }
    +
    +  /**
    +   * Get state value watermark (see [[StreamingSymmetricHashJoinExec]] for 
context about it)
    +   * given the join condition and the event time watermark. This is how it 
works.
    +   * - The condition is split into conjunctive predicates, and we find the 
predicates of the
    +   *   form `leftTime + c1 < rightTime + c2`   (or <=, >, >=).
    +   * - We canoncalize the predicate and solve it with the event time 
watermark value to find the
    +   *  value of the state watermark.
    +   *
    +   * @param attributesToFindStateWatemarkFor attributes of the side whose 
state watermark
    +   *                                         is to be calculated
    +   * @param attributesWithEventWatermark  attributes of the other side 
which has a watermark column
    +   * @param joinCondition                 join condition
    +   * @param eventWatermark                watermark defined on the input 
event data
    +   * @return state value watermark in milliseconds
    +   */
    +  def getStateValueWatermark(
    +      attributesToFindStateWatemarkFor: Seq[Attribute],
    +      attributesWithEventWatermark: Seq[Attribute],
    +      joinCondition: Option[Expression],
    +      eventWatermark: Option[Long]): Option[Long] = {
    +    if (joinCondition.isEmpty || eventWatermark.isEmpty) return None
    +
    +    def getStateWatermarkSafely(l: Expression, r: Expression): 
Option[Long] = {
    +      try {
    +        getStateWatemarkFromLessThenPredicate(
    +          l, r, attributesToFindStateWatemarkFor, 
attributesWithEventWatermark, eventWatermark)
    +      } catch {
    +        case NonFatal(e) =>
    +          logWarning(s"Error trying to extract state constraint from 
condition $joinCondition", e)
    +          None
    +      }
    +    }
    +
    +    val allStateWatermarks = 
splitConjunctivePredicates(joinCondition.get).flatMap { predicate =>
    +      val stateWatermark = predicate match {
    +        case LessThan(l, r) => getStateWatermarkSafely(l, r)
    +        case LessThanOrEqual(l, r) => getStateWatermarkSafely(l, r).map(_ 
- 1)
    +        case GreaterThan(l, r) => getStateWatermarkSafely(r, l)
    +        case GreaterThanOrEqual(l, r) => getStateWatermarkSafely(r, 
l).map(_ - 1)
    --- End diff --
    
    if the above line is `-1`, then I would have assumed this should be `+1`


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to