This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 4eb694c  [SPARK-27443][SQL] Support UDF input_file_name in file source 
V2
4eb694c is described below

commit 4eb694c58f4210f0d87a28b153605ba3c960cc38
Author: Gengliang Wang <[email protected]>
AuthorDate: Fri Apr 12 20:30:42 2019 +0800

    [SPARK-27443][SQL] Support UDF input_file_name in file source V2
    
    ## What changes were proposed in this pull request?
    
    Currently, if we select the UDF `input_file_name` as a column in file 
source V2, the results are empty.
    We should support it in file source V2.
    
    ## How was this patch tested?
    
    Unit test
    
    Closes #24347 from gengliangwang/input_file_name.
    
    Authored-by: Gengliang Wang <[email protected]>
    Signed-off-by: Wenchen Fan <[email protected]>
---
 .../execution/datasources/v2/FilePartitionReader.scala | 18 ++++++++++++++----
 .../datasources/v2/FilePartitionReaderFactory.scala    |  6 +++---
 .../apache/spark/sql/FileBasedDataSourceSuite.scala    | 13 +++++++++++++
 3 files changed, 30 insertions(+), 7 deletions(-)

diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FilePartitionReader.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FilePartitionReader.scala
index 7ecd516..7c7b468 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FilePartitionReader.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FilePartitionReader.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.datasources.v2
 import java.io.{FileNotFoundException, IOException}
 
 import org.apache.spark.internal.Logging
+import org.apache.spark.rdd.InputFileBlockHolder
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.sources.v2.reader.PartitionReader
 
@@ -35,8 +36,7 @@ class FilePartitionReader[T](readers: 
Iterator[PartitionedFileReader[T]])
       if (readers.hasNext) {
         if (ignoreMissingFiles || ignoreCorruptFiles) {
           try {
-            currentReader = readers.next()
-            logInfo(s"Reading file $currentReader")
+            currentReader = getNextReader()
           } catch {
             case e: FileNotFoundException if ignoreMissingFiles =>
               logWarning(s"Skipped missing file: $currentReader", e)
@@ -48,11 +48,11 @@ class FilePartitionReader[T](readers: 
Iterator[PartitionedFileReader[T]])
               logWarning(
                 s"Skipped the rest of the content in the corrupted file: 
$currentReader", e)
               currentReader = null
+              InputFileBlockHolder.unset()
               return false
           }
         } else {
-          currentReader = readers.next()
-          logInfo(s"Reading file $currentReader")
+          currentReader = getNextReader()
         }
       } else {
         return false
@@ -84,5 +84,15 @@ class FilePartitionReader[T](readers: 
Iterator[PartitionedFileReader[T]])
     if (currentReader != null) {
       currentReader.close()
     }
+    InputFileBlockHolder.unset()
+  }
+
+  private def getNextReader(): PartitionedFileReader[T] = {
+    val reader = readers.next()
+    logInfo(s"Reading file $reader")
+    // Sets InputFileBlockHolder for the file block's information
+    val file = reader.file
+    InputFileBlockHolder.set(file.filePath, file.start, file.length)
+    reader
   }
 }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FilePartitionReaderFactory.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FilePartitionReaderFactory.scala
index d053ea9..5a19412 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FilePartitionReaderFactory.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FilePartitionReaderFactory.scala
@@ -27,7 +27,7 @@ abstract class FilePartitionReaderFactory extends 
PartitionReaderFactory {
     assert(partition.isInstanceOf[FilePartition])
     val filePartition = partition.asInstanceOf[FilePartition]
     val iter = filePartition.files.toIterator.map { file =>
-      new PartitionedFileReader(file, buildReader(file))
+      PartitionedFileReader(file, buildReader(file))
     }
     new FilePartitionReader[InternalRow](iter)
   }
@@ -36,7 +36,7 @@ abstract class FilePartitionReaderFactory extends 
PartitionReaderFactory {
     assert(partition.isInstanceOf[FilePartition])
     val filePartition = partition.asInstanceOf[FilePartition]
     val iter = filePartition.files.toIterator.map { file =>
-      new PartitionedFileReader(file, buildColumnarReader(file))
+      PartitionedFileReader(file, buildColumnarReader(file))
     }
     new FilePartitionReader[ColumnarBatch](iter)
   }
@@ -49,7 +49,7 @@ abstract class FilePartitionReaderFactory extends 
PartitionReaderFactory {
 }
 
 // A compound class for combining file and its corresponding reader.
-private[v2] class PartitionedFileReader[T](
+private[v2] case class PartitionedFileReader[T](
     file: PartitionedFile,
     reader: PartitionReader[T]) extends PartitionReader[T] {
   override def next(): Boolean = reader.next()
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala
index 506156d..8fcffbf 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala
@@ -526,6 +526,19 @@ class FileBasedDataSourceSuite extends QueryTest with 
SharedSQLContext with Befo
     }
   }
 
+  test("UDF input_file_name()") {
+    Seq("", "orc").foreach { useV1SourceReaderList =>
+      withSQLConf(SQLConf.USE_V1_SOURCE_READER_LIST.key -> 
useV1SourceReaderList) {
+        withTempPath { dir =>
+          val path = dir.getCanonicalPath
+          spark.range(10).write.orc(path)
+          val row = spark.read.orc(path).select(input_file_name).first()
+          assert(row.getString(0).contains(path))
+        }
+      }
+    }
+  }
+
   test("Return correct results when data columns overlap with partition 
columns") {
     Seq("parquet", "orc", "json").foreach { format =>
       withTempPath { path =>


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to