Github user liancheng commented on a diff in the pull request:

    https://github.com/apache/spark/pull/5498#discussion_r28403924
  
    --- Diff: 
sql/core/src/main/scala/org/apache/spark/sql/jdbc/DriverQuirks.scala ---
    @@ -39,33 +39,70 @@ import java.sql.Types
      * if `getJDBCType` returns `(null, None)`, the default type handling is 
used
      * for the given Catalyst type.
      */
    -private[sql] abstract class DriverQuirks {
    +abstract class DriverQuirks {
    +  def canHandle(url : String): Boolean
       def getCatalystType(sqlType: Int, typeName: String, size: Int, md: 
MetadataBuilder): DataType
       def getJDBCType(dt: DataType): (String, Option[Int])
     }
     
    -private[sql] object DriverQuirks {
    +object DriverQuirks {
    +
    +  private var quirks = List[DriverQuirks]()
    +
    +  def registerQuirks(quirk: DriverQuirks) {
    +    quirks = quirk :: quirks
    +  }
    +
    +  def unregisterQuirks(quirk : DriverQuirks) {
    +    quirks = quirks.filterNot(_ == quirk)
    +  }
    +
    +  registerQuirks(new MySQLQuirks())
    +  registerQuirks(new PostgresQuirks())
    +
       /**
        * Fetch the DriverQuirks class corresponding to a given database url.
        */
       def get(url: String): DriverQuirks = {
    -    if (url.substring(0, 10).equals("jdbc:mysql")) {
    -      new MySQLQuirks()
    -    } else if (url.substring(0, 15).equals("jdbc:postgresql")) {
    -      new PostgresQuirks()
    -    } else {
    -      new NoQuirks()
    +    val matchingQuirks = quirks.filter(_.canHandle(url))
    +    matchingQuirks.length match {
    +      case 0 => new NoQuirks()
    +      case 1 => matchingQuirks.head
    +      case _ => new AggregatedQuirks(matchingQuirks)
         }
       }
     }
     
    -private[sql] class NoQuirks extends DriverQuirks {
    +class AggregatedQuirks(quirks: List[DriverQuirks]) extends DriverQuirks {
    +  def canHandle(url : String): Boolean =
    +    quirks.foldLeft(true)((l,r) => l && r.canHandle(url))
    +  def getCatalystType(sqlType: Int, typeName: String, size: Int, md: 
MetadataBuilder) : DataType =
    +    quirks.foldLeft(null.asInstanceOf[DataType])((l,r) =>
    +      if (l != null) {
    +        l
    +      } else {
    +        r.getCatalystType(sqlType, typeName, size, md)
    +      }
    +    )
    +  def getJDBCType(dt: DataType): (String, Option[Int]) =
    +    quirks.foldLeft(null.asInstanceOf[(String, Option[Int])])((l,r) =>
    +      if (l != null) {
    +        l
    +      } else {
    +        r.getJDBCType(dt)
    +      }
    +    )
    --- End diff --
    
    Both `l` and `r` are always non-null pairs. Only `l._1` and `r._1` are 
possible to be null.
    
    ```scala
    quirks.map(_.getJDBCType(dt)).collectFirst {
      case p @ (typeName, _) if typeName != null => p
    }.getOrElse((null, None))
    ```


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to