This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch branch-3.3
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.3 by this push:
new 96c8b4f47c2 [SPARK-38855][SQL] DS V2 supports push down math functions
96c8b4f47c2 is described below
commit 96c8b4f47c2d0df249efb088882b248b5c230188
Author: Jiaan Geng <[email protected]>
AuthorDate: Wed Apr 13 14:41:47 2022 +0800
[SPARK-38855][SQL] DS V2 supports push down math functions
### What changes were proposed in this pull request?
Currently, Spark have some math functions of ANSI standard. Please refer
https://github.com/apache/spark/blob/2f8613f22c0750c00cf1dcfb2f31c431d8dc1be7/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala#L388
These functions show below:
`LN`,
`EXP`,
`POWER`,
`SQRT`,
`FLOOR`,
`CEIL`,
`WIDTH_BUCKET`
The mainstream databases support these functions show below.
| 函数 | PostgreSQL | ClickHouse | H2 | MySQL | Oracle | Redshift |
Presto | Teradata | Snowflake | DB2 | Vertica | Exasol | SqlServer |
Yellowbrick | Impala | Mariadb | Druid | Pig | SQLite | Influxdata |
Singlestore | ElasticSearch |
| ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- |
---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- |
---- | ---- | ---- | ---- |
| `LN` | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes |
Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes |
| `EXP` | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes |
Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes |
| `POWER` | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes
| Yes | Yes | Yes | Yes | Yes | Yes | No | Yes | Yes | Yes | Yes |
| `SQRT` | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes
| Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes |
| `FLOOR` | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes
| Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes |
| `CEIL` | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes
| Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes |
| `WIDTH_BUCKET` | Yes | No | No | No | Yes | No | Yes | Yes | Yes | Yes |
Yes | No | No | No | Yes | No | No | No | No | No | No | No |
DS V2 should supports push down these math functions.
### Why are the changes needed?
DS V2 supports push down math functions
### Does this PR introduce _any_ user-facing change?
'No'.
New feature.
### How was this patch tested?
New tests.
Closes #36140 from beliefer/SPARK-38855.
Authored-by: Jiaan Geng <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
(cherry picked from commit bf75b495e18ed87d0c118bfd5f1ceb52d720cad9)
Signed-off-by: Wenchen Fan <[email protected]>
---
.../expressions/GeneralScalarExpression.java | 54 ++++++++++++++++++++++
.../sql/connector/util/V2ExpressionSQLBuilder.java | 7 +++
.../spark/sql/errors/QueryCompilationErrors.scala | 4 ++
.../sql/catalyst/util/V2ExpressionBuilder.scala | 28 ++++++++++-
.../org/apache/spark/sql/jdbc/H2Dialect.scala | 26 +++++++++++
.../org/apache/spark/sql/jdbc/JDBCV2Suite.scala | 28 ++++++++++-
6 files changed, 145 insertions(+), 2 deletions(-)
diff --git
a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralScalarExpression.java
b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralScalarExpression.java
index 8952761f9ef..58082d5ee09 100644
---
a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralScalarExpression.java
+++
b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralScalarExpression.java
@@ -94,6 +94,60 @@ import
org.apache.spark.sql.connector.util.V2ExpressionSQLBuilder;
* <li>Since version: 3.3.0</li>
* </ul>
* </li>
+ * <li>Name: <code>ABS</code>
+ * <ul>
+ * <li>SQL semantic: <code>ABS(expr)</code></li>
+ * <li>Since version: 3.3.0</li>
+ * </ul>
+ * </li>
+ * <li>Name: <code>COALESCE</code>
+ * <ul>
+ * <li>SQL semantic: <code>COALESCE(expr1, expr2)</code></li>
+ * <li>Since version: 3.3.0</li>
+ * </ul>
+ * </li>
+ * <li>Name: <code>LN</code>
+ * <ul>
+ * <li>SQL semantic: <code>LN(expr)</code></li>
+ * <li>Since version: 3.3.0</li>
+ * </ul>
+ * </li>
+ * <li>Name: <code>EXP</code>
+ * <ul>
+ * <li>SQL semantic: <code>EXP(expr)</code></li>
+ * <li>Since version: 3.3.0</li>
+ * </ul>
+ * </li>
+ * <li>Name: <code>POWER</code>
+ * <ul>
+ * <li>SQL semantic: <code>POWER(expr, number)</code></li>
+ * <li>Since version: 3.3.0</li>
+ * </ul>
+ * </li>
+ * <li>Name: <code>SQRT</code>
+ * <ul>
+ * <li>SQL semantic: <code>SQRT(expr)</code></li>
+ * <li>Since version: 3.3.0</li>
+ * </ul>
+ * </li>
+ * <li>Name: <code>FLOOR</code>
+ * <ul>
+ * <li>SQL semantic: <code>FLOOR(expr)</code></li>
+ * <li>Since version: 3.3.0</li>
+ * </ul>
+ * </li>
+ * <li>Name: <code>CEIL</code>
+ * <ul>
+ * <li>SQL semantic: <code>CEIL(expr)</code></li>
+ * <li>Since version: 3.3.0</li>
+ * </ul>
+ * </li>
+ * <li>Name: <code>WIDTH_BUCKET</code>
+ * <ul>
+ * <li>SQL semantic: <code>WIDTH_BUCKET(expr)</code></li>
+ * <li>Since version: 3.3.0</li>
+ * </ul>
+ * </li>
* </ol>
* Note: SQL semantic conforms ANSI standard, so some expressions are not
supported when ANSI off,
* including: add, subtract, multiply, divide, remainder, pmod.
diff --git
a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java
b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java
index a7d1ed7f85e..c9dfa2003e3 100644
---
a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java
+++
b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java
@@ -95,6 +95,13 @@ public class V2ExpressionSQLBuilder {
return visitUnaryArithmetic(name, inputToSQL(e.children()[0]));
case "ABS":
case "COALESCE":
+ case "LN":
+ case "EXP":
+ case "POWER":
+ case "SQRT":
+ case "FLOOR":
+ case "CEIL":
+ case "WIDTH_BUCKET":
return visitSQLFunction(name,
Arrays.stream(e.children()).map(c ->
build(c)).toArray(String[]::new));
case "CASE_WHEN": {
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
index 0532a953ef4..f1357f91f9d 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
@@ -2392,4 +2392,8 @@ object QueryCompilationErrors {
new AnalysisException(
"Sinks cannot request distribution and ordering in continuous execution
mode")
}
+
+ def noSuchFunctionError(database: String, funcInfo: String): Throwable = {
+ new AnalysisException(s"$database does not support function: $funcInfo")
+ }
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala
b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala
index 37db499470a..487b809d48a 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.catalyst.util
-import org.apache.spark.sql.catalyst.expressions.{Abs, Add, And,
BinaryComparison, BinaryOperator, BitwiseAnd, BitwiseNot, BitwiseOr,
BitwiseXor, CaseWhen, Cast, Coalesce, Contains, Divide, EndsWith, EqualTo,
Expression, In, InSet, IsNotNull, IsNull, Literal, Multiply, Not, Or,
Predicate, Remainder, StartsWith, StringPredicate, Subtract, UnaryMinus}
+import org.apache.spark.sql.catalyst.expressions.{Abs, Add, And,
BinaryComparison, BinaryOperator, BitwiseAnd, BitwiseNot, BitwiseOr,
BitwiseXor, CaseWhen, Cast, Ceil, Coalesce, Contains, Divide, EndsWith,
EqualTo, Exp, Expression, Floor, In, InSet, IsNotNull, IsNull, Literal, Log,
Multiply, Not, Or, Pow, Predicate, Remainder, Sqrt, StartsWith,
StringPredicate, Subtract, UnaryMinus, WidthBucket}
import org.apache.spark.sql.connector.expressions.{Cast => V2Cast, Expression
=> V2Expression, FieldReference, GeneralScalarExpression, LiteralValue}
import org.apache.spark.sql.connector.expressions.filter.{AlwaysFalse,
AlwaysTrue, And => V2And, Not => V2Not, Or => V2Or, Predicate => V2Predicate}
import org.apache.spark.sql.execution.datasources.PushableColumn
@@ -104,6 +104,32 @@ class V2ExpressionBuilder(
} else {
None
}
+ case Log(child) => generateExpression(child)
+ .map(v => new GeneralScalarExpression("LN", Array[V2Expression](v)))
+ case Exp(child) => generateExpression(child)
+ .map(v => new GeneralScalarExpression("EXP", Array[V2Expression](v)))
+ case Pow(left, right) =>
+ val l = generateExpression(left)
+ val r = generateExpression(right)
+ if (l.isDefined && r.isDefined) {
+ Some(new GeneralScalarExpression("POWER", Array[V2Expression](l.get,
r.get)))
+ } else {
+ None
+ }
+ case Sqrt(child) => generateExpression(child)
+ .map(v => new GeneralScalarExpression("SQRT", Array[V2Expression](v)))
+ case Floor(child) => generateExpression(child)
+ .map(v => new GeneralScalarExpression("FLOOR", Array[V2Expression](v)))
+ case Ceil(child) => generateExpression(child)
+ .map(v => new GeneralScalarExpression("CEIL", Array[V2Expression](v)))
+ case wb: WidthBucket =>
+ val childrenExpressions = wb.children.flatMap(generateExpression(_))
+ if (childrenExpressions.length == wb.children.length) {
+ Some(new GeneralScalarExpression("WIDTH_BUCKET",
+ childrenExpressions.toArray[V2Expression]))
+ } else {
+ None
+ }
case and: And =>
// AND expects predicate
val l = generateExpression(and.left, true)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala
b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala
index 643376cdb12..0aa971c0d3a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala
@@ -20,14 +20,40 @@ package org.apache.spark.sql.jdbc
import java.sql.SQLException
import java.util.Locale
+import scala.util.control.NonFatal
+
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis.{NoSuchNamespaceException,
NoSuchTableException, TableAlreadyExistsException}
+import org.apache.spark.sql.connector.expressions.Expression
import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc,
GeneralAggregateFunc}
+import org.apache.spark.sql.errors.QueryCompilationErrors
private object H2Dialect extends JdbcDialect {
override def canHandle(url: String): Boolean =
url.toLowerCase(Locale.ROOT).startsWith("jdbc:h2")
+ class H2SQLBuilder extends JDBCSQLBuilder {
+ override def visitSQLFunction(funcName: String, inputs: Array[String]):
String = {
+ funcName match {
+ case "WIDTH_BUCKET" =>
+ val functionInfo = super.visitSQLFunction(funcName, inputs)
+ throw QueryCompilationErrors.noSuchFunctionError("H2", functionInfo)
+ case _ => super.visitSQLFunction(funcName, inputs)
+ }
+ }
+ }
+
+ override def compileExpression(expr: Expression): Option[String] = {
+ val h2SQLBuilder = new H2SQLBuilder()
+ try {
+ Some(h2SQLBuilder.build(expr))
+ } catch {
+ case NonFatal(e) =>
+ logWarning("Error occurs while compiling V2 expression", e)
+ None
+ }
+ }
+
override def compileAggregate(aggFunction: AggregateFunc): Option[String] = {
super.compileAggregate(aggFunction).orElse(
aggFunction match {
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
index 858781f2cde..e28d9ba9ba8 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
@@ -26,7 +26,7 @@ import
org.apache.spark.sql.catalyst.analysis.CannotReplaceMissingTableException
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, Sort}
import
org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanRelation,
V1ScanWrapper}
import org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCTableCatalog
-import org.apache.spark.sql.functions.{abs, avg, coalesce, count,
count_distinct, lit, not, sum, udf, when}
+import org.apache.spark.sql.functions.{abs, avg, ceil, coalesce, count,
count_distinct, exp, floor, lit, log => ln, not, pow, sqrt, sum, udf, when}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.util.Utils
@@ -440,6 +440,32 @@ class JDBCV2Suite extends QueryTest with
SharedSparkSession with ExplainSuiteHel
checkPushedInfo(df5, expectedPlanFragment5)
checkAnswer(df5, Seq(Row(1, "amy", 10000, 1000, true),
Row(1, "cathy", 9000, 1200, false), Row(6, "jen", 12000, 1200,
true)))
+
+ val df6 = spark.table("h2.test.employee")
+ .filter(ln($"dept") > 1)
+ .filter(exp($"salary") > 2000)
+ .filter(pow($"dept", 2) > 4)
+ .filter(sqrt($"salary") > 100)
+ .filter(floor($"dept") > 1)
+ .filter(ceil($"dept") > 1)
+ checkFiltersRemoved(df6, ansiMode)
+ val expectedPlanFragment6 = if (ansiMode) {
+ "PushedFilters: [DEPT IS NOT NULL, SALARY IS NOT NULL, " +
+ "LN(CAST(DEPT AS double)) > 1.0, EXP(CAST(SALARY AS double)...,"
+ } else {
+ "PushedFilters: [DEPT IS NOT NULL, SALARY IS NOT NULL]"
+ }
+ checkPushedInfo(df6, expectedPlanFragment6)
+ checkAnswer(df6, Seq(Row(6, "jen", 12000, 1200, true)))
+
+ // H2 does not support width_bucket
+ val df7 = sql("""
+ |SELECT * FROM h2.test.employee
+ |WHERE width_bucket(dept, 1, 6, 3) > 1
+ |""".stripMargin)
+ checkFiltersRemoved(df7, false)
+ checkPushedInfo(df7, "PushedFilters: [DEPT IS NOT NULL]")
+ checkAnswer(df7, Seq(Row(6, "jen", 12000, 1200, true)))
}
}
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]