Repository: flink Updated Branches: refs/heads/master 5d2da128b -> 870e219d9
http://git-wip-us.apache.org/repos/asf/flink/blob/1f04542e/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/aggregation/ComparableAggregator.java ---------------------------------------------------------------------- diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/aggregation/ComparableAggregator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/aggregation/ComparableAggregator.java index 9331285..465548e 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/aggregation/ComparableAggregator.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/aggregation/ComparableAggregator.java @@ -20,7 +20,7 @@ package org.apache.flink.streaming.api.functions.aggregation; import org.apache.flink.annotation.Internal; import org.apache.flink.api.common.ExecutionConfig; import org.apache.flink.api.common.typeinfo.TypeInformation; -import org.apache.flink.streaming.util.FieldAccessor; +import org.apache.flink.api.java.typeutils.FieldAccessor; @Internal public class ComparableAggregator<T> extends AggregationFunction<T> { @@ -51,7 +51,7 @@ public class ComparableAggregator<T> extends AggregationFunction<T> { AggregationType aggregationType, boolean first, ExecutionConfig config) { - this(aggregationType, FieldAccessor.create(positionToAggregate, typeInfo, config), first); + this(aggregationType, typeInfo.getFieldAccessor(positionToAggregate, config), first); } public ComparableAggregator(String field, @@ -59,7 +59,7 @@ public class ComparableAggregator<T> extends AggregationFunction<T> { AggregationType aggregationType, boolean first, ExecutionConfig config) { - this(aggregationType, FieldAccessor.create(field, typeInfo, config), first); + this(aggregationType, typeInfo.getFieldAccessor(field,config), first); } http://git-wip-us.apache.org/repos/asf/flink/blob/1f04542e/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/aggregation/SumAggregator.java ---------------------------------------------------------------------- diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/aggregation/SumAggregator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/aggregation/SumAggregator.java index cc88eee..90d5e74 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/aggregation/SumAggregator.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/aggregation/SumAggregator.java @@ -23,7 +23,7 @@ import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.api.java.tuple.Tuple; import org.apache.flink.api.java.typeutils.TupleTypeInfo; -import org.apache.flink.streaming.util.FieldAccessor; +import org.apache.flink.api.java.typeutils.FieldAccessor; @Internal public class SumAggregator<T> extends AggregationFunction<T> { @@ -36,7 +36,7 @@ public class SumAggregator<T> extends AggregationFunction<T> { private final boolean isTuple; public SumAggregator(int pos, TypeInformation<T> typeInfo, ExecutionConfig config) { - fieldAccessor = FieldAccessor.create(pos, typeInfo, config); + fieldAccessor = typeInfo.getFieldAccessor(pos, config); adder = SumFunction.getForClass(fieldAccessor.getFieldType().getTypeClass()); if (typeInfo instanceof TupleTypeInfo) { isTuple = true; @@ -48,7 +48,7 @@ public class SumAggregator<T> extends AggregationFunction<T> { } public SumAggregator(String field, TypeInformation<T> typeInfo, ExecutionConfig config) { - fieldAccessor = FieldAccessor.create(field, typeInfo, config); + fieldAccessor = typeInfo.getFieldAccessor(field, config); adder = SumFunction.getForClass(fieldAccessor.getFieldType().getTypeClass()); if (typeInfo instanceof TupleTypeInfo) { isTuple = true; http://git-wip-us.apache.org/repos/asf/flink/blob/1f04542e/flink-streaming-java/src/main/java/org/apache/flink/streaming/util/FieldAccessor.java ---------------------------------------------------------------------- diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/util/FieldAccessor.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/util/FieldAccessor.java deleted file mode 100644 index a23353b..0000000 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/util/FieldAccessor.java +++ /dev/null @@ -1,254 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.streaming.util; - -import org.apache.flink.annotation.Internal; -import org.apache.flink.api.common.ExecutionConfig; -import org.apache.flink.api.common.typeinfo.BasicArrayTypeInfo; -import org.apache.flink.api.common.typeinfo.BasicTypeInfo; -import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo; -import org.apache.flink.api.common.typeinfo.TypeInformation; -import org.apache.flink.api.common.typeutils.CompositeType; -import org.apache.flink.api.java.tuple.Tuple; -import org.apache.flink.api.java.typeutils.PojoTypeInfo; -import org.apache.flink.api.java.typeutils.TupleTypeInfo; -import org.apache.flink.api.java.typeutils.runtime.PojoComparator; -import org.apache.flink.api.java.typeutils.TupleTypeInfoBase; -import org.apache.flink.api.java.typeutils.runtime.TupleSerializerBase; - -import java.io.Serializable; -import java.lang.reflect.Array; -import java.util.List; - -import scala.Product; - - -/** - * These classes encapsulate the logic of accessing a field specified by the user as either an index - * or a field expression string. TypeInformation can also be requested for the field. - * The position index might specify a field of a Tuple, an array, or a simple type (only "0th field"). - */ -@Internal -public abstract class FieldAccessor<R, F> implements Serializable { - - private static final long serialVersionUID = 1L; - - TypeInformation fieldType; - - // Note: Returns the corresponding basic type for array of a primitive type (Integer for int[]). - @SuppressWarnings("unchecked") - public TypeInformation<F> getFieldType() { - return fieldType; - } - - - public abstract F get(R record); - - // Note: This has to return the result, because the SimpleFieldAccessor might not be able to modify the - // record in place. (for example, when R is simply Double) (Unfortunately there is no passing by reference in Java.) - public abstract R set(R record, F fieldValue); - - - - @SuppressWarnings("unchecked") - public static <R, F> FieldAccessor<R, F> create(int pos, TypeInformation<R> typeInfo, ExecutionConfig config) { - if (typeInfo.isTupleType() && ((TupleTypeInfoBase)typeInfo).isCaseClass()) { - return new ProductFieldAccessor<R, F>(pos, typeInfo, config); - } else if (typeInfo.isTupleType()) { - return new TupleFieldAccessor<R, F>(pos, typeInfo); - } else if (typeInfo instanceof BasicArrayTypeInfo || typeInfo instanceof PrimitiveArrayTypeInfo) { - return new ArrayFieldAccessor<R, F>(pos, typeInfo); - } else { - if(pos != 0) { - throw new IndexOutOfBoundsException("Not 0th field selected for a simple type (non-tuple, non-array)."); - } - return (FieldAccessor<R, F>) new SimpleFieldAccessor<R>(typeInfo); - } - } - - public static <R, F> FieldAccessor<R, F> create(String field, TypeInformation<R> typeInfo, ExecutionConfig config) { - if (typeInfo.isTupleType() && ((TupleTypeInfoBase)typeInfo).isCaseClass()) { - int pos = ((TupleTypeInfoBase)typeInfo).getFieldIndex(field); - if(pos == -2) { - throw new RuntimeException("Invalid field selected: " + field); - } - return new ProductFieldAccessor<R, F>(pos, typeInfo, config); - } else if(typeInfo.isTupleType()) { - return new TupleFieldAccessor<R, F>(((TupleTypeInfo)typeInfo).getFieldIndex(field), typeInfo); - } else { - return new PojoFieldAccessor<R, F>(field, typeInfo, config); - } - } - - - - public static class SimpleFieldAccessor<R> extends FieldAccessor<R, R> { - - private static final long serialVersionUID = 1L; - - SimpleFieldAccessor(TypeInformation<R> typeInfo) { - this.fieldType = typeInfo; - } - - @Override - public R get(R record) { - return record; - } - - @Override - public R set(R record, R fieldValue) { - return fieldValue; - } - } - - public static class ArrayFieldAccessor<R, F> extends FieldAccessor<R, F> { - - private static final long serialVersionUID = 1L; - - int pos; - - ArrayFieldAccessor(int pos, TypeInformation typeInfo) { - this.pos = pos; - this.fieldType = BasicTypeInfo.getInfoFor(typeInfo.getTypeClass().getComponentType()); - } - - @SuppressWarnings("unchecked") - @Override - public F get(R record) { - return (F) Array.get(record, pos); - } - - @Override - public R set(R record, F fieldValue) { - Array.set(record, pos, fieldValue); - return record; - } - } - - public static class TupleFieldAccessor<R, F> extends FieldAccessor<R, F> { - - private static final long serialVersionUID = 1L; - - int pos; - - TupleFieldAccessor(int pos, TypeInformation<R> typeInfo) { - this.pos = pos; - this.fieldType = ((TupleTypeInfo)typeInfo).getTypeAt(pos); - } - - @SuppressWarnings("unchecked") - @Override - public F get(R record) { - Tuple tuple = (Tuple) record; - return (F)tuple.getField(pos); - } - - @Override - public R set(R record, F fieldValue) { - Tuple tuple = (Tuple) record; - tuple.setField(fieldValue, pos); - return record; - } - } - - public static class PojoFieldAccessor<R, F> extends FieldAccessor<R, F> { - - private static final long serialVersionUID = 1L; - - PojoComparator comparator; - - PojoFieldAccessor(String field, TypeInformation<R> type, ExecutionConfig config) { - if (!(type instanceof CompositeType<?>)) { - throw new IllegalArgumentException( - "Key expressions are only supported on POJO types and Tuples. " - + "A type is considered a POJO if all its fields are public, or have both getters and setters defined"); - } - - @SuppressWarnings("unchecked") - CompositeType<R> cType = (CompositeType<R>) type; - - if(field.contains(".")) { - throw new IllegalArgumentException("The Pojo field accessor currently doesn't support nested POJOs"); - } - - List<CompositeType.FlatFieldDescriptor> fieldDescriptors = cType.getFlatFields(field); - - int logicalKeyPosition = fieldDescriptors.get(0).getPosition(); - this.fieldType = fieldDescriptors.get(0).getType(); - - if (cType instanceof PojoTypeInfo) { - comparator = (PojoComparator<R>) cType.createComparator( - new int[] { logicalKeyPosition }, new boolean[] { false }, 0, config); - } else { - throw new IllegalArgumentException( - "Key expressions are only supported on POJO types. " - + "A type is considered a POJO if all its fields are public, or have both getters and setters defined"); - } - } - - @SuppressWarnings("unchecked") - @Override - public F get(R record) { - return (F) comparator.accessField(comparator.getKeyFields()[0], record); - } - - @Override - public R set(R record, F fieldValue) { - try { - comparator.getKeyFields()[0].set(record, fieldValue); - } catch (IllegalAccessException e) { - throw new RuntimeException("Could not modify the specified field.", e); - } - return record; - } - } - - public static class ProductFieldAccessor<R, F> extends FieldAccessor<R, F> { - - private static final long serialVersionUID = 1L; - - int pos; - TupleSerializerBase<R> serializer; - Object[] fields; - int length; - - ProductFieldAccessor(int pos, TypeInformation<R> typeInfo, ExecutionConfig config) { - this.pos = pos; - this.fieldType = ((TupleTypeInfoBase<R>)typeInfo).getTypeAt(pos); - this.serializer = (TupleSerializerBase<R>)typeInfo.createSerializer(config); - this.length = this.serializer.getArity(); - this.fields = new Object[this.length]; - } - - @SuppressWarnings("unchecked") - @Override - public F get(R record) { - return (F)((Product)record).productElement(pos); - } - - @Override - public R set(R record, F fieldValue) { - Product prod = (Product)record; - for (int i = 0; i < length; i++) { - fields[i] = prod.productElement(i); - } - fields[pos] = fieldValue; - return serializer.createInstance(fields); - } - } -} http://git-wip-us.apache.org/repos/asf/flink/blob/1f04542e/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/FieldAccessorTest.java ---------------------------------------------------------------------- diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/FieldAccessorTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/FieldAccessorTest.java deleted file mode 100644 index d35089a..0000000 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/FieldAccessorTest.java +++ /dev/null @@ -1,75 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.streaming.util; - -import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo; -import org.apache.flink.api.common.typeinfo.BasicArrayTypeInfo; -import org.apache.flink.api.java.tuple.Tuple2; -import org.apache.flink.api.java.typeutils.TupleTypeInfo; -import org.junit.Test; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.fail; - -// This only tests a fraction of FieldAccessor. The other parts are tested indirectly by AggregationFunctionTest. -public class FieldAccessorTest { - - @Test - @SuppressWarnings("unchecked") - public void arrayFieldAccessorTest() { - int[] a = new int[]{3,5}; - FieldAccessor<int[], Integer> fieldAccessor = - (FieldAccessor<int[], Integer>) (Object) - FieldAccessor.create(1, PrimitiveArrayTypeInfo.getInfoFor(a.getClass()), null); - - assertEquals(Integer.class, fieldAccessor.getFieldType().getTypeClass()); - - assertEquals((Integer)a[1], fieldAccessor.get(a)); - - a = fieldAccessor.set(a, 6); - assertEquals((Integer)a[1], fieldAccessor.get(a)); - - - - Integer[] b = new Integer[]{3,5}; - FieldAccessor<Integer[], Integer> fieldAccessor2 = - (FieldAccessor<Integer[], Integer>) (Object) - FieldAccessor.create(1, BasicArrayTypeInfo.getInfoFor(b.getClass()), null); - - assertEquals(Integer.class, fieldAccessor2.getFieldType().getTypeClass()); - - assertEquals((Integer)b[1], fieldAccessor2.get(b)); - - b = fieldAccessor2.set(b, 6); - assertEquals((Integer)b[1], fieldAccessor2.get(b)); - } - - @Test - @SuppressWarnings("unchecked") - public void tupleFieldAccessorOutOfBoundsTest() { - try { - FieldAccessor<Tuple2<Integer, Integer>, Integer> fieldAccessor = - (FieldAccessor<Tuple2<Integer, Integer>, Integer>) (Object) - FieldAccessor.create(2, TupleTypeInfo.getBasicTupleTypeInfo(Integer.class, Integer.class), - null); - fail(); - } catch (IndexOutOfBoundsException e) { - // Nothing to do here - } - } -} http://git-wip-us.apache.org/repos/asf/flink/blob/1f04542e/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/KeyedStream.scala ---------------------------------------------------------------------- diff --git a/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/KeyedStream.scala b/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/KeyedStream.scala index 1971359..d5cc013 100644 --- a/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/KeyedStream.scala +++ b/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/KeyedStream.scala @@ -216,6 +216,10 @@ class KeyedStream[T, K](javaStream: KeyedJavaStream[T, K]) extends DataStream[T] * Applies an aggregation that that gives the current maximum of the data stream at * the given position by the given key. An independent aggregate is kept per key. * + * @param position + * The field position in the data points to maximize. This is applicable to + * Tuple types, basic and primitive array types, Scala case classes, + * and primitive types (which is considered as having one field). */ def max(position: Int): DataStream[T] = aggregate(AggregationType.MAX, position) @@ -223,6 +227,14 @@ class KeyedStream[T, K](javaStream: KeyedJavaStream[T, K]) extends DataStream[T] * Applies an aggregation that that gives the current maximum of the data stream at * the given field by the given key. An independent aggregate is kept per key. * + * @param field + * In case of a POJO, Scala case class, or Tuple type, the + * name of the (public) field on which to perform the aggregation. + * Additionally, a dot can be used to drill down into nested + * objects, as in `"field1.fieldxy"`. + * Furthermore, an array index can also be specified in case of an array of + * a primitive or basic type; or "0" or "*" can be specified in case of a + * basic type (which is considered as having only one field). */ def max(field: String): DataStream[T] = aggregate(AggregationType.MAX, field) @@ -230,6 +242,10 @@ class KeyedStream[T, K](javaStream: KeyedJavaStream[T, K]) extends DataStream[T] * Applies an aggregation that that gives the current minimum of the data stream at * the given position by the given key. An independent aggregate is kept per key. * + * @param position + * The field position in the data points to minimize. This is applicable to + * Tuple types, basic and primitive array types, Scala case classes, + * and primitive types (which is considered as having one field). */ def min(position: Int): DataStream[T] = aggregate(AggregationType.MIN, position) @@ -237,6 +253,14 @@ class KeyedStream[T, K](javaStream: KeyedJavaStream[T, K]) extends DataStream[T] * Applies an aggregation that that gives the current minimum of the data stream at * the given field by the given key. An independent aggregate is kept per key. * + * @param field + * In case of a POJO, Scala case class, or Tuple type, the + * name of the (public) field on which to perform the aggregation. + * Additionally, a dot can be used to drill down into nested + * objects, as in `"field1.fieldxy"`. + * Furthermore, an array index can also be specified in case of an array of + * a primitive or basic type; or "0" or "*" can be specified in case of a + * basic type (which is considered as having only one field). */ def min(field: String): DataStream[T] = aggregate(AggregationType.MIN, field) @@ -244,6 +268,10 @@ class KeyedStream[T, K](javaStream: KeyedJavaStream[T, K]) extends DataStream[T] * Applies an aggregation that sums the data stream at the given position by the given * key. An independent aggregate is kept per key. * + * @param position + * The field position in the data points to sum. This is applicable to + * Tuple types, basic and primitive array types, Scala case classes, + * and primitive types (which is considered as having one field). */ def sum(position: Int): DataStream[T] = aggregate(AggregationType.SUM, position) @@ -251,6 +279,14 @@ class KeyedStream[T, K](javaStream: KeyedJavaStream[T, K]) extends DataStream[T] * Applies an aggregation that sums the data stream at the given field by the given * key. An independent aggregate is kept per key. * + * @param field + * In case of a POJO, Scala case class, or Tuple type, the + * name of the (public) field on which to perform the aggregation. + * Additionally, a dot can be used to drill down into nested + * objects, as in `"field1.fieldxy"`. + * Furthermore, an array index can also be specified in case of an array of + * a primitive or basic type; or "0" or "*" can be specified in case of a + * basic type (which is considered as having only one field). */ def sum(field: String): DataStream[T] = aggregate(AggregationType.SUM, field) @@ -259,34 +295,58 @@ class KeyedStream[T, K](javaStream: KeyedJavaStream[T, K]) extends DataStream[T] * the given position by the given key. An independent aggregate is kept per key. * When equality, the first element is returned with the minimal value. * + * @param position + * The field position in the data points to minimize. This is applicable to + * Tuple types, basic and primitive array types, Scala case classes, + * and primitive types (which is considered as having one field). */ def minBy(position: Int): DataStream[T] = aggregate(AggregationType .MINBY, position) /** - * Applies an aggregation that that gives the current minimum element of the data stream by - * the given field by the given key. An independent aggregate is kept per key. - * When equality, the first element is returned with the minimal value. - * - */ + * Applies an aggregation that that gives the current minimum element of the data stream by + * the given field by the given key. An independent aggregate is kept per key. + * When equality, the first element is returned with the minimal value. + * + * @param field + * In case of a POJO, Scala case class, or Tuple type, the + * name of the (public) field on which to perform the aggregation. + * Additionally, a dot can be used to drill down into nested + * objects, as in `"field1.fieldxy"`. + * Furthermore, an array index can also be specified in case of an array of + * a primitive or basic type; or "0" or "*" can be specified in case of a + * basic type (which is considered as having only one field). + */ def minBy(field: String): DataStream[T] = aggregate(AggregationType .MINBY, field ) /** - * Applies an aggregation that that gives the current maximum element of the data stream by - * the given position by the given key. An independent aggregate is kept per key. - * When equality, the first element is returned with the maximal value. - * - */ + * Applies an aggregation that that gives the current maximum element of the data stream by + * the given position by the given key. An independent aggregate is kept per key. + * When equality, the first element is returned with the maximal value. + * + * @param position + * The field position in the data points to minimize. This is applicable to + * Tuple types, basic and primitive array types, Scala case classes, + * and primitive types (which is considered as having one field). + */ def maxBy(position: Int): DataStream[T] = aggregate(AggregationType.MAXBY, position) /** - * Applies an aggregation that that gives the current maximum element of the data stream by - * the given field by the given key. An independent aggregate is kept per key. - * When equality, the first element is returned with the maximal value. - * - */ + * Applies an aggregation that that gives the current maximum element of the data stream by + * the given field by the given key. An independent aggregate is kept per key. + * When equality, the first element is returned with the maximal value. + * + * @param field + * In case of a POJO, Scala case class, or Tuple type, the + * name of the (public) field on which to perform the aggregation. + * Additionally, a dot can be used to drill down into nested + * objects, as in `"field1.fieldxy"`. + * Furthermore, an array index can also be specified in case of an array of + * a primitive or basic type; or "0" or "*" can be specified in case of a + * basic type (which is considered as having only one field). + */ def maxBy(field: String): DataStream[T] = aggregate(AggregationType.MAXBY, field) http://git-wip-us.apache.org/repos/asf/flink/blob/1f04542e/flink-tests/src/test/java/org/apache/flink/test/streaming/runtime/DataStreamPojoITCase.java ---------------------------------------------------------------------- diff --git a/flink-tests/src/test/java/org/apache/flink/test/streaming/runtime/DataStreamPojoITCase.java b/flink-tests/src/test/java/org/apache/flink/test/streaming/runtime/DataStreamPojoITCase.java index cc8b699..eb03b45 100644 --- a/flink-tests/src/test/java/org/apache/flink/test/streaming/runtime/DataStreamPojoITCase.java +++ b/flink-tests/src/test/java/org/apache/flink/test/streaming/runtime/DataStreamPojoITCase.java @@ -18,6 +18,7 @@ package org.apache.flink.test.streaming.runtime; import org.apache.flink.api.common.functions.FlatMapFunction; +import org.apache.flink.api.common.typeinfo.InvalidFieldReferenceException; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; import org.apache.flink.streaming.util.StreamingMultipleProgramsTestBase; @@ -140,18 +141,49 @@ public class DataStreamPojoITCase extends StreamingMultipleProgramsTestBase { see.execute(); } + @Test + public void testNestedPojoFieldAccessor() throws Exception { + StreamExecutionEnvironment see = StreamExecutionEnvironment.getExecutionEnvironment(); + see.getConfig().disableObjectReuse(); + see.setParallelism(4); - /** - * As per FLINK-3702 Flink doesn't support nested pojo fields for sum() - */ - @Test(expected = IllegalArgumentException.class) + DataStream<Data> dataStream = see.fromCollection(elements); + + DataStream<Data> summedStream = dataStream + .keyBy("aaa") + .sum("stats.count") + .keyBy("aaa") + .flatMap(new FlatMapFunction<Data, Data>() { + Data[] first = new Data[3]; + @Override + public void flatMap(Data value, Collector<Data> out) throws Exception { + if(first[value.aaa] == null) { + first[value.aaa] = value; + if(value.stats.count != 123) { + throw new RuntimeException("Expected stats.count to be 123"); + } + } else { + if(value.stats.count != 2 * 123) { + throw new RuntimeException("Expected stats.count to be 2 * 123"); + } + } + } + }); + + summedStream.print(); + + see.execute(); + } + + @Test(expected = InvalidFieldReferenceException.class) public void testFailOnNestedPojoFieldAccessor() throws Exception { StreamExecutionEnvironment see = StreamExecutionEnvironment.getExecutionEnvironment(); DataStream<Data> dataStream = see.fromCollection(elements); - dataStream.keyBy("aaa", "stats.count").sum("stats.count"); + dataStream.keyBy("aaa", "stats.count").sum("stats.nonExistingField"); } + public static class Data { public int sum; // sum public int aaa; // keyBy
