This is an automated email from the ASF dual-hosted git repository. beliefer pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 971d318cae8e [SPARK-51585][SQL] Oracle dialect supports pushdown datetime functions 971d318cae8e is described below commit 971d318cae8e0af33d09f49d0fe66a3a3cb02c90 Author: beliefer <belie...@163.com> AuthorDate: Wed Aug 6 15:34:55 2025 +0800 [SPARK-51585][SQL] Oracle dialect supports pushdown datetime functions ### What changes were proposed in this pull request? This PR propose to make Oracle dialect supports pushdown datetime functions. ### Why are the changes needed? Currently, DS V2 pushdown framework pushed the datetime functions with in a common way. But Oracle doesn't support some datetime functions. ### Does this PR introduce _any_ user-facing change? 'No'. This is a new feature for Oracle dialect. ### How was this patch tested? GA. ### Was this patch authored or co-authored using generative AI tooling? 'No'. Closes #50353 from beliefer/SPARK-51585. Authored-by: beliefer <belie...@163.com> Signed-off-by: beliefer <belie...@163.com> --- .../spark/sql/jdbc/v2/OracleIntegrationSuite.scala | 164 ++++++++++++++++++++- .../org/apache/spark/sql/jdbc/OracleDialect.scala | 28 +++- 2 files changed, 186 insertions(+), 6 deletions(-) diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/OracleIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/OracleIntegrationSuite.scala index 7a58fca17970..c71f9ae7688f 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/OracleIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/OracleIntegrationSuite.scala @@ -118,11 +118,28 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTes "CREATE TABLE employee (dept NUMBER(32), name VARCHAR2(32), salary NUMBER(20, 2)," + " bonus BINARY_DOUBLE)").executeUpdate() connection.prepareStatement( - s"""CREATE TABLE pattern_testing_table ( - |pattern_testing_col VARCHAR(50) - |) - """.stripMargin + """CREATE TABLE pattern_testing_table ( + |pattern_testing_col VARCHAR(50) + |) + """.stripMargin ).executeUpdate() + connection.prepareStatement( + "CREATE TABLE datetime (name VARCHAR(32), date1 DATE, time1 TIMESTAMP)") + .executeUpdate() + } + + override def dataPreparation(connection: Connection): Unit = { + super.dataPreparation(connection) + connection.prepareStatement( + "INSERT INTO datetime VALUES ('amy', TO_DATE('2022-05-19', 'YYYY-MM-DD')," + + " TO_TIMESTAMP('2022-05-19 00:00:00', 'YYYY-MM-DD HH24:MI:SS'))").executeUpdate() + connection.prepareStatement( + "INSERT INTO datetime VALUES ('alex', TO_DATE('2022-05-18', 'YYYY-MM-DD')," + + " TO_TIMESTAMP('2022-05-18 00:00:00', 'YYYY-MM-DD HH24:MI:SS'))").executeUpdate() + // '2022-01-01' is Saturday and is in ISO year 2021. + connection.prepareStatement( + "INSERT INTO datetime VALUES ('tom', TO_DATE('2022-01-01', 'YYYY-MM-DD')," + + " TO_TIMESTAMP('2022-01-01 00:00:00', 'YYYY-MM-DD HH24:MI:SS'))").executeUpdate() } override def testUpdateColumnType(tbl: String): Unit = { @@ -185,4 +202,143 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTes checkAnswer(sql(s"SELECT * FROM $tableName"), Seq(Row("Eason", "Y "))) } } + + override def testDatetime(tbl: String): Unit = { + val df1 = sql(s"SELECT name FROM $tbl WHERE " + + "dayofyear(date1) > 100 AND dayofmonth(date1) > 10 ") + checkFilterPushed(df1, false) + val rows1 = df1.collect() + assert(rows1.length === 2) + assert(rows1(0).getString(0) === "amy") + assert(rows1(1).getString(0) === "alex") + + val df2 = sql(s"SELECT name FROM $tbl WHERE year(date1) = 2022 AND quarter(date1) = 2") + checkFilterPushed(df2, false) + val rows2 = df2.collect() + assert(rows2.length === 2) + assert(rows2(0).getString(0) === "amy") + assert(rows2(1).getString(0) === "alex") + + val df3 = sql(s"SELECT name FROM $tbl WHERE month(date1) = 5") + checkFilterPushed(df3) + val rows3 = df3.collect() + assert(rows3.length === 2) + assert(rows3(0).getString(0) === "amy") + assert(rows3(1).getString(0) === "alex") + + val df4 = sql(s"SELECT name FROM $tbl WHERE hour(time1) = 0 AND minute(time1) = 0") + checkFilterPushed(df4) + val rows4 = df4.collect() + assert(rows4.length === 3) + assert(rows4(0).getString(0) === "amy") + assert(rows4(1).getString(0) === "alex") + assert(rows4(2).getString(0) === "tom") + + val df5 = sql(s"SELECT name FROM $tbl WHERE " + + "extract(WEEK from date1) > 10 AND extract(YEAR from date1) = 2022") + checkFilterPushed(df5, false) + val rows5 = df5.collect() + assert(rows5.length === 3) + assert(rows5(0).getString(0) === "amy") + assert(rows5(1).getString(0) === "alex") + assert(rows5(2).getString(0) === "tom") + + val df6 = sql(s"SELECT name FROM $tbl WHERE date_add(date1, 1) = date'2022-05-20' " + + "AND datediff(date1, '2022-05-10') > 0") + checkFilterPushed(df6, false) + val rows6 = df6.collect() + assert(rows6.length === 1) + assert(rows6(0).getString(0) === "amy") + + val df7 = sql(s"SELECT name FROM $tbl WHERE weekday(date1) = 2") + checkFilterPushed(df7, false) + val rows7 = df7.collect() + assert(rows7.length === 1) + assert(rows7(0).getString(0) === "alex") + + withClue("weekofyear") { + val woy = sql(s"SELECT weekofyear(date1) FROM $tbl WHERE name = 'tom'") + .collect().head.getInt(0) + val df = sql(s"SELECT name FROM $tbl WHERE weekofyear(date1) = $woy") + checkFilterPushed(df, false) + val rows = df.collect() + assert(rows.length === 1) + assert(rows(0).getString(0) === "tom") + } + + withClue("dayofweek") { + val dow = sql(s"SELECT dayofweek(date1) FROM $tbl WHERE name = 'alex'") + .collect().head.getInt(0) + val df = sql(s"SELECT name FROM $tbl WHERE dayofweek(date1) = $dow") + checkFilterPushed(df, false) + val rows = df.collect() + assert(rows.length === 1) + assert(rows(0).getString(0) === "alex") + } + + withClue("yearofweek") { + val yow = sql(s"SELECT extract(YEAROFWEEK from date1) FROM $tbl WHERE name = 'tom'") + .collect().head.getInt(0) + val df = sql(s"SELECT name FROM $tbl WHERE extract(YEAROFWEEK from date1) = $yow") + checkFilterPushed(df, false) + val rows = df.collect() + assert(rows.length === 1) + assert(rows(0).getString(0) === "tom") + } + + withClue("dayofyear") { + val doy = sql(s"SELECT dayofyear(date1) FROM $tbl WHERE name = 'amy'") + .collect().head.getInt(0) + val df = sql(s"SELECT name FROM $tbl WHERE dayofyear(date1) = $doy") + checkFilterPushed(df, false) + val rows = df.collect() + assert(rows.length === 1) + assert(rows(0).getString(0) === "amy") + } + + withClue("dayofmonth") { + val dom = sql(s"SELECT dayofmonth(date1) FROM $tbl WHERE name = 'amy'") + .collect().head.getInt(0) + val df = sql(s"SELECT name FROM $tbl WHERE dayofmonth(date1) = $dom") + checkFilterPushed(df) + val rows = df.collect() + assert(rows.length === 1) + assert(rows(0).getString(0) === "amy") + } + + withClue("year") { + val year = sql(s"SELECT year(date1) FROM $tbl WHERE name = 'amy'") + .collect().head.getInt(0) + val df = sql(s"SELECT name FROM $tbl WHERE year(date1) = $year") + checkFilterPushed(df) + val rows = df.collect() + assert(rows.length === 3) + assert(rows(0).getString(0) === "amy") + assert(rows5(1).getString(0) === "alex") + assert(rows5(2).getString(0) === "tom") + } + + withClue("second") { + val df = sql(s"SELECT name FROM $tbl WHERE second(time1) = 0 AND month(date1) = 5") + checkFilterPushed(df, false) + val rows = df.collect() + assert(rows.length === 2) + assert(rows(0).getString(0) === "amy") + assert(rows(1).getString(0) === "alex") + } + + val df9 = sql(s"SELECT name FROM $tbl WHERE " + + "dayofyear(date1) > 100 order by dayofyear(date1) limit 1") + checkFilterPushed(df9, false) + val rows9 = df9.collect() + assert(rows9.length === 1) + assert(rows9(0).getString(0) === "alex") + + val df10 = sql(s"SELECT name FROM $tbl WHERE trunc(date1, 'week') = date'2022-05-16'") + checkFilterPushed(df10) + val rows10 = df10.collect() + assert(rows10.length === 2) + assert(rows10(0).getString(0) === "amy") + assert(rows10(1).getString(0) === "alex") + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala index 0c9c84f3f3e7..81031b1ec13d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala @@ -24,7 +24,7 @@ import scala.util.control.NonFatal import org.apache.spark.{SparkThrowable, SparkUnsupportedOperationException} import org.apache.spark.sql.catalyst.SQLConfHelper -import org.apache.spark.sql.connector.expressions.{Expression, Literal} +import org.apache.spark.sql.connector.expressions.{Expression, Extract, Literal} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions import org.apache.spark.sql.jdbc.OracleDialect._ @@ -44,7 +44,7 @@ private case class OracleDialect() extends JdbcDialect with SQLConfHelper with N // scalastyle:on line.size.limit private val supportedAggregateFunctions = Set("MAX", "MIN", "SUM", "COUNT", "AVG") ++ distinctUnsupportedAggregateFunctions - private val supportedFunctions = supportedAggregateFunctions + private val supportedFunctions = supportedAggregateFunctions ++ Set("TRUNC") override def isSupportedFunction(funcName: String): Boolean = supportedFunctions.contains(funcName) @@ -56,6 +56,30 @@ private case class OracleDialect() extends JdbcDialect with SQLConfHelper with N class OracleSQLBuilder extends JDBCSQLBuilder { + override def visitExtract(extract: Extract): String = { + val field = extract.field + field match { + // YEAR, MONTH, DAY, HOUR, MINUTE are identical on Oracle and Spark for + // both datetime and interval types. + case "YEAR" | "MONTH" | "DAY" | "HOUR" | "MINUTE" => + super.visitExtract(field, build(extract.source())) + // Oracle does not support the following date fields: DAY_OF_YEAR, WEEK, QUARTER, + // DAY_OF_WEEK, or YEAR_OF_WEEK. + // We can't push down SECOND due to the difference in result types between Spark and + // Oracle. Spark returns decimal(8, 6), but Oracle returns integer. + case _ => + visitUnexpectedExpr(extract) + } + } + + override def visitSQLFunction(funcName: String, inputs: Array[String]): String = { + funcName match { + case "TRUNC" => + s"TRUNC(${inputs(0)}, 'IW')" + case _ => super.visitSQLFunction(funcName, inputs) + } + } + override def visitAggregateFunction( funcName: String, isDistinct: Boolean, inputs: Array[String]): String = if (isDistinct && distinctUnsupportedAggregateFunctions.contains(funcName)) { --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org