Repository: flink
Updated Branches:
  refs/heads/master 70e78a620 -> f15a7d2d9


[FLINK-5874] Restrict key types in the DataStream API.

Reject a type from being a key in keyBy() if it is:
1. it is a POJO type but does not override the hashCode() and
   relies on the Object.hashCode() implementation.
2. it is an array of any type.


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/f15a7d2d
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/f15a7d2d
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/f15a7d2d

Branch: refs/heads/master
Commit: f15a7d2d9c9aae72bb3ac3eb2478b3ec4759401b
Parents: 70e78a6
Author: kl0u <[email protected]>
Authored: Wed Mar 8 12:11:07 2017 +0100
Committer: kl0u <[email protected]>
Committed: Fri Mar 10 17:58:00 2017 +0100

----------------------------------------------------------------------
 docs/dev/datastream_api.md                      |   9 +
 .../streaming/api/datastream/DataStream.java    |   4 +-
 .../streaming/api/datastream/KeyedStream.java   |  77 +++++-
 .../api/graph/StreamGraphGenerator.java         |   6 +-
 .../flink/streaming/api/DataStreamTest.java     | 238 +++++++++++++++++++
 5 files changed, 327 insertions(+), 7 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/f15a7d2d/docs/dev/datastream_api.md
----------------------------------------------------------------------
diff --git a/docs/dev/datastream_api.md b/docs/dev/datastream_api.md
index df13295..728c945 100644
--- a/docs/dev/datastream_api.md
+++ b/docs/dev/datastream_api.md
@@ -216,6 +216,15 @@ dataStream.filter(new FilterFunction<Integer>() {
 dataStream.keyBy("someKey") // Key by field "someKey"
 dataStream.keyBy(0) // Key by the first element of a Tuple
     {% endhighlight %}
+            <p>
+            <span class="label label-danger">Attention</span> 
+            A type <strong>cannot be a key</strong> if:
+           <ol> 
+           <li> it is a POJO type but does not override the 
<em>hashCode()</em> method and 
+           relies on the <em>Object.hashCode()</em> implementation.</li>
+           <li> it is an array of any type.</li>
+           </ol>
+           </p>
           </td>
         </tr>
         <tr>

http://git-wip-us.apache.org/repos/asf/flink/blob/f15a7d2d/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/DataStream.java
----------------------------------------------------------------------
diff --git 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/DataStream.java
 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/DataStream.java
index 8fcaf6b..71ef048 100644
--- 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/DataStream.java
+++ 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/DataStream.java
@@ -282,9 +282,9 @@ public class DataStream<T> {
        }
 
        /**
-        * Partitions the operator state of a {@link DataStream}using field 
expressions. 
+        * Partitions the operator state of a {@link DataStream} using field 
expressions.
         * A field expression is either the name of a public field or a getter 
method with parentheses
-        * of the {@link DataStream}S underlying type. A dot can be used to 
drill
+        * of the {@link DataStream}'s underlying type. A dot can be used to 
drill
         * down into objects, as in {@code "field1.getInnerField2()" }.
         *
         * @param fields

http://git-wip-us.apache.org/repos/asf/flink/blob/f15a7d2d/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/KeyedStream.java
----------------------------------------------------------------------
diff --git 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/KeyedStream.java
 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/KeyedStream.java
index 7c9f5bc..860aac6 100644
--- 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/KeyedStream.java
+++ 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/KeyedStream.java
@@ -17,18 +17,25 @@
 
 package org.apache.flink.streaming.api.datastream;
 
+import org.apache.commons.lang3.StringUtils;
 import org.apache.flink.annotation.Internal;
 import org.apache.flink.annotation.Public;
 import org.apache.flink.annotation.PublicEvolving;
+import org.apache.flink.api.common.InvalidProgramException;
 import org.apache.flink.api.common.functions.FlatMapFunction;
 import org.apache.flink.api.common.functions.FoldFunction;
 import org.apache.flink.api.common.functions.ReduceFunction;
 import org.apache.flink.api.common.state.FoldingStateDescriptor;
 import org.apache.flink.api.common.state.ReducingStateDescriptor;
 import org.apache.flink.api.common.state.ValueStateDescriptor;
+import org.apache.flink.api.common.typeinfo.BasicArrayTypeInfo;
+import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
 import org.apache.flink.api.common.typeinfo.TypeInformation;
 import org.apache.flink.api.java.Utils;
 import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.api.java.typeutils.PojoTypeInfo;
+import org.apache.flink.api.java.typeutils.TupleTypeInfoBase;
 import org.apache.flink.api.java.typeutils.TypeExtractor;
 import org.apache.flink.streaming.api.TimeCharacteristic;
 import org.apache.flink.streaming.api.functions.ProcessFunction;
@@ -61,6 +68,9 @@ import 
org.apache.flink.streaming.api.windowing.windows.Window;
 import 
org.apache.flink.streaming.runtime.partitioner.KeyGroupStreamPartitioner;
 import org.apache.flink.streaming.runtime.partitioner.StreamPartitioner;
 
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Stack;
 import java.util.UUID;
 
 /**
@@ -114,9 +124,72 @@ public class KeyedStream<T, KEY> extends DataStream<T> {
                                dataStream.getTransformation(),
                                new KeyGroupStreamPartitioner<>(keySelector, 
StreamGraphGenerator.DEFAULT_LOWER_BOUND_MAX_PARALLELISM)));
                this.keySelector = keySelector;
-               this.keyType = keyType;
+               this.keyType = validateKeyType(keyType);
        }
-       
+
+       /**
+        * Validates that a given type of element (as encoded by the provided 
{@link TypeInformation}) can be
+        * used as a key in the {@code DataStream.keyBy()} operation. This is 
done by searching depth-first the
+        * key type and checking if each of the composite types satisfies the 
required conditions
+        * (see {@link #validateKeyTypeIsHashable(TypeInformation)}).
+        *
+        * @param keyType The {@link TypeInformation} of the key.
+        */
+       private TypeInformation<KEY> validateKeyType(TypeInformation<KEY> 
keyType) {
+               Stack<TypeInformation<?>> stack = new Stack<>();
+               stack.push(keyType);
+
+               List<TypeInformation<?>> unsupportedTypes = new ArrayList<>();
+
+               while (!stack.isEmpty()) {
+                       TypeInformation<?> typeInfo = stack.pop();
+
+                       if (!validateKeyTypeIsHashable(typeInfo)) {
+                               unsupportedTypes.add(typeInfo);
+                       }
+                       
+                       if (typeInfo instanceof TupleTypeInfoBase) {
+                               for (int i = 0; i < typeInfo.getArity(); i++) {
+                                       stack.push(((TupleTypeInfoBase) 
typeInfo).getTypeAt(i));        
+                               }
+                       }
+               }
+
+               if (!unsupportedTypes.isEmpty()) {
+                       throw new InvalidProgramException("Type " + keyType + " 
cannot be used as key. Contained " +
+                                       "UNSUPPORTED key types: " + 
StringUtils.join(unsupportedTypes, ", ") + ". Look " +
+                                       "at the keyBy() documentation for the 
conditions a type has to satisfy in order to be " +
+                                       "eligible for a key.");
+               }
+
+               return keyType;
+       }
+
+       /**
+        * Validates that a given type of element (as encoded by the provided 
{@link TypeInformation}) can be
+        * used as a key in the {@code DataStream.keyBy()} operation.
+        *
+        * @param type The {@link TypeInformation} of the type to check.
+        * @return {@code false} if:
+        * <ol>
+        *     <li>it is a POJO type but does not override the {@link 
#hashCode()} method and relies on
+        *     the {@link Object#hashCode()} implementation.</li>
+        *     <li>it is an array of any type (see {@link 
PrimitiveArrayTypeInfo}, {@link BasicArrayTypeInfo},
+        *     {@link ObjectArrayTypeInfo}).</li>
+        * </ol>,
+        * {@code true} otherwise.
+        */
+       private boolean validateKeyTypeIsHashable(TypeInformation<?> type) {
+               try {
+                       return (type instanceof PojoTypeInfo)
+                                       ? 
!type.getTypeClass().getMethod("hashCode").getDeclaringClass().equals(Object.class)
+                                       : !(type instanceof 
PrimitiveArrayTypeInfo || type instanceof BasicArrayTypeInfo || type instanceof 
ObjectArrayTypeInfo);
+               } catch (NoSuchMethodException ignored) {
+                       // this should never happen as we are just searching 
for the hashCode() method.
+               }
+               return false;
+       }
+
        // 
------------------------------------------------------------------------
        //  properties
        // 
------------------------------------------------------------------------

http://git-wip-us.apache.org/repos/asf/flink/blob/f15a7d2d/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamGraphGenerator.java
----------------------------------------------------------------------
diff --git 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamGraphGenerator.java
 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamGraphGenerator.java
index bd018c3..de87a66 100644
--- 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamGraphGenerator.java
+++ 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamGraphGenerator.java
@@ -163,7 +163,7 @@ public class StreamGraphGenerator {
 
                Collection<Integer> transformedIds;
                if (transform instanceof OneInputTransformation<?, ?>) {
-                       transformedIds = 
transformOnInputTransform((OneInputTransformation<?, ?>) transform);
+                       transformedIds = 
transformOneInputTransform((OneInputTransformation<?, ?>) transform);
                } else if (transform instanceof TwoInputTransformation<?, ?, 
?>) {
                        transformedIds = 
transformTwoInputTransform((TwoInputTransformation<?, ?, ?>) transform);
                } else if (transform instanceof SourceTransformation<?>) {
@@ -496,10 +496,10 @@ public class StreamGraphGenerator {
         * Transforms a {@code OneInputTransformation}.
         *
         * <p>
-        * This recusively transforms the inputs, creates a new {@code 
StreamNode} in the graph and
+        * This recursively transforms the inputs, creates a new {@code 
StreamNode} in the graph and
         * wired the inputs to this new node.
         */
-       private <IN, OUT> Collection<Integer> 
transformOnInputTransform(OneInputTransformation<IN, OUT> transform) {
+       private <IN, OUT> Collection<Integer> 
transformOneInputTransform(OneInputTransformation<IN, OUT> transform) {
 
                Collection<Integer> inputIds = transform(transform.getInput());
 

http://git-wip-us.apache.org/repos/asf/flink/blob/f15a7d2d/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/DataStreamTest.java
----------------------------------------------------------------------
diff --git 
a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/DataStreamTest.java
 
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/DataStreamTest.java
index a619338..b4d2421 100644
--- 
a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/DataStreamTest.java
+++ 
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/DataStreamTest.java
@@ -19,14 +19,23 @@ package org.apache.flink.streaming.api;
 
 import java.util.List;
 
+import org.apache.flink.api.common.InvalidProgramException;
 import org.apache.flink.api.common.functions.FilterFunction;
 import org.apache.flink.api.common.functions.FlatMapFunction;
 import org.apache.flink.api.common.functions.FoldFunction;
 import org.apache.flink.api.common.functions.Function;
 import org.apache.flink.api.common.functions.MapFunction;
 import org.apache.flink.api.common.functions.Partitioner;
+import org.apache.flink.api.common.typeinfo.BasicArrayTypeInfo;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
 import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.api.java.tuple.Tuple1;
 import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.GenericTypeInfo;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.api.java.typeutils.TupleTypeInfo;
 import org.apache.flink.api.java.typeutils.TypeExtractor;
 import org.apache.flink.streaming.api.collector.selector.OutputSelector;
 import org.apache.flink.streaming.api.datastream.ConnectedStreams;
@@ -63,7 +72,11 @@ import 
org.apache.flink.streaming.runtime.partitioner.ShufflePartitioner;
 import org.apache.flink.streaming.runtime.partitioner.StreamPartitioner;
 import org.apache.flink.util.Collector;
 
+import org.hamcrest.core.StringStartsWith;
+import org.junit.Assert;
+import org.junit.Rule;
 import org.junit.Test;
+import org.junit.rules.ExpectedException;
 
 import static org.junit.Assert.*;
 
@@ -906,6 +919,231 @@ public class DataStreamTest {
        }
 
        /////////////////////////////////////////////////////////////
+       // KeyBy testing
+       /////////////////////////////////////////////////////////////
+
+       @Rule
+       public ExpectedException expectedException = ExpectedException.none();
+
+       @Test
+       public void testPrimitiveArrayKeyRejection() {
+
+               KeySelector<Tuple2<Integer[], String>, int[]> keySelector =
+                               new KeySelector<Tuple2<Integer[], String>, 
int[]>() {
+
+                       @Override
+                       public int[] getKey(Tuple2<Integer[], String> value) 
throws Exception {
+                               int[] ks = new int[value.f0.length];
+                               for (int i = 0; i < ks.length; i++) {
+                                       ks[i] = value.f0[i];
+                               }
+                               return ks;
+                       }
+               };
+
+               testKeyRejection(keySelector, 
PrimitiveArrayTypeInfo.INT_PRIMITIVE_ARRAY_TYPE_INFO);
+       }
+
+       @Test
+       public void testBasicArrayKeyRejection() {
+
+               KeySelector<Tuple2<Integer[], String>, Integer[]> keySelector =
+                               new KeySelector<Tuple2<Integer[], String>, 
Integer[]>() {
+
+                       @Override
+                       public Integer[] getKey(Tuple2<Integer[], String> 
value) throws Exception {
+                               return value.f0;
+                       }
+               };
+
+               testKeyRejection(keySelector, 
BasicArrayTypeInfo.INT_ARRAY_TYPE_INFO);
+       }
+
+       @Test
+       public void testObjectArrayKeyRejection() {
+
+               KeySelector<Tuple2<Integer[], String>, Object[]> keySelector =
+                               new KeySelector<Tuple2<Integer[], String>, 
Object[]>() {
+
+                                       @Override
+                                       public Object[] 
getKey(Tuple2<Integer[], String> value) throws Exception {
+                                               Object[] ks = new 
Object[value.f0.length];
+                                               for (int i = 0; i < ks.length; 
i++) {
+                                                       ks[i] = new Object();
+                                               }
+                                               return ks;
+                                       }
+                               };
+
+               ObjectArrayTypeInfo<Object[], Object> keyTypeInfo = 
ObjectArrayTypeInfo.getInfoFor(
+                               Object[].class, new 
GenericTypeInfo<>(Object.class));
+
+               testKeyRejection(keySelector, keyTypeInfo);
+       }
+
+       private <K> void testKeyRejection(KeySelector<Tuple2<Integer[], 
String>, K> keySelector, TypeInformation<K> expectedKeyType) {
+               StreamExecutionEnvironment env = 
StreamExecutionEnvironment.getExecutionEnvironment();
+
+               DataStream<Tuple2<Integer[], String>> input = env.fromElements(
+                               new Tuple2<>(new Integer[] {1, 2}, "barfoo")
+               );
+
+               Assert.assertEquals(expectedKeyType, 
TypeExtractor.getKeySelectorTypes(keySelector, input.getType()));
+
+               // adjust the rule
+               expectedException.expect(InvalidProgramException.class);
+               expectedException.expectMessage(new StringStartsWith("Type " + 
expectedKeyType + " cannot be used as key."));
+
+               input.keyBy(keySelector);
+       }
+
+       ////////////////                        Composite Key Tests : POJOs     
                ////////////////
+
+       @Test
+       public void testPOJOWithNestedArrayNoHashCodeKeyRejection() {
+               StreamExecutionEnvironment env = 
StreamExecutionEnvironment.getExecutionEnvironment();
+
+               DataStream<POJOWithHashCode> input = env.fromElements(
+                               new POJOWithHashCode(new int[] {1, 2}));
+
+               TypeInformation<?> expectedTypeInfo = new 
TupleTypeInfo<Tuple1<int[]>>(
+                               
PrimitiveArrayTypeInfo.INT_PRIMITIVE_ARRAY_TYPE_INFO);
+
+               // adjust the rule
+               expectedException.expect(InvalidProgramException.class);
+               expectedException.expectMessage(new StringStartsWith("Type " + 
expectedTypeInfo + " cannot be used as key."));
+
+               input.keyBy("id");
+       }
+
+       @Test
+       public void testPOJOWithNestedArrayAndHashCodeWorkAround() {
+               StreamExecutionEnvironment env = 
StreamExecutionEnvironment.getExecutionEnvironment();
+
+               DataStream<POJOWithHashCode> input = env.fromElements(
+                               new POJOWithHashCode(new int[] {1, 2}));
+
+               input.keyBy(new KeySelector<POJOWithHashCode, 
POJOWithHashCode>() {
+                       @Override
+                       public POJOWithHashCode getKey(POJOWithHashCode value) 
throws Exception {
+                               return value;
+                       }
+               }).addSink(new SinkFunction<POJOWithHashCode>() {
+                       @Override
+                       public void invoke(POJOWithHashCode value) throws 
Exception {
+                               Assert.assertEquals(value.getId(), new int[]{1, 
2});
+                       }
+               });
+       }
+
+       @Test
+       public void testPOJOnoHashCodeKeyRejection() {
+
+               KeySelector<POJOWithoutHashCode, POJOWithoutHashCode> 
keySelector =
+                               new KeySelector<POJOWithoutHashCode, 
POJOWithoutHashCode>() {
+                                       @Override
+                                       public POJOWithoutHashCode 
getKey(POJOWithoutHashCode value) throws Exception {
+                                               return value;
+                                       }
+                               };
+
+               StreamExecutionEnvironment env = 
StreamExecutionEnvironment.getExecutionEnvironment();
+
+               DataStream<POJOWithoutHashCode> input = env.fromElements(
+                               new POJOWithoutHashCode(new int[] {1, 2}));
+
+               // adjust the rule
+               expectedException.expect(InvalidProgramException.class);
+
+               input.keyBy(keySelector);
+       }
+
+       ////////////////                        Composite Key Tests : Tuples    
                ////////////////
+
+       @Test
+       public void testTupleNestedArrayKeyRejection() {
+               StreamExecutionEnvironment env = 
StreamExecutionEnvironment.getExecutionEnvironment();
+
+               DataStream<Tuple2<Integer[], String>> input = env.fromElements(
+                               new Tuple2<>(new Integer[] {1, 2}, 
"test-test"));
+
+               TypeInformation<?> expectedTypeInfo = new 
TupleTypeInfo<Tuple2<Integer[], String>>(
+                               BasicArrayTypeInfo.INT_ARRAY_TYPE_INFO, 
BasicTypeInfo.STRING_TYPE_INFO);
+
+               // adjust the rule
+               expectedException.expect(InvalidProgramException.class);
+               expectedException.expectMessage(new StringStartsWith("Type " + 
expectedTypeInfo + " cannot be used as key."));
+
+               input.keyBy(new KeySelector<Tuple2<Integer[],String>, 
Tuple2<Integer[],String>>() {
+                       @Override
+                       public Tuple2<Integer[], String> 
getKey(Tuple2<Integer[], String> value) throws Exception {
+                               return value;
+                       }
+               });
+       }
+
+       @Test
+       public void testPrimitiveKeyAcceptance() throws Exception {
+               StreamExecutionEnvironment env = 
StreamExecutionEnvironment.getExecutionEnvironment();
+               env.setParallelism(1);
+               env.setMaxParallelism(1);
+
+               DataStream<Integer> input = env.fromElements(new 
Integer(10000));
+
+               KeyedStream<Integer, Object> keyedStream = input.keyBy(new 
KeySelector<Integer, Object>() {
+                       @Override
+                       public Object getKey(Integer value) throws Exception {
+                               return value;
+                       }
+               });
+
+               keyedStream.addSink(new SinkFunction<Integer>() {
+                       @Override
+                       public void invoke(Integer value) throws Exception {
+                               Assert.assertEquals(10000L, (long) value);
+                       }
+               });
+       }
+
+       public static class POJOWithoutHashCode {
+
+               private int[] id;
+
+               public POJOWithoutHashCode() {}
+
+               public POJOWithoutHashCode(int[] id) {
+                       this.id = id;
+               }
+
+               public int[] getId() {
+                       return id;
+               }
+
+               public void setId(int[] id) {
+                       this.id = id;
+               }
+       }
+
+       public static class POJOWithHashCode extends POJOWithoutHashCode {
+
+               public POJOWithHashCode() {
+               }
+
+               public POJOWithHashCode(int[] id) {
+                       super(id);
+               }
+
+               @Override
+               public int hashCode() {
+                       int hash = 31;
+                       for (int i : getId()) {
+                               hash = 37 * hash + i;
+                       }
+                       return hash;
+               }
+       }
+
+       /////////////////////////////////////////////////////////////
        // Utilities
        /////////////////////////////////////////////////////////////
 

Reply via email to