Github user marmbrus commented on a diff in the pull request:
https://github.com/apache/spark/pull/5498#discussion_r28470212
--- Diff:
sql/core/src/main/scala/org/apache/spark/sql/jdbc/DriverQuirks.scala ---
@@ -39,33 +39,68 @@ import java.sql.Types
* if `getJDBCType` returns `(null, None)`, the default type handling is
used
* for the given Catalyst type.
*/
-private[sql] abstract class DriverQuirks {
+abstract class DriverQuirks {
+ def canHandle(url : String): Boolean
def getCatalystType(sqlType: Int, typeName: String, size: Int, md:
MetadataBuilder): DataType
def getJDBCType(dt: DataType): (String, Option[Int])
}
-private[sql] object DriverQuirks {
+object DriverQuirks {
+
+ private var quirks = List[DriverQuirks]()
+
+ def registerQuirks(quirk: DriverQuirks) : Unit = {
+ quirks = quirk :: quirks
+ }
+
+ def unregisterQuirks(quirk : DriverQuirks) : Unit = {
+ quirks = quirks.filterNot(_ == quirk)
+ }
+
+ registerQuirks(new MySQLQuirks())
+ registerQuirks(new PostgresQuirks())
+
/**
* Fetch the DriverQuirks class corresponding to a given database url.
*/
def get(url: String): DriverQuirks = {
- if (url.substring(0, 10).equals("jdbc:mysql")) {
- new MySQLQuirks()
- } else if (url.substring(0, 15).equals("jdbc:postgresql")) {
- new PostgresQuirks()
- } else {
- new NoQuirks()
+ val matchingQuirks = quirks.filter(_.canHandle(url))
+ matchingQuirks.length match {
+ case 0 => new NoQuirks()
+ case 1 => matchingQuirks.head
+ case _ => new AggregatedQuirks(matchingQuirks)
}
}
}
-private[sql] class NoQuirks extends DriverQuirks {
+class AggregatedQuirks(quirks: List[DriverQuirks]) extends DriverQuirks {
+
+ require(!quirks.isEmpty)
+
+ def canHandle(url : String): Boolean =
+ quirks.map(_.canHandle(url)).reduce(_ && _)
+
+ def getCatalystType(sqlType: Int, typeName: String, size: Int, md:
MetadataBuilder): DataType =
+ quirks.map(_.getCatalystType(sqlType, typeName, size,
md)).collectFirst {
+ case dataType if dataType != null => dataType
+ }.orNull
+
+ def getJDBCType(dt: DataType): (String, Option[Int]) =
+ quirks.map(_.getJDBCType(dt)).collectFirst {
+ case t @ (typeName,sqlType) if typeName != null || sqlType.isDefined
=> t
+ }.getOrElse((null, None))
+
+}
+
+class NoQuirks extends DriverQuirks {
+ def canHandle(url : String): Boolean = true
def getCatalystType(sqlType: Int, typeName: String, size: Int, md:
MetadataBuilder): DataType =
null
def getJDBCType(dt: DataType): (String, Option[Int]) = (null, None)
}
-private[sql] class PostgresQuirks extends DriverQuirks {
+class PostgresQuirks extends DriverQuirks {
--- End diff --
Should these be `case object`s?
---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]