This is an automated email from the ASF dual-hosted git repository.

mbudiu pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/calcite.git


The following commit(s) were added to refs/heads/main by this push:
     new 82a3219e1b [CALCITE-6312] Add LOG function (enabled in PostgreSQL 
library)
82a3219e1b is described below

commit 82a3219e1bc8ffac9bbd20be3dc8c11f3882ba50
Author: caicancai <[email protected]>
AuthorDate: Tue Jul 2 00:43:46 2024 +0800

    [CALCITE-6312] Add LOG function (enabled in PostgreSQL library)
---
 .../calcite/adapter/enumerable/RexImpTable.java    | 100 ++++++++++-----------
 .../org/apache/calcite/runtime/SqlFunctions.java   |  63 +++++++++----
 .../calcite/sql/fun/SqlLibraryOperators.java       |  12 ++-
 .../org/apache/calcite/util/BuiltInMethod.java     |   2 +-
 site/_docs/reference.md                            |   5 +-
 .../org/apache/calcite/test/SqlOperatorTest.java   |  62 ++++++++++---
 6 files changed, 158 insertions(+), 86 deletions(-)

diff --git 
a/core/src/main/java/org/apache/calcite/adapter/enumerable/RexImpTable.java 
b/core/src/main/java/org/apache/calcite/adapter/enumerable/RexImpTable.java
index d8a3bf6f34..7e17beb93e 100644
--- a/core/src/main/java/org/apache/calcite/adapter/enumerable/RexImpTable.java
+++ b/core/src/main/java/org/apache/calcite/adapter/enumerable/RexImpTable.java
@@ -65,6 +65,7 @@ import org.apache.calcite.sql.SqlWindowTableFunction;
 import org.apache.calcite.sql.fun.SqlItemOperator;
 import org.apache.calcite.sql.fun.SqlJsonArrayAggAggFunction;
 import org.apache.calcite.sql.fun.SqlJsonObjectAggAggFunction;
+import org.apache.calcite.sql.fun.SqlLibrary;
 import org.apache.calcite.sql.fun.SqlQuantifyOperator;
 import org.apache.calcite.sql.fun.SqlStdOperatorTable;
 import org.apache.calcite.sql.fun.SqlTrimFunction;
@@ -75,6 +76,7 @@ import org.apache.calcite.sql.validate.SqlUserDefinedFunction;
 import org.apache.calcite.sql.validate.SqlUserDefinedTableFunction;
 import org.apache.calcite.sql.validate.SqlUserDefinedTableMacro;
 import org.apache.calcite.util.BuiltInMethod;
+import org.apache.calcite.util.Pair;
 import org.apache.calcite.util.Util;
 
 import com.google.common.collect.ImmutableList;
@@ -218,6 +220,7 @@ import static 
org.apache.calcite.sql.fun.SqlLibraryOperators.LOG2;
 import static org.apache.calcite.sql.fun.SqlLibraryOperators.LOGICAL_AND;
 import static org.apache.calcite.sql.fun.SqlLibraryOperators.LOGICAL_OR;
 import static org.apache.calcite.sql.fun.SqlLibraryOperators.LOG_MYSQL;
+import static org.apache.calcite.sql.fun.SqlLibraryOperators.LOG_POSTGRES;
 import static org.apache.calcite.sql.fun.SqlLibraryOperators.LPAD;
 import static org.apache.calcite.sql.fun.SqlLibraryOperators.MAP;
 import static org.apache.calcite.sql.fun.SqlLibraryOperators.MAP_CONCAT;
@@ -667,12 +670,13 @@ public class RexImpTable {
       defineMethod(POWER_PG, BuiltInMethod.POWER_PG.method, NullPolicy.STRICT);
       defineMethod(ABS, BuiltInMethod.ABS.method, NullPolicy.STRICT);
 
-      map.put(LN, new LogImplementor());
-      map.put(LOG, new LogImplementor());
-      map.put(LOG10, new LogImplementor());
+      map.put(LN, new LogImplementor(SqlLibrary.BIG_QUERY));
+      map.put(LOG, new LogImplementor(SqlLibrary.BIG_QUERY));
+      map.put(LOG10, new LogImplementor(SqlLibrary.BIG_QUERY));
 
-      map.put(LOG_MYSQL, new LogMysqlImplementor());
-      map.put(LOG2, new LogMysqlImplementor());
+      map.put(LOG_POSTGRES, new LogImplementor(SqlLibrary.POSTGRESQL));
+      map.put(LOG_MYSQL, new LogImplementor(SqlLibrary.MYSQL));
+      map.put(LOG2, new LogImplementor(SqlLibrary.MYSQL));
 
       defineReflective(RAND, BuiltInMethod.RAND.method,
           BuiltInMethod.RAND_SEED.method);
@@ -4217,67 +4221,57 @@ public class RexImpTable {
    * appropriate base (i.e. base e for LN).
    */
   private static class LogImplementor extends AbstractRexCallImplementor {
-    LogImplementor() {
+    private final SqlLibrary library;
+    LogImplementor(SqlLibrary library) {
       super("log", NullPolicy.STRICT, true);
+      this.library = library;
     }
 
     @Override Expression implementSafe(final RexToLixTranslator translator,
         final RexCall call, final List<Expression> argValueList) {
-      return Expressions.call(BuiltInMethod.LOG.method, args(call, 
argValueList));
-    }
-
+      return Expressions.
+          call(BuiltInMethod.LOG.method, args(call, argValueList, library));
+    }
+
+    /**
+     * This method is used to handle the implementation of different log 
functions.
+     * It generates the corresponding expression list based on the input 
function name
+     * and argument list.
+     *
+     * @param call The RexCall that contains the function call information.
+     * @param argValueList The list of argument expressions.
+     * @param library The SQL library that the function belongs to.
+     * @return A list of expressions that represents the implementation of the 
log function.
+     */
     private static List<Expression> args(RexCall call,
-        List<Expression> argValueList) {
-      Expression operand0 = argValueList.get(0);
-      final Expressions.FluentList<Expression> list = 
Expressions.list(operand0);
-      switch (call.getOperator().getName()) {
-      case "LOG":
-        if (argValueList.size() == 2) {
-          return 
list.append(argValueList.get(1)).append(Expressions.constant(0));
-        }
-        // fall through
-      case "LN":
-        return 
list.append(Expressions.constant(Math.exp(1))).append(Expressions.constant(0));
-      case "LOG10":
-        return 
list.append(Expressions.constant(BigDecimal.TEN)).append(Expressions.constant(0));
-      default:
-        throw new AssertionError("Operator not found: " + call.getOperator());
+        List<Expression> argValueList, SqlLibrary library) {
+      Pair<Expression, Expression> operands;
+      Expression operand0;
+      Expression operand1;
+      if (argValueList.size() == 1) {
+        operands = library == SqlLibrary.POSTGRESQL
+            ? Pair.of(argValueList.get(0), 
Expressions.constant(BigDecimal.TEN))
+            : Pair.of(argValueList.get(0), Expressions.constant(Math.exp(1)));
+      } else {
+        operands = library == SqlLibrary.BIG_QUERY
+            ? Pair.of(argValueList.get(0), argValueList.get(1))
+            : Pair.of(argValueList.get(1), argValueList.get(0));
       }
-    }
-  }
-
-  /** Implementor for the {@code LN}, {@code LOG}, {@code LOG2} and {@code 
LOG10} operators
-   *  on Mysql and Spark library
-   *
-   * <p>Handles all logarithm functions using log rules to determine the
-   * appropriate base (i.e. base e for LN).
-   */
-  private static class LogMysqlImplementor extends AbstractRexCallImplementor {
-    LogMysqlImplementor() {
-      super("log", NullPolicy.STRICT, true);
-    }
-
-    @Override Expression implementSafe(final RexToLixTranslator translator,
-        final RexCall call, final List<Expression> argValueList) {
-      return Expressions.call(BuiltInMethod.LOG.method, args(call, 
argValueList));
-    }
-
-    private static List<Expression> args(RexCall call,
-        List<Expression> argValueList) {
-      Expression operand0 = argValueList.get(0);
+      operand0 = operands.left;
+      operand1 = operands.right;
+      boolean nonPositiveIsNull = library == SqlLibrary.MYSQL ? true : false;
       final Expressions.FluentList<Expression> list = 
Expressions.list(operand0);
       switch (call.getOperator().getName()) {
       case "LOG":
-        if (argValueList.size() == 2) {
-          return 
list.append(argValueList.get(1)).append(Expressions.constant(1));
-        }
-        // fall through
+        return 
list.append(operand1).append(Expressions.constant(nonPositiveIsNull));
       case "LN":
-        return 
list.append(Expressions.constant(Math.exp(1))).append(Expressions.constant(1));
+        return list.append(Expressions.constant(Math.exp(1)))
+            .append(Expressions.constant(nonPositiveIsNull));
       case "LOG2":
-        return 
list.append(Expressions.constant(2)).append(Expressions.constant(1));
+        return 
list.append(Expressions.constant(2)).append(Expressions.constant(nonPositiveIsNull));
       case "LOG10":
-        return 
list.append(Expressions.constant(BigDecimal.TEN)).append(Expressions.constant(1));
+        return list.append(Expressions.constant(BigDecimal.TEN))
+            .append(Expressions.constant(nonPositiveIsNull));
       default:
         throw new AssertionError("Operator not found: " + call.getOperator());
       }
diff --git a/core/src/main/java/org/apache/calcite/runtime/SqlFunctions.java 
b/core/src/main/java/org/apache/calcite/runtime/SqlFunctions.java
index b41990060b..5ff6e6a7ed 100644
--- a/core/src/main/java/org/apache/calcite/runtime/SqlFunctions.java
+++ b/core/src/main/java/org/apache/calcite/runtime/SqlFunctions.java
@@ -2863,41 +2863,66 @@ public class SqlFunctions {
     return Math.pow(b0.doubleValue(), b1.doubleValue());
   }
 
-
   // LN, LOG, LOG10, LOG2
 
-  /** SQL {@code LOG(number, number2)} function applied to double values. */
-  public static @Nullable Double log(double number, double number2, int 
nullFlag) {
-    if (nullFlag == 1 && number <= 0) {
+  /**
+   * SQL {@code LOG(number, base)} function applied to double values.
+   *
+   * @param nonPositiveIsNull if true return null for non-positive values
+   */
+  public static @Nullable Double log(double number, double base, boolean 
nonPositiveIsNull) {
+    if (nonPositiveIsNull && number <= 0) {
       return null;
     }
-    return Math.log(number) / Math.log(number2);
+    if (number <= 0 || base <= 0) {
+      throw new IllegalArgumentException("Cannot take logarithm of zero or 
negative number");
+    }
+    return Math.log(number) / Math.log(base);
   }
 
-  /** SQL {@code LOG(number, number2)} function applied to
-   * double and BigDecimal values. */
-  public static @Nullable Double log(double number, BigDecimal number2, int 
nullFlag) {
-    if (nullFlag == 1 && number <= 0) {
+  /** SQL {@code LOG(number, base)} function applied to
+   * double and BigDecimal values.
+   *
+   * @param nonPositiveIsNull if true return null for non-positive values
+   */
+  public static @Nullable Double log(double number, BigDecimal base, boolean 
nonPositiveIsNull) {
+    if (nonPositiveIsNull && number <= 0) {
       return null;
     }
-    return Math.log(number) / Math.log(number2.doubleValue());
+    if (number <= 0 || base.doubleValue() <= 0) {
+      throw new IllegalArgumentException("Cannot take logarithm of zero or 
negative number");
+    }
+    return  Math.log(number) / Math.log(base.doubleValue());
   }
 
-  /** SQL {@code LOG(number, number2)} function applied to
-   * BigDecimal and double values. */
-  public static @Nullable Double log(BigDecimal number, double number2, int 
nullFlag) {
-    if (nullFlag == 1 && number.doubleValue() <= 0) {
+  /** SQL {@code LOG(number, base)} function applied to
+   * BigDecimal and double values.
+   *
+   * @param nonPositiveIsNull if true return null for non-positive values
+   */
+  public static @Nullable Double log(BigDecimal number, double base, Boolean 
nonPositiveIsNull) {
+    if (nonPositiveIsNull && number.doubleValue() <= 0) {
       return null;
     }
-    return Math.log(number.doubleValue()) / Math.log(number2);
+    if (number.doubleValue() <= 0 || base <= 0) {
+      throw new IllegalArgumentException("Cannot take logarithm of zero or 
negative number");
+    }
+    return Math.log(number.doubleValue()) / Math.log(base);
   }
 
-  /** SQL {@code LOG(number, number2)} function applied to double values. */
-  public static @Nullable Double log(BigDecimal number, BigDecimal number2, 
int nullFlag) {
-    if (nullFlag == 1 && number.doubleValue() <= 0) {
+  /** SQL {@code LOG(number, base)} function applied to double values.
+   *
+   * @param nonPositiveIsNull if true return null for non-positive values
+   */
+  public static @Nullable Double log(BigDecimal number, BigDecimal base,
+      Boolean nonPositiveIsNull) {
+    if (nonPositiveIsNull && number.doubleValue() <= 0) {
       return null;
     }
-    return Math.log(number.doubleValue()) / Math.log(number2.doubleValue());
+    if (number.doubleValue() <= 0 || base.doubleValue() <= 0) {
+      throw new IllegalArgumentException("Cannot take logarithm of zero or 
negative number");
+    }
+    return Math.log(number.doubleValue()) / Math.log(base.doubleValue());
   }
 
   // MOD
diff --git 
a/core/src/main/java/org/apache/calcite/sql/fun/SqlLibraryOperators.java 
b/core/src/main/java/org/apache/calcite/sql/fun/SqlLibraryOperators.java
index 26db9029c2..c582885f87 100644
--- a/core/src/main/java/org/apache/calcite/sql/fun/SqlLibraryOperators.java
+++ b/core/src/main/java/org/apache/calcite/sql/fun/SqlLibraryOperators.java
@@ -2319,13 +2319,23 @@ public abstract class SqlLibraryOperators {
           OperandTypes.NUMERIC_OPTIONAL_NUMERIC,
           SqlFunctionCategory.NUMERIC);
 
-  /** The "LOG(numeric, numeric1)" function. Returns the base numeric1 
logarithm of numeric. */
+  /** The "LOG(numeric1 [, numeric2 ]) " function. Returns the logarithm of 
numeric2
+   * to base numeric1.*/
   @LibraryOperator(libraries = {MYSQL, SPARK})
   public static final SqlFunction LOG_MYSQL =
       SqlBasicFunction.create(SqlKind.LOG,
           ReturnTypes.DOUBLE_FORCE_NULLABLE,
           OperandTypes.NUMERIC_OPTIONAL_NUMERIC);
 
+  /** The "LOG(numeric1 [, numeric2 ]) " function. Returns the logarithm of 
numeric2
+   * to base numeric1.*/
+  @LibraryOperator(libraries = {POSTGRESQL}, exceptLibraries = {REDSHIFT})
+  public static final SqlFunction LOG_POSTGRES =
+      new SqlBasicFunction("LOG", SqlKind.LOG,
+          SqlSyntax.FUNCTION, true, ReturnTypes.DOUBLE_NULLABLE, null,
+          OperandHandlers.DEFAULT, OperandTypes.NUMERIC_OPTIONAL_NUMERIC, 0,
+          SqlFunctionCategory.NUMERIC, call -> SqlMonotonicity.NOT_MONOTONIC, 
false) { };
+
   /** The "LOG2(numeric)" function. Returns the base 2 logarithm of numeric. */
   @LibraryOperator(libraries = {MYSQL, SPARK})
   public static final SqlFunction LOG2 =
diff --git a/core/src/main/java/org/apache/calcite/util/BuiltInMethod.java 
b/core/src/main/java/org/apache/calcite/util/BuiltInMethod.java
index d90f00222a..5f88ecc586 100644
--- a/core/src/main/java/org/apache/calcite/util/BuiltInMethod.java
+++ b/core/src/main/java/org/apache/calcite/util/BuiltInMethod.java
@@ -526,7 +526,7 @@ public enum BuiltInMethod {
   SAFE_DIVIDE(SqlFunctions.class, "safeDivide", double.class, double.class),
   SAFE_MULTIPLY(SqlFunctions.class, "safeMultiply", double.class, 
double.class),
   SAFE_SUBTRACT(SqlFunctions.class, "safeSubtract", double.class, 
double.class),
-  LOG(SqlFunctions.class, "log", long.class, long.class, int.class),
+  LOG(SqlFunctions.class, "log", long.class, long.class, boolean.class),
   SEC(SqlFunctions.class, "sec", double.class),
   SECH(SqlFunctions.class, "sech", double.class),
   SIGN(SqlFunctions.class, "sign", long.class),
diff --git a/site/_docs/reference.md b/site/_docs/reference.md
index b1567c8da7..a39f9f1db6 100644
--- a/site/_docs/reference.md
+++ b/site/_docs/reference.md
@@ -2807,8 +2807,9 @@ In the following:
 | f s | LEN(string)                                  | Equivalent to 
`CHAR_LENGTH(string)`
 | b f s | LENGTH(string)                             | Equivalent to 
`CHAR_LENGTH(string)`
 | h s | LEVENSHTEIN(string1, string2)                | Returns the Levenshtein 
distance between *string1* and *string2*
-| b | LOG(numeric1 [, numeric2 ])                    | Returns the logarithm 
of *numeric1* to base *numeric2*, or base e if *numeric2* is not present
-| m s | LOG(numeric1 [, numeric2 ])                  | Returns the logarithm 
of *numeric1* to base *numeric2*, or base e if *numeric2* is not present, or 
null if *numeric1* is 0 or negative
+| b | LOG(numeric1 [, base ])                        | Returns the logarithm 
of *numeric1* to base *base*, or base e if *base* is not present, or error if 
*numeric1* is 0 or negative
+| m s | LOG([, base ], numeric1)                     | Returns the logarithm 
of *numeric1* to base *base*, or base e if *base* is not present, or null if 
*numeric1* is 0 or negative
+| p | LOG([, base ], numeric1 )                      | Returns the logarithm 
of *numeric1* to base *base*, or base 10 if *numeric1* is not present, or error 
if *numeric1* is 0 or negative
 | m s | LOG2(numeric)                                | Returns the base 2 
logarithm of *numeric*
 | b o p r s | LPAD(string, length [, pattern ])      | Returns a string or 
bytes value that consists of *string* prepended to *length* with *pattern*
 | b | TO_BASE32(string)                              | Converts the *string* 
to base-32 encoded form and returns an encoded string
diff --git a/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java 
b/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java
index 8b94908fbd..f8bb5be715 100644
--- a/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java
+++ b/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java
@@ -7103,14 +7103,22 @@ public class SqlOperatorTest {
         isWithin(2.0, 0.000001));
     f.checkScalarApprox("log(10, 100)", "DOUBLE NOT NULL",
         isWithin(0.5, 0.000001));
-    f.checkScalarApprox("log(cast(10e6 as double), 10)", "DOUBLE NOT NULL",
+    f.checkScalarApprox("log(cast(1e7 as double), 10)", "DOUBLE NOT NULL",
         isWithin(7.0, 0.000001));
     f.checkScalarApprox("log(cast(10e8 as float), 10)", "DOUBLE NOT NULL",
         isWithin(9.0, 0.000001));
     f.checkScalarApprox("log(cast(10e-3 as real), 10)", "DOUBLE NOT NULL",
         isWithin(-2.0, 0.000001));
+    f.checkScalarApprox("log(10)", "DOUBLE NOT NULL",
+        isWithin(2.302585092994046, 0.000001));
     f.checkNull("log(cast(null as real), 10)");
     f.checkNull("log(10, cast(null as real))");
+    f.checkFails("log(0)",
+        "Cannot take logarithm of zero or negative number", true);
+    f.checkFails("log(0, 64)",
+        "Cannot take logarithm of zero or negative number", true);
+    f.checkFails("log(64, 0)",
+        "Cannot take logarithm of zero or negative number", true);
   }
 
   /** Test case for
@@ -7161,24 +7169,27 @@ public class SqlOperatorTest {
     final Consumer<SqlOperatorFixture> consumer = f -> {
       f.checkScalarApprox("log(10, 10)", "DOUBLE",
           isWithin(1.0, 0.000001));
-      f.checkScalarApprox("log(64, 8)", "DOUBLE",
+      f.checkScalarApprox("log(8, 64)", "DOUBLE",
           isWithin(2.0, 0.000001));
-      f.checkScalarApprox("log(27,3)", "DOUBLE",
+      f.checkScalarApprox("log(3,27)", "DOUBLE",
           isWithin(3.0, 0.000001));
-      f.checkScalarApprox("log(100, 10)", "DOUBLE",
-          isWithin(2.0, 0.000001));
       f.checkScalarApprox("log(10, 100)", "DOUBLE",
+          isWithin(2.0, 0.000001));
+      f.checkScalarApprox("log(100, 10)", "DOUBLE",
           isWithin(0.5, 0.000001));
-      f.checkScalarApprox("log(cast(10e6 as double), 10)", "DOUBLE",
+      f.checkScalarApprox("log(10, cast(1e7 as double))", "DOUBLE",
           isWithin(7.0, 0.000001));
-      f.checkScalarApprox("log(cast(10e8 as float), 10)", "DOUBLE",
+      f.checkScalarApprox("log(10, cast(1e9 as float))", "DOUBLE",
           isWithin(9.0, 0.000001));
-      f.checkScalarApprox("log(cast(10e-3 as real), 10)", "DOUBLE",
+      // real type is equivalent to double type
+      f.checkScalarApprox("log(10, cast(1e-2 as real))", "DOUBLE",
           isWithin(-2.0, 0.000001));
+      f.checkScalarApprox("log(10)", "DOUBLE",
+          isWithin(2.302585092994046, 0.000001));
       f.checkNull("log(cast(null as real), 10)");
       f.checkNull("log(10, cast(null as real))");
-      f.checkNull("log(0, 2)");
-      f.checkNull("log(0,-2)");
+      f.checkNull("log(2, 0)");
+      f.checkNull("log(-2,0)");
       f.checkNull("log(0, +0.0)");
       f.checkNull("log(0, 0.0)");
       f.checkNull("log(null)");
@@ -7189,6 +7200,37 @@ public class SqlOperatorTest {
     f0.forEachLibrary(list(SqlLibrary.MYSQL, SqlLibrary.SPARK), consumer);
   }
 
+  /** Test case for
+   * <a 
href="https://issues.apache.org/jira/browse/CALCITE-6312";>[CALCITE-6312]
+   * Add LOG function (enabled in PostgreSQL library)</a>. */
+  @Test void testPostgresLogFunc() {
+    final SqlOperatorFixture f0 = fixture()
+        .setFor(SqlLibraryOperators.LOG_POSTGRES, VmName.EXPAND);
+    f0.checkFails("^log(100, 10)^",
+        "No match found for function signature LOG\\(<NUMERIC>, <NUMERIC>\\)", 
false);
+    final SqlOperatorFixture f = f0.withLibrary(SqlLibrary.POSTGRESQL);
+    f.checkScalar("log(10, 10)", 1.0,
+        "DOUBLE NOT NULL");
+    f.checkScalar("log(8, 64)", 2.0,
+        "DOUBLE NOT NULL");
+    f.checkScalar("log(10, 100)", 2.0,
+        "DOUBLE NOT NULL");
+    f.checkScalar("log(100, 10)", 0.5,
+        "DOUBLE NOT NULL");
+    f.checkScalar("log(10, cast(1e7 as double))", 7.0,
+        "DOUBLE NOT NULL");
+    f.checkScalar("log(10)", 1.0,
+        "DOUBLE NOT NULL");
+    f.checkNull("log(cast(null as real), 10)");
+    f.checkNull("log(10, cast(null as real))");
+    f.checkFails("log(0)",
+        "Cannot take logarithm of zero or negative number", true);
+    f.checkFails("log(0, 64)",
+        "Cannot take logarithm of zero or negative number", true);
+    f.checkFails("log(64, 0)",
+        "Cannot take logarithm of zero or negative number", true);
+  }
+
   @Test void testRandFunc() {
     final SqlOperatorFixture f = fixture();
     f.setFor(SqlStdOperatorTable.RAND, VmName.EXPAND);

Reply via email to