Repository: flink
Updated Branches:
  refs/heads/master bd2fce6e1 -> 6f09ecded


[FLINK-4801] [types] Input type inference is faulty with custom Tuples and 
RichFunctions

This closes #2625.


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/6f09ecde
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/6f09ecde
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/6f09ecde

Branch: refs/heads/master
Commit: 6f09ecded9e22a5eaa548ebbddb9b28dad4207c2
Parents: bd2fce6
Author: twalthr <[email protected]>
Authored: Wed Oct 12 10:33:47 2016 +0200
Committer: twalthr <[email protected]>
Committed: Tue Nov 15 14:56:59 2016 +0100

----------------------------------------------------------------------
 .../api/java/typeutils/TypeExtractionUtils.java |  38 ++++++
 .../flink/api/java/typeutils/TypeExtractor.java | 128 ++++++-------------
 .../typeutils/runtime/kryo/Serializers.java     |   6 +-
 .../api/java/typeutils/TypeExtractorTest.java   |  37 +++++-
 4 files changed, 114 insertions(+), 95 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/6f09ecde/flink-core/src/main/java/org/apache/flink/api/java/typeutils/TypeExtractionUtils.java
----------------------------------------------------------------------
diff --git 
a/flink-core/src/main/java/org/apache/flink/api/java/typeutils/TypeExtractionUtils.java
 
b/flink-core/src/main/java/org/apache/flink/api/java/typeutils/TypeExtractionUtils.java
index 4439612..0aac257 100644
--- 
a/flink-core/src/main/java/org/apache/flink/api/java/typeutils/TypeExtractionUtils.java
+++ 
b/flink-core/src/main/java/org/apache/flink/api/java/typeutils/TypeExtractionUtils.java
@@ -20,7 +20,9 @@ package org.apache.flink.api.java.typeutils;
 
 import java.lang.reflect.Constructor;
 import java.lang.reflect.Method;
+import java.lang.reflect.ParameterizedType;
 import java.lang.reflect.Type;
+import java.lang.reflect.TypeVariable;
 import java.util.ArrayList;
 import java.util.Collections;
 import java.util.List;
@@ -81,6 +83,12 @@ public class TypeExtractionUtils {
                }
        }
 
+       /**
+        * Checks if the given function has been implemented using a Java 8 
lambda. If yes, a LambdaExecutable
+        * is returned describing the method/constructor. Otherwise null.
+        *
+        * @throws TypeExtractionException lambda extraction is pretty hacky, 
it might fail for unknown JVM issues.
+        */
        public static LambdaExecutable checkAndExtractLambda(Function function) 
throws TypeExtractionException {
                try {
                        // get serialized lambda
@@ -164,4 +172,34 @@ public class TypeExtractionUtils {
                }
                return result;
        }
+
+       /**
+        * Convert ParameterizedType or Class to a Class.
+        */
+       public static Class<?> typeToClass(Type t) {
+               if (t instanceof Class) {
+                       return (Class<?>)t;
+               }
+               else if (t instanceof ParameterizedType) {
+                       return ((Class<?>)((ParameterizedType) t).getRawType());
+               }
+               throw new IllegalArgumentException("Cannot convert type to 
class");
+       }
+
+       /**
+        * Checks if a type can be converted to a Class. This is true for 
ParameterizedType and Class.
+        */
+       public static boolean isClassType(Type t) {
+               return t instanceof Class<?> || t instanceof ParameterizedType;
+       }
+
+       /**
+        * Checks whether two types are type variables describing the same.
+        */
+       public static boolean sameTypeVars(Type t1, Type t2) {
+               return t1 instanceof TypeVariable &&
+                       t2 instanceof TypeVariable &&
+                       ((TypeVariable<?>) 
t1).getName().equals(((TypeVariable<?>) t2).getName()) &&
+                       ((TypeVariable<?>) 
t1).getGenericDeclaration().equals(((TypeVariable<?>) 
t2).getGenericDeclaration());
+       }
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/6f09ecde/flink-core/src/main/java/org/apache/flink/api/java/typeutils/TypeExtractor.java
----------------------------------------------------------------------
diff --git 
a/flink-core/src/main/java/org/apache/flink/api/java/typeutils/TypeExtractor.java
 
b/flink-core/src/main/java/org/apache/flink/api/java/typeutils/TypeExtractor.java
index c1febea..b41bbc1 100644
--- 
a/flink-core/src/main/java/org/apache/flink/api/java/typeutils/TypeExtractor.java
+++ 
b/flink-core/src/main/java/org/apache/flink/api/java/typeutils/TypeExtractor.java
@@ -68,6 +68,9 @@ import org.apache.flink.api.java.tuple.Tuple0;
 import 
org.apache.flink.api.java.typeutils.TypeExtractionUtils.LambdaExecutable;
 import static 
org.apache.flink.api.java.typeutils.TypeExtractionUtils.checkAndExtractLambda;
 import static 
org.apache.flink.api.java.typeutils.TypeExtractionUtils.getAllDeclaredMethods;
+import static 
org.apache.flink.api.java.typeutils.TypeExtractionUtils.isClassType;
+import static 
org.apache.flink.api.java.typeutils.TypeExtractionUtils.sameTypeVars;
+import static 
org.apache.flink.api.java.typeutils.TypeExtractionUtils.typeToClass;
 import org.apache.flink.types.Either;
 import org.apache.flink.types.Value;
 import org.apache.flink.util.InstantiationUtil;
@@ -859,6 +862,14 @@ public class TypeExtractor {
                return null;
        }
 
+       /**
+        * Finds the type information to a type variable.
+        *
+        * It solve the following:
+        *
+        * Return the type information for "returnTypeVar" given that "inType" 
has type information "inTypeInfo".
+        * Thus "inType" must contain "returnTypeVar" in a 
"inputTypeHierarchy", otherwise null is returned.
+        */
        @SuppressWarnings({"unchecked", "rawtypes"})
        private <IN1> TypeInformation<?> 
createTypeInfoFromInput(TypeVariable<?> returnTypeVar, ArrayList<Type> 
inputTypeHierarchy, Type inType, TypeInformation<IN1> inTypeInfo) {
                TypeInformation<?> info = null;
@@ -891,9 +902,14 @@ public class TypeExtractor {
                        }
                }
                // the input is a type variable
+               else if (sameTypeVars(inType, returnTypeVar)) {
+                       return inTypeInfo;
+               }
                else if (inType instanceof TypeVariable) {
-                       inType = materializeTypeVariable(inputTypeHierarchy, 
(TypeVariable<?>) inType);
-                       info = findCorrespondingInfo(returnTypeVar, inType, 
inTypeInfo, inputTypeHierarchy);
+                       Type resolvedInType = 
materializeTypeVariable(inputTypeHierarchy, (TypeVariable<?>) inType);
+                       if (resolvedInType != inType) {
+                               info = createTypeInfoFromInput(returnTypeVar, 
inputTypeHierarchy, resolvedInType, inTypeInfo);
+                       }
                }
                // input is an array
                else if (inType instanceof GenericArrayType) {
@@ -910,8 +926,7 @@ public class TypeExtractor {
                        info = createTypeInfoFromInput(returnTypeVar, 
inputTypeHierarchy, ((GenericArrayType) inType).getGenericComponentType(), 
componentInfo);
                }
                // the input is a tuple
-               else if (inTypeInfo instanceof TupleTypeInfo && 
isClassType(inType) 
-                               && 
Tuple.class.isAssignableFrom(typeToClass(inType))) {
+               else if (inTypeInfo instanceof TupleTypeInfo && 
isClassType(inType) && Tuple.class.isAssignableFrom(typeToClass(inType))) {
                        ParameterizedType tupleBaseClass;
                        
                        // get tuple from possible tuple subclass
@@ -935,10 +950,25 @@ public class TypeExtractor {
                        }
                }
                // the input is a pojo
-               else if (inTypeInfo instanceof PojoTypeInfo) {
+               else if (inTypeInfo instanceof PojoTypeInfo && 
isClassType(inType)) {
                        // build the entire type hierarchy for the pojo
                        getTypeHierarchy(inputTypeHierarchy, inType, 
Object.class);
-                       info = findCorrespondingInfo(returnTypeVar, inType, 
inTypeInfo, inputTypeHierarchy);
+                       // determine a field containing the type variable
+                       List<Field> fields = 
getAllDeclaredFields(typeToClass(inType));
+                       for (Field field : fields) {
+                               Type fieldType = field.getGenericType();
+                               if (fieldType instanceof TypeVariable && 
sameTypeVars(returnTypeVar, materializeTypeVariable(inputTypeHierarchy, 
(TypeVariable<?>) fieldType))) {
+                                       return getTypeOfPojoField(inTypeInfo, 
field);
+                               }
+                               else if (fieldType instanceof ParameterizedType 
|| fieldType instanceof GenericArrayType) {
+                                       ArrayList<Type> 
typeHierarchyWithFieldType = new ArrayList<>(inputTypeHierarchy);
+                                       
typeHierarchyWithFieldType.add(fieldType);
+                                       TypeInformation<?> foundInfo = 
createTypeInfoFromInput(returnTypeVar, typeHierarchyWithFieldType, fieldType, 
getTypeOfPojoField(inTypeInfo, field));
+                                       if (foundInfo != null) {
+                                               return foundInfo;
+                                       }
+                               }
+                       }
                }
                return info;
        }
@@ -1557,66 +1587,7 @@ public class TypeExtractor {
                }
                throw new InvalidTypesException();
        }
-       
-       private static TypeInformation<?> findCorrespondingInfo(TypeVariable<?> 
typeVar, Type type, TypeInformation<?> corrInfo, ArrayList<Type> typeHierarchy) 
{
-               if (sameTypeVars(type, typeVar)) {
-                       return corrInfo;
-               }
-               else if (type instanceof TypeVariable && 
sameTypeVars(materializeTypeVariable(typeHierarchy, (TypeVariable<?>) type), 
typeVar)) {
-                       return corrInfo;
-               }
-               else if (type instanceof GenericArrayType) {
-                       TypeInformation<?> componentInfo = null;
-                       if (corrInfo instanceof BasicArrayTypeInfo) {
-                               componentInfo = ((BasicArrayTypeInfo<?,?>) 
corrInfo).getComponentInfo();
-                       }
-                       else if (corrInfo instanceof PrimitiveArrayTypeInfo) {
-                               componentInfo = 
BasicTypeInfo.getInfoFor(corrInfo.getTypeClass().getComponentType());
-                       }
-                       else if (corrInfo instanceof ObjectArrayTypeInfo) {
-                               componentInfo = ((ObjectArrayTypeInfo<?,?>) 
corrInfo).getComponentInfo();
-                       }
-                       TypeInformation<?> info = 
findCorrespondingInfo(typeVar, ((GenericArrayType) 
type).getGenericComponentType(), componentInfo, typeHierarchy);
-                       if (info != null) {
-                               return info;
-                       }
-               }
-               else if (corrInfo instanceof TupleTypeInfo
-                               && type instanceof ParameterizedType
-                               && Tuple.class.isAssignableFrom((Class<?>) 
((ParameterizedType) type).getRawType())) {
-                       ParameterizedType tuple = (ParameterizedType) type;
-                       Type[] args = tuple.getActualTypeArguments();
-                       
-                       for (int i = 0; i < args.length; i++) {
-                               TypeInformation<?> info = 
findCorrespondingInfo(typeVar, args[i], ((TupleTypeInfo<?>) 
corrInfo).getTypeAt(i), typeHierarchy);
-                               if (info != null) {
-                                       return info;
-                               }
-                       }
-               }
-               else if (corrInfo instanceof PojoTypeInfo && isClassType(type)) 
{
-                       // determine a field containing the type variable
-                       List<Field> fields = 
getAllDeclaredFields(typeToClass(type));
-                       for (Field field : fields) {
-                               Type fieldType = field.getGenericType();
-                               if (fieldType instanceof TypeVariable 
-                                               && sameTypeVars(typeVar, 
materializeTypeVariable(typeHierarchy, (TypeVariable<?>) fieldType))) {
-                                       return getTypeOfPojoField(corrInfo, 
field);
-                               }
-                               else if (fieldType instanceof ParameterizedType
-                                               || fieldType instanceof 
GenericArrayType) {
-                                       ArrayList<Type> 
typeHierarchyWithFieldType = new ArrayList<Type>(typeHierarchy);
-                                       
typeHierarchyWithFieldType.add(fieldType);
-                                       TypeInformation<?> info = 
findCorrespondingInfo(typeVar, fieldType, getTypeOfPojoField(corrInfo, field), 
typeHierarchyWithFieldType);
-                                       if (info != null) {
-                                               return info;
-                                       }
-                               }
-                       }
-               }
-               return null;
-       }
-       
+
        /**
         * Tries to find a concrete value (Class, ParameterizedType etc. ) for 
a TypeVariable by traversing the type hierarchy downwards.
         * If a value could not be found it will return the most bottom type 
variable in the hierarchy.
@@ -1991,30 +1962,6 @@ public class TypeExtractor {
                }
                return false;
        }
-
-       @Internal
-       public static Class<?> typeToClass(Type t) {
-               if (t instanceof Class) {
-                       return (Class<?>)t;
-               }
-               else if (t instanceof ParameterizedType) {
-                       return ((Class<?>)((ParameterizedType) t).getRawType());
-               }
-               throw new IllegalArgumentException("Cannot convert type to 
class");
-       }
-
-       @Internal
-       public static boolean isClassType(Type t) {
-               return t instanceof Class<?> || t instanceof ParameterizedType;
-       }
-
-       private static boolean sameTypeVars(Type t1, Type t2) {
-               if (!(t1 instanceof TypeVariable) || !(t2 instanceof 
TypeVariable)) {
-                       return false;
-               }
-               return ((TypeVariable<?>) 
t1).getName().equals(((TypeVariable<?>) t2).getName())
-                               && ((TypeVariable<?>) 
t1).getGenericDeclaration().equals(((TypeVariable<?>) 
t2).getGenericDeclaration());
-       }
        
        private static TypeInformation<?> getTypeOfPojoField(TypeInformation<?> 
pojoInfo, Field field) {
                for (int j = 0; j < pojoInfo.getArity(); j++) {
@@ -2026,7 +1973,6 @@ public class TypeExtractor {
                return null;
        }
 
-
        public static <X> TypeInformation<X> getForObject(X value) {
                return new TypeExtractor().privateGetForObject(value);
        }

http://git-wip-us.apache.org/repos/asf/flink/blob/6f09ecde/flink-core/src/main/java/org/apache/flink/api/java/typeutils/runtime/kryo/Serializers.java
----------------------------------------------------------------------
diff --git 
a/flink-core/src/main/java/org/apache/flink/api/java/typeutils/runtime/kryo/Serializers.java
 
b/flink-core/src/main/java/org/apache/flink/api/java/typeutils/runtime/kryo/Serializers.java
index b6e978f..4976d6a 100644
--- 
a/flink-core/src/main/java/org/apache/flink/api/java/typeutils/runtime/kryo/Serializers.java
+++ 
b/flink-core/src/main/java/org/apache/flink/api/java/typeutils/runtime/kryo/Serializers.java
@@ -34,7 +34,7 @@ import org.apache.flink.api.common.typeinfo.TypeInformation;
 import org.apache.flink.api.common.typeutils.CompositeType;
 import org.apache.flink.api.java.typeutils.GenericTypeInfo;
 import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
-import org.apache.flink.api.java.typeutils.TypeExtractor;
+import org.apache.flink.api.java.typeutils.TypeExtractionUtils;
 
 import java.io.Serializable;
 import java.lang.reflect.Field;
@@ -113,8 +113,8 @@ public class Serializers {
                        ParameterizedType parameterizedFieldType = 
(ParameterizedType) fieldType;
                        
                        for (Type t: 
parameterizedFieldType.getActualTypeArguments()) {
-                               if (TypeExtractor.isClassType(t) ) {
-                                       
recursivelyRegisterType(TypeExtractor.typeToClass(t), config, alreadySeen);
+                               if (TypeExtractionUtils.isClassType(t) ) {
+                                       
recursivelyRegisterType(TypeExtractionUtils.typeToClass(t), config, 
alreadySeen);
                                }
                        }
 

http://git-wip-us.apache.org/repos/asf/flink/blob/6f09ecde/flink-core/src/test/java/org/apache/flink/api/java/typeutils/TypeExtractorTest.java
----------------------------------------------------------------------
diff --git 
a/flink-core/src/test/java/org/apache/flink/api/java/typeutils/TypeExtractorTest.java
 
b/flink-core/src/test/java/org/apache/flink/api/java/typeutils/TypeExtractorTest.java
index 443cbc3..55cd42d 100644
--- 
a/flink-core/src/test/java/org/apache/flink/api/java/typeutils/TypeExtractorTest.java
+++ 
b/flink-core/src/test/java/org/apache/flink/api/java/typeutils/TypeExtractorTest.java
@@ -30,12 +30,14 @@ import java.util.List;
 import java.util.Map;
 import org.apache.flink.api.common.ExecutionConfig;
 import org.apache.flink.api.common.functions.InvalidTypesException;
+import org.apache.flink.api.common.functions.JoinFunction;
 import org.apache.flink.api.common.functions.MapFunction;
 import org.apache.flink.api.common.functions.RichCoGroupFunction;
 import org.apache.flink.api.common.functions.RichCrossFunction;
 import org.apache.flink.api.common.functions.RichFlatJoinFunction;
 import org.apache.flink.api.common.functions.RichFlatMapFunction;
 import org.apache.flink.api.common.functions.RichGroupReduceFunction;
+import org.apache.flink.api.common.functions.RichJoinFunction;
 import org.apache.flink.api.common.functions.RichMapFunction;
 import org.apache.flink.api.common.functions.RuntimeContext;
 import org.apache.flink.api.common.typeinfo.BasicArrayTypeInfo;
@@ -1665,7 +1667,40 @@ public class TypeExtractorTest {
                Assert.assertTrue(ti.isBasicType());
                Assert.assertEquals(BasicTypeInfo.STRING_TYPE_INFO, ti);
        }
-       
+
+       public static class CustomTuple2WithArray<K> extends Tuple2<K[], K> {
+
+               public CustomTuple2WithArray() {
+                       // default constructor
+               }
+       }
+
+       public class JoinWithCustomTuple2WithArray<T> extends 
RichJoinFunction<CustomTuple2WithArray<T>, CustomTuple2WithArray<T>, 
CustomTuple2WithArray<T>> {
+
+               @Override
+               public CustomTuple2WithArray<T> join(CustomTuple2WithArray<T> 
first, CustomTuple2WithArray<T> second) throws Exception {
+                       return null;
+               }
+       }
+
+       @Test
+       public void testInputInferenceWithCustomTupleAndRichFunction() {
+               JoinFunction<CustomTuple2WithArray<Long>, 
CustomTuple2WithArray<Long>, CustomTuple2WithArray<Long>> function = new 
JoinWithCustomTuple2WithArray<>();
+
+               TypeInformation<?> ti = TypeExtractor.getJoinReturnTypes(
+                       function,
+                       new 
TypeHint<CustomTuple2WithArray<Long>>(){}.getTypeInfo(),
+                       new 
TypeHint<CustomTuple2WithArray<Long>>(){}.getTypeInfo());
+
+               Assert.assertTrue(ti.isTupleType());
+               TupleTypeInfo<?> tti = (TupleTypeInfo<?>) ti;
+               Assert.assertEquals(BasicTypeInfo.LONG_TYPE_INFO, 
tti.getTypeAt(1));
+
+               Assert.assertTrue(tti.getTypeAt(0) instanceof 
ObjectArrayTypeInfo<?, ?>);
+               ObjectArrayTypeInfo<?, ?> oati = (ObjectArrayTypeInfo<?, ?>) 
tti.getTypeAt(0);
+               Assert.assertEquals(BasicTypeInfo.LONG_TYPE_INFO, 
oati.getComponentInfo());
+       }
+
        public static enum MyEnum {
                ONE, TWO, THREE
        }

Reply via email to