This is an automated email from the ASF dual-hosted git repository.

yao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 2b0e841a3534 [SPARK-47496][SQL] Java SPI Support for dynamic JDBC 
dialect registering
2b0e841a3534 is described below

commit 2b0e841a35343343c82e8ca15225014b64d8c59f
Author: Kent Yao <y...@apache.org>
AuthorDate: Thu Mar 21 19:34:28 2024 +0800

    [SPARK-47496][SQL] Java SPI Support for dynamic JDBC dialect registering
    
    ### What changes were proposed in this pull request?
    
    This PR brings the Java ServiceProvider Interface (SPI) Support for dynamic 
JDBC dialect registering.
    
    A custom JDBC dialect can be registered easily instead of calling 
JdbcDialects.registerDialect manually.
    
    ### Why are the changes needed?
    
    For pure SQL and other non-Java API users, it's difficult to register a 
custom JDBC dialect to use. With this patch, this can be done when the jar 
containing the dialect class is visible to the spark classloader.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, but mostly for third-party developers
    
    ### How was this patch tested?
    
    new tests
    
    ### Was this patch authored or co-authored using generative AI tooling?
    no
    
    Closes #45626 from yaooqinn/SPARK-47496.
    
    Authored-by: Kent Yao <y...@apache.org>
    Signed-off-by: Kent Yao <y...@apache.org>
---
 project/MimaExcludes.scala                         | 10 +++-
 .../services/org.apache.spark.sql.jdbc.JdbcDialect | 29 ++++++++++
 .../org/apache/spark/sql/jdbc/DB2Dialect.scala     |  2 +-
 .../apache/spark/sql/jdbc/DatabricksDialect.scala  |  2 +-
 .../org/apache/spark/sql/jdbc/DerbyDialect.scala   |  2 +-
 .../org/apache/spark/sql/jdbc/H2Dialect.scala      |  2 +-
 .../org/apache/spark/sql/jdbc/JdbcDialects.scala   | 20 +++----
 .../apache/spark/sql/jdbc/MsSqlServerDialect.scala | 20 +++----
 .../org/apache/spark/sql/jdbc/MySQLDialect.scala   |  2 +-
 .../org/apache/spark/sql/jdbc/OracleDialect.scala  | 18 ++++---
 .../apache/spark/sql/jdbc/PostgresDialect.scala    |  2 +-
 .../apache/spark/sql/jdbc/SnowflakeDialect.scala   |  2 +-
 .../apache/spark/sql/jdbc/TeradataDialect.scala    |  2 +-
 .../services/org.apache.spark.sql.jdbc.JdbcDialect | 20 +++++++
 .../spark/sql/jdbc/DummyDatabaseDialect.scala}     | 18 +------
 .../org/apache/spark/sql/jdbc/JDBCSuite.scala      | 61 +++++++++++++---------
 .../org/apache/spark/sql/jdbc/JDBCV2Suite.scala    | 52 +++++++++---------
 .../org/apache/spark/sql/jdbc/JDBCWriteSuite.scala |  4 +-
 18 files changed, 163 insertions(+), 105 deletions(-)

diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 225a13cd3537..630dd1d77cc7 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -82,7 +82,15 @@ object MimaExcludes {
     
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.evaluation.BinaryClassificationMetrics.scoreLabelsWeight"),
     // SPARK-46938: Javax -> Jakarta namespace change.
     
ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ui.ProxyRedirectHandler$ResponseWrapper"),
-    
ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ui.ProxyRedirectHandler#ResponseWrapper.this")
+    
ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ui.ProxyRedirectHandler#ResponseWrapper.this"),
+    
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.jdbc.DB2Dialect#DB2SQLBuilder.this"),
+    
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.jdbc.DB2Dialect#DB2SQLQueryBuilder.this"),
+    
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.jdbc.MsSqlServerDialect#MsSqlServerSQLBuilder.this"),
+    
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.jdbc.MsSqlServerDialect#MsSqlServerSQLQueryBuilder.this"),
+    
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.jdbc.MySQLDialect#MySQLSQLBuilder.this"),
+    
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.jdbc.MySQLDialect#MySQLSQLQueryBuilder.this"),
+    
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.jdbc.OracleDialect#OracleSQLBuilder.this"),
+    
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.jdbc.OracleDialect#OracleSQLQueryBuilder.this")
   )
 
   // Default exclude rules
diff --git 
a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.jdbc.JdbcDialect
 
b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.jdbc.JdbcDialect
new file mode 100644
index 000000000000..0b9dda2d14f2
--- /dev/null
+++ 
b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.jdbc.JdbcDialect
@@ -0,0 +1,29 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+org.apache.spark.sql.jdbc.MySQLDialect
+org.apache.spark.sql.jdbc.PostgresDialect
+org.apache.spark.sql.jdbc.DB2Dialect
+org.apache.spark.sql.jdbc.MsSqlServerDialect
+org.apache.spark.sql.jdbc.DerbyDialect
+org.apache.spark.sql.jdbc.OracleDialect
+org.apache.spark.sql.jdbc.TeradataDialect
+org.apache.spark.sql.jdbc.H2Dialect
+org.apache.spark.sql.jdbc.SnowflakeDialect
+org.apache.spark.sql.jdbc.DatabricksDialect
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala
index 62c31b1c4c5d..31a7c783ba60 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala
@@ -30,7 +30,7 @@ import org.apache.spark.sql.connector.expressions.Expression
 import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions
 import org.apache.spark.sql.types._
 
-private object DB2Dialect extends JdbcDialect {
+private case class DB2Dialect() extends JdbcDialect {
 
   override def canHandle(url: String): Boolean =
     url.toLowerCase(Locale.ROOT).startsWith("jdbc:db2")
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DatabricksDialect.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DatabricksDialect.scala
index c905374c1678..54b8c2622827 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DatabricksDialect.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DatabricksDialect.scala
@@ -25,7 +25,7 @@ import 
org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions
 import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo
 import org.apache.spark.sql.types._
 
-private case object DatabricksDialect extends JdbcDialect {
+private case class DatabricksDialect() extends JdbcDialect {
 
   override def canHandle(url: String): Boolean = {
     url.startsWith("jdbc:databricks")
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala
index 545cbf265bb0..36af0e6aeaf1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala
@@ -25,7 +25,7 @@ import org.apache.spark.sql.errors.{QueryCompilationErrors, 
QueryExecutionErrors
 import org.apache.spark.sql.types._
 
 
-private object DerbyDialect extends JdbcDialect {
+private case class DerbyDialect() extends JdbcDialect {
 
   override def canHandle(url: String): Boolean =
     url.toLowerCase(Locale.ROOT).startsWith("jdbc:derby")
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala
index f4a1650b3e8c..ebfc6093dc16 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala
@@ -37,7 +37,7 @@ import 
org.apache.spark.sql.connector.expressions.{Expression, FieldReference, N
 import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcUtils}
 import org.apache.spark.sql.types.{BooleanType, ByteType, DataType, 
DecimalType, MetadataBuilder, ShortType, StringType, TimestampType}
 
-private[sql] object H2Dialect extends JdbcDialect {
+private[sql] case class H2Dialect() extends JdbcDialect {
   override def canHandle(url: String): Boolean =
     url.toLowerCase(Locale.ROOT).startsWith("jdbc:h2")
 
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala
index 7d2812d48cae..845161c81ea5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.jdbc
 import java.sql.{Connection, Date, Driver, Statement, Timestamp}
 import java.time.{Instant, LocalDate, LocalDateTime}
 import java.util
+import java.util.ServiceLoader
 
 import scala.collection.mutable.ArrayBuilder
 import scala.util.control.NonFatal
@@ -46,6 +47,7 @@ import 
org.apache.spark.sql.execution.datasources.jdbc.connection.ConnectionProv
 import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types._
+import org.apache.spark.util.Utils
 
 /**
  * :: DeveloperApi ::
@@ -825,16 +827,14 @@ object JdbcDialects {
 
   private[this] var dialects = List[JdbcDialect]()
 
-  registerDialect(MySQLDialect)
-  registerDialect(PostgresDialect)
-  registerDialect(DB2Dialect)
-  registerDialect(MsSqlServerDialect)
-  registerDialect(DerbyDialect)
-  registerDialect(OracleDialect)
-  registerDialect(TeradataDialect)
-  registerDialect(H2Dialect)
-  registerDialect(SnowflakeDialect)
-  registerDialect(DatabricksDialect)
+  private def registerDialects(): Unit = {
+    val loader = ServiceLoader.load(classOf[JdbcDialect], 
Utils.getContextOrSparkClassLoader)
+    val iter = loader.iterator()
+    while (iter.hasNext) {
+      registerDialect(iter.next())
+    }
+  }
+  registerDialects()
 
   /**
    * Fetch the JdbcDialect class corresponding to a given database url.
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala
index aaee6be24e61..1b6dc1af9ec0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala
@@ -29,18 +29,11 @@ import 
org.apache.spark.sql.connector.expressions.{Expression, NullOrdering, Sor
 import org.apache.spark.sql.errors.QueryExecutionErrors
 import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions
 import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.jdbc.MsSqlServerDialect.{GEOGRAPHY, GEOMETRY}
 import org.apache.spark.sql.types._
 
 
-private object MsSqlServerDialect extends JdbcDialect {
-
-  // Special JDBC types in Microsoft SQL Server.
-  // 
https://github.com/microsoft/mssql-jdbc/blob/v9.4.1/src/main/java/microsoft/sql/Types.java
-  private object SpecificTypes {
-    val GEOMETRY = -157
-    val GEOGRAPHY = -158
-  }
-
+private case class MsSqlServerDialect() extends JdbcDialect {
   override def canHandle(url: String): Boolean =
     url.toLowerCase(Locale.ROOT).startsWith("jdbc:sqlserver")
 
@@ -113,7 +106,7 @@ private object MsSqlServerDialect extends JdbcDialect {
           // Reference doc: 
https://learn.microsoft.com/en-us/sql/t-sql/data-types
           case java.sql.Types.SMALLINT | java.sql.Types.TINYINT => 
Some(ShortType)
           case java.sql.Types.REAL => Some(FloatType)
-          case SpecificTypes.GEOMETRY | SpecificTypes.GEOGRAPHY => 
Some(BinaryType)
+          case GEOMETRY | GEOGRAPHY => Some(BinaryType)
           case _ => None
         }
       }
@@ -226,3 +219,10 @@ private object MsSqlServerDialect extends JdbcDialect {
 
   override def supportsLimit: Boolean = true
 }
+
+private object MsSqlServerDialect {
+  // Special JDBC types in Microsoft SQL Server.
+  // 
https://github.com/microsoft/mssql-jdbc/blob/v9.4.1/src/main/java/microsoft/sql/Types.java
+  final val GEOMETRY = -157
+  final val GEOGRAPHY = -158
+}
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala
index a245458a5cb4..292e3ca2d5e4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala
@@ -35,7 +35,7 @@ import org.apache.spark.sql.errors.QueryExecutionErrors
 import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcUtils}
 import org.apache.spark.sql.types._
 
-private case object MySQLDialect extends JdbcDialect with SQLConfHelper {
+private case class MySQLDialect() extends JdbcDialect with SQLConfHelper {
 
   override def canHandle(url : String): Boolean =
     url.toLowerCase(Locale.ROOT).startsWith("jdbc:mysql")
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala
index 544c0197dec9..a9c246c93879 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala
@@ -25,17 +25,11 @@ import scala.util.control.NonFatal
 import org.apache.spark.SparkUnsupportedOperationException
 import org.apache.spark.sql.connector.expressions.Expression
 import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions
+import org.apache.spark.sql.jdbc.OracleDialect._
 import org.apache.spark.sql.types._
 
 
-private case object OracleDialect extends JdbcDialect {
-  private[jdbc] val BINARY_FLOAT = 100
-  private[jdbc] val BINARY_DOUBLE = 101
-  private[jdbc] val TIMESTAMP_TZ = -101
-  // oracle.jdbc.OracleType.TIMESTAMP_WITH_LOCAL_TIME_ZONE
-  private[jdbc] val TIMESTAMP_LTZ = -102
-
-
+private case class OracleDialect() extends JdbcDialect {
   override def canHandle(url: String): Boolean =
     url.toLowerCase(Locale.ROOT).startsWith("jdbc:oracle")
 
@@ -230,3 +224,11 @@ private case object OracleDialect extends JdbcDialect {
 
   override def supportsOffset: Boolean = true
 }
+
+private[jdbc] object OracleDialect {
+  final val BINARY_FLOAT = 100
+  final val BINARY_DOUBLE = 101
+  final val TIMESTAMP_TZ = -101
+  // oracle.jdbc.OracleType.TIMESTAMP_WITH_LOCAL_TIME_ZONE
+  final val TIMESTAMP_LTZ = -102
+}
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala
index c9737867d3e0..5c949b28ba7c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala
@@ -33,7 +33,7 @@ import 
org.apache.spark.sql.execution.datasources.v2.TableSampleInfo
 import org.apache.spark.sql.types._
 
 
-private object PostgresDialect extends JdbcDialect with SQLConfHelper {
+private case class PostgresDialect() extends JdbcDialect with SQLConfHelper {
 
   override def canHandle(url: String): Boolean =
     url.toLowerCase(Locale.ROOT).startsWith("jdbc:postgresql")
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/SnowflakeDialect.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/SnowflakeDialect.scala
index d8a8fe6ba4a9..276364d5d89e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/SnowflakeDialect.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/SnowflakeDialect.scala
@@ -22,7 +22,7 @@ import java.util.Locale
 import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils
 import org.apache.spark.sql.types.{BooleanType, DataType}
 
-private case object SnowflakeDialect extends JdbcDialect {
+private case class SnowflakeDialect() extends JdbcDialect {
   override def canHandle(url: String): Boolean =
     url.toLowerCase(Locale.ROOT).startsWith("jdbc:snowflake")
 
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala
index 0f0812bdaeb9..7acd22a3f10b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala
@@ -24,7 +24,7 @@ import org.apache.spark.sql.connector.catalog.Identifier
 import org.apache.spark.sql.types._
 
 
-private case object TeradataDialect extends JdbcDialect {
+private case class TeradataDialect() extends JdbcDialect {
 
   override def canHandle(url: String): Boolean =
     url.toLowerCase(Locale.ROOT).startsWith("jdbc:teradata")
diff --git 
a/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.jdbc.JdbcDialect
 
b/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.jdbc.JdbcDialect
new file mode 100644
index 000000000000..ce96a578e50c
--- /dev/null
+++ 
b/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.jdbc.JdbcDialect
@@ -0,0 +1,20 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+org.apache.spark.sql.jdbc.DummyDatabaseDialect
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/SnowflakeDialect.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/DummyDatabaseDialect.scala
similarity index 56%
copy from 
sql/core/src/main/scala/org/apache/spark/sql/jdbc/SnowflakeDialect.scala
copy to 
sql/core/src/test/scala/org/apache/spark/sql/jdbc/DummyDatabaseDialect.scala
index d8a8fe6ba4a9..a8bca85dcb65 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/SnowflakeDialect.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/DummyDatabaseDialect.scala
@@ -17,20 +17,6 @@
 
 package org.apache.spark.sql.jdbc
 
-import java.util.Locale
-
-import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils
-import org.apache.spark.sql.types.{BooleanType, DataType}
-
-private case object SnowflakeDialect extends JdbcDialect {
-  override def canHandle(url: String): Boolean =
-    url.toLowerCase(Locale.ROOT).startsWith("jdbc:snowflake")
-
-  override def getJDBCType(dt: DataType): Option[JdbcType] = dt match {
-    case BooleanType =>
-      // By default, BOOLEAN is mapped to BIT(1).
-      // but Snowflake does not have a BIT type. It uses BOOLEAN instead.
-      Some(JdbcType("BOOLEAN", java.sql.Types.BOOLEAN))
-    case _ => JdbcUtils.getCommonJDBCType(dt)
-  }
+class DummyDatabaseDialect extends JdbcDialect {
+  override def canHandle(url: String): Boolean = url.startsWith("jdbc:dummy")
 }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
index a2dac5a9e1e9..e2bdd8aee97d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
@@ -786,12 +786,12 @@ class JDBCSuite extends QueryTest with SharedSparkSession 
{
   }
 
   test("Default jdbc dialect registration") {
-    assert(JdbcDialects.get("jdbc:mysql://127.0.0.1/db") == MySQLDialect)
-    assert(JdbcDialects.get("jdbc:postgresql://127.0.0.1/db") == 
PostgresDialect)
-    assert(JdbcDialects.get("jdbc:db2://127.0.0.1/db") == DB2Dialect)
-    assert(JdbcDialects.get("jdbc:sqlserver://127.0.0.1/db") == 
MsSqlServerDialect)
-    assert(JdbcDialects.get("jdbc:derby:db") == DerbyDialect)
-    assert(JdbcDialects.get("test.invalid") == NoopDialect)
+    assert(JdbcDialects.get("jdbc:mysql://127.0.0.1/db") === MySQLDialect())
+    assert(JdbcDialects.get("jdbc:postgresql://127.0.0.1/db") === 
PostgresDialect())
+    assert(JdbcDialects.get("jdbc:db2://127.0.0.1/db") === DB2Dialect())
+    assert(JdbcDialects.get("jdbc:sqlserver://127.0.0.1/db") === 
MsSqlServerDialect())
+    assert(JdbcDialects.get("jdbc:derby:db") === DerbyDialect())
+    assert(JdbcDialects.get("test.invalid") === NoopDialect)
   }
 
   test("quote column names by jdbc dialect") {
@@ -846,13 +846,13 @@ class JDBCSuite extends QueryTest with SharedSparkSession 
{
   }
 
   test("Dialect unregister") {
-    JdbcDialects.unregisterDialect(H2Dialect)
+    JdbcDialects.unregisterDialect(H2Dialect())
     try {
       JdbcDialects.registerDialect(testH2Dialect)
       JdbcDialects.unregisterDialect(testH2Dialect)
       assert(JdbcDialects.get(urlWithUserAndPass) == NoopDialect)
     } finally {
-      JdbcDialects.registerDialect(H2Dialect)
+      JdbcDialects.registerDialect(H2Dialect())
     }
   }
 
@@ -997,7 +997,7 @@ class JDBCSuite extends QueryTest with SharedSparkSession {
     // JDBC url is a required option but is not used in this test.
     val options = new JDBCOptions(Map("url" -> "jdbc:h2://host:port", 
"dbtable" -> "test"))
     assert(
-      OracleDialect
+      OracleDialect()
         .getJdbcSQLQueryBuilder(options)
         .withColumns(Array("a", "b"))
         .withLimit(123)
@@ -1053,7 +1053,7 @@ class JDBCSuite extends QueryTest with SharedSparkSession 
{
     // JDBC url is a required option but is not used in this test.
     val options = new JDBCOptions(Map("url" -> "jdbc:h2://host:port", 
"dbtable" -> "test"))
     assert(
-      MsSqlServerDialect
+      MsSqlServerDialect()
         .getJdbcSQLQueryBuilder(options)
         .withColumns(Array("a", "b"))
         .withLimit(123)
@@ -1066,7 +1066,7 @@ class JDBCSuite extends QueryTest with SharedSparkSession 
{
     // JDBC url is a required option but is not used in this test.
     val options = new JDBCOptions(Map("url" -> "jdbc:db2://host:port", 
"dbtable" -> "test"))
     assert(
-      DB2Dialect
+      DB2Dialect()
         .getJdbcSQLQueryBuilder(options)
         .withColumns(Array("a", "b"))
         .withLimit(123)
@@ -1938,20 +1938,20 @@ class JDBCSuite extends QueryTest with 
SharedSparkSession {
   }
 
   test("SPARK-28552: Case-insensitive database URLs in JdbcDialect") {
-    assert(JdbcDialects.get("jdbc:mysql://localhost/db") === MySQLDialect)
-    assert(JdbcDialects.get("jdbc:MySQL://localhost/db") === MySQLDialect)
-    assert(JdbcDialects.get("jdbc:postgresql://localhost/db") === 
PostgresDialect)
-    assert(JdbcDialects.get("jdbc:postGresql://localhost/db") === 
PostgresDialect)
-    assert(JdbcDialects.get("jdbc:db2://localhost/db") === DB2Dialect)
-    assert(JdbcDialects.get("jdbc:DB2://localhost/db") === DB2Dialect)
-    assert(JdbcDialects.get("jdbc:sqlserver://localhost/db") === 
MsSqlServerDialect)
-    assert(JdbcDialects.get("jdbc:sqlServer://localhost/db") === 
MsSqlServerDialect)
-    assert(JdbcDialects.get("jdbc:derby://localhost/db") === DerbyDialect)
-    assert(JdbcDialects.get("jdbc:derBy://localhost/db") === DerbyDialect)
-    assert(JdbcDialects.get("jdbc:oracle://localhost/db") === OracleDialect)
-    assert(JdbcDialects.get("jdbc:Oracle://localhost/db") === OracleDialect)
-    assert(JdbcDialects.get("jdbc:teradata://localhost/db") === 
TeradataDialect)
-    assert(JdbcDialects.get("jdbc:Teradata://localhost/db") === 
TeradataDialect)
+    assert(JdbcDialects.get("jdbc:mysql://localhost/db") === MySQLDialect())
+    assert(JdbcDialects.get("jdbc:MySQL://localhost/db") === MySQLDialect())
+    assert(JdbcDialects.get("jdbc:postgresql://localhost/db") === 
PostgresDialect())
+    assert(JdbcDialects.get("jdbc:postGresql://localhost/db") === 
PostgresDialect())
+    assert(JdbcDialects.get("jdbc:db2://localhost/db") === DB2Dialect())
+    assert(JdbcDialects.get("jdbc:DB2://localhost/db") === DB2Dialect())
+    assert(JdbcDialects.get("jdbc:sqlserver://localhost/db") === 
MsSqlServerDialect())
+    assert(JdbcDialects.get("jdbc:sqlServer://localhost/db") === 
MsSqlServerDialect())
+    assert(JdbcDialects.get("jdbc:derby://localhost/db") === DerbyDialect())
+    assert(JdbcDialects.get("jdbc:derBy://localhost/db") === DerbyDialect())
+    assert(JdbcDialects.get("jdbc:oracle://localhost/db") === OracleDialect())
+    assert(JdbcDialects.get("jdbc:Oracle://localhost/db") === OracleDialect())
+    assert(JdbcDialects.get("jdbc:teradata://localhost/db") === 
TeradataDialect())
+    assert(JdbcDialects.get("jdbc:Teradata://localhost/db") === 
TeradataDialect())
   }
 
   test("SQLContext.jdbc (deprecated)") {
@@ -2099,7 +2099,8 @@ class JDBCSuite extends QueryTest with SharedSparkSession 
{
   }
 
   test("SPARK-45139: DatabricksDialect url handling") {
-    assert(JdbcDialects.get("jdbc:databricks://account.cloud.databricks.com") 
== DatabricksDialect)
+    assert(JdbcDialects.get("jdbc:databricks://account.cloud.databricks.com") 
===
+      DatabricksDialect())
   }
 
   test("SPARK-45139: DatabricksDialect catalyst type mapping") {
@@ -2154,4 +2155,12 @@ class JDBCSuite extends QueryTest with 
SharedSparkSession {
     val expected = Map("percentile_approx_val" -> 49)
     assert(namedObservation.get === expected)
   }
+
+  test("SPARK-47496: ServiceLoader support for JDBC dialects") {
+    var dialect = JdbcDialects.get("jdbc:dummy:dummy_host:dummy_port/dummy_db")
+    assert(dialect.isInstanceOf[DummyDatabaseDialect])
+    JdbcDialects.unregisterDialect(dialect)
+    dialect = JdbcDialects.get("jdbc:dummy:dummy_host:dummy_port/dummy_db")
+    assert(dialect === NoopDialect)
+  }
 }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
index 7bae2d77a161..1b3672cdba5a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
@@ -52,8 +52,12 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession 
with ExplainSuiteHel
   val testBytes = Array[Byte](99.toByte, 134.toByte, 135.toByte, 200.toByte, 
205.toByte) ++
     Array.fill(15)(0.toByte)
 
+  private val h2Dialect = JdbcDialects.get(url).asInstanceOf[H2Dialect]
+
   val testH2Dialect = new JdbcDialect {
-    override def canHandle(url: String): Boolean = H2Dialect.canHandle(url)
+    val h2 = JdbcDialects.get(url).asInstanceOf[H2Dialect]
+
+    override def canHandle(url: String): Boolean = h2.canHandle(url)
 
     override def supportsLimit: Boolean = false
 
@@ -102,7 +106,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession 
with ExplainSuiteHel
       }
     }
 
-    override def functions: Seq[(String, UnboundFunction)] = 
H2Dialect.functions
+    override def functions: Seq[(String, UnboundFunction)] = h2.functions
   }
 
   case object CharLength extends ScalarFunction[Int] {
@@ -225,15 +229,15 @@ class JDBCV2Suite extends QueryTest with 
SharedSparkSession with ExplainSuiteHel
       stmt.setBytes(2, testBytes)
       stmt.executeUpdate()
     }
-    H2Dialect.registerFunction("my_avg", IntegralAverage)
-    H2Dialect.registerFunction("my_strlen", StrLen(CharLength))
-    H2Dialect.registerFunction("my_strlen_magic", 
StrLen(CharLengthWithMagicMethod))
-    H2Dialect.registerFunction(
+    h2Dialect.registerFunction("my_avg", IntegralAverage)
+    h2Dialect.registerFunction("my_strlen", StrLen(CharLength))
+    h2Dialect.registerFunction("my_strlen_magic", 
StrLen(CharLengthWithMagicMethod))
+    h2Dialect.registerFunction(
       "my_strlen_static_magic", StrLen(new JavaStrLenStaticMagic()))
   }
 
   override def afterAll(): Unit = {
-    H2Dialect.clearFunctions()
+    h2Dialect.clearFunctions()
     Utils.deleteRecursively(tempDir)
     super.afterAll()
   }
@@ -340,7 +344,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession 
with ExplainSuiteHel
     checkPushedInfo(df5, "PushedFilters: []")
     checkAnswer(df5, Seq(Row(10000.00, 1000.0, "amy")))
 
-    JdbcDialects.unregisterDialect(H2Dialect)
+    JdbcDialects.unregisterDialect(h2Dialect)
     try {
       JdbcDialects.registerDialect(testH2Dialect)
       val df6 = spark.read.table("h2.test.employee")
@@ -350,7 +354,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession 
with ExplainSuiteHel
       checkAnswer(df6, Seq(Row(1, "amy", 10000.00, 1000.0, true)))
     } finally {
       JdbcDialects.unregisterDialect(testH2Dialect)
-      JdbcDialects.registerDialect(H2Dialect)
+      JdbcDialects.registerDialect(h2Dialect)
     }
   }
 
@@ -437,7 +441,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession 
with ExplainSuiteHel
     checkPushedInfo(df6, "PushedFilters: []")
     checkAnswer(df6, Seq(Row(10000.00, 1300.0, "dav"), Row(9000.00, 1200.0, 
"cat")))
 
-    JdbcDialects.unregisterDialect(H2Dialect)
+    JdbcDialects.unregisterDialect(h2Dialect)
     try {
       JdbcDialects.registerDialect(testH2Dialect)
       val df7 = spark.read
@@ -450,7 +454,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession 
with ExplainSuiteHel
       checkAnswer(df7, Seq(Row(1, "cathy", 9000.00, 1200.0, false)))
     } finally {
       JdbcDialects.unregisterDialect(testH2Dialect)
-      JdbcDialects.registerDialect(H2Dialect)
+      JdbcDialects.registerDialect(h2Dialect)
     }
   }
 
@@ -1590,7 +1594,7 @@ class JDBCV2Suite extends QueryTest with 
SharedSparkSession with ExplainSuiteHel
   }
 
   test("scan with filter push-down with UDF") {
-    JdbcDialects.unregisterDialect(H2Dialect)
+    JdbcDialects.unregisterDialect(h2Dialect)
     try {
       JdbcDialects.registerDialect(testH2Dialect)
       val df1 = sql("SELECT * FROM h2.test.people where h2.my_strlen(name) > 
2")
@@ -1610,12 +1614,12 @@ class JDBCV2Suite extends QueryTest with 
SharedSparkSession with ExplainSuiteHel
       checkAnswer(df2, Seq(Row("fred", 1), Row("mary", 2)))
     } finally {
       JdbcDialects.unregisterDialect(testH2Dialect)
-      JdbcDialects.registerDialect(H2Dialect)
+      JdbcDialects.registerDialect(h2Dialect)
     }
   }
 
   test("scan with filter push-down with UDF that has magic method") {
-    JdbcDialects.unregisterDialect(H2Dialect)
+    JdbcDialects.unregisterDialect(h2Dialect)
     try {
       JdbcDialects.registerDialect(testH2Dialect)
       val df1 = sql("SELECT * FROM h2.test.people where 
h2.my_strlen_magic(name) > 2")
@@ -1636,12 +1640,12 @@ class JDBCV2Suite extends QueryTest with 
SharedSparkSession with ExplainSuiteHel
       checkAnswer(df2, Seq(Row("fred", 1), Row("mary", 2)))
     } finally {
       JdbcDialects.unregisterDialect(testH2Dialect)
-      JdbcDialects.registerDialect(H2Dialect)
+      JdbcDialects.registerDialect(h2Dialect)
     }
   }
 
   test("scan with filter push-down with UDF that has static magic method") {
-    JdbcDialects.unregisterDialect(H2Dialect)
+    JdbcDialects.unregisterDialect(h2Dialect)
     try {
       JdbcDialects.registerDialect(testH2Dialect)
       val df1 = sql("SELECT * FROM h2.test.people where 
h2.my_strlen_static_magic(name) > 2")
@@ -1662,7 +1666,7 @@ class JDBCV2Suite extends QueryTest with 
SharedSparkSession with ExplainSuiteHel
       checkAnswer(df2, Seq(Row("fred", 1), Row("mary", 2)))
     } finally {
       JdbcDialects.unregisterDialect(testH2Dialect)
-      JdbcDialects.registerDialect(H2Dialect)
+      JdbcDialects.registerDialect(h2Dialect)
     }
   }
 
@@ -2872,8 +2876,8 @@ class JDBCV2Suite extends QueryTest with 
SharedSparkSession with ExplainSuiteHel
     }
   }
 
-  test("register dialect specific functions") {
-    JdbcDialects.unregisterDialect(H2Dialect)
+  test("register h2Dialect specific functions") {
+    JdbcDialects.unregisterDialect(h2Dialect)
     try {
       JdbcDialects.registerDialect(testH2Dialect)
       val df = sql("SELECT h2.my_avg(id) FROM h2.test.people")
@@ -2905,12 +2909,12 @@ class JDBCV2Suite extends QueryTest with 
SharedSparkSession with ExplainSuiteHel
           stop = 20))
     } finally {
       JdbcDialects.unregisterDialect(testH2Dialect)
-      JdbcDialects.registerDialect(H2Dialect)
+      JdbcDialects.registerDialect(h2Dialect)
     }
   }
 
   test("scan with aggregate push-down: complete push-down UDAF") {
-    JdbcDialects.unregisterDialect(H2Dialect)
+    JdbcDialects.unregisterDialect(h2Dialect)
     try {
       JdbcDialects.registerDialect(testH2Dialect)
       val df1 = sql("SELECT h2.my_avg(id) FROM h2.test.people")
@@ -2959,7 +2963,7 @@ class JDBCV2Suite extends QueryTest with 
SharedSparkSession with ExplainSuiteHel
       }
     } finally {
       JdbcDialects.unregisterDialect(testH2Dialect)
-      JdbcDialects.registerDialect(H2Dialect)
+      JdbcDialects.registerDialect(h2Dialect)
     }
   }
 
@@ -3006,7 +3010,7 @@ class JDBCV2Suite extends QueryTest with 
SharedSparkSession with ExplainSuiteHel
 
   test("IDENTIFIER_TOO_MANY_NAME_PARTS: " +
     "jdbc function doesn't support identifiers consisting of more than 2 
parts") {
-    JdbcDialects.unregisterDialect(H2Dialect)
+    JdbcDialects.unregisterDialect(h2Dialect)
     try {
       JdbcDialects.registerDialect(testH2Dialect)
       checkError(
@@ -3019,7 +3023,7 @@ class JDBCV2Suite extends QueryTest with 
SharedSparkSession with ExplainSuiteHel
       )
     } finally {
       JdbcDialects.unregisterDialect(testH2Dialect)
-      JdbcDialects.registerDialect(H2Dialect)
+      JdbcDialects.registerDialect(h2Dialect)
     }
   }
 
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala
index f904d0e3d3c8..0d9dc2f76faf 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala
@@ -206,7 +206,7 @@ class JDBCWriteSuite extends SharedSparkSession with 
BeforeAndAfter {
   }
 
   test("Truncate") {
-    JdbcDialects.unregisterDialect(H2Dialect)
+    JdbcDialects.unregisterDialect(H2Dialect())
     try {
       JdbcDialects.registerDialect(testH2Dialect)
       val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2)
@@ -231,7 +231,7 @@ class JDBCWriteSuite extends SharedSparkSession with 
BeforeAndAfter {
             
"Some(StructType(StructField(name,StringType,true),StructField(id,IntegerType,true)))"))
     } finally {
       JdbcDialects.unregisterDialect(testH2Dialect)
-      JdbcDialects.registerDialect(H2Dialect)
+      JdbcDialects.registerDialect(H2Dialect())
     }
   }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org


Reply via email to