This is an automated email from the ASF dual-hosted git repository. twalthr pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/flink.git
The following commit(s) were added to refs/heads/master by this push: new 06bfd9b [FLINK-14983][table-common] Add interfaces for input type inference 06bfd9b is described below commit 06bfd9b2e37a3cb58573f485968217f6802869ff Author: Timo Walther <twal...@apache.org> AuthorDate: Fri Nov 29 17:00:15 2019 +0100 [FLINK-14983][table-common] Add interfaces for input type inference This adds a similar class as Calcite's SqlOperandTypeInference to Flink's type inference. For FLIP-65, we will need to implement those interfaces as the planner needs to infer the DataTypes out of logical types that come from the logical query. This is also beneficial to support the NULL literal in the future. This closes #10368. --- .../resolver/rules/ResolveCallByArgumentsRule.java | 88 ++++--- .../flink/table/types/inference/CallContext.java | 47 +++- .../table/types/inference/CallContextBase.java | 72 ------ .../table/types/inference/InputTypeStrategies.java | 73 ++++++ .../table/types/inference/InputTypeStrategy.java | 49 ++++ .../{CallContext.java => MutableCallContext.java} | 25 +- .../flink/table/types/inference/TypeInference.java | 116 +++++---- .../table/types/inference/TypeInferenceUtil.java | 266 ++++++++++++++------- .../strategies/BridgingInputTypeStrategy.java | 135 +++++++++++ .../strategies/ExplicitInputTypeStrategy.java | 67 ++++++ .../NopInputTypeStrategy.java} | 34 ++- .../strategies/OutputTypeInputTypeStrategy.java | 58 +++++ .../types/inference/utils/AdaptedCallContext.java | 107 +++++++++ .../types/inference/utils/UnknownCallContext.java | 93 +++++++ .../types/inference/InputTypeStrategiesTest.java | 227 ++++++++++++++++++ .../types/inference/InputTypeValidatorsTest.java | 15 +- .../table/types/inference/TypeStrategiesTest.java | 10 +- 17 files changed, 1215 insertions(+), 267 deletions(-) diff --git a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/expressions/resolver/rules/ResolveCallByArgumentsRule.java b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/expressions/resolver/rules/ResolveCallByArgumentsRule.java index 2903504..122b207 100644 --- a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/expressions/resolver/rules/ResolveCallByArgumentsRule.java +++ b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/expressions/resolver/rules/ResolveCallByArgumentsRule.java @@ -37,9 +37,14 @@ import org.apache.flink.table.types.DataType; import org.apache.flink.table.types.inference.CallContext; import org.apache.flink.table.types.inference.TypeInference; import org.apache.flink.table.types.inference.TypeInferenceUtil; +import org.apache.flink.table.types.inference.TypeInferenceUtil.Result; +import org.apache.flink.table.types.inference.TypeInferenceUtil.SurroundingInfo; import org.apache.flink.table.types.inference.TypeStrategies; import org.apache.flink.util.Preconditions; +import javax.annotation.Nullable; + +import java.util.ArrayList; import java.util.Collections; import java.util.List; import java.util.Optional; @@ -67,43 +72,53 @@ final class ResolveCallByArgumentsRule implements ResolverRule { @Override public List<Expression> apply(List<Expression> expression, ResolutionContext context) { return expression.stream() - .flatMap(expr -> expr.accept(new ResolvingCallVisitor(context)).stream()) + .flatMap(expr -> expr.accept(new ResolvingCallVisitor(context, null)).stream()) .collect(Collectors.toList()); } // -------------------------------------------------------------------------------------------- - private class ResolvingCallVisitor extends RuleExpressionVisitor<List<ResolvedExpression>> { + private static class ResolvingCallVisitor extends RuleExpressionVisitor<List<ResolvedExpression>> { + + private @Nullable SurroundingInfo surroundingInfo; - ResolvingCallVisitor(ResolutionContext context) { + ResolvingCallVisitor(ResolutionContext context, @Nullable SurroundingInfo surroundingInfo) { super(context); + this.surroundingInfo = surroundingInfo; } @Override public List<ResolvedExpression> visit(UnresolvedCallExpression unresolvedCall) { + final FunctionDefinition definition = unresolvedCall.getFunctionDefinition(); - final List<ResolvedExpression> resolvedArgs = unresolvedCall.getChildren().stream() - .flatMap(c -> c.accept(this).stream()) - .collect(Collectors.toList()); + final String name = unresolvedCall.getFunctionIdentifier() + .map(FunctionIdentifier::toString) + .orElseGet(definition::toString); + + final Optional<TypeInference> typeInference = getOptionalTypeInference(definition); + + // resolve the children with information from the current call + final List<ResolvedExpression> resolvedArgs = new ArrayList<>(); + final int argCount = unresolvedCall.getChildren().size(); + for (int i = 0; i < argCount; i++) { + final int currentPos = i; + final ResolvingCallVisitor childResolver = new ResolvingCallVisitor( + resolutionContext, + typeInference + .map(inference -> new SurroundingInfo(name, definition, inference, argCount, currentPos)) + .orElse(null)); + resolvedArgs.addAll(unresolvedCall.getChildren().get(i).accept(childResolver)); + } - if (unresolvedCall.getFunctionDefinition() == BuiltInFunctionDefinitions.FLATTEN) { + if (definition == BuiltInFunctionDefinitions.FLATTEN) { return executeFlatten(resolvedArgs); } - if (unresolvedCall.getFunctionDefinition() instanceof BuiltInFunctionDefinition) { - final BuiltInFunctionDefinition definition = - (BuiltInFunctionDefinition) unresolvedCall.getFunctionDefinition(); - - if (definition.getTypeInference().getOutputTypeStrategy() != TypeStrategies.MISSING) { - return Collections.singletonList( - runTypeInference( - unresolvedCall, - definition.getTypeInference(), - resolvedArgs)); - } - } return Collections.singletonList( - runLegacyTypeInference(unresolvedCall, resolvedArgs)); + typeInference + .map(newInference -> runTypeInference(name, unresolvedCall, newInference, resolvedArgs, surroundingInfo)) + .orElseGet(() -> runLegacyTypeInference(unresolvedCall, resolvedArgs)) + ); } @Override @@ -140,18 +155,29 @@ final class ResolveCallByArgumentsRule implements ResolverRule { .collect(Collectors.toList()); } + /** + * Temporary method until all calls define a type inference. + */ + private Optional<TypeInference> getOptionalTypeInference(FunctionDefinition definition) { + if (definition instanceof BuiltInFunctionDefinition) { + final BuiltInFunctionDefinition builtInDefinition = (BuiltInFunctionDefinition) definition; + if (builtInDefinition.getTypeInference().getOutputTypeStrategy() != TypeStrategies.MISSING) { + return Optional.of(builtInDefinition.getTypeInference()); + } + } + return Optional.empty(); + } + private ResolvedExpression runTypeInference( + String name, UnresolvedCallExpression unresolvedCall, TypeInference inference, - List<ResolvedExpression> resolvedArgs) { - - final String name = unresolvedCall.getFunctionIdentifier() - .map(FunctionIdentifier::toString) - .orElseGet(() -> unresolvedCall.getFunctionDefinition().toString()); + List<ResolvedExpression> resolvedArgs, + @Nullable SurroundingInfo surroundingInfo) { - final TypeInferenceUtil.Result inferenceResult = TypeInferenceUtil.runTypeInference( + final Result inferenceResult = TypeInferenceUtil.runTypeInference( inference, - new TableApiCallContext(name, unresolvedCall.getFunctionDefinition(), resolvedArgs)); + new TableApiCallContext(name, unresolvedCall.getFunctionDefinition(), resolvedArgs), surroundingInfo); final List<ResolvedExpression> adaptedArguments = adaptArguments(inferenceResult, resolvedArgs); @@ -164,7 +190,7 @@ final class ResolveCallByArgumentsRule implements ResolverRule { final PlannerTypeInferenceUtil util = resolutionContext.functionLookup().getPlannerTypeInferenceUtil(); - final TypeInferenceUtil.Result inferenceResult = util.runTypeInference( + final Result inferenceResult = util.runTypeInference( unresolvedCall, resolvedArgs); @@ -174,10 +200,10 @@ final class ResolveCallByArgumentsRule implements ResolverRule { } /** - * Adapts the arguments according to the properties of the {@link TypeInferenceUtil.Result}. + * Adapts the arguments according to the properties of the {@link Result}. */ private List<ResolvedExpression> adaptArguments( - TypeInferenceUtil.Result inferenceResult, + Result inferenceResult, List<ResolvedExpression> resolvedArgs) { return IntStream.range(0, resolvedArgs.size()) @@ -198,7 +224,7 @@ final class ResolveCallByArgumentsRule implements ResolverRule { // -------------------------------------------------------------------------------------------- - private class TableApiCallContext implements CallContext { + private static class TableApiCallContext implements CallContext { private final String name; diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/CallContext.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/CallContext.java index 3e54884..780a12c 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/CallContext.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/CallContext.java @@ -19,19 +19,62 @@ package org.apache.flink.table.types.inference; import org.apache.flink.annotation.PublicEvolving; +import org.apache.flink.table.api.ValidationException; +import org.apache.flink.table.functions.FunctionDefinition; import org.apache.flink.table.types.DataType; import java.util.List; +import java.util.Optional; /** - * Provides details about the function call for {@link InputTypeValidator} and {@link TypeStrategy}. + * Provides details about a function call during {@link TypeInference}. */ @PublicEvolving -public interface CallContext extends CallContextBase { +public interface CallContext { + + /** + * Returns the function definition that defines the function currently being called. + */ + FunctionDefinition getFunctionDefinition(); + + /** + * Returns whether the argument at the given position is a value literal. + */ + boolean isArgumentLiteral(int pos); + + /** + * Returns {@code true} if the argument at the given position is a literal and {@code null}, + * {@code false} otherwise. + * + * <p>Use {@link #isArgumentLiteral(int)} before to check if the argument is actually a literal. + */ + boolean isArgumentNull(int pos); + + /** + * Returns the literal value of the argument at the given position, given that the argument is a + * literal, is not null, and can be expressed as an instance of the provided class. + * + * <p>Use {@link #isArgumentLiteral(int)} before to check if the argument is actually a literal. + */ + <T> Optional<T> getArgumentValue(int pos, Class<T> clazz); + + /** + * Returns the function's name usually referencing the function in a catalog. + * + * <p>Note: The name is meant for debugging purposes only. + */ + String getName(); /** * Returns a resolved list of the call's argument types. It also includes a type for every argument * in a vararg function call. */ List<DataType> getArgumentDataTypes(); + + /** + * Creates a validation error for exiting the type inference process with a meaningful exception. + */ + default ValidationException newValidationError(String message, Object... args) { + return new ValidationException(String.format(message, args)); + } } diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/CallContextBase.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/CallContextBase.java deleted file mode 100644 index 6abaafb..0000000 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/CallContextBase.java +++ /dev/null @@ -1,72 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.table.types.inference; - -import org.apache.flink.annotation.PublicEvolving; -import org.apache.flink.table.api.ValidationException; -import org.apache.flink.table.functions.FunctionDefinition; - -import java.util.Optional; - -/** - * Provides details about the function call for which type inference is performed. - */ -@PublicEvolving -public interface CallContextBase { - - /** - * Returns the function definition that defines the function currently being called. - */ - FunctionDefinition getFunctionDefinition(); - - /** - * Returns whether the argument at the given position is a value literal. - */ - boolean isArgumentLiteral(int pos); - - /** - * Returns {@code true} if the argument at the given position is a literal and {@code null}, - * {@code false} otherwise. - * - * <p>Use {@link #isArgumentLiteral(int)} before to check if the argument is actually a literal. - */ - boolean isArgumentNull(int pos); - - /** - * Returns the literal value of the argument at the given position, given that the argument is a - * literal, is not null, and can be expressed as an instance of the provided class. - * - * <p>Use {@link #isArgumentLiteral(int)} before to check if the argument is actually a literal. - */ - <T> Optional<T> getArgumentValue(int pos, Class<T> clazz); - - /** - * Returns the function's name usually referencing the function in a catalog. - * - * <p>Note: The name is meant for debugging purposes only. - */ - String getName(); - - /** - * Creates a validation error for exiting the type inference process with a meaningful exception. - */ - default ValidationException newValidationError(String message, Object... args) { - return new ValidationException(String.format(message, args)); - } -} diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/InputTypeStrategies.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/InputTypeStrategies.java new file mode 100644 index 0000000..bd45c36 --- /dev/null +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/InputTypeStrategies.java @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.types.inference; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.table.types.DataType; +import org.apache.flink.table.types.inference.strategies.BridgingInputTypeStrategy; +import org.apache.flink.table.types.inference.strategies.BridgingInputTypeStrategy.BridgingSignature; +import org.apache.flink.table.types.inference.strategies.ExplicitInputTypeStrategy; +import org.apache.flink.table.types.inference.strategies.NopInputTypeStrategy; +import org.apache.flink.table.types.inference.strategies.OutputTypeInputTypeStrategy; +import org.apache.flink.table.types.logical.LogicalType; + +import java.util.Arrays; +import java.util.List; + +/** + * Strategies for inferring missing or incomplete input argument data types. + * + * @see InputTypeStrategy + */ +@Internal +public final class InputTypeStrategies { + + /** + * Input strategy that does nothing. + */ + public static final InputTypeStrategy NOP = new NopInputTypeStrategy(); + + /** + * Input type strategy that supplies the function's output {@link DataType} for each unknown + * argument if available. + */ + public static final OutputTypeInputTypeStrategy OUTPUT_TYPE = new OutputTypeInputTypeStrategy(); + + /** + * Input type strategy that supplies a fixed {@link DataType} for each argument. + */ + public static ExplicitInputTypeStrategy explicit(DataType... dataTypes) { + return new ExplicitInputTypeStrategy(Arrays.asList(dataTypes)); + } + + /** + * Special input type strategy for enriching data types with an expected conversion class. This + * is in particular useful when a data type has been created out of a {@link LogicalType} but + * runtime hints are still missing. + */ + public static BridgingInputTypeStrategy bridging(List<BridgingSignature> bridgingSignatures) { + return new BridgingInputTypeStrategy(bridgingSignatures); + } + + // -------------------------------------------------------------------------------------------- + + private InputTypeStrategies() { + // no instantiation + } +} diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/InputTypeStrategy.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/InputTypeStrategy.java new file mode 100644 index 0000000..dd93afd --- /dev/null +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/InputTypeStrategy.java @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.types.inference; + +import org.apache.flink.annotation.PublicEvolving; +import org.apache.flink.table.types.DataType; + +/** + * Strategy for inferring missing or incomplete input argument data types in a function call. + * + * <p>This interface has two responsibilities: + * + * <p>In the {@link TypeInference} process, it is called before the validation of input arguments and + * can help in resolving the type of untyped {@code NULL} literals. + * + * <p>During the planning process, it can help in resolving the actual {@link DataType} including the + * conversion class that a function implementation expects from the runtime. This requires that a + * strategy can also be called on already validated arguments without affecting the logical type. This + * is different from Calcite where unknown types are resolved first and might be overwritten by more + * concrete types if available. + * + * <p>Note: Implementations should implement {@link Object#hashCode()} and {@link Object#equals(Object)}. + * + * @see InputTypeStrategies + */ +@PublicEvolving +public interface InputTypeStrategy { + + /** + * Infers the argument types of a function call. + */ + void inferInputTypes(MutableCallContext callContext); +} diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/CallContext.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/MutableCallContext.java similarity index 50% copy from flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/CallContext.java copy to flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/MutableCallContext.java index 3e54884..d07697a 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/CallContext.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/MutableCallContext.java @@ -19,19 +19,32 @@ package org.apache.flink.table.types.inference; import org.apache.flink.annotation.PublicEvolving; +import org.apache.flink.table.api.DataTypes; import org.apache.flink.table.types.DataType; -import java.util.List; +import java.util.Optional; /** - * Provides details about the function call for {@link InputTypeValidator} and {@link TypeStrategy}. + * Provides details about a function call during {@link TypeInference} and allows to mutate argument + * data types for {@link InputTypeStrategy}. + * + * <p>Note: In particular, this method allows to enrich {@link DataTypes#NULL()} to a meaningful data + * type or modify a {@link DataType}'s conversion class. */ @PublicEvolving -public interface CallContext extends CallContextBase { +public interface MutableCallContext extends CallContext { + + /** + * Mutates the data type of an argument at the given position. + */ + void mutateArgumentDataType(int pos, DataType newDataType); /** - * Returns a resolved list of the call's argument types. It also includes a type for every argument - * in a vararg function call. + * Returns the inferred output data type of the function call. + * + * <p>It does this by inferring the input argument data type of a wrapping call (if available) + * where this function call is an argument. For example, {@code takes_string(this_function(NULL))} + * would lead to a {@link DataTypes#STRING()} because the wrapping call expects a string argument. */ - List<DataType> getArgumentDataTypes(); + Optional<DataType> getOutputDataType(); } diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/TypeInference.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/TypeInference.java index af6e3c5..220121f 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/TypeInference.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/TypeInference.java @@ -31,7 +31,7 @@ import java.util.Optional; * Provides logic for the type inference of function calls. It includes: * <ul> * <li>explicit input specification for (possibly named) arguments</li> - * <li>inference of missing input types</li> + * <li>inference of missing or incomplete input types</li> * <li>validation of input types</li> * <li>inference of an intermediate accumulation type</li> * <li>inference of the final output type</li> @@ -42,22 +42,28 @@ import java.util.Optional; @PublicEvolving public final class TypeInference { + private final @Nullable List<String> argumentNames; + + private final @Nullable List<DataType> argumentTypes; + + private final @Nullable InputTypeStrategy inputTypeStrategy; + private final InputTypeValidator inputTypeValidator; private final @Nullable TypeStrategy accumulatorTypeStrategy; private final TypeStrategy outputTypeStrategy; - private final @Nullable List<String> argumentNames; - - private final @Nullable List<DataType> argumentTypes; - private TypeInference( + @Nullable List<String> argumentNames, + @Nullable List<DataType> argumentTypes, + @Nullable InputTypeStrategy inputTypeStrategy, InputTypeValidator inputTypeValidator, @Nullable TypeStrategy accumulatorTypeStrategy, - TypeStrategy outputTypeStrategy, - @Nullable List<String> argumentNames, - @Nullable List<DataType> argumentTypes) { + TypeStrategy outputTypeStrategy) { + this.argumentNames = argumentNames; + this.argumentTypes = argumentTypes; + this.inputTypeStrategy = inputTypeStrategy; this.inputTypeValidator = inputTypeValidator; this.accumulatorTypeStrategy = accumulatorTypeStrategy; this.outputTypeStrategy = outputTypeStrategy; @@ -68,8 +74,6 @@ public final class TypeInference { argumentNames.size(), argumentTypes.size())); } - this.argumentNames = argumentNames; - this.argumentTypes = argumentTypes; } /** @@ -79,6 +83,18 @@ public final class TypeInference { return new TypeInference.Builder(); } + public Optional<List<String>> getArgumentNames() { + return Optional.ofNullable(argumentNames); + } + + public Optional<List<DataType>> getArgumentTypes() { + return Optional.ofNullable(argumentTypes); + } + + public Optional<InputTypeStrategy> getInputTypeStrategy() { + return Optional.ofNullable(inputTypeStrategy); + } + public InputTypeValidator getInputTypeValidator() { return inputTypeValidator; } @@ -91,14 +107,6 @@ public final class TypeInference { return outputTypeStrategy; } - public Optional<List<String>> getArgumentNames() { - return Optional.ofNullable(argumentNames); - } - - public Optional<List<DataType>> getArgumentTypes() { - return Optional.ofNullable(argumentTypes); - } - // -------------------------------------------------------------------------------------------- /** @@ -106,21 +114,55 @@ public final class TypeInference { */ public static class Builder { + private @Nullable List<String> argumentNames; + + private @Nullable List<DataType> argumentTypes; + + private @Nullable InputTypeStrategy inputTypeStrategy; + private InputTypeValidator inputTypeValidator = InputTypeValidators.PASSING; private @Nullable TypeStrategy accumulatorTypeStrategy; private @Nullable TypeStrategy outputTypeStrategy; - private @Nullable List<String> argumentNames; - - private @Nullable List<DataType> argumentTypes; - public Builder() { // default constructor to allow a fluent definition } /** + * Sets the list of argument names for specifying static input explicitly. + * + * <p>This information is useful for SQL's concept of named arguments using the assignment + * operator (e.g. {@code FUNC(max => 42)}). + */ + public Builder namedArguments(List<String> argumentNames) { + this.argumentNames = + Preconditions.checkNotNull(argumentNames, "List of argument names must not be null."); + return this; + } + + /** + * Sets the list of argument types for specifying static input explicitly. + * + * <p>This information is useful for implicit and safe casting. + */ + public Builder typedArguments(List<DataType> argumentTypes) { + this.argumentTypes = + Preconditions.checkNotNull(argumentTypes, "List of argument types must not be null."); + return this; + } + + /** + * Sets the strategy for inferring missing or incomplete input argument data types. + */ + public Builder inputTypeStrategy(InputTypeStrategy inputTypeStrategy) { + this.inputTypeStrategy = + Preconditions.checkNotNull(inputTypeStrategy, "Input type strategy must not be null."); + return this; + } + + /** * Sets the validator for checking the input data types of a function call. * * <p>A always passing function is assumed by default (see {@link InputTypeValidators#PASSING}). @@ -151,36 +193,14 @@ public final class TypeInference { return this; } - /** - * Sets the list of argument names for specifying static input explicitly. - * - * <p>This information is useful for SQL's concept of named arguments using the assignment - * operator (e.g. {@code FUNC(max => 42)}). - */ - public Builder namedArguments(List<String> argumentNames) { - this.argumentNames = - Preconditions.checkNotNull(argumentNames, "List of argument names must not be null."); - return this; - } - - /** - * Sets the list of argument types for specifying static input explicitly. - * - * <p>This information is useful for implicit and safe casting. - */ - public Builder typedArguments(List<DataType> argumentTypes) { - this.argumentTypes = - Preconditions.checkNotNull(argumentTypes, "List of argument types must not be null."); - return this; - } - public TypeInference build() { return new TypeInference( + argumentNames, + argumentTypes, + inputTypeStrategy, inputTypeValidator, accumulatorTypeStrategy, - Preconditions.checkNotNull(outputTypeStrategy, "Output type strategy must not be null."), - argumentNames, - argumentTypes); + Preconditions.checkNotNull(outputTypeStrategy, "Output type strategy must not be null.")); } } } diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/TypeInferenceUtil.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/TypeInferenceUtil.java index 06a1826..ffdd99b 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/TypeInferenceUtil.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/TypeInferenceUtil.java @@ -24,6 +24,9 @@ import org.apache.flink.table.api.ValidationException; import org.apache.flink.table.functions.FunctionDefinition; import org.apache.flink.table.functions.FunctionKind; import org.apache.flink.table.types.DataType; +import org.apache.flink.table.types.inference.utils.AdaptedCallContext; +import org.apache.flink.table.types.inference.utils.UnknownCallContext; +import org.apache.flink.table.types.logical.LogicalTypeRoot; import org.apache.flink.table.types.logical.utils.LogicalTypeCasts; import javax.annotation.Nullable; @@ -32,15 +35,49 @@ import java.util.List; import java.util.Optional; import java.util.stream.Collectors; +import static org.apache.flink.table.types.logical.utils.LogicalTypeChecks.hasRoot; + /** * Utility for performing type inference. + * + * <p>The following steps summarize the envisioned type inference process. Not all features are implemented + * or exposed through the API yet (*). + * + * <p><ul> + * <li>1. Validate number of arguments. + * <li>2. (*) Apply assignment operators on the call expression by permuting operands and adding default + * expressions. These are preparations for {@link CallContext}. + * <li>3. For unknown (NULL) operands: Access the outer wrapping call and try to get its operand + * type for the return type of the actual call. E.g. for {@code takes_string(this_function(NULL))} + * infer operands from {@code takes_string(NULL)} and use the inferred string type as the return + * type of {@code this_function(NULL)}. + * <li>4. Try infer unknown operands, fail if not possible. + * <li>5. (*) Check the usage of DEFAULT operands are correct using validator.isOptional(). + * <li>6. Perform input type validation. + * <li>7. (Optional) Infer accumulator type. + * <li>8. (*) Check for an implementation evaluation method matching the operands. The matching happens + * class-based. Thus, for example, eval(Object) is valid for (INT). Or eval(Object...) is valid + * for (INT, STRING). We rely on the conversion classes specified by DataType. + * <li>9. Infer return type. + * </ul> */ @Internal public final class TypeInferenceUtil { - public static Result runTypeInference(TypeInference typeInference, CallContext callContext) { + /** + * Runs the type inference process. + * + * @param typeInference type inference of the current call + * @param callContext call context of the current call + * @param surroundingInfo information about the outer wrapping call of a current function call for + * performing input type inference + */ + public static Result runTypeInference( + TypeInference typeInference, + CallContext callContext, + @Nullable SurroundingInfo surroundingInfo) { try { - return runTypeInferenceInternal(typeInference, callContext); + return runTypeInferenceInternal(typeInference, callContext, surroundingInfo); } catch (ValidationException e) { throw new ValidationException( String.format( @@ -60,6 +97,59 @@ public final class TypeInferenceUtil { } /** + * Information what the outer world (i.e. an outer wrapping call) expects from the current + * function call. This can be helpful for {@link InputTypeStrategy}. + * + * @see MutableCallContext#getOutputDataType() + */ + public static final class SurroundingInfo { + + private final String name; + + private final FunctionDefinition functionDefinition; + + private final TypeInference typeInference; + + private final int argumentCount; + + private final int innerCallPosition; + + public SurroundingInfo( + String name, + FunctionDefinition functionDefinition, + TypeInference typeInference, + int argumentCount, + int innerCallPosition) { + this.name = name; + this.functionDefinition = functionDefinition; + this.typeInference = typeInference; + this.argumentCount = argumentCount; + this.innerCallPosition = innerCallPosition; + } + + private Optional<DataType> inferOutputType() { + // no strategy for inference + if (!typeInference.getInputTypeStrategy().isPresent()) { + return Optional.empty(); + } + final boolean isValidCount = validateArgumentCount( + typeInference.getInputTypeValidator().getArgumentCount(), + argumentCount, + false); + if (!isValidCount) { + return Optional.empty(); + } + // for "takes_string(this_function(NULL))" simulate "takes_string(NULL)" + // for retrieving the output type of "this_function(NULL)" + final CallContext callContext = new UnknownCallContext(name, functionDefinition, argumentCount); + final AdaptedCallContext adaptedContext = adaptArguments(typeInference, callContext, null); + final InputTypeStrategy inputTypeStrategy = typeInference.getInputTypeStrategy().get(); + inputTypeStrategy.inferInputTypes(adaptedContext); + return Optional.of(adaptedContext.getArgumentDataTypes().get(innerCallPosition)); + } + } + + /** * The result of a type inference run. It contains information about how arguments need to be * adapted in order to comply with the function's signature. * @@ -98,23 +188,23 @@ public final class TypeInferenceUtil { // -------------------------------------------------------------------------------------------- - private static Result runTypeInferenceInternal(TypeInference typeInference, CallContext callContext) { - final List<DataType> argumentTypes = callContext.getArgumentDataTypes(); - + private static Result runTypeInferenceInternal( + TypeInference typeInference, + CallContext callContext, + @Nullable SurroundingInfo surroundingInfo) { try { validateArgumentCount( typeInference.getInputTypeValidator().getArgumentCount(), - argumentTypes.size()); + callContext.getArgumentDataTypes().size(), + true); } catch (ValidationException e) { throw getInvalidInputException(typeInference.getInputTypeValidator(), callContext, e); } - final List<DataType> expectedTypes = typeInference.getArgumentTypes() - .orElse(argumentTypes); - final AdaptedCallContext adaptedCallContext = adaptArguments( + typeInference, callContext, - expectedTypes); + surroundingInfo); try { validateInputTypes( @@ -165,36 +255,49 @@ public final class TypeInferenceUtil { return stringBuilder.toString(); } - private static void validateArgumentCount(ArgumentCount argumentCount, int actualCount) { - argumentCount.getMinCount().ifPresent((min) -> { - if (actualCount < min) { + private static boolean validateArgumentCount( + ArgumentCount argumentCount, + int actualCount, + boolean throwOnFailure) { + final int minCount = argumentCount.getMinCount().orElse(0); + if (actualCount < minCount) { + if (throwOnFailure) { throw new ValidationException( String.format( "Invalid number of arguments. At least %d arguments expected but %d passed.", - min, + minCount, actualCount)); } - }); - - argumentCount.getMaxCount().ifPresent((max) -> { - if (actualCount > max) { + return false; + } + final int maxCount = argumentCount.getMaxCount().orElse(Integer.MAX_VALUE); + if (actualCount > maxCount) { + if (throwOnFailure) { throw new ValidationException( String.format( "Invalid number of arguments. At most %d arguments expected but %d passed.", - max, + maxCount, actualCount)); } - }); - + return false; + } if (!argumentCount.isValidCount(actualCount)) { - throw new ValidationException( - String.format( - "Invalid number of arguments. %d arguments passed.", - actualCount)); + if (throwOnFailure) { + throw new ValidationException( + String.format( + "Invalid number of arguments. %d arguments passed.", + actualCount)); + } + return false; } + return true; } private static void validateInputTypes(InputTypeValidator inputTypeValidator, CallContext callContext) { + // check for unknown types first + if (callContext.getArgumentDataTypes().stream().anyMatch(TypeInferenceUtil::isUnknown)) { + throw new ValidationException("Invalid use of untyped NULL in arguments."); + } if (!inputTypeValidator.validate(callContext, true)) { throw new ValidationException("Invalid input arguments."); } @@ -207,10 +310,21 @@ public final class TypeInferenceUtil { * values (*) where (*) is future work. */ private static AdaptedCallContext adaptArguments( + TypeInference typeInference, CallContext callContext, - List<DataType> expectedTypes) { + @Nullable SurroundingInfo surroundingInfo) { + + final List<DataType> expectedOrActualTypes = typeInference.getArgumentTypes() + .orElse(callContext.getArgumentDataTypes()); + + final AdaptedCallContext adaptedCallContext = inferInputTypes( + typeInference, + expectedOrActualTypes, + callContext, + surroundingInfo); final List<DataType> actualTypes = callContext.getArgumentDataTypes(); + final List<DataType> expectedTypes = adaptedCallContext.getArgumentDataTypes(); for (int pos = 0; pos < actualTypes.size(); pos++) { final DataType expectedType = expectedTypes.get(pos); final DataType actualType = actualTypes.get(pos); @@ -224,8 +338,32 @@ public final class TypeInferenceUtil { actualType)); } } + return adaptedCallContext; + } + + private static AdaptedCallContext inferInputTypes( + TypeInference typeInference, + List<DataType> expectedTypes, + CallContext callContext, + @Nullable SurroundingInfo surroundingInfo) { + // use information of surrounding call to determine output type of this call + final DataType outputType; + if (surroundingInfo != null) { + outputType = surroundingInfo.inferOutputType().orElse(null); + } else { + outputType = null; + } + + final AdaptedCallContext adaptedCallContext = new AdaptedCallContext( + callContext, + expectedTypes, + outputType); - return new AdaptedCallContext(callContext, expectedTypes); + // further adapt the arguments by calling an input strategy + typeInference.getInputTypeStrategy() + .ifPresent(s -> s.inferInputTypes(adaptedCallContext)); + + return adaptedCallContext; } private static boolean canCast(DataType sourceDataType, DataType targetDataType) { @@ -247,80 +385,38 @@ public final class TypeInferenceUtil { } final DataType outputType = potentialOutputType.get(); + if (isUnknown(outputType)) { + throw new ValidationException( + "Could not infer an output type for the given arguments. Untyped NULL received."); + } + if (adaptedCallContext.getFunctionDefinition().getKind() == FunctionKind.TABLE_AGGREGATE || adaptedCallContext.getFunctionDefinition().getKind() == FunctionKind.AGGREGATE) { // an accumulator might be an internal feature of the planner, therefore it is not // mandatory here; we assume the output type to be the accumulator type in this case if (accumulatorTypeStrategy == null) { - return new Result(adaptedCallContext.expectedArguments, outputType, outputType); + return new Result(adaptedCallContext.getArgumentDataTypes(), outputType, outputType); } final Optional<DataType> potentialAccumulatorType = accumulatorTypeStrategy.inferType(adaptedCallContext); if (!potentialAccumulatorType.isPresent()) { throw new ValidationException("Could not infer an accumulator type for the given arguments."); } - return new Result(adaptedCallContext.expectedArguments, potentialAccumulatorType.get(), outputType); - - } else { - return new Result(adaptedCallContext.expectedArguments, null, outputType); - } - } + final DataType accumulatorType = potentialAccumulatorType.get(); - /** - * Helper context that deals with adapted arguments. - * - * <p>For example, if an argument needs to be casted to a target type, an expression that was a - * literal before is not a literal anymore in this call context. - */ - private static class AdaptedCallContext implements CallContext { - - private final CallContext originalContext; - private final List<DataType> expectedArguments; - - public AdaptedCallContext(CallContext originalContext, List<DataType> castedArguments) { - this.originalContext = originalContext; - this.expectedArguments = castedArguments; - } - - @Override - public List<DataType> getArgumentDataTypes() { - return expectedArguments; - } - - @Override - public FunctionDefinition getFunctionDefinition() { - return originalContext.getFunctionDefinition(); - } - - @Override - public boolean isArgumentLiteral(int pos) { - if (isCasted(pos)) { - return false; + if (isUnknown(accumulatorType)) { + throw new ValidationException( + "Could not infer an accumulator type for the given arguments. Untyped NULL received."); } - return originalContext.isArgumentLiteral(pos); - } - @Override - public boolean isArgumentNull(int pos) { - // null remains null regardless of casting - return originalContext.isArgumentNull(pos); - } + return new Result(adaptedCallContext.getArgumentDataTypes(), potentialAccumulatorType.get(), outputType); - @Override - public <T> Optional<T> getArgumentValue(int pos, Class<T> clazz) { - if (isCasted(pos)) { - return Optional.empty(); - } - return originalContext.getArgumentValue(pos, clazz); - } - - @Override - public String getName() { - return originalContext.getName(); + } else { + return new Result(adaptedCallContext.getArgumentDataTypes(), null, outputType); } + } - private boolean isCasted(int pos) { - return !originalContext.getArgumentDataTypes().get(pos).equals(expectedArguments.get(pos)); - } + private static boolean isUnknown(DataType dataType) { + return hasRoot(dataType.getLogicalType(), LogicalTypeRoot.NULL); } private TypeInferenceUtil() { diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/BridgingInputTypeStrategy.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/BridgingInputTypeStrategy.java new file mode 100644 index 0000000..4def14d --- /dev/null +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/BridgingInputTypeStrategy.java @@ -0,0 +1,135 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.types.inference.strategies; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.table.types.DataType; +import org.apache.flink.table.types.inference.InputTypeStrategy; +import org.apache.flink.table.types.inference.MutableCallContext; +import org.apache.flink.table.types.logical.LogicalType; + +import javax.annotation.Nullable; + +import java.util.List; +import java.util.Objects; + +/** + * Input type strategy for enriching data types with an expected conversion class. This is in particular + * useful when a data type has been created out of a {@link LogicalType} but runtime hints are still missing. + */ +@Internal +public final class BridgingInputTypeStrategy implements InputTypeStrategy { + + private final List<BridgingSignature> bridgingSignatures; + + public BridgingInputTypeStrategy(List<BridgingSignature> bridgingSignatures) { + this.bridgingSignatures = bridgingSignatures; + } + + @Override + public void inferInputTypes(MutableCallContext callContext) { + final List<DataType> actualDataTypes = callContext.getArgumentDataTypes(); + for (BridgingSignature bridgingSignature : bridgingSignatures) { + if (bridgingSignature.matches(actualDataTypes)) { + bridgingSignature.enrich(actualDataTypes.size(), callContext); + return; // there should be only one matching signature + } + } + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + BridgingInputTypeStrategy that = (BridgingInputTypeStrategy) o; + return bridgingSignatures.equals(that.bridgingSignatures); + } + + @Override + public int hashCode() { + return Objects.hash(bridgingSignatures); + } + + /** + * Helper class that represents a signature of input arguments. + * + * <p>Note: Array elements can be null for skipping the enrichment of certain arguments. If the + * signature has varargs, the last data type represents the varying argument type. + */ + public static final class BridgingSignature { + + private final DataType[] expectedDataTypes; + + private final boolean isVarying; + + public BridgingSignature(DataType[] expectedDataTypes, boolean isVarying) { + this.expectedDataTypes = expectedDataTypes; + this.isVarying = isVarying; + } + + public boolean matches(List<DataType> actualDataTypes) { + if (!isValidCount(actualDataTypes.size())) { + return false; + } + for (int i = 0; i < actualDataTypes.size(); i++) { + final LogicalType actualType = actualDataTypes.get(i).getLogicalType(); + final DataType expectedType = getExpectedDataType(i); + if (expectedType != null && !actualType.equals(expectedType.getLogicalType())) { + return false; + } + } + return true; + } + + public void enrich(int argumentCount, MutableCallContext mutableCallContext) { + for (int i = 0; i < argumentCount; i++) { + mutableCallContext.mutateArgumentDataType(i, getExpectedDataType(i)); + } + } + + private boolean isValidCount(int actualCount) { + final int minCount; + if (isVarying) { + minCount = expectedDataTypes.length - 1; + } else { + minCount = expectedDataTypes.length; + } + final int maxCount; + if (isVarying) { + maxCount = Integer.MAX_VALUE; + } else { + maxCount = expectedDataTypes.length; + } + return actualCount >= minCount && actualCount <= maxCount; + } + + private @Nullable DataType getExpectedDataType(int pos) { + if (pos < expectedDataTypes.length) { + return expectedDataTypes[pos]; + } else if (isVarying) { + return expectedDataTypes[expectedDataTypes.length - 1]; + } + throw new IllegalStateException("Argument count should have been validated before."); + } + } +} diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/ExplicitInputTypeStrategy.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/ExplicitInputTypeStrategy.java new file mode 100644 index 0000000..cf4c354 --- /dev/null +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/ExplicitInputTypeStrategy.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.types.inference.strategies; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.table.types.DataType; +import org.apache.flink.table.types.inference.InputTypeStrategy; +import org.apache.flink.table.types.inference.MutableCallContext; + +import java.util.List; +import java.util.Objects; + +/** + * Input type strategy that supplies a fixed {@link DataType} for each argument. + */ +@Internal +public final class ExplicitInputTypeStrategy implements InputTypeStrategy { + + private final List<DataType> dataTypes; + + public ExplicitInputTypeStrategy(List<DataType> dataTypes) { + this.dataTypes = dataTypes; + } + + @Override + public void inferInputTypes(MutableCallContext callContext) { + if (callContext.getArgumentDataTypes().size() != dataTypes.size()) { + return; + } + for (int i = 0; i < dataTypes.size(); i++) { + callContext.mutateArgumentDataType(i, dataTypes.get(i)); + } + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + ExplicitInputTypeStrategy that = (ExplicitInputTypeStrategy) o; + return dataTypes.equals(that.dataTypes); + } + + @Override + public int hashCode() { + return Objects.hash(dataTypes); + } +} diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/CallContext.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/NopInputTypeStrategy.java similarity index 56% copy from flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/CallContext.java copy to flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/NopInputTypeStrategy.java index 3e54884..8c33c39 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/CallContext.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/NopInputTypeStrategy.java @@ -16,22 +16,30 @@ * limitations under the License. */ -package org.apache.flink.table.types.inference; +package org.apache.flink.table.types.inference.strategies; -import org.apache.flink.annotation.PublicEvolving; -import org.apache.flink.table.types.DataType; - -import java.util.List; +import org.apache.flink.annotation.Internal; +import org.apache.flink.table.types.inference.InputTypeStrategy; +import org.apache.flink.table.types.inference.MutableCallContext; /** - * Provides details about the function call for {@link InputTypeValidator} and {@link TypeStrategy}. + * Input strategy that does nothing. */ -@PublicEvolving -public interface CallContext extends CallContextBase { +@Internal +public final class NopInputTypeStrategy implements InputTypeStrategy { + + @Override + public void inferInputTypes(MutableCallContext callContext) { + // nothing to do + } + + @Override + public boolean equals(Object o) { + return this == o || o instanceof NopInputTypeStrategy; + } - /** - * Returns a resolved list of the call's argument types. It also includes a type for every argument - * in a vararg function call. - */ - List<DataType> getArgumentDataTypes(); + @Override + public int hashCode() { + return NopInputTypeStrategy.class.hashCode(); + } } diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/OutputTypeInputTypeStrategy.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/OutputTypeInputTypeStrategy.java new file mode 100644 index 0000000..fa154a0 --- /dev/null +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/OutputTypeInputTypeStrategy.java @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to you under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.types.inference.strategies; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.table.types.DataType; +import org.apache.flink.table.types.inference.InputTypeStrategy; +import org.apache.flink.table.types.inference.MutableCallContext; +import org.apache.flink.table.types.logical.LogicalTypeRoot; + +import java.util.List; + +import static org.apache.flink.table.types.logical.utils.LogicalTypeChecks.hasRoot; + +/** + * Input type strategy that supplies the function's output {@link DataType} for each unknown + * argument if available. + */ +@Internal +public final class OutputTypeInputTypeStrategy implements InputTypeStrategy { + + @Override + public void inferInputTypes(MutableCallContext callContext) { + callContext.getOutputDataType().ifPresent(t -> { + final List<DataType> dataTypes = callContext.getArgumentDataTypes(); + for (int i = 0; i < dataTypes.size(); i++) { + if (hasRoot(dataTypes.get(i).getLogicalType(), LogicalTypeRoot.NULL)) { + callContext.mutateArgumentDataType(i, t); + } + } + }); + } + + @Override + public boolean equals(Object o) { + return this == o || o instanceof OutputTypeInputTypeStrategy; + } + + @Override + public int hashCode() { + return OutputTypeInputTypeStrategy.class.hashCode(); + } +} diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/utils/AdaptedCallContext.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/utils/AdaptedCallContext.java new file mode 100644 index 0000000..5240a64 --- /dev/null +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/utils/AdaptedCallContext.java @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.types.inference.utils; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.table.functions.FunctionDefinition; +import org.apache.flink.table.types.DataType; +import org.apache.flink.table.types.inference.CallContext; +import org.apache.flink.table.types.inference.MutableCallContext; + +import javax.annotation.Nullable; + +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; + +/** + * Helper context that deals with adapted arguments. + * + * <p>For example, if an argument needs to be casted to a target type, an expression that was a + * literal before is not a literal anymore in this call context. + */ +@Internal +public final class AdaptedCallContext implements MutableCallContext { + + private final CallContext originalContext; + + private final List<DataType> expectedArguments; + + private final @Nullable DataType outputDataType; + + public AdaptedCallContext( + CallContext originalContext, + List<DataType> castedArguments, + @Nullable DataType outputDataType) { + this.originalContext = originalContext; + this.expectedArguments = new ArrayList<>(castedArguments); + this.outputDataType = outputDataType; + } + + @Override + public List<DataType> getArgumentDataTypes() { + return expectedArguments; + } + + @Override + public FunctionDefinition getFunctionDefinition() { + return originalContext.getFunctionDefinition(); + } + + @Override + public boolean isArgumentLiteral(int pos) { + if (isCasted(pos)) { + return false; + } + return originalContext.isArgumentLiteral(pos); + } + + @Override + public boolean isArgumentNull(int pos) { + // null remains null regardless of casting + return originalContext.isArgumentNull(pos); + } + + @Override + public <T> Optional<T> getArgumentValue(int pos, Class<T> clazz) { + if (isCasted(pos)) { + return Optional.empty(); + } + return originalContext.getArgumentValue(pos, clazz); + } + + @Override + public String getName() { + return originalContext.getName(); + } + + private boolean isCasted(int pos) { + return !originalContext.getArgumentDataTypes().get(pos).equals(expectedArguments.get(pos)); + } + + @Override + public void mutateArgumentDataType(int pos, DataType newDataType) { + expectedArguments.set(pos, newDataType); + } + + @Override + public Optional<DataType> getOutputDataType() { + return Optional.ofNullable(outputDataType); + } +} diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/utils/UnknownCallContext.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/utils/UnknownCallContext.java new file mode 100644 index 0000000..f4ba594 --- /dev/null +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/utils/UnknownCallContext.java @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.types.inference.utils; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.table.api.DataTypes; +import org.apache.flink.table.functions.FunctionDefinition; +import org.apache.flink.table.types.DataType; +import org.apache.flink.table.types.inference.CallContext; + +import java.util.AbstractList; +import java.util.List; +import java.util.Optional; + +/** + * A {@link CallContext} with unknown data types. + */ +@Internal +public final class UnknownCallContext implements CallContext { + + private static final DataType NULL = DataTypes.NULL(); + + private final String name; + + private final FunctionDefinition functionDefinition; + + private final List<DataType> argumentDataTypes; + + public UnknownCallContext( + String name, + FunctionDefinition functionDefinition, + int argumentCount) { + this.name = name; + this.functionDefinition = functionDefinition; + this.argumentDataTypes = new AbstractList<DataType>() { + @Override + public DataType get(int index) { + return NULL; + } + + @Override + public int size() { + return argumentCount; + } + }; + } + + @Override + public FunctionDefinition getFunctionDefinition() { + return functionDefinition; + } + + @Override + public boolean isArgumentLiteral(int pos) { + return false; + } + + @Override + public boolean isArgumentNull(int pos) { + return false; + } + + @Override + public <T> Optional<T> getArgumentValue(int pos, Class<T> clazz) { + return Optional.empty(); + } + + @Override + public String getName() { + return name; + } + + @Override + public List<DataType> getArgumentDataTypes() { + return argumentDataTypes; + } +} diff --git a/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/inference/InputTypeStrategiesTest.java b/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/inference/InputTypeStrategiesTest.java new file mode 100644 index 0000000..42c8f4e --- /dev/null +++ b/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/inference/InputTypeStrategiesTest.java @@ -0,0 +1,227 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to you under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.types.inference; + +import org.apache.flink.table.api.DataTypes; +import org.apache.flink.table.api.ValidationException; +import org.apache.flink.table.functions.FunctionKind; +import org.apache.flink.table.types.DataType; +import org.apache.flink.table.types.inference.strategies.BridgingInputTypeStrategy.BridgingSignature; +import org.apache.flink.table.types.inference.utils.CallContextMock; +import org.apache.flink.table.types.inference.utils.FunctionDefinitionMock; + +import org.junit.Assert; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameter; +import org.junit.runners.Parameterized.Parameters; + +import javax.annotation.Nullable; + +import java.util.Arrays; +import java.util.List; + +import static org.apache.flink.util.CoreMatchers.containsCause; +import static org.hamcrest.CoreMatchers.equalTo; + +/** + * Tests for built-in {@link InputTypeStrategies}. + */ +@RunWith(Parameterized.class) +public class InputTypeStrategiesTest { + + @Parameters + public static List<TestSpec> testData() { + return Arrays.asList( + // no inference + TestSpec + .forInputStrategy(InputTypeStrategies.NOP) + .calledWithArgumentTypes(DataTypes.INT(), DataTypes.STRING()) + .expectArgumentTypes(DataTypes.INT(), DataTypes.STRING()), + + // incomplete inference + TestSpec + .forInputStrategy(InputTypeStrategies.NOP) + .calledWithArgumentTypes(DataTypes.NULL(), DataTypes.STRING(), DataTypes.NULL()) + .expectErrorMessage("Invalid use of untyped NULL in arguments."), + + // typed arguments help inferring a type + TestSpec + .forInputStrategy(InputTypeStrategies.NOP) + .typedArguments(DataTypes.INT(), DataTypes.STRING(), DataTypes.BOOLEAN()) + .calledWithArgumentTypes(DataTypes.NULL(), DataTypes.STRING(), DataTypes.NULL()) + .expectArgumentTypes(DataTypes.INT(), DataTypes.STRING(), DataTypes.BOOLEAN()), + + // surrounding function helps inferring a type + TestSpec + .forInputStrategy(InputTypeStrategies.OUTPUT_TYPE) + .surroundingStrategy(InputTypeStrategies.explicit(DataTypes.BOOLEAN())) + .calledWithArgumentTypes(DataTypes.NULL(), DataTypes.STRING(), DataTypes.NULL()) + .expectArgumentTypes(DataTypes.BOOLEAN(), DataTypes.STRING(), DataTypes.BOOLEAN()), + + // enrich data type with conversion class of INT + TestSpec + .forInputStrategy(createBridgingInputTypeStrategy()) + .calledWithArgumentTypes(DataTypes.STRING(), DataTypes.INT()) + .expectArgumentTypes(DataTypes.STRING(), DataTypes.INT().bridgedTo(int.class)), + + // enrich data type with conversion class of varying INT + TestSpec + .forInputStrategy(createBridgingInputTypeStrategy()) + .calledWithArgumentTypes(DataTypes.INT(), DataTypes.INT(), DataTypes.INT()) + .expectArgumentTypes( + DataTypes.INT().bridgedTo(int.class), + DataTypes.INT().bridgedTo(int.class), + DataTypes.INT().bridgedTo(int.class)), + + // check function without arguments + TestSpec + .forInputStrategy(createBridgingInputTypeStrategy()) + .calledWithArgumentTypes() + .expectArgumentTypes() + ); + } + + @Parameter + public TestSpec testSpec; + + @Rule + public ExpectedException thrown = ExpectedException.none(); + + @Test + public void testInputTypeStrategy() { + if (testSpec.expectedErrorMessage != null) { + thrown.expect(ValidationException.class); + thrown.expectCause(containsCause(new ValidationException(testSpec.expectedErrorMessage))); + } + TypeInferenceUtil.Result result = runTypeInference(); + if (testSpec.expectedArgumentTypes != null) { + Assert.assertThat(result.getExpectedArgumentTypes(), equalTo(testSpec.expectedArgumentTypes)); + } + } + + // -------------------------------------------------------------------------------------------- + + private TypeInferenceUtil.Result runTypeInference() { + final FunctionDefinitionMock functionDefinitionMock = new FunctionDefinitionMock(); + functionDefinitionMock.functionKind = FunctionKind.SCALAR; + final CallContextMock callContextMock = new CallContextMock(); + callContextMock.functionDefinition = functionDefinitionMock; + callContextMock.argumentDataTypes = testSpec.actualArgumentTypes; + callContextMock.name = "f"; + + final TypeInference.Builder builder = TypeInference.newBuilder() + .inputTypeStrategy(testSpec.strategy) + .inputTypeValidator(InputTypeValidators.PASSING) + .outputTypeStrategy(TypeStrategies.explicit(DataTypes.BOOLEAN())); + + if (testSpec.typedArguments != null) { + builder.typedArguments(testSpec.typedArguments); + } + + final TypeInferenceUtil.SurroundingInfo surroundingInfo; + if (testSpec.surroundingStrategy != null) { + final TypeInference outerTypeInference = TypeInference.newBuilder() + .inputTypeStrategy(testSpec.surroundingStrategy) + .outputTypeStrategy(TypeStrategies.MISSING) + .build(); + surroundingInfo = new TypeInferenceUtil.SurroundingInfo( + "f_outer", + functionDefinitionMock, + outerTypeInference, + 1, + 0); + } else { + surroundingInfo = null; + } + return TypeInferenceUtil.runTypeInference(builder.build(), callContextMock, surroundingInfo); + } + + // -------------------------------------------------------------------------------------------- + + private static class TestSpec { + + private final InputTypeStrategy strategy; + + // types explicitly expected by the type inference + private @Nullable List<DataType> typedArguments; + + private @Nullable InputTypeStrategy surroundingStrategy; + + private @Nullable List<DataType> actualArgumentTypes; + + private @Nullable List<DataType> expectedArgumentTypes; + + private @Nullable String expectedErrorMessage; + + private TestSpec(InputTypeStrategy strategy) { + this.strategy = strategy; + } + + static TestSpec forInputStrategy(InputTypeStrategy strategy) { + return new TestSpec(strategy); + } + + TestSpec typedArguments(DataType... dataTypes) { + this.typedArguments = Arrays.asList(dataTypes); + return this; + } + + TestSpec surroundingStrategy(InputTypeStrategy surroundingStrategy) { + this.surroundingStrategy = surroundingStrategy; + return this; + } + + TestSpec calledWithArgumentTypes(DataType... dataTypes) { + this.actualArgumentTypes = Arrays.asList(dataTypes); + return this; + } + + TestSpec expectArgumentTypes(DataType... dataTypes) { + this.expectedArgumentTypes = Arrays.asList(dataTypes); + return this; + } + + TestSpec expectErrorMessage(String expectedErrorMessage) { + this.expectedErrorMessage = expectedErrorMessage; + return this; + } + } + + private static InputTypeStrategy createBridgingInputTypeStrategy() { + return InputTypeStrategies.bridging( + Arrays.asList( + new BridgingSignature( + new DataType[]{DataTypes.STRING(), DataTypes.INT().bridgedTo(int.class)}, false + ), + new BridgingSignature( + new DataType[]{DataTypes.TIMESTAMP(3).bridgedTo(java.sql.Timestamp.class), DataTypes.STRING()}, false + ), + new BridgingSignature( + new DataType[]{}, false + ), + new BridgingSignature( + new DataType[]{DataTypes.INT().bridgedTo(int.class)}, true + ) + ) + ); + } +} diff --git a/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/inference/InputTypeValidatorsTest.java b/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/inference/InputTypeValidatorsTest.java index aac5449..0ac4de1 100644 --- a/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/inference/InputTypeValidatorsTest.java +++ b/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/inference/InputTypeValidatorsTest.java @@ -36,7 +36,6 @@ import org.junit.runners.Parameterized.Parameters; import javax.annotation.Nullable; import java.util.List; -import java.util.Optional; import java.util.stream.Collectors; import java.util.stream.IntStream; @@ -127,21 +126,21 @@ public class InputTypeValidatorsTest { // left of OR TestSpec - .forValidator(or(explicitSequence(DataTypes.INT()), explicitSequence(DataTypes.NULL()))) + .forValidator(or(explicitSequence(DataTypes.INT()), explicitSequence(DataTypes.BOOLEAN()))) .inputTypes(DataTypes.INT()) .expectSuccess(), // right of OR TestSpec - .forValidator(or(explicitSequence(DataTypes.INT()), explicitSequence(DataTypes.NULL()))) - .inputTypes(DataTypes.NULL()) + .forValidator(or(explicitSequence(DataTypes.INT()), explicitSequence(DataTypes.BOOLEAN()))) + .inputTypes(DataTypes.BOOLEAN()) .expectSuccess(), // invalid type in OR TestSpec - .forValidator(or(explicitSequence(DataTypes.INT()), explicitSequence(DataTypes.NULL()))) + .forValidator(or(explicitSequence(DataTypes.INT()), explicitSequence(DataTypes.STRING()))) .inputTypes(DataTypes.BOOLEAN()) - .expectErrorMessage("Invalid input arguments. Expected signatures are:\nf(INT)\nf(NULL)"), + .expectErrorMessage("Invalid input arguments. Expected signatures are:\nf(INT)\nf(STRING)"), // explicit sequence TestSpec @@ -302,9 +301,9 @@ public class InputTypeValidatorsTest { final TypeInference typeInference = TypeInference.newBuilder() .inputTypeValidator(testSpec.validator) - .outputTypeStrategy(callContext -> Optional.of(DataTypes.NULL())) + .outputTypeStrategy(TypeStrategies.explicit(DataTypes.BOOLEAN())) .build(); - TypeInferenceUtil.runTypeInference(typeInference, callContextMock); + TypeInferenceUtil.runTypeInference(typeInference, callContextMock, null); } // -------------------------------------------------------------------------------------------- diff --git a/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/inference/TypeStrategiesTest.java b/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/inference/TypeStrategiesTest.java index d545c7a..64e7a11 100644 --- a/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/inference/TypeStrategiesTest.java +++ b/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/inference/TypeStrategiesTest.java @@ -83,7 +83,13 @@ public class TypeStrategiesTest { TestSpec .forStrategy(createMatchingTypeStrategy()) .inputTypes(DataTypes.INT(), DataTypes.INT()) - .expectErrorMessage("Could not infer an output type for the given arguments.") + .expectErrorMessage("Could not infer an output type for the given arguments."), + + // invalid return type + TestSpec + .forStrategy(TypeStrategies.explicit(DataTypes.NULL())) + .inputTypes() + .expectErrorMessage("Could not infer an output type for the given arguments. Untyped NULL received.") ); } @@ -119,7 +125,7 @@ public class TypeStrategiesTest { .inputTypeValidator(InputTypeValidators.PASSING) .outputTypeStrategy(testSpec.strategy) .build(); - return TypeInferenceUtil.runTypeInference(typeInference, callContextMock); + return TypeInferenceUtil.runTypeInference(typeInference, callContextMock, null); } // --------------------------------------------------------------------------------------------