Repository: spark
Updated Branches:
  refs/heads/master 60977889e -> 79636054f


[SPARK-20148][SQL] Extend the file commit API to allow subscribing to task 
commit messages

## What changes were proposed in this pull request?

The internal FileCommitProtocol interface returns all task commit messages in 
bulk to the implementation when a job finishes. However, it is sometimes useful 
to access those messages before the job completes, so that the driver gets 
incremental progress updates before the job finishes.

This adds an `onTaskCommit` listener to the internal api.

## How was this patch tested?

Unit tests.

cc rxin

Author: Eric Liang <e...@databricks.com>

Closes #17475 from ericl/file-commit-api-ext.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/79636054
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/79636054
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/79636054

Branch: refs/heads/master
Commit: 79636054f60dd639e9d326e1328717e97df13304
Parents: 6097788
Author: Eric Liang <e...@databricks.com>
Authored: Wed Mar 29 20:59:48 2017 -0700
Committer: Reynold Xin <r...@databricks.com>
Committed: Wed Mar 29 20:59:48 2017 -0700

----------------------------------------------------------------------
 .../spark/internal/io/FileCommitProtocol.scala  |  7 +++++
 .../datasources/FileFormatWriter.scala          | 22 ++++++++++----
 .../sql/test/DataFrameReaderWriterSuite.scala   | 31 +++++++++++++++++++-
 3 files changed, 53 insertions(+), 7 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/79636054/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala
----------------------------------------------------------------------
diff --git 
a/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala 
b/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala
index 2394cf3..7efa941 100644
--- a/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala
+++ b/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala
@@ -121,6 +121,13 @@ abstract class FileCommitProtocol {
   def deleteWithJob(fs: FileSystem, path: Path, recursive: Boolean): Boolean = 
{
     fs.delete(path, recursive)
   }
+
+  /**
+   * Called on the driver after a task commits. This can be used to access 
task commit messages
+   * before the job has finished. These same task commit messages will be 
passed to commitJob()
+   * if the entire job succeeds.
+   */
+  def onTaskCommit(taskCommit: TaskCommitMessage): Unit = {}
 }
 
 

http://git-wip-us.apache.org/repos/asf/spark/blob/79636054/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala
index 7957224..bda64d4 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala
@@ -80,6 +80,9 @@ object FileFormatWriter extends Logging {
        """.stripMargin)
   }
 
+  /** The result of a successful write task. */
+  private case class WriteTaskResult(commitMsg: TaskCommitMessage, 
updatedPartitions: Set[String])
+
   /**
    * Basic work flow of this command is:
    * 1. Driver side setup, including output committer initialization and data 
source specific
@@ -172,8 +175,9 @@ object FileFormatWriter extends Logging {
             global = false,
             child = queryExecution.executedPlan).execute()
         }
-
-        val ret = sparkSession.sparkContext.runJob(rdd,
+        val ret = new Array[WriteTaskResult](rdd.partitions.length)
+        sparkSession.sparkContext.runJob(
+          rdd,
           (taskContext: TaskContext, iter: Iterator[InternalRow]) => {
             executeTask(
               description = description,
@@ -182,10 +186,16 @@ object FileFormatWriter extends Logging {
               sparkAttemptNumber = taskContext.attemptNumber(),
               committer,
               iterator = iter)
+          },
+          0 until rdd.partitions.length,
+          (index, res: WriteTaskResult) => {
+            committer.onTaskCommit(res.commitMsg)
+            ret(index) = res
           })
 
-        val commitMsgs = ret.map(_._1)
-        val updatedPartitions = 
ret.flatMap(_._2).distinct.map(PartitioningUtils.parsePathFragment)
+        val commitMsgs = ret.map(_.commitMsg)
+        val updatedPartitions = ret.flatMap(_.updatedPartitions)
+          .distinct.map(PartitioningUtils.parsePathFragment)
 
         committer.commitJob(job, commitMsgs)
         logInfo(s"Job ${job.getJobID} committed.")
@@ -205,7 +215,7 @@ object FileFormatWriter extends Logging {
       sparkPartitionId: Int,
       sparkAttemptNumber: Int,
       committer: FileCommitProtocol,
-      iterator: Iterator[InternalRow]): (TaskCommitMessage, Set[String]) = {
+      iterator: Iterator[InternalRow]): WriteTaskResult = {
 
     val jobId = SparkHadoopWriterUtils.createJobID(new Date, sparkStageId)
     val taskId = new TaskID(jobId, TaskType.MAP, sparkPartitionId)
@@ -238,7 +248,7 @@ object FileFormatWriter extends Logging {
         // Execute the task to write rows out and commit the task.
         val outputPartitions = writeTask.execute(iterator)
         writeTask.releaseResources()
-        (committer.commitTask(taskAttemptContext), outputPartitions)
+        WriteTaskResult(committer.commitTask(taskAttemptContext), 
outputPartitions)
       })(catchBlock = {
         // If there is an error, release resource and then abort the task
         try {

http://git-wip-us.apache.org/repos/asf/spark/blob/79636054/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala
index 8287776..7c71e72 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala
@@ -18,9 +18,12 @@
 package org.apache.spark.sql.test
 
 import java.io.File
+import java.util.concurrent.ConcurrentLinkedQueue
 
 import org.scalatest.BeforeAndAfter
 
+import org.apache.spark.internal.io.FileCommitProtocol.TaskCommitMessage
+import org.apache.spark.internal.io.HadoopMapReduceCommitProtocol
 import org.apache.spark.sql._
 import org.apache.spark.sql.catalyst.TableIdentifier
 import org.apache.spark.sql.sources._
@@ -41,7 +44,6 @@ object LastOptions {
   }
 }
 
-
 /** Dummy provider. */
 class DefaultSource
   extends RelationProvider
@@ -107,6 +109,20 @@ class DefaultSourceWithoutUserSpecifiedSchema
   }
 }
 
+object MessageCapturingCommitProtocol {
+  val commitMessages = new ConcurrentLinkedQueue[TaskCommitMessage]()
+}
+
+class MessageCapturingCommitProtocol(jobId: String, path: String)
+    extends HadoopMapReduceCommitProtocol(jobId, path) {
+
+  // captures commit messages for testing
+  override def onTaskCommit(msg: TaskCommitMessage): Unit = {
+    MessageCapturingCommitProtocol.commitMessages.offer(msg)
+  }
+}
+
+
 class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with 
BeforeAndAfter {
   import testImplicits._
 
@@ -291,6 +307,19 @@ class DataFrameReaderWriterSuite extends QueryTest with 
SharedSQLContext with Be
     Option(dir).map(spark.read.format("org.apache.spark.sql.test").load)
   }
 
+  test("write path implements onTaskCommit API correctly") {
+    withSQLConf(
+        "spark.sql.sources.commitProtocolClass" ->
+          classOf[MessageCapturingCommitProtocol].getCanonicalName) {
+      withTempDir { dir =>
+        val path = dir.getCanonicalPath
+        MessageCapturingCommitProtocol.commitMessages.clear()
+        spark.range(10).repartition(10).write.mode("overwrite").parquet(path)
+        assert(MessageCapturingCommitProtocol.commitMessages.size() == 10)
+      }
+    }
+  }
+
   test("read a data source that does not extend SchemaRelationProvider") {
     val dfReader = spark.read
       .option("from", "1")


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to