c21 commented on a change in pull request #32198:
URL: https://github.com/apache/spark/pull/32198#discussion_r617869815



##########
File path: 
sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
##########
@@ -3150,6 +3150,14 @@ object SQLConf {
     .booleanConf
     .createWithDefault(false)
 
+  val MAX_CONCURRENT_OUTPUT_WRITERS = 
buildConf("spark.sql.maxConcurrentOutputWriters")

Review comment:
       @cloud-fan - updated.

##########
File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala
##########
@@ -47,6 +48,7 @@ abstract class FileFormatDataWriter(
   protected val MAX_FILE_COUNTER: Int = 1000 * 1000
   protected val updatedPartitions: mutable.Set[String] = mutable.Set[String]()
   protected var currentWriter: OutputWriter = _

Review comment:
       @cloud-fan - makes sense. Af first place I was hesitating to make 
broader change of interface `OutputWriter`. But updated now.

##########
File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala
##########
@@ -255,25 +320,182 @@ class DynamicPartitionDataWriter(
       }
       if (isBucketed) {
         currentBucketId = nextBucketId
-        statsTrackers.foreach(_.newBucket(currentBucketId.get))
       }
 
       fileCounter = 0
-      newOutputWriter(currentPartitionValues, currentBucketId)
-    } else if (description.maxRecordsPerFile > 0 &&
-      recordsInFile >= description.maxRecordsPerFile) {
-      // Exceeded the threshold in terms of the number of records per file.
-      // Create a new file by increasing the file counter.
-      fileCounter += 1
-      assert(fileCounter < MAX_FILE_COUNTER,
-        s"File counter $fileCounter is beyond max value $MAX_FILE_COUNTER")
+      newOutputWriter(currentPartitionValues, currentBucketId, true)
+    } else {
+      checkRecordsInFile(currentPartitionValues, currentBucketId)
+    }
+    writeRecord(record)
+  }
+}
+
+/**
+ * Dynamic partition writer with concurrent writers, meaning multiple 
concurrent writers are opened
+ * for writing.
+ *
+ * The process has the following steps:
+ *  - Step 1: Maintain a map of output writers per each partition and/or 
bucket columns. Keep all
+ *            writers opened and write rows one by one.
+ *  - Step 2: If number of concurrent writers exceeds limit, sort rest of rows 
on partition and/or
+ *            bucket column(s). Write rows one by one, and eagerly close the 
writer when finishing
+ *            each partition and/or bucket.
+ *
+ * Caller is expected to call `writeWithIterator()` instead of `write()` to 
write records.
+ */
+class DynamicPartitionDataConcurrentWriter(
+    description: WriteJobDescription,
+    taskAttemptContext: TaskAttemptContext,
+    committer: FileCommitProtocol,
+    concurrentOutputWriterSpec: ConcurrentOutputWriterSpec)
+  extends BaseDynamicPartitionDataWriter(description, taskAttemptContext, 
committer) {
+
+  /** Wrapper class to index a unique concurrent output writer. */
+  private case class WriterIndex(
+    var partitionValues: Option[UnsafeRow],
+    var bucketId: Option[Int])
+
+  /** Wrapper class for status of a unique concurrent output writer. */
+  private case class WriterStatus(

Review comment:
       @cloud-fan - updated.

##########
File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala
##########
@@ -255,25 +320,182 @@ class DynamicPartitionDataWriter(
       }
       if (isBucketed) {
         currentBucketId = nextBucketId
-        statsTrackers.foreach(_.newBucket(currentBucketId.get))
       }
 
       fileCounter = 0
-      newOutputWriter(currentPartitionValues, currentBucketId)
-    } else if (description.maxRecordsPerFile > 0 &&
-      recordsInFile >= description.maxRecordsPerFile) {
-      // Exceeded the threshold in terms of the number of records per file.
-      // Create a new file by increasing the file counter.
-      fileCounter += 1
-      assert(fileCounter < MAX_FILE_COUNTER,
-        s"File counter $fileCounter is beyond max value $MAX_FILE_COUNTER")
+      newOutputWriter(currentPartitionValues, currentBucketId, true)
+    } else {
+      checkRecordsInFile(currentPartitionValues, currentBucketId)
+    }
+    writeRecord(record)
+  }
+}
+
+/**
+ * Dynamic partition writer with concurrent writers, meaning multiple 
concurrent writers are opened
+ * for writing.
+ *
+ * The process has the following steps:
+ *  - Step 1: Maintain a map of output writers per each partition and/or 
bucket columns. Keep all
+ *            writers opened and write rows one by one.
+ *  - Step 2: If number of concurrent writers exceeds limit, sort rest of rows 
on partition and/or
+ *            bucket column(s). Write rows one by one, and eagerly close the 
writer when finishing
+ *            each partition and/or bucket.
+ *
+ * Caller is expected to call `writeWithIterator()` instead of `write()` to 
write records.
+ */
+class DynamicPartitionDataConcurrentWriter(
+    description: WriteJobDescription,
+    taskAttemptContext: TaskAttemptContext,
+    committer: FileCommitProtocol,
+    concurrentOutputWriterSpec: ConcurrentOutputWriterSpec)
+  extends BaseDynamicPartitionDataWriter(description, taskAttemptContext, 
committer) {
+
+  /** Wrapper class to index a unique concurrent output writer. */
+  private case class WriterIndex(
+    var partitionValues: Option[UnsafeRow],
+    var bucketId: Option[Int])
+
+  /** Wrapper class for status of a unique concurrent output writer. */
+  private case class WriterStatus(
+    var outputWriter: OutputWriter,
+    var recordsInFile: Long,
+    var fileCounter: Int,
+    var latestFilePath: String)
 
-      newOutputWriter(currentPartitionValues, currentBucketId)
+  /**
+   * State to indicate if we are falling back to sort-based writer.
+   * Because we first try to use concurrent writers, its initial value is 
false.
+   */
+  private var sortBased: Boolean = false
+  private val concurrentWriters = mutable.HashMap[WriterIndex, WriterStatus]()
+  private val currentWriterId = WriterIndex(None, None)
+
+  /**
+   * Release resources for all concurrent output writers.
+   */
+  override protected def releaseResources(): Unit = {
+    currentWriter = null

Review comment:
       @cloud-fan - yes it should be not null here as it indicates the last 
active writer. `currentWriter` is always put inside `concurrentWriters` in 
`retrieveWriterInMap()`. So it should be closed below when iterating 
`concurrentWriters.values`.

##########
File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala
##########
@@ -56,6 +58,7 @@ abstract class FileFormatDataWriter(
     if (currentWriter != null) {

Review comment:
       @cloud-fan - yes I agree with that looks pretty weird. Good idea and 
updated.

##########
File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala
##########
@@ -255,25 +320,182 @@ class DynamicPartitionDataWriter(
       }
       if (isBucketed) {
         currentBucketId = nextBucketId
-        statsTrackers.foreach(_.newBucket(currentBucketId.get))
       }
 
       fileCounter = 0
-      newOutputWriter(currentPartitionValues, currentBucketId)
-    } else if (description.maxRecordsPerFile > 0 &&
-      recordsInFile >= description.maxRecordsPerFile) {
-      // Exceeded the threshold in terms of the number of records per file.
-      // Create a new file by increasing the file counter.
-      fileCounter += 1
-      assert(fileCounter < MAX_FILE_COUNTER,
-        s"File counter $fileCounter is beyond max value $MAX_FILE_COUNTER")
+      newOutputWriter(currentPartitionValues, currentBucketId, true)
+    } else {
+      checkRecordsInFile(currentPartitionValues, currentBucketId)
+    }
+    writeRecord(record)
+  }
+}
+
+/**
+ * Dynamic partition writer with concurrent writers, meaning multiple 
concurrent writers are opened
+ * for writing.
+ *
+ * The process has the following steps:
+ *  - Step 1: Maintain a map of output writers per each partition and/or 
bucket columns. Keep all
+ *            writers opened and write rows one by one.
+ *  - Step 2: If number of concurrent writers exceeds limit, sort rest of rows 
on partition and/or
+ *            bucket column(s). Write rows one by one, and eagerly close the 
writer when finishing
+ *            each partition and/or bucket.
+ *
+ * Caller is expected to call `writeWithIterator()` instead of `write()` to 
write records.
+ */
+class DynamicPartitionDataConcurrentWriter(
+    description: WriteJobDescription,
+    taskAttemptContext: TaskAttemptContext,
+    committer: FileCommitProtocol,
+    concurrentOutputWriterSpec: ConcurrentOutputWriterSpec)
+  extends BaseDynamicPartitionDataWriter(description, taskAttemptContext, 
committer) {
+
+  /** Wrapper class to index a unique concurrent output writer. */
+  private case class WriterIndex(
+    var partitionValues: Option[UnsafeRow],
+    var bucketId: Option[Int])
+
+  /** Wrapper class for status of a unique concurrent output writer. */
+  private case class WriterStatus(
+    var outputWriter: OutputWriter,
+    var recordsInFile: Long,
+    var fileCounter: Int,
+    var latestFilePath: String)
 
-      newOutputWriter(currentPartitionValues, currentBucketId)
+  /**
+   * State to indicate if we are falling back to sort-based writer.
+   * Because we first try to use concurrent writers, its initial value is 
false.
+   */
+  private var sortBased: Boolean = false
+  private val concurrentWriters = mutable.HashMap[WriterIndex, WriterStatus]()
+  private val currentWriterId = WriterIndex(None, None)
+
+  /**
+   * Release resources for all concurrent output writers.
+   */
+  override protected def releaseResources(): Unit = {
+    currentWriter = null
+    concurrentWriters.values.foreach(status => {
+      if (status.outputWriter != null) {
+        try {
+          status.outputWriter.close()
+        } finally {
+          status.outputWriter = null
+        }
+      }
+    })
+    concurrentWriters.clear()
+  }
+
+  override def write(record: InternalRow): Unit = {
+    val nextPartitionValues = if (isPartitioned) 
Some(getPartitionValues(record)) else None
+    val nextBucketId = if (isBucketed) Some(getBucketId(record)) else None
+
+    if (currentWriterId.partitionValues != nextPartitionValues ||
+      currentWriterId.bucketId != nextBucketId) {
+      // See a new partition or bucket - write to a new partition dir (or a 
new bucket file).
+      updateCurrentWriterStatusInMap()
+      if (isBucketed) {
+        currentWriterId.bucketId = nextBucketId
+      }
+      if (isPartitioned && currentWriterId.partitionValues != 
nextPartitionValues) {
+        currentWriterId.partitionValues = Some(nextPartitionValues.get.copy())
+        if (!concurrentWriters.contains(currentWriterId)) {
+          
statsTrackers.foreach(_.newPartition(currentWriterId.partitionValues.get))
+        }
+      }
+      retrieveWriterInMap()
     }
-    val outputRow = getOutputRow(record)
-    currentWriter.write(outputRow)
-    statsTrackers.foreach(_.newRow(outputRow))
-    recordsInFile += 1
+
+    checkRecordsInFile(currentWriterId.partitionValues, 
currentWriterId.bucketId)
+    writeRecord(record)
+  }
+
+  /**
+   * Write iterator of records with concurrent writers.
+   */
+  def writeWithIterator(iterator: Iterator[InternalRow]): Unit = {
+    while (iterator.hasNext && !sortBased) {
+      write(iterator.next())
+    }
+
+    if (iterator.hasNext) {
+      clearCurrentWriterStatus()
+      val sorter = concurrentOutputWriterSpec.createSorter()
+      val sortIterator = 
sorter.sort(iterator.asInstanceOf[Iterator[UnsafeRow]])
+      while (sortIterator.hasNext) {
+        write(sortIterator.next())
+      }
+    }
+  }
+
+  /**
+   * Update current writer status when a new writer is needed for writing row.
+   */
+  private def updateCurrentWriterStatusInMap(): Unit = {
+    if (currentWriterId.partitionValues.isDefined || 
currentWriterId.bucketId.isDefined) {
+      if (!sortBased) {
+        // Update writer status in concurrent writers map, because the writer 
is probably needed
+        // again later for writing other rows.

Review comment:
       I want to avoid extra computation (e.g. lookup in hashmap) if it is 
unnecessary. I don't see a benefit to update status in map per row, as anyway 
the current writer's status is captured in `(currentWriter, recordsInFile, 
fileCounter)` already.

##########
File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala
##########
@@ -273,14 +289,25 @@ object FileFormatWriter extends Logging {
       } else if (description.partitionColumns.isEmpty && 
description.bucketIdExpression.isEmpty) {
         new SingleDirectoryDataWriter(description, taskAttemptContext, 
committer)
       } else {
-        new DynamicPartitionDataWriter(description, taskAttemptContext, 
committer)
+        concurrentOutputWriterSpec match {
+          case Some(spec) =>
+            new DynamicPartitionDataConcurrentWriter(
+              description, taskAttemptContext, committer, spec)
+          case _ =>
+            new DynamicPartitionDataSingleWriter(description, 
taskAttemptContext, committer)
+        }
       }
 
     try {
       Utils.tryWithSafeFinallyAndFailureCallbacks(block = {
         // Execute the task to write rows out and commit the task.
-        while (iterator.hasNext) {
-          dataWriter.write(iterator.next())
+        dataWriter match {
+          case w: DynamicPartitionDataConcurrentWriter =>
+            w.writeWithIterator(iterator)

Review comment:
       @cloud-fan - wondering what's the benefit of doing it? Updated anyway 
now.

##########
File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BasicWriteStatsTracker.scala
##########
@@ -134,23 +134,20 @@ class BasicWriteTaskStatsTracker(hadoopConf: 
Configuration)
     partitions.append(partitionValues)
   }
 
-  override def newBucket(bucketId: Int): Unit = {
-    // currently unhandled
+  override def newFile(filePath: String): Unit = {
+    submittedFiles += filePath
+    numSubmittedFiles += 1
   }
 
-  override def newFile(filePath: String): Unit = {
-    statCurrentFile()
-    curFile = Some(filePath)
-    submittedFiles += 1
+  override def closeFile(filePath: String): Unit = {
+    getFileStats(filePath)
+    submittedFiles.remove(filePath)
   }
 
-  private def statCurrentFile(): Unit = {
-    curFile.foreach { path =>
-      getFileSize(path).foreach { len =>
-        numBytes += len
-        numFiles += 1
-      }
-      curFile = None
+  private def getFileStats(filePath: String): Unit = {

Review comment:
       @cloud-fan - updated.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]



---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to