This is an automated email from the ASF dual-hosted git repository.
tanner pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/calcite.git
The following commit(s) were added to refs/heads/main by this push:
new 1b8dc60b76 [CALCITE-5735] Add SAFE_MULTIPLY function (enabled for
BigQuery)
1b8dc60b76 is described below
commit 1b8dc60b7673d83445f06ecaebd5f89dfd5781f6
Author: Tanner Clary <[email protected]>
AuthorDate: Mon May 22 23:54:29 2023 -0700
[CALCITE-5735] Add SAFE_MULTIPLY function (enabled for BigQuery)
---
babel/src/test/resources/sql/big-query.iq | 69 +++++++++++++++++++
.../calcite/adapter/enumerable/RexImpTable.java | 27 ++++++++
.../org/apache/calcite/runtime/SqlFunctions.java | 79 ++++++++++++++++++++++
.../calcite/sql/fun/SqlLibraryOperators.java | 9 +++
.../org/apache/calcite/sql/type/ReturnTypes.java | 9 +++
site/_docs/reference.md | 1 +
.../org/apache/calcite/test/SqlOperatorTest.java | 71 +++++++++++++++++++
7 files changed, 265 insertions(+)
diff --git a/babel/src/test/resources/sql/big-query.iq
b/babel/src/test/resources/sql/big-query.iq
index 429ec7b830..6a6970dfba 100755
--- a/babel/src/test/resources/sql/big-query.iq
+++ b/babel/src/test/resources/sql/big-query.iq
@@ -600,6 +600,75 @@ FROM t;
!ok
!}
+#####################################################################
+# SAFE_MULTIPLY
+#
+# SAFE_MULTIPLY(value1, value2)
+#
+# Equivalent to the mulitply operator (*), but returns NULL if
overflow/underflow occurs.
+SELECT SAFE_MULTIPLY(5, 4) as result;
++--------+
+| result |
++--------+
+| 20 |
++--------+
+(1 row)
+
+!ok
+
+# Overflow occurs if result is greater than 2^63 - 1
+SELECT SAFE_MULTIPLY(9223372036854775807, 2) as overflow_result;
++-----------------+
+| overflow_result |
++-----------------+
+| |
++-----------------+
+(1 row)
+
+!ok
+
+# Underflow occurs if result is less than -2^63
+SELECT SAFE_MULTIPLY(-9223372036854775806, 3) as underflow_result;
++------------------+
+| underflow_result |
++------------------+
+| |
++------------------+
+(1 row)
+
+!ok
+
+SELECT SAFE_MULTIPLY(CAST(1.7e308 as DOUBLE), CAST(3 as BIGINT)) as
double_overflow;
++-----------------+
+| double_overflow |
++-----------------+
+| |
++-----------------+
+(1 row)
+
+!ok
+
+SELECT SAFE_MULTIPLY(CAST(-3.5e75 AS DECIMAL(76, 0)), CAST(10 AS BIGINT)) as
decimal_overflow;
++------------------+
+| decimal_overflow |
++------------------+
+| |
++------------------+
+(1 row)
+
+!ok
+
+# NaN arguments should return NaN
+SELECT SAFE_MULTIPLY(CAST('NaN' AS DOUBLE), CAST(3 as BIGINT)) as NaN_result;
++------------+
+| NaN_result |
++------------+
+| NaN |
++------------+
+(1 row)
+
+!ok
+
#####################################################################
# NOT EQUAL Operator (value1 != value2)
#
diff --git
a/core/src/main/java/org/apache/calcite/adapter/enumerable/RexImpTable.java
b/core/src/main/java/org/apache/calcite/adapter/enumerable/RexImpTable.java
index 4dfba29852..94daa1016d 100644
--- a/core/src/main/java/org/apache/calcite/adapter/enumerable/RexImpTable.java
+++ b/core/src/main/java/org/apache/calcite/adapter/enumerable/RexImpTable.java
@@ -220,6 +220,7 @@ import static
org.apache.calcite.sql.fun.SqlLibraryOperators.RIGHT;
import static org.apache.calcite.sql.fun.SqlLibraryOperators.RLIKE;
import static org.apache.calcite.sql.fun.SqlLibraryOperators.RPAD;
import static org.apache.calcite.sql.fun.SqlLibraryOperators.SAFE_CAST;
+import static org.apache.calcite.sql.fun.SqlLibraryOperators.SAFE_MULTIPLY;
import static org.apache.calcite.sql.fun.SqlLibraryOperators.SAFE_OFFSET;
import static org.apache.calcite.sql.fun.SqlLibraryOperators.SAFE_ORDINAL;
import static org.apache.calcite.sql.fun.SqlLibraryOperators.SEC;
@@ -614,6 +615,8 @@ public class RexImpTable {
defineMethod(TRUNC, "struncate", NullPolicy.STRICT);
defineMethod(TRUNCATE, "struncate", NullPolicy.STRICT);
+ map.put(SAFE_MULTIPLY, new SafeArithmeticImplementor());
+
map.put(PI, new PiImplementor());
return populate2();
}
@@ -2380,6 +2383,30 @@ public class RexImpTable {
}
}
+ /** Implementor for the {@code SAFE_MULTIPLY} function. */
+ private static class SafeArithmeticImplementor extends MethodNameImplementor
{
+ SafeArithmeticImplementor() {
+ super("safeMultiply", NullPolicy.STRICT, false);
+ }
+
+ @Override Expression implementSafe(final RexToLixTranslator translator,
+ final RexCall call, final List<Expression> argValueList) {
+ Expression arg0 = convertType(argValueList.get(0), call.operands.get(0));
+ Expression arg1 = convertType(argValueList.get(1), call.operands.get(1));
+ return Expressions.call(SqlFunctions.class, "safeMultiply", arg0, arg1);
+ }
+
+ // Because BigQuery treats all int types as aliases for BIGINT (Java's
long)
+ // they can all be converted to LONG to minimize entries in the
SqlFunctions class.
+ private Expression convertType(Expression arg, RexNode node) {
+ if (SqlTypeName.INT_TYPES.contains(node.getType().getSqlTypeName())) {
+ return Expressions.convert_(arg, long.class);
+ } else {
+ return arg;
+ }
+ }
+ }
+
/** Implementor for the {@code FLOOR} and {@code CEIL} functions. */
private static class FloorImplementor extends MethodNameImplementor {
final Method timestampMethod;
diff --git a/core/src/main/java/org/apache/calcite/runtime/SqlFunctions.java
b/core/src/main/java/org/apache/calcite/runtime/SqlFunctions.java
index 1bf6e855f9..4c3233bbfd 100644
--- a/core/src/main/java/org/apache/calcite/runtime/SqlFunctions.java
+++ b/core/src/main/java/org/apache/calcite/runtime/SqlFunctions.java
@@ -1705,6 +1705,85 @@ public class SqlFunctions {
throw notArithmetic("*", b0, b1);
}
+ /** SQL <code>SAFE_MULTIPLY</code> function applied to long values. */
+ public static @Nullable Long safeMultiply(long b0, long b1) {
+ try {
+ return Math.multiplyExact(b0, b1);
+ } catch (ArithmeticException e) {
+ return null;
+ }
+ }
+
+ /** SQL <code>SAFE_MULTIPLY</code> function applied to long and BigDecimal
values. */
+ public static @Nullable BigDecimal safeMultiply(long b0, BigDecimal b1) {
+ BigDecimal ans = BigDecimal.valueOf(b0).multiply(b1);
+ return safeDecimal(ans) ? ans : null;
+ }
+
+ /** SQL <code>SAFE_MULTIPLY</code> function applied to BigDecimal and long
values. */
+ public static @Nullable BigDecimal safeMultiply(BigDecimal b0, long b1) {
+ BigDecimal ans = b0.multiply(BigDecimal.valueOf(b1));
+ return safeDecimal(ans) ? ans : null;
+ }
+
+ /** SQL <code>SAFE_MULTIPLY</code> function applied to BigDecimal values. */
+ public static @Nullable BigDecimal safeMultiply(BigDecimal b0, BigDecimal
b1) {
+ BigDecimal ans = b0.multiply(b1);
+ return safeDecimal(ans) ? ans : null;
+ }
+
+ /** SQL <code>SAFE_MULTIPLY</code> function applied to double and long
values. */
+ public static @Nullable Double safeMultiply(double b0, long b1) {
+ double ans = b0 * b1;
+ return safeDouble(ans) || !Double.isFinite(b0) ? ans : null;
+ }
+
+ /** SQL <code>SAFE_MULTIPLY</code> function applied to long and double
values. */
+ public static @Nullable Double safeMultiply(long b0, double b1) {
+ double ans = b0 * b1;
+ return safeDouble(ans) || !Double.isFinite(b1) ? ans : null;
+ }
+
+ /** SQL <code>SAFE_MULTIPLY</code> function applied to double and BigDecimal
values. */
+ public static @Nullable Double safeMultiply(double b0, BigDecimal b1) {
+ double ans = b0 * b1.doubleValue();
+ return safeDouble(ans) || !Double.isFinite(b0) ? ans : null;
+ }
+
+ /** SQL <code>SAFE_MULTIPLY</code> function applied to BigDecimal and double
values. */
+ public static @Nullable Double safeMultiply(BigDecimal b0, double b1) {
+ double ans = b0.doubleValue() * b1;
+ return safeDouble(ans) || !Double.isFinite(b1) ? ans : null;
+ }
+
+ /** SQL <code>SAFE_MULTIPLY</code> function applied to double values. */
+ public static @Nullable Double safeMultiply(double b0, double b1) {
+ double ans = b0 * b1;
+ boolean isFinite = Double.isFinite(b0) && Double.isFinite(b1);
+ return safeDouble(ans) || !isFinite ? ans : null;
+ }
+
+ /** Returns whether a BigDecimal value is safe (that is, has not overflowed).
+ * According to BigQuery, BigDecimal overflow occurs if the precision is
greater
+ * than 76 or the scale is greater than 38. */
+ private static boolean safeDecimal(BigDecimal b) {
+ return b.scale() <= 38 && b.precision() <= 76;
+ }
+
+ /** Returns whether a double value is safe (that is, has not overflowed). */
+ private static boolean safeDouble(double d) {
+ // If the double is positive and falls between the MIN and MAX double
values,
+ // overflow has not occurred. If the double is negative and falls between
the
+ // negated MIN and MAX double values, overflow has not occurred. Otherwise,
+ // overflow has occurred. Important to note that 'Double.MIN_VALUE' refers
to
+ // minimum positive value.
+ if (d < Double.MAX_VALUE && d > Double.MIN_VALUE) {
+ return true;
+ } else {
+ return d > -Double.MAX_VALUE && d < -Double.MIN_VALUE;
+ }
+ }
+
private static RuntimeException notArithmetic(String op, Object b0,
Object b1) {
return RESOURCE.invalidTypesForArithmetic(b0.getClass().toString(),
diff --git
a/core/src/main/java/org/apache/calcite/sql/fun/SqlLibraryOperators.java
b/core/src/main/java/org/apache/calcite/sql/fun/SqlLibraryOperators.java
index 4bff7cb103..749d12f978 100644
--- a/core/src/main/java/org/apache/calcite/sql/fun/SqlLibraryOperators.java
+++ b/core/src/main/java/org/apache/calcite/sql/fun/SqlLibraryOperators.java
@@ -1642,6 +1642,15 @@ public abstract class SqlLibraryOperators {
OperandTypes.family(SqlTypeFamily.TIMESTAMP, SqlTypeFamily.TIMESTAMP,
SqlTypeFamily.ANY));
+ /** The "SAFE_MULTIPLY(numeric1, numeric2)" function; equivalent to the
{@code *} operator but
+ * returns null if overflow occurs. */
+ @LibraryOperator(libraries = {BIG_QUERY})
+ public static final SqlFunction SAFE_MULTIPLY =
+ SqlBasicFunction.create("SAFE_MULTIPLY",
+ ReturnTypes.PRODUCT_FORCE_NULLABLE,
+ OperandTypes.NUMERIC_NUMERIC,
+ SqlFunctionCategory.NUMERIC);
+
/** The "CHAR(n)" function; returns the character whose ASCII code is
* {@code n} % 256, or null if {@code n} < 0. */
@LibraryOperator(libraries = {MYSQL, SPARK})
diff --git a/core/src/main/java/org/apache/calcite/sql/type/ReturnTypes.java
b/core/src/main/java/org/apache/calcite/sql/type/ReturnTypes.java
index e10107f28b..8bf20706f8 100644
--- a/core/src/main/java/org/apache/calcite/sql/type/ReturnTypes.java
+++ b/core/src/main/java/org/apache/calcite/sql/type/ReturnTypes.java
@@ -773,6 +773,15 @@ public abstract class ReturnTypes {
public static final SqlReturnTypeInference DECIMAL_PRODUCT_NULLABLE =
DECIMAL_PRODUCT.andThen(SqlTypeTransforms.TO_NULLABLE);
+ /**
+ * Same as {@link #DECIMAL_PRODUCT_NULLABLE} but returns with nullability if
any of
+ * the operands is nullable or the operation results in overflow by using
+ * {@link org.apache.calcite.sql.type.SqlTypeTransforms#FORCE_NULLABLE}.
Also handles
+ * multiplication for integers, not just decimals.
+ */
+ public static final SqlReturnTypeInference PRODUCT_FORCE_NULLABLE =
+
DECIMAL_PRODUCT_NULLABLE.orElse(LEAST_RESTRICTIVE).andThen(SqlTypeTransforms.FORCE_NULLABLE);
+
/**
* Type-inference strategy whereby the result type of a call is
* {@link #DECIMAL_PRODUCT_NULLABLE} with a fallback to
diff --git a/site/_docs/reference.md b/site/_docs/reference.md
index a9a147a0ca..1211499a25 100644
--- a/site/_docs/reference.md
+++ b/site/_docs/reference.md
@@ -2785,6 +2785,7 @@ BigQuery's type system uses confusingly different names
for types and functions:
| b o | RPAD(string, length[, pattern ]) | Returns a string or
bytes value that consists of *string* appended to *length* with *pattern*
| b o | RTRIM(string) | Returns *string* with
all blanks removed from the end
| b | SAFE_CAST(value AS type) | Converts *value* to
*type*, returning NULL if conversion fails
+| b | SAFE_MULTIPLY(numeric1, numeric2) | Returns *numeric1* *
*numeric2*, or NULL on overflow
| b | SAFE_OFFSET(index) | Similar to `OFFSET`
except null is returned if *index* is out of bounds
| b | SAFE_ORDINAL(index) | Similar to `OFFSET`
except *index* begins at 1 and null is returned if *index* is out of bounds
| * | SEC(numeric) | Returns the secant of
*numeric* in radians
diff --git a/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java
b/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java
index 82ba08dbe5..ac54e24c70 100644
--- a/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java
+++ b/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java
@@ -7002,6 +7002,77 @@ public class SqlOperatorTest {
f.checkNull("truncate(cast(null as double))");
}
+ @Test void testSafeMultiplyFunc() {
+ final SqlOperatorFixture f0 =
fixture().setFor(SqlLibraryOperators.SAFE_MULTIPLY);
+ f0.checkFails("^safe_multiply(2, 3)^",
+ "No match found for function signature "
+ + "SAFE_MULTIPLY\\(<NUMERIC>, <NUMERIC>\\)", false);
+ final SqlOperatorFixture f = f0.withLibrary(SqlLibrary.BIG_QUERY);
+ // Basic test for each of the 9 2-permutations of BIGINT, DECIMAL, and
FLOAT
+ f.checkScalar("safe_multiply(cast(20 as bigint), cast(20 as bigint))",
+ "400", "BIGINT");
+ f.checkScalar("safe_multiply(cast(20 as bigint), cast(1.2345 as
decimal(5,4)))",
+ "24.6900", "DECIMAL(19, 4)");
+ f.checkScalar("safe_multiply(cast(1.2345 as decimal(5,4)), cast(20 as
bigint))",
+ "24.6900", "DECIMAL(19, 4)");
+ f.checkScalar("safe_multiply(cast(1.2345 as decimal(5,4)), "
+ + "cast(2.0 as decimal(2, 1)))", "2.46900", "DECIMAL(7, 5)");
+ f.checkScalar("safe_multiply(cast(3 as double), cast(3 as bigint))",
+ "9.0", "DOUBLE");
+ f.checkScalar("safe_multiply(cast(3 as bigint), cast(3 as double))",
+ "9.0", "DOUBLE");
+ f.checkScalar("safe_multiply(cast(3 as double), cast(1.2345 as decimal(5,
4)))",
+ "3.7035", "DOUBLE");
+ f.checkScalar("safe_multiply(cast(1.2345 as decimal(5, 4)), cast(3 as
double))",
+ "3.7035", "DOUBLE");
+ f.checkScalar("safe_multiply(cast(3 as double), cast(3 as double))",
+ "9.0", "DOUBLE");
+ // Tests for + and - Infinity
+ f.checkScalar("safe_multiply(cast('Infinity' as double), cast(3 as
double))",
+ "Infinity", "DOUBLE");
+ f.checkScalar("safe_multiply(cast('-Infinity' as double), cast(3 as
double))",
+ "-Infinity", "DOUBLE");
+ f.checkScalar("safe_multiply(cast('-Infinity' as double), "
+ + "cast('Infinity' as double))", "-Infinity", "DOUBLE");
+ // Tests for NaN
+ f.checkScalar("safe_multiply(cast('NaN' as double), cast(3 as bigint))",
+ "NaN", "DOUBLE");
+ f.checkScalar("safe_multiply(cast('NaN' as double), cast(1.23 as
decimal(3, 2)))",
+ "NaN", "DOUBLE");
+ f.checkScalar("safe_multiply(cast('NaN' as double), cast('Infinity' as
double))",
+ "NaN", "DOUBLE");
+ f.checkScalar("safe_multiply(cast(3 as bigint), cast('NaN' as double))",
+ "NaN", "DOUBLE");
+ f.checkScalar("safe_multiply(cast(1.23 as decimal(3, 2)), cast('NaN' as
double))",
+ "NaN", "DOUBLE");
+ // Overflow test for each pairing
+ f.checkNull("safe_multiply(cast(20 as bigint), "
+ + "cast(9223372036854775807 as bigint))");
+ f.checkNull("safe_multiply(cast(20 as bigint), "
+ + "cast(-9223372036854775807 as bigint))");
+ f.checkNull("safe_multiply(cast(10 as bigint), cast(3.5e75 as DECIMAL(76,
0)))");
+ f.checkNull("safe_multiply(cast(10 as bigint), cast(-3.5e75 as DECIMAL(76,
0)))");
+ f.checkNull("safe_multiply(cast(3.5e75 as DECIMAL(76, 0)), cast(10 as
bigint))");
+ f.checkNull("safe_multiply(cast(-3.5e75 as DECIMAL(76, 0)), cast(10 as
bigint))");
+ f.checkNull("safe_multiply(cast(3.5e75 as DECIMAL(76, 0)), "
+ + "cast(1.5 as DECIMAL(2, 1)))");
+ f.checkNull("safe_multiply(cast(-3.5e75 as DECIMAL(76, 0)), "
+ + "cast(1.5 as DECIMAL(2, 1)))");
+ f.checkNull("safe_multiply(cast(1.7e308 as double), cast(1.23 as
decimal(3, 2)))");
+ f.checkNull("safe_multiply(cast(-1.7e308 as double), cast(1.2 as
decimal(2, 1)))");
+ f.checkNull("safe_multiply(cast(1.2 as decimal(2, 1)), cast(1.7e308 as
double))");
+ f.checkNull("safe_multiply(cast(1.2 as decimal(2, 1)), cast(-1.7e308 as
double))");
+ f.checkNull("safe_multiply(cast(1.7e308 as double), cast(3 as bigint))");
+ f.checkNull("safe_multiply(cast(-1.7e308 as double), cast(3 as bigint))");
+ f.checkNull("safe_multiply(cast(3 as bigint), cast(1.7e308 as double))");
+ f.checkNull("safe_multiply(cast(3 as bigint), cast(-1.7e308 as double))");
+ f.checkNull("safe_multiply(cast(3 as double), cast(1.7e308 as double))");
+ f.checkNull("safe_multiply(cast(3 as double), cast(-1.7e308 as double))");
+ // Check that null argument retuns null
+ f.checkNull("safe_multiply(cast(null as double), cast(3 as bigint))");
+ f.checkNull("safe_multiply(cast(3 as double), cast(null as bigint))");
+ }
+
@Test void testNullifFunc() {
final SqlOperatorFixture f = fixture();
f.setFor(SqlStdOperatorTable.NULLIF, VM_EXPAND);