dawidwys commented on a change in pull request #8865: [FLINK-12924][table] Introduce basic type inference interfaces URL: https://github.com/apache/flink/pull/8865#discussion_r297033180
########## File path: flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/TypeInferenceUtil.java ########## @@ -0,0 +1,307 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.types.inference; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.table.api.TableException; +import org.apache.flink.table.api.ValidationException; +import org.apache.flink.table.functions.FunctionDefinition; +import org.apache.flink.table.functions.FunctionKind; +import org.apache.flink.table.types.DataType; + +import java.util.List; +import java.util.Optional; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +/** + * Utility for performing type inference. + */ +@Internal +public class TypeInferenceUtil { + + public static Result runTypeInference(TypeInference typeInference, CallContext callContext) { + try { + return runTypeInferenceInternal(typeInference, callContext); + } catch (ValidationException e) { + throw new ValidationException( + String.format( + "Invalid call to function '%s'. Given arguments: %s", + callContext.getName(), + callContext.getArgumentDataTypes().stream() + .map(DataType::toString) + .collect(Collectors.joining())), + e); + } catch (Throwable t) { + throw new TableException( + String.format( + "Unexpected error in type inference logic of function '%s'. This is a bug.", + callContext.getName()), + t); + } + } + + /** + * The result of a type inference run. It contains information about how arguments need to be + * adapted in order to comply with the function's signature. + * + * <p>This includes casts that need to be inserted, reordering of arguments (*), or insertion of default + * values (*) where (*) is future work. + */ + public static class Result { + + private final List<DataType> expectedArgumentTypes; + + private final DataType accumulatorDataType; + + private final DataType outputDataType; + + public Result( + List<DataType> expectedArgumentTypes, + DataType accumulatorDataType, + DataType outputDataType) { + this.expectedArgumentTypes = expectedArgumentTypes; + this.accumulatorDataType = accumulatorDataType; + this.outputDataType = outputDataType; + } + + public List<DataType> getExpectedArgumentTypes() { + return expectedArgumentTypes; + } + + public DataType getAccumulatorDataType() { + return accumulatorDataType; + } + + public DataType getOutputDataType() { + return outputDataType; + } + } + + // -------------------------------------------------------------------------------------------- + + private static Result runTypeInferenceInternal(TypeInference typeInference, CallContext callContext) { + final List<DataType> argumentTypes = callContext.getArgumentDataTypes(); + + try { + validateArgumentCount( + typeInference.getInputTypeValidator().getArgumentCount(), + callContext.getArgumentDataTypes().size()); + } catch (ValidationException e) { + throw getInvalidInputException(typeInference.getInputTypeValidator(), callContext); + } + + final List<DataType> expectedTypes = typeInference.getArgumentTypes() + .orElse(callContext.getArgumentDataTypes()); + + final List<String> expectedNames = typeInference.getArgumentNames() + .orElse(getDefaultArgumentNames(expectedTypes.size())); + + final AdaptedCallContext adaptedCallContext = adaptArguments( + callContext, + expectedNames, + expectedTypes); + + try { + validateInputTypes( + typeInference.getInputTypeValidator(), + adaptedCallContext); + } catch (ValidationException e) { + throw getInvalidInputException(typeInference.getInputTypeValidator(), adaptedCallContext); + } + + return inferTypes( + adaptedCallContext, + typeInference.getAccumulatorTypeStrategy(), + typeInference.getOutputTypeStrategy()); + } + + private static List<String> getDefaultArgumentNames(int argumentCount) { + return IntStream.range(0, argumentCount) + .mapToObj(i -> "arg" + i) + .collect(Collectors.toList()); + } + + private static ValidationException getInvalidInputException( + InputTypeValidator validator, + CallContext callContext) { + return new ValidationException( + String.format( + "Invalid input arguments. Expected signatures are:\n%s", + String.join( + "\n", + validator.getExpectedSignatures( + callContext.getName(), + callContext.getFunctionDefinition())))); + } + + private static void validateArgumentCount(ArgumentCount argumentCount, int actualCount) { + argumentCount.getMinCount().ifPresent((min) -> { + if (actualCount < min) { + throw new ValidationException( + String.format( + "Invalid number of arguments. At least %d arguments expected but %d passed.", + min, + actualCount)); + } + }); + + argumentCount.getMaxCount().ifPresent((max) -> { + if (actualCount > max) { + throw new ValidationException( + String.format( + "Invalid number of arguments. At most %d arguments expected but %d passed.", + max, + actualCount)); + } + }); + + if (argumentCount.isValidCount(actualCount)) { + throw new ValidationException( + String.format( + "Invalid number of arguments. %d arguments passed.", + actualCount)); + } + } + + private static void validateInputTypes(InputTypeValidator inputTypeValidator, CallContext callContext) { + if (!inputTypeValidator.validate(callContext, true)) { + throw new ValidationException("Invalid input arguments."); + } + } + + /** + * Adapts the call's argument if necessary. + * + * <p>This includes casts that need to be inserted, reordering of arguments (*), or insertion of default + * values (*) where (*) is future work. + */ + private static AdaptedCallContext adaptArguments( + CallContext callContext, + List<String> expectedNames, + List<DataType> expectedTypes) { + + for (int pos = 0; pos < callContext.getArgumentDataTypes().size(); pos++) { Review comment: Extract `callContext.getArgumentDataTypes()` to a variable ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: [email protected] With regards, Apache Git Services
