This is an automated email from the ASF dual-hosted git repository.

maxgekk pushed a commit to branch branch-3.2
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.2 by this push:
     new bc69e6c  [SPARK-38523][SQL][3.2] Fix referring to the corrupt record 
column from CSV
bc69e6c is described below

commit bc69e6cc2b58a2eec3491a1b18adafa0d9a8c1ee
Author: Max Gekk <max.g...@gmail.com>
AuthorDate: Mon Mar 14 12:22:38 2022 +0300

    [SPARK-38523][SQL][3.2] Fix referring to the corrupt record column from CSV
    
    ### What changes were proposed in this pull request?
    In the case when an user specifies the corrupt record column via the CSV 
option `columnNameOfCorruptRecord`:
    1. Disable the column pruning feature in the CSV parser.
    2. Don't push filters to `UnivocityParser` that refer to the "virtual" 
column `columnNameOfCorruptRecord`. Since the column cannot present in the 
input CSV, user's queries fail while compiling predicates. After the changes, 
the skipped filters are applied later on the upper layer.
    
    ### Why are the changes needed?
    The changes allow to refer to the corrupt record column from user's queries:
    
    ```Scala
    spark.read.format("csv")
      .option("header", "true")
      .option("columnNameOfCorruptRecord", "corrRec")
      .schema(schema)
      .load("csv_corrupt_record.csv")
      .filter($"corrRec".isNotNull)
      .show()
    ```
    for the input file "csv_corrupt_record.csv":
    ```
    0,2013-111_11 12:13:14
    1,1983-08-04
    ```
    the query returns:
    ```
    +---+----+----------------------+
    |a  |b   |corrRec               |
    +---+----+----------------------+
    |0  |null|0,2013-111_11 12:13:14|
    +---+----+----------------------+
    ```
    
    ### Does this PR introduce _any_ user-facing change?
    Yes. Before the changes, the query above fails with the exception:
    ```Java
    java.lang.IllegalArgumentException: _corrupt_record does not exist. 
Available: a, b
        at 
org.apache.spark.sql.types.StructType.$anonfun$fieldIndex$1(StructType.scala:310)
 ~[classes/:?]
    ```
    
    ### How was this patch tested?
    By running new CSV test:
    ```
    $ build/sbt "sql/testOnly *.CSVv1Suite"
    $ build/sbt "sql/testOnly *.CSVv2Suite"
    $ build/sbt "sql/testOnly *.CSVLegacyTimeParserSuite"
    ```
    
    Authored-by: Max Gekk <max.gekkgmail.com>
    Signed-off-by: Wenchen Fan <wenchendatabricks.com>
    (cherry picked from commit 959694271e30879c944d7fd5de2740571012460a)
    Signed-off-by: Max Gekk <max.gekkgmail.com>
    
    Closes #35844 from MaxGekk/csv-ref-_corupt_record-3.2.
    
    Authored-by: Max Gekk <max.g...@gmail.com>
    Signed-off-by: Max Gekk <max.g...@gmail.com>
---
 .../execution/datasources/csv/CSVFileFormat.scala  | 18 +++++-----
 .../v2/csv/CSVPartitionReaderFactory.scala         |  3 +-
 .../sql/execution/datasources/v2/csv/CSVScan.scala | 16 ++++-----
 .../sql/execution/datasources/csv/CSVSuite.scala   | 42 +++++++++-------------
 4 files changed, 34 insertions(+), 45 deletions(-)

diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala
index 8add63c..00115cf 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala
@@ -26,7 +26,6 @@ import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.csv.{CSVHeaderChecker, CSVOptions, 
UnivocityParser}
 import org.apache.spark.sql.catalyst.expressions.ExprUtils
 import org.apache.spark.sql.catalyst.util.CompressionCodecs
-import org.apache.spark.sql.errors.QueryCompilationErrors
 import org.apache.spark.sql.execution.datasources._
 import org.apache.spark.sql.sources._
 import org.apache.spark.sql.types._
@@ -101,21 +100,20 @@ class CSVFileFormat extends TextBasedFileFormat with 
DataSourceRegister {
       hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = 
{
     val broadcastedHadoopConf =
       sparkSession.sparkContext.broadcast(new 
SerializableConfiguration(hadoopConf))
-
+    val columnPruning = sparkSession.sessionState.conf.csvColumnPruning &&
+      !requiredSchema.exists(_.name == 
sparkSession.sessionState.conf.columnNameOfCorruptRecord)
     val parsedOptions = new CSVOptions(
       options,
-      sparkSession.sessionState.conf.csvColumnPruning,
+      columnPruning,
       sparkSession.sessionState.conf.sessionLocalTimeZone,
       sparkSession.sessionState.conf.columnNameOfCorruptRecord)
 
     // Check a field requirement for corrupt records here to throw an 
exception in a driver side
     ExprUtils.verifyColumnNameOfCorruptRecord(dataSchema, 
parsedOptions.columnNameOfCorruptRecord)
-
-    if (requiredSchema.length == 1 &&
-      requiredSchema.head.name == parsedOptions.columnNameOfCorruptRecord) {
-      throw 
QueryCompilationErrors.queryFromRawFilesIncludeCorruptRecordColumnError()
-    }
-    val columnPruning = sparkSession.sessionState.conf.csvColumnPruning
+    // Don't push any filter which refers to the "virtual" column which cannot 
present in the input.
+    // Such filters will be applied later on the upper layer.
+    val actualFilters =
+      
filters.filterNot(_.references.contains(parsedOptions.columnNameOfCorruptRecord))
 
     (file: PartitionedFile) => {
       val conf = broadcastedHadoopConf.value.value
@@ -127,7 +125,7 @@ class CSVFileFormat extends TextBasedFileFormat with 
DataSourceRegister {
         actualDataSchema,
         actualRequiredSchema,
         parsedOptions,
-        filters)
+        actualFilters)
       val schema = if (columnPruning) actualRequiredSchema else 
actualDataSchema
       val isStartOfFile = file.start == 0
       val headerChecker = new CSVHeaderChecker(
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVPartitionReaderFactory.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVPartitionReaderFactory.scala
index 31d31bd..bf996ab 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVPartitionReaderFactory.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVPartitionReaderFactory.scala
@@ -46,7 +46,6 @@ case class CSVPartitionReaderFactory(
     partitionSchema: StructType,
     parsedOptions: CSVOptions,
     filters: Seq[Filter]) extends FilePartitionReaderFactory {
-  private val columnPruning = sqlConf.csvColumnPruning
 
   override def buildReader(file: PartitionedFile): 
PartitionReader[InternalRow] = {
     val conf = broadcastedConf.value.value
@@ -59,7 +58,7 @@ case class CSVPartitionReaderFactory(
       actualReadDataSchema,
       parsedOptions,
       filters)
-    val schema = if (columnPruning) actualReadDataSchema else actualDataSchema
+    val schema = if (parsedOptions.columnPruning) actualReadDataSchema else 
actualDataSchema
     val isStartOfFile = file.start == 0
     val headerChecker = new CSVHeaderChecker(
       schema, parsedOptions, source = s"CSV file: ${file.filePath}", 
isStartOfFile)
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScan.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScan.scala
index 3f77b21..1eaa804 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScan.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScan.scala
@@ -24,7 +24,6 @@ import org.apache.spark.sql.SparkSession
 import org.apache.spark.sql.catalyst.csv.CSVOptions
 import org.apache.spark.sql.catalyst.expressions.{Expression, ExprUtils}
 import org.apache.spark.sql.connector.read.PartitionReaderFactory
-import org.apache.spark.sql.errors.QueryCompilationErrors
 import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex
 import org.apache.spark.sql.execution.datasources.csv.CSVDataSource
 import org.apache.spark.sql.execution.datasources.v2.{FileScan, 
TextBasedFileScan}
@@ -45,9 +44,11 @@ case class CSVScan(
     dataFilters: Seq[Expression] = Seq.empty)
   extends TextBasedFileScan(sparkSession, options) {
 
+  val columnPruning = sparkSession.sessionState.conf.csvColumnPruning &&
+    !readDataSchema.exists(_.name == 
sparkSession.sessionState.conf.columnNameOfCorruptRecord)
   private lazy val parsedOptions: CSVOptions = new CSVOptions(
     options.asScala.toMap,
-    columnPruning = sparkSession.sessionState.conf.csvColumnPruning,
+    columnPruning = columnPruning,
     sparkSession.sessionState.conf.sessionLocalTimeZone,
     sparkSession.sessionState.conf.columnNameOfCorruptRecord)
 
@@ -67,11 +68,10 @@ case class CSVScan(
   override def createReaderFactory(): PartitionReaderFactory = {
     // Check a field requirement for corrupt records here to throw an 
exception in a driver side
     ExprUtils.verifyColumnNameOfCorruptRecord(dataSchema, 
parsedOptions.columnNameOfCorruptRecord)
-
-    if (readDataSchema.length == 1 &&
-      readDataSchema.head.name == parsedOptions.columnNameOfCorruptRecord) {
-      throw 
QueryCompilationErrors.queryFromRawFilesIncludeCorruptRecordColumnError()
-    }
+    // Don't push any filter which refers to the "virtual" column which cannot 
present in the input.
+    // Such filters will be applied later on the upper layer.
+    val actualFilters =
+      
pushedFilters.filterNot(_.references.contains(parsedOptions.columnNameOfCorruptRecord))
 
     val caseSensitiveMap = options.asCaseSensitiveMap.asScala.toMap
     // Hadoop Configurations are case sensitive.
@@ -81,7 +81,7 @@ case class CSVScan(
     // The partition values are already truncated in `FileScan.partitions`.
     // We should use `readPartitionSchema` as the partition schema here.
     CSVPartitionReaderFactory(sparkSession.sessionState.conf, broadcastedConf,
-      dataSchema, readDataSchema, readPartitionSchema, parsedOptions, 
pushedFilters)
+      dataSchema, readDataSchema, readPartitionSchema, parsedOptions, 
actualFilters)
   }
 
   override def withFilters(
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
index 7efdf7c..8e501a7 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
@@ -1415,38 +1415,30 @@ abstract class CSVSuite
     checkAnswer(df, Row("a", null, "a"))
   }
 
-  test("SPARK-21610: Corrupt records are not handled properly when creating a 
dataframe " +
-    "from a file") {
-    val columnNameOfCorruptRecord = "_corrupt_record"
+  test("SPARK-38523: referring to the corrupt record column") {
     val schema = new StructType()
       .add("a", IntegerType)
       .add("b", DateType)
-      .add(columnNameOfCorruptRecord, StringType)
-    // negative cases
-    val msg = intercept[AnalysisException] {
-      spark
-        .read
-        .option("columnNameOfCorruptRecord", columnNameOfCorruptRecord)
-        .schema(schema)
-        .csv(testFile(valueMalformedFile))
-        .select(columnNameOfCorruptRecord)
-        .collect()
-    }.getMessage
-    assert(msg.contains("only include the internal corrupt record column"))
-
-    // workaround
-    val df = spark
+      .add("corrRec", StringType)
+    val readback = spark
       .read
-      .option("columnNameOfCorruptRecord", columnNameOfCorruptRecord)
+      .option("columnNameOfCorruptRecord", "corrRec")
       .schema(schema)
       .csv(testFile(valueMalformedFile))
-      .cache()
-    assert(df.filter($"_corrupt_record".isNotNull).count() == 1)
-    assert(df.filter($"_corrupt_record".isNull).count() == 1)
     checkAnswer(
-      df.select(columnNameOfCorruptRecord),
-      Row("0,2013-111_11 12:13:14") :: Row(null) :: Nil
-    )
+      readback,
+      Row(0, null, "0,2013-111_11 12:13:14") ::
+      Row(1, Date.valueOf("1983-08-04"), null) :: Nil)
+    checkAnswer(
+      readback.filter($"corrRec".isNotNull),
+      Row(0, null, "0,2013-111_11 12:13:14"))
+    checkAnswer(
+      readback.select($"corrRec", $"b"),
+      Row("0,2013-111_11 12:13:14", null) ::
+      Row(null, Date.valueOf("1983-08-04")) :: Nil)
+    checkAnswer(
+      readback.filter($"corrRec".isNull && $"a" === 1),
+      Row(1, Date.valueOf("1983-08-04"), null) :: Nil)
   }
 
   test("SPARK-23846: schema inferring touches less data if samplingRatio < 
1.0") {

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to