Repository: flink Updated Branches: refs/heads/master 8395508b0 -> dea417260
[FLINK-8456] Add Scala API for Connected Streams with Broadcast State. Project: http://git-wip-us.apache.org/repos/asf/flink/repo Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/9628dc89 Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/9628dc89 Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/9628dc89 Branch: refs/heads/master Commit: 9628dc8917b2f69bb04119c1d41932076632de01 Parents: 8395508 Author: kkloudas <kklou...@gmail.com> Authored: Wed Feb 7 17:22:01 2018 +0100 Committer: kkloudas <kklou...@gmail.com> Committed: Fri Feb 9 18:14:18 2018 +0100 ---------------------------------------------------------------------- .../streaming/api/datastream/DataStream.java | 4 +- .../api/scala/BroadcastConnectedStream.scala | 81 ++++++++++ .../flink/streaming/api/scala/DataStream.scala | 42 +++++ .../flink/streaming/api/scala/package.scala | 8 +- .../api/scala/BroadcastStateITCase.scala | 161 +++++++++++++++++++ 5 files changed, 293 insertions(+), 3 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/flink/blob/9628dc89/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/DataStream.java ---------------------------------------------------------------------- diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/DataStream.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/DataStream.java index 8d18b80..9a17987 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/DataStream.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/DataStream.java @@ -399,8 +399,8 @@ public class DataStream<T> { /** * Sets the partitioning of the {@link DataStream} so that the output elements * are broadcasted to every parallel instance of the next operation. In addition, - * it implicitly creates a {@link org.apache.flink.api.common.state.BroadcastState broadcast state} - * which can be used to store the element of the stream. + * it implicitly as many {@link org.apache.flink.api.common.state.BroadcastState broadcast states} + * as the specified descriptors which can be used to store the element of the stream. * * @param broadcastStateDescriptors the descriptors of the broadcast states to create. * @return A {@link BroadcastStream} which can be used in the {@link #connect(BroadcastStream)} to http://git-wip-us.apache.org/repos/asf/flink/blob/9628dc89/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/BroadcastConnectedStream.scala ---------------------------------------------------------------------- diff --git a/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/BroadcastConnectedStream.scala b/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/BroadcastConnectedStream.scala new file mode 100644 index 0000000..63c7fe0 --- /dev/null +++ b/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/BroadcastConnectedStream.scala @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.api.scala + +import org.apache.flink.annotation.PublicEvolving +import org.apache.flink.api.common.typeinfo.TypeInformation +import org.apache.flink.streaming.api.datastream.{BroadcastConnectedStream => JavaBCStream} +import org.apache.flink.streaming.api.functions.co.{BroadcastProcessFunction, KeyedBroadcastProcessFunction} + +class BroadcastConnectedStream[IN1, IN2](javaStream: JavaBCStream[IN1, IN2]) { + + /** + * Assumes as inputs a [[org.apache.flink.streaming.api.datastream.BroadcastStream]] and a + * [[KeyedStream]] and applies the given [[KeyedBroadcastProcessFunction]] on them, thereby + * creating a transformed output stream. + * + * @param function The [[KeyedBroadcastProcessFunction]] applied to each element in the stream. + * @tparam KS The type of the keys in the keyed stream. + * @tparam OUT The type of the output elements. + * @return The transformed [[DataStream]]. + */ + @PublicEvolving + def process[KS, OUT: TypeInformation]( + function: KeyedBroadcastProcessFunction[KS, IN1, IN2, OUT]) + : DataStream[OUT] = { + + if (function == null) { + throw new NullPointerException("KeyedBroadcastProcessFunction function must not be null.") + } + + val outputTypeInfo : TypeInformation[OUT] = implicitly[TypeInformation[OUT]] + asScalaStream(javaStream.process(function, outputTypeInfo)) + } + + /** + * Assumes as inputs a [[org.apache.flink.streaming.api.datastream.BroadcastStream]] + * and a non-keyed [[DataStream]] and applies the given + * [[org.apache.flink.streaming.api.functions.co.BroadcastProcessFunction]] + * on them, thereby creating a transformed output stream. + * + * @param function The [[BroadcastProcessFunction]] applied to each element in the stream. + * @tparam OUT The type of the output elements. + * @return The transformed { @link DataStream}. + */ + @PublicEvolving + def process[OUT: TypeInformation]( + function: BroadcastProcessFunction[IN1, IN2, OUT]) + : DataStream[OUT] = { + + if (function == null) { + throw new NullPointerException("BroadcastProcessFunction function must not be null.") + } + + val outputTypeInfo : TypeInformation[OUT] = implicitly[TypeInformation[OUT]] + asScalaStream(javaStream.process(function, outputTypeInfo)) + } + + /** + * Returns a "closure-cleaned" version of the given function. Cleans only if closure cleaning + * is not disabled in the [[org.apache.flink.api.common.ExecutionConfig]] + */ + private[flink] def clean[F <: AnyRef](f: F) = { + new StreamExecutionEnvironment(javaStream.getExecutionEnvironment).scalaClean(f) + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/9628dc89/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/DataStream.scala ---------------------------------------------------------------------- diff --git a/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/DataStream.scala b/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/DataStream.scala index ef2e741..9170940 100644 --- a/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/DataStream.scala +++ b/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/DataStream.scala @@ -24,6 +24,7 @@ import org.apache.flink.api.common.functions.{FilterFunction, FlatMapFunction, M import org.apache.flink.api.common.io.OutputFormat import org.apache.flink.api.common.operators.ResourceSpec import org.apache.flink.api.common.serialization.SerializationSchema +import org.apache.flink.api.common.state.MapStateDescriptor import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.flink.api.java.functions.KeySelector import org.apache.flink.api.java.tuple.{Tuple => JavaTuple} @@ -364,6 +365,27 @@ class DataStream[T](stream: JavaStream[T]) { asScalaStream(stream.connect(dataStream.javaStream)) /** + * Creates a new [[BroadcastConnectedStream]] by connecting the current + * [[DataStream]] or [[KeyedStream]] with a [[BroadcastStream]]. + * + * The latter can be created using the [[broadcast(MapStateDescriptor[])]] method. + * + * The resulting stream can be further processed using the + * ``broadcastConnectedStream.process(myFunction)`` + * method, where ``myFunction`` can be either a + * [[org.apache.flink.streaming.api.functions.co.KeyedBroadcastProcessFunction]] + * or a [[org.apache.flink.streaming.api.functions.co.BroadcastProcessFunction]] + * depending on the current stream being a [[KeyedStream]] or not. + * + * @param broadcastStream The broadcast stream with the broadcast state to be + * connected with this stream. + * @return The [[BroadcastConnectedStream]]. + */ + @PublicEvolving + def connect[R](broadcastStream: BroadcastStream[R]): BroadcastConnectedStream[T, R] = + asScalaStream(stream.connect(broadcastStream)) + + /** * Groups the elements of a DataStream by the given key positions (for tuple/array types) to * be used with grouped operators like grouped reduce or grouped aggregations. */ @@ -442,6 +464,26 @@ class DataStream[T](stream: JavaStream[T]) { def broadcast: DataStream[T] = asScalaStream(stream.broadcast()) /** + * Sets the partitioning of the [[DataStream]] so that the output elements + * are broadcasted to every parallel instance of the next operation. In addition, + * it implicitly creates as many + * [[org.apache.flink.api.common.state.BroadcastState broadcast states]] + * as the specified descriptors which can be used to store the element of the stream. + * + * @param broadcastStateDescriptors the descriptors of the broadcast states to create. + * @return A [[BroadcastStream]] which can be used in the + * [[DataStream.connect(BroadcastStream)]] to create a + * [[BroadcastConnectedStream]] for further processing of the elements. + */ + @PublicEvolving + def broadcast(broadcastStateDescriptors: MapStateDescriptor[_, _]*): BroadcastStream[T] = { + if (broadcastStateDescriptors == null) { + throw new NullPointerException("Map function must not be null.") + } + stream.broadcast(broadcastStateDescriptors: _*) + } + + /** * Sets the partitioning of the DataStream so that the output values all go to * the first instance of the next processing operator. Use this setting with care * since it might cause a serious performance bottleneck in the application. http://git-wip-us.apache.org/repos/asf/flink/blob/9628dc89/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/package.scala ---------------------------------------------------------------------- diff --git a/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/package.scala b/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/package.scala index 90f255c..ef96fd1 100644 --- a/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/package.scala +++ b/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/package.scala @@ -24,6 +24,7 @@ import org.apache.flink.api.scala.typeutils.{CaseClassTypeInfo, TypeUtils} import org.apache.flink.streaming.api.datastream.{ DataStream => JavaStream } import org.apache.flink.streaming.api.datastream.{ SplitStream => SplitJavaStream } import org.apache.flink.streaming.api.datastream.{ ConnectedStreams => ConnectedJavaStreams } +import org.apache.flink.streaming.api.datastream.{ BroadcastConnectedStream => BroadcastConnectedJavaStreams } import org.apache.flink.streaming.api.datastream.{ KeyedStream => KeyedJavaStream } import language.implicitConversions @@ -61,8 +62,13 @@ package object scala { */ private[flink] def asScalaStream[IN1, IN2](stream: ConnectedJavaStreams[IN1, IN2]) = new ConnectedStreams[IN1, IN2](stream) + /** + * Converts an [[org.apache.flink.streaming.api.datastream.BroadcastConnectedStream]] to a + * [[org.apache.flink.streaming.api.scala.BroadcastConnectedStream]]. + */ + private[flink] def asScalaStream[IN1, IN2](stream: BroadcastConnectedJavaStreams[IN1, IN2]) + = new BroadcastConnectedStream[IN1, IN2](stream) - private[flink] def fieldNames2Indices( typeInfo: TypeInformation[_], fields: Array[String]): Array[Int] = { http://git-wip-us.apache.org/repos/asf/flink/blob/9628dc89/flink-streaming-scala/src/test/scala/org/apache/flink/streaming/api/scala/BroadcastStateITCase.scala ---------------------------------------------------------------------- diff --git a/flink-streaming-scala/src/test/scala/org/apache/flink/streaming/api/scala/BroadcastStateITCase.scala b/flink-streaming-scala/src/test/scala/org/apache/flink/streaming/api/scala/BroadcastStateITCase.scala new file mode 100644 index 0000000..af883e1 --- /dev/null +++ b/flink-streaming-scala/src/test/scala/org/apache/flink/streaming/api/scala/BroadcastStateITCase.scala @@ -0,0 +1,161 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.api.scala + +import org.apache.flink.api.common.state.MapStateDescriptor +import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation} +import org.apache.flink.streaming.api.TimeCharacteristic +import org.apache.flink.streaming.api.functions.AssignerWithPunctuatedWatermarks +import org.apache.flink.streaming.api.functions.co.KeyedBroadcastProcessFunction +import org.apache.flink.streaming.api.functions.sink.RichSinkFunction +import org.apache.flink.streaming.api.watermark.Watermark +import org.apache.flink.test.util.AbstractTestBase +import org.apache.flink.util.Collector +import org.junit.Assert.assertEquals +import org.junit.{Assert, Test} + +/** + * ITCase for the [[org.apache.flink.api.common.state.BroadcastState]]. + */ +class BroadcastStateITCase extends AbstractTestBase { + + @Test + @throws[Exception] + def testConnectWithBroadcastTranslation(): Unit = { + + val timerTimestamp = 100000L + + val DESCRIPTOR = new MapStateDescriptor[Long, String]( + "broadcast-state", + BasicTypeInfo.LONG_TYPE_INFO.asInstanceOf[TypeInformation[Long]], + BasicTypeInfo.STRING_TYPE_INFO) + + val expected = Map[Long, String]( + 0L -> "test:0", + 1L -> "test:1", + 2L -> "test:2", + 3L -> "test:3", + 4L -> "test:4", + 5L -> "test:5") + + val env = StreamExecutionEnvironment.getExecutionEnvironment + env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime) + + val srcOne = env + .generateSequence(0L, 5L) + .assignTimestampsAndWatermarks(new AssignerWithPunctuatedWatermarks[Long]() { + + override def extractTimestamp(element: Long, previousElementTimestamp: Long): Long = + element + + override def checkAndGetNextWatermark(lastElement: Long, extractedTimestamp: Long) = + new Watermark(extractedTimestamp) + + }) + .keyBy((value: Long) => value) + + val srcTwo = env + .fromCollection(expected.values.toSeq) + .assignTimestampsAndWatermarks(new AssignerWithPunctuatedWatermarks[String]() { + + override def extractTimestamp(element: String, previousElementTimestamp: Long): Long = + element.split(":")(1).toLong + + override def checkAndGetNextWatermark(lastElement: String, extractedTimestamp: Long) = + new Watermark(extractedTimestamp) + }) + + val broadcast = srcTwo.broadcast(DESCRIPTOR) + // the timestamp should be high enough to trigger the timer after all the elements arrive. + val output = srcOne.connect(broadcast) + .process(new TestBroadcastProcessFunction(100000L, expected)) + + output + .addSink(new TestSink(expected.size)) + .setParallelism(1) + env.execute + } +} + +class TestBroadcastProcessFunction( + expectedTimestamp: Long, + expectedBroadcastState: Map[Long, String]) + extends KeyedBroadcastProcessFunction[Long, Long, String, String] { + + val localDescriptor = new MapStateDescriptor[Long, String]( + "broadcast-state", + BasicTypeInfo.LONG_TYPE_INFO.asInstanceOf[TypeInformation[Long]], + BasicTypeInfo.STRING_TYPE_INFO) + + @throws[Exception] + override def processElement( + value: Long, + ctx: KeyedBroadcastProcessFunction[Long, Long, String, String]#KeyedReadOnlyContext, + out: Collector[String]): Unit = { + + ctx.timerService.registerEventTimeTimer(expectedTimestamp) + } + + @throws[Exception] + override def processBroadcastElement( + value: String, + ctx: KeyedBroadcastProcessFunction[Long, Long, String, String]#KeyedContext, + out: Collector[String]): Unit = { + + val key = value.split(":")(1).toLong + ctx.getBroadcastState(localDescriptor).put(key, value) + } + + @throws[Exception] + override def onTimer( + timestamp: Long, + ctx: KeyedBroadcastProcessFunction[Long, Long, String, String]#OnTimerContext, + out: Collector[String]): Unit = { + + var map = Map[Long, String]() + + import scala.collection.JavaConversions._ + for (entry <- ctx.getBroadcastState(localDescriptor).immutableEntries()) { + val v = expectedBroadcastState.get(entry.getKey).get + assertEquals(v, entry.getValue) + map += (entry.getKey -> entry.getValue) + } + + Assert.assertEquals(expectedBroadcastState, map) + + out.collect(timestamp.toString) + } +} + +class TestSink(val expectedOutputCounter: Int) extends RichSinkFunction[String] { + + var outputCounter: Int = 0 + + override def invoke(value: String) = { + outputCounter = outputCounter + 1 + } + + @throws[Exception] + override def close(): Unit = { + super.close() + + // make sure that all the timers fired + assertEquals(expectedOutputCounter, outputCounter) + } +}