http://git-wip-us.apache.org/repos/asf/flink/blob/870e219d/flink-streaming-java/src/main/java/org/apache/flink/streaming/util/typeutils/FieldAccessorFactory.java ---------------------------------------------------------------------- diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/util/typeutils/FieldAccessorFactory.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/util/typeutils/FieldAccessorFactory.java new file mode 100644 index 0000000..6dbeedd --- /dev/null +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/util/typeutils/FieldAccessorFactory.java @@ -0,0 +1,242 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.util.typeutils; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.api.common.ExecutionConfig; +import org.apache.flink.api.common.operators.Keys; +import org.apache.flink.api.common.typeinfo.BasicArrayTypeInfo; +import org.apache.flink.api.common.typeinfo.BasicTypeInfo; +import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeutils.CompositeType; +import org.apache.flink.api.java.typeutils.PojoField; +import org.apache.flink.api.java.typeutils.PojoTypeInfo; +import org.apache.flink.api.java.typeutils.TupleTypeInfo; +import org.apache.flink.api.java.typeutils.TupleTypeInfoBase; + +import java.io.Serializable; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + + +/** + * Static factories for the {@link FieldAccessor} utilities. + */ +@Internal +public class FieldAccessorFactory implements Serializable { + + /** + * Creates a {@link FieldAccessor} for the given field position, which can be used to get and set + * the specified field on instances of this type. + * + * @param pos The field position (zero-based) + * @param config Configuration object + * @param <F> The type of the field to access + * @return The created FieldAccessor + */ + @Internal + public static <T, F> FieldAccessor<T, F> getAccessor(TypeInformation<T> typeInfo, int pos, ExecutionConfig config){ + + // In case of arrays + if (typeInfo instanceof BasicArrayTypeInfo || typeInfo instanceof PrimitiveArrayTypeInfo) { + return new FieldAccessor.ArrayFieldAccessor<>(pos, typeInfo); + + // In case of basic types + } else if (typeInfo instanceof BasicTypeInfo) { + if (pos != 0) { + throw new CompositeType.InvalidFieldReferenceException("The " + ((Integer) pos).toString() + ". field selected on a " + + "basic type (" + typeInfo.toString() + "). A field expression on a basic type can only select " + + "the 0th field (which means selecting the entire basic type)."); + } + @SuppressWarnings("unchecked") + FieldAccessor<T, F> result = (FieldAccessor<T, F>) new FieldAccessor.SimpleFieldAccessor<>(typeInfo); + return result; + + // In case of case classes + } else if (typeInfo.isTupleType() && ((TupleTypeInfoBase) typeInfo).isCaseClass()) { + TupleTypeInfoBase tupleTypeInfo = (TupleTypeInfoBase) typeInfo; + @SuppressWarnings("unchecked") + TypeInformation<F> fieldTypeInfo = (TypeInformation<F>)tupleTypeInfo.getTypeAt(pos); + return new FieldAccessor.RecursiveProductFieldAccessor<>( + pos, typeInfo, new FieldAccessor.SimpleFieldAccessor<>(fieldTypeInfo), config); + + // In case of tuples + } else if (typeInfo.isTupleType()) { + @SuppressWarnings("unchecked") + FieldAccessor<T, F> result = new FieldAccessor.SimpleTupleFieldAccessor(pos, typeInfo); + return result; + + // Default case, PojoType is directed to this statement + } else { + throw new CompositeType.InvalidFieldReferenceException("Cannot reference field by position on " + typeInfo.toString() + + "Referencing a field by position is supported on tuples, case classes, and arrays. " + + "Additionally, you can select the 0th field of a primitive/basic type (e.g. int)."); + } + } + + /** + * Creates a {@link FieldAccessor} for the field that is given by a field expression, + * which can be used to get and set the specified field on instances of this type. + * + * @param field The field expression + * @param config Configuration object + * @param <F> The type of the field to access + * @return The created FieldAccessor + */ + @Internal + public static <T, F> FieldAccessor<T, F> getAccessor(TypeInformation<T> typeInfo, String field, ExecutionConfig config) { + + // In case of arrays + if (typeInfo instanceof BasicArrayTypeInfo || typeInfo instanceof PrimitiveArrayTypeInfo) { + try { + return new FieldAccessor.ArrayFieldAccessor<>(Integer.parseInt(field), typeInfo); + } catch (NumberFormatException ex) { + throw new CompositeType.InvalidFieldReferenceException + ("A field expression on an array must be an integer index (that might be given as a string)."); + } + + // In case of basic types + } else if (typeInfo instanceof BasicTypeInfo) { + try { + int pos = field.equals(Keys.ExpressionKeys.SELECT_ALL_CHAR) ? 0 : Integer.parseInt(field); + return FieldAccessorFactory.getAccessor(typeInfo, pos, config); + } catch (NumberFormatException ex) { + throw new CompositeType.InvalidFieldReferenceException("You tried to select the field \"" + field + + "\" on a " + typeInfo.toString() + ". A field expression on a basic type can only be \"*\" or \"0\"" + + " (both of which mean selecting the entire basic type)."); + } + + // In case of Pojos + } else if (typeInfo instanceof PojoTypeInfo) { + FieldExpression decomp = decomposeFieldExpression(field); + PojoTypeInfo<?> pojoTypeInfo = (PojoTypeInfo) typeInfo; + + int fieldIndex = pojoTypeInfo.getFieldIndex(decomp.head); + + if (fieldIndex == -1) { + throw new CompositeType.InvalidFieldReferenceException( + "Unable to find field \"" + decomp.head + "\" in type " + typeInfo + "."); + } else { + PojoField pojoField = pojoTypeInfo.getPojoFieldAt(fieldIndex); + TypeInformation<?> fieldType = pojoTypeInfo.getTypeAt(fieldIndex); + if (decomp.tail == null) { + @SuppressWarnings("unchecked") + FieldAccessor<F, F> innerAccessor = new FieldAccessor.SimpleFieldAccessor<>((TypeInformation<F>) fieldType); + return new FieldAccessor.PojoFieldAccessor<>(pojoField.getField(), innerAccessor); + } else { + @SuppressWarnings("unchecked") + FieldAccessor<Object, F> innerAccessor = FieldAccessorFactory + .getAccessor((TypeInformation<Object>) fieldType, decomp.tail, config); + return new FieldAccessor.PojoFieldAccessor<>(pojoField.getField(), innerAccessor); + } + } + // In case of case classes + } else if (typeInfo.isTupleType() && ((TupleTypeInfoBase) typeInfo).isCaseClass()) { + TupleTypeInfoBase tupleTypeInfo = (TupleTypeInfoBase) typeInfo; + FieldExpression decomp = decomposeFieldExpression(field); + int fieldPos = tupleTypeInfo.getFieldIndex(decomp.head); + if (fieldPos < 0) { + throw new CompositeType.InvalidFieldReferenceException("Invalid field selected: " + field); + } + + if (decomp.tail == null){ + return new FieldAccessor.SimpleProductFieldAccessor<>(fieldPos, typeInfo, config); + } else { + @SuppressWarnings("unchecked") + FieldAccessor<Object, F> innerAccessor = getAccessor(tupleTypeInfo.getTypeAt(fieldPos), decomp.tail, config); + return new FieldAccessor.RecursiveProductFieldAccessor<>(fieldPos, typeInfo, innerAccessor, config); + } + + // In case of tuples + } else if (typeInfo.isTupleType()) { + TupleTypeInfo tupleTypeInfo = (TupleTypeInfo) typeInfo; + FieldExpression decomp = decomposeFieldExpression(field); + int fieldPos = tupleTypeInfo.getFieldIndex(decomp.head); + if (fieldPos == -1) { + try { + fieldPos = Integer.parseInt(decomp.head); + } catch (NumberFormatException ex) { + throw new CompositeType.InvalidFieldReferenceException("Tried to select field \"" + decomp.head + + "\" on " + typeInfo.toString() + " . Only integer values are allowed here."); + } + } + if (decomp.tail == null) { + @SuppressWarnings("unchecked") + FieldAccessor<T, F> result = new FieldAccessor.SimpleTupleFieldAccessor(fieldPos, tupleTypeInfo); + return result; + } else { + @SuppressWarnings("unchecked") + FieldAccessor<?, F> innerAccessor = getAccessor(tupleTypeInfo.getTypeAt(fieldPos), decomp.tail, config); + @SuppressWarnings("unchecked") + FieldAccessor<T, F> result = new FieldAccessor.RecursiveTupleFieldAccessor(fieldPos, innerAccessor, tupleTypeInfo); + return result; + } + + // Default statement + } else { + throw new CompositeType.InvalidFieldReferenceException("Cannot reference field by field expression on " + typeInfo.toString() + + "Field expressions are only supported on POJO types, tuples, and case classes. " + + "(See the Flink documentation on what is considered a POJO.)"); + } + } + + // -------------------------------------------------------------------------------------------------- + + private final static String REGEX_FIELD = "[\\p{L}\\p{Digit}_\\$]*"; // This can start with a digit (because of Tuples) + private final static String REGEX_NESTED_FIELDS = "("+REGEX_FIELD+")(\\.(.+))?"; + private final static String REGEX_NESTED_FIELDS_WILDCARD = REGEX_NESTED_FIELDS + +"|\\"+ Keys.ExpressionKeys.SELECT_ALL_CHAR + +"|\\"+ Keys.ExpressionKeys.SELECT_ALL_CHAR_SCALA; + + private static final Pattern PATTERN_NESTED_FIELDS_WILDCARD = Pattern.compile(REGEX_NESTED_FIELDS_WILDCARD); + + private static FieldExpression decomposeFieldExpression(String fieldExpression) { + Matcher matcher = PATTERN_NESTED_FIELDS_WILDCARD.matcher(fieldExpression); + if (!matcher.matches()) { + throw new CompositeType.InvalidFieldReferenceException("Invalid field expression \""+fieldExpression+"\"."); + } + + String head = matcher.group(0); + if(head.equals(Keys.ExpressionKeys.SELECT_ALL_CHAR) || head.equals(Keys.ExpressionKeys.SELECT_ALL_CHAR_SCALA)) { + throw new CompositeType.InvalidFieldReferenceException("No wildcards are allowed here."); + } else { + head = matcher.group(1); + } + + String tail = matcher.group(3); + + return new FieldExpression(head, tail); + } + + /** + * Represents a decomposition of a field expression into its first part, and the rest. + * E.g. "foo.f1.bar" is decomposed into "foo" and "f1.bar". + */ + private static class FieldExpression implements Serializable { + + private static final long serialVersionUID = 1L; + + public String head, tail; // tail can be null, if the field expression had just one part + + FieldExpression(String head, String tail) { + this.head = head; + this.tail = tail; + } + } +}
http://git-wip-us.apache.org/repos/asf/flink/blob/870e219d/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/typeutils/FieldAccessorTest.java ---------------------------------------------------------------------- diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/typeutils/FieldAccessorTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/typeutils/FieldAccessorTest.java new file mode 100644 index 0000000..5e7dd35 --- /dev/null +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/typeutils/FieldAccessorTest.java @@ -0,0 +1,358 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.util.typeutils; + +import static org.junit.Assert.*; + +import org.apache.flink.api.common.typeinfo.BasicArrayTypeInfo; +import org.apache.flink.api.common.typeinfo.BasicTypeInfo; +import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeutils.CompositeType; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.api.java.typeutils.PojoTypeInfo; +import org.apache.flink.api.java.typeutils.TupleTypeInfo; +import org.apache.flink.api.java.typeutils.TypeExtractor; +import org.junit.Test; + +public class FieldAccessorTest { + + // Note, that AggregationFunctionTest indirectly also tests FieldAccessors. + // ProductFieldAccessors are tested in CaseClassFieldAccessorTest. + + @Test + public void testFlatTuple() { + Tuple2<String, Integer> t = Tuple2.of("aa", 5); + TupleTypeInfo<Tuple2<String, Integer>> tpeInfo = + (TupleTypeInfo<Tuple2<String, Integer>>) TypeExtractor.getForObject(t); + + FieldAccessor<Tuple2<String, Integer>, String> f0 = FieldAccessorFactory.getAccessor(tpeInfo, "f0", null); + assertEquals("aa", f0.get(t)); + assertEquals("aa", t.f0); + t = f0.set(t, "b"); + assertEquals("b", f0.get(t)); + assertEquals("b", t.f0); + + FieldAccessor<Tuple2<String, Integer>, Integer> f1 = FieldAccessorFactory.getAccessor(tpeInfo, "f1", null); + assertEquals(5, (int) f1.get(t)); + assertEquals(5, (int) t.f1); + t = f1.set(t, 7); + assertEquals(7, (int) f1.get(t)); + assertEquals(7, (int) t.f1); + assertEquals("b", f0.get(t)); + assertEquals("b", t.f0); + + + FieldAccessor<Tuple2<String, Integer>, Integer> f1n = FieldAccessorFactory.getAccessor(tpeInfo, 1, null); + assertEquals(7, (int) f1n.get(t)); + assertEquals(7, (int) t.f1); + t = f1n.set(t, 10); + assertEquals(10, (int) f1n.get(t)); + assertEquals(10, (int) f1.get(t)); + assertEquals(10, (int) t.f1); + assertEquals("b", f0.get(t)); + assertEquals("b", t.f0); + + FieldAccessor<Tuple2<String, Integer>, Integer> f1ns = FieldAccessorFactory.getAccessor(tpeInfo, "1", null); + assertEquals(10, (int) f1ns.get(t)); + assertEquals(10, (int) t.f1); + t = f1ns.set(t, 11); + assertEquals(11, (int) f1ns.get(t)); + assertEquals(11, (int) f1.get(t)); + assertEquals(11, (int) t.f1); + assertEquals("b", f0.get(t)); + assertEquals("b", t.f0); + + // This is technically valid (the ".0" is selecting the 0th field of a basic type). + FieldAccessor<Tuple2<String, Integer>, String> f0_0 = FieldAccessorFactory.getAccessor(tpeInfo, "f0.0", null); + assertEquals("b", f0_0.get(t)); + assertEquals("b", t.f0); + t = f0_0.set(t, "cc"); + assertEquals("cc", f0_0.get(t)); + assertEquals("cc", t.f0); + + } + + @Test(expected = CompositeType.InvalidFieldReferenceException.class) + public void testIllegalFlatTuple() { + Tuple2<String, Integer> t = Tuple2.of("aa", 5); + TupleTypeInfo<Tuple2<String, Integer>> tpeInfo = + (TupleTypeInfo<Tuple2<String, Integer>>) TypeExtractor.getForObject(t); + + FieldAccessorFactory.getAccessor(tpeInfo, "illegal", null); + } + + @Test + public void testTupleInTuple() { + Tuple2<String, Tuple3<Integer, Long, Double>> t = Tuple2.of("aa", Tuple3.of(5, 9L, 2.0)); + TupleTypeInfo<Tuple2<String, Tuple3<Integer, Long, Double>>> tpeInfo = + (TupleTypeInfo<Tuple2<String, Tuple3<Integer, Long, Double>>>)TypeExtractor.getForObject(t); + + FieldAccessor<Tuple2<String, Tuple3<Integer, Long, Double>>, String> f0 = FieldAccessorFactory + .getAccessor(tpeInfo, "f0", null); + assertEquals("aa", f0.get(t)); + assertEquals("aa", t.f0); + + FieldAccessor<Tuple2<String, Tuple3<Integer, Long, Double>>, Double> f1f2 = FieldAccessorFactory + .getAccessor(tpeInfo, "f1.f2", null); + assertEquals(2.0, f1f2.get(t), 0); + assertEquals(2.0, t.f1.f2, 0); + t = f1f2.set(t, 3.0); + assertEquals(3.0, f1f2.get(t), 0); + assertEquals(3.0, t.f1.f2, 0); + assertEquals("aa", f0.get(t)); + assertEquals("aa", t.f0); + + FieldAccessor<Tuple2<String, Tuple3<Integer, Long, Double>>, Tuple3<Integer, Long, Double>> f1 = + FieldAccessorFactory.getAccessor(tpeInfo, "f1", null); + assertEquals(Tuple3.of(5, 9L, 3.0), f1.get(t)); + assertEquals(Tuple3.of(5, 9L, 3.0), t.f1); + t = f1.set(t, Tuple3.of(8, 12L, 4.0)); + assertEquals(Tuple3.of(8, 12L, 4.0), f1.get(t)); + assertEquals(Tuple3.of(8, 12L, 4.0), t.f1); + assertEquals("aa", f0.get(t)); + assertEquals("aa", t.f0); + + FieldAccessor<Tuple2<String, Tuple3<Integer, Long, Double>>, Tuple3<Integer, Long, Double>> f1n = + FieldAccessorFactory.getAccessor(tpeInfo, 1, null); + assertEquals(Tuple3.of(8, 12L, 4.0), f1n.get(t)); + assertEquals(Tuple3.of(8, 12L, 4.0), t.f1); + t = f1n.set(t, Tuple3.of(10, 13L, 5.0)); + assertEquals(Tuple3.of(10, 13L, 5.0), f1n.get(t)); + assertEquals(Tuple3.of(10, 13L, 5.0), f1.get(t)); + assertEquals(Tuple3.of(10, 13L, 5.0), t.f1); + assertEquals("aa", f0.get(t)); + assertEquals("aa", t.f0); + } + + @Test(expected = CompositeType.InvalidFieldReferenceException.class) + @SuppressWarnings("unchecked") + public void testIllegalTupleField() { + FieldAccessorFactory.getAccessor(TupleTypeInfo.getBasicTupleTypeInfo(Integer.class, Integer.class), 2, null); + } + + public static class Foo { + public int x; + public Tuple2<String, Long> t; + public Short y; + + public Foo() {} + + public Foo(int x, Tuple2<String, Long> t, Short y) { + this.x = x; + this.t = t; + this.y = y; + } + } + + @Test + public void testTupleInPojoInTuple() { + Tuple2<String, Foo> t = Tuple2.of("aa", new Foo(8, Tuple2.of("ddd", 9L), (short) 2)); + TupleTypeInfo<Tuple2<String, Foo>> tpeInfo = + (TupleTypeInfo<Tuple2<String, Foo>>) TypeExtractor.getForObject(t); + + FieldAccessor<Tuple2<String, Foo>, Long> f1tf1 = FieldAccessorFactory.getAccessor(tpeInfo, "f1.t.f1", null); + assertEquals(9L, (long) f1tf1.get(t)); + assertEquals(9L, (long) t.f1.t.f1); + t = f1tf1.set(t, 12L); + assertEquals(12L, (long) f1tf1.get(t)); + assertEquals(12L, (long) t.f1.t.f1); + + FieldAccessor<Tuple2<String, Foo>, String> f1tf0 = FieldAccessorFactory.getAccessor(tpeInfo, "f1.t.f0", null); + assertEquals("ddd", f1tf0.get(t)); + assertEquals("ddd", t.f1.t.f0); + t = f1tf0.set(t, "alma"); + assertEquals("alma", f1tf0.get(t)); + assertEquals("alma", t.f1.t.f0); + + FieldAccessor<Tuple2<String, Foo>, Foo> f1 = FieldAccessorFactory.getAccessor(tpeInfo, "f1", null); + FieldAccessor<Tuple2<String, Foo>, Foo> f1n = FieldAccessorFactory.getAccessor(tpeInfo, 1, null); + assertEquals(Tuple2.of("alma", 12L), f1.get(t).t); + assertEquals(Tuple2.of("alma", 12L), f1n.get(t).t); + assertEquals(Tuple2.of("alma", 12L), t.f1.t); + Foo newFoo = new Foo(8, Tuple2.of("ddd", 9L), (short) 2); + f1.set(t, newFoo); + assertEquals(newFoo, f1.get(t)); + assertEquals(newFoo, f1n.get(t)); + assertEquals(newFoo, t.f1); + } + + @Test(expected = CompositeType.InvalidFieldReferenceException.class) + public void testIllegalTupleInPojoInTuple() { + Tuple2<String, Foo> t = Tuple2.of("aa", new Foo(8, Tuple2.of("ddd", 9L), (short) 2)); + TupleTypeInfo<Tuple2<String, Foo>> tpeInfo = + (TupleTypeInfo<Tuple2<String, Foo>>) TypeExtractor.getForObject(t); + + FieldAccessorFactory.getAccessor(tpeInfo, "illegal.illegal.illegal", null); + } + + public static class Inner { + public long x; + public boolean b; + + public Inner(){} + + public Inner(long x) { + this.x = x; + } + + public Inner(long x, boolean b) { + this.x = x; + this.b = b; + } + + @Override + public String toString() { + return ((Long)x).toString() + ", " + b; + } + } + + public static class Outer { + public int a; + public Inner i; + public short b; + + public Outer(){} + + public Outer(int a, Inner i, short b) { + this.a = a; + this.i = i; + this.b = b; + } + + @Override + public String toString() { + return a+", "+i.toString()+", "+b; + } + } + + @Test + public void testPojoInPojo() { + Outer o = new Outer(10, new Inner(4L), (short)12); + PojoTypeInfo<Outer> tpeInfo = (PojoTypeInfo<Outer>) TypeInformation.of(Outer.class); + + FieldAccessor<Outer, Long> fix = FieldAccessorFactory.getAccessor(tpeInfo, "i.x", null); + assertEquals(4L, (long) fix.get(o)); + assertEquals(4L, o.i.x); + o = fix.set(o, 22L); + assertEquals(22L, (long) fix.get(o)); + assertEquals(22L, o.i.x); + + FieldAccessor<Outer, Inner> fi = FieldAccessorFactory.getAccessor(tpeInfo, "i", null); + assertEquals(22L, fi.get(o).x); + assertEquals(22L, (long) fix.get(o)); + assertEquals(22L, o.i.x); + o = fi.set(o, new Inner(30L)); + assertEquals(30L, fi.get(o).x); + assertEquals(30L, (long) fix.get(o)); + assertEquals(30L, o.i.x); + } + + @Test + @SuppressWarnings("unchecked") + public void testArray() { + int[] a = new int[]{3,5}; + FieldAccessor<int[], Integer> fieldAccessor = + (FieldAccessor<int[], Integer>) (Object) + FieldAccessorFactory.getAccessor(PrimitiveArrayTypeInfo.getInfoFor(a.getClass()), 1, null); + + assertEquals(Integer.class, fieldAccessor.getFieldType().getTypeClass()); + + assertEquals((Integer)a[1], fieldAccessor.get(a)); + + a = fieldAccessor.set(a, 6); + assertEquals((Integer)a[1], fieldAccessor.get(a)); + + + + Integer[] b = new Integer[]{3,5}; + FieldAccessor<Integer[], Integer> fieldAccessor2 = + (FieldAccessor<Integer[], Integer>) (Object) + FieldAccessorFactory.getAccessor(BasicArrayTypeInfo.getInfoFor(b.getClass()), 1, null); + + assertEquals(Integer.class, fieldAccessor2.getFieldType().getTypeClass()); + + assertEquals(b[1], fieldAccessor2.get(b)); + + b = fieldAccessor2.set(b, 6); + assertEquals(b[1], fieldAccessor2.get(b)); + } + + public static class ArrayInPojo { + public long x; + public int[] arr; + public int y; + + public ArrayInPojo() {} + + public ArrayInPojo(long x, int[] arr, int y) { + this.x = x; + this.arr = arr; + this.y = y; + } + } + + @Test + public void testArrayInPojo() { + ArrayInPojo o = new ArrayInPojo(10L, new int[]{3,4,5}, 12); + PojoTypeInfo<ArrayInPojo> tpeInfo = (PojoTypeInfo<ArrayInPojo>)TypeInformation.of(ArrayInPojo.class); + + FieldAccessor<ArrayInPojo, Integer> fix = FieldAccessorFactory.getAccessor(tpeInfo, "arr.1", null); + assertEquals(4, (int) fix.get(o)); + assertEquals(4L, o.arr[1]); + o = fix.set(o, 8); + assertEquals(8, (int) fix.get(o)); + assertEquals(8, o.arr[1]); + } + + @Test + public void testBasicType() { + Long x = 7L; + TypeInformation<Long> tpeInfo = BasicTypeInfo.LONG_TYPE_INFO; + + FieldAccessor<Long, Long> f = FieldAccessorFactory.getAccessor(tpeInfo, 0, null); + assertEquals(7L, (long) f.get(x)); + x = f.set(x, 12L); + assertEquals(12L, (long) f.get(x)); + assertEquals(12L, (long) x); + + FieldAccessor<Long, Long> f2 = FieldAccessorFactory.getAccessor(tpeInfo, "*", null); + assertEquals(12L, (long) f2.get(x)); + x = f2.set(x, 14L); + assertEquals(14L, (long) f2.get(x)); + assertEquals(14L, (long) x); + } + + @Test(expected = IllegalArgumentException.class) + public void testIllegalBasicType1() { + Long x = 7L; + TypeInformation<Long> tpeInfo = BasicTypeInfo.LONG_TYPE_INFO; + + FieldAccessor<Long, Long> f = FieldAccessorFactory.getAccessor(tpeInfo, 1, null); + } + + @Test(expected = IllegalArgumentException.class) + public void testIllegalBasicType2() { + Long x = 7L; + TypeInformation<Long> tpeInfo = BasicTypeInfo.LONG_TYPE_INFO; + + FieldAccessor<Long, Long> f = FieldAccessorFactory.getAccessor(tpeInfo, "foo", null); + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/870e219d/flink-streaming-scala/pom.xml ---------------------------------------------------------------------- diff --git a/flink-streaming-scala/pom.xml b/flink-streaming-scala/pom.xml index 26f7cc2..b0cc961 100644 --- a/flink-streaming-scala/pom.xml +++ b/flink-streaming-scala/pom.xml @@ -92,6 +92,14 @@ under the License. <dependency> <groupId>org.apache.flink</groupId> + <artifactId>flink-streaming-java_2.10</artifactId> + <version>${project.version}</version> + <scope>test</scope> + <type>test-jar</type> + </dependency> + + <dependency> + <groupId>org.apache.flink</groupId> <artifactId>flink-tests_2.10</artifactId> <version>${project.version}</version> <scope>test</scope> http://git-wip-us.apache.org/repos/asf/flink/blob/870e219d/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/KeyedStream.scala ---------------------------------------------------------------------- diff --git a/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/KeyedStream.scala b/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/KeyedStream.scala index d5cc013..66d80c2 100644 --- a/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/KeyedStream.scala +++ b/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/KeyedStream.scala @@ -217,9 +217,9 @@ class KeyedStream[T, K](javaStream: KeyedJavaStream[T, K]) extends DataStream[T] * the given position by the given key. An independent aggregate is kept per key. * * @param position - * The field position in the data points to maximize. This is applicable to - * Tuple types, basic and primitive array types, Scala case classes, - * and primitive types (which is considered as having one field). + * The field position in the data points to minimize. This is applicable to + * Tuple types, Scala case classes, and primitive types (which is considered + * as having one field). */ def max(position: Int): DataStream[T] = aggregate(AggregationType.MAX, position) @@ -232,9 +232,8 @@ class KeyedStream[T, K](javaStream: KeyedJavaStream[T, K]) extends DataStream[T] * name of the (public) field on which to perform the aggregation. * Additionally, a dot can be used to drill down into nested * objects, as in `"field1.fieldxy"`. - * Furthermore, an array index can also be specified in case of an array of - * a primitive or basic type; or "0" or "*" can be specified in case of a - * basic type (which is considered as having only one field). + * Furthermore "*" can be specified in case of a basic type + * (which is considered as having only one field). */ def max(field: String): DataStream[T] = aggregate(AggregationType.MAX, field) @@ -244,8 +243,8 @@ class KeyedStream[T, K](javaStream: KeyedJavaStream[T, K]) extends DataStream[T] * * @param position * The field position in the data points to minimize. This is applicable to - * Tuple types, basic and primitive array types, Scala case classes, - * and primitive types (which is considered as having one field). + * Tuple types, Scala case classes, and primitive types (which is considered + * as having one field). */ def min(position: Int): DataStream[T] = aggregate(AggregationType.MIN, position) @@ -258,9 +257,8 @@ class KeyedStream[T, K](javaStream: KeyedJavaStream[T, K]) extends DataStream[T] * name of the (public) field on which to perform the aggregation. * Additionally, a dot can be used to drill down into nested * objects, as in `"field1.fieldxy"`. - * Furthermore, an array index can also be specified in case of an array of - * a primitive or basic type; or "0" or "*" can be specified in case of a - * basic type (which is considered as having only one field). + * Furthermore "*" can be specified in case of a basic type + * (which is considered as having only one field). */ def min(field: String): DataStream[T] = aggregate(AggregationType.MIN, field) @@ -269,9 +267,9 @@ class KeyedStream[T, K](javaStream: KeyedJavaStream[T, K]) extends DataStream[T] * key. An independent aggregate is kept per key. * * @param position - * The field position in the data points to sum. This is applicable to - * Tuple types, basic and primitive array types, Scala case classes, - * and primitive types (which is considered as having one field). + * The field position in the data points to minimize. This is applicable to + * Tuple types, Scala case classes, and primitive types (which is considered + * as having one field). */ def sum(position: Int): DataStream[T] = aggregate(AggregationType.SUM, position) @@ -284,9 +282,8 @@ class KeyedStream[T, K](javaStream: KeyedJavaStream[T, K]) extends DataStream[T] * name of the (public) field on which to perform the aggregation. * Additionally, a dot can be used to drill down into nested * objects, as in `"field1.fieldxy"`. - * Furthermore, an array index can also be specified in case of an array of - * a primitive or basic type; or "0" or "*" can be specified in case of a - * basic type (which is considered as having only one field). + * Furthermore "*" can be specified in case of a basic type + * (which is considered as having only one field). */ def sum(field: String): DataStream[T] = aggregate(AggregationType.SUM, field) @@ -297,8 +294,8 @@ class KeyedStream[T, K](javaStream: KeyedJavaStream[T, K]) extends DataStream[T] * * @param position * The field position in the data points to minimize. This is applicable to - * Tuple types, basic and primitive array types, Scala case classes, - * and primitive types (which is considered as having one field). + * Tuple types, Scala case classes, and primitive types (which is considered + * as having one field). */ def minBy(position: Int): DataStream[T] = aggregate(AggregationType .MINBY, position) @@ -313,9 +310,8 @@ class KeyedStream[T, K](javaStream: KeyedJavaStream[T, K]) extends DataStream[T] * name of the (public) field on which to perform the aggregation. * Additionally, a dot can be used to drill down into nested * objects, as in `"field1.fieldxy"`. - * Furthermore, an array index can also be specified in case of an array of - * a primitive or basic type; or "0" or "*" can be specified in case of a - * basic type (which is considered as having only one field). + * Furthermore "*" can be specified in case of a basic type + * (which is considered as having only one field). */ def minBy(field: String): DataStream[T] = aggregate(AggregationType .MINBY, field ) @@ -327,8 +323,8 @@ class KeyedStream[T, K](javaStream: KeyedJavaStream[T, K]) extends DataStream[T] * * @param position * The field position in the data points to minimize. This is applicable to - * Tuple types, basic and primitive array types, Scala case classes, - * and primitive types (which is considered as having one field). + * Tuple types, Scala case classes, and primitive types (which is considered + * as having one field). */ def maxBy(position: Int): DataStream[T] = aggregate(AggregationType.MAXBY, position) @@ -343,9 +339,8 @@ class KeyedStream[T, K](javaStream: KeyedJavaStream[T, K]) extends DataStream[T] * name of the (public) field on which to perform the aggregation. * Additionally, a dot can be used to drill down into nested * objects, as in `"field1.fieldxy"`. - * Furthermore, an array index can also be specified in case of an array of - * a primitive or basic type; or "0" or "*" can be specified in case of a - * basic type (which is considered as having only one field). + * Furthermore "*" can be specified in case of a basic type + * (which is considered as having only one field). */ def maxBy(field: String): DataStream[T] = aggregate(AggregationType.MAXBY, field) http://git-wip-us.apache.org/repos/asf/flink/blob/870e219d/flink-streaming-scala/src/test/scala/org/apache/flink/streaming/api/scala/CaseClassFieldAccessorTest.scala ---------------------------------------------------------------------- diff --git a/flink-streaming-scala/src/test/scala/org/apache/flink/streaming/api/scala/CaseClassFieldAccessorTest.scala b/flink-streaming-scala/src/test/scala/org/apache/flink/streaming/api/scala/CaseClassFieldAccessorTest.scala new file mode 100644 index 0000000..b61a812 --- /dev/null +++ b/flink-streaming-scala/src/test/scala/org/apache/flink/streaming/api/scala/CaseClassFieldAccessorTest.scala @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.api.scala + +import org.apache.flink.api.common.ExecutionConfig +import org.apache.flink.streaming.util.typeutils.{FieldAccessorFactory, FieldAccessorTest} +import org.apache.flink.util.TestLogger +import org.junit.Test +import org.scalatest.junit.JUnitSuiteLike + +class CaseClassFieldAccessorTest extends TestLogger with JUnitSuiteLike { + + @Test + def testFieldAccessorFlatCaseClass(): Unit = { + case class IntBoolean(foo: Int, bar: Boolean) + val tpeInfo = createTypeInformation[IntBoolean] + + { + // by field name + val accessor1 = FieldAccessorFactory.getAccessor[IntBoolean, Int](tpeInfo, "foo", null) + val accessor2 = FieldAccessorFactory.getAccessor[IntBoolean, Boolean](tpeInfo, "bar", null) + + val x1 = IntBoolean(5, false) + assert(accessor1.get(x1) == 5) + assert(accessor2.get(x1) == false) + assert(x1.foo == 5) + assert(x1.bar == false) + + val x2: IntBoolean = accessor1.set(x1, 6) + assert(accessor1.get(x2) == 6) + assert(x2.foo == 6) + + val x3 = accessor2.set(x2, true) + assert(x3.bar == true) + assert(accessor2.get(x3) == true) + assert(x3.foo == 6) + } + + { + // by field pos + val accessor1 = FieldAccessorFactory.getAccessor[IntBoolean, Int](tpeInfo, 0, null) + val accessor2 = FieldAccessorFactory.getAccessor[IntBoolean, Boolean](tpeInfo, 1, null) + + val x1 = IntBoolean(5, false) + assert(accessor1.get(x1) == 5) + assert(accessor2.get(x1) == false) + assert(x1.foo == 5) + assert(x1.bar == false) + + val x2: IntBoolean = accessor1.set(x1, 6) + assert(accessor1.get(x2) == 6) + assert(x2.foo == 6) + + val x3 = accessor2.set(x2, true) + assert(x3.bar == true) + assert(accessor2.get(x3) == true) + assert(x3.foo == 6) + } + } + + @Test + def testFieldAccessorPojoInCaseClass(): Unit = { + case class Outer(a: Int, i: FieldAccessorTest.Inner, b: Boolean) + var x = Outer(1, new FieldAccessorTest.Inner(3L, true), false) + val tpeInfo = createTypeInformation[Outer] + val cfg = new ExecutionConfig + + val fib = FieldAccessorFactory.getAccessor[Outer, Boolean](tpeInfo, "i.b", cfg) + assert(fib.get(x) == true) + assert(x.i.b == true) + x = fib.set(x, false) + assert(fib.get(x) == false) + assert(x.i.b == false) + + val fi = FieldAccessorFactory.getAccessor[Outer, FieldAccessorTest.Inner](tpeInfo, "i", cfg) + assert(fi.get(x).x == 3L) + assert(x.i.x == 3L) + x = fi.set(x, new FieldAccessorTest.Inner(4L, true)) + assert(fi.get(x).x == 4L) + assert(x.i.x == 4L) + + val fin = FieldAccessorFactory.getAccessor[Outer, FieldAccessorTest.Inner](tpeInfo, 1, cfg) + assert(fin.get(x).x == 4L) + assert(x.i.x == 4L) + x = fin.set(x, new FieldAccessorTest.Inner(5L, true)) + assert(fin.get(x).x == 5L) + assert(x.i.x == 5L) + } + + @Test + def testFieldAccessorTuple(): Unit = { + val tpeInfo = createTypeInformation[(Int, Long)] + var x = (5, 6L) + val f0 = FieldAccessorFactory.getAccessor[(Int, Long), Int](tpeInfo, 0, null) + assert(f0.get(x) == 5) + x = f0.set(x, 8) + assert(f0.get(x) == 8) + assert(x._1 == 8) + } + + @Test + def testFieldAccessorCaseClassInCaseClass(): Unit = { + case class Inner(a: Short, b: String) + case class Outer(a: Int, i: Inner, b: Boolean) + val tpeInfo = createTypeInformation[Outer] + + var x = Outer(1, Inner(2, "alma"), true) + + val fib = FieldAccessorFactory.getAccessor[Outer, String](tpeInfo, "i.b", null) + assert(fib.get(x) == "alma") + assert(x.i.b == "alma") + x = fib.set(x, "korte") + assert(fib.get(x) == "korte") + assert(x.i.b == "korte") + + val fi = FieldAccessorFactory.getAccessor[Outer, Inner](tpeInfo, "i", null) + assert(fi.get(x) == Inner(2, "korte")) + x = fi.set(x, Inner(3, "aaa")) + assert(x.i == Inner(3, "aaa")) + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/870e219d/flink-tests/src/test/java/org/apache/flink/test/streaming/runtime/DataStreamPojoITCase.java ---------------------------------------------------------------------- diff --git a/flink-tests/src/test/java/org/apache/flink/test/streaming/runtime/DataStreamPojoITCase.java b/flink-tests/src/test/java/org/apache/flink/test/streaming/runtime/DataStreamPojoITCase.java index eb03b45..0949c68 100644 --- a/flink-tests/src/test/java/org/apache/flink/test/streaming/runtime/DataStreamPojoITCase.java +++ b/flink-tests/src/test/java/org/apache/flink/test/streaming/runtime/DataStreamPojoITCase.java @@ -18,7 +18,7 @@ package org.apache.flink.test.streaming.runtime; import org.apache.flink.api.common.functions.FlatMapFunction; -import org.apache.flink.api.common.typeinfo.InvalidFieldReferenceException; +import org.apache.flink.api.common.typeutils.CompositeType; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; import org.apache.flink.streaming.util.StreamingMultipleProgramsTestBase; @@ -175,7 +175,7 @@ public class DataStreamPojoITCase extends StreamingMultipleProgramsTestBase { see.execute(); } - @Test(expected = InvalidFieldReferenceException.class) + @Test(expected = CompositeType.InvalidFieldReferenceException.class) public void testFailOnNestedPojoFieldAccessor() throws Exception { StreamExecutionEnvironment see = StreamExecutionEnvironment.getExecutionEnvironment();
