This is an automated email from the ASF dual-hosted git repository.

dwysakowicz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git


The following commit(s) were added to refs/heads/master by this push:
     new 648b6bac952 [FLINK-33439] Implement type inference for IN function
648b6bac952 is described below

commit 648b6bac95232c8498b392ac01e5089777553c77
Author: Dawid Wysakowicz <[email protected]>
AuthorDate: Thu Nov 2 16:17:51 2023 +0100

    [FLINK-33439] Implement type inference for IN function
---
 .../functions/BuiltInFunctionDefinitions.java      |   3 +-
 .../strategies/ComparableTypeStrategy.java         | 124 +------------------
 .../strategies/SpecificInputTypeStrategies.java    |   3 +
 .../strategies/SubQueryInputTypeStrategy.java      | 119 ++++++++++++++++++
 .../types/logical/utils/LogicalTypeChecks.java     | 136 +++++++++++++++++++++
 .../strategies/SubQueryInputTypeStrategyTest.java  | 128 +++++++++++++++++++
 .../expressions/PlannerExpressionConverter.scala   |   4 -
 .../flink/table/planner/expressions/subquery.scala |  78 ------------
 .../validation/ScalarFunctionsValidationTest.scala |   9 --
 .../validation/ScalarOperatorsValidationTest.scala |   6 -
 10 files changed, 392 insertions(+), 218 deletions(-)

diff --git 
a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/BuiltInFunctionDefinitions.java
 
b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/BuiltInFunctionDefinitions.java
index 4f9953f45c9..b8012922df2 100644
--- 
a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/BuiltInFunctionDefinitions.java
+++ 
b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/BuiltInFunctionDefinitions.java
@@ -2242,7 +2242,8 @@ public final class BuiltInFunctionDefinitions {
             BuiltInFunctionDefinition.newBuilder()
                     .name("in")
                     .kind(SCALAR)
-                    .outputTypeStrategy(TypeStrategies.MISSING)
+                    .inputTypeStrategy(SpecificInputTypeStrategies.IN)
+                    
.outputTypeStrategy(nullableIfArgs(explicit(DataTypes.BOOLEAN())))
                     .build();
 
     public static final BuiltInFunctionDefinition CAST =
diff --git 
a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/ComparableTypeStrategy.java
 
b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/ComparableTypeStrategy.java
index 88aa6877c3e..cb62543cce5 100644
--- 
a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/ComparableTypeStrategy.java
+++ 
b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/ComparableTypeStrategy.java
@@ -26,14 +26,9 @@ import org.apache.flink.table.types.inference.CallContext;
 import org.apache.flink.table.types.inference.ConstantArgumentCount;
 import org.apache.flink.table.types.inference.InputTypeStrategy;
 import org.apache.flink.table.types.inference.Signature;
-import org.apache.flink.table.types.logical.DistinctType;
-import org.apache.flink.table.types.logical.LegacyTypeInformationType;
 import org.apache.flink.table.types.logical.LogicalType;
-import org.apache.flink.table.types.logical.LogicalTypeFamily;
-import org.apache.flink.table.types.logical.LogicalTypeRoot;
-import org.apache.flink.table.types.logical.RawType;
-import org.apache.flink.table.types.logical.StructuredType;
 import 
org.apache.flink.table.types.logical.StructuredType.StructuredComparison;
+import org.apache.flink.table.types.logical.utils.LogicalTypeChecks;
 import org.apache.flink.util.Preconditions;
 
 import java.util.Collections;
@@ -49,7 +44,7 @@ import java.util.Optional;
  * with itself (e.g. for aggregations).
  *
  * <p>For the rules which types are comparable with which types see {@link
- * #areComparable(LogicalType, LogicalType)}.
+ * LogicalTypeChecks#areComparable(LogicalType, LogicalType, 
StructuredComparison)}.
  */
 @Internal
 public final class ComparableTypeStrategy implements InputTypeStrategy {
@@ -78,7 +73,7 @@ public final class ComparableTypeStrategy implements 
InputTypeStrategy {
         final List<DataType> argumentDataTypes = 
callContext.getArgumentDataTypes();
         if (argumentDataTypes.size() == 1) {
             final LogicalType argType = 
argumentDataTypes.get(0).getLogicalType();
-            if (!areComparable(argType, argType)) {
+            if (!LogicalTypeChecks.areComparable(argType, argType, 
requiredComparison)) {
                 return callContext.fail(
                         throwOnFailure,
                         "Type '%s' should support %s comparison with itself.",
@@ -90,7 +85,7 @@ public final class ComparableTypeStrategy implements 
InputTypeStrategy {
                 final LogicalType firstType = 
argumentDataTypes.get(i).getLogicalType();
                 final LogicalType secondType = argumentDataTypes.get(i + 
1).getLogicalType();
 
-                if (!areComparable(firstType, secondType)) {
+                if (!LogicalTypeChecks.areComparable(firstType, secondType, 
requiredComparison)) {
                     return callContext.fail(
                             throwOnFailure,
                             "All types in a comparison should support %s 
comparison with each other. "
@@ -111,120 +106,9 @@ public final class ComparableTypeStrategy implements 
InputTypeStrategy {
                 : "both 'EQUALS' and 'ORDER'";
     }
 
-    private boolean areComparable(LogicalType firstType, LogicalType 
secondType) {
-        return areComparableWithNormalizedNullability(firstType.copy(true), 
secondType.copy(true));
-    }
-
-    private boolean areComparableWithNormalizedNullability(
-            LogicalType firstType, LogicalType secondType) {
-        // A hack to support legacy types. To be removed when we drop the 
legacy types.
-        if (firstType instanceof LegacyTypeInformationType
-                || secondType instanceof LegacyTypeInformationType) {
-            return true;
-        }
-
-        // everything is comparable with null, it should return null in that 
case
-        if (firstType.is(LogicalTypeRoot.NULL) || 
secondType.is(LogicalTypeRoot.NULL)) {
-            return true;
-        }
-
-        if (firstType.getTypeRoot() == secondType.getTypeRoot()) {
-            return areTypesOfSameRootComparable(firstType, secondType);
-        }
-
-        if (firstType.is(LogicalTypeFamily.NUMERIC) && 
secondType.is(LogicalTypeFamily.NUMERIC)) {
-            return true;
-        }
-
-        // DATE + ALL TIMESTAMPS
-        if (firstType.is(LogicalTypeFamily.DATETIME) && 
secondType.is(LogicalTypeFamily.DATETIME)) {
-            return true;
-        }
-
-        // VARCHAR + CHAR (we do not compare collations here)
-        if (firstType.is(LogicalTypeFamily.CHARACTER_STRING)
-                && secondType.is(LogicalTypeFamily.CHARACTER_STRING)) {
-            return true;
-        }
-
-        // VARBINARY + BINARY
-        if (firstType.is(LogicalTypeFamily.BINARY_STRING)
-                && secondType.is(LogicalTypeFamily.BINARY_STRING)) {
-            return true;
-        }
-
-        return false;
-    }
-
-    private boolean areTypesOfSameRootComparable(LogicalType firstType, 
LogicalType secondType) {
-        switch (firstType.getTypeRoot()) {
-            case ARRAY:
-            case MULTISET:
-            case MAP:
-            case ROW:
-                return areConstructedTypesComparable(firstType, secondType);
-            case DISTINCT_TYPE:
-                return areDistinctTypesComparable(firstType, secondType);
-            case STRUCTURED_TYPE:
-                return areStructuredTypesComparable(firstType, secondType);
-            case RAW:
-                return areRawTypesComparable(firstType, secondType);
-            default:
-                return true;
-        }
-    }
-
-    private boolean areRawTypesComparable(LogicalType firstType, LogicalType 
secondType) {
-        return firstType.equals(secondType)
-                && Comparable.class.isAssignableFrom(
-                        ((RawType<?>) firstType).getOriginatingClass());
-    }
-
-    private boolean areDistinctTypesComparable(LogicalType firstType, 
LogicalType secondType) {
-        DistinctType firstDistinctType = (DistinctType) firstType;
-        DistinctType secondDistinctType = (DistinctType) secondType;
-        return firstType.equals(secondType)
-                && areComparable(
-                        firstDistinctType.getSourceType(), 
secondDistinctType.getSourceType());
-    }
-
-    private boolean areStructuredTypesComparable(LogicalType firstType, 
LogicalType secondType) {
-        return firstType.equals(secondType) && 
hasRequiredComparison((StructuredType) firstType);
-    }
-
-    private boolean areConstructedTypesComparable(LogicalType firstType, 
LogicalType secondType) {
-        List<LogicalType> firstChildren = firstType.getChildren();
-        List<LogicalType> secondChildren = secondType.getChildren();
-
-        if (firstChildren.size() != secondChildren.size()) {
-            return false;
-        }
-
-        for (int i = 0; i < firstChildren.size(); i++) {
-            if (!areComparable(firstChildren.get(i), secondChildren.get(i))) {
-                return false;
-            }
-        }
-
-        return true;
-    }
-
     @Override
     public List<Signature> getExpectedSignatures(FunctionDefinition 
definition) {
         return Collections.singletonList(
                 Signature.of(Signature.Argument.ofGroupVarying("COMPARABLE")));
     }
-
-    private Boolean hasRequiredComparison(StructuredType structuredType) {
-        switch (requiredComparison) {
-            case EQUALS:
-                return structuredType.getComparison().isEquality();
-            case FULL:
-                return structuredType.getComparison().isComparison();
-            case NONE:
-            default:
-                // this is not important, required comparison will never be 
NONE
-                return true;
-        }
-    }
 }
diff --git 
a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/SpecificInputTypeStrategies.java
 
b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/SpecificInputTypeStrategies.java
index b004ce20e9d..e0a07150d78 100644
--- 
a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/SpecificInputTypeStrategies.java
+++ 
b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/SpecificInputTypeStrategies.java
@@ -122,6 +122,9 @@ public final class SpecificInputTypeStrategies {
     public static final InputTypeStrategy TWO_EQUALS_COMPARABLE =
             comparable(ConstantArgumentCount.of(2), 
StructuredType.StructuredComparison.EQUALS);
 
+    /** Type strategy specific for {@link BuiltInFunctionDefinitions#IN}. */
+    public static final InputTypeStrategy IN = new SubQueryInputTypeStrategy();
+
     private SpecificInputTypeStrategies() {
         // no instantiation
     }
diff --git 
a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/SubQueryInputTypeStrategy.java
 
b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/SubQueryInputTypeStrategy.java
new file mode 100644
index 00000000000..96ff9450d52
--- /dev/null
+++ 
b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/SubQueryInputTypeStrategy.java
@@ -0,0 +1,119 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.types.inference.strategies;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.table.functions.BuiltInFunctionDefinitions;
+import org.apache.flink.table.functions.FunctionDefinition;
+import org.apache.flink.table.types.DataType;
+import org.apache.flink.table.types.inference.ArgumentCount;
+import org.apache.flink.table.types.inference.CallContext;
+import org.apache.flink.table.types.inference.ConstantArgumentCount;
+import org.apache.flink.table.types.inference.InputTypeStrategy;
+import org.apache.flink.table.types.inference.Signature;
+import org.apache.flink.table.types.logical.LogicalType;
+import org.apache.flink.table.types.logical.LogicalTypeRoot;
+import org.apache.flink.table.types.logical.RowType;
+import org.apache.flink.table.types.logical.StructuredType;
+import org.apache.flink.table.types.logical.utils.LogicalTypeChecks;
+import org.apache.flink.table.types.logical.utils.LogicalTypeMerging;
+import org.apache.flink.table.types.utils.TypeConversions;
+
+import java.util.Arrays;
+import java.util.List;
+import java.util.Optional;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
+import java.util.stream.Stream;
+
+/** {@link InputTypeStrategy} for {@link BuiltInFunctionDefinitions#IN}. */
+@Internal
+public class SubQueryInputTypeStrategy implements InputTypeStrategy {
+    @Override
+    public ArgumentCount getArgumentCount() {
+        return ConstantArgumentCount.from(2);
+    }
+
+    @Override
+    public Optional<List<DataType>> inferInputTypes(
+            CallContext callContext, boolean throwOnFailure) {
+        final LogicalType rightType;
+        final DataType leftType = callContext.getArgumentDataTypes().get(0);
+        if (callContext.getArgumentDataTypes().size() > 2) {
+            final Optional<LogicalType> commonType =
+                    LogicalTypeMerging.findCommonType(
+                            callContext.getArgumentDataTypes().stream()
+                                    .map(DataType::getLogicalType)
+                                    .collect(Collectors.toList()));
+            if (!commonType.isPresent()) {
+                return callContext.fail(
+                        throwOnFailure, "Could not find a common type of the 
sublist.");
+            }
+            rightType = commonType.get();
+        } else {
+            rightType = 
callContext.getArgumentDataTypes().get(1).getLogicalType();
+        }
+
+        // check if the types are comparable, if the types are not comparable, 
check if it is not
+        // a sub-query case like SELECT a IN (SELECT b FROM table1). We check 
if the result of the
+        // rightType is of a ROW type with a single column, and if that column 
is comparable with
+        // left type
+        if (!LogicalTypeChecks.areComparable(
+                        leftType.getLogicalType(),
+                        rightType,
+                        StructuredType.StructuredComparison.EQUALS)
+                && !isComparableWithSubQuery(leftType.getLogicalType(), 
rightType)) {
+            return callContext.fail(
+                    throwOnFailure,
+                    "Types on the right side of IN operator (%s) are not 
comparable with %s.",
+                    rightType,
+                    leftType.getLogicalType());
+        }
+
+        return Optional.of(
+                Stream.concat(
+                                Stream.of(leftType),
+                                IntStream.range(1, 
callContext.getArgumentDataTypes().size())
+                                        .mapToObj(
+                                                i ->
+                                                        
TypeConversions.fromLogicalToDataType(
+                                                                rightType)))
+                        .collect(Collectors.toList()));
+    }
+
+    private static boolean isComparableWithSubQuery(LogicalType left, 
LogicalType right) {
+        if (right.is(LogicalTypeRoot.ROW) && right.getChildren().size() == 1) {
+            final RowType rowType = (RowType) right;
+            return LogicalTypeChecks.areComparable(
+                    left, rowType.getTypeAt(0), 
StructuredType.StructuredComparison.EQUALS);
+        }
+        return false;
+    }
+
+    @Override
+    public List<Signature> getExpectedSignatures(FunctionDefinition 
definition) {
+        return Arrays.asList(
+                Signature.of(
+                        Signature.Argument.ofGroup("COMPARABLE"),
+                        Signature.Argument.ofGroupVarying("COMPARABLE")),
+                Signature.of(
+                        Signature.Argument.ofGroup("COMPARABLE"),
+                        Signature.Argument.ofGroup("SUBQUERY")));
+    }
+}
diff --git 
a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/logical/utils/LogicalTypeChecks.java
 
b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/logical/utils/LogicalTypeChecks.java
index cb2fd4c9668..d6e31ea92da 100644
--- 
a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/logical/utils/LogicalTypeChecks.java
+++ 
b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/logical/utils/LogicalTypeChecks.java
@@ -32,11 +32,14 @@ import org.apache.flink.table.types.logical.IntType;
 import org.apache.flink.table.types.logical.LegacyTypeInformationType;
 import org.apache.flink.table.types.logical.LocalZonedTimestampType;
 import org.apache.flink.table.types.logical.LogicalType;
+import org.apache.flink.table.types.logical.LogicalTypeFamily;
 import org.apache.flink.table.types.logical.LogicalTypeRoot;
+import org.apache.flink.table.types.logical.RawType;
 import org.apache.flink.table.types.logical.RowType;
 import org.apache.flink.table.types.logical.SmallIntType;
 import org.apache.flink.table.types.logical.StructuredType;
 import org.apache.flink.table.types.logical.StructuredType.StructuredAttribute;
+import 
org.apache.flink.table.types.logical.StructuredType.StructuredComparison;
 import org.apache.flink.table.types.logical.TimeType;
 import org.apache.flink.table.types.logical.TimestampKind;
 import org.apache.flink.table.types.logical.TimestampType;
@@ -242,6 +245,139 @@ public final class LogicalTypeChecks {
         }
     }
 
+    public static boolean areComparable(
+            LogicalType firstType,
+            LogicalType secondType,
+            StructuredComparison requiredComparison) {
+        return areComparableWithNormalizedNullability(
+                firstType.copy(true), secondType.copy(true), 
requiredComparison);
+    }
+
+    private static boolean areComparableWithNormalizedNullability(
+            LogicalType firstType,
+            LogicalType secondType,
+            StructuredComparison requiredComparison) {
+        // A hack to support legacy types. To be removed when we drop the 
legacy types.
+        if (firstType instanceof LegacyTypeInformationType
+                || secondType instanceof LegacyTypeInformationType) {
+            return true;
+        }
+
+        // everything is comparable with null, it should return null in that 
case
+        if (firstType.is(LogicalTypeRoot.NULL) || 
secondType.is(LogicalTypeRoot.NULL)) {
+            return true;
+        }
+
+        if (firstType.getTypeRoot() == secondType.getTypeRoot()) {
+            return areTypesOfSameRootComparable(firstType, secondType, 
requiredComparison);
+        }
+
+        if (firstType.is(LogicalTypeFamily.NUMERIC) && 
secondType.is(LogicalTypeFamily.NUMERIC)) {
+            return true;
+        }
+
+        // DATE + ALL TIMESTAMPS
+        if (firstType.is(LogicalTypeFamily.DATETIME) && 
secondType.is(LogicalTypeFamily.DATETIME)) {
+            return true;
+        }
+
+        // VARCHAR + CHAR (we do not compare collations here)
+        if (firstType.is(LogicalTypeFamily.CHARACTER_STRING)
+                && secondType.is(LogicalTypeFamily.CHARACTER_STRING)) {
+            return true;
+        }
+
+        // VARBINARY + BINARY
+        if (firstType.is(LogicalTypeFamily.BINARY_STRING)
+                && secondType.is(LogicalTypeFamily.BINARY_STRING)) {
+            return true;
+        }
+
+        return false;
+    }
+
+    private static boolean areTypesOfSameRootComparable(
+            LogicalType firstType,
+            LogicalType secondType,
+            StructuredComparison requiredComparison) {
+        switch (firstType.getTypeRoot()) {
+            case ARRAY:
+            case MULTISET:
+            case MAP:
+            case ROW:
+                return areConstructedTypesComparable(firstType, secondType, 
requiredComparison);
+            case DISTINCT_TYPE:
+                return areDistinctTypesComparable(firstType, secondType, 
requiredComparison);
+            case STRUCTURED_TYPE:
+                return areStructuredTypesComparable(firstType, secondType, 
requiredComparison);
+            case RAW:
+                return areRawTypesComparable(firstType, secondType);
+            default:
+                return true;
+        }
+    }
+
+    private static boolean areRawTypesComparable(LogicalType firstType, 
LogicalType secondType) {
+        return firstType.equals(secondType)
+                && Comparable.class.isAssignableFrom(
+                        ((RawType<?>) firstType).getOriginatingClass());
+    }
+
+    private static boolean areDistinctTypesComparable(
+            LogicalType firstType,
+            LogicalType secondType,
+            StructuredComparison requiredComparison) {
+        DistinctType firstDistinctType = (DistinctType) firstType;
+        DistinctType secondDistinctType = (DistinctType) secondType;
+        return firstType.equals(secondType)
+                && areComparable(
+                        firstDistinctType.getSourceType(),
+                        secondDistinctType.getSourceType(),
+                        requiredComparison);
+    }
+
+    private static boolean areStructuredTypesComparable(
+            LogicalType firstType,
+            LogicalType secondType,
+            StructuredComparison requiredComparison) {
+        return firstType.equals(secondType)
+                && hasRequiredComparison((StructuredType) firstType, 
requiredComparison);
+    }
+
+    private static boolean areConstructedTypesComparable(
+            LogicalType firstType,
+            LogicalType secondType,
+            StructuredComparison requiredComparison) {
+        List<LogicalType> firstChildren = firstType.getChildren();
+        List<LogicalType> secondChildren = secondType.getChildren();
+
+        if (firstChildren.size() != secondChildren.size()) {
+            return false;
+        }
+
+        for (int i = 0; i < firstChildren.size(); i++) {
+            if (!areComparable(firstChildren.get(i), secondChildren.get(i), 
requiredComparison)) {
+                return false;
+            }
+        }
+
+        return true;
+    }
+
+    private static Boolean hasRequiredComparison(
+            StructuredType structuredType, StructuredComparison 
requiredComparison) {
+        switch (requiredComparison) {
+            case EQUALS:
+                return structuredType.getComparison().isEquality();
+            case FULL:
+                return structuredType.getComparison().isComparison();
+            case NONE:
+            default:
+                // this is not important, required comparison will never be 
NONE
+                return true;
+        }
+    }
+
     private LogicalTypeChecks() {
         // no instantiation
     }
diff --git 
a/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/inference/strategies/SubQueryInputTypeStrategyTest.java
 
b/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/inference/strategies/SubQueryInputTypeStrategyTest.java
new file mode 100644
index 00000000000..82ce646860b
--- /dev/null
+++ 
b/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/inference/strategies/SubQueryInputTypeStrategyTest.java
@@ -0,0 +1,128 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.types.inference.strategies;
+
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.types.inference.InputTypeStrategiesTestBase;
+
+import java.util.stream.Stream;
+
+/** Tests for {@link SubQueryInputTypeStrategy}. */
+class SubQueryInputTypeStrategyTest extends InputTypeStrategiesTestBase {
+
+    @Override
+    protected Stream<TestSpec> testData() {
+        return Stream.of(
+                TestSpec.forStrategy("IN a set", 
SpecificInputTypeStrategies.IN)
+                        .calledWithArgumentTypes(
+                                DataTypes.INT(),
+                                DataTypes.BIGINT(),
+                                DataTypes.SMALLINT(),
+                                DataTypes.INT())
+                        .expectArgumentTypes(
+                                DataTypes.INT(),
+                                DataTypes.BIGINT(),
+                                DataTypes.BIGINT(),
+                                DataTypes.BIGINT()),
+                TestSpec.forStrategy("IN a set, binary", 
SpecificInputTypeStrategies.IN)
+                        .calledWithArgumentTypes(
+                                DataTypes.BYTES(),
+                                DataTypes.BYTES(),
+                                DataTypes.BYTES(),
+                                DataTypes.BYTES())
+                        .expectArgumentTypes(
+                                DataTypes.BYTES(),
+                                DataTypes.BYTES(),
+                                DataTypes.BYTES(),
+                                DataTypes.BYTES()),
+                TestSpec.forStrategy("IN a set, string", 
SpecificInputTypeStrategies.IN)
+                        .calledWithArgumentTypes(
+                                DataTypes.STRING(),
+                                DataTypes.STRING(),
+                                DataTypes.STRING(),
+                                DataTypes.STRING())
+                        .expectArgumentTypes(
+                                DataTypes.STRING(),
+                                DataTypes.STRING(),
+                                DataTypes.STRING(),
+                                DataTypes.STRING()),
+                TestSpec.forStrategy(
+                                "IN a set, multiset(timestamp)", 
SpecificInputTypeStrategies.IN)
+                        .calledWithArgumentTypes(
+                                DataTypes.MULTISET(DataTypes.TIMESTAMP()),
+                                DataTypes.MULTISET(DataTypes.TIMESTAMP()),
+                                DataTypes.MULTISET(DataTypes.TIMESTAMP()),
+                                DataTypes.MULTISET(DataTypes.TIMESTAMP()))
+                        .expectArgumentTypes(
+                                DataTypes.MULTISET(DataTypes.TIMESTAMP()),
+                                DataTypes.MULTISET(DataTypes.TIMESTAMP()),
+                                DataTypes.MULTISET(DataTypes.TIMESTAMP()),
+                                DataTypes.MULTISET(DataTypes.TIMESTAMP())),
+                TestSpec.forStrategy("IN a set, arrays", 
SpecificInputTypeStrategies.IN)
+                        .calledWithArgumentTypes(
+                                DataTypes.ARRAY(DataTypes.BIGINT()),
+                                DataTypes.ARRAY(DataTypes.BIGINT()),
+                                DataTypes.ARRAY(DataTypes.INT()),
+                                DataTypes.ARRAY(DataTypes.SMALLINT()))
+                        .expectArgumentTypes(
+                                DataTypes.ARRAY(DataTypes.BIGINT()),
+                                DataTypes.ARRAY(DataTypes.BIGINT()),
+                                DataTypes.ARRAY(DataTypes.BIGINT()),
+                                DataTypes.ARRAY(DataTypes.BIGINT())),
+                TestSpec.forStrategy("IN a set of ROWs", 
SpecificInputTypeStrategies.IN)
+                        .calledWithArgumentTypes(
+                                DataTypes.ROW(DataTypes.FIELD("f0", 
DataTypes.INT())),
+                                DataTypes.ROW(DataTypes.FIELD("f0", 
DataTypes.INT())))
+                        .expectArgumentTypes(
+                                DataTypes.ROW(DataTypes.FIELD("f0", 
DataTypes.INT())),
+                                DataTypes.ROW(DataTypes.FIELD("f0", 
DataTypes.INT()))),
+                TestSpec.forStrategy("IN a subquery", 
SpecificInputTypeStrategies.IN)
+                        .calledWithArgumentTypes(
+                                DataTypes.INT(),
+                                DataTypes.ROW(DataTypes.FIELD("f0", 
DataTypes.BIGINT())))
+                        .expectArgumentTypes(
+                                DataTypes.INT(),
+                                DataTypes.ROW(DataTypes.FIELD("f0", 
DataTypes.BIGINT()))),
+                TestSpec.forStrategy("IN a set not comparable", 
SpecificInputTypeStrategies.IN)
+                        .calledWithArgumentTypes(DataTypes.INT(), 
DataTypes.STRING())
+                        .expectErrorMessage(
+                                "Types on the right side of IN operator 
(STRING) are not comparable with INT."),
+                TestSpec.forStrategy("IN a subquery not comparable", 
SpecificInputTypeStrategies.IN)
+                        .calledWithArgumentTypes(
+                                DataTypes.INT(),
+                                DataTypes.ROW(DataTypes.FIELD("f0", 
DataTypes.STRING())))
+                        .expectErrorMessage(
+                                "Types on the right side of IN operator 
(ROW<`f0` STRING>) are not comparable with INT"),
+                TestSpec.forStrategy("IN a subquery of ROWs", 
SpecificInputTypeStrategies.IN)
+                        .calledWithArgumentTypes(
+                                DataTypes.ROW(DataTypes.FIELD("f0", 
DataTypes.INT())),
+                                DataTypes.ROW(
+                                        DataTypes.FIELD(
+                                                "f0",
+                                                DataTypes.ROW(
+                                                        DataTypes.FIELD("f0", 
DataTypes.INT())))))
+                        .expectArgumentTypes(
+                                DataTypes.ROW(DataTypes.FIELD("f0", 
DataTypes.INT())),
+                                DataTypes.ROW(
+                                        DataTypes.FIELD(
+                                                "f0",
+                                                DataTypes.ROW(
+                                                        DataTypes.FIELD("f0", 
DataTypes.INT()))))));
+    }
+}
diff --git 
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/expressions/PlannerExpressionConverter.scala
 
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/expressions/PlannerExpressionConverter.scala
index 7996be7dd87..cf95c97cf0d 100644
--- 
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/expressions/PlannerExpressionConverter.scala
+++ 
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/expressions/PlannerExpressionConverter.scala
@@ -127,10 +127,6 @@ class PlannerExpressionConverter private extends 
ApiExpressionVisitor[PlannerExp
       case fd: FunctionDefinition =>
         fd match {
 
-          case IN =>
-            assert(args.size > 1)
-            In(args.head, args.drop(1))
-
           case DISTINCT =>
             assert(args.size == 1)
             DistinctAgg(args.head)
diff --git 
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/expressions/subquery.scala
 
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/expressions/subquery.scala
deleted file mode 100644
index a5595024e8d..00000000000
--- 
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/expressions/subquery.scala
+++ /dev/null
@@ -1,78 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.flink.table.planner.expressions
-
-import org.apache.flink.api.common.typeinfo.BasicTypeInfo._
-import org.apache.flink.api.common.typeinfo.TypeInformation
-import org.apache.flink.table.operations.QueryOperation
-import org.apache.flink.table.planner.typeutils.TypeInfoCheckUtils._
-import org.apache.flink.table.planner.validate.{ValidationFailure, 
ValidationResult, ValidationSuccess}
-import org.apache.flink.table.types.utils.TypeConversions
-import 
org.apache.flink.table.types.utils.TypeConversions.fromDataTypeToLegacyInfo
-
-case class In(expression: PlannerExpression, elements: Seq[PlannerExpression])
-  extends PlannerExpression {
-
-  override def toString = s"$expression.in(${elements.mkString(", ")})"
-
-  override private[flink] def children: Seq[PlannerExpression] = expression +: 
elements.distinct
-
-  override private[flink] def validateInput(): ValidationResult = {
-    // check if this is a sub-query expression or an element list
-    elements.head match {
-
-      case TableReference(name, tableOperation: QueryOperation) =>
-        if (elements.length != 1) {
-          return ValidationFailure("IN operator supports only one table 
reference.")
-        }
-        val resolvedSchema = tableOperation.getResolvedSchema
-        if (resolvedSchema.getColumnCount > 1) {
-          return ValidationFailure(
-            s"The sub-query table '$name' must not have more than one column.")
-        }
-        (
-          expression.resultType,
-          fromDataTypeToLegacyInfo(resolvedSchema.getColumnDataTypes.get(0))) 
match {
-          case (lType, rType) if lType == rType => ValidationSuccess
-          case (lType, rType) if isNumeric(lType) && isNumeric(rType) => 
ValidationSuccess
-          case (lType, rType) if isArray(lType) && lType.getTypeClass == 
rType.getTypeClass =>
-            ValidationSuccess
-          case (lType, rType) =>
-            ValidationFailure(s"IN operator on incompatible types: $lType and 
$rType.")
-        }
-
-      case _ =>
-        val types = children.tail.map(_.resultType)
-        if (types.distinct.length != 1) {
-          return ValidationFailure(
-            s"Types on the right side of the IN operator must be the same, " +
-              s"got ${types.mkString(", ")}.")
-        }
-        (children.head.resultType, children.last.resultType) match {
-          case (lType, rType) if isNumeric(lType) && isNumeric(rType) => 
ValidationSuccess
-          case (lType, rType) if lType == rType => ValidationSuccess
-          case (lType, rType) if isArray(lType) && lType.getTypeClass == 
rType.getTypeClass =>
-            ValidationSuccess
-          case (lType, rType) =>
-            ValidationFailure(s"IN operator on incompatible types: $lType and 
$rType.")
-        }
-    }
-  }
-
-  override private[flink] def resultType: TypeInformation[_] = 
BOOLEAN_TYPE_INFO
-}
diff --git 
a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/expressions/validation/ScalarFunctionsValidationTest.scala
 
b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/expressions/validation/ScalarFunctionsValidationTest.scala
index 1ce90d8e9e9..4a604ae733c 100644
--- 
a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/expressions/validation/ScalarFunctionsValidationTest.scala
+++ 
b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/expressions/validation/ScalarFunctionsValidationTest.scala
@@ -123,15 +123,6 @@ class ScalarFunctionsValidationTest extends 
ScalarTypesTestBase {
   // Sub-query functions
   // 
----------------------------------------------------------------------------------------------
 
-  @Test
-  def testInValidationExceptionMoreThanOneTypes(): Unit = {
-    assertThatExceptionOfType(classOf[ValidationException])
-      .isThrownBy(() => testTableApi('f2.in('f3, 'f8), "TRUE"))
-
-    assertThatExceptionOfType(classOf[ValidationException])
-      .isThrownBy(() => testTableApi('f2.in('f3, 'f4, 4), "FALSE"))
-  }
-
   @Test
   def scalaInValidationExceptionDifferentOperandsTest(): Unit = {
     assertThatExceptionOfType(classOf[ValidationException])
diff --git 
a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/expressions/validation/ScalarOperatorsValidationTest.scala
 
b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/expressions/validation/ScalarOperatorsValidationTest.scala
index f319924c035..189591c691d 100644
--- 
a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/expressions/validation/ScalarOperatorsValidationTest.scala
+++ 
b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/expressions/validation/ScalarOperatorsValidationTest.scala
@@ -47,12 +47,6 @@ class ScalarOperatorsValidationTest extends 
ScalarOperatorsTestBase {
   // Sub-query functions
   // 
----------------------------------------------------------------------------------------------
 
-  @Test
-  def testInMoreThanOneTypes(): Unit = {
-    assertThatExceptionOfType(classOf[ValidationException])
-      .isThrownBy(() => testTableApi('f2.in('f3, 'f4, 4), "FAIL"))
-  }
-
   @Test
   def testInDifferentOperands(): Unit = {
     assertThatExceptionOfType(classOf[ValidationException])


Reply via email to