Github user brkyvz commented on a diff in the pull request:
https://github.com/apache/spark/pull/19271#discussion_r139815966
--- Diff:
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExecHelper.scala
---
@@ -0,0 +1,303 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import scala.util.control.NonFatal
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.catalyst.expressions.{Add, Attribute,
AttributeReference, BoundReference, Cast, CheckOverflow, Expression,
ExpressionSet, GreaterThan, GreaterThanOrEqual, LessThan, LessThanOrEqual,
Literal, Multiply, NamedExpression, PredicateHelper, Subtract, TimeAdd,
TimeSub, UnaryMinus}
+import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark._
+import
org.apache.spark.sql.execution.streaming.WatermarkSupport.watermarkExpression
+import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.CalendarInterval
+
+
+/**
+ * Helper object for [[StreamingSymmetricHashJoinExec]].
+ */
+object StreamingSymmetricHashJoinExecHelper extends PredicateHelper with
Logging {
+
+ sealed trait JoinSide
+ case object LeftSide extends JoinSide { override def toString(): String
= "left" }
+ case object RightSide extends JoinSide { override def toString(): String
= "right" }
+
+ sealed trait JoinStateWatermarkPredicate
+ case class JoinStateKeyWatermarkPredicate(expr: Expression) extends
JoinStateWatermarkPredicate
+ case class JoinStateValueWatermarkPredicate(expr: Expression) extends
JoinStateWatermarkPredicate
+
+ case class JoinStateWatermarkPredicates(
+ left: Option[JoinStateWatermarkPredicate] = None,
+ right: Option[JoinStateWatermarkPredicate] = None)
+
+ def getStateWatermarkPredicates(
+ leftAttributes: Seq[Attribute],
+ rightAttributes: Seq[Attribute],
+ leftKeys: Seq[Expression],
+ rightKeys: Seq[Expression],
+ condition: Option[Expression],
+ eventTimeWatermark: Option[Long]): JoinStateWatermarkPredicates = {
+ val joinKeyOrdinalForWatermark: Option[Int] = {
+ leftKeys.zipWithIndex.collectFirst {
+ case (ne: NamedExpression, index) if
ne.metadata.contains(delayKey) => index
+ } orElse {
+ rightKeys.zipWithIndex.collectFirst {
+ case (ne: NamedExpression, index) if
ne.metadata.contains(delayKey) => index
+ }
+ }
+ }
+
+ def getOneSideStateWatermarkPredicate(
+ oneSideInputAttributes: Seq[Attribute],
+ oneSideJoinKeys: Seq[Expression],
+ otherSideInputAttributes: Seq[Attribute]):
Option[JoinStateWatermarkPredicate] = {
+ val isWatermarkDefinedOnInput =
oneSideInputAttributes.exists(_.metadata.contains(delayKey))
+ val isWatermarkDefinedOnJoinKey =
joinKeyOrdinalForWatermark.isDefined
+
+ if (isWatermarkDefinedOnJoinKey) { // case 1 and 3 explained in the
class docs
+ val keyExprWithWatermark = BoundReference(
+ joinKeyOrdinalForWatermark.get,
+ oneSideJoinKeys(joinKeyOrdinalForWatermark.get).dataType,
+ oneSideJoinKeys(joinKeyOrdinalForWatermark.get).nullable)
+ val expr = watermarkExpression(Some(keyExprWithWatermark),
eventTimeWatermark)
+ expr.map(JoinStateKeyWatermarkPredicate)
+
+ } else if (isWatermarkDefinedOnInput) { // case 2 explained in the
class docs
+ val stateValueWatermark = getStateValueWatermark(
+ attributesToFindStateWatemarkFor = oneSideInputAttributes,
+ attributesWithEventWatermark = otherSideInputAttributes,
+ condition,
+ eventTimeWatermark)
+ val inputAttributeWithWatermark =
oneSideInputAttributes.find(_.metadata.contains(delayKey))
+ val expr = watermarkExpression(inputAttributeWithWatermark,
stateValueWatermark)
+ expr.map(JoinStateValueWatermarkPredicate)
+
+ } else {
+ None
+
+ }
+ }
+
+ val leftStateWatermarkPredicate =
+ getOneSideStateWatermarkPredicate(leftAttributes, leftKeys,
rightAttributes)
+ val rightStateWatermarkPredicate =
+ getOneSideStateWatermarkPredicate(rightAttributes, rightKeys,
leftAttributes)
+ JoinStateWatermarkPredicates(leftStateWatermarkPredicate,
rightStateWatermarkPredicate)
+ }
+
+ /**
+ * Get state value watermark (see [[StreamingSymmetricHashJoinExec]] for
context about it)
+ * given the join condition and the event time watermark. This is how it
works.
+ * - The condition is split into conjunctive predicates, and we find the
predicates of the
+ * form `leftTime + c1 < rightTime + c2` (or <=, >, >=).
+ * - We canoncalize the predicate and solve it with the event time
watermark value to find the
+ * value of the state watermark.
+ *
+ * @param attributesToFindStateWatemarkFor attributes of the side whose
state watermark
+ * is to be calculated
+ * @param attributesWithEventWatermark attributes of the other side
which has a watermark column
+ * @param joinCondition join condition
+ * @param eventWatermark watermark defined on the input
event data
+ * @return state value watermark in milliseconds
+ */
+ def getStateValueWatermark(
+ attributesToFindStateWatemarkFor: Seq[Attribute],
+ attributesWithEventWatermark: Seq[Attribute],
+ joinCondition: Option[Expression],
+ eventWatermark: Option[Long]): Option[Long] = {
+ if (joinCondition.isEmpty || eventWatermark.isEmpty) return None
+
+ def getStateWatermarkSafely(l: Expression, r: Expression):
Option[Long] = {
+ try {
+ getStateWatemarkFromLessThenPredicate(
+ l, r, attributesToFindStateWatemarkFor,
attributesWithEventWatermark, eventWatermark)
+ } catch {
+ case NonFatal(e) =>
+ logWarning(s"Error trying to extract state constraint from
condition $joinCondition", e)
+ None
+ }
+ }
+
+ val allStateWatermarks =
splitConjunctivePredicates(joinCondition.get).flatMap { predicate =>
+ val stateWatermark = predicate match {
+ case LessThan(l, r) => getStateWatermarkSafely(l, r)
+ case LessThanOrEqual(l, r) => getStateWatermarkSafely(l, r).map(_
- 1)
+ case GreaterThan(l, r) => getStateWatermarkSafely(r, l)
+ case GreaterThanOrEqual(l, r) => getStateWatermarkSafely(r,
l).map(_ - 1)
+ case _ => None
+ }
+ if (stateWatermark.nonEmpty) {
+ logInfo(s"Condition $joinCondition generated watermark constraint
= ${stateWatermark.get}")
+ }
+ stateWatermark
+ }
+ allStateWatermarks.reduceOption((x, y) => Math.min(x, y))
+ }
+
+ /**
+ * Extract constraint from conditions. For example: if we want to find
the constraint for
+ * leftTime using the watermark on the rightTime. Example:
+ *
+ * Input: rightTime-with-watermark + c1 < leftTime + c2
+ * Canonical form: rightTime-with-watermark + c1 + (-c2) +
(-leftTime) < 0
+ * Solving for rightTime: rightTime-with-watermark + c1 + (-c2) <
leftTime
+ * With watermark value: watermark-value + c1 + (-c2) < leftTime
+ */
+ private def getStateWatemarkFromLessThenPredicate(
+ leftExpr: Expression,
+ rightExpr: Expression,
+ attributesToFindStateWatermarkFor: Seq[Attribute],
+ attributesWithEventWatermark: Seq[Attribute],
+ eventWatermark: Option[Long]): Option[Long] = {
+
+ def containsAttributeToFindStateConstraintFor(e: Expression): Boolean
= {
+ e.collectLeaves().collectFirst {
+ case a@AttributeReference(_, TimestampType, _, _)
+ if attributesToFindStateWatermarkFor.contains(a) => a
+ }.nonEmpty
+ }
+
+ // Canonicalization step 1: convert to (rightTime-with-watermark + c1)
- (leftTime + c2) < 0
+ val allOnLeftExpr = Subtract(leftExpr, rightExpr)
+ logDebug(s"All on
Left:\n${allOnLeftExpr.treeString(true)}\n${allOnLeftExpr.asCode}")
+
+ // Canonicalization step 2: extract commutative terms
+ // rightTime-with-watermark, c1, -leftTime, -c2
+ val terms = ExpressionSet(collectTerms(allOnLeftExpr))
+ logDebug("Terms extracted from join condition:\n\t" +
terms.mkString("\n\t"))
+
+ // Find the term that has leftTime (i.e. the one present in
attributesToFindConstraintFor
+ val constraintTerms =
terms.filter(containsAttributeToFindStateConstraintFor)
+
+ // Verify there is only one correct constraint term and of the correct
type
+ if (constraintTerms.size > 1) {
+ logWarning("Failed to extract state constraint terms: multiple time
terms in condition\n\t" +
+ terms.mkString("\n\t"))
+ return None
+ }
+ if (constraintTerms.isEmpty) {
+ logDebug("Failed to extract state constraint terms: no time terms in
condition\n\t" +
+ terms.mkString("\n\t"))
+ return None
+ }
+ val constraintTerm = constraintTerms.head
+ if (constraintTerm.collectFirst { case u: UnaryMinus => u }.isEmpty) {
+ // Incorrect condition. We want the constraint term in canonical
form to be `-leftTime`
+ // so that resolve for it as `-leftTime + watermark + c < 0` ==>
`watermark + c < leftTime`.
+ // Now, if the original conditions is `rightTime-with-watermark >
leftTime` and watermark
+ // condition is `rightTime-with-watermark > watermarkValue`, then no
constraint about
+ // `leftTime` can be inferred. In this case, after canonicalization
and collection of terms,
+ // the constraintTerm would be `leftTime` and not `-leftTime`.
Hence, we return None.
+ return None
+ }
+
+ // Replace watermark attribute with watermark value, and generate the
resolved expression
+ // from the other terms. That is,
+ // rightTime-with-watermark, c1, -c2 => watermark, c1, -c2 =>
watermark + c1 + (-c2)
+ logDebug(s"Constraint term from join condition:\t$constraintTerm")
+ val exprWithWatermarkSubstituted = (terms - constraintTerm).map { term
=>
+ term.transform {
+ case a@AttributeReference(_, TimestampType, _, metadata)
+ if attributesWithEventWatermark.contains(a) &&
a.metadata.contains(delayKey) =>
+ Literal(eventWatermark.get)
+ }
+ }.reduceLeft(Add)
+
+ // Calculate the constraint value
+ logInfo(s"Final expression to evaluate
constraint:\t$exprWithWatermarkSubstituted")
+ val constraintValue =
exprWithWatermarkSubstituted.eval().asInstanceOf[java.lang.Double]
+ Some(Double2double(constraintValue).toLong)
+ }
+
+ /**
+ * Collect all the terms present in an expression after converting it
into the form
+ * a + b + c + d where each term be either an attribute or a literal
casted to long,
+ * optionally wrapped in a unary minus.
+ */
+ private def collectTerms(exprToCollectFrom: Expression): Seq[Expression]
= {
+ var invalid = false
+
+ /** Wrap a term with UnaryMinus if its needs to be negated. */
+ def negateIfNeeded(expr: Expression, minus: Boolean): Expression = {
+ if (minus) UnaryMinus(expr) else expr
+ }
+
+ /**
+ * Recursively split the expression into its leaf terms contains
attributes or literals.
+ * Returns terms only of the forms:
+ * Csat(AttributeReference), UnaryMinus(Cast(AttributeReference)),
+ * Cast(AttributeReference, Double),
UnaryMinus(Cast(AttributeReference, Double))
+ * Multiply(Literal), UnaryMinus(Multiply(Literal))
+ * Multiply(Cast(Literal)), UnaryMinus(Multiple(Cast(Literal)))
+ *
+ * Note:
+ * - If term needs to be negated for making it a commutative term,
+ * then it will be wrapped in UnaryMinus(...)
+ * - Each terms will be representing timestamp value or time interval
in milliseconds,
+ * typed as doubles.
+ */
+ def collect(expr: Expression, negate: Boolean): Seq[Expression] = {
+ expr match {
+ case Add(left, right) =>
+ collect(left, negate) ++ collect(right, negate)
+ case Subtract(left, right) =>
+ collect(left, negate) ++ collect(right, !negate)
+ case TimeAdd(left, right, _) =>
+ collect(left, negate) ++ collect(right, negate)
+ case TimeSub(left, right, _) =>
+ collect(left, negate) ++ collect(right, !negate)
+ case UnaryMinus(child) =>
+ collect(child, !negate)
+ case CheckOverflow(child, _) =>
+ collect(child, negate)
+ case Cast(child, dataType, _) =>
+ dataType match {
+ case _: NumericType | _: TimestampType => collect(child,
negate)
+ case _ =>
+ invalid = true
+ Seq.empty
+ }
+ case a: AttributeReference =>
+ val castedRef = if (a.dataType != DoubleType) Cast(a,
DoubleType) else a
+ Seq(negateIfNeeded(castedRef, negate))
+ case lit: Literal =>
+ // If literal of type calendar interval, then explicitly convert
to millis
+ // Convert other number like literal to doubles representing
millis (by x1000)
+ val castedLit = lit.dataType match {
+ case CalendarIntervalType =>
+ val calendarInterval =
lit.value.asInstanceOf[CalendarInterval]
+ val millisPerMonth = CalendarInterval.MICROS_PER_DAY / 1000
* 31
--- End diff --
I don't think 31 is a standard anywhere. I would prefer if we just throw an
AnalysisException here if we find months.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]