Repository: spark Updated Branches: refs/heads/master a5849ad9a -> f596ebe4d
[SPARK-24327][SQL] Verify and normalize a partition column name based on the JDBC resolved schema ## What changes were proposed in this pull request? This pr modified JDBC datasource code to verify and normalize a partition column based on the JDBC resolved schema before building `JDBCRelation`. Closes #20370 ## How was this patch tested? Added tests in `JDBCSuite`. Author: Takeshi Yamamuro <yamam...@apache.org> Closes #21379 from maropu/SPARK-24327. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/f596ebe4 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/f596ebe4 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/f596ebe4 Branch: refs/heads/master Commit: f596ebe4d3170590b6fce34c179e51ee80c965d3 Parents: a5849ad Author: Takeshi Yamamuro <yamam...@apache.org> Authored: Sun Jun 24 23:14:42 2018 -0700 Committer: Xiao Li <gatorsm...@gmail.com> Committed: Sun Jun 24 23:14:42 2018 -0700 ---------------------------------------------------------------------- .../scala/org/apache/spark/util/Utils.scala | 2 +- .../datasources/jdbc/JDBCRelation.scala | 76 ++++++++++++++++---- .../datasources/jdbc/JdbcRelationProvider.scala | 6 +- .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 51 ++++++++++++- 4 files changed, 118 insertions(+), 17 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/f596ebe4/core/src/main/scala/org/apache/spark/util/Utils.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index c139db4..a6fd363 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -100,7 +100,7 @@ private[spark] object Utils extends Logging { */ val DEFAULT_MAX_TO_STRING_FIELDS = 25 - private def maxNumToStringFields = { + private[spark] def maxNumToStringFields = { if (SparkEnv.get != null) { SparkEnv.get.conf.getInt("spark.debug.maxToStringFields", DEFAULT_MAX_TO_STRING_FIELDS) } else { http://git-wip-us.apache.org/repos/asf/spark/blob/f596ebe4/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala index b23e5a7..b84543c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala @@ -22,10 +22,12 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.Partition import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Row, SaveMode, SparkSession, SQLContext} +import org.apache.spark.sql.{AnalysisException, DataFrame, Row, SaveMode, SparkSession, SQLContext} +import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.jdbc.JdbcDialects import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.StructType +import org.apache.spark.util.Utils /** * Instructions on how to partition the table among workers. @@ -48,10 +50,17 @@ private[sql] object JDBCRelation extends Logging { * Null value predicate is added to the first partition where clause to include * the rows with null value for the partitions column. * + * @param schema resolved schema of a JDBC table * @param partitioning partition information to generate the where clause for each partition + * @param resolver function used to determine if two identifiers are equal + * @param jdbcOptions JDBC options that contains url * @return an array of partitions with where clause for each partition */ - def columnPartition(partitioning: JDBCPartitioningInfo): Array[Partition] = { + def columnPartition( + schema: StructType, + partitioning: JDBCPartitioningInfo, + resolver: Resolver, + jdbcOptions: JDBCOptions): Array[Partition] = { if (partitioning == null || partitioning.numPartitions <= 1 || partitioning.lowerBound == partitioning.upperBound) { return Array[Partition](JDBCPartition(null, 0)) @@ -78,7 +87,10 @@ private[sql] object JDBCRelation extends Logging { // Overflow and silliness can happen if you subtract then divide. // Here we get a little roundoff, but that's (hopefully) OK. val stride: Long = upperBound / numPartitions - lowerBound / numPartitions - val column = partitioning.column + + val column = verifyAndGetNormalizedColumnName( + schema, partitioning.column, resolver, jdbcOptions) + var i: Int = 0 var currentValue: Long = lowerBound val ans = new ArrayBuffer[Partition]() @@ -99,10 +111,57 @@ private[sql] object JDBCRelation extends Logging { } ans.toArray } + + // Verify column name based on the JDBC resolved schema + private def verifyAndGetNormalizedColumnName( + schema: StructType, + columnName: String, + resolver: Resolver, + jdbcOptions: JDBCOptions): String = { + val dialect = JdbcDialects.get(jdbcOptions.url) + schema.map(_.name).find { fieldName => + resolver(fieldName, columnName) || + resolver(dialect.quoteIdentifier(fieldName), columnName) + }.map(dialect.quoteIdentifier).getOrElse { + throw new AnalysisException(s"User-defined partition column $columnName not " + + s"found in the JDBC relation: ${schema.simpleString(Utils.maxNumToStringFields)}") + } + } + + /** + * Takes a (schema, table) specification and returns the table's Catalyst schema. + * If `customSchema` defined in the JDBC options, replaces the schema's dataType with the + * custom schema's type. + * + * @param resolver function used to determine if two identifiers are equal + * @param jdbcOptions JDBC options that contains url, table and other information. + * @return resolved Catalyst schema of a JDBC table + */ + def getSchema(resolver: Resolver, jdbcOptions: JDBCOptions): StructType = { + val tableSchema = JDBCRDD.resolveTable(jdbcOptions) + jdbcOptions.customSchema match { + case Some(customSchema) => JdbcUtils.getCustomSchema( + tableSchema, customSchema, resolver) + case None => tableSchema + } + } + + /** + * Resolves a Catalyst schema of a JDBC table and returns [[JDBCRelation]] with the schema. + */ + def apply( + parts: Array[Partition], + jdbcOptions: JDBCOptions)( + sparkSession: SparkSession): JDBCRelation = { + val schema = JDBCRelation.getSchema(sparkSession.sessionState.conf.resolver, jdbcOptions) + JDBCRelation(schema, parts, jdbcOptions)(sparkSession) + } } private[sql] case class JDBCRelation( - parts: Array[Partition], jdbcOptions: JDBCOptions)(@transient val sparkSession: SparkSession) + override val schema: StructType, + parts: Array[Partition], + jdbcOptions: JDBCOptions)(@transient val sparkSession: SparkSession) extends BaseRelation with PrunedFilteredScan with InsertableRelation { @@ -111,15 +170,6 @@ private[sql] case class JDBCRelation( override val needConversion: Boolean = false - override val schema: StructType = { - val tableSchema = JDBCRDD.resolveTable(jdbcOptions) - jdbcOptions.customSchema match { - case Some(customSchema) => JdbcUtils.getCustomSchema( - tableSchema, customSchema, sparkSession.sessionState.conf.resolver) - case None => tableSchema - } - } - // Check if JDBCRDD.compileFilter can accept input filters override def unhandledFilters(filters: Array[Filter]): Array[Filter] = { filters.filter(JDBCRDD.compileFilter(_, JdbcDialects.get(jdbcOptions.url)).isEmpty) http://git-wip-us.apache.org/repos/asf/spark/blob/f596ebe4/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala index f8c5677..2b488bb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala @@ -48,8 +48,10 @@ class JdbcRelationProvider extends CreatableRelationProvider JDBCPartitioningInfo( partitionColumn.get, lowerBound.get, upperBound.get, numPartitions.get) } - val parts = JDBCRelation.columnPartition(partitionInfo) - JDBCRelation(parts, jdbcOptions)(sqlContext.sparkSession) + val resolver = sqlContext.conf.resolver + val schema = JDBCRelation.getSchema(resolver, jdbcOptions) + val parts = JDBCRelation.columnPartition(schema, partitionInfo, resolver, jdbcOptions) + JDBCRelation(schema, parts, jdbcOptions)(sqlContext.sparkSession) } override def createRelation( http://git-wip-us.apache.org/repos/asf/spark/blob/f596ebe4/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index bc2aca6..6ea61f0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -31,8 +31,9 @@ import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.execution.DataSourceScanExec import org.apache.spark.sql.execution.command.ExplainCommand import org.apache.spark.sql.execution.datasources.LogicalRelation -import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JDBCRDD, JDBCRelation, JdbcUtils} +import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JDBCPartition, JDBCRDD, JDBCRelation, JdbcUtils} import org.apache.spark.sql.execution.metric.InputOutputMetricsHelper +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -238,6 +239,11 @@ class JDBCSuite extends SparkFunSuite |OPTIONS (url '$url', dbtable 'TEST."mixedCaseCols"', user 'testUser', password 'testPass') """.stripMargin.replaceAll("\n", " ")) + conn.prepareStatement("CREATE TABLE test.partition (THEID INTEGER, `THE ID` INTEGER) " + + "AS SELECT 1, 1") + .executeUpdate() + conn.commit() + // Untested: IDENTITY, OTHER, UUID, ARRAY, and GEOMETRY types. } @@ -1206,4 +1212,47 @@ class JDBCSuite extends SparkFunSuite }.getMessage assert(errMsg.contains("Statement was canceled or the session timed out")) } + + test("SPARK-24327 verify and normalize a partition column based on a JDBC resolved schema") { + def testJdbcParitionColumn(partColName: String, expectedColumnName: String): Unit = { + val df = spark.read.format("jdbc") + .option("url", urlWithUserAndPass) + .option("dbtable", "TEST.PARTITION") + .option("partitionColumn", partColName) + .option("lowerBound", 1) + .option("upperBound", 4) + .option("numPartitions", 3) + .load() + + val quotedPrtColName = testH2Dialect.quoteIdentifier(expectedColumnName) + df.logicalPlan match { + case LogicalRelation(JDBCRelation(_, parts, _), _, _, _) => + val whereClauses = parts.map(_.asInstanceOf[JDBCPartition].whereClause).toSet + assert(whereClauses === Set( + s"$quotedPrtColName < 2 or $quotedPrtColName is null", + s"$quotedPrtColName >= 2 AND $quotedPrtColName < 3", + s"$quotedPrtColName >= 3")) + } + } + + testJdbcParitionColumn("THEID", "THEID") + testJdbcParitionColumn("\"THEID\"", "THEID") + withSQLConf("spark.sql.caseSensitive" -> "false") { + testJdbcParitionColumn("ThEiD", "THEID") + } + testJdbcParitionColumn("THE ID", "THE ID") + + def testIncorrectJdbcPartitionColumn(partColName: String): Unit = { + val errMsg = intercept[AnalysisException] { + testJdbcParitionColumn(partColName, "THEID") + }.getMessage + assert(errMsg.contains(s"User-defined partition column $partColName not found " + + "in the JDBC relation:")) + } + + testIncorrectJdbcPartitionColumn("NoExistingColumn") + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + testIncorrectJdbcPartitionColumn(testH2Dialect.quoteIdentifier("ThEiD")) + } + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org