This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 4a34db9  [SPARK-32709][SQL] Support writing Hive bucketed table 
(Parquet/ORC format with Hive hash)
4a34db9 is described below

commit 4a34db9a17c69922f007738fdbe61fe2b8de688b
Author: Cheng Su <[email protected]>
AuthorDate: Fri Sep 17 14:28:51 2021 +0800

    [SPARK-32709][SQL] Support writing Hive bucketed table (Parquet/ORC format 
with Hive hash)
    
    ### What changes were proposed in this pull request?
    
    This is a re-work of https://github.com/apache/spark/pull/30003, here we 
add support for writing Hive bucketed table with Parquet/ORC file format (data 
source v1 write path and Hive hash as the hash function). Support for Hive's 
other file format will be added in follow up PR.
    
    The changes are mostly on:
    
    * `HiveMetastoreCatalog.scala`: When converting hive table relation to data 
source relation, pass bucket info (BucketSpec) and other hive related info as 
options into `HadoopFsRelation` and `LogicalRelation`, which can be later 
accessed by `FileFormatWriter` to customize bucket id and file name.
    
    * `FileFormatWriter.scala`: Use `HiveHash` for `bucketIdExpression` if it's 
writing to Hive bucketed table. In addition, Spark output file name should 
follow Hive/Presto/Trino bucketed file naming convention. Introduce another 
parameter `bucketFileNamePrefix` and it introduces subsequent change in 
`FileFormatDataWriter`.
    
    * `HadoopMapReduceCommitProtocol`: Implement the new file name APIs 
introduced in https://github.com/apache/spark/pull/33012, and change its 
sub-class `PathOutputCommitProtocol`, to make Hive bucketed table writing work 
with all commit protocol (including S3A commit protocol).
    
    ### Why are the changes needed?
    
    To make Spark write other-SQL-engines-compatible bucketed table. Currently 
Spark bucketed table cannot be leveraged by other SQL engines like Hive and 
Presto, because it uses a different hash function (Spark murmur3hash) and 
different file name scheme. With this PR, the Spark-written-Hive-bucketed-table 
can be efficiently read by Presto and Hive to do bucket filter pruning, join, 
group-by, etc. This was and is blocking several companies (confirmed from 
Facebook, Lyft, etc) migrate buc [...]
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, any Hive bucketed table (with Parquet/ORC format) written by Spark, is 
properly bucketed and can be efficiently processed by Hive and Presto/Trino.
    
    ### How was this patch tested?
    
    * Added unit test in BucketedWriteWithHiveSupportSuite.scala, to verify 
bucket file names and each row in each bucket is written properly.
    * Tested by Lyft Spark team (Shashank Pedamallu) to read Spark-written 
bucketed table from Trino, Spark and Hive.
    
    Closes #33432 from c21/hive-bucket-v1.
    
    Authored-by: Cheng Su <[email protected]>
    Signed-off-by: Wenchen Fan <[email protected]>
---
 .../io/HadoopMapReduceCommitProtocol.scala         | 18 ++++++--
 .../io/cloud/PathOutputCommitProtocol.scala        |  9 ++--
 .../sql/execution/datasources/BucketingUtils.scala |  3 ++
 .../datasources/FileFormatDataWriter.scala         | 34 +++++++++++----
 .../execution/datasources/FileFormatWriter.scala   | 39 ++++++++++++++----
 .../sql/execution/datasources/v2/FileWrite.scala   |  2 +-
 .../spark/sql/sources/BucketedWriteSuite.scala     | 34 ++++++++-------
 .../spark/sql/hive/HiveMetastoreCatalog.scala      | 33 ++++++++++-----
 .../org/apache/spark/sql/hive/HiveStrategies.scala |  4 +-
 .../execution/CreateHiveTableAsSelectCommand.scala |  2 +-
 .../BucketedWriteWithHiveSupportSuite.scala        | 48 ++++++++++++++++++++++
 11 files changed, 173 insertions(+), 53 deletions(-)

diff --git 
a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala
 
b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala
index c061d61..a39e9ab 100644
--- 
a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala
+++ 
b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala
@@ -118,7 +118,12 @@ class HadoopMapReduceCommitProtocol(
 
   override def newTaskTempFile(
       taskContext: TaskAttemptContext, dir: Option[String], ext: String): 
String = {
-    val filename = getFilename(taskContext, ext)
+    newTaskTempFile(taskContext, dir, FileNameSpec("", ext))
+  }
+
+  override def newTaskTempFile(
+      taskContext: TaskAttemptContext, dir: Option[String], spec: 
FileNameSpec): String = {
+    val filename = getFilename(taskContext, spec)
 
     val stagingDir: Path = committer match {
       // For FileOutputCommitter it has its own staging path called "work 
path".
@@ -141,7 +146,12 @@ class HadoopMapReduceCommitProtocol(
 
   override def newTaskTempFileAbsPath(
       taskContext: TaskAttemptContext, absoluteDir: String, ext: String): 
String = {
-    val filename = getFilename(taskContext, ext)
+    newTaskTempFileAbsPath(taskContext, absoluteDir, FileNameSpec("", ext))
+  }
+
+  override def newTaskTempFileAbsPath(
+      taskContext: TaskAttemptContext, absoluteDir: String, spec: 
FileNameSpec): String = {
+    val filename = getFilename(taskContext, spec)
     val absOutputPath = new Path(absoluteDir, filename).toString
 
     // Include a UUID here to prevent file collisions for one task writing to 
different dirs.
@@ -152,12 +162,12 @@ class HadoopMapReduceCommitProtocol(
     tmpOutputPath
   }
 
-  protected def getFilename(taskContext: TaskAttemptContext, ext: String): 
String = {
+  protected def getFilename(taskContext: TaskAttemptContext, spec: 
FileNameSpec): String = {
     // The file name looks like 
part-00000-2dd664f9-d2c4-4ffe-878f-c6c70c1fb0cb_00003-c000.parquet
     // Note that %05d does not truncate the split number, so if we have more 
than 100000 tasks,
     // the file name is fine and won't overflow.
     val split = taskContext.getTaskAttemptID.getTaskID.getId
-    f"part-$split%05d-$jobId$ext"
+    f"${spec.prefix}part-$split%05d-$jobId${spec.suffix}"
   }
 
   override def setupJob(jobContext: JobContext): Unit = {
diff --git 
a/hadoop-cloud/src/hadoop-3/main/scala/org/apache/spark/internal/io/cloud/PathOutputCommitProtocol.scala
 
b/hadoop-cloud/src/hadoop-3/main/scala/org/apache/spark/internal/io/cloud/PathOutputCommitProtocol.scala
index 2ca5087..fc5d0a0 100644
--- 
a/hadoop-cloud/src/hadoop-3/main/scala/org/apache/spark/internal/io/cloud/PathOutputCommitProtocol.scala
+++ 
b/hadoop-cloud/src/hadoop-3/main/scala/org/apache/spark/internal/io/cloud/PathOutputCommitProtocol.scala
@@ -23,6 +23,7 @@ import org.apache.hadoop.fs.Path
 import org.apache.hadoop.mapreduce.TaskAttemptContext
 import org.apache.hadoop.mapreduce.lib.output.{FileOutputCommitter, 
PathOutputCommitter, PathOutputCommitterFactory}
 
+import org.apache.spark.internal.io.FileNameSpec
 import org.apache.spark.internal.io.HadoopMapReduceCommitProtocol
 
 /**
@@ -122,20 +123,20 @@ class PathOutputCommitProtocol(
    *
    * @param taskContext task context
    * @param dir         optional subdirectory
-   * @param ext         file extension
+   * @param spec        file naming specification
    * @return a path as a string
    */
   override def newTaskTempFile(
       taskContext: TaskAttemptContext,
       dir: Option[String],
-      ext: String): String = {
+      spec: FileNameSpec): String = {
 
     val workDir = committer.getWorkPath
     val parent = dir.map {
       d => new Path(workDir, d)
     }.getOrElse(workDir)
-    val file = new Path(parent, getFilename(taskContext, ext))
-    logTrace(s"Creating task file $file for dir $dir and ext $ext")
+    val file = new Path(parent, getFilename(taskContext, spec))
+    logTrace(s"Creating task file $file for dir $dir and spec $spec")
     file.toString
   }
 
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BucketingUtils.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BucketingUtils.scala
index a776fc3..0622586 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BucketingUtils.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BucketingUtils.scala
@@ -33,6 +33,9 @@ object BucketingUtils {
   //   part-r-00000-2dd664f9-d2c4-4ffe-878f-c6c70c1fb0cb_00003.gz.parquet
   private val bucketedFileName = """.*_(\d+)(?:\..*)?$""".r
 
+  // The reserved option name for data source to write Hive-compatible 
bucketed table
+  val optionForHiveCompatibleBucketWrite = 
"__hive_compatible_bucketed_table_insertion__"
+
   def getBucketId(fileName: String): Option[Int] = fileName match {
     case bucketedFileName(bucketId) => Some(bucketId.toInt)
     case other => None
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala
index 815d8ac..0b1b616 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala
@@ -22,7 +22,7 @@ import org.apache.hadoop.fs.Path
 import org.apache.hadoop.mapreduce.TaskAttemptContext
 
 import org.apache.spark.internal.Logging
-import org.apache.spark.internal.io.FileCommitProtocol
+import org.apache.spark.internal.io.{FileCommitProtocol, FileNameSpec}
 import org.apache.spark.internal.io.FileCommitProtocol.TaskCommitMessage
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
@@ -193,7 +193,7 @@ abstract class BaseDynamicPartitionDataWriter(
   protected val isPartitioned = description.partitionColumns.nonEmpty
 
   /** Flag saying whether or not the data to be written out is bucketed. */
-  protected val isBucketed = description.bucketIdExpression.isDefined
+  protected val isBucketed = description.bucketSpec.isDefined
 
   assert(isPartitioned || isBucketed,
     s"""DynamicPartitionWriteTask should be used for writing out data that's 
either
@@ -238,7 +238,8 @@ abstract class BaseDynamicPartitionDataWriter(
   /** Given an input row, returns the corresponding `bucketId` */
   protected lazy val getBucketId: InternalRow => Int = {
     val proj =
-      UnsafeProjection.create(description.bucketIdExpression.toSeq, 
description.allColumns)
+      
UnsafeProjection.create(Seq(description.bucketSpec.get.bucketIdExpression),
+        description.allColumns)
     row => proj(row).getInt(0)
   }
 
@@ -271,17 +272,24 @@ abstract class BaseDynamicPartitionDataWriter(
 
     val bucketIdStr = 
bucketId.map(BucketingUtils.bucketIdToString).getOrElse("")
 
-    // This must be in a form that matches our bucketing format. See 
BucketingUtils.
-    val ext = f"$bucketIdStr.c$fileCounter%03d" +
+    // The prefix and suffix must be in a form that matches our bucketing 
format. See BucketingUtils
+    // for details. The prefix is required to represent bucket id when writing 
Hive-compatible
+    // bucketed table.
+    val prefix = bucketId match {
+      case Some(id) => description.bucketSpec.get.bucketFileNamePrefix(id)
+      case _ => ""
+    }
+    val suffix = f"$bucketIdStr.c$fileCounter%03d" +
       description.outputWriterFactory.getFileExtension(taskAttemptContext)
+    val fileNameSpec = FileNameSpec(prefix, suffix)
 
     val customPath = partDir.flatMap { dir =>
       
description.customPartitionLocations.get(PartitioningUtils.parsePathFragment(dir))
     }
     val currentPath = if (customPath.isDefined) {
-      committer.newTaskTempFileAbsPath(taskAttemptContext, customPath.get, ext)
+      committer.newTaskTempFileAbsPath(taskAttemptContext, customPath.get, 
fileNameSpec)
     } else {
-      committer.newTaskTempFile(taskAttemptContext, partDir, ext)
+      committer.newTaskTempFile(taskAttemptContext, partDir, fileNameSpec)
     }
 
     currentWriter = description.outputWriterFactory.newInstance(
@@ -554,6 +562,16 @@ class DynamicPartitionDataConcurrentWriter(
   }
 }
 
+/**
+ * Bucketing specification for all the write tasks.
+ *
+ * @param bucketIdExpression Expression to calculate bucket id based on bucket 
column(s).
+ * @param bucketFileNamePrefix Prefix of output file name based on bucket id.
+ */
+case class WriterBucketSpec(
+  bucketIdExpression: Expression,
+  bucketFileNamePrefix: Int => String)
+
 /** A shared job description for all the write tasks. */
 class WriteJobDescription(
     val uuid: String, // prevent collision between different (appending) write 
jobs
@@ -562,7 +580,7 @@ class WriteJobDescription(
     val allColumns: Seq[Attribute],
     val dataColumns: Seq[Attribute],
     val partitionColumns: Seq[Attribute],
-    val bucketIdExpression: Option[Expression],
+    val bucketSpec: Option[WriterBucketSpec],
     val path: String,
     val customPartitionLocations: Map[TablePartitionSpec, String],
     val maxRecordsPerFile: Long,
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala
index 2e36837..409e334 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala
@@ -122,12 +122,34 @@ object FileFormatWriter extends Logging {
     }
     val empty2NullPlan = if (needConvert) ProjectExec(projectList, plan) else 
plan
 
-    val bucketIdExpression = bucketSpec.map { spec =>
+    val writerBucketSpec = bucketSpec.map { spec =>
       val bucketColumns = spec.bucketColumnNames.map(c => 
dataColumns.find(_.name == c).get)
-      // Use `HashPartitioning.partitionIdExpression` as our bucket id 
expression, so that we can
-      // guarantee the data distribution is same between shuffle and bucketed 
data source, which
-      // enables us to only shuffle one side when join a bucketed table and a 
normal one.
-      HashPartitioning(bucketColumns, spec.numBuckets).partitionIdExpression
+
+      if (options.getOrElse(BucketingUtils.optionForHiveCompatibleBucketWrite, 
"false") ==
+        "true") {
+        // Hive bucketed table: use `HiveHash` and bitwise-and as bucket id 
expression.
+        // Without the extra bitwise-and operation, we can get wrong bucket id 
when hash value of
+        // columns is negative. See Hive implementation in
+        // 
`org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils#getBucketNumber()`.
+        val hashId = BitwiseAnd(HiveHash(bucketColumns), Literal(Int.MaxValue))
+        val bucketIdExpression = Pmod(hashId, Literal(spec.numBuckets))
+
+        // The bucket file name prefix is following Hive, Presto and Trino 
conversion, so this
+        // makes sure Hive bucketed table written by Spark, can be read by 
other SQL engines.
+        //
+        // Hive: 
`org.apache.hadoop.hive.ql.exec.Utilities#getBucketIdFromFile()`.
+        // Trino: 
`io.trino.plugin.hive.BackgroundHiveSplitLoader#BUCKET_PATTERNS`.
+        val fileNamePrefix = (bucketId: Int) => f"$bucketId%05d_0_"
+        WriterBucketSpec(bucketIdExpression, fileNamePrefix)
+      } else {
+        // Spark bucketed table: use `HashPartitioning.partitionIdExpression` 
as bucket id
+        // expression, so that we can guarantee the data distribution is same 
between shuffle and
+        // bucketed data source, which enables us to only shuffle one side 
when join a bucketed
+        // table and a normal one.
+        val bucketIdExpression = HashPartitioning(bucketColumns, 
spec.numBuckets)
+          .partitionIdExpression
+        WriterBucketSpec(bucketIdExpression, (_: Int) => "")
+      }
     }
     val sortColumns = bucketSpec.toSeq.flatMap {
       spec => spec.sortColumnNames.map(c => dataColumns.find(_.name == c).get)
@@ -148,7 +170,7 @@ object FileFormatWriter extends Logging {
       allColumns = outputSpec.outputColumns,
       dataColumns = dataColumns,
       partitionColumns = partitionColumns,
-      bucketIdExpression = bucketIdExpression,
+      bucketSpec = writerBucketSpec,
       path = outputSpec.outputPath,
       customPartitionLocations = outputSpec.customPartitionLocations,
       maxRecordsPerFile = 
caseInsensitiveOptions.get("maxRecordsPerFile").map(_.toLong)
@@ -159,7 +181,8 @@ object FileFormatWriter extends Logging {
     )
 
     // We should first sort by partition columns, then bucket id, and finally 
sorting columns.
-    val requiredOrdering = partitionColumns ++ bucketIdExpression ++ 
sortColumns
+    val requiredOrdering =
+      partitionColumns ++ writerBucketSpec.map(_.bucketIdExpression) ++ 
sortColumns
     // the sort order doesn't matter
     val actualOrdering = empty2NullPlan.outputOrdering.map(_.child)
     val orderingMatched = if (requiredOrdering.length > actualOrdering.length) 
{
@@ -286,7 +309,7 @@ object FileFormatWriter extends Logging {
       if (sparkPartitionId != 0 && !iterator.hasNext) {
         // In case of empty job, leave first partition to save meta for file 
format like parquet.
         new EmptyDirectoryDataWriter(description, taskAttemptContext, 
committer)
-      } else if (description.partitionColumns.isEmpty && 
description.bucketIdExpression.isEmpty) {
+      } else if (description.partitionColumns.isEmpty && 
description.bucketSpec.isEmpty) {
         new SingleDirectoryDataWriter(description, taskAttemptContext, 
committer)
       } else {
         concurrentOutputWriterSpec match {
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWrite.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWrite.scala
index 7f66a09..ccc467a 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWrite.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWrite.scala
@@ -133,7 +133,7 @@ trait FileWrite extends Write {
       allColumns = allColumns,
       dataColumns = allColumns,
       partitionColumns = Seq.empty,
-      bucketIdExpression = None,
+      bucketSpec = None,
       path = pathName,
       customPartitionLocations = Map.empty,
       maxRecordsPerFile = 
caseInsensitiveOptions.get("maxRecordsPerFile").map(_.toLong)
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala
index 0a5feda..ae35f29 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala
@@ -19,10 +19,10 @@ package org.apache.spark.sql.sources
 
 import java.io.File
 
-import org.apache.spark.sql.{AnalysisException, QueryTest}
+import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest}
 import org.apache.spark.sql.catalyst.TableIdentifier
 import org.apache.spark.sql.catalyst.catalog.BucketSpec
-import org.apache.spark.sql.catalyst.expressions.UnsafeProjection
+import org.apache.spark.sql.catalyst.expressions.{Expression, UnsafeProjection}
 import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
 import org.apache.spark.sql.execution.datasources.BucketingUtils
 import org.apache.spark.sql.functions._
@@ -136,29 +136,35 @@ abstract class BucketedWriteSuite extends QueryTest with 
SQLTestUtils {
     (0 until 50).map(i => (i % 5, i % 13, i.toString)).toDF("i", "j", "k")
   }
 
-  def tableDir: File = {
-    val identifier = 
spark.sessionState.sqlParser.parseTableIdentifier("bucketed_table")
+  def tableDir(table: String = "bucketed_table"): File = {
+    val identifier = spark.sessionState.sqlParser.parseTableIdentifier(table)
     new File(spark.sessionState.catalog.defaultTablePath(identifier))
   }
 
+  private def bucketIdExpression(expressions: Seq[Expression], numBuckets: 
Int): Expression =
+    HashPartitioning(expressions, numBuckets).partitionIdExpression
+
   /**
    * A helper method to check the bucket write functionality in low level, 
i.e. check the written
    * bucket files to see if the data are correct.  User should pass in a data 
dir that these bucket
    * files are written to, and the format of data(parquet, json, etc.), and 
the bucketing
    * information.
    */
-  private def testBucketing(
+  protected def testBucketing(
       dataDir: File,
       source: String,
       numBuckets: Int,
       bucketCols: Seq[String],
-      sortCols: Seq[String] = Nil): Unit = {
+      sortCols: Seq[String] = Nil,
+      inputDF: DataFrame = df,
+      bucketIdExpression: (Seq[Expression], Int) => Expression = 
bucketIdExpression,
+      getBucketIdFromFileName: String => Option[Int] = 
BucketingUtils.getBucketId): Unit = {
     val allBucketFiles = dataDir.listFiles().filterNot(f =>
       f.getName.startsWith(".") || f.getName.startsWith("_")
     )
 
     for (bucketFile <- allBucketFiles) {
-      val bucketId = BucketingUtils.getBucketId(bucketFile.getName).getOrElse {
+      val bucketId = getBucketIdFromFileName(bucketFile.getName).getOrElse {
         fail(s"Unable to find the related bucket files.")
       }
 
@@ -167,7 +173,7 @@ abstract class BucketedWriteSuite extends QueryTest with 
SQLTestUtils {
       val selectedColumns = (bucketCols ++ sortCols).distinct
       // We may lose the type information after write(e.g. json format doesn't 
keep schema
       // information), here we get the types from the original dataframe.
-      val types = df.select(selectedColumns.map(col): 
_*).schema.map(_.dataType)
+      val types = inputDF.select(selectedColumns.map(col): 
_*).schema.map(_.dataType)
       val columns = selectedColumns.zip(types).map {
         case (colName, dt) => col(colName).cast(dt)
       }
@@ -188,7 +194,7 @@ abstract class BucketedWriteSuite extends QueryTest with 
SQLTestUtils {
       val qe = readBack.select(bucketCols.map(col): _*).queryExecution
       val rows = qe.toRdd.map(_.copy()).collect()
       val getBucketId = UnsafeProjection.create(
-        HashPartitioning(qe.analyzed.output, numBuckets).partitionIdExpression 
:: Nil,
+        bucketIdExpression(qe.analyzed.output, numBuckets) :: Nil,
         qe.analyzed.output)
 
       for (row <- rows) {
@@ -208,7 +214,7 @@ abstract class BucketedWriteSuite extends QueryTest with 
SQLTestUtils {
           .saveAsTable("bucketed_table")
 
         for (i <- 0 until 5) {
-          testBucketing(new File(tableDir, s"i=$i"), source, 8, Seq("j", "k"))
+          testBucketing(new File(tableDir(), s"i=$i"), source, 8, Seq("j", 
"k"))
         }
       }
     }
@@ -225,7 +231,7 @@ abstract class BucketedWriteSuite extends QueryTest with 
SQLTestUtils {
           .saveAsTable("bucketed_table")
 
         for (i <- 0 until 5) {
-          testBucketing(new File(tableDir, s"i=$i"), source, 8, Seq("j"), 
Seq("k"))
+          testBucketing(new File(tableDir(), s"i=$i"), source, 8, Seq("j"), 
Seq("k"))
         }
       }
     }
@@ -255,7 +261,7 @@ abstract class BucketedWriteSuite extends QueryTest with 
SQLTestUtils {
           .bucketBy(8, "i", "j")
           .saveAsTable("bucketed_table")
 
-        testBucketing(tableDir, source, 8, Seq("i", "j"))
+        testBucketing(tableDir(), source, 8, Seq("i", "j"))
       }
     }
   }
@@ -269,7 +275,7 @@ abstract class BucketedWriteSuite extends QueryTest with 
SQLTestUtils {
           .sortBy("k")
           .saveAsTable("bucketed_table")
 
-        testBucketing(tableDir, source, 8, Seq("i", "j"), Seq("k"))
+        testBucketing(tableDir(), source, 8, Seq("i", "j"), Seq("k"))
       }
     }
   }
@@ -286,7 +292,7 @@ abstract class BucketedWriteSuite extends QueryTest with 
SQLTestUtils {
             .saveAsTable("bucketed_table")
 
           for (i <- 0 until 5) {
-            testBucketing(new File(tableDir, s"i=$i"), source, 8, Seq("j", 
"k"))
+            testBucketing(new File(tableDir(), s"i=$i"), source, 8, Seq("j", 
"k"))
           }
         }
       }
diff --git 
a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala 
b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
index 05aa648..c905a52 100644
--- 
a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
+++ 
b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
@@ -125,7 +125,7 @@ private[hive] class HiveMetastoreCatalog(sparkSession: 
SparkSession) extends Log
   private def isParquetProperty(key: String) =
     key.startsWith("parquet.") || key.contains(".parquet.")
 
-  def convert(relation: HiveTableRelation): LogicalRelation = {
+  def convert(relation: HiveTableRelation, isWrite: Boolean): LogicalRelation 
= {
     val serde = 
relation.tableMeta.storage.serde.getOrElse("").toLowerCase(Locale.ROOT)
 
     // Consider table and storage properties. For properties existing in both 
sides, storage
@@ -134,7 +134,7 @@ private[hive] class HiveMetastoreCatalog(sparkSession: 
SparkSession) extends Log
       val options = 
relation.tableMeta.properties.filterKeys(isParquetProperty).toMap ++
         relation.tableMeta.storage.properties + (ParquetOptions.MERGE_SCHEMA ->
         
SQLConf.get.getConf(HiveUtils.CONVERT_METASTORE_PARQUET_WITH_SCHEMA_MERGING).toString)
-        convertToLogicalRelation(relation, options, 
classOf[ParquetFileFormat], "parquet")
+        convertToLogicalRelation(relation, options, 
classOf[ParquetFileFormat], "parquet", isWrite)
     } else {
       val options = 
relation.tableMeta.properties.filterKeys(isOrcProperty).toMap ++
         relation.tableMeta.storage.properties
@@ -143,13 +143,15 @@ private[hive] class HiveMetastoreCatalog(sparkSession: 
SparkSession) extends Log
           relation,
           options,
           
classOf[org.apache.spark.sql.execution.datasources.orc.OrcFileFormat],
-          "orc")
+          "orc",
+          isWrite)
       } else {
         convertToLogicalRelation(
           relation,
           options,
           classOf[org.apache.spark.sql.hive.orc.OrcFileFormat],
-          "orc")
+          "orc",
+          isWrite)
       }
     }
   }
@@ -158,7 +160,8 @@ private[hive] class HiveMetastoreCatalog(sparkSession: 
SparkSession) extends Log
       relation: HiveTableRelation,
       options: Map[String, String],
       fileFormatClass: Class[_ <: FileFormat],
-      fileType: String): LogicalRelation = {
+      fileType: String,
+      isWrite: Boolean): LogicalRelation = {
     val metastoreSchema = relation.tableMeta.schema
     val tableIdentifier =
       QualifiedTableName(relation.tableMeta.database, 
relation.tableMeta.identifier.table)
@@ -166,6 +169,14 @@ private[hive] class HiveMetastoreCatalog(sparkSession: 
SparkSession) extends Log
     val lazyPruningEnabled = 
sparkSession.sqlContext.conf.manageFilesourcePartitions
     val tablePath = new Path(relation.tableMeta.location)
     val fileFormat = fileFormatClass.getConstructor().newInstance()
+    val bucketSpec = relation.tableMeta.bucketSpec
+    val (hiveOptions, hiveBucketSpec) =
+      if (isWrite) {
+        (options.updated(BucketingUtils.optionForHiveCompatibleBucketWrite, 
"true"),
+          bucketSpec)
+      } else {
+        (options, None)
+      }
 
     val result = if (relation.isPartitioned) {
       val partitionSchema = relation.tableMeta.partitionSchema
@@ -207,16 +218,16 @@ private[hive] class HiveMetastoreCatalog(sparkSession: 
SparkSession) extends Log
             }
           }
 
-          val updatedTable = inferIfNeeded(relation, options, fileFormat, 
Option(fileIndex))
+          val updatedTable = inferIfNeeded(relation, hiveOptions, fileFormat, 
Option(fileIndex))
 
           // Spark SQL's data source table now support static and dynamic 
partition insert. Source
           // table converted from Hive table should always use dynamic.
-          val enableDynamicPartition = 
options.updated("partitionOverwriteMode", "dynamic")
+          val enableDynamicPartition = 
hiveOptions.updated("partitionOverwriteMode", "dynamic")
           val fsRelation = HadoopFsRelation(
             location = fileIndex,
             partitionSchema = partitionSchema,
             dataSchema = updatedTable.dataSchema,
-            bucketSpec = None,
+            bucketSpec = hiveBucketSpec,
             fileFormat = fileFormat,
             options = enableDynamicPartition)(sparkSession = sparkSession)
           val created = LogicalRelation(fsRelation, updatedTable)
@@ -236,17 +247,17 @@ private[hive] class HiveMetastoreCatalog(sparkSession: 
SparkSession) extends Log
           fileFormatClass,
           None)
         val logicalRelation = cached.getOrElse {
-          val updatedTable = inferIfNeeded(relation, options, fileFormat)
+          val updatedTable = inferIfNeeded(relation, hiveOptions, fileFormat)
           val created =
             LogicalRelation(
               DataSource(
                 sparkSession = sparkSession,
                 paths = rootPath.toString :: Nil,
                 userSpecifiedSchema = Option(updatedTable.dataSchema),
-                bucketSpec = None,
+                bucketSpec = hiveBucketSpec,
                 // Do not interpret the 'path' option at all when tables are 
read using the Hive
                 // source, since the URIs will already have been read from the 
table's LOCATION.
-                options = options.filter { case (k, _) => 
!k.equalsIgnoreCase("path") },
+                options = hiveOptions.filter { case (k, _) => 
!k.equalsIgnoreCase("path") },
                 className = fileType).resolveRelation(),
               table = updatedTable)
 
diff --git 
a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala 
b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala
index c8a5c03..0da7465 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala
@@ -211,13 +211,13 @@ case class RelationConversions(
           if query.resolved && DDLUtils.isHiveTable(r.tableMeta) &&
             (!r.isPartitioned || 
conf.getConf(HiveUtils.CONVERT_INSERTING_PARTITIONED_TABLE))
             && isConvertible(r) =>
-        InsertIntoStatement(metastoreCatalog.convert(r), partition, cols,
+        InsertIntoStatement(metastoreCatalog.convert(r, isWrite = true), 
partition, cols,
           query, overwrite, ifPartitionNotExists)
 
       // Read path
       case relation: HiveTableRelation
           if DDLUtils.isHiveTable(relation.tableMeta) && 
isConvertible(relation) =>
-        metastoreCatalog.convert(relation)
+        metastoreCatalog.convert(relation, isWrite = false)
 
       // CTAS
       case CreateTable(tableDesc, mode, Some(query))
diff --git 
a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala
 
b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala
index beaebb5..cef2a36 100644
--- 
a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala
+++ 
b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala
@@ -161,7 +161,7 @@ case class OptimizedCreateHiveTableAsSelectCommand(
     val metastoreCatalog = 
catalog.asInstanceOf[HiveSessionCatalog].metastoreCatalog
     val hiveTable = DDLUtils.readHiveTable(tableDesc)
 
-    val hadoopRelation = metastoreCatalog.convert(hiveTable) match {
+    val hadoopRelation = metastoreCatalog.convert(hiveTable, isWrite = true) 
match {
       case LogicalRelation(t: HadoopFsRelation, _, _, _) => t
       case _ => throw 
QueryCompilationErrors.tableIdentifierNotConvertedToHadoopFsRelationError(
         tableIdentifier)
diff --git 
a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteWithHiveSupportSuite.scala
 
b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteWithHiveSupportSuite.scala
index bdbdcc2..c12caaa 100644
--- 
a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteWithHiveSupportSuite.scala
+++ 
b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteWithHiveSupportSuite.scala
@@ -17,6 +17,11 @@
 
 package org.apache.spark.sql.sources
 
+import java.io.File
+
+import org.apache.spark.sql.SaveMode
+import org.apache.spark.sql.catalyst.expressions.{BitwiseAnd, Expression, 
HiveHash, Literal, Pmod}
+import org.apache.spark.sql.hive.test.TestHive.implicits._
 import org.apache.spark.sql.hive.test.TestHiveSingleton
 import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION
 
@@ -27,4 +32,47 @@ class BucketedWriteWithHiveSupportSuite extends 
BucketedWriteSuite with TestHive
   }
 
   override protected def fileFormatsToTest: Seq[String] = Seq("parquet", "orc")
+
+  test("write hive bucketed table") {
+    def bucketIdExpression(expressions: Seq[Expression], numBuckets: Int): 
Expression =
+      Pmod(BitwiseAnd(HiveHash(expressions), Literal(Int.MaxValue)), 
Literal(8))
+
+    def getBucketIdFromFileName(fileName: String): Option[Int] = {
+      val hiveBucketedFileName = """^(\d+)_0_.*$""".r
+      fileName match {
+        case hiveBucketedFileName(bucketId) => Some(bucketId.toInt)
+        case _ => None
+      }
+    }
+
+    val table = "hive_bucketed_table"
+
+    fileFormatsToTest.foreach { format =>
+      withTable(table) {
+        sql(
+          s"""
+             |CREATE TABLE IF NOT EXISTS $table (i int, j string)
+             |PARTITIONED BY(k string)
+             |CLUSTERED BY (i, j) SORTED BY (i) INTO 8 BUCKETS
+             |STORED AS $format
+           """.stripMargin)
+
+        val df =
+          (0 until 50).map(i => (i % 13, i.toString, i % 5)).toDF("i", "j", 
"k")
+        df.write.mode(SaveMode.Overwrite).insertInto(table)
+
+        for (k <- 0 until 5) {
+          testBucketing(
+            new File(tableDir(table), s"k=$k"),
+            format,
+            8,
+            Seq("i", "j"),
+            Seq("i"),
+            df,
+            bucketIdExpression,
+            getBucketIdFromFileName)
+        }
+      }
+    }
+  }
 }

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to