This is an automated email from the ASF dual-hosted git repository.

jhyde pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/calcite.git

commit a00d565456b43a245602aff5fd399a7f8794a83d
Author: Norman Jordan <[email protected]>
AuthorDate: Tue Sep 10 15:32:13 2024 -0700

    [CALCITE-6550] Improve SQL function overloading
    
    This change improves how RexImpTable handles collisions in
    the map containing scalar functions if two operators are so
    similar that `equals` and `hashCode` regard them as equal.
    
    Rules are now as follows:
     * If only one implementor is found for an operator key,
       return that implementor
     * If there are multiple implementors for an operator key,
       look for one with the exact same operator
     * An operator key is the operator name and kind
    
    Following this change, we were able to remove tech debt:
    convert some trivial anonymous subclasses in
    SqlLibraryOperators back to SqlBasicFunction.create, and
    make SqlBasicFunction's constructor private again.
    
    Close apache/calcite#3954
---
 .../calcite/adapter/enumerable/RexImpTable.java    | 54 +++++++++++---
 .../org/apache/calcite/sql/SqlBasicFunction.java   |  2 +-
 .../calcite/sql/fun/SqlLibraryOperators.java       | 82 ++++++++--------------
 3 files changed, 76 insertions(+), 62 deletions(-)

diff --git 
a/core/src/main/java/org/apache/calcite/adapter/enumerable/RexImpTable.java 
b/core/src/main/java/org/apache/calcite/adapter/enumerable/RexImpTable.java
index c37796c2a8..258df59bda 100644
--- a/core/src/main/java/org/apache/calcite/adapter/enumerable/RexImpTable.java
+++ b/core/src/main/java/org/apache/calcite/adapter/enumerable/RexImpTable.java
@@ -46,6 +46,7 @@ import org.apache.calcite.rex.RexNode;
 import org.apache.calcite.rex.RexPatternFieldRef;
 import org.apache.calcite.rex.RexWindowExclusion;
 import org.apache.calcite.runtime.FlatLists;
+import org.apache.calcite.runtime.PairList;
 import org.apache.calcite.runtime.SqlFunctions;
 import org.apache.calcite.schema.FunctionContext;
 import org.apache.calcite.schema.ImplementableAggFunction;
@@ -544,7 +545,7 @@ public class RexImpTable {
   public static final MemberExpression BOXED_TRUE_EXPR =
       Expressions.field(null, Boolean.class, "TRUE");
 
-  private final ImmutableMap<SqlOperator, RexCallImplementor> map;
+  private final ImmutableMap<SqlOperator, PairList<SqlOperator, 
RexCallImplementor>> map;
   private final ImmutableMap<SqlAggFunction, Supplier<? extends 
AggImplementor>> aggMap;
   private final ImmutableMap<SqlAggFunction, Supplier<? extends 
WinAggImplementor>> winAggMap;
   private final ImmutableMap<SqlMatchFunction, Supplier<? extends 
MatchImplementor>> matchMap;
@@ -552,7 +553,10 @@ public class RexImpTable {
       tvfImplementorMap;
 
   private RexImpTable(Builder builder) {
-    this.map = ImmutableMap.copyOf(builder.map);
+    final ImmutableMap.Builder<SqlOperator, PairList<SqlOperator, 
RexCallImplementor>>
+        mapBuilder = ImmutableMap.builder();
+    builder.map.forEach((k, v) -> mapBuilder.put(k, v.immutable()));
+    this.map = ImmutableMap.copyOf(mapBuilder.build());
     this.aggMap = ImmutableMap.copyOf(builder.aggMap);
     this.winAggMap = ImmutableMap.copyOf(builder.winAggMap);
     this.matchMap = ImmutableMap.copyOf(builder.matchMap);
@@ -850,7 +854,6 @@ public class RexImpTable {
           new SafeArithmeticImplementor(BuiltInMethod.SAFE_SUBTRACT.method));
 
       define(PI, new PiImplementor());
-      populate2();
     }
 
     /** Second step of population. */
@@ -1279,7 +1282,8 @@ public class RexImpTable {
 
   /** Holds intermediate state from which a RexImpTable can be constructed. */
   private static class Builder extends AbstractBuilder {
-    private final Map<SqlOperator, RexCallImplementor> map = new HashMap<>();
+    private final Map<SqlOperator, PairList<SqlOperator, RexCallImplementor>> 
map =
+        new HashMap<>();
     private final Map<SqlAggFunction, Supplier<? extends AggImplementor>> 
aggMap =
         new HashMap<>();
     private final Map<SqlAggFunction, Supplier<? extends WinAggImplementor>> 
winAggMap =
@@ -1290,13 +1294,27 @@ public class RexImpTable {
         tvfImplementorMap = new HashMap<>();
 
     @Override protected RexCallImplementor get(SqlOperator operator) {
-      return requireNonNull(map.get(operator),
-          () -> "no implementor for " + operator);
+      final PairList<SqlOperator, RexCallImplementor> implementors =
+          requireNonNull(map.get(operator));
+      if (implementors.size() == 1) {
+        return implementors.get(0).getValue();
+      } else {
+        for (Map.Entry<SqlOperator, RexCallImplementor> entry : implementors) {
+          if (operator == entry.getKey()) {
+            return entry.getValue();
+          }
+        }
+        throw new NullPointerException();
+      }
     }
 
     @Override <T extends RexCallImplementor> T define(SqlOperator operator,
         T implementor) {
-      map.put(operator, requireNonNull(implementor, "implementor"));
+      if (map.containsKey(operator)) {
+        map.get(operator).add(operator, implementor);
+      } else {
+        map.put(operator, PairList.of(operator, implementor));
+      }
       return implementor;
     }
 
@@ -1371,9 +1389,27 @@ public class RexImpTable {
           ((ImplementableFunction) udf).getImplementor();
       return wrapAsRexCallImplementor(implementor);
     } else if (operator instanceof SqlTypeConstructorFunction) {
-      return map.get(SqlStdOperatorTable.ROW);
+      final PairList<SqlOperator, RexCallImplementor> implementors =
+          map.get(SqlStdOperatorTable.ROW);
+      if (implementors != null && implementors.size() == 1) {
+        return implementors.get(0).getValue();
+      }
+    } else {
+      final PairList<SqlOperator, RexCallImplementor> implementors =
+          map.get(operator);
+      if (implementors != null) {
+        if (implementors.size() == 1) {
+          return implementors.get(0).getValue();
+        } else {
+          for (Map.Entry<SqlOperator, RexCallImplementor> entry : 
implementors) {
+            if (operator == entry.getKey()) {
+              return entry.getValue();
+            }
+          }
+        }
+      }
     }
-    return map.get(operator);
+    return null;
   }
 
   public @Nullable AggImplementor get(final SqlAggFunction aggregation,
diff --git a/core/src/main/java/org/apache/calcite/sql/SqlBasicFunction.java 
b/core/src/main/java/org/apache/calcite/sql/SqlBasicFunction.java
index 05c1f80c0b..ceba36579e 100644
--- a/core/src/main/java/org/apache/calcite/sql/SqlBasicFunction.java
+++ b/core/src/main/java/org/apache/calcite/sql/SqlBasicFunction.java
@@ -67,7 +67,7 @@ public class SqlBasicFunction extends SqlFunction {
    * @param category Categorization for function
    * @param monotonicityInference Strategy to infer monotonicity of a call
    */
-  protected SqlBasicFunction(String name, SqlKind kind, SqlSyntax syntax,
+  private SqlBasicFunction(String name, SqlKind kind, SqlSyntax syntax,
       boolean deterministic, SqlReturnTypeInference returnTypeInference,
       @Nullable SqlOperandTypeInference operandTypeInference,
       SqlOperandHandler operandHandler,
diff --git 
a/core/src/main/java/org/apache/calcite/sql/fun/SqlLibraryOperators.java 
b/core/src/main/java/org/apache/calcite/sql/fun/SqlLibraryOperators.java
index bbdaab5096..2b6466034d 100644
--- a/core/src/main/java/org/apache/calcite/sql/fun/SqlLibraryOperators.java
+++ b/core/src/main/java/org/apache/calcite/sql/fun/SqlLibraryOperators.java
@@ -42,7 +42,6 @@ import org.apache.calcite.sql.type.SqlTypeFamily;
 import org.apache.calcite.sql.type.SqlTypeName;
 import org.apache.calcite.sql.type.SqlTypeTransforms;
 import org.apache.calcite.sql.type.SqlTypeUtil;
-import org.apache.calcite.sql.validate.SqlMonotonicity;
 import org.apache.calcite.sql.validate.SqlValidator;
 import org.apache.calcite.sql.validate.SqlValidatorUtil;
 import org.apache.calcite.util.Litmus;
@@ -577,10 +576,8 @@ public abstract class SqlLibraryOperators {
    * {@code rep} and returns modified value. */
   @LibraryOperator(libraries = {REDSHIFT})
   public static final SqlFunction REGEXP_REPLACE_2 =
-      new SqlBasicFunction("REGEXP_REPLACE", SqlKind.OTHER_FUNCTION,
-          SqlSyntax.FUNCTION, true, ReturnTypes.VARCHAR_NULLABLE, null,
-          OperandHandlers.DEFAULT, OperandTypes.STRING_STRING, 0,
-          SqlFunctionCategory.STRING, call -> SqlMonotonicity.NOT_MONOTONIC, 
false) { };
+      SqlBasicFunction.create("REGEXP_REPLACE", ReturnTypes.VARCHAR_NULLABLE,
+          OperandTypes.STRING_STRING, SqlFunctionCategory.STRING);
 
   /** The "REGEXP_REPLACE(value, regexp, rep)"
    * function. Replaces all substrings of value that match regexp with
@@ -596,11 +593,10 @@ public abstract class SqlLibraryOperators {
    * pos. */
   @LibraryOperator(libraries = {MYSQL, ORACLE, REDSHIFT})
   public static final SqlFunction REGEXP_REPLACE_4 =
-      new SqlBasicFunction("REGEXP_REPLACE", SqlKind.OTHER_FUNCTION,
-          SqlSyntax.FUNCTION, true, ReturnTypes.VARCHAR_NULLABLE, null,
-          OperandHandlers.DEFAULT, OperandTypes.family(SqlTypeFamily.STRING, 
SqlTypeFamily.STRING,
-              SqlTypeFamily.STRING, SqlTypeFamily.INTEGER),
-          0, SqlFunctionCategory.STRING, call -> 
SqlMonotonicity.NOT_MONOTONIC, false) { };
+      SqlBasicFunction.create("REGEXP_REPLACE", ReturnTypes.VARCHAR_NULLABLE,
+          OperandTypes.family(SqlTypeFamily.STRING, SqlTypeFamily.STRING, 
SqlTypeFamily.STRING,
+              SqlTypeFamily.INTEGER),
+          SqlFunctionCategory.STRING);
 
   /** The "REGEXP_REPLACE(value, regexp, rep, pos, [ occurrence | matchType ])"
    * function. Replaces all substrings of value that match regexp with
@@ -609,15 +605,13 @@ public abstract class SqlLibraryOperators {
    * is a string of flags to apply to the search. */
   @LibraryOperator(libraries = {MYSQL, REDSHIFT})
   public static final SqlFunction REGEXP_REPLACE_5 =
-      new SqlBasicFunction("REGEXP_REPLACE", SqlKind.OTHER_FUNCTION,
-          SqlSyntax.FUNCTION, true, ReturnTypes.VARCHAR_NULLABLE, null,
-          OperandHandlers.DEFAULT,
+      SqlBasicFunction.create("REGEXP_REPLACE", ReturnTypes.VARCHAR_NULLABLE,
           OperandTypes.or(
               OperandTypes.family(SqlTypeFamily.STRING, SqlTypeFamily.STRING,
                   SqlTypeFamily.STRING, SqlTypeFamily.INTEGER, 
SqlTypeFamily.INTEGER),
               OperandTypes.family(SqlTypeFamily.STRING, SqlTypeFamily.STRING,
                   SqlTypeFamily.STRING, SqlTypeFamily.INTEGER, 
SqlTypeFamily.STRING)),
-          0, SqlFunctionCategory.STRING, call -> 
SqlMonotonicity.NOT_MONOTONIC, false) { };
+          SqlFunctionCategory.STRING);
 
   /** The "REGEXP_REPLACE(value, regexp, rep, pos, matchType)"
    * function. Replaces all substrings of value that match regexp with
@@ -625,11 +619,10 @@ public abstract class SqlLibraryOperators {
    * pos. Replace only the occurrence match or all matches if occurrence is 0. 
*/
   @LibraryOperator(libraries = {ORACLE})
   public static final SqlFunction REGEXP_REPLACE_5_ORACLE =
-      new SqlBasicFunction("REGEXP_REPLACE", SqlKind.OTHER_FUNCTION,
-          SqlSyntax.FUNCTION, true, ReturnTypes.VARCHAR_NULLABLE, null,
-          OperandHandlers.DEFAULT, OperandTypes.family(SqlTypeFamily.STRING, 
SqlTypeFamily.STRING,
-          SqlTypeFamily.STRING, SqlTypeFamily.INTEGER, SqlTypeFamily.INTEGER),
-          0, SqlFunctionCategory.STRING, call -> 
SqlMonotonicity.NOT_MONOTONIC, false) { };
+      SqlBasicFunction.create("REGEXP_REPLACE", ReturnTypes.VARCHAR_NULLABLE,
+          OperandTypes.family(SqlTypeFamily.STRING, SqlTypeFamily.STRING,
+              SqlTypeFamily.STRING, SqlTypeFamily.INTEGER, 
SqlTypeFamily.INTEGER),
+          SqlFunctionCategory.STRING);
 
   /** The "REGEXP_REPLACE(value, regexp, rep, pos, occurrence, matchType)"
    * function. Replaces all substrings of value that match regexp with
@@ -638,41 +631,34 @@ public abstract class SqlLibraryOperators {
    * is a string of flags to apply to the search. */
   @LibraryOperator(libraries = {MYSQL, ORACLE, REDSHIFT})
   public static final SqlFunction REGEXP_REPLACE_6 =
-      new SqlBasicFunction("REGEXP_REPLACE", SqlKind.OTHER_FUNCTION,
-          SqlSyntax.FUNCTION, true, ReturnTypes.VARCHAR_NULLABLE, null,
-          OperandHandlers.DEFAULT, OperandTypes.family(SqlTypeFamily.STRING, 
SqlTypeFamily.STRING,
-          SqlTypeFamily.STRING, SqlTypeFamily.INTEGER, SqlTypeFamily.INTEGER, 
SqlTypeFamily.STRING),
-          0, SqlFunctionCategory.STRING, call -> 
SqlMonotonicity.NOT_MONOTONIC, false) { };
+      SqlBasicFunction.create("REGEXP_REPLACE", ReturnTypes.VARCHAR_NULLABLE,
+          OperandTypes.family(SqlTypeFamily.STRING, SqlTypeFamily.STRING, 
SqlTypeFamily.STRING,
+              SqlTypeFamily.INTEGER, SqlTypeFamily.INTEGER, 
SqlTypeFamily.STRING),
+          SqlFunctionCategory.STRING);
 
   /** The "REGEXP_REPLACE(value, regexp, rep)"
    * function. Replaces all substrings of value that match regexp with
    * {@code rep} and returns modified value. */
   @LibraryOperator(libraries = {BIG_QUERY})
   public static final SqlFunction REGEXP_REPLACE_BIG_QUERY_3 =
-      new SqlBasicFunction("REGEXP_REPLACE", SqlKind.OTHER_FUNCTION,
-          SqlSyntax.FUNCTION, true, ReturnTypes.VARCHAR_NULLABLE, null,
-          OperandHandlers.DEFAULT, OperandTypes.STRING_STRING_STRING, 0,
-          SqlFunctionCategory.STRING, call -> SqlMonotonicity.NOT_MONOTONIC, 
false) { };
+      SqlBasicFunction.create("REGEXP_REPLACE", ReturnTypes.VARCHAR_NULLABLE,
+          OperandTypes.STRING_STRING_STRING, SqlFunctionCategory.STRING);
 
   /** The "REGEXP_REPLACE(value, regexp, rep)"
    * function. Replaces all substrings of value that match regexp with
    * {@code rep} and returns modified value. */
   @LibraryOperator(libraries = {POSTGRESQL}, exceptLibraries = REDSHIFT)
   public static final SqlFunction REGEXP_REPLACE_PG_3 =
-      new SqlBasicFunction("REGEXP_REPLACE", SqlKind.OTHER_FUNCTION,
-          SqlSyntax.FUNCTION, true, ReturnTypes.VARCHAR_NULLABLE, null,
-          OperandHandlers.DEFAULT, OperandTypes.STRING_STRING_STRING, 0,
-          SqlFunctionCategory.STRING, call -> SqlMonotonicity.NOT_MONOTONIC, 
false) { };
+      SqlBasicFunction.create("REGEXP_REPLACE", ReturnTypes.VARCHAR_NULLABLE,
+          OperandTypes.STRING_STRING_STRING, SqlFunctionCategory.STRING);
 
   /** The "REGEXP_REPLACE(value, regexp, rep, flags)"
    * function. Replaces all substrings of value that match regexp with
    * {@code rep} and returns modified value. flags are applied to the search. 
*/
   @LibraryOperator(libraries = {POSTGRESQL}, exceptLibraries = REDSHIFT)
   public static final SqlFunction REGEXP_REPLACE_PG_4 =
-      new SqlBasicFunction("REGEXP_REPLACE", SqlKind.OTHER_FUNCTION,
-          SqlSyntax.FUNCTION, true, ReturnTypes.VARCHAR_NULLABLE, null,
-          OperandHandlers.DEFAULT, OperandTypes.STRING_STRING_STRING_STRING, 0,
-          SqlFunctionCategory.STRING, call -> SqlMonotonicity.NOT_MONOTONIC, 
false) { };
+      SqlBasicFunction.create("REGEXP_REPLACE", ReturnTypes.VARCHAR_NULLABLE,
+          OperandTypes.STRING_STRING_STRING_STRING, 
SqlFunctionCategory.STRING);
 
   /** The "REGEXP_SUBSTR(value, regexp[, position[, occurrence]])" function.
    * Returns the substring in value that matches the regexp. Returns NULL if 
there is no match. */
@@ -1875,10 +1861,8 @@ public abstract class SqlLibraryOperators {
    * converts {@code timestamp} to string according to the given {@code 
format}. */
   @LibraryOperator(libraries = {POSTGRESQL}, exceptLibraries = {REDSHIFT})
   public static final SqlFunction TO_CHAR_PG =
-      new SqlBasicFunction("TO_CHAR", SqlKind.OTHER_FUNCTION,
-      SqlSyntax.FUNCTION, true, ReturnTypes.VARCHAR_NULLABLE, null,
-      OperandHandlers.DEFAULT, OperandTypes.TIMESTAMP_STRING, 0,
-          SqlFunctionCategory.TIMEDATE, call -> SqlMonotonicity.NOT_MONOTONIC, 
false) { };
+      SqlBasicFunction.create("TO_CHAR", ReturnTypes.VARCHAR_NULLABLE,
+          OperandTypes.TIMESTAMP_STRING, SqlFunctionCategory.TIMEDATE);
 
   /** The "TO_DATE(string1, string2)" function; casts string1
    * to a DATE using the format specified in string2. */
@@ -1893,10 +1877,8 @@ public abstract class SqlLibraryOperators {
    * to a DATE using the format specified in string2. */
   @LibraryOperator(libraries = {POSTGRESQL}, exceptLibraries = {REDSHIFT})
   public static final SqlFunction TO_DATE_PG =
-      new SqlBasicFunction("TO_DATE", SqlKind.OTHER_FUNCTION,
-          SqlSyntax.FUNCTION, true, ReturnTypes.DATE_NULLABLE, null,
-          OperandHandlers.DEFAULT, OperandTypes.STRING_STRING, 0,
-          SqlFunctionCategory.TIMEDATE, call -> SqlMonotonicity.NOT_MONOTONIC, 
false) { };
+      SqlBasicFunction.create("TO_DATE", ReturnTypes.DATE_NULLABLE,
+          OperandTypes.STRING_STRING, SqlFunctionCategory.TIMEDATE);
 
   /** The "TO_TIMESTAMP(string1, string2)" function; casts string1
    * to a TIMESTAMP using the format specified in string2. */
@@ -1911,10 +1893,8 @@ public abstract class SqlLibraryOperators {
    * to a TIMESTAMP using the format specified in string2. */
   @LibraryOperator(libraries = {POSTGRESQL}, exceptLibraries = {REDSHIFT})
   public static final SqlFunction TO_TIMESTAMP_PG =
-      new SqlBasicFunction("TO_TIMESTAMP", SqlKind.OTHER_FUNCTION,
-          SqlSyntax.FUNCTION, true, ReturnTypes.TIMESTAMP_TZ_NULLABLE, null,
-          OperandHandlers.DEFAULT, OperandTypes.STRING_STRING, 0,
-          SqlFunctionCategory.TIMEDATE, call -> SqlMonotonicity.NOT_MONOTONIC, 
false) { };
+      SqlBasicFunction.create("TO_TIMESTAMP", 
ReturnTypes.TIMESTAMP_TZ_NULLABLE,
+          OperandTypes.STRING_STRING, SqlFunctionCategory.TIMEDATE);
 
   /**
    * The "PARSE_TIME(string, string)" function (BigQuery);
@@ -2512,10 +2492,8 @@ public abstract class SqlLibraryOperators {
    * to base numeric1.*/
   @LibraryOperator(libraries = {POSTGRESQL}, exceptLibraries = {REDSHIFT})
   public static final SqlFunction LOG_POSTGRES =
-      new SqlBasicFunction("LOG", SqlKind.LOG,
-          SqlSyntax.FUNCTION, true, ReturnTypes.DOUBLE_NULLABLE, null,
-          OperandHandlers.DEFAULT, OperandTypes.NUMERIC_OPTIONAL_NUMERIC, 0,
-          SqlFunctionCategory.NUMERIC, call -> SqlMonotonicity.NOT_MONOTONIC, 
false) { };
+      SqlBasicFunction.create("LOG", ReturnTypes.DOUBLE_NULLABLE,
+          OperandTypes.NUMERIC_OPTIONAL_NUMERIC, SqlFunctionCategory.NUMERIC);
 
   /** The "LOG2(numeric)" function. Returns the base 2 logarithm of numeric. */
   @LibraryOperator(libraries = {MYSQL, SPARK})

Reply via email to