dawidwys commented on a change in pull request #10606: 
[FLINK-15009][table-common] Add a utility for creating type inference logic via 
reflection
URL: https://github.com/apache/flink/pull/10606#discussion_r363764199
 
 

 ##########
 File path: 
flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/extraction/TypeInferenceExtractorTest.java
 ##########
 @@ -0,0 +1,679 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.types.extraction;
+
+import org.apache.flink.table.annotation.DataTypeHint;
+import org.apache.flink.table.annotation.FunctionHint;
+import org.apache.flink.table.annotation.InputGroup;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.ValidationException;
+import org.apache.flink.table.catalog.DataTypeLookup;
+import org.apache.flink.table.catalog.UnresolvedIdentifier;
+import org.apache.flink.table.functions.AggregateFunction;
+import org.apache.flink.table.functions.ScalarFunction;
+import org.apache.flink.table.functions.TableAggregateFunction;
+import org.apache.flink.table.functions.TableFunction;
+import org.apache.flink.table.types.DataType;
+import org.apache.flink.table.types.inference.ArgumentTypeStrategy;
+import org.apache.flink.table.types.inference.InputTypeStrategies;
+import org.apache.flink.table.types.inference.InputTypeStrategy;
+import org.apache.flink.table.types.inference.TypeInference;
+import org.apache.flink.table.types.inference.TypeStrategies;
+import org.apache.flink.table.types.inference.TypeStrategy;
+import org.apache.flink.table.types.logical.utils.LogicalTypeParser;
+import org.apache.flink.table.types.utils.TypeConversions;
+import org.apache.flink.types.Row;
+
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.ExpectedException;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.junit.runners.Parameterized.Parameter;
+import org.junit.runners.Parameterized.Parameters;
+
+import javax.annotation.Nullable;
+
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.function.Supplier;
+
+import static org.apache.flink.util.CoreMatchers.containsCause;
+import static org.hamcrest.CoreMatchers.equalTo;
+import static org.junit.Assert.assertThat;
+
+/**
+ * Tests for {@link TypeInferenceExtractor}.
+ */
+@RunWith(Parameterized.class)
+@SuppressWarnings("unused")
+public class TypeInferenceExtractorTest {
+
+       @Parameters
+       public static List<TestSpec> testData() {
+               return Arrays.asList(
+                       // function hint defines everything
+                       TestSpec
+                               .forScalarFunction(FullFunctionHint.class)
+                               .expectTypedArguments(DataTypes.INT(), 
DataTypes.STRING())
+                               .expectNamedArguments("i", "s")
+                               .expectOutputMapping(
+                                       InputTypeStrategies.sequence(
+                                               new String[] {"i", "s"},
+                                               new ArgumentTypeStrategy[] {
+                                                       
InputTypeStrategies.explicit(DataTypes.INT()),
+                                                       
InputTypeStrategies.explicit(DataTypes.STRING())}
+                                       ),
+                                       
TypeStrategies.explicit(DataTypes.BOOLEAN())),
+
+                       // function hint defines everything with overloading
+                       TestSpec
+                               .forScalarFunction(FullFunctionHints.class)
+                               .expectOutputMapping(
+                                       
InputTypeStrategies.sequence(InputTypeStrategies.explicit(DataTypes.INT())),
+                                       
TypeStrategies.explicit(DataTypes.INT()))
+                               .expectOutputMapping(
+                                       
InputTypeStrategies.sequence(InputTypeStrategies.explicit(DataTypes.BIGINT())),
+                                       
TypeStrategies.explicit(DataTypes.BIGINT())),
+
+                       // global output hint with local input overloading
+                       TestSpec
+                               
.forScalarFunction(GlobalOutputFunctionHint.class)
+                               .expectOutputMapping(
+                                       
InputTypeStrategies.sequence(InputTypeStrategies.explicit(DataTypes.INT())),
+                                       
TypeStrategies.explicit(DataTypes.INT()))
+                               .expectOutputMapping(
+                                       
InputTypeStrategies.sequence(InputTypeStrategies.explicit(DataTypes.STRING())),
+                                       
TypeStrategies.explicit(DataTypes.INT())),
+
+                       // unsupported output overloading
+                       TestSpec
+                               
.forScalarFunction(InvalidSingleOutputFunctionHint.class)
+                               .expectErrorMessage("Function hints that lead 
to ambiguous results are not allowed."),
+
+                       // global and local overloading
+                       TestSpec
+                               .forScalarFunction(SplitFullFunctionHints.class)
+                               .expectOutputMapping(
+                                       
InputTypeStrategies.sequence(InputTypeStrategies.explicit(DataTypes.INT())),
+                                       
TypeStrategies.explicit(DataTypes.INT()))
+                               .expectOutputMapping(
+                                       
InputTypeStrategies.sequence(InputTypeStrategies.explicit(DataTypes.BIGINT())),
+                                       
TypeStrategies.explicit(DataTypes.BIGINT())),
+
+                       // global and local overloading with unsupported output 
overloading
+                       TestSpec
+                               
.forScalarFunction(InvalidFullOutputFunctionHint.class)
+                               .expectErrorMessage("Function hints with same 
input definition but different result types are not allowed."),
+
+                       // invalid data type hint
+                       TestSpec
+                               .forScalarFunction(IncompleteFunctionHint.class)
+                               .expectErrorMessage("Data type hint does 
neither specify a data type nor input group for use as function argument."),
+
+                       // varargs and ANY input group
+                       TestSpec
+                               .forScalarFunction(ComplexFunctionHint.class)
+                               .expectOutputMapping(
+                                       InputTypeStrategies.varyingSequence(
+                                               new String[]{"myInt", "myAny"},
+                                               new 
ArgumentTypeStrategy[]{InputTypeStrategies.explicit(DataTypes.INT()), 
InputTypeStrategies.ANY}),
+                                       
TypeStrategies.explicit(DataTypes.BOOLEAN())),
+
+                       // ignore argument names during overloading
+                       TestSpec
+                               
.forScalarFunction(InvalidOutputWithArgNamesFunctionHint.class)
+                               .expectErrorMessage("Function hints with same 
input definition but different result types are not allowed."),
+
+                       // global input hints and local output hints
+                       TestSpec
+                               
.forScalarFunction(GlobalInputFunctionHints.class)
+                               .expectOutputMapping(
+                                       
InputTypeStrategies.sequence(InputTypeStrategies.explicit(DataTypes.INT())),
+                                       
TypeStrategies.explicit(DataTypes.INT()))
+                               .expectOutputMapping(
+                                       
InputTypeStrategies.sequence(InputTypeStrategies.explicit(DataTypes.BIGINT())),
+                                       
TypeStrategies.explicit(DataTypes.INT())),
+
+                       // no arguments
+                       TestSpec
+                               .forScalarFunction(ZeroArgFunction.class)
+                               .expectOutputMapping(
+                                       InputTypeStrategies.sequence(new 
String[0], new ArgumentTypeStrategy[0]),
+                                       
TypeStrategies.explicit(DataTypes.INT())),
+
+                       // test primitive arguments extraction
+                       TestSpec
+                               .forScalarFunction(MixedArgFunction.class)
+                               .expectOutputMapping(
+                                       InputTypeStrategies.sequence(
+                                               new String[]{"i", "d"},
+                                               new ArgumentTypeStrategy[]{
+                                                       
InputTypeStrategies.explicit(DataTypes.INT().notNull().bridgedTo(int.class)),
+                                                       
InputTypeStrategies.explicit(DataTypes.DOUBLE())}),
+                                       
TypeStrategies.explicit(DataTypes.INT())),
+
+                       // test overloaded arguments extraction
+                       TestSpec
+                               .forScalarFunction(OverloadedFunction.class)
+                               .expectOutputMapping(
+                                       InputTypeStrategies.sequence(
+                                               new String[]{"i", "d"},
+                                               new ArgumentTypeStrategy[]{
+                                                       
InputTypeStrategies.explicit(DataTypes.INT().notNull().bridgedTo(int.class)),
+                                                       
InputTypeStrategies.explicit(DataTypes.DOUBLE())}),
+                                       
TypeStrategies.explicit(DataTypes.INT()))
+                               .expectOutputMapping(
+                                       InputTypeStrategies.sequence(
+                                               new String[]{"s"},
+                                               new ArgumentTypeStrategy[]{
+                                                       
InputTypeStrategies.explicit(DataTypes.STRING())
+                                               }),
+                                       
TypeStrategies.explicit(DataTypes.BIGINT().notNull().bridgedTo(long.class))),
+
+                       // test varying arguments extraction
+                       TestSpec
+                               .forScalarFunction(VarArgFunction.class)
+                               .expectOutputMapping(
+                                       InputTypeStrategies.varyingSequence(
+                                               new String[]{"i", "more"},
+                                               new ArgumentTypeStrategy[]{
+                                                       
InputTypeStrategies.explicit(DataTypes.INT().notNull().bridgedTo(int.class)),
+                                                       
InputTypeStrategies.explicit(DataTypes.INT().notNull().bridgedTo(int.class))
+                                               }),
+                                       
TypeStrategies.explicit(DataTypes.STRING())),
+
+                       // output hint with input extraction
+                       TestSpec
+                               
.forScalarFunction(ExtractWithOutputHintFunction.class)
+                               .expectOutputMapping(
+                                       InputTypeStrategies.sequence(
+                                               new String[]{"i"},
+                                               new ArgumentTypeStrategy[]{
+                                                       
InputTypeStrategies.explicit(DataTypes.INT())
+                                               }),
+                                       
TypeStrategies.explicit(DataTypes.INT())),
+
+                       // output extraction with input hints
+                       TestSpec
+                               
.forScalarFunction(ExtractWithInputHintFunction.class)
+                               .expectOutputMapping(
+                                       InputTypeStrategies.sequence(
+                                               new String[]{"i", "b"},
+                                               new ArgumentTypeStrategy[]{
+                                                       
InputTypeStrategies.explicit(DataTypes.INT()),
+                                                       
InputTypeStrategies.explicit(DataTypes.BOOLEAN())
+                                               }),
+                                       
TypeStrategies.explicit(DataTypes.DOUBLE().notNull().bridgedTo(double.class))),
+
+                       // different accumulator depending on input
+                       TestSpec
+                               
.forAggregateFunction(InputDependentAccumulatorFunction.class)
+                               .expectAccumulatorMapping(
+                                       
InputTypeStrategies.sequence(InputTypeStrategies.explicit(DataTypes.STRING())),
+                                       
TypeStrategies.explicit(DataTypes.ROW(DataTypes.FIELD("f", 
DataTypes.STRING()))))
+                               .expectAccumulatorMapping(
+                                       
InputTypeStrategies.sequence(InputTypeStrategies.explicit(DataTypes.BIGINT())),
+                                       
TypeStrategies.explicit(DataTypes.ROW(DataTypes.FIELD("f", 
DataTypes.BIGINT()))))
+                               .expectOutputMapping(
+                                       
InputTypeStrategies.sequence(InputTypeStrategies.explicit(DataTypes.STRING())),
+                                       
TypeStrategies.explicit(DataTypes.STRING()))
+                               .expectOutputMapping(
+                                       
InputTypeStrategies.sequence(InputTypeStrategies.explicit(DataTypes.BIGINT())),
+                                       
TypeStrategies.explicit(DataTypes.STRING())),
+
+                       // input, accumulator, and output are spread across the 
function
+                       TestSpec
+                               
.forAggregateFunction(AggregateFunctionWithManyAnnotations.class)
+                               .expectAccumulatorMapping(
+                                       InputTypeStrategies.sequence(
+                                               new String[]{"r"},
+                                               new ArgumentTypeStrategy[]{
+                                                       
InputTypeStrategies.explicit(DataTypes.ROW(
+                                                               
DataTypes.FIELD("i", DataTypes.INT()),
+                                                               
DataTypes.FIELD("b", DataTypes.BOOLEAN())))
+                                               }),
+                                       
TypeStrategies.explicit(DataTypes.ROW(DataTypes.FIELD("b", 
DataTypes.BOOLEAN()))))
+                               .expectOutputMapping(
+                                       InputTypeStrategies.sequence(
+                                               new String[]{"r"},
+                                               new ArgumentTypeStrategy[]{
+                                                       
InputTypeStrategies.explicit(DataTypes.ROW(
+                                                               
DataTypes.FIELD("i", DataTypes.INT()),
+                                                               
DataTypes.FIELD("b", DataTypes.BOOLEAN())))
+                                               }),
+                                       
TypeStrategies.explicit(DataTypes.STRING())),
+
+                       // test for table functions
+                       TestSpec
+                               .forTableFunction(OutputHintTableFunction.class)
+                               .expectOutputMapping(
+                                       InputTypeStrategies.sequence(
+                                               new String[]{"i"},
+                                               new ArgumentTypeStrategy[]{
+                                                       
InputTypeStrategies.explicit(DataTypes.INT().notNull().bridgedTo(int.class))
+                                               }),
+                                       TypeStrategies.explicit(
+                                               DataTypes.ROW(
+                                                       DataTypes.FIELD("i", 
DataTypes.INT()),
+                                                       DataTypes.FIELD("b", 
DataTypes.BOOLEAN())))),
+
+                       // mismatch between hints and implementation regarding 
return type
+                       TestSpec
+                               
.forScalarFunction(InvalidMethodScalarFunction.class)
+                               .expectErrorMessage("Considering all hints, the 
method should comply with the signature:\n" +
+                                       "java.lang.Long eval(int)"),
+
+                       // mismatch between hints and implementation regarding 
accumulator
+                       TestSpec
+                               
.forAggregateFunction(InvalidMethodAggregateFunction.class)
+                               .expectErrorMessage("Considering all hints, the 
method should comply with the signature:\n" +
+                                       "accumulate(java.lang.Integer, int, 
boolean)"),
+
+                       // no implementation
+                       TestSpec
+                               
.forTableFunction(MissingMethodTableFunction.class)
+                               .expectErrorMessage("Could not find a publicly 
accessible method named 'eval'.")
+               );
+       }
+
+       @Parameter
+       public TestSpec testSpec;
+
+       @Rule
+       public ExpectedException thrown = ExpectedException.none();
+
+       @Test
+       public void testArgumentNames() {
+               if (testSpec.expectedArgumentNames != null) {
+                       assertThat(
+                               
testSpec.typeInferenceExtraction.get().getNamedArguments(),
+                               
equalTo(Optional.of(testSpec.expectedArgumentNames)));
+               }
+       }
+
+       @Test
+       public void testArgumentTypes() {
+               if (testSpec.expectedArgumentTypes != null) {
+                       assertThat(
+                               
testSpec.typeInferenceExtraction.get().getTypedArguments(),
+                               
equalTo(Optional.of(testSpec.expectedArgumentTypes)));
+               }
+       }
+
+       @Test
+       public void testAccumulatorTypeStrategy() {
+               if (!testSpec.expectedAccumulatorStrategies.isEmpty()) {
+                       assertThat(
+                               
testSpec.typeInferenceExtraction.get().getAccumulatorTypeStrategy().isPresent(),
+                               equalTo(true));
+                       assertThat(
+                               
testSpec.typeInferenceExtraction.get().getAccumulatorTypeStrategy().get(),
+                               
equalTo(TypeStrategies.mapping(testSpec.expectedAccumulatorStrategies)));
+               }
+       }
+
+       @Test
+       public void testOutputTypeStrategy() {
+               if (!testSpec.expectedOutputStrategies.isEmpty()) {
+                       assertThat(
+                               
testSpec.typeInferenceExtraction.get().getOutputTypeStrategy(),
+                               
equalTo(TypeStrategies.mapping(testSpec.expectedOutputStrategies)));
+               }
+       }
+
+       @Test
+       public void testErrorMessage() {
+               if (testSpec.expectedErrorMessage != null) {
+                       thrown.expect(ValidationException.class);
+                       thrown.expectCause(containsCause(new 
ValidationException(testSpec.expectedErrorMessage)));
+                       testSpec.typeInferenceExtraction.get();
+               }
+       }
+
+       // 
--------------------------------------------------------------------------------------------
+       // Test utilities
+       // 
--------------------------------------------------------------------------------------------
+
+       private static class TestSpec {
+
+               final Supplier<TypeInference> typeInferenceExtraction;
+
+               @Nullable List<String> expectedArgumentNames;
+
+               @Nullable List<DataType> expectedArgumentTypes;
+
+               Map<InputTypeStrategy, TypeStrategy> 
expectedAccumulatorStrategies;
+
+               Map<InputTypeStrategy, TypeStrategy> expectedOutputStrategies;
+
+               @Nullable String expectedErrorMessage;
+
+               private TestSpec(Supplier<TypeInference> 
typeInferenceExtraction) {
+                       this.typeInferenceExtraction = typeInferenceExtraction;
+                       this.expectedAccumulatorStrategies = new HashMap<>();
+                       this.expectedOutputStrategies = new HashMap<>();
+               }
+
+               static TestSpec forScalarFunction(Class<? extends 
ScalarFunction> function) {
+                       return new TestSpec(() ->
+                               TypeInferenceExtractor.forScalarFunction(new 
DataTypeLookupMock(), function));
+               }
+
+               static TestSpec forAggregateFunction(Class<? extends 
AggregateFunction> function) {
+                       return new TestSpec(() ->
+                               TypeInferenceExtractor.forAggregateFunction(new 
DataTypeLookupMock(), function));
+               }
+
+               static TestSpec forTableFunction(Class<? extends TableFunction> 
function) {
+                       return new TestSpec(() ->
+                               TypeInferenceExtractor.forTableFunction(new 
DataTypeLookupMock(), function));
+               }
+
+               static TestSpec forTableAggregateFunction(Class<? extends 
TableAggregateFunction> function) {
+                       return new TestSpec(() ->
+                               
TypeInferenceExtractor.forTableAggregateFunction(new DataTypeLookupMock(), 
function));
+               }
+
+               TestSpec expectNamedArguments(String... expectedArgumentNames) {
+                       this.expectedArgumentNames = 
Arrays.asList(expectedArgumentNames);
+                       return this;
+               }
+
+               TestSpec expectTypedArguments(DataType... 
expectedArgumentTypes) {
+                       this.expectedArgumentTypes = 
Arrays.asList(expectedArgumentTypes);
+                       return this;
+               }
+
+               TestSpec expectAccumulatorMapping(InputTypeStrategy validator, 
TypeStrategy accumulatorStrategy) {
+                       this.expectedAccumulatorStrategies.put(validator, 
accumulatorStrategy);
+                       return this;
+               }
+
+               TestSpec expectOutputMapping(InputTypeStrategy validator, 
TypeStrategy outputStrategy) {
+                       this.expectedOutputStrategies.put(validator, 
outputStrategy);
+                       return this;
+               }
+
+               TestSpec expectErrorMessage(String expectedErrorMessage) {
+                       this.expectedErrorMessage = expectedErrorMessage;
+                       return this;
+               }
+       }
+
+       private static class DataTypeLookupMock implements DataTypeLookup {
+
+               @Override
+               public Optional<DataType> lookupDataType(String name) {
+                       return 
Optional.of(TypeConversions.fromLogicalToDataType(LogicalTypeParser.parse(name)));
+               }
+
+               @Override
+               public Optional<DataType> lookupDataType(UnresolvedIdentifier 
identifier) {
+                       return Optional.empty();
+               }
+
+               @Override
+               public DataType resolveRawDataType(Class<?> clazz) {
+                       return null;
+               }
+       }
+
+       // 
--------------------------------------------------------------------------------------------
+       // Test classes for extraction
+       // 
--------------------------------------------------------------------------------------------
+
+       @FunctionHint(
+               input = {@DataTypeHint("INT"), @DataTypeHint("STRING")},
+               argumentNames = {"i", "s"},
+               output = @DataTypeHint("BOOLEAN")
+       )
+       private static class FullFunctionHint extends ScalarFunction {
+               public Boolean eval(Integer i, String s) {
+                       return null;
+               }
+       }
+
+       private static class ComplexFunctionHint extends ScalarFunction {
+               @FunctionHint(
+                       input = {@DataTypeHint("INT"), @DataTypeHint(inputGroup 
= InputGroup.ANY)},
+                       argumentNames = {"myInt", "myAny"},
+                       output = @DataTypeHint("BOOLEAN"),
+                       isVarArgs = true
+               )
+               public Boolean eval(Object... o) {
+                       return null;
+               }
+       }
+
+       @FunctionHint(input = @DataTypeHint("INT"), output = 
@DataTypeHint("INT"))
+       @FunctionHint(input = @DataTypeHint("BIGINT"), output = 
@DataTypeHint("BIGINT"))
+       private static class FullFunctionHints extends ScalarFunction {
+               public Number eval(Number n) {
+                       return null;
+               }
+       }
+
+       @FunctionHint(output = @DataTypeHint("INT"))
+       private static class GlobalOutputFunctionHint extends ScalarFunction {
+               @FunctionHint(input = @DataTypeHint("INT"))
+               public Integer eval(Integer n) {
+                       return null;
+               }
+
+               @FunctionHint(input = @DataTypeHint("STRING"))
+               public Integer eval(String n) {
+                       return null;
+               }
+       }
+
+       @FunctionHint(output = @DataTypeHint("INT"))
+       private static class InvalidSingleOutputFunctionHint extends 
ScalarFunction {
+               @FunctionHint(output = @DataTypeHint("STRING"))
+               public Integer eval(Integer n) {
+                       return null;
+               }
+       }
+
+       @FunctionHint(input = @DataTypeHint("INT"), output = 
@DataTypeHint("INT"))
+       private static class SplitFullFunctionHints extends ScalarFunction {
+               @FunctionHint(input = @DataTypeHint("BIGINT"), output = 
@DataTypeHint("BIGINT"))
+               public Number eval(Number n) {
+                       return null;
+               }
+       }
+
+       @FunctionHint(input = @DataTypeHint("INT"), output = 
@DataTypeHint("INT"))
+       private static class InvalidFullOutputFunctionHint extends 
ScalarFunction {
+               @FunctionHint(input = @DataTypeHint("INT"), output = 
@DataTypeHint("BIGINT"))
 
 Review comment:
   Could we make sure the method return types are compatible with the 
conversion classes in those cases?
   
   It would simplify reasoning. Right now even if we remove one of the output 
declarations it still fails because the return type mismatches with the 
declared output type.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to