This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 4dce45a  [SPARK-26744][SQL] Support schema validation in 
FileDataSourceV2 framework
4dce45a is described below

commit 4dce45a5992e6a89a26b5a0739b33cfeaf979208
Author: Gengliang Wang <gengliang.w...@databricks.com>
AuthorDate: Sat Feb 16 17:11:36 2019 +0800

    [SPARK-26744][SQL] Support schema validation in FileDataSourceV2 framework
    
    ## What changes were proposed in this pull request?
    
    The file source has a schema validation feature, which validates 2 schemas:
    1. the user-specified schema when reading.
    2. the schema of input data when writing.
    
    If a file source doesn't support the schema, we can fail the query earlier.
    
    This PR is to implement the same feature  in the `FileDataSourceV2` 
framework. Comparing to `FileFormat`, `FileDataSourceV2` has multiple layers. 
The API is added in two places:
    1. Read path: the table schema is determined in `TableProvider.getTable`. 
The actual read schema can be a subset of the table schema.  This PR proposes 
to validate the actual read schema in  `FileScan`.
    2.  Write path: validate the actual output schema in `FileWriteBuilder`.
    
    ## How was this patch tested?
    
    Unit test
    
    Closes #23714 from gengliangwang/schemaValidationV2.
    
    Authored-by: Gengliang Wang <gengliang.w...@databricks.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../sql/execution/datasources/v2/FileScan.scala    |  33 ++++-
 .../datasources/v2/FileWriteBuilder.scala          |  24 +++-
 .../datasources/v2/orc/OrcDataSourceV2.scala       |  19 ++-
 .../sql/execution/datasources/v2/orc/OrcScan.scala |  10 +-
 .../datasources/v2/orc/OrcWriteBuilder.scala       |   6 +
 .../spark/sql/FileBasedDataSourceSuite.scala       | 152 +++++++++++----------
 6 files changed, 167 insertions(+), 77 deletions(-)

diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala
index 3615b15..bdd6a48 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala
@@ -18,15 +18,16 @@ package org.apache.spark.sql.execution.datasources.v2
 
 import org.apache.hadoop.fs.Path
 
-import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.{AnalysisException, SparkSession}
 import org.apache.spark.sql.execution.PartitionedFileUtil
 import org.apache.spark.sql.execution.datasources._
 import org.apache.spark.sql.sources.v2.reader.{Batch, InputPartition, Scan}
-import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.types.{DataType, StructType}
 
 abstract class FileScan(
     sparkSession: SparkSession,
-    fileIndex: PartitioningAwareFileIndex) extends Scan with Batch {
+    fileIndex: PartitioningAwareFileIndex,
+    readSchema: StructType) extends Scan with Batch {
   /**
    * Returns whether a file with `path` could be split or not.
    */
@@ -34,6 +35,22 @@ abstract class FileScan(
     false
   }
 
+  /**
+   * Returns whether this format supports the given [[DataType]] in write path.
+   * By default all data types are supported.
+   */
+  def supportsDataType(dataType: DataType): Boolean = true
+
+  /**
+   * The string that represents the format that this data source provider 
uses. This is
+   * overridden by children to provide a nice alias for the data source. For 
example:
+   *
+   * {{{
+   *   override def formatName(): String = "ORC"
+   * }}}
+   */
+  def formatName: String
+
   protected def partitions: Seq[FilePartition] = {
     val selectedPartitions = fileIndex.listFiles(Seq.empty, Seq.empty)
     val maxSplitBytes = FilePartition.maxSplitBytes(sparkSession, 
selectedPartitions)
@@ -57,5 +74,13 @@ abstract class FileScan(
     partitions.toArray
   }
 
-  override def toBatch: Batch = this
+  override def toBatch: Batch = {
+    readSchema.foreach { field =>
+      if (!supportsDataType(field.dataType)) {
+        throw new AnalysisException(
+          s"$formatName data source does not support 
${field.dataType.catalogString} data type.")
+      }
+    }
+    this
+  }
 }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWriteBuilder.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWriteBuilder.scala
index ce9b52f..6a94248 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWriteBuilder.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWriteBuilder.scala
@@ -34,7 +34,7 @@ import org.apache.spark.sql.execution.metric.SQLMetric
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.sources.v2.DataSourceOptions
 import org.apache.spark.sql.sources.v2.writer.{BatchWrite, SupportsSaveMode, 
WriteBuilder}
-import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.types.{DataType, StructType}
 import org.apache.spark.util.SerializableConfiguration
 
 abstract class FileWriteBuilder(options: DataSourceOptions)
@@ -104,12 +104,34 @@ abstract class FileWriteBuilder(options: 
DataSourceOptions)
       options: Map[String, String],
       dataSchema: StructType): OutputWriterFactory
 
+  /**
+   * Returns whether this format supports the given [[DataType]] in write path.
+   * By default all data types are supported.
+   */
+  def supportsDataType(dataType: DataType): Boolean = true
+
+  /**
+   * The string that represents the format that this data source provider 
uses. This is
+   * overridden by children to provide a nice alias for the data source. For 
example:
+   *
+   * {{{
+   *   override def formatName(): String = "ORC"
+   * }}}
+   */
+  def formatName: String
+
   private def validateInputs(): Unit = {
     assert(schema != null, "Missing input data schema")
     assert(queryId != null, "Missing query ID")
     assert(mode != null, "Missing save mode")
     assert(options.paths().length == 1)
     DataSource.validateSchema(schema)
+    schema.foreach { field =>
+      if (!supportsDataType(field.dataType)) {
+        throw new AnalysisException(
+          s"$formatName data source does not support 
${field.dataType.catalogString} data type.")
+      }
+    }
   }
 
   private def getJobInstance(hadoopConf: Configuration, path: Path): Job = {
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcDataSourceV2.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcDataSourceV2.scala
index 74739b4..f279af4 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcDataSourceV2.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcDataSourceV2.scala
@@ -20,7 +20,7 @@ import org.apache.spark.sql.execution.datasources._
 import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat
 import org.apache.spark.sql.execution.datasources.v2._
 import org.apache.spark.sql.sources.v2.{DataSourceOptions, Table}
-import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.types._
 
 class OrcDataSourceV2 extends FileDataSourceV2 {
 
@@ -42,3 +42,20 @@ class OrcDataSourceV2 extends FileDataSourceV2 {
     OrcTable(tableName, sparkSession, options, Some(schema))
   }
 }
+
+object OrcDataSourceV2 {
+  def supportsDataType(dataType: DataType): Boolean = dataType match {
+    case _: AtomicType => true
+
+    case st: StructType => st.forall { f => supportsDataType(f.dataType) }
+
+    case ArrayType(elementType, _) => supportsDataType(elementType)
+
+    case MapType(keyType, valueType, _) =>
+      supportsDataType(keyType) && supportsDataType(valueType)
+
+    case udt: UserDefinedType[_] => supportsDataType(udt.sqlType)
+
+    case _ => false
+  }
+}
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala
index a792ad3..3c5dc1f 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala
@@ -23,7 +23,7 @@ import org.apache.spark.sql.SparkSession
 import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex
 import org.apache.spark.sql.execution.datasources.v2.FileScan
 import org.apache.spark.sql.sources.v2.reader.PartitionReaderFactory
-import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.types.{DataType, StructType}
 import org.apache.spark.util.SerializableConfiguration
 
 case class OrcScan(
@@ -31,7 +31,7 @@ case class OrcScan(
     hadoopConf: Configuration,
     fileIndex: PartitioningAwareFileIndex,
     dataSchema: StructType,
-    readSchema: StructType) extends FileScan(sparkSession, fileIndex) {
+    readSchema: StructType) extends FileScan(sparkSession, fileIndex, 
readSchema) {
   override def isSplitable(path: Path): Boolean = true
 
   override def createReaderFactory(): PartitionReaderFactory = {
@@ -40,4 +40,10 @@ case class OrcScan(
     OrcPartitionReaderFactory(sparkSession.sessionState.conf, broadcastedConf,
       dataSchema, fileIndex.partitionSchema, readSchema)
   }
+
+  override def supportsDataType(dataType: DataType): Boolean = {
+    OrcDataSourceV2.supportsDataType(dataType)
+  }
+
+  override def formatName: String = "ORC"
 }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcWriteBuilder.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcWriteBuilder.scala
index 80429d9..1aec4d8 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcWriteBuilder.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcWriteBuilder.scala
@@ -63,4 +63,10 @@ class OrcWriteBuilder(options: DataSourceOptions) extends 
FileWriteBuilder(optio
       }
     }
   }
+
+  override def supportsDataType(dataType: DataType): Boolean = {
+    OrcDataSourceV2.supportsDataType(dataType)
+  }
+
+  override def formatName: String = "ORC"
 }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala
index fc87b04..e0c0484 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala
@@ -329,83 +329,97 @@ class FileBasedDataSourceSuite extends QueryTest with 
SharedSQLContext with Befo
   test("SPARK-24204 error handling for unsupported Interval data types - csv, 
json, parquet, orc") {
     withTempDir { dir =>
       val tempDir = new File(dir, "files").getCanonicalPath
-      // TODO(SPARK-26744): support data type validating in V2 data source, 
and test V2 as well.
-      withSQLConf(SQLConf.USE_V1_SOURCE_WRITER_LIST.key -> "orc") {
-        // write path
-        Seq("csv", "json", "parquet", "orc").foreach { format =>
-          var msg = intercept[AnalysisException] {
-            sql("select interval 1 
days").write.format(format).mode("overwrite").save(tempDir)
-          }.getMessage
-          assert(msg.contains("Cannot save interval data type into external 
storage."))
-
-          msg = intercept[AnalysisException] {
-            spark.udf.register("testType", () => new IntervalData())
-            sql("select 
testType()").write.format(format).mode("overwrite").save(tempDir)
-          }.getMessage
-          assert(msg.toLowerCase(Locale.ROOT)
-            .contains(s"$format data source does not support calendarinterval 
data type."))
+      Seq(true, false).foreach { useV1 =>
+        val useV1List = if (useV1) {
+          "orc"
+        } else {
+          ""
         }
+        def errorMessage(format: String, isWrite: Boolean): String = {
+          if (isWrite && (useV1 || format != "orc")) {
+            "cannot save interval data type into external storage."
+          } else {
+            s"$format data source does not support calendarinterval data type."
+          }
+        }
+
+        withSQLConf(SQLConf.USE_V1_SOURCE_WRITER_LIST.key -> useV1List) {
+          // write path
+          Seq("csv", "json", "parquet", "orc").foreach { format =>
+            var msg = intercept[AnalysisException] {
+              sql("select interval 1 
days").write.format(format).mode("overwrite").save(tempDir)
+            }.getMessage
+            assert(msg.toLowerCase(Locale.ROOT).contains(errorMessage(format, 
true)))
+          }
 
-        // read path
-        Seq("parquet", "csv").foreach { format =>
-          var msg = intercept[AnalysisException] {
-            val schema = StructType(StructField("a", CalendarIntervalType, 
true) :: Nil)
-            spark.range(1).write.format(format).mode("overwrite").save(tempDir)
-            spark.read.schema(schema).format(format).load(tempDir).collect()
-          }.getMessage
-          assert(msg.toLowerCase(Locale.ROOT)
-            .contains(s"$format data source does not support calendarinterval 
data type."))
-
-          msg = intercept[AnalysisException] {
-            val schema = StructType(StructField("a", new IntervalUDT(), true) 
:: Nil)
-            spark.range(1).write.format(format).mode("overwrite").save(tempDir)
-            spark.read.schema(schema).format(format).load(tempDir).collect()
-          }.getMessage
-          assert(msg.toLowerCase(Locale.ROOT)
-            .contains(s"$format data source does not support calendarinterval 
data type."))
+          // read path
+          Seq("parquet", "csv").foreach { format =>
+            var msg = intercept[AnalysisException] {
+              val schema = StructType(StructField("a", CalendarIntervalType, 
true) :: Nil)
+              
spark.range(1).write.format(format).mode("overwrite").save(tempDir)
+              spark.read.schema(schema).format(format).load(tempDir).collect()
+            }.getMessage
+            assert(msg.toLowerCase(Locale.ROOT).contains(errorMessage(format, 
false)))
+
+            msg = intercept[AnalysisException] {
+              val schema = StructType(StructField("a", new IntervalUDT(), 
true) :: Nil)
+              
spark.range(1).write.format(format).mode("overwrite").save(tempDir)
+              spark.read.schema(schema).format(format).load(tempDir).collect()
+            }.getMessage
+            assert(msg.toLowerCase(Locale.ROOT).contains(errorMessage(format, 
false)))
+          }
         }
       }
     }
   }
 
   test("SPARK-24204 error handling for unsupported Null data types - csv, 
parquet, orc") {
-    // TODO(SPARK-26744): support data type validating in V2 data source, and 
test V2 as well.
-    withSQLConf(SQLConf.USE_V1_SOURCE_READER_LIST.key -> "orc",
-      SQLConf.USE_V1_SOURCE_WRITER_LIST.key -> "orc") {
-      withTempDir { dir =>
-        val tempDir = new File(dir, "files").getCanonicalPath
-
-        Seq("parquet", "csv", "orc").foreach { format =>
-          // write path
-          var msg = intercept[AnalysisException] {
-            sql("select 
null").write.format(format).mode("overwrite").save(tempDir)
-          }.getMessage
-          assert(msg.toLowerCase(Locale.ROOT)
-            .contains(s"$format data source does not support null data type."))
-
-          msg = intercept[AnalysisException] {
-            spark.udf.register("testType", () => new NullData())
-            sql("select 
testType()").write.format(format).mode("overwrite").save(tempDir)
-          }.getMessage
-          assert(msg.toLowerCase(Locale.ROOT)
-            .contains(s"$format data source does not support null data type."))
-
-          // read path
-          msg = intercept[AnalysisException] {
-            val schema = StructType(StructField("a", NullType, true) :: Nil)
-            spark.range(1).write.format(format).mode("overwrite").save(tempDir)
-            spark.read.schema(schema).format(format).load(tempDir).collect()
-          }.getMessage
-          assert(msg.toLowerCase(Locale.ROOT)
-            .contains(s"$format data source does not support null data type."))
-
-          msg = intercept[AnalysisException] {
-            val schema = StructType(StructField("a", new NullUDT(), true) :: 
Nil)
-            spark.range(1).write.format(format).mode("overwrite").save(tempDir)
-            spark.read.schema(schema).format(format).load(tempDir).collect()
-          }.getMessage
-          assert(msg.toLowerCase(Locale.ROOT)
-            .contains(s"$format data source does not support null data type."))
+    Seq(true, false).foreach { useV1 =>
+      val useV1List = if (useV1) {
+        "orc"
+      } else {
+        ""
+      }
+      def errorMessage(format: String): String = {
+        s"$format data source does not support null data type."
+      }
+      withSQLConf(SQLConf.USE_V1_SOURCE_READER_LIST.key -> useV1List,
+        SQLConf.USE_V1_SOURCE_WRITER_LIST.key -> useV1List) {
+        withTempDir { dir =>
+          val tempDir = new File(dir, "files").getCanonicalPath
+
+          Seq("parquet", "csv", "orc").foreach { format =>
+            // write path
+            var msg = intercept[AnalysisException] {
+              sql("select 
null").write.format(format).mode("overwrite").save(tempDir)
+            }.getMessage
+            assert(msg.toLowerCase(Locale.ROOT)
+              .contains(errorMessage(format)))
+
+            msg = intercept[AnalysisException] {
+              spark.udf.register("testType", () => new NullData())
+              sql("select 
testType()").write.format(format).mode("overwrite").save(tempDir)
+            }.getMessage
+            assert(msg.toLowerCase(Locale.ROOT)
+              .contains(errorMessage(format)))
+
+            // read path
+            msg = intercept[AnalysisException] {
+              val schema = StructType(StructField("a", NullType, true) :: Nil)
+              
spark.range(1).write.format(format).mode("overwrite").save(tempDir)
+              spark.read.schema(schema).format(format).load(tempDir).collect()
+            }.getMessage
+            assert(msg.toLowerCase(Locale.ROOT)
+              .contains(errorMessage(format)))
+
+            msg = intercept[AnalysisException] {
+              val schema = StructType(StructField("a", new NullUDT(), true) :: 
Nil)
+              
spark.range(1).write.format(format).mode("overwrite").save(tempDir)
+              spark.read.schema(schema).format(format).load(tempDir).collect()
+            }.getMessage
+            assert(msg.toLowerCase(Locale.ROOT)
+              .contains(errorMessage(format)))
+          }
         }
       }
     }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to