[FLINK-6233] [table] Add inner rowtime window join between two streams for SQL.
This closes #4625. Project: http://git-wip-us.apache.org/repos/asf/flink/repo Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/655d8b16 Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/655d8b16 Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/655d8b16 Branch: refs/heads/master Commit: 655d8b16193ac7131fa1f58fb4ba7ff96e439438 Parents: 9829ca0 Author: Xingcan Cui <[email protected]> Authored: Wed Aug 30 13:57:38 2017 +0800 Committer: Fabian Hueske <[email protected]> Committed: Tue Oct 10 23:09:07 2017 +0200 ---------------------------------------------------------------------- docs/dev/table/sql.md | 2 +- .../nodes/datastream/DataStreamWindowJoin.scala | 132 +++++- .../datastream/DataStreamWindowJoinRule.scala | 11 +- .../join/ProcTimeBoundedStreamInnerJoin.scala | 68 +++ .../runtime/join/ProcTimeWindowInnerJoin.scala | 346 ---------------- .../join/RowTimeBoundedStreamInnerJoin.scala | 82 ++++ .../join/TimeBoundedStreamInnerJoin.scala | 412 +++++++++++++++++++ .../table/runtime/join/WindowJoinUtil.scala | 40 +- .../flink/table/api/stream/sql/JoinTest.scala | 94 ++++- .../table/runtime/harness/JoinHarnessTest.scala | 305 +++++++++++--- .../table/runtime/stream/sql/JoinITCase.scala | 205 ++++++++- 11 files changed, 1230 insertions(+), 467 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/flink/blob/655d8b16/docs/dev/table/sql.md ---------------------------------------------------------------------- diff --git a/docs/dev/table/sql.md b/docs/dev/table/sql.md index b9205ab..533aa6e 100644 --- a/docs/dev/table/sql.md +++ b/docs/dev/table/sql.md @@ -409,7 +409,7 @@ FROM Orders LEFT JOIN Product ON Orders.productId = Product.id </ul> </p> - <p><b>Note:</b> Currently, only processing time window joins and <code>INNER</code> joins are supported.</p> + <p><b>Note:</b> Currently, only <code>INNER</code> joins are supported.</p> {% highlight sql %} SELECT * http://git-wip-us.apache.org/repos/asf/flink/blob/655d8b16/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamWindowJoin.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamWindowJoin.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamWindowJoin.scala index f8015b3..9358aa3 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamWindowJoin.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamWindowJoin.scala @@ -23,14 +23,20 @@ import org.apache.calcite.rel.`type`.RelDataType import org.apache.calcite.rel.core.{JoinInfo, JoinRelType} import org.apache.calcite.rel.{BiRel, RelNode, RelWriter} import org.apache.calcite.rex.RexNode +import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.flink.api.java.functions.NullByteKeySelector +import org.apache.flink.api.java.tuple.Tuple import org.apache.flink.streaming.api.datastream.DataStream +import org.apache.flink.streaming.api.functions.co.CoProcessFunction import org.apache.flink.table.api.{StreamQueryConfig, StreamTableEnvironment, TableException} import org.apache.flink.table.plan.nodes.CommonJoin import org.apache.flink.table.plan.schema.RowSchema import org.apache.flink.table.plan.util.UpdatingPlanChecker -import org.apache.flink.table.runtime.join.{ProcTimeWindowInnerJoin, WindowJoinUtil} +import org.apache.flink.table.runtime.join.{ProcTimeBoundedStreamInnerJoin, RowTimeBoundedStreamInnerJoin, WindowJoinUtil} +import org.apache.flink.table.runtime.operators.KeyedCoProcessOperatorWithWatermarkDelay import org.apache.flink.table.runtime.types.{CRow, CRowTypeInfo} +import org.apache.flink.table.util.Logging +import org.apache.flink.util.Collector /** * RelNode for a time windowed stream join. @@ -48,11 +54,14 @@ class DataStreamWindowJoin( isRowTime: Boolean, leftLowerBound: Long, leftUpperBound: Long, + leftTimeIdx: Int, + rightTimeIdx: Int, remainCondition: Option[RexNode], ruleDescription: String) extends BiRel(cluster, traitSet, leftNode, rightNode) with CommonJoin - with DataStreamRel { + with DataStreamRel + with Logging { override def deriveRowType(): RelDataType = schema.relDataType @@ -70,6 +79,8 @@ class DataStreamWindowJoin( isRowTime, leftLowerBound, leftUpperBound, + leftTimeIdx, + rightTimeIdx, remainCondition, ruleDescription) } @@ -107,10 +118,12 @@ class DataStreamWindowJoin( val leftDataStream = left.asInstanceOf[DataStreamRel].translateToPlan(tableEnv, queryConfig) val rightDataStream = right.asInstanceOf[DataStreamRel].translateToPlan(tableEnv, queryConfig) - // get the equality keys and other condition + // get the equi-keys and other conditions val joinInfo = JoinInfo.of(leftNode, rightNode, joinCondition) val leftKeys = joinInfo.leftKeys.toIntArray val rightKeys = joinInfo.rightKeys.toIntArray + val relativeWindowSize = leftUpperBound - leftLowerBound + val returnTypeInfo = CRowTypeInfo(schema.typeInfo) // generate join function val joinFunction = @@ -125,20 +138,32 @@ class DataStreamWindowJoin( joinType match { case JoinRelType.INNER => - if (isRowTime) { - // RowTime JoinCoProcessFunction - throw new TableException( - "RowTime inner join between stream and stream is not supported yet.") + if (relativeWindowSize < 0) { + LOG.warn(s"The relative window size $relativeWindowSize is negative," + + " please check the join conditions.") + createEmptyInnerJoin(leftDataStream, rightDataStream, returnTypeInfo) } else { - // Proctime JoinCoProcessFunction - createProcTimeInnerJoinFunction( - leftDataStream, - rightDataStream, - joinFunction.name, - joinFunction.code, - leftKeys, - rightKeys - ) + if (isRowTime) { + createRowTimeInnerJoin( + leftDataStream, + rightDataStream, + returnTypeInfo, + joinFunction.name, + joinFunction.code, + leftKeys, + rightKeys + ) + } else { + createProcTimeInnerJoin( + leftDataStream, + rightDataStream, + returnTypeInfo, + joinFunction.name, + joinFunction.code, + leftKeys, + rightKeys + ) + } } case JoinRelType.FULL => throw new TableException( @@ -152,19 +177,40 @@ class DataStreamWindowJoin( } } - def createProcTimeInnerJoinFunction( + def createEmptyInnerJoin( + leftDataStream: DataStream[CRow], + rightDataStream: DataStream[CRow], + returnTypeInfo: TypeInformation[CRow]): DataStream[CRow] = { + leftDataStream.connect(rightDataStream).process( + new CoProcessFunction[CRow, CRow, CRow] { + override def processElement1( + value: CRow, + ctx: CoProcessFunction[CRow, CRow, CRow]#Context, + out: Collector[CRow]): Unit = { + //Do nothing. + } + override def processElement2( + value: CRow, + ctx: CoProcessFunction[CRow, CRow, CRow]#Context, + out: Collector[CRow]): Unit = { + //Do nothing. + } + }).returns(returnTypeInfo) + } + + def createProcTimeInnerJoin( leftDataStream: DataStream[CRow], rightDataStream: DataStream[CRow], + returnTypeInfo: TypeInformation[CRow], joinFunctionName: String, joinFunctionCode: String, leftKeys: Array[Int], rightKeys: Array[Int]): DataStream[CRow] = { - val returnTypeInfo = CRowTypeInfo(schema.typeInfo) - - val procInnerJoinFunc = new ProcTimeWindowInnerJoin( + val procInnerJoinFunc = new ProcTimeBoundedStreamInnerJoin( leftLowerBound, leftUpperBound, + allowedLateness = 0L, leftSchema.typeInfo, rightSchema.typeInfo, joinFunctionName, @@ -184,4 +230,50 @@ class DataStreamWindowJoin( .returns(returnTypeInfo) } } + + def createRowTimeInnerJoin( + leftDataStream: DataStream[CRow], + rightDataStream: DataStream[CRow], + returnTypeInfo: TypeInformation[CRow], + joinFunctionName: String, + joinFunctionCode: String, + leftKeys: Array[Int], + rightKeys: Array[Int]): DataStream[CRow] = { + + val rowTimeInnerJoinFunc = new RowTimeBoundedStreamInnerJoin( + leftLowerBound, + leftUpperBound, + allowedLateness = 0L, + leftSchema.typeInfo, + rightSchema.typeInfo, + joinFunctionName, + joinFunctionCode, + leftTimeIdx, + rightTimeIdx) + + if (!leftKeys.isEmpty) { + leftDataStream + .connect(rightDataStream) + .keyBy(leftKeys, rightKeys) + .transform( + "InnerRowtimeWindowJoin", + returnTypeInfo, + new KeyedCoProcessOperatorWithWatermarkDelay[Tuple, CRow, CRow, CRow]( + rowTimeInnerJoinFunc, + rowTimeInnerJoinFunc.getMaxOutputDelay) + ) + } else { + leftDataStream.connect(rightDataStream) + .keyBy(new NullByteKeySelector[CRow](), new NullByteKeySelector[CRow]) + .transform( + "InnerRowtimeWindowJoin", + returnTypeInfo, + new KeyedCoProcessOperatorWithWatermarkDelay[java.lang.Byte, CRow, CRow, CRow]( + rowTimeInnerJoinFunc, + rowTimeInnerJoinFunc.getMaxOutputDelay) + ) + .setParallelism(1) + .setMaxParallelism(1) + } + } } http://git-wip-us.apache.org/repos/asf/flink/blob/655d8b16/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/datastream/DataStreamWindowJoinRule.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/datastream/DataStreamWindowJoinRule.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/datastream/DataStreamWindowJoinRule.scala index 7dfcbc5..d208d2b 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/datastream/DataStreamWindowJoinRule.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/datastream/DataStreamWindowJoinRule.scala @@ -40,10 +40,9 @@ class DataStreamWindowJoinRule override def matches(call: RelOptRuleCall): Boolean = { val join: FlinkLogicalJoin = call.rel(0).asInstanceOf[FlinkLogicalJoin] - val joinInfo = join.analyzeCondition val (windowBounds, remainingPreds) = WindowJoinUtil.extractWindowBoundsFromPredicate( - joinInfo.getRemaining(join.getCluster.getRexBuilder), + join.getCondition, join.getLeft.getRowType.getFieldCount, join.getRowType, join.getCluster.getRexBuilder, @@ -55,8 +54,7 @@ class DataStreamWindowJoinRule if (windowBounds.isDefined) { if (windowBounds.get.isEventTime) { - // we cannot handle event-time window joins yet - false + !remainingPredsAccessTime } else { // Check that no event-time attributes are in the input. // The proc-time join implementation does ensure that record timestamp are correctly set. @@ -80,13 +78,12 @@ class DataStreamWindowJoinRule val traitSet: RelTraitSet = rel.getTraitSet.replace(FlinkConventions.DATASTREAM) val convLeft: RelNode = RelOptRule.convert(join.getInput(0), FlinkConventions.DATASTREAM) val convRight: RelNode = RelOptRule.convert(join.getInput(1), FlinkConventions.DATASTREAM) - val joinInfo = join.analyzeCondition val leftRowSchema = new RowSchema(convLeft.getRowType) val rightRowSchema = new RowSchema(convRight.getRowType) val (windowBounds, remainCondition) = WindowJoinUtil.extractWindowBoundsFromPredicate( - joinInfo.getRemaining(join.getCluster.getRexBuilder), + join.getCondition, leftRowSchema.arity, join.getRowType, join.getCluster.getRexBuilder, @@ -105,6 +102,8 @@ class DataStreamWindowJoinRule windowBounds.get.isEventTime, windowBounds.get.leftLowerBound, windowBounds.get.leftUpperBound, + windowBounds.get.leftTimeIdx, + windowBounds.get.rightTimeIdx, remainCondition, description) } http://git-wip-us.apache.org/repos/asf/flink/blob/655d8b16/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/ProcTimeBoundedStreamInnerJoin.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/ProcTimeBoundedStreamInnerJoin.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/ProcTimeBoundedStreamInnerJoin.scala new file mode 100644 index 0000000..ab5a9c3 --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/ProcTimeBoundedStreamInnerJoin.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.join + +import org.apache.flink.api.common.typeinfo.TypeInformation +import org.apache.flink.streaming.api.functions.co.CoProcessFunction +import org.apache.flink.table.runtime.types.CRow +import org.apache.flink.types.Row + +/** + * The function to execute processing time bounded stream inner-join. + */ +final class ProcTimeBoundedStreamInnerJoin( + leftLowerBound: Long, + leftUpperBound: Long, + allowedLateness: Long, + leftType: TypeInformation[Row], + rightType: TypeInformation[Row], + genJoinFuncName: String, + genJoinFuncCode: String) + extends TimeBoundedStreamInnerJoin( + leftLowerBound, + leftUpperBound, + allowedLateness, + leftType, + rightType, + genJoinFuncName, + genJoinFuncCode) { + + override def updateOperatorTime(ctx: CoProcessFunction[CRow, CRow, CRow]#Context): Unit = { + leftOperatorTime = ctx.timerService().currentProcessingTime() + rightOperatorTime = leftOperatorTime + } + + override def getTimeForLeftStream( + context: CoProcessFunction[CRow, CRow, CRow]#Context, + row: Row): Long = { + leftOperatorTime + } + + override def getTimeForRightStream( + context: CoProcessFunction[CRow, CRow, CRow]#Context, + row: Row): Long = { + rightOperatorTime + } + + override def registerTimer( + ctx: CoProcessFunction[CRow, CRow, CRow]#Context, + cleanupTime: Long): Unit = { + ctx.timerService.registerProcessingTimeTimer(cleanupTime) + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/655d8b16/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/ProcTimeWindowInnerJoin.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/ProcTimeWindowInnerJoin.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/ProcTimeWindowInnerJoin.scala deleted file mode 100644 index 8240376..0000000 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/ProcTimeWindowInnerJoin.scala +++ /dev/null @@ -1,346 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.table.runtime.join - -import java.util -import java.util.{List => JList} - -import org.apache.flink.api.common.functions.FlatJoinFunction -import org.apache.flink.api.common.state._ -import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation} -import org.apache.flink.api.java.typeutils.ListTypeInfo -import org.apache.flink.configuration.Configuration -import org.apache.flink.streaming.api.functions.co.CoProcessFunction -import org.apache.flink.table.codegen.Compiler -import org.apache.flink.table.runtime.CRowWrappingCollector -import org.apache.flink.table.runtime.types.CRow -import org.apache.flink.table.util.Logging -import org.apache.flink.types.Row -import org.apache.flink.util.Collector - -/** - * A CoProcessFunction to support stream join stream, currently just support inner-join - * - * @param leftLowerBound - * the left stream lower bound, and -leftLowerBound is the right stream upper bound - * @param leftUpperBound - * the left stream upper bound, and -leftUpperBound is the right stream lower bound - * @param element1Type the input type of left stream - * @param element2Type the input type of right stream - * @param genJoinFuncName the function code of other non-equi condition - * @param genJoinFuncCode the function name of other non-equi condition - * - */ -class ProcTimeWindowInnerJoin( - private val leftLowerBound: Long, - private val leftUpperBound: Long, - private val element1Type: TypeInformation[Row], - private val element2Type: TypeInformation[Row], - private val genJoinFuncName: String, - private val genJoinFuncCode: String) - extends CoProcessFunction[CRow, CRow, CRow] - with Compiler[FlatJoinFunction[Row, Row, Row]] - with Logging { - - private var cRowWrapper: CRowWrappingCollector = _ - - // other condition function - private var joinFunction: FlatJoinFunction[Row, Row, Row] = _ - - // tmp list to store expired records - private var removeList: JList[Long] = _ - - // state to hold left stream element - private var row1MapState: MapState[Long, JList[Row]] = _ - // state to hold right stream element - private var row2MapState: MapState[Long, JList[Row]] = _ - - // state to record last timer of left stream, 0 means no timer - private var timerState1: ValueState[Long] = _ - // state to record last timer of right stream, 0 means no timer - private var timerState2: ValueState[Long] = _ - - // compute window sizes, i.e., how long to keep rows in state. - // window size of -1 means rows do not need to be put into state. - private val leftStreamWinSize: Long = if (leftLowerBound <= 0) -leftLowerBound else -1 - private val rightStreamWinSize: Long = if (leftUpperBound >= 0) leftUpperBound else -1 - - override def open(config: Configuration) { - LOG.debug(s"Compiling JoinFunction: $genJoinFuncName \n\n " + - s"Code:\n$genJoinFuncCode") - val clazz = compile( - getRuntimeContext.getUserCodeClassLoader, - genJoinFuncName, - genJoinFuncCode) - LOG.debug("Instantiating JoinFunction.") - joinFunction = clazz.newInstance() - - removeList = new util.ArrayList[Long]() - cRowWrapper = new CRowWrappingCollector() - cRowWrapper.setChange(true) - - // initialize row state - val rowListTypeInfo1: TypeInformation[JList[Row]] = new ListTypeInfo[Row](element1Type) - val mapStateDescriptor1: MapStateDescriptor[Long, JList[Row]] = - new MapStateDescriptor[Long, JList[Row]]("row1mapstate", - BasicTypeInfo.LONG_TYPE_INFO.asInstanceOf[TypeInformation[Long]], rowListTypeInfo1) - row1MapState = getRuntimeContext.getMapState(mapStateDescriptor1) - - val rowListTypeInfo2: TypeInformation[JList[Row]] = new ListTypeInfo[Row](element2Type) - val mapStateDescriptor2: MapStateDescriptor[Long, JList[Row]] = - new MapStateDescriptor[Long, JList[Row]]("row2mapstate", - BasicTypeInfo.LONG_TYPE_INFO.asInstanceOf[TypeInformation[Long]], rowListTypeInfo2) - row2MapState = getRuntimeContext.getMapState(mapStateDescriptor2) - - // initialize timer state - val valueStateDescriptor1: ValueStateDescriptor[Long] = - new ValueStateDescriptor[Long]("timervaluestate1", classOf[Long]) - timerState1 = getRuntimeContext.getState(valueStateDescriptor1) - - val valueStateDescriptor2: ValueStateDescriptor[Long] = - new ValueStateDescriptor[Long]("timervaluestate2", classOf[Long]) - timerState2 = getRuntimeContext.getState(valueStateDescriptor2) - } - - /** - * Process left stream records - * - * @param valueC The input value. - * @param ctx The ctx to register timer or get current time - * @param out The collector for returning result values. - * - */ - override def processElement1( - valueC: CRow, - ctx: CoProcessFunction[CRow, CRow, CRow]#Context, - out: Collector[CRow]): Unit = { - - processElement( - valueC, - ctx, - out, - leftStreamWinSize, - timerState1, - row1MapState, - row2MapState, - -leftUpperBound, // right stream lower - -leftLowerBound, // right stream upper - isLeft = true - ) - } - - /** - * Process right stream records - * - * @param valueC The input value. - * @param ctx The ctx to register timer or get current time - * @param out The collector for returning result values. - * - */ - override def processElement2( - valueC: CRow, - ctx: CoProcessFunction[CRow, CRow, CRow]#Context, - out: Collector[CRow]): Unit = { - - processElement( - valueC, - ctx, - out, - rightStreamWinSize, - timerState2, - row2MapState, - row1MapState, - leftLowerBound, // left stream lower - leftUpperBound, // left stream upper - isLeft = false - ) - } - - /** - * Called when a processing timer trigger. - * Expire left/right records which earlier than current time - windowsize. - * - * @param timestamp The timestamp of the firing timer. - * @param ctx The ctx to register timer or get current time - * @param out The collector for returning result values. - */ - override def onTimer( - timestamp: Long, - ctx: CoProcessFunction[CRow, CRow, CRow]#OnTimerContext, - out: Collector[CRow]): Unit = { - - if (timerState1.value == timestamp) { - expireOutTimeRow( - timestamp, - leftStreamWinSize, - row1MapState, - timerState1, - ctx - ) - } - - if (timerState2.value == timestamp) { - expireOutTimeRow( - timestamp, - rightStreamWinSize, - row2MapState, - timerState2, - ctx - ) - } - } - - /** - * Puts an element from the input stream into state and search the other state to - * output records meet the condition, and registers a timer for the current record - * if there is no timer at present. - */ - private def processElement( - valueC: CRow, - ctx: CoProcessFunction[CRow, CRow, CRow]#Context, - out: Collector[CRow], - winSize: Long, - timerState: ValueState[Long], - rowMapState: MapState[Long, JList[Row]], - otherRowMapState: MapState[Long, JList[Row]], - otherLowerBound: Long, - otherUpperBound: Long, - isLeft: Boolean): Unit = { - - cRowWrapper.out = out - - val row = valueC.row - - val curProcessTime = ctx.timerService.currentProcessingTime - val otherLowerTime = curProcessTime + otherLowerBound - val otherUpperTime = curProcessTime + otherUpperBound - - if (winSize >= 0) { - // put row into state for later joining. - // (winSize == 0) joins rows received in the same millisecond. - var rowList = rowMapState.get(curProcessTime) - if (rowList == null) { - rowList = new util.ArrayList[Row]() - } - rowList.add(row) - rowMapState.put(curProcessTime, rowList) - - // register a timer to remove the row from state once it is expired - if (timerState.value == 0) { - val cleanupTime = curProcessTime + winSize + 1 - ctx.timerService.registerProcessingTimeTimer(cleanupTime) - timerState.update(cleanupTime) - } - } - - // join row with rows received from the other input - val otherTimeIter = otherRowMapState.keys().iterator() - if (isLeft) { - // go over all timestamps in the other input's state - while (otherTimeIter.hasNext) { - val otherTimestamp = otherTimeIter.next() - if (otherTimestamp < otherLowerTime) { - // other timestamp is expired. Remove it later. - removeList.add(otherTimestamp) - } else if (otherTimestamp <= otherUpperTime) { - // join row with all rows from the other input for this timestamp - val otherRows = otherRowMapState.get(otherTimestamp) - var i = 0 - while (i < otherRows.size) { - joinFunction.join(row, otherRows.get(i), cRowWrapper) - i += 1 - } - } - } - } else { - // go over all timestamps in the other input's state - while (otherTimeIter.hasNext) { - val otherTimestamp = otherTimeIter.next() - if (otherTimestamp < otherLowerTime) { - // other timestamp is expired. Remove it later. - removeList.add(otherTimestamp) - } else if (otherTimestamp <= otherUpperTime) { - // join row with all rows from the other input for this timestamp - val otherRows = otherRowMapState.get(otherTimestamp) - var i = 0 - while (i < otherRows.size) { - joinFunction.join(otherRows.get(i), row, cRowWrapper) - i += 1 - } - } - } - } - - // remove rows for expired timestamps - var i = removeList.size - 1 - while (i >= 0) { - otherRowMapState.remove(removeList.get(i)) - i -= 1 - } - removeList.clear() - } - - /** - * Removes records which are outside the join window from the state. - * Registers a new timer if the state still holds records after the clean-up. - */ - private def expireOutTimeRow( - curTime: Long, - winSize: Long, - rowMapState: MapState[Long, JList[Row]], - timerState: ValueState[Long], - ctx: CoProcessFunction[CRow, CRow, CRow]#OnTimerContext): Unit = { - - val expiredTime = curTime - winSize - val keyIter = rowMapState.keys().iterator() - var validTimestamp: Boolean = false - // Search for expired timestamps. - // If we find a non-expired timestamp, remember the timestamp and leave the loop. - // This way we find all expired timestamps if they are sorted without doing a full pass. - while (keyIter.hasNext && !validTimestamp) { - val recordTime = keyIter.next - if (recordTime < expiredTime) { - removeList.add(recordTime) - } else { - // we found a timestamp that is still valid - validTimestamp = true - } - } - - // If the state has non-expired timestamps, register a new timer. - // Otherwise clean the complete state for this input. - if (validTimestamp) { - - // Remove expired records from state - var i = removeList.size - 1 - while (i >= 0) { - rowMapState.remove(removeList.get(i)) - i -= 1 - } - removeList.clear() - - val cleanupTime = curTime + winSize + 1 - ctx.timerService.registerProcessingTimeTimer(cleanupTime) - timerState.update(cleanupTime) - } else { - timerState.clear() - rowMapState.clear() - } - } -} http://git-wip-us.apache.org/repos/asf/flink/blob/655d8b16/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/RowTimeBoundedStreamInnerJoin.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/RowTimeBoundedStreamInnerJoin.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/RowTimeBoundedStreamInnerJoin.scala new file mode 100644 index 0000000..5cf5a53 --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/RowTimeBoundedStreamInnerJoin.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.join + +import org.apache.flink.api.common.typeinfo.TypeInformation +import org.apache.flink.streaming.api.functions.co.CoProcessFunction +import org.apache.flink.table.runtime.types.CRow +import org.apache.flink.types.Row + +/** + * The function to execute row(event) time bounded stream inner-join. + */ +final class RowTimeBoundedStreamInnerJoin( + leftLowerBound: Long, + leftUpperBound: Long, + allowedLateness: Long, + leftType: TypeInformation[Row], + rightType: TypeInformation[Row], + genJoinFuncName: String, + genJoinFuncCode: String, + leftTimeIdx: Int, + rightTimeIdx: Int) + extends TimeBoundedStreamInnerJoin( + leftLowerBound, + leftUpperBound, + allowedLateness, + leftType, + rightType, + genJoinFuncName, + genJoinFuncCode) { + + /** + * Get the maximum interval between receiving a row and emitting it (as part of a joined result). + * Only reasonable for row time join. + * + * @return the maximum delay for the outputs + */ + def getMaxOutputDelay: Long = Math.max(leftRelativeSize, rightRelativeSize) + allowedLateness + + override def updateOperatorTime(ctx: CoProcessFunction[CRow, CRow, CRow]#Context): Unit = { + leftOperatorTime = + if (ctx.timerService().currentWatermark() > 0) ctx.timerService().currentWatermark() + else 0L + // We may set different operator times in the future. + rightOperatorTime = leftOperatorTime + } + + override def getTimeForLeftStream( + context: CoProcessFunction[CRow, CRow, CRow]#Context, + row: Row): Long = { + row.getField(leftTimeIdx).asInstanceOf[Long] + } + + override def getTimeForRightStream( + context: CoProcessFunction[CRow, CRow, CRow]#Context, + row: Row): Long = { + row.getField(rightTimeIdx).asInstanceOf[Long] + } + + override def registerTimer( + ctx: CoProcessFunction[CRow, CRow, CRow]#Context, + cleanupTime: Long): Unit = { + // Maybe we can register timers for different streams in the future. + ctx.timerService.registerEventTimeTimer(cleanupTime) + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/655d8b16/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/TimeBoundedStreamInnerJoin.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/TimeBoundedStreamInnerJoin.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/TimeBoundedStreamInnerJoin.scala new file mode 100644 index 0000000..7bf3d33 --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/TimeBoundedStreamInnerJoin.scala @@ -0,0 +1,412 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.join + +import java.util +import java.util.{List => JList} + +import org.apache.flink.api.common.functions.FlatJoinFunction +import org.apache.flink.api.common.state._ +import org.apache.flink.api.common.typeinfo.TypeInformation +import org.apache.flink.api.java.typeutils.ListTypeInfo +import org.apache.flink.configuration.Configuration +import org.apache.flink.streaming.api.functions.co.CoProcessFunction +import org.apache.flink.table.api.Types +import org.apache.flink.table.codegen.Compiler +import org.apache.flink.table.runtime.CRowWrappingCollector +import org.apache.flink.table.runtime.types.CRow +import org.apache.flink.table.util.Logging +import org.apache.flink.types.Row +import org.apache.flink.util.Collector + +/** + * A CoProcessFunction to execute time-bounded stream inner-join. + * Two kinds of time criteria: + * "L.time between R.time + X and R.time + Y" or "R.time between L.time - Y and L.time - X". + * + * @param leftLowerBound the lower bound for the left stream (X in the criteria) + * @param leftUpperBound the upper bound for the left stream (Y in the criteria) + * @param allowedLateness the lateness allowed for the two streams + * @param leftType the input type of left stream + * @param rightType the input type of right stream + * @param genJoinFuncName the function code of other non-equi conditions + * @param genJoinFuncCode the function name of other non-equi conditions + * + */ +abstract class TimeBoundedStreamInnerJoin( + private val leftLowerBound: Long, + private val leftUpperBound: Long, + private val allowedLateness: Long, + private val leftType: TypeInformation[Row], + private val rightType: TypeInformation[Row], + private val genJoinFuncName: String, + private val genJoinFuncCode: String) + extends CoProcessFunction[CRow, CRow, CRow] + with Compiler[FlatJoinFunction[Row, Row, Row]] + with Logging { + + private var cRowWrapper: CRowWrappingCollector = _ + + // the join function for other conditions + private var joinFunction: FlatJoinFunction[Row, Row, Row] = _ + + // cache to store rows from the left stream + private var leftCache: MapState[Long, JList[Row]] = _ + // cache to store rows from the right stream + private var rightCache: MapState[Long, JList[Row]] = _ + + // state to record the timer on the left stream. 0 means no timer set + private var leftTimerState: ValueState[Long] = _ + // state to record the timer on the right stream. 0 means no timer set + private var rightTimerState: ValueState[Long] = _ + + protected val leftRelativeSize: Long = -leftLowerBound + protected val rightRelativeSize: Long = leftUpperBound + + private var leftExpirationTime: Long = 0L + private var rightExpirationTime: Long = 0L + + protected var leftOperatorTime: Long = 0L + protected var rightOperatorTime: Long = 0L + + + // for delayed cleanup + private val cleanupDelay = (leftRelativeSize + rightRelativeSize) / 2 + + if (allowedLateness < 0) { + throw new IllegalArgumentException("The allowed lateness must be non-negative.") + } + + override def open(config: Configuration) { + LOG.debug(s"Compiling JoinFunction: $genJoinFuncName \n\n " + + s"Code:\n$genJoinFuncCode") + val clazz = compile( + getRuntimeContext.getUserCodeClassLoader, + genJoinFuncName, + genJoinFuncCode) + LOG.debug("Instantiating JoinFunction.") + joinFunction = clazz.newInstance() + + cRowWrapper = new CRowWrappingCollector() + cRowWrapper.setChange(true) + + // Initialize the data caches. + val leftListTypeInfo: TypeInformation[JList[Row]] = new ListTypeInfo[Row](leftType) + val leftStateDescriptor: MapStateDescriptor[Long, JList[Row]] = + new MapStateDescriptor[Long, JList[Row]]( + "InnerJoinLeftCache", + Types.LONG.asInstanceOf[TypeInformation[Long]], + leftListTypeInfo) + leftCache = getRuntimeContext.getMapState(leftStateDescriptor) + + val rightListTypeInfo: TypeInformation[JList[Row]] = new ListTypeInfo[Row](rightType) + val rightStateDescriptor: MapStateDescriptor[Long, JList[Row]] = + new MapStateDescriptor[Long, JList[Row]]( + "InnerJoinRightCache", + Types.LONG.asInstanceOf[TypeInformation[Long]], + rightListTypeInfo) + rightCache = getRuntimeContext.getMapState(rightStateDescriptor) + + // Initialize the timer states. + val leftTimerStateDesc: ValueStateDescriptor[Long] = + new ValueStateDescriptor[Long]("InnerJoinLeftTimerState", classOf[Long]) + leftTimerState = getRuntimeContext.getState(leftTimerStateDesc) + + val rightTimerStateDesc: ValueStateDescriptor[Long] = + new ValueStateDescriptor[Long]("InnerJoinRightTimerState", classOf[Long]) + rightTimerState = getRuntimeContext.getState(rightTimerStateDesc) + } + + /** + * Process rows from the left stream. + */ + override def processElement1( + cRowValue: CRow, + ctx: CoProcessFunction[CRow, CRow, CRow]#Context, + out: Collector[CRow]): Unit = { + updateOperatorTime(ctx) + val leftRow = cRowValue.row + val timeForLeftRow: Long = getTimeForLeftStream(ctx, leftRow) + val rightQualifiedLowerBound: Long = timeForLeftRow - rightRelativeSize + val rightQualifiedUpperBound: Long = timeForLeftRow + leftRelativeSize + cRowWrapper.out = out + // Check if we need to cache the current row. + if (rightOperatorTime < rightQualifiedUpperBound) { + // Operator time of right stream has not exceeded the upper window bound of the current + // row. Put it into the left cache, since later coming records from the right stream are + // expected to be joined with it. + var leftRowList = leftCache.get(timeForLeftRow) + if (null == leftRowList) { + leftRowList = new util.ArrayList[Row](1) + } + leftRowList.add(leftRow) + leftCache.put(timeForLeftRow, leftRowList) + if (rightTimerState.value == 0) { + // Register a timer on the RIGHT stream to remove rows. + registerCleanUpTimer(ctx, timeForLeftRow, leftRow = true) + } + } + // Check if we need to join the current row against cached rows of the right input. + // The condition here should be rightMinimumTime < rightQualifiedUpperBound. + // I use rightExpirationTime as an approximation of the rightMinimumTime here, + // since rightExpirationTime <= rightMinimumTime is always true. + if (rightExpirationTime < rightQualifiedUpperBound) { + // Upper bound of current join window has not passed the cache expiration time yet. + // There might be qualifying rows in the cache that the current row needs to be joined with. + rightExpirationTime = calExpirationTime(leftOperatorTime, rightRelativeSize) + // Join the leftRow with rows from the right cache. + val rightIterator = rightCache.iterator() + while (rightIterator.hasNext) { + val rightEntry = rightIterator.next + val rightTime = rightEntry.getKey + if (rightTime >= rightQualifiedLowerBound && rightTime <= rightQualifiedUpperBound) { + val rightRows = rightEntry.getValue + var i = 0 + while (i < rightRows.size) { + joinFunction.join(leftRow, rightRows.get(i), cRowWrapper) + i += 1 + } + } + + if (rightTime <= rightExpirationTime) { + // eager remove + rightIterator.remove() + }// We could do the short-cutting optimization here once we get a state with ordered keys. + } + } + } + + /** + * Process rows from the right stream. + */ + override def processElement2( + cRowValue: CRow, + ctx: CoProcessFunction[CRow, CRow, CRow]#Context, + out: Collector[CRow]): Unit = { + updateOperatorTime(ctx) + val rightRow = cRowValue.row + val timeForRightRow: Long = getTimeForRightStream(ctx, rightRow) + val leftQualifiedLowerBound: Long = timeForRightRow - leftRelativeSize + val leftQualifiedUpperBound: Long = timeForRightRow + rightRelativeSize + cRowWrapper.out = out + // Check if we need to cache the current row. + if (leftOperatorTime < leftQualifiedUpperBound) { + // Operator time of left stream has not exceeded the upper window bound of the current + // row. Put it into the right cache, since later coming records from the left stream are + // expected to be joined with it. + var rightRowList = rightCache.get(timeForRightRow) + if (null == rightRowList) { + rightRowList = new util.ArrayList[Row](1) + } + rightRowList.add(rightRow) + rightCache.put(timeForRightRow, rightRowList) + if (leftTimerState.value == 0) { + // Register a timer on the LEFT stream to remove rows. + registerCleanUpTimer(ctx, timeForRightRow, leftRow = false) + } + } + // Check if we need to join the current row against cached rows of the left input. + // The condition here should be leftMinimumTime < leftQualifiedUpperBound. + // I use leftExpirationTime as an approximation of the leftMinimumTime here, + // since leftExpirationTime <= leftMinimumTime is always true. + if (leftExpirationTime < leftQualifiedUpperBound) { + leftExpirationTime = calExpirationTime(rightOperatorTime, leftRelativeSize) + // Join the rightRow with rows from the left cache. + val leftIterator = leftCache.iterator() + while (leftIterator.hasNext) { + val leftEntry = leftIterator.next + val leftTime = leftEntry.getKey + if (leftTime >= leftQualifiedLowerBound && leftTime <= leftQualifiedUpperBound) { + val leftRows = leftEntry.getValue + var i = 0 + while (i < leftRows.size) { + joinFunction.join(leftRows.get(i), rightRow, cRowWrapper) + i += 1 + } + } + if (leftTime <= leftExpirationTime) { + // eager remove + leftIterator.remove() + } // We could do the short-cutting optimization here once we get a state with ordered keys. + } + } + } + + /** + * Called when a registered timer is fired. + * Remove rows whose timestamps are earlier than the expiration time, + * and register a new timer for the remaining rows. + * + * @param timestamp the timestamp of the timer + * @param ctx the context to register timer or get current time + * @param out the collector for returning result values + */ + override def onTimer( + timestamp: Long, + ctx: CoProcessFunction[CRow, CRow, CRow]#OnTimerContext, + out: Collector[CRow]): Unit = { + updateOperatorTime(ctx) + // In the future, we should separate the left and right watermarks. Otherwise, the + // registered timer of the faster stream will be delayed, even if the watermarks have + // already been emitted by the source. + if (leftTimerState.value == timestamp) { + rightExpirationTime = calExpirationTime(leftOperatorTime, rightRelativeSize) + removeExpiredRows( + rightExpirationTime, + rightCache, + leftTimerState, + ctx, + removeLeft = false + ) + } + + if (rightTimerState.value == timestamp) { + leftExpirationTime = calExpirationTime(rightOperatorTime, leftRelativeSize) + removeExpiredRows( + leftExpirationTime, + leftCache, + rightTimerState, + ctx, + removeLeft = true + ) + } + } + + /** + * Calculate the expiration time with the given operator time and relative window size. + * + * @param operatorTime the operator time + * @param relativeSize the relative window size + * @return the expiration time for cached rows + */ + private def calExpirationTime(operatorTime: Long, relativeSize: Long): Long = { + if (operatorTime < Long.MaxValue) { + operatorTime - relativeSize - allowedLateness - 1 + } else { + // When operatorTime = Long.MaxValue, it means the stream has reached the end. + Long.MaxValue + } + } + + /** + * Register a timer for cleaning up rows in a specified time. + * + * @param ctx the context to register timer + * @param rowTime time for the input row + * @param leftRow whether this row comes from the left stream + */ + private def registerCleanUpTimer( + ctx: CoProcessFunction[CRow, CRow, CRow]#Context, + rowTime: Long, + leftRow: Boolean): Unit = { + if (leftRow) { + val cleanupTime = rowTime + leftRelativeSize + cleanupDelay + allowedLateness + 1 + registerTimer(ctx, cleanupTime) + rightTimerState.update(cleanupTime) + } else { + val cleanupTime = rowTime + rightRelativeSize + cleanupDelay + allowedLateness + 1 + registerTimer(ctx, cleanupTime) + leftTimerState.update(cleanupTime) + } + } + + /** + * Remove the expired rows. Register a new timer if the cache still holds valid rows + * after the cleaning up. + * + * @param expirationTime the expiration time for this cache + * @param rowCache the row cache + * @param timerState timer state for the opposite stream + * @param ctx the context to register the cleanup timer + * @param removeLeft whether to remove the left rows + */ + private def removeExpiredRows( + expirationTime: Long, + rowCache: MapState[Long, JList[Row]], + timerState: ValueState[Long], + ctx: CoProcessFunction[CRow, CRow, CRow]#OnTimerContext, + removeLeft: Boolean): Unit = { + + val keysIterator = rowCache.keys().iterator() + + var earliestTimestamp: Long = -1L + var rowTime: Long = 0L + + // We remove all expired keys and do not leave the loop early. + // Hence, we do a full pass over the state. + while (keysIterator.hasNext) { + rowTime = keysIterator.next + if (rowTime <= expirationTime) { + keysIterator.remove() + } else { + // We find the earliest timestamp that is still valid. + if (rowTime < earliestTimestamp || earliestTimestamp < 0) { + earliestTimestamp = rowTime + } + } + } + if (earliestTimestamp > 0) { + // There are rows left in the cache. Register a timer to expire them later. + registerCleanUpTimer( + ctx, + earliestTimestamp, + removeLeft) + } else { + // No rows left in the cache. Clear the states and the timerState will be 0. + timerState.clear() + rowCache.clear() + } + } + + /** + * Update the operator time of the two streams. + * Must be the first call in all processing methods (i.e., processElement(), onTimer()). + * + * @param ctx the context to acquire watermarks + */ + def updateOperatorTime(ctx: CoProcessFunction[CRow, CRow, CRow]#Context): Unit + + /** + * Return the time for the target row from the left stream. + * + * @param context the runtime context + * @param row the target row + * @return time for the target row + */ + def getTimeForLeftStream(context: CoProcessFunction[CRow, CRow, CRow]#Context, row: Row): Long + + /** + * Return the time for the target row from the right stream. + * + * @param context the runtime context + * @param row the target row + * @return time for the target row + */ + def getTimeForRightStream(context: CoProcessFunction[CRow, CRow, CRow]#Context, row: Row): Long + + /** + * Register a proctime or rowtime timer. + * + * @param ctx the context to register the timer + * @param cleanupTime timestamp for the timer + */ + def registerTimer( + ctx: CoProcessFunction[CRow, CRow, CRow]#Context, + cleanupTime: Long): Unit +} http://git-wip-us.apache.org/repos/asf/flink/blob/655d8b16/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/WindowJoinUtil.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/WindowJoinUtil.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/WindowJoinUtil.scala index b566113..6f97f2a 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/WindowJoinUtil.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/WindowJoinUtil.scala @@ -39,14 +39,23 @@ import scala.collection.JavaConverters._ */ object WindowJoinUtil { - case class WindowBounds(isEventTime: Boolean, leftLowerBound: Long, leftUpperBound: Long) + case class WindowBounds( + isEventTime: Boolean, + leftLowerBound: Long, + leftUpperBound: Long, + leftTimeIdx: Int, + rightTimeIdx: Int) protected case class WindowBound(bound: Long, isLeftLower: Boolean) + protected case class TimePredicate( isEventTime: Boolean, leftInputOnLeftSide: Boolean, + leftTimeIdx: Int, + rightTimeIdx: Int, pred: RexCall) - protected case class TimeAttributeAccess(isEventTime: Boolean, isLeftInput: Boolean) + + protected case class TimeAttributeAccess(isEventTime: Boolean, isLeftInput: Boolean, idx: Int) /** * Extracts the window bounds from a join predicate. @@ -116,7 +125,21 @@ object WindowJoinUtil { Some(otherPreds.reduceLeft((l, r) => RelOptUtil.andJoinFilters(rexBuilder, l, r))) } - val bounds = Some(WindowBounds(timePreds.head.isEventTime, leftLowerBound, leftUpperBound)) + val bounds = if (timePreds.head.leftInputOnLeftSide) { + Some(WindowBounds( + timePreds.head.isEventTime, + leftLowerBound, + leftUpperBound, + timePreds.head.leftTimeIdx, + timePreds.head.rightTimeIdx)) + } else { + Some(WindowBounds( + timePreds.head.isEventTime, + leftLowerBound, + leftUpperBound, + timePreds.head.rightTimeIdx, + timePreds.head.leftTimeIdx)) + } (bounds, remainCondition) } @@ -196,8 +219,8 @@ object WindowJoinUtil { case (Some(left), Some(right)) if left.isLeftInput == right.isLeftInput => // Window join predicates must reference the time attribute of both inputs. Right(pred) - case (Some(left), Some(_)) => - Left(TimePredicate(left.isEventTime, left.isLeftInput, c)) + case (Some(left), Some(right)) => + Left(TimePredicate(left.isEventTime, left.isLeftInput, left.idx, right.idx, c)) } // not a comparison predicate. case _ => Right(pred) @@ -224,8 +247,11 @@ object WindowJoinUtil { inputType.getFieldList.get(idx).getType match { case t: TimeIndicatorRelDataType => // time attribute access. Remember time type and side of input - val isLeftInput = idx < leftFieldCount - Seq(TimeAttributeAccess(t.isEventTime, isLeftInput)) + if (idx < leftFieldCount) { + Seq(TimeAttributeAccess(t.isEventTime, true, idx)) + } else { + Seq(TimeAttributeAccess(t.isEventTime, false, idx - leftFieldCount)) + } case _ => // not a time attribute access. Seq() http://git-wip-us.apache.org/repos/asf/flink/blob/655d8b16/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/sql/JoinTest.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/sql/JoinTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/sql/JoinTest.scala index e066fe4..a4234c5 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/sql/JoinTest.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/sql/JoinTest.scala @@ -32,7 +32,7 @@ class JoinTest extends TableTestBase { streamUtil.addTable[(Int, String, Long)]("MyTable2", 'a, 'b, 'c.rowtime, 'proctime.proctime) @Test - def testProcessingTimeInnerJoinWithOnClause() = { + def testProcessingTimeInnerJoinWithOnClause(): Unit = { val sqlQuery = """ @@ -70,7 +70,45 @@ class JoinTest extends TableTestBase { } @Test - def testProcessingTimeInnerJoinWithWhereClause() = { + def testRowTimeInnerJoinWithOnClause(): Unit = { + + val sqlQuery = + """ + |SELECT t1.a, t2.b + |FROM MyTable t1 JOIN MyTable2 t2 ON + | t1.a = t2.a AND + | t1.c BETWEEN t2.c - INTERVAL '10' SECOND AND t2.c + INTERVAL '1' HOUR + |""".stripMargin + + val expected = + unaryNode( + "DataStreamCalc", + binaryNode( + "DataStreamWindowJoin", + unaryNode( + "DataStreamCalc", + streamTableNode(0), + term("select", "a", "c") + ), + unaryNode( + "DataStreamCalc", + streamTableNode(1), + term("select", "a", "b", "c") + ), + term("where", + "AND(=(a, a0), >=(c, -(c0, 10000)), " + + "<=(c, DATETIME_PLUS(c0, 3600000)))"), + term("join", "a, c, a0, b, c0"), + term("joinType", "InnerJoin") + ), + term("select", "a", "b") + ) + + streamUtil.verifySql(sqlQuery, expected) + } + + @Test + def testProcessingTimeInnerJoinWithWhereClause(): Unit = { val sqlQuery = """ @@ -108,6 +146,44 @@ class JoinTest extends TableTestBase { } @Test + def testRowTimeInnerJoinWithWhereClause(): Unit = { + + val sqlQuery = + """ + |SELECT t1.a, t2.b + |FROM MyTable t1, MyTable2 t2 + |WHERE t1.a = t2.a AND + | t1.c BETWEEN t2.c - INTERVAL '10' MINUTE AND t2.c + INTERVAL '1' HOUR + |""".stripMargin + + val expected = + unaryNode( + "DataStreamCalc", + binaryNode( + "DataStreamWindowJoin", + unaryNode( + "DataStreamCalc", + streamTableNode(0), + term("select", "a", "c") + ), + unaryNode( + "DataStreamCalc", + streamTableNode(1), + term("select", "a", "b", "c") + ), + term("where", + "AND(=(a, a0), >=(c, -(c0, 600000)), " + + "<=(c, DATETIME_PLUS(c0, 3600000)))"), + term("join", "a, c, a0, b, c0"), + term("joinType", "InnerJoin") + ), + term("select", "a", "b0 AS b") + ) + + streamUtil.verifySql(sqlQuery, expected) + } + + @Test def testJoinTimeBoundary(): Unit = { verifyTimeBoundary( "t1.proctime between t2.proctime - interval '1' hour " + @@ -175,16 +251,17 @@ class JoinTest extends TableTestBase { "SELECT t1.a, t2.c FROM MyTable3 as t1 join MyTable4 as t2 on t1.a = t2.a and " + "t1.b >= t2.b - interval '10' second and t1.b <= t2.b - interval '5' second and " + "t1.c > t2.c" + // The equi-join predicate should also be included verifyRemainConditionConvert( query, - ">($2, $6)") + "AND(=($0, $4), >($2, $6))") val query1 = "SELECT t1.a, t2.c FROM MyTable3 as t1 join MyTable4 as t2 on t1.a = t2.a and " + "t1.b >= t2.b - interval '10' second and t1.b <= t2.b - interval '5' second " verifyRemainConditionConvert( query1, - "") + "=($0, $4)") streamUtil.addTable[(Int, Long, Int)]("MyTable5", 'a, 'b, 'c, 'proctime.proctime) streamUtil.addTable[(Int, Long, Int)]("MyTable6", 'a, 'b, 'c, 'proctime.proctime) @@ -195,7 +272,7 @@ class JoinTest extends TableTestBase { "t1.c > t2.c" verifyRemainConditionConvert( query2, - ">($2, $6)") + "AND(=($0, $4), >($2, $6))") } private def verifyTimeBoundary( @@ -209,10 +286,9 @@ class JoinTest extends TableTestBase { val resultTable = streamUtil.tableEnv.sqlQuery(query) val relNode = resultTable.getRelNode val joinNode = relNode.getInput(0).asInstanceOf[LogicalJoin] - val rexNode = joinNode.getCondition val (windowBounds, _) = WindowJoinUtil.extractWindowBoundsFromPredicate( - rexNode, + joinNode.getCondition, 4, joinNode.getRowType, joinNode.getCluster.getRexBuilder, @@ -233,11 +309,9 @@ class JoinTest extends TableTestBase { val resultTable = streamUtil.tableEnv.sqlQuery(query) val relNode = resultTable.getRelNode val joinNode = relNode.getInput(0).asInstanceOf[LogicalJoin] - val joinInfo = joinNode.analyzeCondition - val rexNode = joinInfo.getRemaining(joinNode.getCluster.getRexBuilder) val (_, remainCondition) = WindowJoinUtil.extractWindowBoundsFromPredicate( - rexNode, + joinNode.getCondition, 4, joinNode.getRowType, joinNode.getCluster.getRexBuilder, http://git-wip-us.apache.org/repos/asf/flink/blob/655d8b16/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/JoinHarnessTest.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/JoinHarnessTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/JoinHarnessTest.scala index 065b7bc..192befd 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/JoinHarnessTest.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/JoinHarnessTest.scala @@ -17,29 +17,26 @@ */ package org.apache.flink.table.runtime.harness +import java.lang.{Long => JLong} import java.util.concurrent.ConcurrentLinkedQueue -import java.lang.{Integer => JInt} -import org.apache.flink.api.common.typeinfo.BasicTypeInfo._ -import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation} -import org.apache.flink.api.java.typeutils.RowTypeInfo import org.apache.flink.streaming.api.operators.co.KeyedCoProcessOperator +import org.apache.flink.streaming.api.watermark.Watermark import org.apache.flink.streaming.runtime.streamrecord.StreamRecord import org.apache.flink.streaming.util.KeyedTwoInputStreamOperatorTestHarness +import org.apache.flink.table.api.Types import org.apache.flink.table.runtime.harness.HarnessTestBase.{RowResultSortComparator, TupleRowKeySelector} -import org.apache.flink.table.runtime.join.ProcTimeWindowInnerJoin +import org.apache.flink.table.runtime.join.{ProcTimeBoundedStreamInnerJoin, RowTimeBoundedStreamInnerJoin} import org.apache.flink.table.runtime.types.CRow import org.apache.flink.types.Row +import org.junit.Assert.{assertEquals} import org.junit.Test -import org.junit.Assert.{assertEquals, assertTrue} -class JoinHarnessTest extends HarnessTestBase{ - - private val rT = new RowTypeInfo(Array[TypeInformation[_]]( - INT_TYPE_INFO, - STRING_TYPE_INFO), - Array("a", "b")) +class JoinHarnessTest extends HarnessTestBase { + private val rowType = Types.ROW( + Types.LONG, + Types.STRING) val funcCode: String = """ @@ -75,84 +72,88 @@ class JoinHarnessTest extends HarnessTestBase{ /** a.proctime >= b.proctime - 10 and a.proctime <= b.proctime + 20 **/ @Test - def testNormalProcTimeJoin() { + def testProcTimeJoinWithCommonBounds() { - val joinProcessFunc = new ProcTimeWindowInnerJoin(-10, 20, rT, rT, "TestJoinFunction", funcCode) + val joinProcessFunc = new ProcTimeBoundedStreamInnerJoin( + -10, 20, 0, rowType, rowType, "TestJoinFunction", funcCode) val operator: KeyedCoProcessOperator[Integer, CRow, CRow, CRow] = new KeyedCoProcessOperator[Integer, CRow, CRow, CRow](joinProcessFunc) val testHarness: KeyedTwoInputStreamOperatorTestHarness[Integer, CRow, CRow, CRow] = new KeyedTwoInputStreamOperatorTestHarness[Integer, CRow, CRow, CRow]( - operator, - new TupleRowKeySelector[Integer](0), - new TupleRowKeySelector[Integer](0), - BasicTypeInfo.INT_TYPE_INFO, - 1, 1, 0) + operator, + new TupleRowKeySelector[Integer](0), + new TupleRowKeySelector[Integer](0), + Types.INT, + 1, 1, 0) testHarness.open() - // left stream input testHarness.setProcessingTime(1) testHarness.processElement1(new StreamRecord( - CRow(Row.of(1: JInt, "aaa"), true), 1)) + CRow(Row.of(1L: JLong, "1a1"), true), 1)) assertEquals(1, testHarness.numProcessingTimeTimers()) testHarness.setProcessingTime(2) testHarness.processElement1(new StreamRecord( - CRow(Row.of(2: JInt, "bbb"), true), 2)) + CRow(Row.of(2L: JLong, "2a2"), true), 2)) + + // timers for key = 1 and key = 2 assertEquals(2, testHarness.numProcessingTimeTimers()) + testHarness.setProcessingTime(3) testHarness.processElement1(new StreamRecord( - CRow(Row.of(1: JInt, "aaa2"), true), 3)) + CRow(Row.of(1L: JLong, "1a3"), true), 3)) assertEquals(4, testHarness.numKeyedStateEntries()) + + // The number of timers won't increase. assertEquals(2, testHarness.numProcessingTimeTimers()) - // right stream input and output normally testHarness.processElement2(new StreamRecord( - CRow(Row.of(1: JInt, "Hi1"), true), 3)) + CRow(Row.of(1L: JLong, "1b3"), true), 3)) testHarness.setProcessingTime(4) testHarness.processElement2(new StreamRecord( - CRow(Row.of(2: JInt, "Hello1"), true), 4)) - assertEquals(8, testHarness.numKeyedStateEntries()) - assertEquals(4, testHarness.numProcessingTimeTimers()) + CRow(Row.of(2L: JLong, "2b4"), true), 4)) - // expired left stream record at timestamp 1 - testHarness.setProcessingTime(12) + // The number of states should be doubled. assertEquals(8, testHarness.numKeyedStateEntries()) assertEquals(4, testHarness.numProcessingTimeTimers()) + + // Test for -10 boundary (13 - 10 = 3). + // The left row (key = 1) with timestamp = 1 will be eagerly removed here. + testHarness.setProcessingTime(13) testHarness.processElement2(new StreamRecord( - CRow(Row.of(1: JInt, "Hi2"), true), 12)) + CRow(Row.of(1L: JLong, "1b13"), true), 13)) - // expired right stream record at timestamp 4 and all left stream - testHarness.setProcessingTime(25) - assertEquals(2, testHarness.numKeyedStateEntries()) - assertEquals(1, testHarness.numProcessingTimeTimers()) + // Test for +20 boundary (13 + 20 = 33). + testHarness.setProcessingTime(33) + assertEquals(4, testHarness.numKeyedStateEntries()) + assertEquals(2, testHarness.numProcessingTimeTimers()) testHarness.processElement1(new StreamRecord( - CRow(Row.of(1: JInt, "aaa3"), true), 25)) + CRow(Row.of(1L: JLong, "1a33"), true), 33)) + testHarness.processElement1(new StreamRecord( - CRow(Row.of(2: JInt, "bbb2"), true), 25)) + CRow(Row.of(2L: JLong, "2a33"), true), 33)) + + // The left row (key = 2) with timestamp = 2 will be eagerly removed here. testHarness.processElement2(new StreamRecord( - CRow(Row.of(2: JInt, "Hello2"), true), 25)) + CRow(Row.of(2L: JLong, "2b33"), true), 33)) - testHarness.setProcessingTime(45) - assertTrue(testHarness.numKeyedStateEntries() > 0) - testHarness.setProcessingTime(46) - assertEquals(0, testHarness.numKeyedStateEntries()) val result = testHarness.getOutput val expectedOutput = new ConcurrentLinkedQueue[Object]() expectedOutput.add(new StreamRecord( - CRow(Row.of(1: JInt, "aaa", 1: JInt, "Hi1"), true), 3)) + CRow(Row.of(1L: JLong, "1a1", 1L: JLong, "1b3"), true), 3)) expectedOutput.add(new StreamRecord( - CRow(Row.of(1: JInt, "aaa2", 1: JInt, "Hi1"), true), 3)) + CRow(Row.of(1L: JLong, "1a3", 1L: JLong, "1b3"), true), 3)) expectedOutput.add(new StreamRecord( - CRow(Row.of(2: JInt, "bbb", 2: JInt, "Hello1"), true), 4)) + CRow(Row.of(2L: JLong, "2a2", 2L: JLong, "2b4"), true), 4)) expectedOutput.add(new StreamRecord( - CRow(Row.of(1: JInt, "aaa2", 1: JInt, "Hi2"), true), 12)) + CRow(Row.of(1L: JLong, "1a3", 1L: JLong, "1b13"), true), 13)) expectedOutput.add(new StreamRecord( - CRow(Row.of(1: JInt, "aaa3", 1: JInt, "Hi2"), true), 25)) + CRow(Row.of(1L: JLong, "1a33", 1L: JLong, "1b13"), true), 33)) expectedOutput.add(new StreamRecord( - CRow(Row.of(2: JInt, "bbb2", 2: JInt, "Hello2"), true), 25)) + CRow(Row.of(2L: JLong, "2a33", 2L: JLong, "2b33"), true), 33)) verify(expectedOutput, result, new RowResultSortComparator()) @@ -161,9 +162,10 @@ class JoinHarnessTest extends HarnessTestBase{ /** a.proctime >= b.proctime - 10 and a.proctime <= b.proctime - 5 **/ @Test - def testProcTimeJoinSingleNeedStore() { + def testProcTimeJoinWithNegativeBounds() { - val joinProcessFunc = new ProcTimeWindowInnerJoin(-10, -5, rT, rT, "TestJoinFunction", funcCode) + val joinProcessFunc = new ProcTimeBoundedStreamInnerJoin( + -10, -5, 0, rowType, rowType, "TestJoinFunction", funcCode) val operator: KeyedCoProcessOperator[Integer, CRow, CRow, CRow] = new KeyedCoProcessOperator[Integer, CRow, CRow, CRow](joinProcessFunc) @@ -172,50 +174,58 @@ class JoinHarnessTest extends HarnessTestBase{ operator, new TupleRowKeySelector[Integer](0), new TupleRowKeySelector[Integer](0), - BasicTypeInfo.INT_TYPE_INFO, + Types.INT, 1, 1, 0) testHarness.open() testHarness.setProcessingTime(1) testHarness.processElement1(new StreamRecord( - CRow(Row.of(1: JInt, "aaa1"), true), 1)) + CRow(Row.of(1L: JLong, "1a1"), true), 1)) testHarness.setProcessingTime(2) testHarness.processElement1(new StreamRecord( - CRow(Row.of(2: JInt, "aaa2"), true), 2)) + CRow(Row.of(2L: JLong, "2a2"), true), 2)) testHarness.setProcessingTime(3) testHarness.processElement1(new StreamRecord( - CRow(Row.of(1: JInt, "aaa3"), true), 3)) + CRow(Row.of(1L: JLong, "1a3"), true), 3)) assertEquals(4, testHarness.numKeyedStateEntries()) assertEquals(2, testHarness.numProcessingTimeTimers()) - // Do not store b elements - // not meet a.proctime <= b.proctime - 5 + // All the right rows will not be cached. testHarness.processElement2(new StreamRecord( - CRow(Row.of(1: JInt, "bbb3"), true), 3)) + CRow(Row.of(1L: JLong, "1b3"), true), 3)) assertEquals(4, testHarness.numKeyedStateEntries()) assertEquals(2, testHarness.numProcessingTimeTimers()) - // meet a.proctime <= b.proctime - 5 testHarness.setProcessingTime(7) + + // Meets a.proctime <= b.proctime - 5. + // This row will only be joined without being cached (7 >= 7 - 5). testHarness.processElement2(new StreamRecord( - CRow(Row.of(2: JInt, "bbb7"), true), 7)) + CRow(Row.of(2L: JLong, "2b7"), true), 7)) assertEquals(4, testHarness.numKeyedStateEntries()) assertEquals(2, testHarness.numProcessingTimeTimers()) - // expire record of stream a at timestamp 1 testHarness.setProcessingTime(12) - assertEquals(4, testHarness.numKeyedStateEntries()) - assertEquals(2, testHarness.numProcessingTimeTimers()) + // The left row (key = 1) with timestamp = 1 will be eagerly removed here. testHarness.processElement2(new StreamRecord( - CRow(Row.of(1: JInt, "bbb12"), true), 12)) + CRow(Row.of(1L: JLong, "1b12"), true), 12)) + // We add a delay (relativeWindowSize / 2) for cleaning up state. + // No timers will be triggered here. testHarness.setProcessingTime(13) + assertEquals(4, testHarness.numKeyedStateEntries()) + assertEquals(2, testHarness.numProcessingTimeTimers()) + + // Trigger the timer registered by the left row (key = 1) with timestamp = 1 + // (1 + 10 + 2 + 0 + 1 = 14). + // The left row (key = 1) with timestamp = 3 will removed here. + testHarness.setProcessingTime(14) assertEquals(2, testHarness.numKeyedStateEntries()) assertEquals(1, testHarness.numProcessingTimeTimers()) - // state must be cleaned after the window timer interval has passed without new rows. - testHarness.setProcessingTime(23) + // Clean up the left row (key = 2) with timestamp = 2. + testHarness.setProcessingTime(16) assertEquals(0, testHarness.numKeyedStateEntries()) assertEquals(0, testHarness.numProcessingTimeTimers()) val result = testHarness.getOutput @@ -223,13 +233,174 @@ class JoinHarnessTest extends HarnessTestBase{ val expectedOutput = new ConcurrentLinkedQueue[Object]() expectedOutput.add(new StreamRecord( - CRow(Row.of(2: JInt, "aaa2", 2: JInt, "bbb7"), true), 7)) + CRow(Row.of(2L: JLong, "2a2", 2L: JLong, "2b7"), true), 7)) expectedOutput.add(new StreamRecord( - CRow(Row.of(1: JInt, "aaa3", 1: JInt, "bbb12"), true), 12)) + CRow(Row.of(1L: JLong, "1a3", 1L: JLong, "1b12"), true), 12)) verify(expectedOutput, result, new RowResultSortComparator()) testHarness.close() } + /** a.c1 >= b.rowtime - 10 and a.rowtime <= b.rowtime + 20 **/ + @Test + def testRowTimeJoinWithCommonBounds() { + + val joinProcessFunc = new RowTimeBoundedStreamInnerJoin( + -10, 20, 0, rowType, rowType, "TestJoinFunction", funcCode, 0, 0) + + val operator: KeyedCoProcessOperator[String, CRow, CRow, CRow] = + new KeyedCoProcessOperator[String, CRow, CRow, CRow](joinProcessFunc) + val testHarness: KeyedTwoInputStreamOperatorTestHarness[String, CRow, CRow, CRow] = + new KeyedTwoInputStreamOperatorTestHarness[String, CRow, CRow, CRow]( + operator, + new TupleRowKeySelector[String](1), + new TupleRowKeySelector[String](1), + Types.STRING, + 1, 1, 0) + + testHarness.open() + + testHarness.processWatermark1(new Watermark(1)) + testHarness.processWatermark2(new Watermark(1)) + + // Test late data. + testHarness.processElement1(new StreamRecord[CRow]( + CRow(Row.of(1L: JLong, "k1"), true), 0)) + + // Though (1L, "k1") is actually late, it will also be cached. + assertEquals(1, testHarness.numEventTimeTimers()) + + testHarness.processElement1(new StreamRecord[CRow]( + CRow(Row.of(2L: JLong, "k1"), true), 0)) + testHarness.processElement2(new StreamRecord[CRow]( + CRow(Row.of(2L: JLong, "k1"), true), 0)) + + assertEquals(2, testHarness.numEventTimeTimers()) + assertEquals(4, testHarness.numKeyedStateEntries()) + + testHarness.processElement1(new StreamRecord[CRow]( + CRow(Row.of(5L: JLong, "k1"), true), 0)) + + testHarness.processElement2(new StreamRecord[CRow]( + CRow(Row.of(15L: JLong, "k1"), true), 0)) + + testHarness.processWatermark1(new Watermark(20)) + testHarness.processWatermark2(new Watermark(20)) + + assertEquals(4, testHarness.numKeyedStateEntries()) + + testHarness.processElement1(new StreamRecord[CRow]( + CRow(Row.of(35L: JLong, "k1"), true), 0)) + + // The right rows with timestamp = 2 and 5 will be removed here. + // The left rows with timestamp = 2 and 15 will be removed here. + testHarness.processWatermark1(new Watermark(38)) + testHarness.processWatermark2(new Watermark(38)) + + testHarness.processElement1(new StreamRecord[CRow]( + CRow(Row.of(40L: JLong, "k2"), true), 0)) + testHarness.processElement2(new StreamRecord[CRow]( + CRow(Row.of(39L: JLong, "k2"), true), 0)) + + assertEquals(6, testHarness.numKeyedStateEntries()) + + // The right row with timestamp = 35 will be removed here. + testHarness.processWatermark1(new Watermark(61)) + testHarness.processWatermark2(new Watermark(61)) + + assertEquals(4, testHarness.numKeyedStateEntries()) + + val expectedOutput = new ConcurrentLinkedQueue[Object]() + expectedOutput.add(new StreamRecord( + CRow(Row.of(2L: JLong, "k1", 2L: JLong, "k1"), true), 0)) + expectedOutput.add(new StreamRecord( + CRow(Row.of(5L: JLong, "k1", 2L: JLong, "k1"), true), 0)) + expectedOutput.add(new StreamRecord( + CRow(Row.of(5L: JLong, "k1", 15L: JLong, "k1"), true), 0)) + expectedOutput.add(new StreamRecord( + CRow(Row.of(35L: JLong, "k1", 15L: JLong, "k1"), true), 0)) + expectedOutput.add(new StreamRecord( + CRow(Row.of(40L: JLong, "k2", 39L: JLong, "k2"), true), 0)) + + // This result is produced by the late row (1, "k1"). + expectedOutput.add(new StreamRecord( + CRow(Row.of(1L: JLong, "k1", 2L: JLong, "k1"), true), 0)) + + val result = testHarness.getOutput + verify(expectedOutput, result, new RowResultSortComparator()) + testHarness.close() + } + + /** a.rowtime >= b.rowtime - 10 and a.rowtime <= b.rowtime - 7 **/ + @Test + def testRowTimeJoinWithNegativeBounds() { + + val joinProcessFunc = new RowTimeBoundedStreamInnerJoin( + -10, -7, 0, rowType, rowType, "TestJoinFunction", funcCode, 0, 0) + + val operator: KeyedCoProcessOperator[String, CRow, CRow, CRow] = + new KeyedCoProcessOperator[String, CRow, CRow, CRow](joinProcessFunc) + val testHarness: KeyedTwoInputStreamOperatorTestHarness[String, CRow, CRow, CRow] = + new KeyedTwoInputStreamOperatorTestHarness[String, CRow, CRow, CRow]( + operator, + new TupleRowKeySelector[String](1), + new TupleRowKeySelector[String](1), + Types.STRING, + 1, 1, 0) + + testHarness.open() + + testHarness.processWatermark1(new Watermark(1)) + testHarness.processWatermark2(new Watermark(1)) + + // This row will not be cached. + testHarness.processElement2(new StreamRecord[CRow]( + CRow(Row.of(2L: JLong, "k1"), true), 0)) + + assertEquals(0, testHarness.numKeyedStateEntries()) + + testHarness.processWatermark1(new Watermark(2)) + testHarness.processWatermark2(new Watermark(2)) + + testHarness.processElement1(new StreamRecord[CRow]( + CRow(Row.of(3L: JLong, "k1"), true), 0)) + testHarness.processElement2(new StreamRecord[CRow]( + CRow(Row.of(3L: JLong, "k1"), true), 0)) + + // Test for -10 boundary (13 - 10 = 3). + // This row from the right stream will be cached. + // The clean time for the left stream is 13 - 7 + 1 - 1 = 8 + testHarness.processElement2(new StreamRecord[CRow]( + CRow(Row.of(13L: JLong, "k1"), true), 0)) + + // Test for -7 boundary (13 - 7 = 6). + testHarness.processElement1(new StreamRecord[CRow]( + CRow(Row.of(6L: JLong, "k1"), true), 0)) + + assertEquals(4, testHarness.numKeyedStateEntries()) + + // Trigger the left timer with timestamp 8. + // The row with timestamp = 13 will be removed here (13 < 10 + 7). + testHarness.processWatermark1(new Watermark(10)) + testHarness.processWatermark2(new Watermark(10)) + + assertEquals(2, testHarness.numKeyedStateEntries()) + + // Clear the states. + testHarness.processWatermark1(new Watermark(18)) + testHarness.processWatermark2(new Watermark(18)) + + assertEquals(0, testHarness.numKeyedStateEntries()) + + val expectedOutput = new ConcurrentLinkedQueue[Object]() + expectedOutput.add(new StreamRecord( + CRow(Row.of(3L: JLong, "k1", 13L: JLong, "k1"), true), 0)) + expectedOutput.add(new StreamRecord( + CRow(Row.of(6L: JLong, "k1", 13L: JLong, "k1"), true), 0)) + + val result = testHarness.getOutput + verify(expectedOutput, result, new RowResultSortComparator()) + testHarness.close() + } } http://git-wip-us.apache.org/repos/asf/flink/blob/655d8b16/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/JoinITCase.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/JoinITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/JoinITCase.scala index e40da7a..13bfbcd 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/JoinITCase.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/JoinITCase.scala @@ -19,18 +19,22 @@ package org.apache.flink.table.runtime.stream.sql import org.apache.flink.api.scala._ +import org.apache.flink.streaming.api.TimeCharacteristic +import org.apache.flink.streaming.api.functions.AssignerWithPunctuatedWatermarks import org.apache.flink.streaming.api.scala.StreamExecutionEnvironment +import org.apache.flink.streaming.api.watermark.Watermark import org.apache.flink.table.api.TableEnvironment import org.apache.flink.table.api.scala._ import org.apache.flink.table.runtime.utils.{StreamITCase, StreamingWithStateTestBase} import org.apache.flink.types.Row +import org.hamcrest.CoreMatchers import org.junit._ import scala.collection.mutable class JoinITCase extends StreamingWithStateTestBase { - /** test process time inner join **/ + /** test proctime inner join **/ @Test def testProcessTimeInnerJoin(): Unit = { val env = StreamExecutionEnvironment.getExecutionEnvironment @@ -39,8 +43,14 @@ class JoinITCase extends StreamingWithStateTestBase { StreamITCase.clear env.setParallelism(1) - val sqlQuery = "SELECT t2.a, t2.c, t1.c from T1 as t1 join T2 as t2 on t1.a = t2.a and " + - "t1.proctime between t2.proctime - interval '5' second and t2.proctime + interval '5' second" + val sqlQuery = + """ + |SELECT t2.a, t2.c, t1.c + |FROM T1 as t1 join T2 as t2 ON + | t1.a = t2.a AND + | t1.proctime BETWEEN t2.proctime - INTERVAL '5' SECOND AND + | t2.proctime + INTERVAL '5' SECOND + |""".stripMargin val data1 = new mutable.MutableList[(Int, Long, String)] data1.+=((1, 1L, "Hi1")) @@ -65,19 +75,24 @@ class JoinITCase extends StreamingWithStateTestBase { env.execute() } - /** test process time inner join with other condition **/ + /** test proctime inner join with other condition **/ @Test - def testProcessTimeInnerJoinWithOtherCondition(): Unit = { + def testProcessTimeInnerJoinWithOtherConditions(): Unit = { val env = StreamExecutionEnvironment.getExecutionEnvironment val tEnv = TableEnvironment.getTableEnvironment(env) env.setStateBackend(getStateBackend) StreamITCase.clear - env.setParallelism(1) + env.setParallelism(2) - val sqlQuery = "SELECT t2.a, t2.c, t1.c from T1 as t1 join T2 as t2 on t1.a = t2.a and " + - "t1.proctime between t2.proctime - interval '5' second " + - "and t2.proctime + interval '5' second " + - "and t1.b > t2.b and t1.b + t2.b < 14" + val sqlQuery = + """ + |SELECT t2.a, t2.c, t1.c + |FROM T1 as t1 JOIN T2 as t2 ON + | t1.a = t2.a AND + | t1.proctime BETWEEN t2.proctime - interval '5' SECOND AND + | t2.proctime + interval '5' second AND + | t1.b = t2.b + |""".stripMargin val data1 = new mutable.MutableList[(String, Long, String)] data1.+=(("1", 1L, "Hi1")) @@ -91,6 +106,10 @@ class JoinITCase extends StreamingWithStateTestBase { data2.+=(("1", 5L, "HiHi")) data2.+=(("2", 2L, "HeHe")) + // For null key test + data1.+=((null.asInstanceOf[String], 20L, "leftNull")) + data2.+=((null.asInstanceOf[String], 20L, "rightNull")) + val t1 = env.fromCollection(data1).toTable(tEnv, 'a, 'b, 'c, 'proctime.proctime) val t2 = env.fromCollection(data2).toTable(tEnv, 'a, 'b, 'c, 'proctime.proctime) @@ -100,7 +119,173 @@ class JoinITCase extends StreamingWithStateTestBase { val result = tEnv.sqlQuery(sqlQuery).toAppendStream[Row] result.addSink(new StreamITCase.StringSink[Row]) env.execute() + + // Assert there is no result with null keys. + Assert.assertFalse(StreamITCase.testResults.toString().contains("null")) + } + + /** test rowtime inner join **/ + @Test + def testRowTimeInnerJoin(): Unit = { + val env = StreamExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env) + env.setStateBackend(getStateBackend) + env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime) + StreamITCase.clear + env.setParallelism(1) + + val sqlQuery = + """ + |SELECT t2.a, t2.c, t1.c + |FROM T1 as t1 join T2 as t2 ON + | t1.a = t2.a AND + | t1.rt BETWEEN t2.rt - INTERVAL '5' SECOND AND + | t2.rt + INTERVAL '6' SECOND + |""".stripMargin + + val data1 = new mutable.MutableList[(Int, Long, String, Long)] + // for boundary test + data1.+=((1, 999L, "LEFT0.999", 999L)) + data1.+=((1, 1000L, "LEFT1", 1000L)) + data1.+=((1, 2000L, "LEFT2", 2000L)) + data1.+=((1, 3000L, "LEFT3", 3000L)) + data1.+=((2, 4000L, "LEFT4", 4000L)) + data1.+=((1, 5000L, "LEFT5", 5000L)) + data1.+=((1, 6000L, "LEFT6", 6000L)) + + val data2 = new mutable.MutableList[(Int, Long, String, Long)] + data2.+=((1, 6000L, "RIGHT6", 6000L)) + data2.+=((2, 7000L, "RIGHT7", 7000L)) + + val t1 = env.fromCollection(data1) + .assignTimestampsAndWatermarks(new Tuple2WatermarkExtractor) + .toTable(tEnv, 'a, 'b, 'c, 'rt.rowtime) + val t2 = env.fromCollection(data2) + .assignTimestampsAndWatermarks(new Tuple2WatermarkExtractor) + .toTable(tEnv, 'a, 'b, 'c, 'rt.rowtime) + + tEnv.registerTable("T1", t1) + tEnv.registerTable("T2", t2) + + val result = tEnv.sql(sqlQuery).toAppendStream[Row] + result.addSink(new StreamITCase.StringSink[Row]) + env.execute() + val expected = new java.util.ArrayList[String] + expected.add("1,RIGHT6,LEFT1") + expected.add("1,RIGHT6,LEFT2") + expected.add("1,RIGHT6,LEFT3") + expected.add("1,RIGHT6,LEFT5") + expected.add("1,RIGHT6,LEFT6") + expected.add("2,RIGHT7,LEFT4") + StreamITCase.compareWithList(expected) } + /** test rowtime inner join with other conditions **/ + @Test + def testRowTimeInnerJoinWithOtherConditions(): Unit = { + val env = StreamExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env) + env.setStateBackend(getStateBackend) + env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime) + StreamITCase.clear + + // different parallelisms lead to different join results + env.setParallelism(1) + + val sqlQuery = + """ + |SELECT t2.a, t1.c, t2.c + |FROM T1 as t1 JOIN T2 as t2 ON + | t1.a = t2.a AND + | t1.rt > t2.rt - INTERVAL '5' SECOND AND + | t1.rt < t2.rt - INTERVAL '1' SECOND AND + | t1.b < t2.b AND + | t1.b > 2 + |""".stripMargin + + val data1 = new mutable.MutableList[(Int, Long, String, Long)] + data1.+=((1, 4L, "LEFT1", 1000L)) + // for boundary test + data1.+=((1, 8L, "LEFT1.1", 1001L)) + // predicate (t1.b > 2) push down + data1.+=((1, 2L, "LEFT2", 2000L)) + data1.+=((1, 7L, "LEFT3", 3000L)) + data1.+=((2, 5L, "LEFT4", 4000L)) + // for boundary test + data1.+=((1, 4L, "LEFT4.9", 4999L)) + data1.+=((1, 4L, "LEFT5", 5000L)) + data1.+=((1, 10L, "LEFT6", 6000L)) + // a left late row + data1.+=((1, 3L, "LEFT3.5", 3500L)) + + val data2 = new mutable.MutableList[(Int, Long, String, Long)] + // just for watermark + data2.+=((1, 1L, "RIGHT1", 1000L)) + data2.+=((1, 9L, "RIGHT6", 6000L)) + data2.+=((2, 14L, "RIGHT7", 7000L)) + data2.+=((1, 4L, "RIGHT8", 8000L)) + // a right late row + data2.+=((1, 10L, "RIGHT5", 5000L)) + + val t1 = env.fromCollection(data1) + .assignTimestampsAndWatermarks(new Tuple2WatermarkExtractor) + .toTable(tEnv, 'a, 'b, 'c, 'rt.rowtime) + val t2 = env.fromCollection(data2) + .assignTimestampsAndWatermarks(new Tuple2WatermarkExtractor) + .toTable(tEnv, 'a, 'b, 'c, 'rt.rowtime) + + tEnv.registerTable("T1", t1) + tEnv.registerTable("T2", t2) + + val result = tEnv.sql(sqlQuery).toAppendStream[Row] + result.addSink(new StreamITCase.StringSink[Row]) + env.execute() + + // There may be two expected results according to the process order. + val expected1 = new mutable.MutableList[String] + expected1+= "1,LEFT3,RIGHT6" + expected1+= "1,LEFT1.1,RIGHT6" + expected1+= "2,LEFT4,RIGHT7" + expected1+= "1,LEFT4.9,RIGHT6" + // produced by the left late rows + expected1+= "1,LEFT3.5,RIGHT6" + expected1+= "1,LEFT3.5,RIGHT8" + // produced by the right late rows + expected1+= "1,LEFT3,RIGHT5" + expected1+= "1,LEFT3.5,RIGHT5" + + val expected2 = new mutable.MutableList[String] + expected2+= "1,LEFT3,RIGHT6" + expected2+= "1,LEFT1.1,RIGHT6" + expected2+= "2,LEFT4,RIGHT7" + expected2+= "1,LEFT4.9,RIGHT6" + // produced by the left late rows + expected2+= "1,LEFT3.5,RIGHT6" + expected2+= "1,LEFT3.5,RIGHT8" + // produced by the right late rows + expected2+= "1,LEFT3,RIGHT5" + expected2+= "1,LEFT1,RIGHT5" + expected2+= "1,LEFT1.1,RIGHT5" + + Assert.assertThat( + StreamITCase.testResults.sorted, + CoreMatchers.either(CoreMatchers.is(expected1.sorted)). + or(CoreMatchers.is(expected2.sorted))) + } } +private class Tuple2WatermarkExtractor + extends AssignerWithPunctuatedWatermarks[(Int, Long, String, Long)] { + + override def checkAndGetNextWatermark( + lastElement: (Int, Long, String, Long), + extractedTimestamp: Long): Watermark = { + new Watermark(extractedTimestamp - 1) + } + + override def extractTimestamp( + element: (Int, Long, String, Long), + previousElementTimestamp: Long): Long = { + element._4 + } +}
