dawidwys commented on a change in pull request #8865:  [FLINK-12924][table] 
Introduce basic type inference interfaces
URL: https://github.com/apache/flink/pull/8865#discussion_r297045313
 
 

 ##########
 File path: 
flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/TypeInferenceUtil.java
 ##########
 @@ -0,0 +1,307 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.types.inference;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.table.api.TableException;
+import org.apache.flink.table.api.ValidationException;
+import org.apache.flink.table.functions.FunctionDefinition;
+import org.apache.flink.table.functions.FunctionKind;
+import org.apache.flink.table.types.DataType;
+
+import java.util.List;
+import java.util.Optional;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
+
+/**
+ * Utility for performing type inference.
+ */
+@Internal
+public class TypeInferenceUtil {
+
+       public static Result runTypeInference(TypeInference typeInference, 
CallContext callContext) {
+               try {
+                       return runTypeInferenceInternal(typeInference, 
callContext);
+               } catch (ValidationException e) {
+                       throw new ValidationException(
+                               String.format(
+                                       "Invalid call to function '%s'. Given 
arguments: %s",
+                                       callContext.getName(),
+                                       
callContext.getArgumentDataTypes().stream()
+                                               .map(DataType::toString)
+                                               .collect(Collectors.joining())),
+                               e);
+               } catch (Throwable t) {
+                       throw new TableException(
+                               String.format(
+                                       "Unexpected error in type inference 
logic of function '%s'. This is a bug.",
+                                       callContext.getName()),
+                               t);
+               }
+       }
+
+       /**
+        * The result of a type inference run. It contains information about 
how arguments need to be
+        * adapted in order to comply with the function's signature.
+        *
+        * <p>This includes casts that need to be inserted, reordering of 
arguments (*), or insertion of default
+        * values (*) where (*) is future work.
+        */
+       public static class Result {
+
+               private final List<DataType> expectedArgumentTypes;
+
+               private final DataType accumulatorDataType;
+
+               private final DataType outputDataType;
+
+               public Result(
+                               List<DataType> expectedArgumentTypes,
+                               DataType accumulatorDataType,
+                               DataType outputDataType) {
+                       this.expectedArgumentTypes = expectedArgumentTypes;
+                       this.accumulatorDataType = accumulatorDataType;
+                       this.outputDataType = outputDataType;
+               }
+
+               public List<DataType> getExpectedArgumentTypes() {
+                       return expectedArgumentTypes;
+               }
+
+               public DataType getAccumulatorDataType() {
+                       return accumulatorDataType;
+               }
+
+               public DataType getOutputDataType() {
+                       return outputDataType;
+               }
+       }
+
+       // 
--------------------------------------------------------------------------------------------
+
+       private static Result runTypeInferenceInternal(TypeInference 
typeInference, CallContext callContext) {
+               final List<DataType> argumentTypes = 
callContext.getArgumentDataTypes();
+
+               try {
+                       validateArgumentCount(
+                               
typeInference.getInputTypeValidator().getArgumentCount(),
+                               callContext.getArgumentDataTypes().size());
+               } catch (ValidationException e) {
+                       throw 
getInvalidInputException(typeInference.getInputTypeValidator(), callContext);
+               }
+
+               final List<DataType> expectedTypes = 
typeInference.getArgumentTypes()
+                       .orElse(callContext.getArgumentDataTypes());
+
+               final List<String> expectedNames = 
typeInference.getArgumentNames()
+                       .orElse(getDefaultArgumentNames(expectedTypes.size()));
+
+               final AdaptedCallContext adaptedCallContext = adaptArguments(
+                       callContext,
+                       expectedNames,
+                       expectedTypes);
+
+               try {
+                       validateInputTypes(
+                               typeInference.getInputTypeValidator(),
+                               adaptedCallContext);
+               } catch (ValidationException e) {
+                       throw 
getInvalidInputException(typeInference.getInputTypeValidator(), 
adaptedCallContext);
+               }
+
+               return inferTypes(
+                       adaptedCallContext,
+                       typeInference.getAccumulatorTypeStrategy(),
+                       typeInference.getOutputTypeStrategy());
+       }
+
+       private static List<String> getDefaultArgumentNames(int argumentCount) {
+               return IntStream.range(0, argumentCount)
+                       .mapToObj(i -> "arg" + i)
+                       .collect(Collectors.toList());
+       }
+
+       private static ValidationException getInvalidInputException(
+                       InputTypeValidator validator,
+                       CallContext callContext) {
+               return new ValidationException(
+                       String.format(
+                               "Invalid input arguments. Expected signatures 
are:\n%s",
+                               String.join(
+                                       "\n",
+                                       validator.getExpectedSignatures(
+                                               callContext.getName(),
+                                               
callContext.getFunctionDefinition()))));
+       }
+
+       private static void validateArgumentCount(ArgumentCount argumentCount, 
int actualCount) {
+               argumentCount.getMinCount().ifPresent((min) -> {
+                       if (actualCount < min) {
+                               throw new ValidationException(
+                                       String.format(
+                                               "Invalid number of arguments. 
At least %d arguments expected but %d passed.",
+                                               min,
+                                               actualCount));
+                       }
+               });
+
+               argumentCount.getMaxCount().ifPresent((max) -> {
+                       if (actualCount > max) {
+                               throw new ValidationException(
+                                       String.format(
+                                               "Invalid number of arguments. 
At most %d arguments expected but %d passed.",
+                                               max,
+                                               actualCount));
+                       }
+               });
+
+               if (argumentCount.isValidCount(actualCount)) {
+                       throw new ValidationException(
+                               String.format(
+                                       "Invalid number of arguments. %d 
arguments passed.",
+                                       actualCount));
+               }
+       }
+
+       private static void validateInputTypes(InputTypeValidator 
inputTypeValidator, CallContext callContext) {
+               if (!inputTypeValidator.validate(callContext, true)) {
+                       throw new ValidationException("Invalid input 
arguments.");
+               }
+       }
+
+       /**
+        * Adapts the call's argument if necessary.
+        *
+        * <p>This includes casts that need to be inserted, reordering of 
arguments (*), or insertion of default
+        * values (*) where (*) is future work.
+        */
+       private static AdaptedCallContext adaptArguments(
+                       CallContext callContext,
+                       List<String> expectedNames,
+                       List<DataType> expectedTypes) {
+
+               for (int pos = 0; pos < 
callContext.getArgumentDataTypes().size(); pos++) {
+                       final DataType expectedType = expectedTypes.get(pos);
+                       final DataType actualType = 
callContext.getArgumentDataTypes().get(pos);
+
+                       if (!actualType.equals(expectedType) && 
!canCast(actualType, expectedType)) {
+                               throw new ValidationException(
+                                       String.format(
+                                               "Invalid argument type at 
position %d. Data type %s expected but %s passed.",
+                                               pos,
+                                               expectedType,
+                                               actualType));
+                       }
+               }
+
+               return new AdaptedCallContext(callContext, expectedTypes);
+       }
+
+       private static boolean canCast(DataType sourceDataType, DataType 
targetDataType) {
+               return false; // TODO unsupported for now
+       }
+
+       private static Result inferTypes(
+                       AdaptedCallContext adaptedCallContext,
+                       TypeStrategy accumulatorTypeStrategy,
+                       TypeStrategy outputTypeStrategy) {
+
+               // infer output type first for better error message
+               // (logically an accumulator type should be inferred first)
+               final Optional<DataType> potentialOutputType = 
outputTypeStrategy.inferType(adaptedCallContext);
+               if (!potentialOutputType.isPresent()) {
+                       throw new ValidationException("Could not infer an 
output type for the given arguments.");
+               }
+               final DataType outputType = potentialOutputType.get();
+
+               if (adaptedCallContext.getFunctionDefinition().getKind() == 
FunctionKind.TABLE_AGGREGATE ||
+                               
adaptedCallContext.getFunctionDefinition().getKind() == FunctionKind.AGGREGATE) 
{
+                       final Optional<DataType> potentialAccumulatorType = 
accumulatorTypeStrategy.inferType(adaptedCallContext);
+                       if (!potentialAccumulatorType.isPresent()) {
+                               throw new ValidationException("Could not infer 
an accumulator type for the given arguments.");
+                       }
+                       return new Result(adaptedCallContext.expectedArguments, 
potentialAccumulatorType.get(), outputType);
+               } else {
+                       return new Result(adaptedCallContext.expectedArguments, 
outputType, outputType);
 
 Review comment:
   isn't it a bit misleading that we pass output type as accumulator type even 
though we do not expect accumulator at all?

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to