This is an automated email from the ASF dual-hosted git repository. dmsysolyatin pushed a commit to branch CALCITE-5160-v3 in repository https://gitbox.apache.org/repos/asf/calcite.git
commit 32ee6dc6dd4de903df8a5cc4d6c3a7985fd20565 Author: dssysolyatin <[email protected]> AuthorDate: Mon Jun 13 18:25:46 2022 +0300 [CALCITE-5160] ANY, SOME operators should support scalar arrays --- .../calcite/adapter/enumerable/RexImpTable.java | 101 +++++++++++++++++++++ .../org/apache/calcite/runtime/SqlFunctions.java | 31 ++++++- .../calcite/sql/fun/SqlQuantifyOperator.java | 34 +++++++ .../calcite/sql/fun/SqlStdOperatorTable.java | 9 ++ .../apache/calcite/sql2rel/SqlToRelConverter.java | 29 +++++- .../calcite/sql2rel/StandardConvertletTable.java | 27 +++++- .../org/apache/calcite/util/BuiltInMethod.java | 3 + .../org/apache/calcite/test/SqlOperatorTest.java | 61 +++++++++++++ 8 files changed, 291 insertions(+), 4 deletions(-) diff --git a/core/src/main/java/org/apache/calcite/adapter/enumerable/RexImpTable.java b/core/src/main/java/org/apache/calcite/adapter/enumerable/RexImpTable.java index 3a03c67e68..b1ed010eea 100644 --- a/core/src/main/java/org/apache/calcite/adapter/enumerable/RexImpTable.java +++ b/core/src/main/java/org/apache/calcite/adapter/enumerable/RexImpTable.java @@ -60,6 +60,7 @@ import org.apache.calcite.sql.SqlTypeConstructorFunction; import org.apache.calcite.sql.SqlWindowTableFunction; import org.apache.calcite.sql.fun.SqlJsonArrayAggAggFunction; import org.apache.calcite.sql.fun.SqlJsonObjectAggAggFunction; +import org.apache.calcite.sql.fun.SqlQuantifyOperator; import org.apache.calcite.sql.fun.SqlStdOperatorTable; import org.apache.calcite.sql.fun.SqlTrimFunction; import org.apache.calcite.sql.type.SqlTypeName; @@ -330,6 +331,12 @@ import static org.apache.calcite.sql.fun.SqlStdOperatorTable.SIN; import static org.apache.calcite.sql.fun.SqlStdOperatorTable.SINGLE_VALUE; import static org.apache.calcite.sql.fun.SqlStdOperatorTable.SLICE; import static org.apache.calcite.sql.fun.SqlStdOperatorTable.SOME; +import static org.apache.calcite.sql.fun.SqlStdOperatorTable.SOME_EQ; +import static org.apache.calcite.sql.fun.SqlStdOperatorTable.SOME_GE; +import static org.apache.calcite.sql.fun.SqlStdOperatorTable.SOME_GT; +import static org.apache.calcite.sql.fun.SqlStdOperatorTable.SOME_LE; +import static org.apache.calcite.sql.fun.SqlStdOperatorTable.SOME_LT; +import static org.apache.calcite.sql.fun.SqlStdOperatorTable.SOME_NE; import static org.apache.calcite.sql.fun.SqlStdOperatorTable.STRUCT_ACCESS; import static org.apache.calcite.sql.fun.SqlStdOperatorTable.SUBMULTISET_OF; import static org.apache.calcite.sql.fun.SqlStdOperatorTable.SUBSTRING; @@ -743,6 +750,13 @@ public class RexImpTable { map.put(CURRENT_ROLE, systemFunctionImplementor); map.put(CURRENT_CATALOG, systemFunctionImplementor); + defineSome(SOME_EQ, EQUALS); + defineSome(SOME_GT, GREATER_THAN); + defineSome(SOME_GE, GREATER_THAN_OR_EQUAL); + defineSome(SOME_LE, LESS_THAN_OR_EQUAL); + defineSome(SOME_LT, LESS_THAN); + defineSome(SOME_NE, NOT_EQUALS); + // Current time functions map.put(CURRENT_TIME, systemFunctionImplementor); map.put(CURRENT_TIMESTAMP, systemFunctionImplementor); @@ -850,6 +864,12 @@ public class RexImpTable { new BinaryImplementor(nullPolicy, true, expressionType, backupMethodName)); } + + private void defineSome(SqlQuantifyOperator operator, SqlBinaryOperator binaryOperator) { + map.put(operator, + new SomeImplementor(binaryOperator, + (BinaryImplementor) requireNonNull(map.get(binaryOperator)))); + } } public static CallImplementor createImplementor( @@ -4070,4 +4090,85 @@ public class RexImpTable { gapInterval)); } } + + /** + * Implementation of + * <a href="https://www.postgresql.org/docs/current/functions-comparisons.html#id-1.5.8.30.16"> + * ANY/SOME (array)</a>. + */ + private static class SomeImplementor extends AbstractRexCallImplementor { + private final SqlBinaryOperator binaryOperator; + private final BinaryImplementor binaryImplementor; + + SomeImplementor(SqlBinaryOperator binaryOperator, BinaryImplementor binaryImplementor) { + super(NullPolicy.NONE, false); + this.binaryOperator = binaryOperator; + this.binaryImplementor = binaryImplementor; + } + + @Override String getVariableName() { + return "some"; + } + + @Override Expression implementSafe(RexToLixTranslator translator, RexCall call, + List<Expression> argValueList) { + // neither nullable: + // return COLLECTION_SOME(collection, <predicate>) + // right component nullable: + // return COLLECTION_NULLABLE_SOME(collection, <predicate>) + // left nullable: + // left null literal: + // return null; + // left nullable expression: + // return left == null ? null : COLLECTION_<SOME|NULLABLE_SOME> + Expression left = argValueList.get(0); + Expression right = argValueList.get(1); + + RexNode leftRex = call.getOperands().get(0); + if (leftRex.getType().getSqlTypeName() == SqlTypeName.NULL) { + return left; + } + + final Expression nullExpr = Expressions.constant(null); + final RelDataType rightComponentType = + requireNonNull(call.getOperands().get(1).getType().getComponentType()); + final ParameterExpression predicateArg = Expressions.parameter( + translator.typeFactory.getJavaClass(rightComponentType), "el"); + + // left should have final modifier otherwise it can not be passed to lambda + final ParameterExpression leftValueExpr = + Expressions.parameter(left.getType(), + translator.getBlockBuilder().newName("_" + getVariableName() + "_value")); + translator.getBlockBuilder().add(Expressions.declare(Modifier.FINAL, leftValueExpr, left)); + + BlockBuilder predicateBuilder = new BlockBuilder(); + final Expression condition = binaryImplementor.implementSafe(translator, + (RexCall) translator.builder.makeCall(binaryOperator, + ImmutableList.of(leftRex, translator.builder.makeZeroLiteral(leftRex.getType()))), + ImmutableList.of(leftValueExpr, predicateArg) + ); + + predicateBuilder.add(Expressions.return_(null, condition)); + final Expression predicate = Expressions.lambda(predicateBuilder.toBlock(), predicateArg); + + final Expression someCallExpr; + if (rightComponentType.isNullable()) { + someCallExpr = Expressions.call(BuiltInMethod.COLLECTION_NULLABLE_SOME.method, + right, predicate); + } else { + someCallExpr = Expressions.call(BuiltInMethod.COLLECTION_SOME.method, right, predicate); + } + + if (leftRex.getType().isNullable()) { + // left == null ? null : someCallExpr + return Expressions.makeTernary( + ExpressionType.Conditional, + Expressions.equal(leftValueExpr, nullExpr), + nullExpr, + someCallExpr + ); + } + return someCallExpr; + } + } } diff --git a/core/src/main/java/org/apache/calcite/runtime/SqlFunctions.java b/core/src/main/java/org/apache/calcite/runtime/SqlFunctions.java index abbec30a08..29aa3d5c7e 100644 --- a/core/src/main/java/org/apache/calcite/runtime/SqlFunctions.java +++ b/core/src/main/java/org/apache/calcite/runtime/SqlFunctions.java @@ -31,6 +31,7 @@ import org.apache.calcite.linq4j.function.Deterministic; import org.apache.calcite.linq4j.function.Experimental; import org.apache.calcite.linq4j.function.Function1; import org.apache.calcite.linq4j.function.NonDeterministic; +import org.apache.calcite.linq4j.function.Predicate1; import org.apache.calcite.linq4j.tree.Primitive; import org.apache.calcite.rel.type.TimeFrame; import org.apache.calcite.rel.type.TimeFrameSet; @@ -3636,6 +3637,34 @@ public class SqlFunctions { return args; } + /** Returns whether there is an element in {@code list} for which + * {@code predicate} is true. Also, there are specific rules for handling null values: + * <ul> + * <li>if an element from the {@code list} is null + * then the element does not satisfy {@code predicate}</li> + * <li>if there is no an element from the {@code list} that satisfies the {@code predicate} + * and the {@code list} contains at least one null value then function returns null</li> + * </ul> + * + * Support the {@code SOME(<ARRAY>|<MULTISET>) } operator. + * + * @see org.apache.calcite.util.BuiltInMethod#COLLECTION_NULLABLE_SOME + */ + public static @Nullable <E> Boolean nullableSome(List<? extends E> list, + Predicate1<E> predicate) { + boolean nullExists = false; + for (E e : list) { + if (e == null) { + nullExists = true; + continue; + } + if (predicate.apply(e)) { + return true; + } + } + return nullExists ? null : false; + } + /** Similar to {@link Linq4j#product(Iterable)} but each resulting list * implements {@link FlatLists.ComparableList}. */ public static <E extends Comparable> Enumerable<FlatLists.ComparableList<E>> product( @@ -3648,7 +3677,7 @@ public class SqlFunctions { } }; } - + /** * Implements the {@code .} (field access) operator on an object * whose type is not known until runtime. diff --git a/core/src/main/java/org/apache/calcite/sql/fun/SqlQuantifyOperator.java b/core/src/main/java/org/apache/calcite/sql/fun/SqlQuantifyOperator.java index 37250cd7cf..e49dd33594 100644 --- a/core/src/main/java/org/apache/calcite/sql/fun/SqlQuantifyOperator.java +++ b/core/src/main/java/org/apache/calcite/sql/fun/SqlQuantifyOperator.java @@ -16,10 +16,19 @@ */ package org.apache.calcite.sql.fun; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.sql.SqlCall; import org.apache.calcite.sql.SqlKind; +import org.apache.calcite.sql.SqlNode; +import org.apache.calcite.sql.SqlNodeList; +import org.apache.calcite.sql.type.SqlTypeName; +import org.apache.calcite.sql.type.SqlTypeUtil; +import org.apache.calcite.sql.validate.SqlValidator; +import org.apache.calcite.sql.validate.SqlValidatorScope; import com.google.common.base.Preconditions; +import java.util.List; import java.util.Objects; /** @@ -59,4 +68,29 @@ public class SqlQuantifyOperator extends SqlInOperator { Preconditions.checkArgument(kind == SqlKind.SOME || kind == SqlKind.ALL); } + + @Override public RelDataType deriveType(SqlValidator validator, + SqlValidatorScope scope, SqlCall call) { + final List<SqlNode> operands = call.getOperandList(); + assert operands.size() == 2; + + // Supporting SOME(<MULTISET>|<ARRAY>) + final SqlNode right = operands.get(1); + if (call.getKind() == SqlKind.SOME && right instanceof SqlNodeList + && ((SqlNodeList) right).size() == 1) { + final RelDataType rightType = validator.deriveType(scope, ((SqlNodeList) right).get(0)); + if (SqlTypeUtil.isCollection(rightType)) { + final RelDataType componentRightType = Objects.requireNonNull(rightType.getComponentType()); + final RelDataType leftType = validator.deriveType(scope, operands.get(0)); + if (SqlTypeUtil.sameNamedType(componentRightType, leftType)) { + return validator.getTypeFactory().createTypeWithNullability( + validator.getTypeFactory().createSqlType(SqlTypeName.BOOLEAN), + componentRightType.isNullable() || leftType.isNullable() + ); + } + } + } + + return super.deriveType(validator, scope, call); + } } diff --git a/core/src/main/java/org/apache/calcite/sql/fun/SqlStdOperatorTable.java b/core/src/main/java/org/apache/calcite/sql/fun/SqlStdOperatorTable.java index 6c7f1cfc46..71fcf0709d 100644 --- a/core/src/main/java/org/apache/calcite/sql/fun/SqlStdOperatorTable.java +++ b/core/src/main/java/org/apache/calcite/sql/fun/SqlStdOperatorTable.java @@ -457,6 +457,15 @@ public class SqlStdOperatorTable extends ReflectiveSqlOperatorTable { public static final SqlQuantifyOperator SOME_NE = new SqlQuantifyOperator(SqlKind.SOME, SqlKind.NOT_EQUALS); + public static final List<SqlQuantifyOperator> SOME_OPERATORS = ImmutableList.of( + SqlStdOperatorTable.SOME_EQ, + SqlStdOperatorTable.SOME_GT, + SqlStdOperatorTable.SOME_GE, + SqlStdOperatorTable.SOME_LE, + SqlStdOperatorTable.SOME_LT, + SqlStdOperatorTable.SOME_NE + ); + /** * The <code>< ALL</code> operator. */ diff --git a/core/src/main/java/org/apache/calcite/sql2rel/SqlToRelConverter.java b/core/src/main/java/org/apache/calcite/sql2rel/SqlToRelConverter.java index 581fa48a82..6469072821 100644 --- a/core/src/main/java/org/apache/calcite/sql2rel/SqlToRelConverter.java +++ b/core/src/main/java/org/apache/calcite/sql2rel/SqlToRelConverter.java @@ -1190,6 +1190,7 @@ public class SqlToRelConverter { case ALL: call = (SqlBasicCall) subQuery.node; query = call.operand(1); + if (!config.isExpand() && !(query instanceof SqlNodeList)) { return; } @@ -2104,9 +2105,9 @@ public class SqlToRelConverter { // register the scalar sub-queries first so they can be converted // before the IN expression is converted. switch (kind) { + case SOME: case IN: case NOT_IN: - case SOME: case ALL: switch (logic) { case TRUE_FALSE_UNKNOWN: @@ -2124,7 +2125,27 @@ public class SqlToRelConverter { default: break; } + + if (kind == SqlKind.SOME) { + List<SqlNode> operandList = ((SqlCall) node).getOperandList(); + SqlNode left = operandList.get(0); + SqlNode right = operandList.get(1); + if (right instanceof SqlNodeList && ((SqlNodeList) right).size() == 1) { + SqlNode rightValue = ((SqlNodeList) right).get(0); + RelDataType rightValueType = validator().deriveType(bb.scope(), rightValue); + if (SqlTypeUtil.isCollection(rightValueType)) { + RelDataType leftType = validator().deriveType(bb.scope(), left); + if (SqlTypeUtil.sameNamedType(requireNonNull(rightValueType.getComponentType()), + leftType) || SqlUtil.isNull(left)) { + findSubQueries(bb, rightValue, logic, registerOnlyScalarSubQueries, clause); + return; + } + } + } + } + bb.registerSubQuery(node, logic, clause); + break; default: break; @@ -5451,7 +5472,11 @@ public class SqlToRelConverter { case CURSOR: case IN: case NOT_IN: - subQuery = requireNonNull(getSubQuery(expr, null)); + subQuery = getSubQuery(expr, null); + if (subQuery == null && kind == SqlKind.SOME) { + break; + } + assert subQuery != null; rex = requireNonNull(subQuery.expr); return StandardConvertletTable.castToValidatedType(expr, rex, validator(), rexBuilder); diff --git a/core/src/main/java/org/apache/calcite/sql2rel/StandardConvertletTable.java b/core/src/main/java/org/apache/calcite/sql2rel/StandardConvertletTable.java index 28fe7c0374..e5cdeee485 100644 --- a/core/src/main/java/org/apache/calcite/sql2rel/StandardConvertletTable.java +++ b/core/src/main/java/org/apache/calcite/sql2rel/StandardConvertletTable.java @@ -96,6 +96,7 @@ import java.util.Objects; import java.util.function.UnaryOperator; import java.util.stream.Collectors; +import static org.apache.calcite.sql.fun.SqlStdOperatorTable.SOME_OPERATORS; import static org.apache.calcite.sql.type.NonNullableAccessors.getComponentTypeOrThrow; import static org.apache.calcite.util.Util.first; @@ -183,7 +184,7 @@ public class StandardConvertletTable extends ReflectiveConvertletTable { new SubstrConvertlet(SqlLibrary.ORACLE)); registerOp(SqlLibraryOperators.SUBSTR_POSTGRESQL, new SubstrConvertlet(SqlLibrary.POSTGRESQL)); - + registerOp(SqlLibraryOperators.DATE_ADD, new TimestampAddConvertlet()); registerOp(SqlLibraryOperators.DATE_DIFF, @@ -209,6 +210,8 @@ public class StandardConvertletTable extends ReflectiveConvertletTable { registerOp(SqlLibraryOperators.TIMESTAMP_SUB, new TimestampSubConvertlet()); + SOME_OPERATORS.forEach(operator -> registerOp(operator, StandardConvertletTable::convertSome)); + registerOp(SqlLibraryOperators.NVL, StandardConvertletTable::convertNvl); registerOp(SqlLibraryOperators.DECODE, StandardConvertletTable::convertDecode); @@ -346,6 +349,28 @@ public class StandardConvertletTable extends ReflectiveConvertletTable { }); } } + + /** + * Converts SOME, ANY operator. + */ + private static RexNode convertSome(SqlRexContext cx, SqlCall call) { + final RexBuilder rexBuilder = cx.getRexBuilder(); + final RexNode operand0 = + cx.convertExpression(call.getOperandList().get(0)); + assert call.getOperandList().get(1) instanceof SqlNodeList; + final RexNode operand1 = + cx.convertExpression(((SqlNodeList) call.getOperandList().get(1)).get(0)); + final RelDataType operand1ComponentType = + requireNonNull(operand1.getType().getComponentType()); + final RelDataType returnType = cx.getTypeFactory().createTypeWithNullability( + cx.getTypeFactory().createSqlType(SqlTypeName.BOOLEAN), + operand0.getType().isNullable() || operand1ComponentType.isNullable()); + return rexBuilder.makeCall( + returnType, + call.getOperator(), + ImmutableList.of(operand0, operand1) + ); + } /** Converts a call to the {@code NVL} function (and also its synonym, * {@code IFNULL}). */ diff --git a/core/src/main/java/org/apache/calcite/util/BuiltInMethod.java b/core/src/main/java/org/apache/calcite/util/BuiltInMethod.java index 2b3eaaf085..d3f0e8601a 100644 --- a/core/src/main/java/org/apache/calcite/util/BuiltInMethod.java +++ b/core/src/main/java/org/apache/calcite/util/BuiltInMethod.java @@ -308,6 +308,9 @@ public enum BuiltInMethod { COLLECTIONS_EMPTY_LIST(Collections.class, "emptyList"), COLLECTIONS_SINGLETON_LIST(Collections.class, "singletonList", Object.class), COLLECTION_SIZE(Collection.class, "size"), + COLLECTION_SOME(Functions.class, "exists", List.class, Predicate1.class), + COLLECTION_NULLABLE_SOME(SqlFunctions.class, "nullableSome", + List.class, Predicate1.class), MAP_CLEAR(Map.class, "clear"), MAP_GET(Map.class, "get", Object.class), MAP_GET_OR_DEFAULT(Map.class, "getOrDefault", Object.class, Object.class), diff --git a/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java b/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java index 4bbb6f1d69..bbc69af7ac 100644 --- a/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java +++ b/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java @@ -19,6 +19,8 @@ package org.apache.calcite.test; import org.apache.calcite.avatica.util.DateTimeUtils; import org.apache.calcite.config.CalciteConnectionProperty; import org.apache.calcite.linq4j.Linq4j; +import org.apache.calcite.linq4j.function.Function1; +import org.apache.calcite.linq4j.function.Function2; import org.apache.calcite.plan.Strong; import org.apache.calcite.rel.type.DelegatingTypeSystem; import org.apache.calcite.rel.type.RelDataType; @@ -104,6 +106,7 @@ import java.util.regex.Pattern; import static org.apache.calcite.linq4j.tree.Expressions.list; import static org.apache.calcite.rel.type.RelDataTypeImpl.NON_NULLABLE_SUFFIX; import static org.apache.calcite.sql.fun.SqlStdOperatorTable.PI; +import static org.apache.calcite.sql.fun.SqlStdOperatorTable.SOME_OPERATORS; import static org.apache.calcite.sql.test.ResultCheckers.isExactly; import static org.apache.calcite.sql.test.ResultCheckers.isNullValue; import static org.apache.calcite.sql.test.ResultCheckers.isSet; @@ -9471,6 +9474,64 @@ public class SqlOperatorTest { f.checkAgg("some(x = 2)", values, isSingle("true")); } + @Test void testSomeOperatorFunc() { + final SqlOperatorFixture f = fixture(); + SOME_OPERATORS.forEach(operator -> f.setFor(operator, VM_EXPAND)); + + Function2<String, Boolean, Void> checkBoolean = (sql, result) -> { + f.checkBoolean(sql.replace("COLLECTION", "ARRAY"), result); + f.checkBoolean(sql.replace("COLLECTION", "MULTISET"), result); + + return null; + }; + + Function1<String, Void> checkNull = sql -> { + f.checkNull(sql.replace("COLLECTION", "ARRAY")); + f.checkNull(sql.replace("COLLECTION", "MULTISET")); + return null; + }; + + checkBoolean.apply("1 = some (COLLECTION[1,2,null])", true); + checkBoolean.apply("3 = some (COLLECTION[1,2])", false); + checkNull.apply("3 = some (COLLECTION[1,2,null])"); + + checkBoolean.apply("2 <> some (COLLECTION[1,2,null])", true); + checkBoolean.apply("3 <> some (COLLECTION[1,2,null])", true); + checkBoolean.apply("1 <> some (COLLECTION[1])", false); + checkNull.apply("1 <> some (COLLECTION[1,null])"); + + checkBoolean.apply("1 < some (COLLECTION[1,2,null])", true); // 1 < 2 + checkBoolean.apply("0 < some (COLLECTION[1,2,null])", true); // 0 < [1,2] + checkBoolean.apply("2 < some (COLLECTION[1,2])", false); + checkNull.apply("2 < some (COLLECTION[1,2,null])"); + + checkBoolean.apply("2 <= some (COLLECTION[1,2,null])", true); // 2 <= 2 + checkBoolean.apply("0 <= some (COLLECTION[1,2,null])", true); // 0 <= [1,2] + checkBoolean.apply("3 <= some (COLLECTION[1,2])", false); + checkNull.apply("3 <= some (COLLECTION[1,2,null])"); + + checkBoolean.apply("2 > some (COLLECTION[1,2,null])", true); // 2 > 1 + checkBoolean.apply("3 > some (COLLECTION[1,2,null])", true); // 3 > [1,2] + checkBoolean.apply("1 > some (COLLECTION[1,2])", false); + checkNull.apply("1 > some (COLLECTION[1,2,null])"); + + checkBoolean.apply("2 >= some (COLLECTION[1,2,null])", true); // 2 >= 2 + checkBoolean.apply("3 >= some (COLLECTION[1,2,null])", true); // 3 >= [1,2] + checkBoolean.apply("0 >= some (COLLECTION[1,2])", false); + checkNull.apply("0 >= some (COLLECTION[1,2,null])"); + + SOME_OPERATORS.forEach(operator -> { + checkNull.apply("null" + operator.comparisonKind.sql + " some (COLLECTION[1,2,3])"); + checkNull.apply("null" + operator.comparisonKind.sql + " some (COLLECTION[1,2,3,null])"); + }); + + f.check("SELECT 3 = some(x.t) FROM (SELECT ARRAY[1,2,3,null] as t) as x", + "BOOLEAN", true); + f.check("SELECT 4 = some(x.t) FROM (SELECT ARRAY[1,2,3] as t) as x", + "BOOLEAN NOT NULL", false); + f.check("SELECT 4 = some(x.t) FROM (SELECT ARRAY[1,2,3,null] as t) as x", + "BOOLEAN", isNullValue()); + } @Test void testAnyValueFunc() { final SqlOperatorFixture f = fixture();
