dawidwys commented on a change in pull request #10606: 
[FLINK-15009][table-common] Add a utility for creating type inference logic via 
reflection
URL: https://github.com/apache/flink/pull/10606#discussion_r363370516
 
 

 ##########
 File path: 
flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/TypeInferenceExtractor.java
 ##########
 @@ -0,0 +1,703 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.types.extraction;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.table.annotation.DataTypeHint;
+import org.apache.flink.table.annotation.FunctionHint;
+import org.apache.flink.table.api.ValidationException;
+import org.apache.flink.table.catalog.DataTypeLookup;
+import org.apache.flink.table.functions.AggregateFunction;
+import org.apache.flink.table.functions.AsyncTableFunction;
+import org.apache.flink.table.functions.ScalarFunction;
+import org.apache.flink.table.functions.TableAggregateFunction;
+import org.apache.flink.table.functions.TableFunction;
+import org.apache.flink.table.functions.UserDefinedFunction;
+import org.apache.flink.table.types.CollectionDataType;
+import org.apache.flink.table.types.DataType;
+import org.apache.flink.table.types.extraction.utils.ExtractionUtils;
+import org.apache.flink.table.types.extraction.utils.FunctionArgumentTemplate;
+import org.apache.flink.table.types.extraction.utils.FunctionResultTemplate;
+import org.apache.flink.table.types.extraction.utils.FunctionSignatureTemplate;
+import org.apache.flink.table.types.extraction.utils.FunctionTemplate;
+import org.apache.flink.table.types.inference.InputTypeStrategies;
+import org.apache.flink.table.types.inference.InputTypeStrategy;
+import org.apache.flink.table.types.inference.TypeInference;
+import org.apache.flink.table.types.inference.TypeStrategies;
+import org.apache.flink.table.types.inference.TypeStrategy;
+
+import javax.annotation.Nullable;
+
+import java.lang.reflect.Method;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import java.util.Set;
+import java.util.concurrent.CompletableFuture;
+import java.util.function.Function;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
+import java.util.stream.Stream;
+
+import static 
org.apache.flink.table.types.extraction.utils.ExtractionUtils.collectAnnotationsOfClass;
+import static 
org.apache.flink.table.types.extraction.utils.ExtractionUtils.collectAnnotationsOfMethod;
+import static 
org.apache.flink.table.types.extraction.utils.ExtractionUtils.collectMethods;
+import static 
org.apache.flink.table.types.extraction.utils.ExtractionUtils.extractionError;
+import static 
org.apache.flink.table.types.extraction.utils.ExtractionUtils.isAssignable;
+import static 
org.apache.flink.table.types.extraction.utils.ExtractionUtils.isMethodInvokable;
+
+/**
+ * Reflection-based utility for extracting a {@link TypeInference} from a 
supported subclass of
+ * {@link UserDefinedFunction}.
+ *
+ * <p>The behavior of this utility can be influenced by {@link DataTypeHint}s 
and {@link FunctionHint}s
+ * which have higher precedence than the reflective information.
+ *
+ * <p>Note: This utility assumes that functions have been validated before 
regarding accessibility of
+ * class/methods and serializability.
+ */
+@Internal
+public final class TypeInferenceExtractor {
+
+       private final DataTypeLookup lookup;
+
+       private final Class<? extends UserDefinedFunction> function;
+
+       private final String functionExplanation;
+
+       private final String methodName;
+
+       private TypeInferenceExtractor(
+                       DataTypeLookup lookup,
+                       Class<? extends UserDefinedFunction> function,
+                       String functionExplanation,
+                       String methodName) {
+               this.lookup = lookup;
+               this.function = function;
+               this.functionExplanation = functionExplanation;
+               this.methodName = methodName;
+       }
+
+       /**
+        * Extracts a type inference from a {@link ScalarFunction}.
+        */
+       public static TypeInference forScalarFunction(DataTypeLookup lookup, 
Class<? extends ScalarFunction> function) {
+               final TypeInferenceExtractor extractor = new 
TypeInferenceExtractor(lookup, function, "scalar", "eval");
+               return extractor.extractTypeInference(
+                       extractor.createParameterSignatureExtraction(false),
+                       null,
+                       extractor.createReturnTypeResultExtraction(),
+                       extractor.createParameterAndReturnTypeVerification()
+               );
+       }
+
+       /**
+        * Extracts a type inference from a {@link AggregateFunction}.
+        */
+       public static TypeInference forAggregateFunction(DataTypeLookup lookup, 
Class<? extends AggregateFunction> function) {
+               final TypeInferenceExtractor extractor = new 
TypeInferenceExtractor(lookup, function, "aggregate", "accumulate");
+               return extractor.extractTypeInference(
+                       extractor.createParameterSignatureExtraction(true),
+                       
extractor.createGenericResultExtraction(AggregateFunction.class, 1),
+                       
extractor.createGenericResultExtraction(AggregateFunction.class, 0),
+                       extractor.createParameterWithAccumulatorVerification()
+               );
+       }
+
+       /**
+        * Extracts a type inference from a {@link TableFunction}.
+        */
+       public static TypeInference forTableFunction(DataTypeLookup lookup, 
Class<? extends TableFunction> function) {
+               final TypeInferenceExtractor extractor = new 
TypeInferenceExtractor(lookup, function, "table", "eval");
+               return extractor.extractTypeInference(
+                       extractor.createParameterSignatureExtraction(false),
+                       null,
+                       
extractor.createGenericResultExtraction(TableFunction.class, 0),
+                       extractor.createParameterVerification()
+               );
+       }
+
+       /**
+        * Extracts a type inference from a {@link TableAggregateFunction}.
+        */
+       public static TypeInference forTableAggregateFunction(DataTypeLookup 
lookup, Class<? extends TableAggregateFunction> function) {
+               final TypeInferenceExtractor extractor = new 
TypeInferenceExtractor(lookup, function, "table aggregate", "accumulate");
+               return extractor.extractTypeInference(
+                       extractor.createParameterSignatureExtraction(true),
+                       
extractor.createGenericResultExtraction(TableAggregateFunction.class, 1),
+                       
extractor.createGenericResultExtraction(TableAggregateFunction.class, 0),
+                       extractor.createParameterWithAccumulatorVerification()
+               );
+       }
+
+       /**
+        * Extracts a type inference from a {@link AsyncTableFunction}.
+        */
+       public static TypeInference forAsyncTableFunction(DataTypeLookup 
lookup, Class<? extends AsyncTableFunction> function) {
+               final TypeInferenceExtractor extractor = new 
TypeInferenceExtractor(lookup, function, "async table", "eval");
+               return extractor.extractTypeInference(
+                       extractor.createParameterSignatureExtraction(true),
+                       null,
+                       
extractor.createGenericResultExtraction(AsyncTableFunction.class, 0),
+                       
extractor.createParameterWithArgumentVerification(CompletableFuture.class)
+               );
+       }
+
+       // 
--------------------------------------------------------------------------------------------
+       // Context sensitive extraction and verification logic
+       // 
--------------------------------------------------------------------------------------------
+
+       /**
+        * Extraction that uses the method parameters for producing a {@link 
FunctionSignatureTemplate}.
+        */
+       private SignatureExtraction createParameterSignatureExtraction(boolean 
excludeFirstArg) {
+               final int offset;
+               if (excludeFirstArg) {
+                       offset = 1;
+               } else {
+                       offset = 0;
+               }
+               return method -> {
+                       // argument types
+                       final List<FunctionArgumentTemplate> parameterTypes = 
IntStream.range(offset, method.getParameterCount())
+                       .mapToObj(i -> {
+                               final DataType type = 
DataTypeExtractor.extractFromMethodParameter(lookup, function, method, i);
+                               // unwrap from ARRAY data type in case of 
varargs
+                               if (method.isVarArgs() && i == 
method.getParameterCount() - 1 && type instanceof CollectionDataType) {
+                                       return ((CollectionDataType) 
type).getElementDataType();
+                               } else {
+                                       return type;
+                               }
+                       })
+                       .map(FunctionArgumentTemplate::of)
+                       .collect(Collectors.toList());
+
+                       // argument names
+                       final List<String> methodParameterNames = 
ExtractionUtils.extractMethodParameterNames(method);
+                       final String[] argumentNames;
+                       if (methodParameterNames != null) {
+                               argumentNames = 
methodParameterNames.subList(offset, methodParameterNames.size())
+                                       .toArray(new String[0]);
+                       } else {
+                               argumentNames = null;
+                       }
+
+                       return FunctionSignatureTemplate.of(parameterTypes, 
method.isVarArgs(), argumentNames);
+               };
+       }
+
+       /**
+        * Extraction that uses a generic type variable for producing a {@link 
FunctionResultTemplate}.
+        */
+       private ResultExtraction createGenericResultExtraction(
+                       Class<? extends UserDefinedFunction> baseClass,
+                       int genericPos) {
+               return (method) -> {
+                       final DataType dataType = 
DataTypeExtractor.extractFromGeneric(lookup, baseClass, genericPos, function);
+                       return FunctionResultTemplate.of(dataType);
+               };
+       }
+
+       /**
+        * Extraction that uses the method return type for producing a {@link 
FunctionResultTemplate}.
+        */
+       private ResultExtraction createReturnTypeResultExtraction() {
+               return (method) -> {
+                       final DataType dataType = 
DataTypeExtractor.extractFromMethodOutput(lookup, function, method);
+                       return FunctionResultTemplate.of(dataType);
+               };
+       }
+
+       /**
+        * Verification that checks a method by parameters and return type.
+        */
+       private MethodVerification createParameterAndReturnTypeVerification() {
+               return (method, signature, accumulator, output) -> {
+                       final Class<?>[] parameters = signature.toArray(new 
Class[0]);
+                       final Class<?> returnType = method.getReturnType();
+                       final boolean isValid = isMethodInvokable(method, 
parameters) &&
+                               isAssignable(output, returnType, true);
+                       if (!isValid) {
+                               throw createMethodNotFoundError(parameters, 
returnType);
+                       }
+               };
+       }
+
+       /**
+        * Verification that checks a method by parameters including an 
accumulator.
+        */
+       private MethodVerification createParameterWithAccumulatorVerification() 
{
+               return (method, signature, accumulator, output) ->
+                       createParameterWithArgumentVerification(accumulator)
+                               .verify(method, signature, accumulator, output);
+       }
+
+       /**
+        * Verification that checks a method by parameters including an 
additional first parameter.
+        */
+       private MethodVerification 
createParameterWithArgumentVerification(@Nullable Class<?> argumentClass) {
+               return (method, signature, accumulator, output) -> {
+                       final Class<?>[] parameters = 
Stream.concat(Stream.of(argumentClass), signature.stream())
+                               .toArray(Class[]::new);
+                       if (!isMethodInvokable(method, parameters)) {
+                               throw createMethodNotFoundError(parameters, 
null);
+                       }
+               };
+       }
+
+       /**
+        * Verification that checks a method by parameters.
+        */
+       private MethodVerification createParameterVerification() {
+               return (method, signature, accumulator, output) -> {
+                       final Class<?>[] parameters = signature.toArray(new 
Class[0]);
+                       if (!isMethodInvokable(method, parameters)) {
+                               throw createMethodNotFoundError(parameters, 
null);
+                       }
+               };
+       }
+
+       private ValidationException createMethodNotFoundError(Class<?>[] 
parameters, @Nullable Class<?> returnType) {
+               final StringBuilder builder = new StringBuilder();
+               if (returnType != null) {
+                       builder.append(returnType.getName()).append(" ");
+               }
+               builder
+                       .append(methodName)
+                       .append(
+                               Stream.of(parameters)
+                                       .map(parameter -> {
+                                               // in case we don't know the 
parameter at this location (i.e. for accumulators)
+                                               if (parameter == null) {
+                                                       return "_";
+                                               } else {
+                                                       return 
parameter.getName();
+                                               }
+                                       })
+                                       .collect(Collectors.joining(", ", "(", 
")")));
+               return extractionError(
+                       "Considering all hints, the method should comply with 
the signature:\n%s",
+                       builder.toString());
+       }
+
+       private TypeInference extractTypeInference(
+                       SignatureExtraction signatureExtraction,
+                       @Nullable ResultExtraction accumulatorExtraction,
+                       ResultExtraction outputExtraction,
+                       MethodVerification verification) {
+               try {
+                       return extractTypeInferenceOrError(
+                               signatureExtraction,
+                               accumulatorExtraction,
+                               outputExtraction,
+                               verification
+                       );
+               } catch (Throwable t) {
+                       throw extractionError(
+                               t,
+                               "Could not extract a valid type inference from 
%s function class '%s'. " +
+                                       "Please check for implementation 
mistakes and/or provide a corresponding hint.",
+                               functionExplanation,
+                               function.getName());
+               }
+       }
+
+       private TypeInference extractTypeInferenceOrError(
+                       SignatureExtraction signatureExtraction,
+                       @Nullable ResultExtraction accumulatorExtraction,
+                       ResultExtraction outputExtraction,
+                       MethodVerification verification) {
+
+               final Map<FunctionSignatureTemplate, FunctionResultTemplate> 
outputMapping;
+               try {
+                       outputMapping = extractResultMappings(
+                               signatureExtraction,
+                               outputExtraction,
+                               verification,
+                               false);
+               } catch (Throwable t) {
+                       throw extractionError(t, "Error in extracting a 
signature to output strategy.");
+               }
+
+               // function is accumulating
+               if (accumulatorExtraction != null) {
+                       final Map<FunctionSignatureTemplate, 
FunctionResultTemplate> accumulatorMapping;
+                       try {
+                               accumulatorMapping = extractResultMappings(
+                                       signatureExtraction,
+                                       accumulatorExtraction,
+                                       verification,
+                                       true);
+                       } catch (Throwable t) {
+                               throw extractionError(t, "Error in extracting a 
signature to accumulator strategy.");
+                       }
+                       return buildInference(accumulatorMapping, 
outputMapping);
+               }
+               return buildInference(null, outputMapping);
+       }
+
 
 Review comment:
   Continue review from here.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to