Repository: flink Updated Branches: refs/heads/master bd2fce6e1 -> 6f09ecded
[FLINK-4801] [types] Input type inference is faulty with custom Tuples and RichFunctions This closes #2625. Project: http://git-wip-us.apache.org/repos/asf/flink/repo Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/6f09ecde Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/6f09ecde Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/6f09ecde Branch: refs/heads/master Commit: 6f09ecded9e22a5eaa548ebbddb9b28dad4207c2 Parents: bd2fce6 Author: twalthr <[email protected]> Authored: Wed Oct 12 10:33:47 2016 +0200 Committer: twalthr <[email protected]> Committed: Tue Nov 15 14:56:59 2016 +0100 ---------------------------------------------------------------------- .../api/java/typeutils/TypeExtractionUtils.java | 38 ++++++ .../flink/api/java/typeutils/TypeExtractor.java | 128 ++++++------------- .../typeutils/runtime/kryo/Serializers.java | 6 +- .../api/java/typeutils/TypeExtractorTest.java | 37 +++++- 4 files changed, 114 insertions(+), 95 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/flink/blob/6f09ecde/flink-core/src/main/java/org/apache/flink/api/java/typeutils/TypeExtractionUtils.java ---------------------------------------------------------------------- diff --git a/flink-core/src/main/java/org/apache/flink/api/java/typeutils/TypeExtractionUtils.java b/flink-core/src/main/java/org/apache/flink/api/java/typeutils/TypeExtractionUtils.java index 4439612..0aac257 100644 --- a/flink-core/src/main/java/org/apache/flink/api/java/typeutils/TypeExtractionUtils.java +++ b/flink-core/src/main/java/org/apache/flink/api/java/typeutils/TypeExtractionUtils.java @@ -20,7 +20,9 @@ package org.apache.flink.api.java.typeutils; import java.lang.reflect.Constructor; import java.lang.reflect.Method; +import java.lang.reflect.ParameterizedType; import java.lang.reflect.Type; +import java.lang.reflect.TypeVariable; import java.util.ArrayList; import java.util.Collections; import java.util.List; @@ -81,6 +83,12 @@ public class TypeExtractionUtils { } } + /** + * Checks if the given function has been implemented using a Java 8 lambda. If yes, a LambdaExecutable + * is returned describing the method/constructor. Otherwise null. + * + * @throws TypeExtractionException lambda extraction is pretty hacky, it might fail for unknown JVM issues. + */ public static LambdaExecutable checkAndExtractLambda(Function function) throws TypeExtractionException { try { // get serialized lambda @@ -164,4 +172,34 @@ public class TypeExtractionUtils { } return result; } + + /** + * Convert ParameterizedType or Class to a Class. + */ + public static Class<?> typeToClass(Type t) { + if (t instanceof Class) { + return (Class<?>)t; + } + else if (t instanceof ParameterizedType) { + return ((Class<?>)((ParameterizedType) t).getRawType()); + } + throw new IllegalArgumentException("Cannot convert type to class"); + } + + /** + * Checks if a type can be converted to a Class. This is true for ParameterizedType and Class. + */ + public static boolean isClassType(Type t) { + return t instanceof Class<?> || t instanceof ParameterizedType; + } + + /** + * Checks whether two types are type variables describing the same. + */ + public static boolean sameTypeVars(Type t1, Type t2) { + return t1 instanceof TypeVariable && + t2 instanceof TypeVariable && + ((TypeVariable<?>) t1).getName().equals(((TypeVariable<?>) t2).getName()) && + ((TypeVariable<?>) t1).getGenericDeclaration().equals(((TypeVariable<?>) t2).getGenericDeclaration()); + } } http://git-wip-us.apache.org/repos/asf/flink/blob/6f09ecde/flink-core/src/main/java/org/apache/flink/api/java/typeutils/TypeExtractor.java ---------------------------------------------------------------------- diff --git a/flink-core/src/main/java/org/apache/flink/api/java/typeutils/TypeExtractor.java b/flink-core/src/main/java/org/apache/flink/api/java/typeutils/TypeExtractor.java index c1febea..b41bbc1 100644 --- a/flink-core/src/main/java/org/apache/flink/api/java/typeutils/TypeExtractor.java +++ b/flink-core/src/main/java/org/apache/flink/api/java/typeutils/TypeExtractor.java @@ -68,6 +68,9 @@ import org.apache.flink.api.java.tuple.Tuple0; import org.apache.flink.api.java.typeutils.TypeExtractionUtils.LambdaExecutable; import static org.apache.flink.api.java.typeutils.TypeExtractionUtils.checkAndExtractLambda; import static org.apache.flink.api.java.typeutils.TypeExtractionUtils.getAllDeclaredMethods; +import static org.apache.flink.api.java.typeutils.TypeExtractionUtils.isClassType; +import static org.apache.flink.api.java.typeutils.TypeExtractionUtils.sameTypeVars; +import static org.apache.flink.api.java.typeutils.TypeExtractionUtils.typeToClass; import org.apache.flink.types.Either; import org.apache.flink.types.Value; import org.apache.flink.util.InstantiationUtil; @@ -859,6 +862,14 @@ public class TypeExtractor { return null; } + /** + * Finds the type information to a type variable. + * + * It solve the following: + * + * Return the type information for "returnTypeVar" given that "inType" has type information "inTypeInfo". + * Thus "inType" must contain "returnTypeVar" in a "inputTypeHierarchy", otherwise null is returned. + */ @SuppressWarnings({"unchecked", "rawtypes"}) private <IN1> TypeInformation<?> createTypeInfoFromInput(TypeVariable<?> returnTypeVar, ArrayList<Type> inputTypeHierarchy, Type inType, TypeInformation<IN1> inTypeInfo) { TypeInformation<?> info = null; @@ -891,9 +902,14 @@ public class TypeExtractor { } } // the input is a type variable + else if (sameTypeVars(inType, returnTypeVar)) { + return inTypeInfo; + } else if (inType instanceof TypeVariable) { - inType = materializeTypeVariable(inputTypeHierarchy, (TypeVariable<?>) inType); - info = findCorrespondingInfo(returnTypeVar, inType, inTypeInfo, inputTypeHierarchy); + Type resolvedInType = materializeTypeVariable(inputTypeHierarchy, (TypeVariable<?>) inType); + if (resolvedInType != inType) { + info = createTypeInfoFromInput(returnTypeVar, inputTypeHierarchy, resolvedInType, inTypeInfo); + } } // input is an array else if (inType instanceof GenericArrayType) { @@ -910,8 +926,7 @@ public class TypeExtractor { info = createTypeInfoFromInput(returnTypeVar, inputTypeHierarchy, ((GenericArrayType) inType).getGenericComponentType(), componentInfo); } // the input is a tuple - else if (inTypeInfo instanceof TupleTypeInfo && isClassType(inType) - && Tuple.class.isAssignableFrom(typeToClass(inType))) { + else if (inTypeInfo instanceof TupleTypeInfo && isClassType(inType) && Tuple.class.isAssignableFrom(typeToClass(inType))) { ParameterizedType tupleBaseClass; // get tuple from possible tuple subclass @@ -935,10 +950,25 @@ public class TypeExtractor { } } // the input is a pojo - else if (inTypeInfo instanceof PojoTypeInfo) { + else if (inTypeInfo instanceof PojoTypeInfo && isClassType(inType)) { // build the entire type hierarchy for the pojo getTypeHierarchy(inputTypeHierarchy, inType, Object.class); - info = findCorrespondingInfo(returnTypeVar, inType, inTypeInfo, inputTypeHierarchy); + // determine a field containing the type variable + List<Field> fields = getAllDeclaredFields(typeToClass(inType)); + for (Field field : fields) { + Type fieldType = field.getGenericType(); + if (fieldType instanceof TypeVariable && sameTypeVars(returnTypeVar, materializeTypeVariable(inputTypeHierarchy, (TypeVariable<?>) fieldType))) { + return getTypeOfPojoField(inTypeInfo, field); + } + else if (fieldType instanceof ParameterizedType || fieldType instanceof GenericArrayType) { + ArrayList<Type> typeHierarchyWithFieldType = new ArrayList<>(inputTypeHierarchy); + typeHierarchyWithFieldType.add(fieldType); + TypeInformation<?> foundInfo = createTypeInfoFromInput(returnTypeVar, typeHierarchyWithFieldType, fieldType, getTypeOfPojoField(inTypeInfo, field)); + if (foundInfo != null) { + return foundInfo; + } + } + } } return info; } @@ -1557,66 +1587,7 @@ public class TypeExtractor { } throw new InvalidTypesException(); } - - private static TypeInformation<?> findCorrespondingInfo(TypeVariable<?> typeVar, Type type, TypeInformation<?> corrInfo, ArrayList<Type> typeHierarchy) { - if (sameTypeVars(type, typeVar)) { - return corrInfo; - } - else if (type instanceof TypeVariable && sameTypeVars(materializeTypeVariable(typeHierarchy, (TypeVariable<?>) type), typeVar)) { - return corrInfo; - } - else if (type instanceof GenericArrayType) { - TypeInformation<?> componentInfo = null; - if (corrInfo instanceof BasicArrayTypeInfo) { - componentInfo = ((BasicArrayTypeInfo<?,?>) corrInfo).getComponentInfo(); - } - else if (corrInfo instanceof PrimitiveArrayTypeInfo) { - componentInfo = BasicTypeInfo.getInfoFor(corrInfo.getTypeClass().getComponentType()); - } - else if (corrInfo instanceof ObjectArrayTypeInfo) { - componentInfo = ((ObjectArrayTypeInfo<?,?>) corrInfo).getComponentInfo(); - } - TypeInformation<?> info = findCorrespondingInfo(typeVar, ((GenericArrayType) type).getGenericComponentType(), componentInfo, typeHierarchy); - if (info != null) { - return info; - } - } - else if (corrInfo instanceof TupleTypeInfo - && type instanceof ParameterizedType - && Tuple.class.isAssignableFrom((Class<?>) ((ParameterizedType) type).getRawType())) { - ParameterizedType tuple = (ParameterizedType) type; - Type[] args = tuple.getActualTypeArguments(); - - for (int i = 0; i < args.length; i++) { - TypeInformation<?> info = findCorrespondingInfo(typeVar, args[i], ((TupleTypeInfo<?>) corrInfo).getTypeAt(i), typeHierarchy); - if (info != null) { - return info; - } - } - } - else if (corrInfo instanceof PojoTypeInfo && isClassType(type)) { - // determine a field containing the type variable - List<Field> fields = getAllDeclaredFields(typeToClass(type)); - for (Field field : fields) { - Type fieldType = field.getGenericType(); - if (fieldType instanceof TypeVariable - && sameTypeVars(typeVar, materializeTypeVariable(typeHierarchy, (TypeVariable<?>) fieldType))) { - return getTypeOfPojoField(corrInfo, field); - } - else if (fieldType instanceof ParameterizedType - || fieldType instanceof GenericArrayType) { - ArrayList<Type> typeHierarchyWithFieldType = new ArrayList<Type>(typeHierarchy); - typeHierarchyWithFieldType.add(fieldType); - TypeInformation<?> info = findCorrespondingInfo(typeVar, fieldType, getTypeOfPojoField(corrInfo, field), typeHierarchyWithFieldType); - if (info != null) { - return info; - } - } - } - } - return null; - } - + /** * Tries to find a concrete value (Class, ParameterizedType etc. ) for a TypeVariable by traversing the type hierarchy downwards. * If a value could not be found it will return the most bottom type variable in the hierarchy. @@ -1991,30 +1962,6 @@ public class TypeExtractor { } return false; } - - @Internal - public static Class<?> typeToClass(Type t) { - if (t instanceof Class) { - return (Class<?>)t; - } - else if (t instanceof ParameterizedType) { - return ((Class<?>)((ParameterizedType) t).getRawType()); - } - throw new IllegalArgumentException("Cannot convert type to class"); - } - - @Internal - public static boolean isClassType(Type t) { - return t instanceof Class<?> || t instanceof ParameterizedType; - } - - private static boolean sameTypeVars(Type t1, Type t2) { - if (!(t1 instanceof TypeVariable) || !(t2 instanceof TypeVariable)) { - return false; - } - return ((TypeVariable<?>) t1).getName().equals(((TypeVariable<?>) t2).getName()) - && ((TypeVariable<?>) t1).getGenericDeclaration().equals(((TypeVariable<?>) t2).getGenericDeclaration()); - } private static TypeInformation<?> getTypeOfPojoField(TypeInformation<?> pojoInfo, Field field) { for (int j = 0; j < pojoInfo.getArity(); j++) { @@ -2026,7 +1973,6 @@ public class TypeExtractor { return null; } - public static <X> TypeInformation<X> getForObject(X value) { return new TypeExtractor().privateGetForObject(value); } http://git-wip-us.apache.org/repos/asf/flink/blob/6f09ecde/flink-core/src/main/java/org/apache/flink/api/java/typeutils/runtime/kryo/Serializers.java ---------------------------------------------------------------------- diff --git a/flink-core/src/main/java/org/apache/flink/api/java/typeutils/runtime/kryo/Serializers.java b/flink-core/src/main/java/org/apache/flink/api/java/typeutils/runtime/kryo/Serializers.java index b6e978f..4976d6a 100644 --- a/flink-core/src/main/java/org/apache/flink/api/java/typeutils/runtime/kryo/Serializers.java +++ b/flink-core/src/main/java/org/apache/flink/api/java/typeutils/runtime/kryo/Serializers.java @@ -34,7 +34,7 @@ import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.common.typeutils.CompositeType; import org.apache.flink.api.java.typeutils.GenericTypeInfo; import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo; -import org.apache.flink.api.java.typeutils.TypeExtractor; +import org.apache.flink.api.java.typeutils.TypeExtractionUtils; import java.io.Serializable; import java.lang.reflect.Field; @@ -113,8 +113,8 @@ public class Serializers { ParameterizedType parameterizedFieldType = (ParameterizedType) fieldType; for (Type t: parameterizedFieldType.getActualTypeArguments()) { - if (TypeExtractor.isClassType(t) ) { - recursivelyRegisterType(TypeExtractor.typeToClass(t), config, alreadySeen); + if (TypeExtractionUtils.isClassType(t) ) { + recursivelyRegisterType(TypeExtractionUtils.typeToClass(t), config, alreadySeen); } } http://git-wip-us.apache.org/repos/asf/flink/blob/6f09ecde/flink-core/src/test/java/org/apache/flink/api/java/typeutils/TypeExtractorTest.java ---------------------------------------------------------------------- diff --git a/flink-core/src/test/java/org/apache/flink/api/java/typeutils/TypeExtractorTest.java b/flink-core/src/test/java/org/apache/flink/api/java/typeutils/TypeExtractorTest.java index 443cbc3..55cd42d 100644 --- a/flink-core/src/test/java/org/apache/flink/api/java/typeutils/TypeExtractorTest.java +++ b/flink-core/src/test/java/org/apache/flink/api/java/typeutils/TypeExtractorTest.java @@ -30,12 +30,14 @@ import java.util.List; import java.util.Map; import org.apache.flink.api.common.ExecutionConfig; import org.apache.flink.api.common.functions.InvalidTypesException; +import org.apache.flink.api.common.functions.JoinFunction; import org.apache.flink.api.common.functions.MapFunction; import org.apache.flink.api.common.functions.RichCoGroupFunction; import org.apache.flink.api.common.functions.RichCrossFunction; import org.apache.flink.api.common.functions.RichFlatJoinFunction; import org.apache.flink.api.common.functions.RichFlatMapFunction; import org.apache.flink.api.common.functions.RichGroupReduceFunction; +import org.apache.flink.api.common.functions.RichJoinFunction; import org.apache.flink.api.common.functions.RichMapFunction; import org.apache.flink.api.common.functions.RuntimeContext; import org.apache.flink.api.common.typeinfo.BasicArrayTypeInfo; @@ -1665,7 +1667,40 @@ public class TypeExtractorTest { Assert.assertTrue(ti.isBasicType()); Assert.assertEquals(BasicTypeInfo.STRING_TYPE_INFO, ti); } - + + public static class CustomTuple2WithArray<K> extends Tuple2<K[], K> { + + public CustomTuple2WithArray() { + // default constructor + } + } + + public class JoinWithCustomTuple2WithArray<T> extends RichJoinFunction<CustomTuple2WithArray<T>, CustomTuple2WithArray<T>, CustomTuple2WithArray<T>> { + + @Override + public CustomTuple2WithArray<T> join(CustomTuple2WithArray<T> first, CustomTuple2WithArray<T> second) throws Exception { + return null; + } + } + + @Test + public void testInputInferenceWithCustomTupleAndRichFunction() { + JoinFunction<CustomTuple2WithArray<Long>, CustomTuple2WithArray<Long>, CustomTuple2WithArray<Long>> function = new JoinWithCustomTuple2WithArray<>(); + + TypeInformation<?> ti = TypeExtractor.getJoinReturnTypes( + function, + new TypeHint<CustomTuple2WithArray<Long>>(){}.getTypeInfo(), + new TypeHint<CustomTuple2WithArray<Long>>(){}.getTypeInfo()); + + Assert.assertTrue(ti.isTupleType()); + TupleTypeInfo<?> tti = (TupleTypeInfo<?>) ti; + Assert.assertEquals(BasicTypeInfo.LONG_TYPE_INFO, tti.getTypeAt(1)); + + Assert.assertTrue(tti.getTypeAt(0) instanceof ObjectArrayTypeInfo<?, ?>); + ObjectArrayTypeInfo<?, ?> oati = (ObjectArrayTypeInfo<?, ?>) tti.getTypeAt(0); + Assert.assertEquals(BasicTypeInfo.LONG_TYPE_INFO, oati.getComponentInfo()); + } + public static enum MyEnum { ONE, TWO, THREE }
