Repository: flink Updated Branches: refs/heads/master 70e78a620 -> f15a7d2d9
[FLINK-5874] Restrict key types in the DataStream API. Reject a type from being a key in keyBy() if it is: 1. it is a POJO type but does not override the hashCode() and relies on the Object.hashCode() implementation. 2. it is an array of any type. Project: http://git-wip-us.apache.org/repos/asf/flink/repo Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/f15a7d2d Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/f15a7d2d Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/f15a7d2d Branch: refs/heads/master Commit: f15a7d2d9c9aae72bb3ac3eb2478b3ec4759401b Parents: 70e78a6 Author: kl0u <[email protected]> Authored: Wed Mar 8 12:11:07 2017 +0100 Committer: kl0u <[email protected]> Committed: Fri Mar 10 17:58:00 2017 +0100 ---------------------------------------------------------------------- docs/dev/datastream_api.md | 9 + .../streaming/api/datastream/DataStream.java | 4 +- .../streaming/api/datastream/KeyedStream.java | 77 +++++- .../api/graph/StreamGraphGenerator.java | 6 +- .../flink/streaming/api/DataStreamTest.java | 238 +++++++++++++++++++ 5 files changed, 327 insertions(+), 7 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/flink/blob/f15a7d2d/docs/dev/datastream_api.md ---------------------------------------------------------------------- diff --git a/docs/dev/datastream_api.md b/docs/dev/datastream_api.md index df13295..728c945 100644 --- a/docs/dev/datastream_api.md +++ b/docs/dev/datastream_api.md @@ -216,6 +216,15 @@ dataStream.filter(new FilterFunction<Integer>() { dataStream.keyBy("someKey") // Key by field "someKey" dataStream.keyBy(0) // Key by the first element of a Tuple {% endhighlight %} + <p> + <span class="label label-danger">Attention</span> + A type <strong>cannot be a key</strong> if: + <ol> + <li> it is a POJO type but does not override the <em>hashCode()</em> method and + relies on the <em>Object.hashCode()</em> implementation.</li> + <li> it is an array of any type.</li> + </ol> + </p> </td> </tr> <tr> http://git-wip-us.apache.org/repos/asf/flink/blob/f15a7d2d/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/DataStream.java ---------------------------------------------------------------------- diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/DataStream.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/DataStream.java index 8fcaf6b..71ef048 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/DataStream.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/DataStream.java @@ -282,9 +282,9 @@ public class DataStream<T> { } /** - * Partitions the operator state of a {@link DataStream}using field expressions. + * Partitions the operator state of a {@link DataStream} using field expressions. * A field expression is either the name of a public field or a getter method with parentheses - * of the {@link DataStream}S underlying type. A dot can be used to drill + * of the {@link DataStream}'s underlying type. A dot can be used to drill * down into objects, as in {@code "field1.getInnerField2()" }. * * @param fields http://git-wip-us.apache.org/repos/asf/flink/blob/f15a7d2d/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/KeyedStream.java ---------------------------------------------------------------------- diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/KeyedStream.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/KeyedStream.java index 7c9f5bc..860aac6 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/KeyedStream.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/KeyedStream.java @@ -17,18 +17,25 @@ package org.apache.flink.streaming.api.datastream; +import org.apache.commons.lang3.StringUtils; import org.apache.flink.annotation.Internal; import org.apache.flink.annotation.Public; import org.apache.flink.annotation.PublicEvolving; +import org.apache.flink.api.common.InvalidProgramException; import org.apache.flink.api.common.functions.FlatMapFunction; import org.apache.flink.api.common.functions.FoldFunction; import org.apache.flink.api.common.functions.ReduceFunction; import org.apache.flink.api.common.state.FoldingStateDescriptor; import org.apache.flink.api.common.state.ReducingStateDescriptor; import org.apache.flink.api.common.state.ValueStateDescriptor; +import org.apache.flink.api.common.typeinfo.BasicArrayTypeInfo; +import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.java.Utils; import org.apache.flink.api.java.functions.KeySelector; +import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo; +import org.apache.flink.api.java.typeutils.PojoTypeInfo; +import org.apache.flink.api.java.typeutils.TupleTypeInfoBase; import org.apache.flink.api.java.typeutils.TypeExtractor; import org.apache.flink.streaming.api.TimeCharacteristic; import org.apache.flink.streaming.api.functions.ProcessFunction; @@ -61,6 +68,9 @@ import org.apache.flink.streaming.api.windowing.windows.Window; import org.apache.flink.streaming.runtime.partitioner.KeyGroupStreamPartitioner; import org.apache.flink.streaming.runtime.partitioner.StreamPartitioner; +import java.util.ArrayList; +import java.util.List; +import java.util.Stack; import java.util.UUID; /** @@ -114,9 +124,72 @@ public class KeyedStream<T, KEY> extends DataStream<T> { dataStream.getTransformation(), new KeyGroupStreamPartitioner<>(keySelector, StreamGraphGenerator.DEFAULT_LOWER_BOUND_MAX_PARALLELISM))); this.keySelector = keySelector; - this.keyType = keyType; + this.keyType = validateKeyType(keyType); } - + + /** + * Validates that a given type of element (as encoded by the provided {@link TypeInformation}) can be + * used as a key in the {@code DataStream.keyBy()} operation. This is done by searching depth-first the + * key type and checking if each of the composite types satisfies the required conditions + * (see {@link #validateKeyTypeIsHashable(TypeInformation)}). + * + * @param keyType The {@link TypeInformation} of the key. + */ + private TypeInformation<KEY> validateKeyType(TypeInformation<KEY> keyType) { + Stack<TypeInformation<?>> stack = new Stack<>(); + stack.push(keyType); + + List<TypeInformation<?>> unsupportedTypes = new ArrayList<>(); + + while (!stack.isEmpty()) { + TypeInformation<?> typeInfo = stack.pop(); + + if (!validateKeyTypeIsHashable(typeInfo)) { + unsupportedTypes.add(typeInfo); + } + + if (typeInfo instanceof TupleTypeInfoBase) { + for (int i = 0; i < typeInfo.getArity(); i++) { + stack.push(((TupleTypeInfoBase) typeInfo).getTypeAt(i)); + } + } + } + + if (!unsupportedTypes.isEmpty()) { + throw new InvalidProgramException("Type " + keyType + " cannot be used as key. Contained " + + "UNSUPPORTED key types: " + StringUtils.join(unsupportedTypes, ", ") + ". Look " + + "at the keyBy() documentation for the conditions a type has to satisfy in order to be " + + "eligible for a key."); + } + + return keyType; + } + + /** + * Validates that a given type of element (as encoded by the provided {@link TypeInformation}) can be + * used as a key in the {@code DataStream.keyBy()} operation. + * + * @param type The {@link TypeInformation} of the type to check. + * @return {@code false} if: + * <ol> + * <li>it is a POJO type but does not override the {@link #hashCode()} method and relies on + * the {@link Object#hashCode()} implementation.</li> + * <li>it is an array of any type (see {@link PrimitiveArrayTypeInfo}, {@link BasicArrayTypeInfo}, + * {@link ObjectArrayTypeInfo}).</li> + * </ol>, + * {@code true} otherwise. + */ + private boolean validateKeyTypeIsHashable(TypeInformation<?> type) { + try { + return (type instanceof PojoTypeInfo) + ? !type.getTypeClass().getMethod("hashCode").getDeclaringClass().equals(Object.class) + : !(type instanceof PrimitiveArrayTypeInfo || type instanceof BasicArrayTypeInfo || type instanceof ObjectArrayTypeInfo); + } catch (NoSuchMethodException ignored) { + // this should never happen as we are just searching for the hashCode() method. + } + return false; + } + // ------------------------------------------------------------------------ // properties // ------------------------------------------------------------------------ http://git-wip-us.apache.org/repos/asf/flink/blob/f15a7d2d/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamGraphGenerator.java ---------------------------------------------------------------------- diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamGraphGenerator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamGraphGenerator.java index bd018c3..de87a66 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamGraphGenerator.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamGraphGenerator.java @@ -163,7 +163,7 @@ public class StreamGraphGenerator { Collection<Integer> transformedIds; if (transform instanceof OneInputTransformation<?, ?>) { - transformedIds = transformOnInputTransform((OneInputTransformation<?, ?>) transform); + transformedIds = transformOneInputTransform((OneInputTransformation<?, ?>) transform); } else if (transform instanceof TwoInputTransformation<?, ?, ?>) { transformedIds = transformTwoInputTransform((TwoInputTransformation<?, ?, ?>) transform); } else if (transform instanceof SourceTransformation<?>) { @@ -496,10 +496,10 @@ public class StreamGraphGenerator { * Transforms a {@code OneInputTransformation}. * * <p> - * This recusively transforms the inputs, creates a new {@code StreamNode} in the graph and + * This recursively transforms the inputs, creates a new {@code StreamNode} in the graph and * wired the inputs to this new node. */ - private <IN, OUT> Collection<Integer> transformOnInputTransform(OneInputTransformation<IN, OUT> transform) { + private <IN, OUT> Collection<Integer> transformOneInputTransform(OneInputTransformation<IN, OUT> transform) { Collection<Integer> inputIds = transform(transform.getInput()); http://git-wip-us.apache.org/repos/asf/flink/blob/f15a7d2d/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/DataStreamTest.java ---------------------------------------------------------------------- diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/DataStreamTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/DataStreamTest.java index a619338..b4d2421 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/DataStreamTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/DataStreamTest.java @@ -19,14 +19,23 @@ package org.apache.flink.streaming.api; import java.util.List; +import org.apache.flink.api.common.InvalidProgramException; import org.apache.flink.api.common.functions.FilterFunction; import org.apache.flink.api.common.functions.FlatMapFunction; import org.apache.flink.api.common.functions.FoldFunction; import org.apache.flink.api.common.functions.Function; import org.apache.flink.api.common.functions.MapFunction; import org.apache.flink.api.common.functions.Partitioner; +import org.apache.flink.api.common.typeinfo.BasicArrayTypeInfo; +import org.apache.flink.api.common.typeinfo.BasicTypeInfo; +import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo; +import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.java.functions.KeySelector; +import org.apache.flink.api.java.tuple.Tuple1; import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.typeutils.GenericTypeInfo; +import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo; +import org.apache.flink.api.java.typeutils.TupleTypeInfo; import org.apache.flink.api.java.typeutils.TypeExtractor; import org.apache.flink.streaming.api.collector.selector.OutputSelector; import org.apache.flink.streaming.api.datastream.ConnectedStreams; @@ -63,7 +72,11 @@ import org.apache.flink.streaming.runtime.partitioner.ShufflePartitioner; import org.apache.flink.streaming.runtime.partitioner.StreamPartitioner; import org.apache.flink.util.Collector; +import org.hamcrest.core.StringStartsWith; +import org.junit.Assert; +import org.junit.Rule; import org.junit.Test; +import org.junit.rules.ExpectedException; import static org.junit.Assert.*; @@ -906,6 +919,231 @@ public class DataStreamTest { } ///////////////////////////////////////////////////////////// + // KeyBy testing + ///////////////////////////////////////////////////////////// + + @Rule + public ExpectedException expectedException = ExpectedException.none(); + + @Test + public void testPrimitiveArrayKeyRejection() { + + KeySelector<Tuple2<Integer[], String>, int[]> keySelector = + new KeySelector<Tuple2<Integer[], String>, int[]>() { + + @Override + public int[] getKey(Tuple2<Integer[], String> value) throws Exception { + int[] ks = new int[value.f0.length]; + for (int i = 0; i < ks.length; i++) { + ks[i] = value.f0[i]; + } + return ks; + } + }; + + testKeyRejection(keySelector, PrimitiveArrayTypeInfo.INT_PRIMITIVE_ARRAY_TYPE_INFO); + } + + @Test + public void testBasicArrayKeyRejection() { + + KeySelector<Tuple2<Integer[], String>, Integer[]> keySelector = + new KeySelector<Tuple2<Integer[], String>, Integer[]>() { + + @Override + public Integer[] getKey(Tuple2<Integer[], String> value) throws Exception { + return value.f0; + } + }; + + testKeyRejection(keySelector, BasicArrayTypeInfo.INT_ARRAY_TYPE_INFO); + } + + @Test + public void testObjectArrayKeyRejection() { + + KeySelector<Tuple2<Integer[], String>, Object[]> keySelector = + new KeySelector<Tuple2<Integer[], String>, Object[]>() { + + @Override + public Object[] getKey(Tuple2<Integer[], String> value) throws Exception { + Object[] ks = new Object[value.f0.length]; + for (int i = 0; i < ks.length; i++) { + ks[i] = new Object(); + } + return ks; + } + }; + + ObjectArrayTypeInfo<Object[], Object> keyTypeInfo = ObjectArrayTypeInfo.getInfoFor( + Object[].class, new GenericTypeInfo<>(Object.class)); + + testKeyRejection(keySelector, keyTypeInfo); + } + + private <K> void testKeyRejection(KeySelector<Tuple2<Integer[], String>, K> keySelector, TypeInformation<K> expectedKeyType) { + StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); + + DataStream<Tuple2<Integer[], String>> input = env.fromElements( + new Tuple2<>(new Integer[] {1, 2}, "barfoo") + ); + + Assert.assertEquals(expectedKeyType, TypeExtractor.getKeySelectorTypes(keySelector, input.getType())); + + // adjust the rule + expectedException.expect(InvalidProgramException.class); + expectedException.expectMessage(new StringStartsWith("Type " + expectedKeyType + " cannot be used as key.")); + + input.keyBy(keySelector); + } + + //////////////// Composite Key Tests : POJOs //////////////// + + @Test + public void testPOJOWithNestedArrayNoHashCodeKeyRejection() { + StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); + + DataStream<POJOWithHashCode> input = env.fromElements( + new POJOWithHashCode(new int[] {1, 2})); + + TypeInformation<?> expectedTypeInfo = new TupleTypeInfo<Tuple1<int[]>>( + PrimitiveArrayTypeInfo.INT_PRIMITIVE_ARRAY_TYPE_INFO); + + // adjust the rule + expectedException.expect(InvalidProgramException.class); + expectedException.expectMessage(new StringStartsWith("Type " + expectedTypeInfo + " cannot be used as key.")); + + input.keyBy("id"); + } + + @Test + public void testPOJOWithNestedArrayAndHashCodeWorkAround() { + StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); + + DataStream<POJOWithHashCode> input = env.fromElements( + new POJOWithHashCode(new int[] {1, 2})); + + input.keyBy(new KeySelector<POJOWithHashCode, POJOWithHashCode>() { + @Override + public POJOWithHashCode getKey(POJOWithHashCode value) throws Exception { + return value; + } + }).addSink(new SinkFunction<POJOWithHashCode>() { + @Override + public void invoke(POJOWithHashCode value) throws Exception { + Assert.assertEquals(value.getId(), new int[]{1, 2}); + } + }); + } + + @Test + public void testPOJOnoHashCodeKeyRejection() { + + KeySelector<POJOWithoutHashCode, POJOWithoutHashCode> keySelector = + new KeySelector<POJOWithoutHashCode, POJOWithoutHashCode>() { + @Override + public POJOWithoutHashCode getKey(POJOWithoutHashCode value) throws Exception { + return value; + } + }; + + StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); + + DataStream<POJOWithoutHashCode> input = env.fromElements( + new POJOWithoutHashCode(new int[] {1, 2})); + + // adjust the rule + expectedException.expect(InvalidProgramException.class); + + input.keyBy(keySelector); + } + + //////////////// Composite Key Tests : Tuples //////////////// + + @Test + public void testTupleNestedArrayKeyRejection() { + StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); + + DataStream<Tuple2<Integer[], String>> input = env.fromElements( + new Tuple2<>(new Integer[] {1, 2}, "test-test")); + + TypeInformation<?> expectedTypeInfo = new TupleTypeInfo<Tuple2<Integer[], String>>( + BasicArrayTypeInfo.INT_ARRAY_TYPE_INFO, BasicTypeInfo.STRING_TYPE_INFO); + + // adjust the rule + expectedException.expect(InvalidProgramException.class); + expectedException.expectMessage(new StringStartsWith("Type " + expectedTypeInfo + " cannot be used as key.")); + + input.keyBy(new KeySelector<Tuple2<Integer[],String>, Tuple2<Integer[],String>>() { + @Override + public Tuple2<Integer[], String> getKey(Tuple2<Integer[], String> value) throws Exception { + return value; + } + }); + } + + @Test + public void testPrimitiveKeyAcceptance() throws Exception { + StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); + env.setParallelism(1); + env.setMaxParallelism(1); + + DataStream<Integer> input = env.fromElements(new Integer(10000)); + + KeyedStream<Integer, Object> keyedStream = input.keyBy(new KeySelector<Integer, Object>() { + @Override + public Object getKey(Integer value) throws Exception { + return value; + } + }); + + keyedStream.addSink(new SinkFunction<Integer>() { + @Override + public void invoke(Integer value) throws Exception { + Assert.assertEquals(10000L, (long) value); + } + }); + } + + public static class POJOWithoutHashCode { + + private int[] id; + + public POJOWithoutHashCode() {} + + public POJOWithoutHashCode(int[] id) { + this.id = id; + } + + public int[] getId() { + return id; + } + + public void setId(int[] id) { + this.id = id; + } + } + + public static class POJOWithHashCode extends POJOWithoutHashCode { + + public POJOWithHashCode() { + } + + public POJOWithHashCode(int[] id) { + super(id); + } + + @Override + public int hashCode() { + int hash = 31; + for (int i : getId()) { + hash = 37 * hash + i; + } + return hash; + } + } + + ///////////////////////////////////////////////////////////// // Utilities /////////////////////////////////////////////////////////////
