Repository: spark
Updated Branches:
  refs/heads/master 65fe902e1 -> 776b8f17c


[SPARK-19563][SQL] avoid unnecessary sort in FileFormatWriter

## What changes were proposed in this pull request?

In `FileFormatWriter`, we will sort the input rows by partition columns and 
bucket id and sort columns, if we want to write data out partitioned or 
bucketed.

However, if the data is already sorted, we will sort it again, which is 
unnecssary.

This PR removes the sorting logic in `FileFormatWriter` and use `SortExec` 
instead. We will not add `SortExec` if the data is already sorted.

## How was this patch tested?

I did a micro benchmark manually
```
val df = spark.range(10000000).select($"id", $"id" % 10 as "part").sort("part")
spark.time(df.write.partitionBy("part").parquet("/tmp/test"))
```
The result was about 6.4 seconds before this PR, and is 5.7 seconds afterwards.

close https://github.com/apache/spark/pull/16724

Author: Wenchen Fan <[email protected]>

Closes #16898 from cloud-fan/writer.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/776b8f17
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/776b8f17
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/776b8f17

Branch: refs/heads/master
Commit: 776b8f17cfc687a57c005a421a81e591c8d44a3f
Parents: 65fe902
Author: Wenchen Fan <[email protected]>
Authored: Sun Feb 19 18:13:12 2017 -0800
Committer: Wenchen Fan <[email protected]>
Committed: Sun Feb 19 18:13:12 2017 -0800

----------------------------------------------------------------------
 .../datasources/FileFormatWriter.scala          | 189 +++++++++----------
 1 file changed, 90 insertions(+), 99 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/776b8f17/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala
index be13cbc..6443584 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala
@@ -38,10 +38,9 @@ import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.util.DateTimeUtils
-import org.apache.spark.sql.execution.{QueryExecution, SQLExecution, 
UnsafeKVExternalSorter}
-import org.apache.spark.sql.types.{IntegerType, StringType, StructField, 
StructType}
+import org.apache.spark.sql.execution.{QueryExecution, SortExec, SQLExecution}
+import org.apache.spark.sql.types.{StringType, StructType}
 import org.apache.spark.util.{SerializableConfiguration, Utils}
-import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter
 
 
 /** A helper object for writing FileFormat data out to a location. */
@@ -64,9 +63,9 @@ object FileFormatWriter extends Logging {
       val serializableHadoopConf: SerializableConfiguration,
       val outputWriterFactory: OutputWriterFactory,
       val allColumns: Seq[Attribute],
-      val partitionColumns: Seq[Attribute],
       val dataColumns: Seq[Attribute],
-      val bucketSpec: Option[BucketSpec],
+      val partitionColumns: Seq[Attribute],
+      val bucketIdExpression: Option[Expression],
       val path: String,
       val customPartitionLocations: Map[TablePartitionSpec, String],
       val maxRecordsPerFile: Long)
@@ -108,9 +107,21 @@ object FileFormatWriter extends Logging {
     job.setOutputValueClass(classOf[InternalRow])
     FileOutputFormat.setOutputPath(job, new Path(outputSpec.outputPath))
 
+    val allColumns = queryExecution.logical.output
     val partitionSet = AttributeSet(partitionColumns)
     val dataColumns = 
queryExecution.logical.output.filterNot(partitionSet.contains)
 
+    val bucketIdExpression = bucketSpec.map { spec =>
+      val bucketColumns = spec.bucketColumnNames.map(c => 
dataColumns.find(_.name == c).get)
+      // Use `HashPartitioning.partitionIdExpression` as our bucket id 
expression, so that we can
+      // guarantee the data distribution is same between shuffle and bucketed 
data source, which
+      // enables us to only shuffle one side when join a bucketed table and a 
normal one.
+      HashPartitioning(bucketColumns, spec.numBuckets).partitionIdExpression
+    }
+    val sortColumns = bucketSpec.toSeq.flatMap {
+      spec => spec.sortColumnNames.map(c => dataColumns.find(_.name == c).get)
+    }
+
     // Note: prepareWrite has side effect. It sets "job".
     val outputWriterFactory =
       fileFormat.prepareWrite(sparkSession, job, options, 
dataColumns.toStructType)
@@ -119,23 +130,45 @@ object FileFormatWriter extends Logging {
       uuid = UUID.randomUUID().toString,
       serializableHadoopConf = new 
SerializableConfiguration(job.getConfiguration),
       outputWriterFactory = outputWriterFactory,
-      allColumns = queryExecution.logical.output,
-      partitionColumns = partitionColumns,
+      allColumns = allColumns,
       dataColumns = dataColumns,
-      bucketSpec = bucketSpec,
+      partitionColumns = partitionColumns,
+      bucketIdExpression = bucketIdExpression,
       path = outputSpec.outputPath,
       customPartitionLocations = outputSpec.customPartitionLocations,
       maxRecordsPerFile = options.get("maxRecordsPerFile").map(_.toLong)
         .getOrElse(sparkSession.sessionState.conf.maxRecordsPerFile)
     )
 
+    // We should first sort by partition columns, then bucket id, and finally 
sorting columns.
+    val requiredOrdering = partitionColumns ++ bucketIdExpression ++ 
sortColumns
+    // the sort order doesn't matter
+    val actualOrdering = 
queryExecution.executedPlan.outputOrdering.map(_.child)
+    val orderingMatched = if (requiredOrdering.length > actualOrdering.length) 
{
+      false
+    } else {
+      requiredOrdering.zip(actualOrdering).forall {
+        case (requiredOrder, childOutputOrder) =>
+          requiredOrder.semanticEquals(childOutputOrder)
+      }
+    }
+
     SQLExecution.withNewExecutionId(sparkSession, queryExecution) {
       // This call shouldn't be put into the `try` block below because it only 
initializes and
       // prepares the job, any exception thrown from here shouldn't cause 
abortJob() to be called.
       committer.setupJob(job)
 
       try {
-        val ret = sparkSession.sparkContext.runJob(queryExecution.toRdd,
+        val rdd = if (orderingMatched) {
+          queryExecution.toRdd
+        } else {
+          SortExec(
+            requiredOrdering.map(SortOrder(_, Ascending)),
+            global = false,
+            child = queryExecution.executedPlan).execute()
+        }
+
+        val ret = sparkSession.sparkContext.runJob(rdd,
           (taskContext: TaskContext, iter: Iterator[InternalRow]) => {
             executeTask(
               description = description,
@@ -189,7 +222,7 @@ object FileFormatWriter extends Logging {
     committer.setupTask(taskAttemptContext)
 
     val writeTask =
-      if (description.partitionColumns.isEmpty && 
description.bucketSpec.isEmpty) {
+      if (description.partitionColumns.isEmpty && 
description.bucketIdExpression.isEmpty) {
         new SingleDirectoryWriteTask(description, taskAttemptContext, 
committer)
       } else {
         new DynamicPartitionWriteTask(description, taskAttemptContext, 
committer)
@@ -287,31 +320,16 @@ object FileFormatWriter extends Logging {
    * multiple directories (partitions) or files (bucketing).
    */
   private class DynamicPartitionWriteTask(
-      description: WriteJobDescription,
+      desc: WriteJobDescription,
       taskAttemptContext: TaskAttemptContext,
       committer: FileCommitProtocol) extends ExecuteWriteTask {
 
     // currentWriter is initialized whenever we see a new key
     private var currentWriter: OutputWriter = _
 
-    private val bucketColumns: Seq[Attribute] = 
description.bucketSpec.toSeq.flatMap {
-      spec => spec.bucketColumnNames.map(c => 
description.allColumns.find(_.name == c).get)
-    }
-
-    private val sortColumns: Seq[Attribute] = 
description.bucketSpec.toSeq.flatMap {
-      spec => spec.sortColumnNames.map(c => description.allColumns.find(_.name 
== c).get)
-    }
-
-    private def bucketIdExpression: Option[Expression] = 
description.bucketSpec.map { spec =>
-      // Use `HashPartitioning.partitionIdExpression` as our bucket id 
expression, so that we can
-      // guarantee the data distribution is same between shuffle and bucketed 
data source, which
-      // enables us to only shuffle one side when join a bucketed table and a 
normal one.
-      HashPartitioning(bucketColumns, spec.numBuckets).partitionIdExpression
-    }
-
-    /** Expressions that given a partition key build a string like: 
col1=val/col2=val/... */
-    private def partitionStringExpression: Seq[Expression] = {
-      description.partitionColumns.zipWithIndex.flatMap { case (c, i) =>
+    /** Expressions that given partition columns build a path string like: 
col1=val/col2=val/... */
+    private def partitionPathExpression: Seq[Expression] = {
+      desc.partitionColumns.zipWithIndex.flatMap { case (c, i) =>
         // TODO: use correct timezone for partition values.
         val escaped = ScalaUDF(
           ExternalCatalogUtils.escapePathName _,
@@ -325,35 +343,46 @@ object FileFormatWriter extends Logging {
     }
 
     /**
-     * Open and returns a new OutputWriter given a partition key and optional 
bucket id.
+     * Opens a new OutputWriter given a partition key and optional bucket id.
      * If bucket id is specified, we will append it to the end of the file 
name, but before the
      * file extension, e.g. 
part-r-00009-ea518ad4-455a-4431-b471-d24e03814677-00002.gz.parquet
      *
-     * @param key vaues for fields consisting of partition keys for the 
current row
-     * @param partString a function that projects the partition values into a 
string
+     * @param partColsAndBucketId a row consisting of partition columns and a 
bucket id for the
+     *                            current row.
+     * @param getPartitionPath a function that projects the partition values 
into a path string.
      * @param fileCounter the number of files that have been written in the 
past for this specific
      *                    partition. This is used to limit the max number of 
records written for a
      *                    single file. The value should start from 0.
+     * @param updatedPartitions the set of updated partition paths, we should 
add the new partition
+     *                          path of this writer to it.
      */
     private def newOutputWriter(
-        key: InternalRow, partString: UnsafeProjection, fileCounter: Int): 
Unit = {
-      val partDir =
-        if (description.partitionColumns.isEmpty) None else 
Option(partString(key).getString(0))
+        partColsAndBucketId: InternalRow,
+        getPartitionPath: UnsafeProjection,
+        fileCounter: Int,
+        updatedPartitions: mutable.Set[String]): Unit = {
+      val partDir = if (desc.partitionColumns.isEmpty) {
+        None
+      } else {
+        Option(getPartitionPath(partColsAndBucketId).getString(0))
+      }
+      partDir.foreach(updatedPartitions.add)
 
-      // If the bucket spec is defined, the bucket column is right after the 
partition columns
-      val bucketId = if (description.bucketSpec.isDefined) {
-        
BucketingUtils.bucketIdToString(key.getInt(description.partitionColumns.length))
+      // If the bucketId expression is defined, the bucketId column is right 
after the partition
+      // columns.
+      val bucketId = if (desc.bucketIdExpression.isDefined) {
+        
BucketingUtils.bucketIdToString(partColsAndBucketId.getInt(desc.partitionColumns.length))
       } else {
         ""
       }
 
       // This must be in a form that matches our bucketing format. See 
BucketingUtils.
       val ext = f"$bucketId.c$fileCounter%03d" +
-        description.outputWriterFactory.getFileExtension(taskAttemptContext)
+        desc.outputWriterFactory.getFileExtension(taskAttemptContext)
 
       val customPath = partDir match {
         case Some(dir) =>
-          
description.customPartitionLocations.get(PartitioningUtils.parsePathFragment(dir))
+          
desc.customPartitionLocations.get(PartitioningUtils.parsePathFragment(dir))
         case _ =>
           None
       }
@@ -363,80 +392,42 @@ object FileFormatWriter extends Logging {
         committer.newTaskTempFile(taskAttemptContext, partDir, ext)
       }
 
-      currentWriter = description.outputWriterFactory.newInstance(
+      currentWriter = desc.outputWriterFactory.newInstance(
         path = path,
-        dataSchema = description.dataColumns.toStructType,
+        dataSchema = desc.dataColumns.toStructType,
         context = taskAttemptContext)
     }
 
     override def execute(iter: Iterator[InternalRow]): Set[String] = {
-      // We should first sort by partition columns, then bucket id, and 
finally sorting columns.
-      val sortingExpressions: Seq[Expression] =
-        description.partitionColumns ++ bucketIdExpression ++ sortColumns
-      val getSortingKey = UnsafeProjection.create(sortingExpressions, 
description.allColumns)
-
-      val sortingKeySchema = StructType(sortingExpressions.map {
-        case a: Attribute => StructField(a.name, a.dataType, a.nullable)
-        // The sorting expressions are all `Attribute` except bucket id.
-        case _ => StructField("bucketId", IntegerType, nullable = false)
-      })
-
-      // Returns the data columns to be written given an input row
-      val getOutputRow = UnsafeProjection.create(
-        description.dataColumns, description.allColumns)
-
-      // Returns the partition path given a partition key.
-      val getPartitionStringFunc = UnsafeProjection.create(
-        Seq(Concat(partitionStringExpression)), description.partitionColumns)
-
-      // Sorts the data before write, so that we only need one writer at the 
same time.
-      val sorter = new UnsafeKVExternalSorter(
-        sortingKeySchema,
-        StructType.fromAttributes(description.dataColumns),
-        SparkEnv.get.blockManager,
-        SparkEnv.get.serializerManager,
-        TaskContext.get().taskMemoryManager().pageSizeBytes,
-        
SparkEnv.get.conf.getLong("spark.shuffle.spill.numElementsForceSpillThreshold",
-          UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD))
-
-      while (iter.hasNext) {
-        val currentRow = iter.next()
-        sorter.insertKV(getSortingKey(currentRow), getOutputRow(currentRow))
-      }
+      val getPartitionColsAndBucketId = UnsafeProjection.create(
+        desc.partitionColumns ++ desc.bucketIdExpression, desc.allColumns)
 
-      val getBucketingKey: InternalRow => InternalRow = if 
(sortColumns.isEmpty) {
-        identity
-      } else {
-        
UnsafeProjection.create(sortingExpressions.dropRight(sortColumns.length).zipWithIndex.map
 {
-          case (expr, ordinal) => BoundReference(ordinal, expr.dataType, 
expr.nullable)
-        })
-      }
+      // Generates the partition path given the row generated by 
`getPartitionColsAndBucketId`.
+      val getPartPath = UnsafeProjection.create(
+        Seq(Concat(partitionPathExpression)), desc.partitionColumns)
 
-      val sortedIterator = sorter.sortedIterator()
+      // Returns the data columns to be written given an input row
+      val getOutputRow = UnsafeProjection.create(desc.dataColumns, 
desc.allColumns)
 
       // If anything below fails, we should abort the task.
       var recordsInFile: Long = 0L
       var fileCounter = 0
-      var currentKey: UnsafeRow = null
+      var currentPartColsAndBucketId: UnsafeRow = null
       val updatedPartitions = mutable.Set[String]()
-      while (sortedIterator.next()) {
-        val nextKey = 
getBucketingKey(sortedIterator.getKey).asInstanceOf[UnsafeRow]
-        if (currentKey != nextKey) {
-          // See a new key - write to a new partition (new file).
-          currentKey = nextKey.copy()
-          logDebug(s"Writing partition: $currentKey")
+      for (row <- iter) {
+        val nextPartColsAndBucketId = getPartitionColsAndBucketId(row)
+        if (currentPartColsAndBucketId != nextPartColsAndBucketId) {
+          // See a new partition or bucket - write to a new partition dir (or 
a new bucket file).
+          currentPartColsAndBucketId = nextPartColsAndBucketId.copy()
+          logDebug(s"Writing partition: $currentPartColsAndBucketId")
 
           recordsInFile = 0
           fileCounter = 0
 
           releaseResources()
-          newOutputWriter(currentKey, getPartitionStringFunc, fileCounter)
-          val partitionPath = getPartitionStringFunc(currentKey).getString(0)
-          if (partitionPath.nonEmpty) {
-            updatedPartitions.add(partitionPath)
-          }
-        } else if (description.maxRecordsPerFile > 0 &&
-            recordsInFile >= description.maxRecordsPerFile) {
+          newOutputWriter(currentPartColsAndBucketId, getPartPath, 
fileCounter, updatedPartitions)
+        } else if (desc.maxRecordsPerFile > 0 &&
+            recordsInFile >= desc.maxRecordsPerFile) {
           // Exceeded the threshold in terms of the number of records per file.
           // Create a new file by increasing the file counter.
           recordsInFile = 0
@@ -445,10 +436,10 @@ object FileFormatWriter extends Logging {
             s"File counter $fileCounter is beyond max value $MAX_FILE_COUNTER")
 
           releaseResources()
-          newOutputWriter(currentKey, getPartitionStringFunc, fileCounter)
+          newOutputWriter(currentPartColsAndBucketId, getPartPath, 
fileCounter, updatedPartitions)
         }
 
-        currentWriter.write(sortedIterator.getValue)
+        currentWriter.write(getOutputRow(row))
         recordsInFile += 1
       }
       releaseResources()


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to