Repository: flink
Updated Branches:
  refs/heads/master 8395508b0 -> dea417260


[FLINK-8456] Add Scala API for Connected Streams with Broadcast State.


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/9628dc89
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/9628dc89
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/9628dc89

Branch: refs/heads/master
Commit: 9628dc8917b2f69bb04119c1d41932076632de01
Parents: 8395508
Author: kkloudas <kklou...@gmail.com>
Authored: Wed Feb 7 17:22:01 2018 +0100
Committer: kkloudas <kklou...@gmail.com>
Committed: Fri Feb 9 18:14:18 2018 +0100

----------------------------------------------------------------------
 .../streaming/api/datastream/DataStream.java    |   4 +-
 .../api/scala/BroadcastConnectedStream.scala    |  81 ++++++++++
 .../flink/streaming/api/scala/DataStream.scala  |  42 +++++
 .../flink/streaming/api/scala/package.scala     |   8 +-
 .../api/scala/BroadcastStateITCase.scala        | 161 +++++++++++++++++++
 5 files changed, 293 insertions(+), 3 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/9628dc89/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/DataStream.java
----------------------------------------------------------------------
diff --git 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/DataStream.java
 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/DataStream.java
index 8d18b80..9a17987 100644
--- 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/DataStream.java
+++ 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/DataStream.java
@@ -399,8 +399,8 @@ public class DataStream<T> {
        /**
         * Sets the partitioning of the {@link DataStream} so that the output 
elements
         * are broadcasted to every parallel instance of the next operation. In 
addition,
-        * it implicitly creates a {@link 
org.apache.flink.api.common.state.BroadcastState broadcast state}
-        * which can be used to store the element of the stream.
+        * it implicitly as many {@link 
org.apache.flink.api.common.state.BroadcastState broadcast states}
+        * as the specified descriptors which can be used to store the element 
of the stream.
         *
         * @param broadcastStateDescriptors the descriptors of the broadcast 
states to create.
         * @return A {@link BroadcastStream} which can be used in the {@link 
#connect(BroadcastStream)} to

http://git-wip-us.apache.org/repos/asf/flink/blob/9628dc89/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/BroadcastConnectedStream.scala
----------------------------------------------------------------------
diff --git 
a/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/BroadcastConnectedStream.scala
 
b/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/BroadcastConnectedStream.scala
new file mode 100644
index 0000000..63c7fe0
--- /dev/null
+++ 
b/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/BroadcastConnectedStream.scala
@@ -0,0 +1,81 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.api.scala
+
+import org.apache.flink.annotation.PublicEvolving
+import org.apache.flink.api.common.typeinfo.TypeInformation
+import org.apache.flink.streaming.api.datastream.{BroadcastConnectedStream => 
JavaBCStream}
+import org.apache.flink.streaming.api.functions.co.{BroadcastProcessFunction, 
KeyedBroadcastProcessFunction}
+
+class BroadcastConnectedStream[IN1, IN2](javaStream: JavaBCStream[IN1, IN2]) {
+
+  /**
+    * Assumes as inputs a 
[[org.apache.flink.streaming.api.datastream.BroadcastStream]] and a
+    * [[KeyedStream]] and applies the given [[KeyedBroadcastProcessFunction]] 
on them, thereby
+    * creating a transformed output stream.
+    *
+    * @param function The [[KeyedBroadcastProcessFunction]] applied to each 
element in the stream.
+    * @tparam KS The type of the keys in the keyed stream.
+    * @tparam OUT The type of the output elements.
+    * @return The transformed [[DataStream]].
+    */
+  @PublicEvolving
+  def process[KS, OUT: TypeInformation](
+      function: KeyedBroadcastProcessFunction[KS, IN1, IN2, OUT])
+    : DataStream[OUT] = {
+
+    if (function == null) {
+      throw new NullPointerException("KeyedBroadcastProcessFunction function 
must not be null.")
+    }
+
+    val outputTypeInfo : TypeInformation[OUT] = 
implicitly[TypeInformation[OUT]]
+    asScalaStream(javaStream.process(function, outputTypeInfo))
+  }
+
+  /**
+    * Assumes as inputs a 
[[org.apache.flink.streaming.api.datastream.BroadcastStream]]
+    * and a non-keyed [[DataStream]] and applies the given
+    * [[org.apache.flink.streaming.api.functions.co.BroadcastProcessFunction]]
+    * on them, thereby creating a transformed output stream.
+    *
+    * @param function The [[BroadcastProcessFunction]] applied to each element 
in the stream.
+    * @tparam OUT The type of the output elements.
+    * @return The transformed { @link DataStream}.
+    */
+  @PublicEvolving
+  def process[OUT: TypeInformation](
+      function: BroadcastProcessFunction[IN1, IN2, OUT])
+    : DataStream[OUT] = {
+
+    if (function == null) {
+      throw new NullPointerException("BroadcastProcessFunction function must 
not be null.")
+    }
+
+    val outputTypeInfo : TypeInformation[OUT] = 
implicitly[TypeInformation[OUT]]
+    asScalaStream(javaStream.process(function, outputTypeInfo))
+  }
+
+  /**
+    * Returns a "closure-cleaned" version of the given function. Cleans only 
if closure cleaning
+    * is not disabled in the [[org.apache.flink.api.common.ExecutionConfig]]
+    */
+  private[flink] def clean[F <: AnyRef](f: F) = {
+    new 
StreamExecutionEnvironment(javaStream.getExecutionEnvironment).scalaClean(f)
+  }
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/9628dc89/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/DataStream.scala
----------------------------------------------------------------------
diff --git 
a/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/DataStream.scala
 
b/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/DataStream.scala
index ef2e741..9170940 100644
--- 
a/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/DataStream.scala
+++ 
b/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/DataStream.scala
@@ -24,6 +24,7 @@ import org.apache.flink.api.common.functions.{FilterFunction, 
FlatMapFunction, M
 import org.apache.flink.api.common.io.OutputFormat
 import org.apache.flink.api.common.operators.ResourceSpec
 import org.apache.flink.api.common.serialization.SerializationSchema
+import org.apache.flink.api.common.state.MapStateDescriptor
 import org.apache.flink.api.common.typeinfo.TypeInformation
 import org.apache.flink.api.java.functions.KeySelector
 import org.apache.flink.api.java.tuple.{Tuple => JavaTuple}
@@ -364,6 +365,27 @@ class DataStream[T](stream: JavaStream[T]) {
     asScalaStream(stream.connect(dataStream.javaStream))
 
   /**
+    * Creates a new [[BroadcastConnectedStream]] by connecting the current
+    * [[DataStream]] or [[KeyedStream]] with a [[BroadcastStream]].
+    *
+    * The latter can be created using the [[broadcast(MapStateDescriptor[])]] 
method.
+    *
+    * The resulting stream can be further processed using the
+    * ``broadcastConnectedStream.process(myFunction)``
+    * method, where ``myFunction`` can be either a
+    * 
[[org.apache.flink.streaming.api.functions.co.KeyedBroadcastProcessFunction]]
+    * or a 
[[org.apache.flink.streaming.api.functions.co.BroadcastProcessFunction]]
+    * depending on the current stream being a [[KeyedStream]] or not.
+    *
+    * @param broadcastStream The broadcast stream with the broadcast state to 
be
+    *                        connected with this stream.
+    * @return The [[BroadcastConnectedStream]].
+    */
+  @PublicEvolving
+  def connect[R](broadcastStream: BroadcastStream[R]): 
BroadcastConnectedStream[T, R] =
+    asScalaStream(stream.connect(broadcastStream))
+
+  /**
    * Groups the elements of a DataStream by the given key positions (for 
tuple/array types) to
    * be used with grouped operators like grouped reduce or grouped 
aggregations.
    */
@@ -442,6 +464,26 @@ class DataStream[T](stream: JavaStream[T]) {
   def broadcast: DataStream[T] = asScalaStream(stream.broadcast())
 
   /**
+    * Sets the partitioning of the [[DataStream]] so that the output elements
+    * are broadcasted to every parallel instance of the next operation. In 
addition,
+    * it implicitly creates as many
+    * [[org.apache.flink.api.common.state.BroadcastState broadcast states]]
+    * as the specified descriptors which can be used to store the element of 
the stream.
+    *
+    * @param broadcastStateDescriptors the descriptors of the broadcast states 
to create.
+    * @return A [[BroadcastStream]] which can be used in the
+    *         [[DataStream.connect(BroadcastStream)]] to create a
+    *         [[BroadcastConnectedStream]] for further processing of the 
elements.
+    */
+  @PublicEvolving
+  def broadcast(broadcastStateDescriptors: MapStateDescriptor[_, _]*): 
BroadcastStream[T] = {
+    if (broadcastStateDescriptors == null) {
+      throw new NullPointerException("Map function must not be null.")
+    }
+    stream.broadcast(broadcastStateDescriptors: _*)
+  }
+
+  /**
    * Sets the partitioning of the DataStream so that the output values all go 
to
    * the first instance of the next processing operator. Use this setting with 
care
    * since it might cause a serious performance bottleneck in the application.

http://git-wip-us.apache.org/repos/asf/flink/blob/9628dc89/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/package.scala
----------------------------------------------------------------------
diff --git 
a/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/package.scala
 
b/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/package.scala
index 90f255c..ef96fd1 100644
--- 
a/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/package.scala
+++ 
b/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/package.scala
@@ -24,6 +24,7 @@ import 
org.apache.flink.api.scala.typeutils.{CaseClassTypeInfo, TypeUtils}
 import org.apache.flink.streaming.api.datastream.{ DataStream => JavaStream }
 import org.apache.flink.streaming.api.datastream.{ SplitStream => 
SplitJavaStream }
 import org.apache.flink.streaming.api.datastream.{ ConnectedStreams => 
ConnectedJavaStreams }
+import org.apache.flink.streaming.api.datastream.{ BroadcastConnectedStream => 
BroadcastConnectedJavaStreams }
 import org.apache.flink.streaming.api.datastream.{ KeyedStream => 
KeyedJavaStream }
 
 import language.implicitConversions
@@ -61,8 +62,13 @@ package object scala {
    */
   private[flink] def asScalaStream[IN1, IN2](stream: ConnectedJavaStreams[IN1, 
IN2])
                                              = new ConnectedStreams[IN1, 
IN2](stream)
+  /**
+    * Converts an 
[[org.apache.flink.streaming.api.datastream.BroadcastConnectedStream]] to a
+    * [[org.apache.flink.streaming.api.scala.BroadcastConnectedStream]].
+    */
+  private[flink] def asScalaStream[IN1, IN2](stream: 
BroadcastConnectedJavaStreams[IN1, IN2])
+                                              = new 
BroadcastConnectedStream[IN1, IN2](stream)
 
-  
   private[flink] def fieldNames2Indices(
       typeInfo: TypeInformation[_],
       fields: Array[String]): Array[Int] = {

http://git-wip-us.apache.org/repos/asf/flink/blob/9628dc89/flink-streaming-scala/src/test/scala/org/apache/flink/streaming/api/scala/BroadcastStateITCase.scala
----------------------------------------------------------------------
diff --git 
a/flink-streaming-scala/src/test/scala/org/apache/flink/streaming/api/scala/BroadcastStateITCase.scala
 
b/flink-streaming-scala/src/test/scala/org/apache/flink/streaming/api/scala/BroadcastStateITCase.scala
new file mode 100644
index 0000000..af883e1
--- /dev/null
+++ 
b/flink-streaming-scala/src/test/scala/org/apache/flink/streaming/api/scala/BroadcastStateITCase.scala
@@ -0,0 +1,161 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.api.scala
+
+import org.apache.flink.api.common.state.MapStateDescriptor
+import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation}
+import org.apache.flink.streaming.api.TimeCharacteristic
+import 
org.apache.flink.streaming.api.functions.AssignerWithPunctuatedWatermarks
+import 
org.apache.flink.streaming.api.functions.co.KeyedBroadcastProcessFunction
+import org.apache.flink.streaming.api.functions.sink.RichSinkFunction
+import org.apache.flink.streaming.api.watermark.Watermark
+import org.apache.flink.test.util.AbstractTestBase
+import org.apache.flink.util.Collector
+import org.junit.Assert.assertEquals
+import org.junit.{Assert, Test}
+
+/**
+  * ITCase for the [[org.apache.flink.api.common.state.BroadcastState]].
+  */
+class BroadcastStateITCase extends AbstractTestBase {
+
+  @Test
+  @throws[Exception]
+  def testConnectWithBroadcastTranslation(): Unit = {
+
+    val timerTimestamp = 100000L
+
+    val DESCRIPTOR = new MapStateDescriptor[Long, String](
+      "broadcast-state",
+      BasicTypeInfo.LONG_TYPE_INFO.asInstanceOf[TypeInformation[Long]],
+      BasicTypeInfo.STRING_TYPE_INFO)
+
+    val expected = Map[Long, String](
+      0L -> "test:0",
+      1L -> "test:1",
+      2L -> "test:2",
+      3L -> "test:3",
+      4L -> "test:4",
+      5L -> "test:5")
+
+    val env = StreamExecutionEnvironment.getExecutionEnvironment
+    env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime)
+
+    val srcOne = env
+      .generateSequence(0L, 5L)
+      .assignTimestampsAndWatermarks(new 
AssignerWithPunctuatedWatermarks[Long]() {
+
+        override def extractTimestamp(element: Long, previousElementTimestamp: 
Long): Long =
+          element
+
+        override def checkAndGetNextWatermark(lastElement: Long, 
extractedTimestamp: Long) =
+          new Watermark(extractedTimestamp)
+
+      })
+      .keyBy((value: Long) => value)
+
+    val srcTwo = env
+      .fromCollection(expected.values.toSeq)
+      .assignTimestampsAndWatermarks(new 
AssignerWithPunctuatedWatermarks[String]() {
+
+        override def extractTimestamp(element: String, 
previousElementTimestamp: Long): Long =
+          element.split(":")(1).toLong
+
+        override def checkAndGetNextWatermark(lastElement: String, 
extractedTimestamp: Long) =
+          new Watermark(extractedTimestamp)
+      })
+
+    val broadcast = srcTwo.broadcast(DESCRIPTOR)
+    // the timestamp should be high enough to trigger the timer after all the 
elements arrive.
+    val output = srcOne.connect(broadcast)
+      .process(new TestBroadcastProcessFunction(100000L, expected))
+
+    output
+      .addSink(new TestSink(expected.size))
+      .setParallelism(1)
+    env.execute
+  }
+}
+
+class TestBroadcastProcessFunction(
+        expectedTimestamp: Long,
+        expectedBroadcastState: Map[Long, String])
+    extends KeyedBroadcastProcessFunction[Long, Long, String, String] {
+
+  val localDescriptor = new MapStateDescriptor[Long, String](
+    "broadcast-state",
+    BasicTypeInfo.LONG_TYPE_INFO.asInstanceOf[TypeInformation[Long]],
+    BasicTypeInfo.STRING_TYPE_INFO)
+
+  @throws[Exception]
+  override def processElement(
+      value: Long,
+      ctx: KeyedBroadcastProcessFunction[Long, Long, String, 
String]#KeyedReadOnlyContext,
+      out: Collector[String]): Unit = {
+
+    ctx.timerService.registerEventTimeTimer(expectedTimestamp)
+  }
+
+  @throws[Exception]
+  override def processBroadcastElement(
+      value: String,
+      ctx: KeyedBroadcastProcessFunction[Long, Long, String, 
String]#KeyedContext,
+      out: Collector[String]): Unit = {
+
+    val key = value.split(":")(1).toLong
+    ctx.getBroadcastState(localDescriptor).put(key, value)
+  }
+
+  @throws[Exception]
+  override def onTimer(
+      timestamp: Long,
+      ctx: KeyedBroadcastProcessFunction[Long, Long, String, 
String]#OnTimerContext,
+      out: Collector[String]): Unit = {
+
+    var map = Map[Long, String]()
+
+    import scala.collection.JavaConversions._
+    for (entry <- ctx.getBroadcastState(localDescriptor).immutableEntries()) {
+      val v = expectedBroadcastState.get(entry.getKey).get
+      assertEquals(v, entry.getValue)
+      map += (entry.getKey -> entry.getValue)
+    }
+
+    Assert.assertEquals(expectedBroadcastState, map)
+
+    out.collect(timestamp.toString)
+  }
+}
+
+class TestSink(val expectedOutputCounter: Int) extends 
RichSinkFunction[String] {
+
+  var outputCounter: Int = 0
+
+  override def invoke(value: String) = {
+    outputCounter = outputCounter + 1
+  }
+
+  @throws[Exception]
+  override def close(): Unit = {
+    super.close()
+
+    // make sure that all the timers fired
+    assertEquals(expectedOutputCounter, outputCounter)
+  }
+}

Reply via email to