Repository: flink Updated Branches: refs/heads/master 6bb023532 -> 6067833fb
[FLINK-1147][Java API] TypeInference on POJOs This closes #315. Project: http://git-wip-us.apache.org/repos/asf/flink/repo Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/6067833f Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/6067833f Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/6067833f Branch: refs/heads/master Commit: 6067833fb6ad6c11a121d8654d7ca147cc909f05 Parents: 6bb0235 Author: twalthr <[email protected]> Authored: Tue Jan 13 23:59:35 2015 +0100 Committer: twalthr <[email protected]> Committed: Mon Jan 26 16:26:06 2015 +0100 ---------------------------------------------------------------------- .../flink/api/java/typeutils/TypeExtractor.java | 189 ++++++++++++++---- .../api/java/typeutils/TypeInfoParser.java | 7 +- .../type/extractor/PojoTypeExtractionTest.java | 196 ++++++++++++++++++- .../java/type/extractor/TypeExtractorTest.java | 20 +- 4 files changed, 366 insertions(+), 46 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/flink/blob/6067833f/flink-java/src/main/java/org/apache/flink/api/java/typeutils/TypeExtractor.java ---------------------------------------------------------------------- diff --git a/flink-java/src/main/java/org/apache/flink/api/java/typeutils/TypeExtractor.java b/flink-java/src/main/java/org/apache/flink/api/java/typeutils/TypeExtractor.java index edff09c..a1f5dd6 100644 --- a/flink-java/src/main/java/org/apache/flink/api/java/typeutils/TypeExtractor.java +++ b/flink-java/src/main/java/org/apache/flink/api/java/typeutils/TypeExtractor.java @@ -422,7 +422,12 @@ public class TypeExtractor { int fieldCount = countFieldsInClass(tAsClass); if(fieldCount != tupleSubTypes.length) { // the class is not a real tuple because it contains additional fields. treat as a pojo - return (TypeInformation<OUT>) analyzePojo(tAsClass, new ArrayList<Type>(typeHierarchy), null); // the typeHierarchy here should be sufficient, even though it stops at the Tuple.class. + if (t instanceof ParameterizedType) { + return (TypeInformation<OUT>) analyzePojo(tAsClass, new ArrayList<Type>(typeHierarchy), (ParameterizedType) t, in1Type, in2Type); + } + else { + return (TypeInformation<OUT>) analyzePojo(tAsClass, new ArrayList<Type>(typeHierarchy), null, in1Type, in2Type); + } } return new TupleTypeInfo(tAsClass, tupleSubTypes); @@ -482,9 +487,9 @@ public class TypeExtractor { in1Type, in2Type); return ObjectArrayTypeInfo.getInfoFor(t, componentInfo); } - // objects with generics are treated as raw type - else if (t instanceof ParameterizedType) { //TODO - return privateGetForClass((Class<OUT>) ((ParameterizedType) t).getRawType(), typeHierarchy, (ParameterizedType) t); + // objects with generics are treated as Class first + else if (t instanceof ParameterizedType) { + return (TypeInformation<OUT>) privateGetForClass(typeToClass(t), typeHierarchy, (ParameterizedType) t, in1Type, in2Type); } // no tuple, no TypeVariable, no generic type else if (t instanceof Class) { @@ -553,10 +558,25 @@ public class TypeExtractor { // the input is a type variable if (inType instanceof TypeVariable) { inType = materializeTypeVariable(inputTypeHierarchy, (TypeVariable<?>) inType); - info = findCorrespondingInfo(returnTypeVar, inType, inTypeInfo); + info = findCorrespondingInfo(returnTypeVar, inType, inTypeInfo, inputTypeHierarchy); + } + // input is an array + else if (inType instanceof GenericArrayType) { + TypeInformation<?> componentInfo = null; + if (inTypeInfo instanceof BasicArrayTypeInfo) { + componentInfo = ((BasicArrayTypeInfo<?,?>) inTypeInfo).getComponentInfo(); + } + else if (inTypeInfo instanceof PrimitiveArrayTypeInfo) { + componentInfo = BasicTypeInfo.getInfoFor(inTypeInfo.getTypeClass().getComponentType()); + } + else if (inTypeInfo instanceof ObjectArrayTypeInfo) { + componentInfo = ((ObjectArrayTypeInfo<?,?>) inTypeInfo).getComponentInfo(); + } + info = createTypeInfoFromInput(returnTypeVar, inputTypeHierarchy, ((GenericArrayType) inType).getGenericComponentType(), componentInfo); } - // the input is a tuple that may contains type variables - else if (isClassType(inType) && Tuple.class.isAssignableFrom(typeToClass(inType))) { + // the input is a tuple + else if (inTypeInfo instanceof TupleTypeInfo && isClassType(inType) + && Tuple.class.isAssignableFrom(typeToClass(inType))) { ParameterizedType tupleBaseClass = null; // get tuple from possible tuple subclass @@ -579,6 +599,12 @@ public class TypeExtractor { } } } + // the input is a pojo + else if (inTypeInfo instanceof PojoTypeInfo) { + // build the entire type hierarchy for the pojo + getTypeHierarchy(inputTypeHierarchy, inType, Object.class); + info = findCorrespondingInfo(returnTypeVar, inType, inTypeInfo, inputTypeHierarchy); + } return info; } @@ -841,7 +867,7 @@ public class TypeExtractor { * @param curT : start type * @return Type The immediate child of the top class */ - private Type getTypeHierarchy(ArrayList<Type> typeHierarchy, Type curT, Class<?> stopAtClass) { + private static Type getTypeHierarchy(ArrayList<Type> typeHierarchy, Type curT, Class<?> stopAtClass) { // skip first one if (typeHierarchy.size() > 0 && typeHierarchy.get(0) == curT && isClassType(curT)) { curT = typeToClass(curT).getGenericSuperclass(); @@ -926,26 +952,69 @@ public class TypeExtractor { throw new InvalidTypesException(); } - private static TypeInformation<?> findCorrespondingInfo(TypeVariable<?> typeVar, Type type, TypeInformation<?> corrInfo) { - if (type instanceof TypeVariable) { - TypeVariable<?> variable = (TypeVariable<?>) type; - if (variable.getName().equals(typeVar.getName()) && variable.getGenericDeclaration().equals(typeVar.getGenericDeclaration())) { - return corrInfo; + private static TypeInformation<?> findCorrespondingInfo(TypeVariable<?> typeVar, Type type, TypeInformation<?> corrInfo, ArrayList<Type> typeHierarchy) { + if (sameTypeVars(type, typeVar)) { + return corrInfo; + } + else if (type instanceof TypeVariable && sameTypeVars(materializeTypeVariable(typeHierarchy, (TypeVariable<?>) type), typeVar)) { + return corrInfo; + } + else if (type instanceof GenericArrayType) { + TypeInformation<?> componentInfo = null; + if (corrInfo instanceof BasicArrayTypeInfo) { + componentInfo = ((BasicArrayTypeInfo<?,?>) corrInfo).getComponentInfo(); + } + else if (corrInfo instanceof PrimitiveArrayTypeInfo) { + componentInfo = BasicTypeInfo.getInfoFor(corrInfo.getTypeClass().getComponentType()); + } + else if (corrInfo instanceof ObjectArrayTypeInfo) { + componentInfo = ((ObjectArrayTypeInfo<?,?>) corrInfo).getComponentInfo(); } - } else if (type instanceof ParameterizedType && Tuple.class.isAssignableFrom((Class<?>) ((ParameterizedType) type).getRawType())) { + TypeInformation<?> info = findCorrespondingInfo(typeVar, ((GenericArrayType) type).getGenericComponentType(), componentInfo, typeHierarchy); + if (info != null) { + return info; + } + } + else if (corrInfo instanceof TupleTypeInfo + && type instanceof ParameterizedType + && Tuple.class.isAssignableFrom((Class<?>) ((ParameterizedType) type).getRawType())) { ParameterizedType tuple = (ParameterizedType) type; Type[] args = tuple.getActualTypeArguments(); for (int i = 0; i < args.length; i++) { - TypeInformation<?> info = findCorrespondingInfo(typeVar, args[i], ((TupleTypeInfo<?>) corrInfo).getTypeAt(i)); + TypeInformation<?> info = findCorrespondingInfo(typeVar, args[i], ((TupleTypeInfo<?>) corrInfo).getTypeAt(i), typeHierarchy); if (info != null) { return info; } } } + else if (corrInfo instanceof PojoTypeInfo && isClassType(type)) { + // determine a field containing the type variable + List<Field> fields = getAllDeclaredFields(typeToClass(type)); + for (Field field : fields) { + Type fieldType = field.getGenericType(); + if (fieldType instanceof TypeVariable + && sameTypeVars(typeVar, materializeTypeVariable(typeHierarchy, (TypeVariable<?>) fieldType))) { + return getTypeOfPojoField(corrInfo, field); + } + else if (fieldType instanceof ParameterizedType + || fieldType instanceof GenericArrayType) { + ArrayList<Type> typeHierarchyWithFieldType = new ArrayList<Type>(typeHierarchy); + typeHierarchyWithFieldType.add(fieldType); + TypeInformation<?> info = findCorrespondingInfo(typeVar, fieldType, getTypeOfPojoField(corrInfo, field), typeHierarchyWithFieldType); + if (info != null) { + return info; + } + } + } + } return null; } + /** + * Tries to find a concrete value (Class, ParameterizedType etc. ) for a TypeVariable by traversing the type hierarchy downwards. + * If a value could not be found it will return the most bottom type variable in the hierarchy. + */ private static Type materializeTypeVariable(ArrayList<Type> typeHierarchy, TypeVariable<?> typeVar) { TypeVariable<?> inTypeTypeVar = typeVar; // iterate thru hierarchy from top to bottom until type variable gets a class assigned @@ -961,8 +1030,7 @@ public class TypeExtractor { TypeVariable<?> curVarOfCurT = rawType.getTypeParameters()[paramIndex]; // check if variable names match - if (curVarOfCurT.getName().equals(inTypeTypeVar.getName()) - && curVarOfCurT.getGenericDeclaration().equals(inTypeTypeVar.getGenericDeclaration())) { + if (sameTypeVars(curVarOfCurT, inTypeTypeVar)) { Type curVarType = ((ParameterizedType) curT).getActualTypeArguments()[paramIndex]; // another type variable level @@ -982,15 +1050,26 @@ public class TypeExtractor { return inTypeTypeVar; } + /** + * Creates type information from a given Class such as Integer, String[] or POJOs. + * + * This method does not support ParameterizedTypes such as Tuples or complex type hierarchies. + * In most cases {@link TypeExtractor#createTypeInfo(Type)} is the recommended method for type extraction + * (a Class is a child of Type). + * + * @param clazz a Class to create TypeInformation for + * @return TypeInformation that describes the passed Class + */ public static <X> TypeInformation<X> getForClass(Class<X> clazz) { return new TypeExtractor().privateGetForClass(clazz, new ArrayList<Type>()); } private <X> TypeInformation<X> privateGetForClass(Class<X> clazz, ArrayList<Type> typeHierarchy) { - return privateGetForClass(clazz, typeHierarchy, null); + return privateGetForClass(clazz, typeHierarchy, null, null, null); } @SuppressWarnings({ "unchecked", "rawtypes" }) - private <X> TypeInformation<X> privateGetForClass(Class<X> clazz, ArrayList<Type> typeHierarchy, ParameterizedType clazzTypeHint) { + private <OUT,IN1,IN2> TypeInformation<OUT> privateGetForClass(Class<OUT> clazz, ArrayList<Type> typeHierarchy, + ParameterizedType parameterizedType, TypeInformation<IN1> in1Type, TypeInformation<IN2> in2Type) { Validate.notNull(clazz); // check for abstract classes or interfaces @@ -999,20 +1078,20 @@ public class TypeExtractor { } if (clazz.equals(Object.class)) { - return new GenericTypeInfo<X>(clazz); + return new GenericTypeInfo<OUT>(clazz); } // check for arrays if (clazz.isArray()) { // primitive arrays: int[], byte[], ... - PrimitiveArrayTypeInfo<X> primitiveArrayInfo = PrimitiveArrayTypeInfo.getInfoFor(clazz); + PrimitiveArrayTypeInfo<OUT> primitiveArrayInfo = PrimitiveArrayTypeInfo.getInfoFor(clazz); if (primitiveArrayInfo != null) { return primitiveArrayInfo; } // basic type arrays: String[], Integer[], Double[] - BasicArrayTypeInfo<X, ?> basicArrayInfo = BasicArrayTypeInfo.getInfoFor(clazz); + BasicArrayTypeInfo<OUT, ?> basicArrayInfo = BasicArrayTypeInfo.getInfoFor(clazz); if (basicArrayInfo != null) { return basicArrayInfo; } @@ -1025,11 +1104,11 @@ public class TypeExtractor { // check for writable types if(Writable.class.isAssignableFrom(clazz)) { - return (TypeInformation<X>) WritableTypeInfo.getWritableTypeInfo((Class<? extends Writable>) clazz); + return (TypeInformation<OUT>) WritableTypeInfo.getWritableTypeInfo((Class<? extends Writable>) clazz); } // check for basic types - TypeInformation<X> basicTypeInfo = BasicTypeInfo.getInfoFor(clazz); + TypeInformation<OUT> basicTypeInfo = BasicTypeInfo.getInfoFor(clazz); if (basicTypeInfo != null) { return basicTypeInfo; } @@ -1037,7 +1116,7 @@ public class TypeExtractor { // check for subclasses of Value if (Value.class.isAssignableFrom(clazz)) { Class<? extends Value> valueClass = clazz.asSubclass(Value.class); - return (TypeInformation<X>) ValueTypeInfo.getValueTypeInfo(valueClass); + return (TypeInformation<OUT>) ValueTypeInfo.getValueTypeInfo(valueClass); } // check for subclasses of Tuple @@ -1047,22 +1126,22 @@ public class TypeExtractor { // check for Enums if(Enum.class.isAssignableFrom(clazz)) { - return (TypeInformation<X>) new EnumTypeInfo(clazz); + return (TypeInformation<OUT>) new EnumTypeInfo(clazz); } if (alreadySeen.contains(clazz)) { - return new GenericTypeInfo<X>(clazz); + return new GenericTypeInfo<OUT>(clazz); } alreadySeen.add(clazz); if (clazz.equals(Class.class)) { // special case handling for Class, this should not be handled by the POJO logic - return new GenericTypeInfo<X>(clazz); + return new GenericTypeInfo<OUT>(clazz); } try { - TypeInformation<X> pojoType = analyzePojo(clazz, new ArrayList<Type>(typeHierarchy), clazzTypeHint); + TypeInformation<OUT> pojoType = analyzePojo(clazz, new ArrayList<Type>(typeHierarchy), parameterizedType, in1Type, in2Type); if (pojoType != null) { return pojoType; } @@ -1074,7 +1153,7 @@ public class TypeExtractor { } // return a generic type - return new GenericTypeInfo<X>(clazz); + return new GenericTypeInfo<OUT>(clazz); } /** @@ -1142,14 +1221,16 @@ public class TypeExtractor { } @SuppressWarnings("unchecked") - private <X> TypeInformation<X> analyzePojo(Class<X> clazz, ArrayList<Type> typeHierarchy, ParameterizedType clazzTypeHint) { - // try to create Type hierarchy, if the incoming only contains the most bottom one or none. - if(typeHierarchy.size() <= 1) { + private <OUT, IN1, IN2> TypeInformation<OUT> analyzePojo(Class<OUT> clazz, ArrayList<Type> typeHierarchy, + ParameterizedType parameterizedType, TypeInformation<IN1> in1Type, TypeInformation<IN2> in2Type) { + // add the hierarchy of the POJO itself if it is generic + if (parameterizedType != null) { + getTypeHierarchy(typeHierarchy, parameterizedType, Object.class); + } + // create a type hierarchy, if the incoming only contains the most bottom one or none. + else if(typeHierarchy.size() <= 1) { getTypeHierarchy(typeHierarchy, clazz, Object.class); } - if(clazzTypeHint != null) { - getTypeHierarchy(typeHierarchy, clazzTypeHint, Object.class); - } List<Field> fields = getAllDeclaredFields(clazz); List<PojoField> pojoFields = new ArrayList<PojoField>(); @@ -1162,17 +1243,18 @@ public class TypeExtractor { try { ArrayList<Type> fieldTypeHierarchy = new ArrayList<Type>(typeHierarchy); fieldTypeHierarchy.add(fieldType); - pojoFields.add(new PojoField(field, createTypeInfoWithTypeHierarchy(fieldTypeHierarchy, fieldType, null, null) )); + TypeInformation<?> ti = createTypeInfoWithTypeHierarchy(fieldTypeHierarchy, fieldType, in1Type, in2Type); + pojoFields.add(new PojoField(field, ti)); } catch (InvalidTypesException e) { Class<?> genericClass = Object.class; if(isClassType(fieldType)) { genericClass = typeToClass(fieldType); } - pojoFields.add(new PojoField(field, new GenericTypeInfo<X>( (Class<X>) genericClass ))); + pojoFields.add(new PojoField(field, new GenericTypeInfo<OUT>((Class<OUT>) genericClass))); } } - CompositeType<X> pojoType = new PojoTypeInfo<X>(clazz, pojoFields); + CompositeType<OUT> pojoType = new PojoTypeInfo<OUT>(clazz, pojoFields); // // Validate the correctness of the pojo. @@ -1223,6 +1305,15 @@ public class TypeExtractor { return result; } + public static Field getDeclaredField(Class<?> clazz, String name) { + for (Field field : getAllDeclaredFields(clazz)) { + if (field.getName().equals(name)) { + return field; + } + } + return null; + } + private static boolean hasFieldWithSameName(String name, List<Field> fields) { for(Field field : fields) { if(name.equals(field.getName())) { @@ -1260,6 +1351,24 @@ public class TypeExtractor { return t instanceof Class<?> || t instanceof ParameterizedType; } + private static boolean sameTypeVars(Type t1, Type t2) { + if (!(t1 instanceof TypeVariable) || !(t2 instanceof TypeVariable)) { + return false; + } + return ((TypeVariable<?>) t1).getName().equals(((TypeVariable<?>)t2).getName()) + && ((TypeVariable<?>) t1).getGenericDeclaration().equals(((TypeVariable<?>)t2).getGenericDeclaration()); + } + + private static TypeInformation<?> getTypeOfPojoField(TypeInformation<?> pojoInfo, Field field) { + for (int j = 0; j < pojoInfo.getArity(); j++) { + PojoField pf = ((PojoTypeInfo<?>) pojoInfo).getPojoFieldAt(j); + if (pf.field.getName().equals(field.getName())) { + return pf.type; + } + } + return null; + } + public static <X> TypeInformation<X> getForObject(X value) { return new TypeExtractor().privateGetForObject(value); @@ -1275,7 +1384,7 @@ public class TypeExtractor { int numFields = t.getArity(); if(numFields != countFieldsInClass(value.getClass())) { // not a tuple since it has more fields. - return analyzePojo((Class<X>) value.getClass(), new ArrayList<Type>(), null); // we immediately call analyze Pojo here, because + return analyzePojo((Class<X>) value.getClass(), new ArrayList<Type>(), null, null, null); // we immediately call analyze Pojo here, because // there is currently no other type that can handle such a class. } http://git-wip-us.apache.org/repos/asf/flink/blob/6067833f/flink-java/src/main/java/org/apache/flink/api/java/typeutils/TypeInfoParser.java ---------------------------------------------------------------------- diff --git a/flink-java/src/main/java/org/apache/flink/api/java/typeutils/TypeInfoParser.java b/flink-java/src/main/java/org/apache/flink/api/java/typeutils/TypeInfoParser.java index 98373da..6890d0c 100644 --- a/flink-java/src/main/java/org/apache/flink/api/java/typeutils/TypeInfoParser.java +++ b/flink-java/src/main/java/org/apache/flink/api/java/typeutils/TypeInfoParser.java @@ -159,6 +159,7 @@ public class TypeInfoParser { } else { arrayClazz = Class.forName("[L" + TUPLE_PACKAGE + "." + className + ";"); } + sb.delete(0, 2); returnType = ObjectArrayTypeInfo.getInfoFor(arrayClazz, new TupleTypeInfo(clazz, types)); } else if (sb.length() < 1 || sb.charAt(0) != '[') { returnType = new TupleTypeInfo(clazz, types); @@ -308,10 +309,8 @@ public class TypeInfoParser { String fieldName = fieldMatcher.group(1); sb.delete(0, fieldName.length() + 1); - Field field = null; - try { - field = clazz.getDeclaredField(fieldName); - } catch (Exception e) { + Field field = TypeExtractor.getDeclaredField(clazz, fieldName); + if (field == null) { throw new IllegalArgumentException("Field '" + fieldName + "'could not be accessed."); } fields.add(new PojoField(field, parse(sb))); http://git-wip-us.apache.org/repos/asf/flink/blob/6067833f/flink-java/src/test/java/org/apache/flink/api/java/type/extractor/PojoTypeExtractionTest.java ---------------------------------------------------------------------- diff --git a/flink-java/src/test/java/org/apache/flink/api/java/type/extractor/PojoTypeExtractionTest.java b/flink-java/src/test/java/org/apache/flink/api/java/type/extractor/PojoTypeExtractionTest.java index 39d6e10..27db31d 100644 --- a/flink-java/src/test/java/org/apache/flink/api/java/type/extractor/PojoTypeExtractionTest.java +++ b/flink-java/src/test/java/org/apache/flink/api/java/type/extractor/PojoTypeExtractionTest.java @@ -23,18 +23,21 @@ import java.util.Collection; import java.util.Date; import java.util.List; +import org.apache.flink.api.common.functions.MapFunction; import org.apache.flink.api.common.typeinfo.BasicArrayTypeInfo; import org.apache.flink.api.common.typeinfo.BasicTypeInfo; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.common.typeutils.CompositeType.FlatFieldDescriptor; import org.apache.flink.api.java.ExecutionEnvironment; import org.apache.flink.api.java.tuple.Tuple1; +import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.api.java.tuple.Tuple3; import org.apache.flink.api.java.typeutils.GenericTypeInfo; import org.apache.flink.api.java.typeutils.PojoField; import org.apache.flink.api.java.typeutils.PojoTypeInfo; import org.apache.flink.api.java.typeutils.TupleTypeInfo; import org.apache.flink.api.java.typeutils.TypeExtractor; +import org.apache.flink.api.java.typeutils.TypeInfoParser; import org.apache.flink.api.java.typeutils.TypeInfoParserTest.MyWritable; import org.apache.flink.api.java.typeutils.WritableTypeInfo; import org.junit.Assert; @@ -208,6 +211,7 @@ public class PojoTypeExtractionTest { checkWCPojoAsserts(typeForObject); } + @SuppressWarnings({ "unchecked", "rawtypes" }) private void checkWCPojoAsserts(TypeInformation<?> typeInfo) { Assert.assertFalse(typeInfo.isBasicType()); Assert.assertFalse(typeInfo.isTupleType()); @@ -406,7 +410,6 @@ public class PojoTypeExtractionTest { Assert.assertEquals(typeInfo.getArity(), 2); } - // Kryo is required for this, so disable for now. @Test public void testPojoAllPublic() { TypeInformation<?> typeForClass = TypeExtractor.createTypeInfo(AllPublic.class); @@ -616,4 +619,195 @@ public class PojoTypeExtractionTest { ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); env.fromElements(new VertexTyped(0L, 3.0), new VertexTyped(1L, 1.0)); } + + public static class MyMapper<T> implements MapFunction<PojoWithGenerics<Long, T>, PojoWithGenerics<T,T>> { + private static final long serialVersionUID = 1L; + + @Override + public PojoWithGenerics<T, T> map(PojoWithGenerics<Long, T> value) + throws Exception { + return null; + } + } + + @SuppressWarnings({ "unchecked", "rawtypes" }) + @Test + public void testGenericPojoTypeInference1() { + MapFunction<?, ?> function = new MyMapper<String>(); + + TypeInformation<?> ti = TypeExtractor.getMapReturnTypes(function, (TypeInformation) + TypeInfoParser.parse("org.apache.flink.api.java.type.extractor.PojoTypeExtractionTest$PojoWithGenerics<key=int,field1=Long,field2=String>")); + Assert.assertTrue(ti instanceof PojoTypeInfo<?>); + PojoTypeInfo<?> pti = (PojoTypeInfo<?>) ti; + for(int i = 0; i < pti.getArity(); i++) { + PojoField field = pti.getPojoFieldAt(i); + String name = field.field.getName(); + if(name.equals("field1")) { + Assert.assertEquals(BasicTypeInfo.STRING_TYPE_INFO, field.type); + } else if (name.equals("field2")) { + Assert.assertEquals(BasicTypeInfo.STRING_TYPE_INFO, field.type); + } else if (name.equals("key")) { + Assert.assertEquals(BasicTypeInfo.INT_TYPE_INFO, field.type); + } else { + Assert.fail("Unexpected field "+field); + } + } + } + + public static class PojoTuple<A, B, C> extends Tuple3<B, C, Long> { + private static final long serialVersionUID = 1L; + + public A extraField; + } + + public static class MyMapper2<D, E> implements MapFunction<Tuple2<E, D>, PojoTuple<E, D, D>> { + private static final long serialVersionUID = 1L; + + @Override + public PojoTuple<E, D, D> map(Tuple2<E, D> value) throws Exception { + return null; + } + } + + @SuppressWarnings({ "unchecked", "rawtypes" }) + @Test + public void testGenericPojoTypeInference2() { + MapFunction<?, ?> function = new MyMapper2<Boolean, Character>(); + + TypeInformation<?> ti = TypeExtractor.getMapReturnTypes(function, (TypeInformation) + TypeInfoParser.parse("Tuple2<Character,Boolean>")); + Assert.assertTrue(ti instanceof PojoTypeInfo<?>); + PojoTypeInfo<?> pti = (PojoTypeInfo<?>) ti; + for(int i = 0; i < pti.getArity(); i++) { + PojoField field = pti.getPojoFieldAt(i); + String name = field.field.getName(); + if(name.equals("extraField")) { + Assert.assertEquals(BasicTypeInfo.CHAR_TYPE_INFO, field.type); + } else if (name.equals("f0")) { + Assert.assertEquals(BasicTypeInfo.BOOLEAN_TYPE_INFO, field.type); + } else if (name.equals("f1")) { + Assert.assertEquals(BasicTypeInfo.BOOLEAN_TYPE_INFO, field.type); + } else if (name.equals("f2")) { + Assert.assertEquals(BasicTypeInfo.LONG_TYPE_INFO, field.type); + } else { + Assert.fail("Unexpected field "+field); + } + } + } + + public static class MyMapper3<D, E> implements MapFunction<PojoTuple<E, D, D>, Tuple2<E, D>> { + private static final long serialVersionUID = 1L; + + @Override + public Tuple2<E, D> map(PojoTuple<E, D, D> value) throws Exception { + return null; + } + } + + @SuppressWarnings({ "unchecked", "rawtypes" }) + @Test + public void testGenericPojoTypeInference3() { + MapFunction<?, ?> function = new MyMapper3<Boolean, Character>(); + + TypeInformation<?> ti = TypeExtractor.getMapReturnTypes(function, (TypeInformation) + TypeInfoParser.parse("org.apache.flink.api.java.type.extractor.PojoTypeExtractionTest$PojoTuple<extraField=char,f0=boolean,f1=boolean,f2=long>")); + Assert.assertTrue(ti instanceof TupleTypeInfo<?>); + TupleTypeInfo<?> tti = (TupleTypeInfo<?>) ti; + Assert.assertEquals(BasicTypeInfo.CHAR_TYPE_INFO, tti.getTypeAt(0)); + Assert.assertEquals(BasicTypeInfo.BOOLEAN_TYPE_INFO, tti.getTypeAt(1)); + } + + public static class PojoWithParameterizedFields1<Z> { + public Tuple2<Z, Z> field; + } + + public static class MyMapper4<A> implements MapFunction<PojoWithParameterizedFields1<A>, A> { + private static final long serialVersionUID = 1L; + @Override + public A map(PojoWithParameterizedFields1<A> value) throws Exception { + return null; + } + } + + @SuppressWarnings({ "unchecked", "rawtypes" }) + @Test + public void testGenericPojoTypeInference4() { + MapFunction<?, ?> function = new MyMapper4<Byte>(); + + TypeInformation<?> ti = TypeExtractor.getMapReturnTypes(function, (TypeInformation) + TypeInfoParser.parse("org.apache.flink.api.java.type.extractor.PojoTypeExtractionTest$PojoWithParameterizedFields1<field=Tuple2<byte,byte>>")); + Assert.assertEquals(BasicTypeInfo.BYTE_TYPE_INFO, ti); + } + + public static class PojoWithParameterizedFields2<Z> { + public PojoWithGenerics<Z, Z> field; + } + + public static class MyMapper5<A> implements MapFunction<PojoWithParameterizedFields2<A>, A> { + private static final long serialVersionUID = 1L; + @Override + public A map(PojoWithParameterizedFields2<A> value) throws Exception { + return null; + } + } + + @SuppressWarnings({ "unchecked", "rawtypes" }) + @Test + public void testGenericPojoTypeInference5() { + MapFunction<?, ?> function = new MyMapper5<Byte>(); + + TypeInformation<?> ti = TypeExtractor.getMapReturnTypes(function, (TypeInformation) + TypeInfoParser.parse("org.apache.flink.api.java.type.extractor.PojoTypeExtractionTest$PojoWithParameterizedFields2<" + + "field=org.apache.flink.api.java.type.extractor.PojoTypeExtractionTest$PojoWithGenerics<key=int,field1=byte,field2=byte>" + + ">")); + Assert.assertEquals(BasicTypeInfo.BYTE_TYPE_INFO, ti); + } + + public static class PojoWithParameterizedFields3<Z> { + public Z[] field; + } + + public static class MyMapper6<A> implements MapFunction<PojoWithParameterizedFields3<A>, A> { + private static final long serialVersionUID = 1L; + @Override + public A map(PojoWithParameterizedFields3<A> value) throws Exception { + return null; + } + } + + @SuppressWarnings({ "unchecked", "rawtypes" }) + @Test + public void testGenericPojoTypeInference6() { + MapFunction<?, ?> function = new MyMapper6<Integer>(); + + TypeInformation<?> ti = TypeExtractor.getMapReturnTypes(function, (TypeInformation) + TypeInfoParser.parse("org.apache.flink.api.java.type.extractor.PojoTypeExtractionTest$PojoWithParameterizedFields3<" + + "field=int[]" + + ">")); + Assert.assertEquals(BasicTypeInfo.INT_TYPE_INFO, ti); + } + + public static class MyMapper7<A> implements MapFunction<PojoWithParameterizedFields4<A>, A> { + private static final long serialVersionUID = 1L; + @Override + public A map(PojoWithParameterizedFields4<A> value) throws Exception { + return null; + } + } + + public static class PojoWithParameterizedFields4<Z> { + public Tuple1<Z>[] field; + } + + @SuppressWarnings({ "unchecked", "rawtypes" }) + @Test + public void testGenericPojoTypeInference7() { + MapFunction<?, ?> function = new MyMapper7<Integer>(); + + TypeInformation<?> ti = TypeExtractor.getMapReturnTypes(function, (TypeInformation) + TypeInfoParser.parse("org.apache.flink.api.java.type.extractor.PojoTypeExtractionTest$PojoWithParameterizedFields4<" + + "field=Tuple1<int>[]" + + ">")); + Assert.assertEquals(BasicTypeInfo.INT_TYPE_INFO, ti); + } } http://git-wip-us.apache.org/repos/asf/flink/blob/6067833f/flink-java/src/test/java/org/apache/flink/api/java/type/extractor/TypeExtractorTest.java ---------------------------------------------------------------------- diff --git a/flink-java/src/test/java/org/apache/flink/api/java/type/extractor/TypeExtractorTest.java b/flink-java/src/test/java/org/apache/flink/api/java/type/extractor/TypeExtractorTest.java index 1364a2f..8a2d675 100644 --- a/flink-java/src/test/java/org/apache/flink/api/java/type/extractor/TypeExtractorTest.java +++ b/flink-java/src/test/java/org/apache/flink/api/java/type/extractor/TypeExtractorTest.java @@ -1260,7 +1260,7 @@ public class TypeExtractorTest { public static class InType extends MyObject<String> {} @SuppressWarnings({ "rawtypes", "unchecked" }) @Test - public void testParamertizedCustomObject() { + public void testParameterizedPojo() { RichMapFunction<?, ?> function = new RichMapFunction<InType, MyObject<String>>() { private static final long serialVersionUID = 1L; @@ -1622,6 +1622,24 @@ public class TypeExtractorTest { Assert.assertEquals(BasicTypeInfo.STRING_TYPE_INFO, ti); } + public static class EdgeMapper4<K, V> implements MapFunction<Edge<K, V>[], V> { + private static final long serialVersionUID = 1L; + + @Override + public V map(Edge<K, V>[] value) throws Exception { + return null; + } + } + + @SuppressWarnings({ "unchecked", "rawtypes" }) + @Test + public void testInputInference4() { + EdgeMapper4<Boolean, String> em = new EdgeMapper4<Boolean, String>(); + TypeInformation<?> ti = TypeExtractor.getMapReturnTypes((MapFunction) em, TypeInfoParser.parse("Tuple3<Boolean,Boolean,String>[]")); + Assert.assertTrue(ti.isBasicType()); + Assert.assertEquals(BasicTypeInfo.STRING_TYPE_INFO, ti); + } + public static enum MyEnum { ONE, TWO, THREE }
