Github user tdas commented on a diff in the pull request:
https://github.com/apache/spark/pull/19327#discussion_r141698177
--- Diff:
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/StreamingJoinHelper.scala
---
@@ -0,0 +1,268 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.analysis
+
+import scala.util.control.NonFatal
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.catalyst.expressions.{Add, AttributeReference,
AttributeSet, Cast, CheckOverflow, Expression, ExpressionSet, GreaterThan,
GreaterThanOrEqual, LessThan, LessThanOrEqual, Literal, Multiply,
PreciseTimestampConversion, PredicateHelper, Subtract, TimeAdd, TimeSub,
UnaryMinus}
+import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark._
+import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.CalendarInterval
+
+
+/**
+ * Helper object for stream joins. See [[StreamingSymmetricHashJoinExec]]
in SQL for more details.
+ */
+object StreamingJoinHelper extends PredicateHelper with Logging {
+ /**
+ * Get state value watermark (see [[StreamingSymmetricHashJoinExec]] for
context about it)
+ * given the join condition and the event time watermark. This is how it
works.
+ * - The condition is split into conjunctive predicates, and we find the
predicates of the
+ * form `leftTime + c1 < rightTime + c2` (or <=, >, >=).
+ * - We canoncalize the predicate and solve it with the event time
watermark value to find the
+ * value of the state watermark.
+ * This function is supposed to make best-effort attempt to get the
state watermark. If there is
+ * any error, it will return None.
+ *
+ * @param attributesToFindStateWatermarkFor attributes of the side whose
state watermark
+ * is to be calculated
+ * @param attributesWithEventWatermark attributes of the other side
which has a watermark column
+ * @param joinCondition join condition
+ * @param eventWatermark watermark defined on the input
event data
+ * @return state value watermark in milliseconds, is possible.
+ */
+ def getStateValueWatermark(
+ attributesToFindStateWatermarkFor: AttributeSet,
+ attributesWithEventWatermark: AttributeSet,
+ joinCondition: Option[Expression],
+ eventWatermark: Option[Long]): Option[Long] = {
+
+ // If condition or event time watermark is not provided, then cannot
calculate state watermark
+ if (joinCondition.isEmpty || eventWatermark.isEmpty) return None
+
+ // If there is not watermark attribute, then cannot define state
watermark
+ if
(!attributesWithEventWatermark.exists(_.metadata.contains(delayKey))) return
None
+
+ def getStateWatermarkSafely(l: Expression, r: Expression):
Option[Long] = {
+ try {
+ getStateWatermarkFromLessThenPredicate(
+ l, r, attributesToFindStateWatermarkFor,
attributesWithEventWatermark, eventWatermark)
+ } catch {
+ case NonFatal(e) =>
+ logWarning(s"Error trying to extract state constraint from
condition $joinCondition", e)
+ None
+ }
+ }
+
+ val allStateWatermarks =
splitConjunctivePredicates(joinCondition.get).flatMap { predicate =>
+
+ // The generated the state watermark cleanup expression is inclusive
of the state watermark.
+ // If state watermark is W, all state where timestamp <= W will be
cleaned up.
+ // Now when the canonicalized join condition solves to leftTime >=
W, we dont want to clean
+ // up leftTime <= W. Rather we should clean up leftTime <= W - 1.
Hence the -1 below.
+ val stateWatermark = predicate match {
+ case LessThan(l, r) => getStateWatermarkSafely(l, r)
+ case LessThanOrEqual(l, r) => getStateWatermarkSafely(l, r).map(_
- 1)
+ case GreaterThan(l, r) => getStateWatermarkSafely(r, l)
+ case GreaterThanOrEqual(l, r) => getStateWatermarkSafely(r,
l).map(_ - 1)
+ case _ => None
+ }
+ if (stateWatermark.nonEmpty) {
+ logInfo(s"Condition $joinCondition generated watermark constraint
= ${stateWatermark.get}")
+ }
+ stateWatermark
+ }
+ allStateWatermarks.reduceOption((x, y) => Math.min(x, y))
+ }
+
+ /**
+ * Extract the state value watermark (milliseconds) from the condition
+ * `LessThan(leftExpr, rightExpr)` where . For example: if we want to
find the constraint for
+ * leftTime using the watermark on the rightTime. Example:
+ *
+ * Input: rightTime-with-watermark + c1 < leftTime + c2
+ * Canonical form: rightTime-with-watermark + c1 + (-c2) +
(-leftTime) < 0
+ * Solving for rightTime: rightTime-with-watermark + c1 + (-c2) <
leftTime
+ * With watermark value: watermark-value + c1 + (-c2) < leftTime
+ */
+ private def getStateWatermarkFromLessThenPredicate(
+ leftExpr: Expression,
+ rightExpr: Expression,
+ attributesToFindStateWatermarkFor: AttributeSet,
+ attributesWithEventWatermark: AttributeSet,
+ eventWatermark: Option[Long]): Option[Long] = {
+
+ val attributesInCondition = AttributeSet(
+ leftExpr.collect { case a: AttributeReference => a } ++
+ rightExpr.collect { case a: AttributeReference => a }
+ )
+ if (attributesInCondition.filter {
attributesToFindStateWatermarkFor.contains(_) }.size > 1 ||
+ attributesInCondition.filter {
attributesWithEventWatermark.contains(_) }.size > 1) {
+ // If more than attributes present in condition from one side, then
it cannot be solved
+ return None
+ }
+
+ def containsAttributeToFindStateConstraintFor(e: Expression): Boolean
= {
+ e.collectLeaves().collectFirst {
+ case a @ AttributeReference(_, _, _, _)
+ if attributesToFindStateWatermarkFor.contains(a) => a
+ }.nonEmpty
+ }
+
+ // Canonicalization step 1: convert to (rightTime-with-watermark + c1)
- (leftTime + c2) < 0
+ val allOnLeftExpr = Subtract(leftExpr, rightExpr)
+ logDebug(s"All on
Left:\n${allOnLeftExpr.treeString(true)}\n${allOnLeftExpr.asCode}")
+
+ // Canonicalization step 2: extract commutative terms
+ // rightTime-with-watermark, c1, -leftTime, -c2
+ val terms = ExpressionSet(collectTerms(allOnLeftExpr))
+ logDebug("Terms extracted from join condition:\n\t" +
terms.mkString("\n\t"))
+
+
+
--- End diff --
remove 2 extra lines
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]