This is an automated email from the ASF dual-hosted git repository.

tanner pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/calcite.git


The following commit(s) were added to refs/heads/main by this push:
     new 1b8dc60b76 [CALCITE-5735] Add SAFE_MULTIPLY function (enabled for 
BigQuery)
1b8dc60b76 is described below

commit 1b8dc60b7673d83445f06ecaebd5f89dfd5781f6
Author: Tanner Clary <[email protected]>
AuthorDate: Mon May 22 23:54:29 2023 -0700

    [CALCITE-5735] Add SAFE_MULTIPLY function (enabled for BigQuery)
---
 babel/src/test/resources/sql/big-query.iq          | 69 +++++++++++++++++++
 .../calcite/adapter/enumerable/RexImpTable.java    | 27 ++++++++
 .../org/apache/calcite/runtime/SqlFunctions.java   | 79 ++++++++++++++++++++++
 .../calcite/sql/fun/SqlLibraryOperators.java       |  9 +++
 .../org/apache/calcite/sql/type/ReturnTypes.java   |  9 +++
 site/_docs/reference.md                            |  1 +
 .../org/apache/calcite/test/SqlOperatorTest.java   | 71 +++++++++++++++++++
 7 files changed, 265 insertions(+)

diff --git a/babel/src/test/resources/sql/big-query.iq 
b/babel/src/test/resources/sql/big-query.iq
index 429ec7b830..6a6970dfba 100755
--- a/babel/src/test/resources/sql/big-query.iq
+++ b/babel/src/test/resources/sql/big-query.iq
@@ -600,6 +600,75 @@ FROM t;
 !ok
 !}
 
+#####################################################################
+# SAFE_MULTIPLY
+#
+# SAFE_MULTIPLY(value1, value2)
+#
+# Equivalent to the mulitply operator (*), but returns NULL if 
overflow/underflow occurs.
+SELECT SAFE_MULTIPLY(5, 4) as result;
++--------+
+| result |
++--------+
+|     20 |
++--------+
+(1 row)
+
+!ok
+
+# Overflow occurs if result is greater than 2^63 - 1
+SELECT SAFE_MULTIPLY(9223372036854775807, 2) as overflow_result;
++-----------------+
+| overflow_result |
++-----------------+
+|                 |
++-----------------+
+(1 row)
+
+!ok
+
+# Underflow occurs if result is less than -2^63
+SELECT SAFE_MULTIPLY(-9223372036854775806, 3) as underflow_result;
++------------------+
+| underflow_result |
++------------------+
+|                  |
++------------------+
+(1 row)
+
+!ok
+
+SELECT SAFE_MULTIPLY(CAST(1.7e308 as DOUBLE), CAST(3 as BIGINT)) as 
double_overflow;
++-----------------+
+| double_overflow |
++-----------------+
+|                 |
++-----------------+
+(1 row)
+
+!ok
+
+SELECT SAFE_MULTIPLY(CAST(-3.5e75 AS DECIMAL(76, 0)), CAST(10 AS BIGINT)) as 
decimal_overflow;
++------------------+
+| decimal_overflow |
++------------------+
+|                  |
++------------------+
+(1 row)
+
+!ok
+
+# NaN arguments should return NaN
+SELECT SAFE_MULTIPLY(CAST('NaN' AS DOUBLE), CAST(3 as BIGINT)) as NaN_result;
++------------+
+| NaN_result |
++------------+
+|        NaN |
++------------+
+(1 row)
+
+!ok
+
 #####################################################################
 # NOT EQUAL Operator (value1 != value2)
 #
diff --git 
a/core/src/main/java/org/apache/calcite/adapter/enumerable/RexImpTable.java 
b/core/src/main/java/org/apache/calcite/adapter/enumerable/RexImpTable.java
index 4dfba29852..94daa1016d 100644
--- a/core/src/main/java/org/apache/calcite/adapter/enumerable/RexImpTable.java
+++ b/core/src/main/java/org/apache/calcite/adapter/enumerable/RexImpTable.java
@@ -220,6 +220,7 @@ import static 
org.apache.calcite.sql.fun.SqlLibraryOperators.RIGHT;
 import static org.apache.calcite.sql.fun.SqlLibraryOperators.RLIKE;
 import static org.apache.calcite.sql.fun.SqlLibraryOperators.RPAD;
 import static org.apache.calcite.sql.fun.SqlLibraryOperators.SAFE_CAST;
+import static org.apache.calcite.sql.fun.SqlLibraryOperators.SAFE_MULTIPLY;
 import static org.apache.calcite.sql.fun.SqlLibraryOperators.SAFE_OFFSET;
 import static org.apache.calcite.sql.fun.SqlLibraryOperators.SAFE_ORDINAL;
 import static org.apache.calcite.sql.fun.SqlLibraryOperators.SEC;
@@ -614,6 +615,8 @@ public class RexImpTable {
       defineMethod(TRUNC, "struncate", NullPolicy.STRICT);
       defineMethod(TRUNCATE, "struncate", NullPolicy.STRICT);
 
+      map.put(SAFE_MULTIPLY, new SafeArithmeticImplementor());
+
       map.put(PI, new PiImplementor());
       return populate2();
     }
@@ -2380,6 +2383,30 @@ public class RexImpTable {
     }
   }
 
+  /** Implementor for the {@code SAFE_MULTIPLY} function. */
+  private static class SafeArithmeticImplementor extends MethodNameImplementor 
{
+    SafeArithmeticImplementor() {
+      super("safeMultiply", NullPolicy.STRICT, false);
+    }
+
+    @Override Expression implementSafe(final RexToLixTranslator translator,
+        final RexCall call, final List<Expression> argValueList) {
+      Expression arg0 = convertType(argValueList.get(0), call.operands.get(0));
+      Expression arg1 = convertType(argValueList.get(1), call.operands.get(1));
+      return Expressions.call(SqlFunctions.class, "safeMultiply", arg0, arg1);
+    }
+
+    // Because BigQuery treats all int types as aliases for BIGINT (Java's 
long)
+    // they can all be converted to LONG to minimize entries in the 
SqlFunctions class.
+    private Expression convertType(Expression arg, RexNode node) {
+      if (SqlTypeName.INT_TYPES.contains(node.getType().getSqlTypeName())) {
+        return Expressions.convert_(arg, long.class);
+      } else {
+        return arg;
+      }
+    }
+  }
+
   /** Implementor for the {@code FLOOR} and {@code CEIL} functions. */
   private static class FloorImplementor extends MethodNameImplementor {
     final Method timestampMethod;
diff --git a/core/src/main/java/org/apache/calcite/runtime/SqlFunctions.java 
b/core/src/main/java/org/apache/calcite/runtime/SqlFunctions.java
index 1bf6e855f9..4c3233bbfd 100644
--- a/core/src/main/java/org/apache/calcite/runtime/SqlFunctions.java
+++ b/core/src/main/java/org/apache/calcite/runtime/SqlFunctions.java
@@ -1705,6 +1705,85 @@ public class SqlFunctions {
     throw notArithmetic("*", b0, b1);
   }
 
+  /** SQL <code>SAFE_MULTIPLY</code> function applied to long values. */
+  public static @Nullable Long safeMultiply(long b0, long b1) {
+    try {
+      return Math.multiplyExact(b0, b1);
+    } catch (ArithmeticException e) {
+      return null;
+    }
+  }
+
+  /** SQL <code>SAFE_MULTIPLY</code> function applied to long and BigDecimal 
values. */
+  public static @Nullable BigDecimal safeMultiply(long b0, BigDecimal b1) {
+    BigDecimal ans = BigDecimal.valueOf(b0).multiply(b1);
+    return safeDecimal(ans) ? ans : null;
+  }
+
+  /** SQL <code>SAFE_MULTIPLY</code> function applied to BigDecimal and long 
values. */
+  public static @Nullable BigDecimal safeMultiply(BigDecimal b0, long b1) {
+    BigDecimal ans = b0.multiply(BigDecimal.valueOf(b1));
+    return safeDecimal(ans) ? ans : null;
+  }
+
+  /** SQL <code>SAFE_MULTIPLY</code> function applied to BigDecimal values. */
+  public static @Nullable BigDecimal safeMultiply(BigDecimal b0, BigDecimal 
b1) {
+    BigDecimal ans = b0.multiply(b1);
+    return safeDecimal(ans) ? ans : null;
+  }
+
+  /** SQL <code>SAFE_MULTIPLY</code> function applied to double and long 
values. */
+  public static @Nullable Double safeMultiply(double b0, long b1) {
+    double ans = b0 * b1;
+    return safeDouble(ans) || !Double.isFinite(b0) ? ans : null;
+  }
+
+  /** SQL <code>SAFE_MULTIPLY</code> function applied to long and double 
values. */
+  public static @Nullable Double safeMultiply(long b0, double b1) {
+    double ans = b0 * b1;
+    return safeDouble(ans) || !Double.isFinite(b1) ? ans : null;
+  }
+
+  /** SQL <code>SAFE_MULTIPLY</code> function applied to double and BigDecimal 
values. */
+  public static @Nullable Double safeMultiply(double b0, BigDecimal b1) {
+    double ans = b0 * b1.doubleValue();
+    return safeDouble(ans) || !Double.isFinite(b0) ? ans : null;
+  }
+
+  /** SQL <code>SAFE_MULTIPLY</code> function applied to BigDecimal and double 
values. */
+  public static @Nullable Double safeMultiply(BigDecimal b0, double b1) {
+    double ans = b0.doubleValue() * b1;
+    return safeDouble(ans) || !Double.isFinite(b1) ? ans : null;
+  }
+
+  /** SQL <code>SAFE_MULTIPLY</code> function applied to double values. */
+  public static @Nullable Double safeMultiply(double b0, double b1) {
+    double ans = b0 * b1;
+    boolean isFinite = Double.isFinite(b0) && Double.isFinite(b1);
+    return safeDouble(ans) || !isFinite ? ans : null;
+  }
+
+  /** Returns whether a BigDecimal value is safe (that is, has not overflowed).
+   * According to BigQuery, BigDecimal overflow occurs if the precision is 
greater
+   * than 76 or the scale is greater than 38. */
+  private static boolean safeDecimal(BigDecimal b) {
+    return b.scale() <= 38 && b.precision() <= 76;
+  }
+
+  /** Returns whether a double value is safe (that is, has not overflowed). */
+  private static boolean safeDouble(double d) {
+    // If the double is positive and falls between the MIN and MAX double 
values,
+    // overflow has not occurred. If the double is negative and falls between 
the
+    // negated MIN and MAX double values, overflow has not occurred. Otherwise,
+    // overflow has occurred. Important to note that 'Double.MIN_VALUE' refers 
to
+    // minimum positive value.
+    if (d < Double.MAX_VALUE && d > Double.MIN_VALUE) {
+      return true;
+    } else {
+      return d > -Double.MAX_VALUE && d < -Double.MIN_VALUE;
+    }
+  }
+
   private static RuntimeException notArithmetic(String op, Object b0,
       Object b1) {
     return RESOURCE.invalidTypesForArithmetic(b0.getClass().toString(),
diff --git 
a/core/src/main/java/org/apache/calcite/sql/fun/SqlLibraryOperators.java 
b/core/src/main/java/org/apache/calcite/sql/fun/SqlLibraryOperators.java
index 4bff7cb103..749d12f978 100644
--- a/core/src/main/java/org/apache/calcite/sql/fun/SqlLibraryOperators.java
+++ b/core/src/main/java/org/apache/calcite/sql/fun/SqlLibraryOperators.java
@@ -1642,6 +1642,15 @@ public abstract class SqlLibraryOperators {
           OperandTypes.family(SqlTypeFamily.TIMESTAMP, SqlTypeFamily.TIMESTAMP,
               SqlTypeFamily.ANY));
 
+  /** The "SAFE_MULTIPLY(numeric1, numeric2)" function; equivalent to the 
{@code *} operator but
+   * returns null if overflow occurs. */
+  @LibraryOperator(libraries = {BIG_QUERY})
+  public static final SqlFunction SAFE_MULTIPLY =
+      SqlBasicFunction.create("SAFE_MULTIPLY",
+          ReturnTypes.PRODUCT_FORCE_NULLABLE,
+          OperandTypes.NUMERIC_NUMERIC,
+          SqlFunctionCategory.NUMERIC);
+
   /** The "CHAR(n)" function; returns the character whose ASCII code is
    * {@code n} % 256, or null if {@code n} &lt; 0. */
   @LibraryOperator(libraries = {MYSQL, SPARK})
diff --git a/core/src/main/java/org/apache/calcite/sql/type/ReturnTypes.java 
b/core/src/main/java/org/apache/calcite/sql/type/ReturnTypes.java
index e10107f28b..8bf20706f8 100644
--- a/core/src/main/java/org/apache/calcite/sql/type/ReturnTypes.java
+++ b/core/src/main/java/org/apache/calcite/sql/type/ReturnTypes.java
@@ -773,6 +773,15 @@ public abstract class ReturnTypes {
   public static final SqlReturnTypeInference DECIMAL_PRODUCT_NULLABLE =
       DECIMAL_PRODUCT.andThen(SqlTypeTransforms.TO_NULLABLE);
 
+  /**
+   * Same as {@link #DECIMAL_PRODUCT_NULLABLE} but returns with nullability if 
any of
+   * the operands is nullable or the operation results in overflow by using
+   * {@link org.apache.calcite.sql.type.SqlTypeTransforms#FORCE_NULLABLE}. 
Also handles
+   * multiplication for integers, not just decimals.
+   */
+  public static final SqlReturnTypeInference PRODUCT_FORCE_NULLABLE =
+      
DECIMAL_PRODUCT_NULLABLE.orElse(LEAST_RESTRICTIVE).andThen(SqlTypeTransforms.FORCE_NULLABLE);
+
   /**
    * Type-inference strategy whereby the result type of a call is
    * {@link #DECIMAL_PRODUCT_NULLABLE} with a fallback to
diff --git a/site/_docs/reference.md b/site/_docs/reference.md
index a9a147a0ca..1211499a25 100644
--- a/site/_docs/reference.md
+++ b/site/_docs/reference.md
@@ -2785,6 +2785,7 @@ BigQuery's type system uses confusingly different names 
for types and functions:
 | b o | RPAD(string, length[, pattern ])             | Returns a string or 
bytes value that consists of *string* appended to *length* with *pattern*
 | b o | RTRIM(string)                                | Returns *string* with 
all blanks removed from the end
 | b | SAFE_CAST(value AS type)                       | Converts *value* to 
*type*, returning NULL if conversion fails
+| b | SAFE_MULTIPLY(numeric1, numeric2)              | Returns *numeric1* * 
*numeric2*, or NULL on overflow
 | b | SAFE_OFFSET(index)                             | Similar to `OFFSET` 
except null is returned if *index* is out of bounds
 | b | SAFE_ORDINAL(index)                            | Similar to `OFFSET` 
except *index* begins at 1 and null is returned if *index* is out of bounds
 | * | SEC(numeric)                                   | Returns the secant of 
*numeric* in radians
diff --git a/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java 
b/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java
index 82ba08dbe5..ac54e24c70 100644
--- a/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java
+++ b/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java
@@ -7002,6 +7002,77 @@ public class SqlOperatorTest {
     f.checkNull("truncate(cast(null as double))");
   }
 
+  @Test void testSafeMultiplyFunc() {
+    final SqlOperatorFixture f0 = 
fixture().setFor(SqlLibraryOperators.SAFE_MULTIPLY);
+    f0.checkFails("^safe_multiply(2, 3)^",
+        "No match found for function signature "
+        + "SAFE_MULTIPLY\\(<NUMERIC>, <NUMERIC>\\)", false);
+    final SqlOperatorFixture f = f0.withLibrary(SqlLibrary.BIG_QUERY);
+    // Basic test for each of the 9 2-permutations of BIGINT, DECIMAL, and 
FLOAT
+    f.checkScalar("safe_multiply(cast(20 as bigint), cast(20 as bigint))",
+        "400", "BIGINT");
+    f.checkScalar("safe_multiply(cast(20 as bigint), cast(1.2345 as 
decimal(5,4)))",
+        "24.6900", "DECIMAL(19, 4)");
+    f.checkScalar("safe_multiply(cast(1.2345 as decimal(5,4)), cast(20 as 
bigint))",
+        "24.6900", "DECIMAL(19, 4)");
+    f.checkScalar("safe_multiply(cast(1.2345 as decimal(5,4)), "
+        + "cast(2.0 as decimal(2, 1)))", "2.46900", "DECIMAL(7, 5)");
+    f.checkScalar("safe_multiply(cast(3 as double), cast(3 as bigint))",
+        "9.0", "DOUBLE");
+    f.checkScalar("safe_multiply(cast(3 as bigint), cast(3 as double))",
+        "9.0", "DOUBLE");
+    f.checkScalar("safe_multiply(cast(3 as double), cast(1.2345 as decimal(5, 
4)))",
+        "3.7035", "DOUBLE");
+    f.checkScalar("safe_multiply(cast(1.2345 as decimal(5, 4)), cast(3 as 
double))",
+        "3.7035", "DOUBLE");
+    f.checkScalar("safe_multiply(cast(3 as double), cast(3 as double))",
+        "9.0", "DOUBLE");
+    // Tests for + and - Infinity
+    f.checkScalar("safe_multiply(cast('Infinity' as double), cast(3 as 
double))",
+        "Infinity", "DOUBLE");
+    f.checkScalar("safe_multiply(cast('-Infinity' as double), cast(3 as 
double))",
+        "-Infinity", "DOUBLE");
+    f.checkScalar("safe_multiply(cast('-Infinity' as double), "
+        + "cast('Infinity' as double))", "-Infinity", "DOUBLE");
+    // Tests for NaN
+    f.checkScalar("safe_multiply(cast('NaN' as double), cast(3 as bigint))",
+        "NaN", "DOUBLE");
+    f.checkScalar("safe_multiply(cast('NaN' as double), cast(1.23 as 
decimal(3, 2)))",
+        "NaN", "DOUBLE");
+    f.checkScalar("safe_multiply(cast('NaN' as double), cast('Infinity' as 
double))",
+        "NaN", "DOUBLE");
+    f.checkScalar("safe_multiply(cast(3 as bigint), cast('NaN' as double))",
+        "NaN", "DOUBLE");
+    f.checkScalar("safe_multiply(cast(1.23 as decimal(3, 2)), cast('NaN' as 
double))",
+        "NaN", "DOUBLE");
+    // Overflow test for each pairing
+    f.checkNull("safe_multiply(cast(20 as bigint), "
+        + "cast(9223372036854775807 as bigint))");
+    f.checkNull("safe_multiply(cast(20 as bigint), "
+        + "cast(-9223372036854775807 as bigint))");
+    f.checkNull("safe_multiply(cast(10 as bigint), cast(3.5e75 as DECIMAL(76, 
0)))");
+    f.checkNull("safe_multiply(cast(10 as bigint), cast(-3.5e75 as DECIMAL(76, 
0)))");
+    f.checkNull("safe_multiply(cast(3.5e75 as DECIMAL(76, 0)), cast(10 as 
bigint))");
+    f.checkNull("safe_multiply(cast(-3.5e75 as DECIMAL(76, 0)), cast(10 as 
bigint))");
+    f.checkNull("safe_multiply(cast(3.5e75 as DECIMAL(76, 0)), "
+        + "cast(1.5 as DECIMAL(2, 1)))");
+    f.checkNull("safe_multiply(cast(-3.5e75 as DECIMAL(76, 0)), "
+        + "cast(1.5 as DECIMAL(2, 1)))");
+    f.checkNull("safe_multiply(cast(1.7e308 as double), cast(1.23 as 
decimal(3, 2)))");
+    f.checkNull("safe_multiply(cast(-1.7e308 as double), cast(1.2 as 
decimal(2, 1)))");
+    f.checkNull("safe_multiply(cast(1.2 as decimal(2, 1)), cast(1.7e308 as 
double))");
+    f.checkNull("safe_multiply(cast(1.2 as decimal(2, 1)), cast(-1.7e308 as 
double))");
+    f.checkNull("safe_multiply(cast(1.7e308 as double), cast(3 as bigint))");
+    f.checkNull("safe_multiply(cast(-1.7e308 as double), cast(3 as bigint))");
+    f.checkNull("safe_multiply(cast(3 as bigint), cast(1.7e308 as double))");
+    f.checkNull("safe_multiply(cast(3 as bigint), cast(-1.7e308 as double))");
+    f.checkNull("safe_multiply(cast(3 as double), cast(1.7e308 as double))");
+    f.checkNull("safe_multiply(cast(3 as double), cast(-1.7e308 as double))");
+    // Check that null argument retuns null
+    f.checkNull("safe_multiply(cast(null as double), cast(3 as bigint))");
+    f.checkNull("safe_multiply(cast(3 as double), cast(null as bigint))");
+  }
+
   @Test void testNullifFunc() {
     final SqlOperatorFixture f = fixture();
     f.setFor(SqlStdOperatorTable.NULLIF, VM_EXPAND);

Reply via email to