This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 4a34db9 [SPARK-32709][SQL] Support writing Hive bucketed table
(Parquet/ORC format with Hive hash)
4a34db9 is described below
commit 4a34db9a17c69922f007738fdbe61fe2b8de688b
Author: Cheng Su <[email protected]>
AuthorDate: Fri Sep 17 14:28:51 2021 +0800
[SPARK-32709][SQL] Support writing Hive bucketed table (Parquet/ORC format
with Hive hash)
### What changes were proposed in this pull request?
This is a re-work of https://github.com/apache/spark/pull/30003, here we
add support for writing Hive bucketed table with Parquet/ORC file format (data
source v1 write path and Hive hash as the hash function). Support for Hive's
other file format will be added in follow up PR.
The changes are mostly on:
* `HiveMetastoreCatalog.scala`: When converting hive table relation to data
source relation, pass bucket info (BucketSpec) and other hive related info as
options into `HadoopFsRelation` and `LogicalRelation`, which can be later
accessed by `FileFormatWriter` to customize bucket id and file name.
* `FileFormatWriter.scala`: Use `HiveHash` for `bucketIdExpression` if it's
writing to Hive bucketed table. In addition, Spark output file name should
follow Hive/Presto/Trino bucketed file naming convention. Introduce another
parameter `bucketFileNamePrefix` and it introduces subsequent change in
`FileFormatDataWriter`.
* `HadoopMapReduceCommitProtocol`: Implement the new file name APIs
introduced in https://github.com/apache/spark/pull/33012, and change its
sub-class `PathOutputCommitProtocol`, to make Hive bucketed table writing work
with all commit protocol (including S3A commit protocol).
### Why are the changes needed?
To make Spark write other-SQL-engines-compatible bucketed table. Currently
Spark bucketed table cannot be leveraged by other SQL engines like Hive and
Presto, because it uses a different hash function (Spark murmur3hash) and
different file name scheme. With this PR, the Spark-written-Hive-bucketed-table
can be efficiently read by Presto and Hive to do bucket filter pruning, join,
group-by, etc. This was and is blocking several companies (confirmed from
Facebook, Lyft, etc) migrate buc [...]
### Does this PR introduce _any_ user-facing change?
Yes, any Hive bucketed table (with Parquet/ORC format) written by Spark, is
properly bucketed and can be efficiently processed by Hive and Presto/Trino.
### How was this patch tested?
* Added unit test in BucketedWriteWithHiveSupportSuite.scala, to verify
bucket file names and each row in each bucket is written properly.
* Tested by Lyft Spark team (Shashank Pedamallu) to read Spark-written
bucketed table from Trino, Spark and Hive.
Closes #33432 from c21/hive-bucket-v1.
Authored-by: Cheng Su <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
.../io/HadoopMapReduceCommitProtocol.scala | 18 ++++++--
.../io/cloud/PathOutputCommitProtocol.scala | 9 ++--
.../sql/execution/datasources/BucketingUtils.scala | 3 ++
.../datasources/FileFormatDataWriter.scala | 34 +++++++++++----
.../execution/datasources/FileFormatWriter.scala | 39 ++++++++++++++----
.../sql/execution/datasources/v2/FileWrite.scala | 2 +-
.../spark/sql/sources/BucketedWriteSuite.scala | 34 ++++++++-------
.../spark/sql/hive/HiveMetastoreCatalog.scala | 33 ++++++++++-----
.../org/apache/spark/sql/hive/HiveStrategies.scala | 4 +-
.../execution/CreateHiveTableAsSelectCommand.scala | 2 +-
.../BucketedWriteWithHiveSupportSuite.scala | 48 ++++++++++++++++++++++
11 files changed, 173 insertions(+), 53 deletions(-)
diff --git
a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala
b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala
index c061d61..a39e9ab 100644
---
a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala
+++
b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala
@@ -118,7 +118,12 @@ class HadoopMapReduceCommitProtocol(
override def newTaskTempFile(
taskContext: TaskAttemptContext, dir: Option[String], ext: String):
String = {
- val filename = getFilename(taskContext, ext)
+ newTaskTempFile(taskContext, dir, FileNameSpec("", ext))
+ }
+
+ override def newTaskTempFile(
+ taskContext: TaskAttemptContext, dir: Option[String], spec:
FileNameSpec): String = {
+ val filename = getFilename(taskContext, spec)
val stagingDir: Path = committer match {
// For FileOutputCommitter it has its own staging path called "work
path".
@@ -141,7 +146,12 @@ class HadoopMapReduceCommitProtocol(
override def newTaskTempFileAbsPath(
taskContext: TaskAttemptContext, absoluteDir: String, ext: String):
String = {
- val filename = getFilename(taskContext, ext)
+ newTaskTempFileAbsPath(taskContext, absoluteDir, FileNameSpec("", ext))
+ }
+
+ override def newTaskTempFileAbsPath(
+ taskContext: TaskAttemptContext, absoluteDir: String, spec:
FileNameSpec): String = {
+ val filename = getFilename(taskContext, spec)
val absOutputPath = new Path(absoluteDir, filename).toString
// Include a UUID here to prevent file collisions for one task writing to
different dirs.
@@ -152,12 +162,12 @@ class HadoopMapReduceCommitProtocol(
tmpOutputPath
}
- protected def getFilename(taskContext: TaskAttemptContext, ext: String):
String = {
+ protected def getFilename(taskContext: TaskAttemptContext, spec:
FileNameSpec): String = {
// The file name looks like
part-00000-2dd664f9-d2c4-4ffe-878f-c6c70c1fb0cb_00003-c000.parquet
// Note that %05d does not truncate the split number, so if we have more
than 100000 tasks,
// the file name is fine and won't overflow.
val split = taskContext.getTaskAttemptID.getTaskID.getId
- f"part-$split%05d-$jobId$ext"
+ f"${spec.prefix}part-$split%05d-$jobId${spec.suffix}"
}
override def setupJob(jobContext: JobContext): Unit = {
diff --git
a/hadoop-cloud/src/hadoop-3/main/scala/org/apache/spark/internal/io/cloud/PathOutputCommitProtocol.scala
b/hadoop-cloud/src/hadoop-3/main/scala/org/apache/spark/internal/io/cloud/PathOutputCommitProtocol.scala
index 2ca5087..fc5d0a0 100644
---
a/hadoop-cloud/src/hadoop-3/main/scala/org/apache/spark/internal/io/cloud/PathOutputCommitProtocol.scala
+++
b/hadoop-cloud/src/hadoop-3/main/scala/org/apache/spark/internal/io/cloud/PathOutputCommitProtocol.scala
@@ -23,6 +23,7 @@ import org.apache.hadoop.fs.Path
import org.apache.hadoop.mapreduce.TaskAttemptContext
import org.apache.hadoop.mapreduce.lib.output.{FileOutputCommitter,
PathOutputCommitter, PathOutputCommitterFactory}
+import org.apache.spark.internal.io.FileNameSpec
import org.apache.spark.internal.io.HadoopMapReduceCommitProtocol
/**
@@ -122,20 +123,20 @@ class PathOutputCommitProtocol(
*
* @param taskContext task context
* @param dir optional subdirectory
- * @param ext file extension
+ * @param spec file naming specification
* @return a path as a string
*/
override def newTaskTempFile(
taskContext: TaskAttemptContext,
dir: Option[String],
- ext: String): String = {
+ spec: FileNameSpec): String = {
val workDir = committer.getWorkPath
val parent = dir.map {
d => new Path(workDir, d)
}.getOrElse(workDir)
- val file = new Path(parent, getFilename(taskContext, ext))
- logTrace(s"Creating task file $file for dir $dir and ext $ext")
+ val file = new Path(parent, getFilename(taskContext, spec))
+ logTrace(s"Creating task file $file for dir $dir and spec $spec")
file.toString
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BucketingUtils.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BucketingUtils.scala
index a776fc3..0622586 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BucketingUtils.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BucketingUtils.scala
@@ -33,6 +33,9 @@ object BucketingUtils {
// part-r-00000-2dd664f9-d2c4-4ffe-878f-c6c70c1fb0cb_00003.gz.parquet
private val bucketedFileName = """.*_(\d+)(?:\..*)?$""".r
+ // The reserved option name for data source to write Hive-compatible
bucketed table
+ val optionForHiveCompatibleBucketWrite =
"__hive_compatible_bucketed_table_insertion__"
+
def getBucketId(fileName: String): Option[Int] = fileName match {
case bucketedFileName(bucketId) => Some(bucketId.toInt)
case other => None
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala
index 815d8ac..0b1b616 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala
@@ -22,7 +22,7 @@ import org.apache.hadoop.fs.Path
import org.apache.hadoop.mapreduce.TaskAttemptContext
import org.apache.spark.internal.Logging
-import org.apache.spark.internal.io.FileCommitProtocol
+import org.apache.spark.internal.io.{FileCommitProtocol, FileNameSpec}
import org.apache.spark.internal.io.FileCommitProtocol.TaskCommitMessage
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
@@ -193,7 +193,7 @@ abstract class BaseDynamicPartitionDataWriter(
protected val isPartitioned = description.partitionColumns.nonEmpty
/** Flag saying whether or not the data to be written out is bucketed. */
- protected val isBucketed = description.bucketIdExpression.isDefined
+ protected val isBucketed = description.bucketSpec.isDefined
assert(isPartitioned || isBucketed,
s"""DynamicPartitionWriteTask should be used for writing out data that's
either
@@ -238,7 +238,8 @@ abstract class BaseDynamicPartitionDataWriter(
/** Given an input row, returns the corresponding `bucketId` */
protected lazy val getBucketId: InternalRow => Int = {
val proj =
- UnsafeProjection.create(description.bucketIdExpression.toSeq,
description.allColumns)
+
UnsafeProjection.create(Seq(description.bucketSpec.get.bucketIdExpression),
+ description.allColumns)
row => proj(row).getInt(0)
}
@@ -271,17 +272,24 @@ abstract class BaseDynamicPartitionDataWriter(
val bucketIdStr =
bucketId.map(BucketingUtils.bucketIdToString).getOrElse("")
- // This must be in a form that matches our bucketing format. See
BucketingUtils.
- val ext = f"$bucketIdStr.c$fileCounter%03d" +
+ // The prefix and suffix must be in a form that matches our bucketing
format. See BucketingUtils
+ // for details. The prefix is required to represent bucket id when writing
Hive-compatible
+ // bucketed table.
+ val prefix = bucketId match {
+ case Some(id) => description.bucketSpec.get.bucketFileNamePrefix(id)
+ case _ => ""
+ }
+ val suffix = f"$bucketIdStr.c$fileCounter%03d" +
description.outputWriterFactory.getFileExtension(taskAttemptContext)
+ val fileNameSpec = FileNameSpec(prefix, suffix)
val customPath = partDir.flatMap { dir =>
description.customPartitionLocations.get(PartitioningUtils.parsePathFragment(dir))
}
val currentPath = if (customPath.isDefined) {
- committer.newTaskTempFileAbsPath(taskAttemptContext, customPath.get, ext)
+ committer.newTaskTempFileAbsPath(taskAttemptContext, customPath.get,
fileNameSpec)
} else {
- committer.newTaskTempFile(taskAttemptContext, partDir, ext)
+ committer.newTaskTempFile(taskAttemptContext, partDir, fileNameSpec)
}
currentWriter = description.outputWriterFactory.newInstance(
@@ -554,6 +562,16 @@ class DynamicPartitionDataConcurrentWriter(
}
}
+/**
+ * Bucketing specification for all the write tasks.
+ *
+ * @param bucketIdExpression Expression to calculate bucket id based on bucket
column(s).
+ * @param bucketFileNamePrefix Prefix of output file name based on bucket id.
+ */
+case class WriterBucketSpec(
+ bucketIdExpression: Expression,
+ bucketFileNamePrefix: Int => String)
+
/** A shared job description for all the write tasks. */
class WriteJobDescription(
val uuid: String, // prevent collision between different (appending) write
jobs
@@ -562,7 +580,7 @@ class WriteJobDescription(
val allColumns: Seq[Attribute],
val dataColumns: Seq[Attribute],
val partitionColumns: Seq[Attribute],
- val bucketIdExpression: Option[Expression],
+ val bucketSpec: Option[WriterBucketSpec],
val path: String,
val customPartitionLocations: Map[TablePartitionSpec, String],
val maxRecordsPerFile: Long,
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala
index 2e36837..409e334 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala
@@ -122,12 +122,34 @@ object FileFormatWriter extends Logging {
}
val empty2NullPlan = if (needConvert) ProjectExec(projectList, plan) else
plan
- val bucketIdExpression = bucketSpec.map { spec =>
+ val writerBucketSpec = bucketSpec.map { spec =>
val bucketColumns = spec.bucketColumnNames.map(c =>
dataColumns.find(_.name == c).get)
- // Use `HashPartitioning.partitionIdExpression` as our bucket id
expression, so that we can
- // guarantee the data distribution is same between shuffle and bucketed
data source, which
- // enables us to only shuffle one side when join a bucketed table and a
normal one.
- HashPartitioning(bucketColumns, spec.numBuckets).partitionIdExpression
+
+ if (options.getOrElse(BucketingUtils.optionForHiveCompatibleBucketWrite,
"false") ==
+ "true") {
+ // Hive bucketed table: use `HiveHash` and bitwise-and as bucket id
expression.
+ // Without the extra bitwise-and operation, we can get wrong bucket id
when hash value of
+ // columns is negative. See Hive implementation in
+ //
`org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils#getBucketNumber()`.
+ val hashId = BitwiseAnd(HiveHash(bucketColumns), Literal(Int.MaxValue))
+ val bucketIdExpression = Pmod(hashId, Literal(spec.numBuckets))
+
+ // The bucket file name prefix is following Hive, Presto and Trino
conversion, so this
+ // makes sure Hive bucketed table written by Spark, can be read by
other SQL engines.
+ //
+ // Hive:
`org.apache.hadoop.hive.ql.exec.Utilities#getBucketIdFromFile()`.
+ // Trino:
`io.trino.plugin.hive.BackgroundHiveSplitLoader#BUCKET_PATTERNS`.
+ val fileNamePrefix = (bucketId: Int) => f"$bucketId%05d_0_"
+ WriterBucketSpec(bucketIdExpression, fileNamePrefix)
+ } else {
+ // Spark bucketed table: use `HashPartitioning.partitionIdExpression`
as bucket id
+ // expression, so that we can guarantee the data distribution is same
between shuffle and
+ // bucketed data source, which enables us to only shuffle one side
when join a bucketed
+ // table and a normal one.
+ val bucketIdExpression = HashPartitioning(bucketColumns,
spec.numBuckets)
+ .partitionIdExpression
+ WriterBucketSpec(bucketIdExpression, (_: Int) => "")
+ }
}
val sortColumns = bucketSpec.toSeq.flatMap {
spec => spec.sortColumnNames.map(c => dataColumns.find(_.name == c).get)
@@ -148,7 +170,7 @@ object FileFormatWriter extends Logging {
allColumns = outputSpec.outputColumns,
dataColumns = dataColumns,
partitionColumns = partitionColumns,
- bucketIdExpression = bucketIdExpression,
+ bucketSpec = writerBucketSpec,
path = outputSpec.outputPath,
customPartitionLocations = outputSpec.customPartitionLocations,
maxRecordsPerFile =
caseInsensitiveOptions.get("maxRecordsPerFile").map(_.toLong)
@@ -159,7 +181,8 @@ object FileFormatWriter extends Logging {
)
// We should first sort by partition columns, then bucket id, and finally
sorting columns.
- val requiredOrdering = partitionColumns ++ bucketIdExpression ++
sortColumns
+ val requiredOrdering =
+ partitionColumns ++ writerBucketSpec.map(_.bucketIdExpression) ++
sortColumns
// the sort order doesn't matter
val actualOrdering = empty2NullPlan.outputOrdering.map(_.child)
val orderingMatched = if (requiredOrdering.length > actualOrdering.length)
{
@@ -286,7 +309,7 @@ object FileFormatWriter extends Logging {
if (sparkPartitionId != 0 && !iterator.hasNext) {
// In case of empty job, leave first partition to save meta for file
format like parquet.
new EmptyDirectoryDataWriter(description, taskAttemptContext,
committer)
- } else if (description.partitionColumns.isEmpty &&
description.bucketIdExpression.isEmpty) {
+ } else if (description.partitionColumns.isEmpty &&
description.bucketSpec.isEmpty) {
new SingleDirectoryDataWriter(description, taskAttemptContext,
committer)
} else {
concurrentOutputWriterSpec match {
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWrite.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWrite.scala
index 7f66a09..ccc467a 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWrite.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWrite.scala
@@ -133,7 +133,7 @@ trait FileWrite extends Write {
allColumns = allColumns,
dataColumns = allColumns,
partitionColumns = Seq.empty,
- bucketIdExpression = None,
+ bucketSpec = None,
path = pathName,
customPartitionLocations = Map.empty,
maxRecordsPerFile =
caseInsensitiveOptions.get("maxRecordsPerFile").map(_.toLong)
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala
index 0a5feda..ae35f29 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala
@@ -19,10 +19,10 @@ package org.apache.spark.sql.sources
import java.io.File
-import org.apache.spark.sql.{AnalysisException, QueryTest}
+import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest}
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.catalog.BucketSpec
-import org.apache.spark.sql.catalyst.expressions.UnsafeProjection
+import org.apache.spark.sql.catalyst.expressions.{Expression, UnsafeProjection}
import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
import org.apache.spark.sql.execution.datasources.BucketingUtils
import org.apache.spark.sql.functions._
@@ -136,29 +136,35 @@ abstract class BucketedWriteSuite extends QueryTest with
SQLTestUtils {
(0 until 50).map(i => (i % 5, i % 13, i.toString)).toDF("i", "j", "k")
}
- def tableDir: File = {
- val identifier =
spark.sessionState.sqlParser.parseTableIdentifier("bucketed_table")
+ def tableDir(table: String = "bucketed_table"): File = {
+ val identifier = spark.sessionState.sqlParser.parseTableIdentifier(table)
new File(spark.sessionState.catalog.defaultTablePath(identifier))
}
+ private def bucketIdExpression(expressions: Seq[Expression], numBuckets:
Int): Expression =
+ HashPartitioning(expressions, numBuckets).partitionIdExpression
+
/**
* A helper method to check the bucket write functionality in low level,
i.e. check the written
* bucket files to see if the data are correct. User should pass in a data
dir that these bucket
* files are written to, and the format of data(parquet, json, etc.), and
the bucketing
* information.
*/
- private def testBucketing(
+ protected def testBucketing(
dataDir: File,
source: String,
numBuckets: Int,
bucketCols: Seq[String],
- sortCols: Seq[String] = Nil): Unit = {
+ sortCols: Seq[String] = Nil,
+ inputDF: DataFrame = df,
+ bucketIdExpression: (Seq[Expression], Int) => Expression =
bucketIdExpression,
+ getBucketIdFromFileName: String => Option[Int] =
BucketingUtils.getBucketId): Unit = {
val allBucketFiles = dataDir.listFiles().filterNot(f =>
f.getName.startsWith(".") || f.getName.startsWith("_")
)
for (bucketFile <- allBucketFiles) {
- val bucketId = BucketingUtils.getBucketId(bucketFile.getName).getOrElse {
+ val bucketId = getBucketIdFromFileName(bucketFile.getName).getOrElse {
fail(s"Unable to find the related bucket files.")
}
@@ -167,7 +173,7 @@ abstract class BucketedWriteSuite extends QueryTest with
SQLTestUtils {
val selectedColumns = (bucketCols ++ sortCols).distinct
// We may lose the type information after write(e.g. json format doesn't
keep schema
// information), here we get the types from the original dataframe.
- val types = df.select(selectedColumns.map(col):
_*).schema.map(_.dataType)
+ val types = inputDF.select(selectedColumns.map(col):
_*).schema.map(_.dataType)
val columns = selectedColumns.zip(types).map {
case (colName, dt) => col(colName).cast(dt)
}
@@ -188,7 +194,7 @@ abstract class BucketedWriteSuite extends QueryTest with
SQLTestUtils {
val qe = readBack.select(bucketCols.map(col): _*).queryExecution
val rows = qe.toRdd.map(_.copy()).collect()
val getBucketId = UnsafeProjection.create(
- HashPartitioning(qe.analyzed.output, numBuckets).partitionIdExpression
:: Nil,
+ bucketIdExpression(qe.analyzed.output, numBuckets) :: Nil,
qe.analyzed.output)
for (row <- rows) {
@@ -208,7 +214,7 @@ abstract class BucketedWriteSuite extends QueryTest with
SQLTestUtils {
.saveAsTable("bucketed_table")
for (i <- 0 until 5) {
- testBucketing(new File(tableDir, s"i=$i"), source, 8, Seq("j", "k"))
+ testBucketing(new File(tableDir(), s"i=$i"), source, 8, Seq("j",
"k"))
}
}
}
@@ -225,7 +231,7 @@ abstract class BucketedWriteSuite extends QueryTest with
SQLTestUtils {
.saveAsTable("bucketed_table")
for (i <- 0 until 5) {
- testBucketing(new File(tableDir, s"i=$i"), source, 8, Seq("j"),
Seq("k"))
+ testBucketing(new File(tableDir(), s"i=$i"), source, 8, Seq("j"),
Seq("k"))
}
}
}
@@ -255,7 +261,7 @@ abstract class BucketedWriteSuite extends QueryTest with
SQLTestUtils {
.bucketBy(8, "i", "j")
.saveAsTable("bucketed_table")
- testBucketing(tableDir, source, 8, Seq("i", "j"))
+ testBucketing(tableDir(), source, 8, Seq("i", "j"))
}
}
}
@@ -269,7 +275,7 @@ abstract class BucketedWriteSuite extends QueryTest with
SQLTestUtils {
.sortBy("k")
.saveAsTable("bucketed_table")
- testBucketing(tableDir, source, 8, Seq("i", "j"), Seq("k"))
+ testBucketing(tableDir(), source, 8, Seq("i", "j"), Seq("k"))
}
}
}
@@ -286,7 +292,7 @@ abstract class BucketedWriteSuite extends QueryTest with
SQLTestUtils {
.saveAsTable("bucketed_table")
for (i <- 0 until 5) {
- testBucketing(new File(tableDir, s"i=$i"), source, 8, Seq("j",
"k"))
+ testBucketing(new File(tableDir(), s"i=$i"), source, 8, Seq("j",
"k"))
}
}
}
diff --git
a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
index 05aa648..c905a52 100644
---
a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
+++
b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
@@ -125,7 +125,7 @@ private[hive] class HiveMetastoreCatalog(sparkSession:
SparkSession) extends Log
private def isParquetProperty(key: String) =
key.startsWith("parquet.") || key.contains(".parquet.")
- def convert(relation: HiveTableRelation): LogicalRelation = {
+ def convert(relation: HiveTableRelation, isWrite: Boolean): LogicalRelation
= {
val serde =
relation.tableMeta.storage.serde.getOrElse("").toLowerCase(Locale.ROOT)
// Consider table and storage properties. For properties existing in both
sides, storage
@@ -134,7 +134,7 @@ private[hive] class HiveMetastoreCatalog(sparkSession:
SparkSession) extends Log
val options =
relation.tableMeta.properties.filterKeys(isParquetProperty).toMap ++
relation.tableMeta.storage.properties + (ParquetOptions.MERGE_SCHEMA ->
SQLConf.get.getConf(HiveUtils.CONVERT_METASTORE_PARQUET_WITH_SCHEMA_MERGING).toString)
- convertToLogicalRelation(relation, options,
classOf[ParquetFileFormat], "parquet")
+ convertToLogicalRelation(relation, options,
classOf[ParquetFileFormat], "parquet", isWrite)
} else {
val options =
relation.tableMeta.properties.filterKeys(isOrcProperty).toMap ++
relation.tableMeta.storage.properties
@@ -143,13 +143,15 @@ private[hive] class HiveMetastoreCatalog(sparkSession:
SparkSession) extends Log
relation,
options,
classOf[org.apache.spark.sql.execution.datasources.orc.OrcFileFormat],
- "orc")
+ "orc",
+ isWrite)
} else {
convertToLogicalRelation(
relation,
options,
classOf[org.apache.spark.sql.hive.orc.OrcFileFormat],
- "orc")
+ "orc",
+ isWrite)
}
}
}
@@ -158,7 +160,8 @@ private[hive] class HiveMetastoreCatalog(sparkSession:
SparkSession) extends Log
relation: HiveTableRelation,
options: Map[String, String],
fileFormatClass: Class[_ <: FileFormat],
- fileType: String): LogicalRelation = {
+ fileType: String,
+ isWrite: Boolean): LogicalRelation = {
val metastoreSchema = relation.tableMeta.schema
val tableIdentifier =
QualifiedTableName(relation.tableMeta.database,
relation.tableMeta.identifier.table)
@@ -166,6 +169,14 @@ private[hive] class HiveMetastoreCatalog(sparkSession:
SparkSession) extends Log
val lazyPruningEnabled =
sparkSession.sqlContext.conf.manageFilesourcePartitions
val tablePath = new Path(relation.tableMeta.location)
val fileFormat = fileFormatClass.getConstructor().newInstance()
+ val bucketSpec = relation.tableMeta.bucketSpec
+ val (hiveOptions, hiveBucketSpec) =
+ if (isWrite) {
+ (options.updated(BucketingUtils.optionForHiveCompatibleBucketWrite,
"true"),
+ bucketSpec)
+ } else {
+ (options, None)
+ }
val result = if (relation.isPartitioned) {
val partitionSchema = relation.tableMeta.partitionSchema
@@ -207,16 +218,16 @@ private[hive] class HiveMetastoreCatalog(sparkSession:
SparkSession) extends Log
}
}
- val updatedTable = inferIfNeeded(relation, options, fileFormat,
Option(fileIndex))
+ val updatedTable = inferIfNeeded(relation, hiveOptions, fileFormat,
Option(fileIndex))
// Spark SQL's data source table now support static and dynamic
partition insert. Source
// table converted from Hive table should always use dynamic.
- val enableDynamicPartition =
options.updated("partitionOverwriteMode", "dynamic")
+ val enableDynamicPartition =
hiveOptions.updated("partitionOverwriteMode", "dynamic")
val fsRelation = HadoopFsRelation(
location = fileIndex,
partitionSchema = partitionSchema,
dataSchema = updatedTable.dataSchema,
- bucketSpec = None,
+ bucketSpec = hiveBucketSpec,
fileFormat = fileFormat,
options = enableDynamicPartition)(sparkSession = sparkSession)
val created = LogicalRelation(fsRelation, updatedTable)
@@ -236,17 +247,17 @@ private[hive] class HiveMetastoreCatalog(sparkSession:
SparkSession) extends Log
fileFormatClass,
None)
val logicalRelation = cached.getOrElse {
- val updatedTable = inferIfNeeded(relation, options, fileFormat)
+ val updatedTable = inferIfNeeded(relation, hiveOptions, fileFormat)
val created =
LogicalRelation(
DataSource(
sparkSession = sparkSession,
paths = rootPath.toString :: Nil,
userSpecifiedSchema = Option(updatedTable.dataSchema),
- bucketSpec = None,
+ bucketSpec = hiveBucketSpec,
// Do not interpret the 'path' option at all when tables are
read using the Hive
// source, since the URIs will already have been read from the
table's LOCATION.
- options = options.filter { case (k, _) =>
!k.equalsIgnoreCase("path") },
+ options = hiveOptions.filter { case (k, _) =>
!k.equalsIgnoreCase("path") },
className = fileType).resolveRelation(),
table = updatedTable)
diff --git
a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala
b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala
index c8a5c03..0da7465 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala
@@ -211,13 +211,13 @@ case class RelationConversions(
if query.resolved && DDLUtils.isHiveTable(r.tableMeta) &&
(!r.isPartitioned ||
conf.getConf(HiveUtils.CONVERT_INSERTING_PARTITIONED_TABLE))
&& isConvertible(r) =>
- InsertIntoStatement(metastoreCatalog.convert(r), partition, cols,
+ InsertIntoStatement(metastoreCatalog.convert(r, isWrite = true),
partition, cols,
query, overwrite, ifPartitionNotExists)
// Read path
case relation: HiveTableRelation
if DDLUtils.isHiveTable(relation.tableMeta) &&
isConvertible(relation) =>
- metastoreCatalog.convert(relation)
+ metastoreCatalog.convert(relation, isWrite = false)
// CTAS
case CreateTable(tableDesc, mode, Some(query))
diff --git
a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala
b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala
index beaebb5..cef2a36 100644
---
a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala
+++
b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala
@@ -161,7 +161,7 @@ case class OptimizedCreateHiveTableAsSelectCommand(
val metastoreCatalog =
catalog.asInstanceOf[HiveSessionCatalog].metastoreCatalog
val hiveTable = DDLUtils.readHiveTable(tableDesc)
- val hadoopRelation = metastoreCatalog.convert(hiveTable) match {
+ val hadoopRelation = metastoreCatalog.convert(hiveTable, isWrite = true)
match {
case LogicalRelation(t: HadoopFsRelation, _, _, _) => t
case _ => throw
QueryCompilationErrors.tableIdentifierNotConvertedToHadoopFsRelationError(
tableIdentifier)
diff --git
a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteWithHiveSupportSuite.scala
b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteWithHiveSupportSuite.scala
index bdbdcc2..c12caaa 100644
---
a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteWithHiveSupportSuite.scala
+++
b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteWithHiveSupportSuite.scala
@@ -17,6 +17,11 @@
package org.apache.spark.sql.sources
+import java.io.File
+
+import org.apache.spark.sql.SaveMode
+import org.apache.spark.sql.catalyst.expressions.{BitwiseAnd, Expression,
HiveHash, Literal, Pmod}
+import org.apache.spark.sql.hive.test.TestHive.implicits._
import org.apache.spark.sql.hive.test.TestHiveSingleton
import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION
@@ -27,4 +32,47 @@ class BucketedWriteWithHiveSupportSuite extends
BucketedWriteSuite with TestHive
}
override protected def fileFormatsToTest: Seq[String] = Seq("parquet", "orc")
+
+ test("write hive bucketed table") {
+ def bucketIdExpression(expressions: Seq[Expression], numBuckets: Int):
Expression =
+ Pmod(BitwiseAnd(HiveHash(expressions), Literal(Int.MaxValue)),
Literal(8))
+
+ def getBucketIdFromFileName(fileName: String): Option[Int] = {
+ val hiveBucketedFileName = """^(\d+)_0_.*$""".r
+ fileName match {
+ case hiveBucketedFileName(bucketId) => Some(bucketId.toInt)
+ case _ => None
+ }
+ }
+
+ val table = "hive_bucketed_table"
+
+ fileFormatsToTest.foreach { format =>
+ withTable(table) {
+ sql(
+ s"""
+ |CREATE TABLE IF NOT EXISTS $table (i int, j string)
+ |PARTITIONED BY(k string)
+ |CLUSTERED BY (i, j) SORTED BY (i) INTO 8 BUCKETS
+ |STORED AS $format
+ """.stripMargin)
+
+ val df =
+ (0 until 50).map(i => (i % 13, i.toString, i % 5)).toDF("i", "j",
"k")
+ df.write.mode(SaveMode.Overwrite).insertInto(table)
+
+ for (k <- 0 until 5) {
+ testBucketing(
+ new File(tableDir(table), s"k=$k"),
+ format,
+ 8,
+ Seq("i", "j"),
+ Seq("i"),
+ df,
+ bucketIdExpression,
+ getBucketIdFromFileName)
+ }
+ }
+ }
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]