This is an automated email from the ASF dual-hosted git repository. jhyde pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/calcite.git
commit a00d565456b43a245602aff5fd399a7f8794a83d Author: Norman Jordan <[email protected]> AuthorDate: Tue Sep 10 15:32:13 2024 -0700 [CALCITE-6550] Improve SQL function overloading This change improves how RexImpTable handles collisions in the map containing scalar functions if two operators are so similar that `equals` and `hashCode` regard them as equal. Rules are now as follows: * If only one implementor is found for an operator key, return that implementor * If there are multiple implementors for an operator key, look for one with the exact same operator * An operator key is the operator name and kind Following this change, we were able to remove tech debt: convert some trivial anonymous subclasses in SqlLibraryOperators back to SqlBasicFunction.create, and make SqlBasicFunction's constructor private again. Close apache/calcite#3954 --- .../calcite/adapter/enumerable/RexImpTable.java | 54 +++++++++++--- .../org/apache/calcite/sql/SqlBasicFunction.java | 2 +- .../calcite/sql/fun/SqlLibraryOperators.java | 82 ++++++++-------------- 3 files changed, 76 insertions(+), 62 deletions(-) diff --git a/core/src/main/java/org/apache/calcite/adapter/enumerable/RexImpTable.java b/core/src/main/java/org/apache/calcite/adapter/enumerable/RexImpTable.java index c37796c2a8..258df59bda 100644 --- a/core/src/main/java/org/apache/calcite/adapter/enumerable/RexImpTable.java +++ b/core/src/main/java/org/apache/calcite/adapter/enumerable/RexImpTable.java @@ -46,6 +46,7 @@ import org.apache.calcite.rex.RexNode; import org.apache.calcite.rex.RexPatternFieldRef; import org.apache.calcite.rex.RexWindowExclusion; import org.apache.calcite.runtime.FlatLists; +import org.apache.calcite.runtime.PairList; import org.apache.calcite.runtime.SqlFunctions; import org.apache.calcite.schema.FunctionContext; import org.apache.calcite.schema.ImplementableAggFunction; @@ -544,7 +545,7 @@ public class RexImpTable { public static final MemberExpression BOXED_TRUE_EXPR = Expressions.field(null, Boolean.class, "TRUE"); - private final ImmutableMap<SqlOperator, RexCallImplementor> map; + private final ImmutableMap<SqlOperator, PairList<SqlOperator, RexCallImplementor>> map; private final ImmutableMap<SqlAggFunction, Supplier<? extends AggImplementor>> aggMap; private final ImmutableMap<SqlAggFunction, Supplier<? extends WinAggImplementor>> winAggMap; private final ImmutableMap<SqlMatchFunction, Supplier<? extends MatchImplementor>> matchMap; @@ -552,7 +553,10 @@ public class RexImpTable { tvfImplementorMap; private RexImpTable(Builder builder) { - this.map = ImmutableMap.copyOf(builder.map); + final ImmutableMap.Builder<SqlOperator, PairList<SqlOperator, RexCallImplementor>> + mapBuilder = ImmutableMap.builder(); + builder.map.forEach((k, v) -> mapBuilder.put(k, v.immutable())); + this.map = ImmutableMap.copyOf(mapBuilder.build()); this.aggMap = ImmutableMap.copyOf(builder.aggMap); this.winAggMap = ImmutableMap.copyOf(builder.winAggMap); this.matchMap = ImmutableMap.copyOf(builder.matchMap); @@ -850,7 +854,6 @@ public class RexImpTable { new SafeArithmeticImplementor(BuiltInMethod.SAFE_SUBTRACT.method)); define(PI, new PiImplementor()); - populate2(); } /** Second step of population. */ @@ -1279,7 +1282,8 @@ public class RexImpTable { /** Holds intermediate state from which a RexImpTable can be constructed. */ private static class Builder extends AbstractBuilder { - private final Map<SqlOperator, RexCallImplementor> map = new HashMap<>(); + private final Map<SqlOperator, PairList<SqlOperator, RexCallImplementor>> map = + new HashMap<>(); private final Map<SqlAggFunction, Supplier<? extends AggImplementor>> aggMap = new HashMap<>(); private final Map<SqlAggFunction, Supplier<? extends WinAggImplementor>> winAggMap = @@ -1290,13 +1294,27 @@ public class RexImpTable { tvfImplementorMap = new HashMap<>(); @Override protected RexCallImplementor get(SqlOperator operator) { - return requireNonNull(map.get(operator), - () -> "no implementor for " + operator); + final PairList<SqlOperator, RexCallImplementor> implementors = + requireNonNull(map.get(operator)); + if (implementors.size() == 1) { + return implementors.get(0).getValue(); + } else { + for (Map.Entry<SqlOperator, RexCallImplementor> entry : implementors) { + if (operator == entry.getKey()) { + return entry.getValue(); + } + } + throw new NullPointerException(); + } } @Override <T extends RexCallImplementor> T define(SqlOperator operator, T implementor) { - map.put(operator, requireNonNull(implementor, "implementor")); + if (map.containsKey(operator)) { + map.get(operator).add(operator, implementor); + } else { + map.put(operator, PairList.of(operator, implementor)); + } return implementor; } @@ -1371,9 +1389,27 @@ public class RexImpTable { ((ImplementableFunction) udf).getImplementor(); return wrapAsRexCallImplementor(implementor); } else if (operator instanceof SqlTypeConstructorFunction) { - return map.get(SqlStdOperatorTable.ROW); + final PairList<SqlOperator, RexCallImplementor> implementors = + map.get(SqlStdOperatorTable.ROW); + if (implementors != null && implementors.size() == 1) { + return implementors.get(0).getValue(); + } + } else { + final PairList<SqlOperator, RexCallImplementor> implementors = + map.get(operator); + if (implementors != null) { + if (implementors.size() == 1) { + return implementors.get(0).getValue(); + } else { + for (Map.Entry<SqlOperator, RexCallImplementor> entry : implementors) { + if (operator == entry.getKey()) { + return entry.getValue(); + } + } + } + } } - return map.get(operator); + return null; } public @Nullable AggImplementor get(final SqlAggFunction aggregation, diff --git a/core/src/main/java/org/apache/calcite/sql/SqlBasicFunction.java b/core/src/main/java/org/apache/calcite/sql/SqlBasicFunction.java index 05c1f80c0b..ceba36579e 100644 --- a/core/src/main/java/org/apache/calcite/sql/SqlBasicFunction.java +++ b/core/src/main/java/org/apache/calcite/sql/SqlBasicFunction.java @@ -67,7 +67,7 @@ public class SqlBasicFunction extends SqlFunction { * @param category Categorization for function * @param monotonicityInference Strategy to infer monotonicity of a call */ - protected SqlBasicFunction(String name, SqlKind kind, SqlSyntax syntax, + private SqlBasicFunction(String name, SqlKind kind, SqlSyntax syntax, boolean deterministic, SqlReturnTypeInference returnTypeInference, @Nullable SqlOperandTypeInference operandTypeInference, SqlOperandHandler operandHandler, diff --git a/core/src/main/java/org/apache/calcite/sql/fun/SqlLibraryOperators.java b/core/src/main/java/org/apache/calcite/sql/fun/SqlLibraryOperators.java index bbdaab5096..2b6466034d 100644 --- a/core/src/main/java/org/apache/calcite/sql/fun/SqlLibraryOperators.java +++ b/core/src/main/java/org/apache/calcite/sql/fun/SqlLibraryOperators.java @@ -42,7 +42,6 @@ import org.apache.calcite.sql.type.SqlTypeFamily; import org.apache.calcite.sql.type.SqlTypeName; import org.apache.calcite.sql.type.SqlTypeTransforms; import org.apache.calcite.sql.type.SqlTypeUtil; -import org.apache.calcite.sql.validate.SqlMonotonicity; import org.apache.calcite.sql.validate.SqlValidator; import org.apache.calcite.sql.validate.SqlValidatorUtil; import org.apache.calcite.util.Litmus; @@ -577,10 +576,8 @@ public abstract class SqlLibraryOperators { * {@code rep} and returns modified value. */ @LibraryOperator(libraries = {REDSHIFT}) public static final SqlFunction REGEXP_REPLACE_2 = - new SqlBasicFunction("REGEXP_REPLACE", SqlKind.OTHER_FUNCTION, - SqlSyntax.FUNCTION, true, ReturnTypes.VARCHAR_NULLABLE, null, - OperandHandlers.DEFAULT, OperandTypes.STRING_STRING, 0, - SqlFunctionCategory.STRING, call -> SqlMonotonicity.NOT_MONOTONIC, false) { }; + SqlBasicFunction.create("REGEXP_REPLACE", ReturnTypes.VARCHAR_NULLABLE, + OperandTypes.STRING_STRING, SqlFunctionCategory.STRING); /** The "REGEXP_REPLACE(value, regexp, rep)" * function. Replaces all substrings of value that match regexp with @@ -596,11 +593,10 @@ public abstract class SqlLibraryOperators { * pos. */ @LibraryOperator(libraries = {MYSQL, ORACLE, REDSHIFT}) public static final SqlFunction REGEXP_REPLACE_4 = - new SqlBasicFunction("REGEXP_REPLACE", SqlKind.OTHER_FUNCTION, - SqlSyntax.FUNCTION, true, ReturnTypes.VARCHAR_NULLABLE, null, - OperandHandlers.DEFAULT, OperandTypes.family(SqlTypeFamily.STRING, SqlTypeFamily.STRING, - SqlTypeFamily.STRING, SqlTypeFamily.INTEGER), - 0, SqlFunctionCategory.STRING, call -> SqlMonotonicity.NOT_MONOTONIC, false) { }; + SqlBasicFunction.create("REGEXP_REPLACE", ReturnTypes.VARCHAR_NULLABLE, + OperandTypes.family(SqlTypeFamily.STRING, SqlTypeFamily.STRING, SqlTypeFamily.STRING, + SqlTypeFamily.INTEGER), + SqlFunctionCategory.STRING); /** The "REGEXP_REPLACE(value, regexp, rep, pos, [ occurrence | matchType ])" * function. Replaces all substrings of value that match regexp with @@ -609,15 +605,13 @@ public abstract class SqlLibraryOperators { * is a string of flags to apply to the search. */ @LibraryOperator(libraries = {MYSQL, REDSHIFT}) public static final SqlFunction REGEXP_REPLACE_5 = - new SqlBasicFunction("REGEXP_REPLACE", SqlKind.OTHER_FUNCTION, - SqlSyntax.FUNCTION, true, ReturnTypes.VARCHAR_NULLABLE, null, - OperandHandlers.DEFAULT, + SqlBasicFunction.create("REGEXP_REPLACE", ReturnTypes.VARCHAR_NULLABLE, OperandTypes.or( OperandTypes.family(SqlTypeFamily.STRING, SqlTypeFamily.STRING, SqlTypeFamily.STRING, SqlTypeFamily.INTEGER, SqlTypeFamily.INTEGER), OperandTypes.family(SqlTypeFamily.STRING, SqlTypeFamily.STRING, SqlTypeFamily.STRING, SqlTypeFamily.INTEGER, SqlTypeFamily.STRING)), - 0, SqlFunctionCategory.STRING, call -> SqlMonotonicity.NOT_MONOTONIC, false) { }; + SqlFunctionCategory.STRING); /** The "REGEXP_REPLACE(value, regexp, rep, pos, matchType)" * function. Replaces all substrings of value that match regexp with @@ -625,11 +619,10 @@ public abstract class SqlLibraryOperators { * pos. Replace only the occurrence match or all matches if occurrence is 0. */ @LibraryOperator(libraries = {ORACLE}) public static final SqlFunction REGEXP_REPLACE_5_ORACLE = - new SqlBasicFunction("REGEXP_REPLACE", SqlKind.OTHER_FUNCTION, - SqlSyntax.FUNCTION, true, ReturnTypes.VARCHAR_NULLABLE, null, - OperandHandlers.DEFAULT, OperandTypes.family(SqlTypeFamily.STRING, SqlTypeFamily.STRING, - SqlTypeFamily.STRING, SqlTypeFamily.INTEGER, SqlTypeFamily.INTEGER), - 0, SqlFunctionCategory.STRING, call -> SqlMonotonicity.NOT_MONOTONIC, false) { }; + SqlBasicFunction.create("REGEXP_REPLACE", ReturnTypes.VARCHAR_NULLABLE, + OperandTypes.family(SqlTypeFamily.STRING, SqlTypeFamily.STRING, + SqlTypeFamily.STRING, SqlTypeFamily.INTEGER, SqlTypeFamily.INTEGER), + SqlFunctionCategory.STRING); /** The "REGEXP_REPLACE(value, regexp, rep, pos, occurrence, matchType)" * function. Replaces all substrings of value that match regexp with @@ -638,41 +631,34 @@ public abstract class SqlLibraryOperators { * is a string of flags to apply to the search. */ @LibraryOperator(libraries = {MYSQL, ORACLE, REDSHIFT}) public static final SqlFunction REGEXP_REPLACE_6 = - new SqlBasicFunction("REGEXP_REPLACE", SqlKind.OTHER_FUNCTION, - SqlSyntax.FUNCTION, true, ReturnTypes.VARCHAR_NULLABLE, null, - OperandHandlers.DEFAULT, OperandTypes.family(SqlTypeFamily.STRING, SqlTypeFamily.STRING, - SqlTypeFamily.STRING, SqlTypeFamily.INTEGER, SqlTypeFamily.INTEGER, SqlTypeFamily.STRING), - 0, SqlFunctionCategory.STRING, call -> SqlMonotonicity.NOT_MONOTONIC, false) { }; + SqlBasicFunction.create("REGEXP_REPLACE", ReturnTypes.VARCHAR_NULLABLE, + OperandTypes.family(SqlTypeFamily.STRING, SqlTypeFamily.STRING, SqlTypeFamily.STRING, + SqlTypeFamily.INTEGER, SqlTypeFamily.INTEGER, SqlTypeFamily.STRING), + SqlFunctionCategory.STRING); /** The "REGEXP_REPLACE(value, regexp, rep)" * function. Replaces all substrings of value that match regexp with * {@code rep} and returns modified value. */ @LibraryOperator(libraries = {BIG_QUERY}) public static final SqlFunction REGEXP_REPLACE_BIG_QUERY_3 = - new SqlBasicFunction("REGEXP_REPLACE", SqlKind.OTHER_FUNCTION, - SqlSyntax.FUNCTION, true, ReturnTypes.VARCHAR_NULLABLE, null, - OperandHandlers.DEFAULT, OperandTypes.STRING_STRING_STRING, 0, - SqlFunctionCategory.STRING, call -> SqlMonotonicity.NOT_MONOTONIC, false) { }; + SqlBasicFunction.create("REGEXP_REPLACE", ReturnTypes.VARCHAR_NULLABLE, + OperandTypes.STRING_STRING_STRING, SqlFunctionCategory.STRING); /** The "REGEXP_REPLACE(value, regexp, rep)" * function. Replaces all substrings of value that match regexp with * {@code rep} and returns modified value. */ @LibraryOperator(libraries = {POSTGRESQL}, exceptLibraries = REDSHIFT) public static final SqlFunction REGEXP_REPLACE_PG_3 = - new SqlBasicFunction("REGEXP_REPLACE", SqlKind.OTHER_FUNCTION, - SqlSyntax.FUNCTION, true, ReturnTypes.VARCHAR_NULLABLE, null, - OperandHandlers.DEFAULT, OperandTypes.STRING_STRING_STRING, 0, - SqlFunctionCategory.STRING, call -> SqlMonotonicity.NOT_MONOTONIC, false) { }; + SqlBasicFunction.create("REGEXP_REPLACE", ReturnTypes.VARCHAR_NULLABLE, + OperandTypes.STRING_STRING_STRING, SqlFunctionCategory.STRING); /** The "REGEXP_REPLACE(value, regexp, rep, flags)" * function. Replaces all substrings of value that match regexp with * {@code rep} and returns modified value. flags are applied to the search. */ @LibraryOperator(libraries = {POSTGRESQL}, exceptLibraries = REDSHIFT) public static final SqlFunction REGEXP_REPLACE_PG_4 = - new SqlBasicFunction("REGEXP_REPLACE", SqlKind.OTHER_FUNCTION, - SqlSyntax.FUNCTION, true, ReturnTypes.VARCHAR_NULLABLE, null, - OperandHandlers.DEFAULT, OperandTypes.STRING_STRING_STRING_STRING, 0, - SqlFunctionCategory.STRING, call -> SqlMonotonicity.NOT_MONOTONIC, false) { }; + SqlBasicFunction.create("REGEXP_REPLACE", ReturnTypes.VARCHAR_NULLABLE, + OperandTypes.STRING_STRING_STRING_STRING, SqlFunctionCategory.STRING); /** The "REGEXP_SUBSTR(value, regexp[, position[, occurrence]])" function. * Returns the substring in value that matches the regexp. Returns NULL if there is no match. */ @@ -1875,10 +1861,8 @@ public abstract class SqlLibraryOperators { * converts {@code timestamp} to string according to the given {@code format}. */ @LibraryOperator(libraries = {POSTGRESQL}, exceptLibraries = {REDSHIFT}) public static final SqlFunction TO_CHAR_PG = - new SqlBasicFunction("TO_CHAR", SqlKind.OTHER_FUNCTION, - SqlSyntax.FUNCTION, true, ReturnTypes.VARCHAR_NULLABLE, null, - OperandHandlers.DEFAULT, OperandTypes.TIMESTAMP_STRING, 0, - SqlFunctionCategory.TIMEDATE, call -> SqlMonotonicity.NOT_MONOTONIC, false) { }; + SqlBasicFunction.create("TO_CHAR", ReturnTypes.VARCHAR_NULLABLE, + OperandTypes.TIMESTAMP_STRING, SqlFunctionCategory.TIMEDATE); /** The "TO_DATE(string1, string2)" function; casts string1 * to a DATE using the format specified in string2. */ @@ -1893,10 +1877,8 @@ public abstract class SqlLibraryOperators { * to a DATE using the format specified in string2. */ @LibraryOperator(libraries = {POSTGRESQL}, exceptLibraries = {REDSHIFT}) public static final SqlFunction TO_DATE_PG = - new SqlBasicFunction("TO_DATE", SqlKind.OTHER_FUNCTION, - SqlSyntax.FUNCTION, true, ReturnTypes.DATE_NULLABLE, null, - OperandHandlers.DEFAULT, OperandTypes.STRING_STRING, 0, - SqlFunctionCategory.TIMEDATE, call -> SqlMonotonicity.NOT_MONOTONIC, false) { }; + SqlBasicFunction.create("TO_DATE", ReturnTypes.DATE_NULLABLE, + OperandTypes.STRING_STRING, SqlFunctionCategory.TIMEDATE); /** The "TO_TIMESTAMP(string1, string2)" function; casts string1 * to a TIMESTAMP using the format specified in string2. */ @@ -1911,10 +1893,8 @@ public abstract class SqlLibraryOperators { * to a TIMESTAMP using the format specified in string2. */ @LibraryOperator(libraries = {POSTGRESQL}, exceptLibraries = {REDSHIFT}) public static final SqlFunction TO_TIMESTAMP_PG = - new SqlBasicFunction("TO_TIMESTAMP", SqlKind.OTHER_FUNCTION, - SqlSyntax.FUNCTION, true, ReturnTypes.TIMESTAMP_TZ_NULLABLE, null, - OperandHandlers.DEFAULT, OperandTypes.STRING_STRING, 0, - SqlFunctionCategory.TIMEDATE, call -> SqlMonotonicity.NOT_MONOTONIC, false) { }; + SqlBasicFunction.create("TO_TIMESTAMP", ReturnTypes.TIMESTAMP_TZ_NULLABLE, + OperandTypes.STRING_STRING, SqlFunctionCategory.TIMEDATE); /** * The "PARSE_TIME(string, string)" function (BigQuery); @@ -2512,10 +2492,8 @@ public abstract class SqlLibraryOperators { * to base numeric1.*/ @LibraryOperator(libraries = {POSTGRESQL}, exceptLibraries = {REDSHIFT}) public static final SqlFunction LOG_POSTGRES = - new SqlBasicFunction("LOG", SqlKind.LOG, - SqlSyntax.FUNCTION, true, ReturnTypes.DOUBLE_NULLABLE, null, - OperandHandlers.DEFAULT, OperandTypes.NUMERIC_OPTIONAL_NUMERIC, 0, - SqlFunctionCategory.NUMERIC, call -> SqlMonotonicity.NOT_MONOTONIC, false) { }; + SqlBasicFunction.create("LOG", ReturnTypes.DOUBLE_NULLABLE, + OperandTypes.NUMERIC_OPTIONAL_NUMERIC, SqlFunctionCategory.NUMERIC); /** The "LOG2(numeric)" function. Returns the base 2 logarithm of numeric. */ @LibraryOperator(libraries = {MYSQL, SPARK})
