This is an automated email from the ASF dual-hosted git repository.
srowen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new e4ca8424474 [SPARK-39384][SQL] Compile built-in linear regression
aggregate functions for JDBC dialect
e4ca8424474 is described below
commit e4ca8424474e571d8e137388fe5d54732b68c2f3
Author: Jiaan Geng <[email protected]>
AuthorDate: Sat Jul 16 09:05:28 2022 -0500
[SPARK-39384][SQL] Compile built-in linear regression aggregate functions
for JDBC dialect
### What changes were proposed in this pull request?
Recently, Spark DS V2 pushdown framework translate a lot of standard linear
regression aggregate functions.
Currently, only H2Dialect compile these standard linear regression
aggregate functions. This PR compile these standard linear regression aggregate
functions for other build-in JDBC dialect.
### Why are the changes needed?
Make build-in JDBC dialect support compile linear regression aggregate
push-down.
### Does this PR introduce _any_ user-facing change?
'No'.
New feature.
### How was this patch tested?
New test cases.
Closes #37188 from beliefer/SPARK-39384.
Authored-by: Jiaan Geng <[email protected]>
Signed-off-by: Sean Owen <[email protected]>
---
.../spark/sql/jdbc/v2/DB2IntegrationSuite.scala | 4 +
.../spark/sql/jdbc/v2/OracleIntegrationSuite.scala | 4 +
.../sql/jdbc/v2/PostgresIntegrationSuite.scala | 8 ++
.../org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala | 118 ++++++++++++++++-----
.../org/apache/spark/sql/jdbc/DB2Dialect.scala | 14 ++-
.../org/apache/spark/sql/jdbc/MySQLDialect.scala | 32 +++++-
.../org/apache/spark/sql/jdbc/OracleDialect.scala | 33 +++++-
.../apache/spark/sql/jdbc/PostgresDialect.scala | 3 +-
8 files changed, 185 insertions(+), 31 deletions(-)
diff --git
a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DB2IntegrationSuite.scala
b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DB2IntegrationSuite.scala
index 4b2bbbdd849..1a25cd2802d 100644
---
a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DB2IntegrationSuite.scala
+++
b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DB2IntegrationSuite.scala
@@ -106,4 +106,8 @@ class DB2IntegrationSuite extends
DockerJDBCIntegrationV2Suite with V2JDBCTest {
testStddevSamp(true)
testCovarPop()
testCovarSamp()
+ testRegrIntercept()
+ testRegrSlope()
+ testRegrR2()
+ testRegrSXY()
}
diff --git
a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/OracleIntegrationSuite.scala
b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/OracleIntegrationSuite.scala
index 8bc79a244e7..5de76089188 100644
---
a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/OracleIntegrationSuite.scala
+++
b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/OracleIntegrationSuite.scala
@@ -111,4 +111,8 @@ class OracleIntegrationSuite extends
DockerJDBCIntegrationV2Suite with V2JDBCTes
testCovarPop()
testCovarSamp()
testCorr()
+ testRegrIntercept()
+ testRegrSlope()
+ testRegrR2()
+ testRegrSXY()
}
diff --git
a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala
b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala
index 77ace3f3f4e..1ff7527c97b 100644
---
a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala
+++
b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala
@@ -104,4 +104,12 @@ class PostgresIntegrationSuite extends
DockerJDBCIntegrationV2Suite with V2JDBCT
testCovarSamp(true)
testCorr()
testCorr(true)
+ testRegrIntercept()
+ testRegrIntercept(true)
+ testRegrSlope()
+ testRegrSlope(true)
+ testRegrR2()
+ testRegrR2(true)
+ testRegrSXY()
+ testRegrSXY(true)
}
diff --git
a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala
b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala
index 0f85bd534c3..543c8465ed2 100644
---
a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala
+++
b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala
@@ -406,9 +406,11 @@ private[v2] trait V2JDBCTest extends SharedSparkSession
with DockerIntegrationFu
protected def caseConvert(tableName: String): String = tableName
+ private def withOrWithout(isDistinct: Boolean): String = if (isDistinct)
"with" else "without"
+
protected def testVarPop(isDistinct: Boolean = false): Unit = {
val distinct = if (isDistinct) "DISTINCT " else ""
- test(s"scan with aggregate push-down: VAR_POP with distinct: $isDistinct")
{
+ test(s"scan with aggregate push-down: VAR_POP ${withOrWithout(isDistinct)}
DISTINCT") {
val df = sql(s"SELECT VAR_POP(${distinct}bonus) FROM
$catalogAndNamespace." +
s"${caseConvert("employee")} WHERE dept > 0 GROUP BY dept ORDER BY
dept")
checkFilterPushed(df)
@@ -416,15 +418,15 @@ private[v2] trait V2JDBCTest extends SharedSparkSession
with DockerIntegrationFu
checkAggregatePushed(df, "VAR_POP")
val row = df.collect()
assert(row.length === 3)
- assert(row(0).getDouble(0) === 10000d)
- assert(row(1).getDouble(0) === 2500d)
- assert(row(2).getDouble(0) === 0d)
+ assert(row(0).getDouble(0) === 10000.0)
+ assert(row(1).getDouble(0) === 2500.0)
+ assert(row(2).getDouble(0) === 0.0)
}
}
protected def testVarSamp(isDistinct: Boolean = false): Unit = {
val distinct = if (isDistinct) "DISTINCT " else ""
- test(s"scan with aggregate push-down: VAR_SAMP with distinct:
$isDistinct") {
+ test(s"scan with aggregate push-down: VAR_SAMP
${withOrWithout(isDistinct)} DISTINCT") {
val df = sql(
s"SELECT VAR_SAMP(${distinct}bonus) FROM $catalogAndNamespace." +
s"${caseConvert("employee")} WHERE dept > 0 GROUP BY dept ORDER BY
dept")
@@ -433,15 +435,15 @@ private[v2] trait V2JDBCTest extends SharedSparkSession
with DockerIntegrationFu
checkAggregatePushed(df, "VAR_SAMP")
val row = df.collect()
assert(row.length === 3)
- assert(row(0).getDouble(0) === 20000d)
- assert(row(1).getDouble(0) === 5000d)
+ assert(row(0).getDouble(0) === 20000.0)
+ assert(row(1).getDouble(0) === 5000.0)
assert(row(2).isNullAt(0))
}
}
protected def testStddevPop(isDistinct: Boolean = false): Unit = {
val distinct = if (isDistinct) "DISTINCT " else ""
- test(s"scan with aggregate push-down: STDDEV_POP with distinct:
$isDistinct") {
+ test(s"scan with aggregate push-down: STDDEV_POP
${withOrWithout(isDistinct)} DISTINCT") {
val df = sql(
s"SELECT STDDEV_POP(${distinct}bonus) FROM $catalogAndNamespace." +
s"${caseConvert("employee")} WHERE dept > 0 GROUP BY dept ORDER BY
dept")
@@ -450,15 +452,15 @@ private[v2] trait V2JDBCTest extends SharedSparkSession
with DockerIntegrationFu
checkAggregatePushed(df, "STDDEV_POP")
val row = df.collect()
assert(row.length === 3)
- assert(row(0).getDouble(0) === 100d)
- assert(row(1).getDouble(0) === 50d)
- assert(row(2).getDouble(0) === 0d)
+ assert(row(0).getDouble(0) === 100.0)
+ assert(row(1).getDouble(0) === 50.0)
+ assert(row(2).getDouble(0) === 0.0)
}
}
protected def testStddevSamp(isDistinct: Boolean = false): Unit = {
val distinct = if (isDistinct) "DISTINCT " else ""
- test(s"scan with aggregate push-down: STDDEV_SAMP with distinct:
$isDistinct") {
+ test(s"scan with aggregate push-down: STDDEV_SAMP
${withOrWithout(isDistinct)} DISTINCT") {
val df = sql(
s"SELECT STDDEV_SAMP(${distinct}bonus) FROM $catalogAndNamespace." +
s"${caseConvert("employee")} WHERE dept > 0 GROUP BY dept ORDER BY
dept")
@@ -467,15 +469,15 @@ private[v2] trait V2JDBCTest extends SharedSparkSession
with DockerIntegrationFu
checkAggregatePushed(df, "STDDEV_SAMP")
val row = df.collect()
assert(row.length === 3)
- assert(row(0).getDouble(0) === 141.4213562373095d)
- assert(row(1).getDouble(0) === 70.71067811865476d)
+ assert(row(0).getDouble(0) === 141.4213562373095)
+ assert(row(1).getDouble(0) === 70.71067811865476)
assert(row(2).isNullAt(0))
}
}
protected def testCovarPop(isDistinct: Boolean = false): Unit = {
val distinct = if (isDistinct) "DISTINCT " else ""
- test(s"scan with aggregate push-down: COVAR_POP with distinct:
$isDistinct") {
+ test(s"scan with aggregate push-down: COVAR_POP
${withOrWithout(isDistinct)} DISTINCT") {
val df = sql(
s"SELECT COVAR_POP(${distinct}bonus, bonus) FROM
$catalogAndNamespace." +
s"${caseConvert("employee")} WHERE dept > 0 GROUP BY dept ORDER BY
dept")
@@ -484,15 +486,15 @@ private[v2] trait V2JDBCTest extends SharedSparkSession
with DockerIntegrationFu
checkAggregatePushed(df, "COVAR_POP")
val row = df.collect()
assert(row.length === 3)
- assert(row(0).getDouble(0) === 10000d)
- assert(row(1).getDouble(0) === 2500d)
- assert(row(2).getDouble(0) === 0d)
+ assert(row(0).getDouble(0) === 10000.0)
+ assert(row(1).getDouble(0) === 2500.0)
+ assert(row(2).getDouble(0) === 0.0)
}
}
protected def testCovarSamp(isDistinct: Boolean = false): Unit = {
val distinct = if (isDistinct) "DISTINCT " else ""
- test(s"scan with aggregate push-down: COVAR_SAMP with distinct:
$isDistinct") {
+ test(s"scan with aggregate push-down: COVAR_SAMP
${withOrWithout(isDistinct)} DISTINCT") {
val df = sql(
s"SELECT COVAR_SAMP(${distinct}bonus, bonus) FROM
$catalogAndNamespace." +
s"${caseConvert("employee")} WHERE dept > 0 GROUP BY dept ORDER BY
dept")
@@ -501,15 +503,15 @@ private[v2] trait V2JDBCTest extends SharedSparkSession
with DockerIntegrationFu
checkAggregatePushed(df, "COVAR_SAMP")
val row = df.collect()
assert(row.length === 3)
- assert(row(0).getDouble(0) === 20000d)
- assert(row(1).getDouble(0) === 5000d)
+ assert(row(0).getDouble(0) === 20000.0)
+ assert(row(1).getDouble(0) === 5000.0)
assert(row(2).isNullAt(0))
}
}
protected def testCorr(isDistinct: Boolean = false): Unit = {
val distinct = if (isDistinct) "DISTINCT " else ""
- test(s"scan with aggregate push-down: CORR with distinct: $isDistinct") {
+ test(s"scan with aggregate push-down: CORR ${withOrWithout(isDistinct)}
DISTINCT") {
val df = sql(
s"SELECT CORR(${distinct}bonus, bonus) FROM $catalogAndNamespace." +
s"${caseConvert("employee")} WHERE dept > 0 GROUP BY dept ORDER BY
dept")
@@ -518,9 +520,77 @@ private[v2] trait V2JDBCTest extends SharedSparkSession
with DockerIntegrationFu
checkAggregatePushed(df, "CORR")
val row = df.collect()
assert(row.length === 3)
- assert(row(0).getDouble(0) === 1d)
- assert(row(1).getDouble(0) === 1d)
+ assert(row(0).getDouble(0) === 1.0)
+ assert(row(1).getDouble(0) === 1.0)
+ assert(row(2).isNullAt(0))
+ }
+ }
+
+ protected def testRegrIntercept(isDistinct: Boolean = false): Unit = {
+ val distinct = if (isDistinct) "DISTINCT " else ""
+ test(s"scan with aggregate push-down: REGR_INTERCEPT
${withOrWithout(isDistinct)} DISTINCT") {
+ val df = sql(
+ s"SELECT REGR_INTERCEPT(${distinct}bonus, bonus) FROM
$catalogAndNamespace." +
+ s"${caseConvert("employee")} WHERE dept > 0 GROUP BY dept ORDER BY
dept")
+ checkFilterPushed(df)
+ checkAggregateRemoved(df)
+ checkAggregatePushed(df, "REGR_INTERCEPT")
+ val row = df.collect()
+ assert(row.length === 3)
+ assert(row(0).getDouble(0) === 0.0)
+ assert(row(1).getDouble(0) === 0.0)
+ assert(row(2).isNullAt(0))
+ }
+ }
+
+ protected def testRegrSlope(isDistinct: Boolean = false): Unit = {
+ val distinct = if (isDistinct) "DISTINCT " else ""
+ test(s"scan with aggregate push-down: REGR_SLOPE
${withOrWithout(isDistinct)} DISTINCT") {
+ val df = sql(
+ s"SELECT REGR_SLOPE(${distinct}bonus, bonus) FROM
$catalogAndNamespace." +
+ s"${caseConvert("employee")} WHERE dept > 0 GROUP BY dept ORDER BY
dept")
+ checkFilterPushed(df)
+ checkAggregateRemoved(df)
+ checkAggregatePushed(df, "REGR_SLOPE")
+ val row = df.collect()
+ assert(row.length === 3)
+ assert(row(0).getDouble(0) === 1.0)
+ assert(row(1).getDouble(0) === 1.0)
+ assert(row(2).isNullAt(0))
+ }
+ }
+
+ protected def testRegrR2(isDistinct: Boolean = false): Unit = {
+ val distinct = if (isDistinct) "DISTINCT " else ""
+ test(s"scan with aggregate push-down: REGR_R2 ${withOrWithout(isDistinct)}
DISTINCT") {
+ val df = sql(
+ s"SELECT REGR_R2(${distinct}bonus, bonus) FROM $catalogAndNamespace." +
+ s"${caseConvert("employee")} WHERE dept > 0 GROUP BY dept ORDER BY
dept")
+ checkFilterPushed(df)
+ checkAggregateRemoved(df)
+ checkAggregatePushed(df, "REGR_R2")
+ val row = df.collect()
+ assert(row.length === 3)
+ assert(row(0).getDouble(0) === 1.0)
+ assert(row(1).getDouble(0) === 1.0)
assert(row(2).isNullAt(0))
}
}
+
+ protected def testRegrSXY(isDistinct: Boolean = false): Unit = {
+ val distinct = if (isDistinct) "DISTINCT " else ""
+ test(s"scan with aggregate push-down: REGR_SXY
${withOrWithout(isDistinct)} DISTINCT") {
+ val df = sql(
+ s"SELECT REGR_SXY(${distinct}bonus, bonus) FROM $catalogAndNamespace."
+
+ s"${caseConvert("employee")} WHERE dept > 0 GROUP BY dept ORDER BY
dept")
+ checkFilterPushed(df)
+ checkAggregateRemoved(df)
+ checkAggregatePushed(df, "REGR_SXY")
+ val row = df.collect()
+ assert(row.length === 3)
+ assert(row(0).getDouble(0) === 20000.0)
+ assert(row(1).getDouble(0) === 5000.0)
+ assert(row(2).getDouble(0) === 0.0)
+ }
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala
b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala
index a3637e57266..6c7c1bfe737 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala
@@ -32,15 +32,27 @@ private object DB2Dialect extends JdbcDialect {
override def canHandle(url: String): Boolean =
url.toLowerCase(Locale.ROOT).startsWith("jdbc:db2")
+ private val distinctUnsupportedAggregateFunctions =
+ Set("COVAR_POP", "COVAR_SAMP", "REGR_INTERCEPT", "REGR_R2", "REGR_SLOPE",
"REGR_SXY")
+
// See https://www.ibm.com/docs/en/db2/11.5?topic=functions-aggregate
private val supportedAggregateFunctions = Set("MAX", "MIN", "SUM", "COUNT",
"AVG",
- "VAR_POP", "VAR_SAMP", "STDDEV_POP", "STDDEV_SAMP", "COVAR_POP",
"COVAR_SAMP")
+ "VAR_POP", "VAR_SAMP", "STDDEV_POP", "STDDEV_SAMP") ++
distinctUnsupportedAggregateFunctions
private val supportedFunctions = supportedAggregateFunctions
override def isSupportedFunction(funcName: String): Boolean =
supportedFunctions.contains(funcName)
class DB2SQLBuilder extends JDBCSQLBuilder {
+ override def visitAggregateFunction(
+ funcName: String, isDistinct: Boolean, inputs: Array[String]): String =
+ if (isDistinct &&
distinctUnsupportedAggregateFunctions.contains(funcName)) {
+ throw new
UnsupportedOperationException(s"${this.getClass.getSimpleName} does not " +
+ s"support aggregate function: $funcName with DISTINCT");
+ } else {
+ super.visitAggregateFunction(funcName, isDistinct, inputs)
+ }
+
override def dialectFunctionName(funcName: String): String = funcName
match {
case "VAR_POP" => "VARIANCE"
case "VAR_SAMP" => "VARIANCE_SAMP"
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala
b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala
index cc04b5c7c92..7dc76eed49f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala
@@ -22,13 +22,14 @@ import java.util
import java.util.Locale
import scala.collection.mutable.ArrayBuilder
+import scala.util.control.NonFatal
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.SQLConfHelper
import org.apache.spark.sql.catalyst.analysis.{IndexAlreadyExistsException,
NoSuchIndexException}
import org.apache.spark.sql.connector.catalog.Identifier
import org.apache.spark.sql.connector.catalog.index.TableIndex
-import org.apache.spark.sql.connector.expressions.{FieldReference,
NamedReference}
+import org.apache.spark.sql.connector.expressions.{Expression, FieldReference,
NamedReference}
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcUtils}
import org.apache.spark.sql.types.{BooleanType, DataType, FloatType, LongType,
MetadataBuilder}
@@ -38,14 +39,39 @@ private case object MySQLDialect extends JdbcDialect with
SQLConfHelper {
override def canHandle(url : String): Boolean =
url.toLowerCase(Locale.ROOT).startsWith("jdbc:mysql")
+ private val distinctUnsupportedAggregateFunctions =
+ Set("VAR_POP", "VAR_SAMP", "STDDEV_POP", "STDDEV_SAMP")
+
// See https://dev.mysql.com/doc/refman/8.0/en/aggregate-functions.html
- private val supportedAggregateFunctions = Set("MAX", "MIN", "SUM", "COUNT",
"AVG",
- "VAR_POP", "VAR_SAMP", "STDDEV_POP", "STDDEV_SAMP")
+ private val supportedAggregateFunctions =
+ Set("MAX", "MIN", "SUM", "COUNT", "AVG") ++
distinctUnsupportedAggregateFunctions
private val supportedFunctions = supportedAggregateFunctions
override def isSupportedFunction(funcName: String): Boolean =
supportedFunctions.contains(funcName)
+ class MySQLSQLBuilder extends JDBCSQLBuilder {
+ override def visitAggregateFunction(
+ funcName: String, isDistinct: Boolean, inputs: Array[String]): String =
+ if (isDistinct &&
distinctUnsupportedAggregateFunctions.contains(funcName)) {
+ throw new
UnsupportedOperationException(s"${this.getClass.getSimpleName} does not " +
+ s"support aggregate function: $funcName with DISTINCT");
+ } else {
+ super.visitAggregateFunction(funcName, isDistinct, inputs)
+ }
+ }
+
+ override def compileExpression(expr: Expression): Option[String] = {
+ val mysqlSQLBuilder = new MySQLSQLBuilder()
+ try {
+ Some(mysqlSQLBuilder.build(expr))
+ } catch {
+ case NonFatal(e) =>
+ logWarning("Error occurs while compiling V2 expression", e)
+ None
+ }
+ }
+
override def getCatalystType(
sqlType: Int, typeName: String, size: Int, md: MetadataBuilder):
Option[DataType] = {
if (sqlType == Types.VARBINARY && typeName.equals("BIT") && size != 1) {
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala
b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala
index 820bff354ca..79ac248d723 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala
@@ -20,7 +20,10 @@ package org.apache.spark.sql.jdbc
import java.sql.{Date, Timestamp, Types}
import java.util.{Locale, TimeZone}
+import scala.util.control.NonFatal
+
import org.apache.spark.sql.catalyst.util.DateTimeUtils
+import org.apache.spark.sql.connector.expressions.Expression
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
@@ -33,16 +36,42 @@ private case object OracleDialect extends JdbcDialect {
override def canHandle(url: String): Boolean =
url.toLowerCase(Locale.ROOT).startsWith("jdbc:oracle")
+ private val distinctUnsupportedAggregateFunctions =
+ Set("VAR_POP", "VAR_SAMP", "STDDEV_POP", "STDDEV_SAMP", "COVAR_POP",
"COVAR_SAMP", "CORR",
+ "REGR_INTERCEPT", "REGR_R2", "REGR_SLOPE", "REGR_SXY")
+
// scalastyle:off line.size.limit
//
https://docs.oracle.com/en/database/oracle/oracle-database/19/sqlrf/Aggregate-Functions.html#GUID-62BE676B-AF18-4E63-BD14-25206FEA0848
// scalastyle:on line.size.limit
- private val supportedAggregateFunctions = Set("MAX", "MIN", "SUM", "COUNT",
"AVG",
- "VAR_POP", "VAR_SAMP", "STDDEV_POP", "STDDEV_SAMP", "COVAR_POP",
"COVAR_SAMP", "CORR")
+ private val supportedAggregateFunctions =
+ Set("MAX", "MIN", "SUM", "COUNT", "AVG") ++
distinctUnsupportedAggregateFunctions
private val supportedFunctions = supportedAggregateFunctions
override def isSupportedFunction(funcName: String): Boolean =
supportedFunctions.contains(funcName)
+ class OracleSQLBuilder extends JDBCSQLBuilder {
+ override def visitAggregateFunction(
+ funcName: String, isDistinct: Boolean, inputs: Array[String]): String =
+ if (isDistinct &&
distinctUnsupportedAggregateFunctions.contains(funcName)) {
+ throw new
UnsupportedOperationException(s"${this.getClass.getSimpleName} does not " +
+ s"support aggregate function: $funcName with DISTINCT");
+ } else {
+ super.visitAggregateFunction(funcName, isDistinct, inputs)
+ }
+ }
+
+ override def compileExpression(expr: Expression): Option[String] = {
+ val oracleSQLBuilder = new OracleSQLBuilder()
+ try {
+ Some(oracleSQLBuilder.build(expr))
+ } catch {
+ case NonFatal(e) =>
+ logWarning("Error occurs while compiling V2 expression", e)
+ None
+ }
+ }
+
private def supportTimeZoneTypes: Boolean = {
val timeZone = DateTimeUtils.getTimeZone(SQLConf.get.sessionLocalTimeZone)
// TODO: support timezone types when users are not using the JVM timezone,
which
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala
b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala
index cb78bc806e2..878d7a7cfe6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala
@@ -38,7 +38,8 @@ private object PostgresDialect extends JdbcDialect with
SQLConfHelper {
// See https://www.postgresql.org/docs/8.4/functions-aggregate.html
private val supportedAggregateFunctions = Set("MAX", "MIN", "SUM", "COUNT",
"AVG",
- "VAR_POP", "VAR_SAMP", "STDDEV_POP", "STDDEV_SAMP", "COVAR_POP",
"COVAR_SAMP", "CORR")
+ "VAR_POP", "VAR_SAMP", "STDDEV_POP", "STDDEV_SAMP", "COVAR_POP",
"COVAR_SAMP", "CORR",
+ "REGR_INTERCEPT", "REGR_R2", "REGR_SLOPE", "REGR_SXY")
private val supportedFunctions = supportedAggregateFunctions
override def isSupportedFunction(funcName: String): Boolean =
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]