Repository: flink
Updated Branches:
  refs/heads/master 159986292 -> 836998bd6


[FLINK-8667] Expose key in KeyedBroadcastProcessFunction#onTimer()

This closes #5500.


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/836998bd
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/836998bd
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/836998bd

Branch: refs/heads/master
Commit: 836998bd65ef2d0d0276faed189a0dfe8a7a6dc3
Parents: 1599862
Author: Bowen Li <bowenl...@gmail.com>
Authored: Thu Feb 15 21:37:44 2018 +0100
Committer: kkloudas <kklou...@gmail.com>
Committed: Tue Mar 6 17:35:03 2018 +0100

----------------------------------------------------------------------
 .../co/KeyedBroadcastProcessFunction.java       |  5 ++
 .../co/CoBroadcastWithKeyedOperator.java        |  5 ++
 .../flink/streaming/api/DataStreamTest.java     |  8 +-
 .../co/CoBroadcastWithKeyedOperatorTest.java    | 83 +++++++++++---------
 .../api/scala/BroadcastStateITCase.scala        | 14 +++-
 .../streaming/runtime/BroadcastStateITCase.java | 24 +++---
 6 files changed, 86 insertions(+), 53 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/836998bd/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/co/KeyedBroadcastProcessFunction.java
----------------------------------------------------------------------
diff --git 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/co/KeyedBroadcastProcessFunction.java
 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/co/KeyedBroadcastProcessFunction.java
index de9cb32..6e6ae5c 100644
--- 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/co/KeyedBroadcastProcessFunction.java
+++ 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/co/KeyedBroadcastProcessFunction.java
@@ -170,5 +170,10 @@ public abstract class KeyedBroadcastProcessFunction<KS, 
IN1, IN2, OUT> extends B
                 * event or processing time timer.
                 */
                public abstract TimeDomain timeDomain();
+
+               /**
+                * Get the key of the firing timer.
+                */
+               public abstract KS getCurrentKey();
        }
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/836998bd/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/co/CoBroadcastWithKeyedOperator.java
----------------------------------------------------------------------
diff --git 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/co/CoBroadcastWithKeyedOperator.java
 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/co/CoBroadcastWithKeyedOperator.java
index 2bdb683..871363b 100644
--- 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/co/CoBroadcastWithKeyedOperator.java
+++ 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/co/CoBroadcastWithKeyedOperator.java
@@ -325,6 +325,11 @@ public class CoBroadcastWithKeyedOperator<KS, IN1, IN2, 
OUT>
                }
 
                @Override
+               public KS getCurrentKey() {
+                       return timer.getKey();
+               }
+
+               @Override
                public TimerService timerService() {
                        return timerService;
                }

http://git-wip-us.apache.org/repos/asf/flink/blob/836998bd/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/DataStreamTest.java
----------------------------------------------------------------------
diff --git 
a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/DataStreamTest.java
 
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/DataStreamTest.java
index 4fa3fc8..6326672 100644
--- 
a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/DataStreamTest.java
+++ 
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/DataStreamTest.java
@@ -707,7 +707,7 @@ public class DataStreamTest extends TestLogger {
                                        Long value,
                                        Context ctx,
                                        Collector<Integer> out) throws 
Exception {
-
+                               // Do nothing
                        }
 
                        @Override
@@ -715,7 +715,7 @@ public class DataStreamTest extends TestLogger {
                                        long timestamp,
                                        OnTimerContext ctx,
                                        Collector<Integer> out) throws 
Exception {
-
+                               // Do nothing
                        }
                };
 
@@ -777,7 +777,7 @@ public class DataStreamTest extends TestLogger {
                                        Long value,
                                        Context ctx,
                                        Collector<Integer> out) throws 
Exception {
-
+                               // Do nothing
                        }
 
                        @Override
@@ -785,7 +785,7 @@ public class DataStreamTest extends TestLogger {
                                        long timestamp,
                                        OnTimerContext ctx,
                                        Collector<Integer> out) throws 
Exception {
-
+                               // Do nothing
                        }
                };
 

http://git-wip-us.apache.org/repos/asf/flink/blob/836998bd/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/co/CoBroadcastWithKeyedOperatorTest.java
----------------------------------------------------------------------
diff --git 
a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/co/CoBroadcastWithKeyedOperatorTest.java
 
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/co/CoBroadcastWithKeyedOperatorTest.java
index 96607d4..b923b75 100644
--- 
a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/co/CoBroadcastWithKeyedOperatorTest.java
+++ 
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/co/CoBroadcastWithKeyedOperatorTest.java
@@ -38,7 +38,6 @@ import org.apache.flink.util.Collector;
 import org.apache.flink.util.OutputTag;
 import org.apache.flink.util.Preconditions;
 
-import org.junit.Assert;
 import org.junit.Test;
 
 import java.util.ArrayList;
@@ -54,6 +53,11 @@ import java.util.Set;
 import java.util.concurrent.ConcurrentLinkedQueue;
 import java.util.function.Function;
 
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
+
 /**
  * Tests for the {@link CoBroadcastWithKeyedOperator}.
  */
@@ -148,7 +152,7 @@ public class CoBroadcastWithKeyedOperatorTest {
                                                        while (it.hasNext()) {
                                                                
list.add(it.next());
                                                        }
-                                                       
Assert.assertEquals(expectedKeyedStates.get(key), list);
+                                                       
assertEquals(expectedKeyedStates.get(key), list);
                                                }
                                        });
                }
@@ -161,12 +165,13 @@ public class CoBroadcastWithKeyedOperatorTest {
 
        @Test
        public void testFunctionWithTimer() throws Exception {
+               final String expectedKey = "6";
 
                try (
                                TwoInputStreamOperatorTestHarness<String, 
Integer, String> testHarness = getInitializedTestHarness(
                                                BasicTypeInfo.STRING_TYPE_INFO,
                                                new IdentityKeySelector<>(),
-                                               new 
FunctionWithTimerOnKeyed(41L))
+                                               new 
FunctionWithTimerOnKeyed(41L, expectedKey))
                ) {
                        testHarness.processWatermark1(new Watermark(10L));
                        testHarness.processWatermark2(new Watermark(10L));
@@ -174,8 +179,8 @@ public class CoBroadcastWithKeyedOperatorTest {
 
                        testHarness.processWatermark1(new Watermark(40L));
                        testHarness.processWatermark2(new Watermark(40L));
-                       testHarness.processElement1(new StreamRecord<>("6", 
13L));
-                       testHarness.processElement1(new StreamRecord<>("6", 
15L));
+                       testHarness.processElement1(new 
StreamRecord<>(expectedKey, 13L));
+                       testHarness.processElement1(new 
StreamRecord<>(expectedKey, 15L));
 
                        testHarness.processWatermark1(new Watermark(50L));
                        testHarness.processWatermark2(new Watermark(50L));
@@ -203,9 +208,11 @@ public class CoBroadcastWithKeyedOperatorTest {
                private static final long serialVersionUID = 
7496674620398203933L;
 
                private final long timerTS;
+               private final String expectedKey;
 
-               FunctionWithTimerOnKeyed(long timerTS) {
+               FunctionWithTimerOnKeyed(long timerTS, String expectedKey) {
                        this.timerTS = timerTS;
+                       this.expectedKey = expectedKey;
                }
 
                @Override
@@ -221,6 +228,7 @@ public class CoBroadcastWithKeyedOperatorTest {
 
                @Override
                public void onTimer(long timestamp, OnTimerContext ctx, 
Collector<String> out) throws Exception {
+                       assertEquals(expectedKey, ctx.getCurrentKey());
                        out.collect("TIMER:" + timestamp);
                }
        }
@@ -293,7 +301,6 @@ public class CoBroadcastWithKeyedOperatorTest {
 
        @Test
        public void testFunctionWithBroadcastState() throws Exception {
-
                final Map<String, Integer> expectedBroadcastState = new 
HashMap<>();
                expectedBroadcastState.put("5.key", 5);
                expectedBroadcastState.put("34.key", 34);
@@ -301,11 +308,13 @@ public class CoBroadcastWithKeyedOperatorTest {
                expectedBroadcastState.put("12.key", 12);
                expectedBroadcastState.put("98.key", 98);
 
+               final String expectedKey = "trigger";
+
                try (
                                TwoInputStreamOperatorTestHarness<String, 
Integer, String> testHarness = getInitializedTestHarness(
                                                BasicTypeInfo.STRING_TYPE_INFO,
                                                new IdentityKeySelector<>(),
-                                               new 
FunctionWithBroadcastState("key", expectedBroadcastState, 41L))
+                                               new 
FunctionWithBroadcastState("key", expectedBroadcastState, 41L, expectedKey))
                ) {
                        testHarness.processWatermark1(new Watermark(10L));
                        testHarness.processWatermark2(new Watermark(10L));
@@ -316,7 +325,7 @@ public class CoBroadcastWithKeyedOperatorTest {
                        testHarness.processElement2(new StreamRecord<>(12, 
16L));
                        testHarness.processElement2(new StreamRecord<>(98, 
19L));
 
-                       testHarness.processElement1(new 
StreamRecord<>("trigger", 13L));
+                       testHarness.processElement1(new 
StreamRecord<>(expectedKey, 13L));
 
                        testHarness.processElement2(new StreamRecord<>(51, 
21L));
 
@@ -324,29 +333,29 @@ public class CoBroadcastWithKeyedOperatorTest {
                        testHarness.processWatermark2(new Watermark(50L));
 
                        Queue<Object> output = testHarness.getOutput();
-                       Assert.assertEquals(3L, output.size());
+                       assertEquals(3L, output.size());
 
                        Object firstRawWm = output.poll();
-                       Assert.assertTrue(firstRawWm instanceof Watermark);
+                       assertTrue(firstRawWm instanceof Watermark);
                        Watermark firstWm = (Watermark) firstRawWm;
-                       Assert.assertEquals(10L, firstWm.getTimestamp());
+                       assertEquals(10L, firstWm.getTimestamp());
 
                        Object rawOutputElem = output.poll();
-                       Assert.assertTrue(rawOutputElem instanceof 
StreamRecord);
+                       assertTrue(rawOutputElem instanceof StreamRecord);
                        StreamRecord<?> outputRec = (StreamRecord<?>) 
rawOutputElem;
-                       Assert.assertTrue(outputRec.getValue() instanceof 
String);
+                       assertTrue(outputRec.getValue() instanceof String);
                        String outputElem = (String) outputRec.getValue();
 
                        expectedBroadcastState.put("51.key", 51);
                        List<Map.Entry<String, Integer>> expectedEntries = new 
ArrayList<>();
                        
expectedEntries.addAll(expectedBroadcastState.entrySet());
                        String expected = "TS:41 " + 
mapToString(expectedEntries);
-                       Assert.assertEquals(expected, outputElem);
+                       assertEquals(expected, outputElem);
 
                        Object secondRawWm = output.poll();
-                       Assert.assertTrue(secondRawWm instanceof Watermark);
+                       assertTrue(secondRawWm instanceof Watermark);
                        Watermark secondWm = (Watermark) secondRawWm;
-                       Assert.assertEquals(50L, secondWm.getTimestamp());
+                       assertEquals(50L, secondWm.getTimestamp());
                }
        }
 
@@ -357,15 +366,17 @@ public class CoBroadcastWithKeyedOperatorTest {
                private final String keyPostfix;
                private final Map<String, Integer> expectedBroadcastState;
                private final long timerTs;
+               private final String expectedKey;
 
                FunctionWithBroadcastState(
                                final String keyPostfix,
                                final Map<String, Integer> 
expectedBroadcastState,
-                               final long timerTs
-               ) {
+                               final long timerTs,
+                               final String expectedKey) {
                        this.keyPostfix = 
Preconditions.checkNotNull(keyPostfix);
                        this.expectedBroadcastState = 
Preconditions.checkNotNull(expectedBroadcastState);
                        this.timerTs = timerTs;
+                       this.expectedKey = expectedKey;
                }
 
                @Override
@@ -381,14 +392,14 @@ public class CoBroadcastWithKeyedOperatorTest {
                        Iterator<Map.Entry<String, Integer>> iter = 
broadcastStateIt.iterator();
 
                        for (int i = 0; i < expectedBroadcastState.size(); i++) 
{
-                               Assert.assertTrue(iter.hasNext());
+                               assertTrue(iter.hasNext());
 
                                Map.Entry<String, Integer> entry = iter.next();
-                               
Assert.assertTrue(expectedBroadcastState.containsKey(entry.getKey()));
-                               
Assert.assertEquals(expectedBroadcastState.get(entry.getKey()), 
entry.getValue());
+                               
assertTrue(expectedBroadcastState.containsKey(entry.getKey()));
+                               
assertEquals(expectedBroadcastState.get(entry.getKey()), entry.getValue());
                        }
 
-                       Assert.assertFalse(iter.hasNext());
+                       assertFalse(iter.hasNext());
 
                        ctx.timerService().registerEventTimeTimer(timerTs);
                }
@@ -401,6 +412,8 @@ public class CoBroadcastWithKeyedOperatorTest {
                        while (iter.hasNext()) {
                                map.add(iter.next());
                        }
+
+                       assertEquals(expectedKey, ctx.getCurrentKey());
                        final String mapToStr = mapToString(map);
                        out.collect("TS:" + timestamp + " " + mapToStr);
                }
@@ -485,22 +498,22 @@ public class CoBroadcastWithKeyedOperatorTest {
                        Queue<?> output2 = testHarness2.getOutput();
                        Queue<?> output3 = testHarness3.getOutput();
 
-                       Assert.assertEquals(expected.size(), output1.size());
+                       assertEquals(expected.size(), output1.size());
                        for (Object o: output1) {
                                StreamRecord<String> rec = 
(StreamRecord<String>) o;
-                               
Assert.assertTrue(expected.contains(rec.getValue()));
+                               assertTrue(expected.contains(rec.getValue()));
                        }
 
-                       Assert.assertEquals(expected.size(), output2.size());
+                       assertEquals(expected.size(), output2.size());
                        for (Object o: output2) {
                                StreamRecord<String> rec = 
(StreamRecord<String>) o;
-                               
Assert.assertTrue(expected.contains(rec.getValue()));
+                               assertTrue(expected.contains(rec.getValue()));
                        }
 
-                       Assert.assertEquals(expected.size(), output3.size());
+                       assertEquals(expected.size(), output3.size());
                        for (Object o: output3) {
                                StreamRecord<String> rec = 
(StreamRecord<String>) o;
-                               
Assert.assertTrue(expected.contains(rec.getValue()));
+                               assertTrue(expected.contains(rec.getValue()));
                        }
                }
        }
@@ -583,16 +596,16 @@ public class CoBroadcastWithKeyedOperatorTest {
                        Queue<?> output1 = testHarness1.getOutput();
                        Queue<?> output2 = testHarness2.getOutput();
 
-                       Assert.assertEquals(expected.size(), output1.size());
+                       assertEquals(expected.size(), output1.size());
                        for (Object o: output1) {
                                StreamRecord<String> rec = 
(StreamRecord<String>) o;
-                               
Assert.assertTrue(expected.contains(rec.getValue()));
+                               assertTrue(expected.contains(rec.getValue()));
                        }
 
-                       Assert.assertEquals(expected.size(), output2.size());
+                       assertEquals(expected.size(), output2.size());
                        for (Object o: output2) {
                                StreamRecord<String> rec = 
(StreamRecord<String>) o;
-                               
Assert.assertTrue(expected.contains(rec.getValue()));
+                               assertTrue(expected.contains(rec.getValue()));
                        }
                }
        }
@@ -653,12 +666,12 @@ public class CoBroadcastWithKeyedOperatorTest {
                        testHarness.processWatermark2(new Watermark(10L));
                        testHarness.processElement2(new StreamRecord<>(5, 12L));
                } catch (NullPointerException e) {
-                       Assert.assertEquals("No key set. This method should not 
be called outside of a keyed context.", e.getMessage());
+                       assertEquals("No key set. This method should not be 
called outside of a keyed context.", e.getMessage());
                        exceptionThrown = true;
                }
 
                if (!exceptionThrown) {
-                       Assert.fail("No exception thrown");
+                       fail("No exception thrown");
                }
        }
 

http://git-wip-us.apache.org/repos/asf/flink/blob/836998bd/flink-streaming-scala/src/test/scala/org/apache/flink/streaming/api/scala/BroadcastStateITCase.scala
----------------------------------------------------------------------
diff --git 
a/flink-streaming-scala/src/test/scala/org/apache/flink/streaming/api/scala/BroadcastStateITCase.scala
 
b/flink-streaming-scala/src/test/scala/org/apache/flink/streaming/api/scala/BroadcastStateITCase.scala
index 6c382d5..55bb3ba 100644
--- 
a/flink-streaming-scala/src/test/scala/org/apache/flink/streaming/api/scala/BroadcastStateITCase.scala
+++ 
b/flink-streaming-scala/src/test/scala/org/apache/flink/streaming/api/scala/BroadcastStateITCase.scala
@@ -28,7 +28,7 @@ import org.apache.flink.streaming.api.watermark.Watermark
 import org.apache.flink.test.util.AbstractTestBase
 import org.apache.flink.util.Collector
 import org.junit.Assert.assertEquals
-import org.junit.{Assert, Test}
+import org.junit.{Test}
 
 /**
   * ITCase for the [[org.apache.flink.api.common.state.BroadcastState]].
@@ -103,13 +103,19 @@ class TestBroadcastProcessFunction(
     BasicTypeInfo.LONG_TYPE_INFO.asInstanceOf[TypeInformation[Long]],
     BasicTypeInfo.STRING_TYPE_INFO)
 
+  var timerToExpectedKey = Map[Long, Long]()
+  var nextTimerTimestamp :Long = expectedTimestamp
+
   @throws[Exception]
   override def processElement(
       value: Long,
       ctx: KeyedBroadcastProcessFunction[Long, Long, String, 
String]#KeyedReadOnlyContext,
       out: Collector[String]): Unit = {
 
-    ctx.timerService.registerEventTimeTimer(expectedTimestamp)
+    val currentTime = nextTimerTimestamp
+    nextTimerTimestamp += 1
+    ctx.timerService.registerEventTimeTimer(currentTime)
+    timerToExpectedKey += (currentTime -> value)
   }
 
   @throws[Exception]
@@ -128,6 +134,8 @@ class TestBroadcastProcessFunction(
       ctx: KeyedBroadcastProcessFunction[Long, Long, String, 
String]#OnTimerContext,
       out: Collector[String]): Unit = {
 
+    assertEquals(timerToExpectedKey(timestamp), ctx.getCurrentKey)
+
     var map = Map[Long, String]()
 
     import scala.collection.JavaConversions._
@@ -137,7 +145,7 @@ class TestBroadcastProcessFunction(
       map += (entry.getKey -> entry.getValue)
     }
 
-    Assert.assertEquals(expectedBroadcastState, map)
+    assertEquals(expectedBroadcastState, map)
 
     out.collect(timestamp.toString)
   }

http://git-wip-us.apache.org/repos/asf/flink/blob/836998bd/flink-tests/src/test/java/org/apache/flink/test/streaming/runtime/BroadcastStateITCase.java
----------------------------------------------------------------------
diff --git 
a/flink-tests/src/test/java/org/apache/flink/test/streaming/runtime/BroadcastStateITCase.java
 
b/flink-tests/src/test/java/org/apache/flink/test/streaming/runtime/BroadcastStateITCase.java
index 868aca9..7ccba33 100644
--- 
a/flink-tests/src/test/java/org/apache/flink/test/streaming/runtime/BroadcastStateITCase.java
+++ 
b/flink-tests/src/test/java/org/apache/flink/test/streaming/runtime/BroadcastStateITCase.java
@@ -32,7 +32,6 @@ import 
org.apache.flink.streaming.api.functions.sink.RichSinkFunction;
 import org.apache.flink.streaming.api.watermark.Watermark;
 import org.apache.flink.util.Collector;
 
-import org.junit.Assert;
 import org.junit.Test;
 
 import javax.annotation.Nullable;
@@ -40,6 +39,8 @@ import javax.annotation.Nullable;
 import java.util.HashMap;
 import java.util.Map;
 
+import static org.junit.Assert.assertEquals;
+
 /**
  * ITCase for the {@link org.apache.flink.api.common.state.BroadcastState}.
  */
@@ -120,7 +121,7 @@ public class BroadcastStateITCase {
                        super.close();
 
                        // make sure that all the timers fired
-                       Assert.assertEquals(expectedOutputCounter, 
outputCounter);
+                       assertEquals(expectedOutputCounter, outputCounter);
                }
        }
 
@@ -145,17 +146,15 @@ public class BroadcastStateITCase {
                private static final long serialVersionUID = 
7616910653561100842L;
 
                private final Map<Long, String> expectedState;
+               private final Map<Long, Long> timerToExpectedKey = new 
HashMap<>();
 
-               private final long timerTimestamp;
+               private long nextTimerTimestamp;
 
                private transient MapStateDescriptor<Long, String> descriptor;
 
-               TestBroadcastProcessFunction(
-                               final long timerTS,
-                               final Map<Long, String> expectedBroadcastState
-               ) {
+               TestBroadcastProcessFunction(final long initialTimerTimestamp, 
final Map<Long, String> expectedBroadcastState) {
                        expectedState = expectedBroadcastState;
-                       timerTimestamp = timerTS;
+                       nextTimerTimestamp = initialTimerTimestamp;
                }
 
                @Override
@@ -169,7 +168,10 @@ public class BroadcastStateITCase {
 
                @Override
                public void processElement(Long value, KeyedReadOnlyContext 
ctx, Collector<String> out) throws Exception {
-                       
ctx.timerService().registerEventTimeTimer(timerTimestamp);
+                       long currentTime = nextTimerTimestamp;
+                       nextTimerTimestamp++;
+                       ctx.timerService().registerEventTimeTimer(currentTime);
+                       timerToExpectedKey.put(currentTime, value);
                }
 
                @Override
@@ -180,14 +182,14 @@ public class BroadcastStateITCase {
 
                @Override
                public void onTimer(long timestamp, OnTimerContext ctx, 
Collector<String> out) throws Exception {
-                       Assert.assertEquals(timerTimestamp, timestamp);
+                       assertEquals(timerToExpectedKey.get(timestamp), 
ctx.getCurrentKey());
 
                        Map<Long, String> map = new HashMap<>();
                        for (Map.Entry<Long, String> entry : 
ctx.getBroadcastState(descriptor).immutableEntries()) {
                                map.put(entry.getKey(), entry.getValue());
                        }
 
-                       Assert.assertEquals(expectedState, map);
+                       assertEquals(expectedState, map);
 
                        out.collect(Long.toString(timestamp));
                }

Reply via email to