This is an automated email from the ASF dual-hosted git repository.
changchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new 2346584a72 [GLUTEN-7028][CH][Part-11] Support write parquet files with
bucket (#8052)
2346584a72 is described below
commit 2346584a72cae67494a073a4f0141386bad723db
Author: Wenzheng Liu <[email protected]>
AuthorDate: Tue Dec 3 15:57:26 2024 +0800
[GLUTEN-7028][CH][Part-11] Support write parquet files with bucket (#8052)
* [GLUTEN-7028][CH] Support write parquet files with bucket
* [GLUTEN-7028][CH] Fix comment
---
.../sql/execution/FileDeltaColumnarWrite.scala | 6 +-
.../gluten/backendsapi/clickhouse/CHBackend.scala | 11 +--
.../backendsapi/clickhouse/CHIteratorApi.scala | 13 +--
.../gluten/backendsapi/clickhouse/CHRuleApi.scala | 1 +
.../backendsapi/clickhouse/RuntimeSettings.scala | 6 --
.../extension/WriteFilesWithBucketValue.scala | 76 ++++++++++++++
.../spark/sql/execution/CHColumnarWrite.scala | 46 +++++++--
.../GlutenClickHouseNativeWriteTableSuite.scala | 15 +--
.../Parser/RelParsers/WriteRelParser.cpp | 14 ++-
.../Parser/RelParsers/WriteRelParser.h | 1 -
.../Storages/MergeTree/SparkMergeTreeSink.h | 6 ++
.../Storages/Output/NormalFileWriter.cpp | 4 +-
.../Storages/Output/NormalFileWriter.h | 110 ++++++++++++++++++---
.../Storages/Output/OutputFormatFile.cpp | 1 -
cpp-ch/local-engine/tests/gtest_write_pipeline.cpp | 6 +-
15 files changed, 242 insertions(+), 74 deletions(-)
diff --git
a/backends-clickhouse/src/main/delta-32/org/apache/spark/sql/execution/FileDeltaColumnarWrite.scala
b/backends-clickhouse/src/main/delta-32/org/apache/spark/sql/execution/FileDeltaColumnarWrite.scala
index 784614152f..bf6b0c0074 100644
---
a/backends-clickhouse/src/main/delta-32/org/apache/spark/sql/execution/FileDeltaColumnarWrite.scala
+++
b/backends-clickhouse/src/main/delta-32/org/apache/spark/sql/execution/FileDeltaColumnarWrite.scala
@@ -108,13 +108,15 @@ case class FileDeltaColumnarWrite(
* {{{
* part-00000-7d672b28-c079-4b00-bb0a-196c15112918-c000.snappy.parquet
* =>
- * part-00000-{}.snappy.parquet
+ * part-00000-{id}.snappy.parquet
* }}}
*/
val guidPattern =
""".*-([0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12})(?:-c(\d+)\..*)?$""".r
val fileNamePattern =
- guidPattern.replaceAllIn(writeFileName, m =>
writeFileName.replace(m.group(1), "{}"))
+ guidPattern.replaceAllIn(
+ writeFileName,
+ m => writeFileName.replace(m.group(1), FileNamePlaceHolder.ID))
logDebug(s"Native staging write path: $writePath and with pattern:
$fileNamePattern")
val settings =
diff --git
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala
index c6c8acf705..e5eb91b69b 100644
---
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala
+++
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala
@@ -246,20 +246,11 @@ object CHBackendSettings extends BackendSettingsApi with
Logging {
}
}
- def validateBucketSpec(): Option[String] = {
- if (bucketSpec.nonEmpty) {
- Some("Unsupported native write: bucket write is not supported.")
- } else {
- None
- }
- }
-
validateCompressionCodec()
.orElse(validateFileFormat())
.orElse(validateFieldMetadata())
.orElse(validateDateTypes())
- .orElse(validateWriteFilesOptions())
- .orElse(validateBucketSpec()) match {
+ .orElse(validateWriteFilesOptions()) match {
case Some(reason) => ValidationResult.failed(reason)
case _ => ValidationResult.succeeded
}
diff --git
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHIteratorApi.scala
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHIteratorApi.scala
index ff268b95d8..878e27a5b8 100644
---
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHIteratorApi.scala
+++
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHIteratorApi.scala
@@ -26,7 +26,7 @@ import org.apache.gluten.sql.shims.SparkShimLoader
import org.apache.gluten.substrait.plan.PlanNode
import org.apache.gluten.substrait.rel._
import org.apache.gluten.substrait.rel.LocalFilesNode.ReadFileFormat
-import org.apache.gluten.vectorized.{BatchIterator,
CHNativeExpressionEvaluator, CloseableCHColumnBatchIterator,
NativeExpressionEvaluator}
+import org.apache.gluten.vectorized.{BatchIterator,
CHNativeExpressionEvaluator, CloseableCHColumnBatchIterator}
import org.apache.spark.{InterruptibleIterator, SparkConf, TaskContext}
import org.apache.spark.affinity.CHAffinity
@@ -322,17 +322,6 @@ class CHIteratorApi extends IteratorApi with Logging with
LogLevelUtil {
createNativeIterator(splitInfoByteArray, wsPlan, materializeInput,
inputIterators))
}
- /**
- * This function used to inject the staging write path before initializing
the native plan.Only
- * used in a pipeline model (spark 3.5) for writing parquet or orc files.
- */
- override def injectWriteFilesTempPath(path: String, fileName: String): Unit
= {
- val settings =
- Map(
- RuntimeSettings.TASK_WRITE_TMP_DIR.key -> path,
- RuntimeSettings.TASK_WRITE_FILENAME.key -> fileName)
- NativeExpressionEvaluator.updateQueryRuntimeSettings(settings)
- }
}
class CollectMetricIterator(
diff --git
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala
index edf7a48025..98cfa0e754 100644
---
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala
+++
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala
@@ -79,6 +79,7 @@ object CHRuleApi {
injector.injectPreTransform(_ => RewriteSubqueryBroadcast())
injector.injectPreTransform(c =>
FallbackBroadcastHashJoin.apply(c.session))
injector.injectPreTransform(c =>
MergeTwoPhasesHashBaseAggregate.apply(c.session))
+ injector.injectPreTransform(_ => WriteFilesWithBucketValue)
// Legacy: The legacy transform rule.
val validatorBuilder: GlutenConfig => Validator = conf =>
diff --git
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/RuntimeSettings.scala
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/RuntimeSettings.scala
index b59bb32392..c2747cf1eb 100644
---
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/RuntimeSettings.scala
+++
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/RuntimeSettings.scala
@@ -35,12 +35,6 @@ object RuntimeSettings {
.stringConf
.createWithDefault("")
- val TASK_WRITE_FILENAME =
- buildConf(runtimeSettings("gluten.task_write_filename"))
- .doc("The temporary file name for writing data")
- .stringConf
- .createWithDefault("")
-
val TASK_WRITE_FILENAME_PATTERN =
buildConf(runtimeSettings("gluten.task_write_filename_pattern"))
.doc("The pattern to generate file name for writing delta parquet in
spark 3.5")
diff --git
a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/WriteFilesWithBucketValue.scala
b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/WriteFilesWithBucketValue.scala
new file mode 100644
index 0000000000..8ab78dcff9
--- /dev/null
+++
b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/WriteFilesWithBucketValue.scala
@@ -0,0 +1,76 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.gluten.extension
+
+import org.apache.gluten.GlutenConfig
+
+import org.apache.spark.sql.catalyst.expressions.{Alias, BitwiseAnd,
Expression, HiveHash, Literal, Pmod}
+import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.execution.{ProjectExec, SparkPlan}
+import org.apache.spark.sql.execution.datasources.WriteFilesExec
+
+/**
+ * Wrap with bucket value to specify the bucket file name in native write.
Native writer will remove
+ * this value in the final output.
+ */
+object WriteFilesWithBucketValue extends Rule[SparkPlan] {
+
+ val optionForHiveCompatibleBucketWrite =
"__hive_compatible_bucketed_table_insertion__"
+
+ override def apply(plan: SparkPlan): SparkPlan = {
+ if (
+ GlutenConfig.getConf.enableGluten
+ && GlutenConfig.getConf.enableNativeWriter.getOrElse(false)
+ ) {
+ plan.transformDown {
+ case writeFiles: WriteFilesExec if writeFiles.bucketSpec.isDefined =>
+ val bucketIdExp = getWriterBucketIdExp(writeFiles)
+ val wrapBucketValue = ProjectExec(
+ writeFiles.child.output :+ Alias(bucketIdExp,
"__bucket_value__")(),
+ writeFiles.child)
+ writeFiles.copy(child = wrapBucketValue)
+ }
+ } else {
+ plan
+ }
+ }
+
+ private def getWriterBucketIdExp(writeFilesExec: WriteFilesExec): Expression
= {
+ val partitionColumns = writeFilesExec.partitionColumns
+ val outputColumns = writeFilesExec.child.output
+ val dataColumns = outputColumns.filterNot(partitionColumns.contains)
+ val bucketSpec = writeFilesExec.bucketSpec.get
+ val bucketColumns = bucketSpec.bucketColumnNames.map(c =>
dataColumns.find(_.name == c).get)
+ if (writeFilesExec.options.getOrElse(optionForHiveCompatibleBucketWrite,
"false") == "true") {
+ val hashId = BitwiseAnd(HiveHash(bucketColumns), Literal(Int.MaxValue))
+ Pmod(hashId, Literal(bucketSpec.numBuckets))
+ // The bucket file name prefix is following Hive, Presto and Trino
conversion, so this
+ // makes sure Hive bucketed table written by Spark, can be read by other
SQL engines.
+ //
+ // Hive:
`org.apache.hadoop.hive.ql.exec.Utilities#getBucketIdFromFile()`.
+ // Trino:
`io.trino.plugin.hive.BackgroundHiveSplitLoader#BUCKET_PATTERNS`.
+
+ } else {
+ // Spark bucketed table: use `HashPartitioning.partitionIdExpression` as
bucket id
+ // expression, so that we can guarantee the data distribution is same
between shuffle and
+ // bucketed data source, which enables us to only shuffle one side when
join a bucketed
+ // table and a normal one.
+ HashPartitioning(bucketColumns,
bucketSpec.numBuckets).partitionIdExpression
+ }
+ }
+}
diff --git
a/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/CHColumnarWrite.scala
b/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/CHColumnarWrite.scala
index 6c7877cc02..1342e25043 100644
---
a/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/CHColumnarWrite.scala
+++
b/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/CHColumnarWrite.scala
@@ -16,7 +16,8 @@
*/
package org.apache.spark.sql.execution
-import org.apache.gluten.backendsapi.BackendsApiManager
+import org.apache.gluten.backendsapi.clickhouse.RuntimeSettings
+import org.apache.gluten.vectorized.NativeExpressionEvaluator
import org.apache.spark.TaskContext
import org.apache.spark.internal.Logging
@@ -25,11 +26,11 @@ import
org.apache.spark.internal.io.FileCommitProtocol.TaskCommitMessage
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.GenericInternalRow
import org.apache.spark.sql.delta.stats.DeltaJobStatisticsTracker
-import org.apache.spark.sql.execution.datasources.{BasicWriteJobStatsTracker,
BasicWriteTaskStats, ExecutedWriteSummary, PartitioningUtils,
WriteJobDescription, WriteTaskResult, WriteTaskStatsTracker}
+import org.apache.spark.sql.execution.datasources._
import org.apache.spark.util.Utils
import org.apache.hadoop.fs.Path
-import org.apache.hadoop.mapreduce.{JobID, OutputCommitter,
TaskAttemptContext, TaskAttemptID, TaskID, TaskType}
+import org.apache.hadoop.mapreduce._
import org.apache.hadoop.mapreduce.lib.output.FileOutputCommitter
import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl
@@ -102,6 +103,12 @@ object CreateFileNameSpec {
}
}
+// More details in local_engine::FileNameGenerator in NormalFileWriter.cpp
+object FileNamePlaceHolder {
+ val ID = "{id}"
+ val BUCKET = "{bucket}"
+}
+
/** [[HadoopMapReduceAdapter]] for [[HadoopMapReduceCommitProtocol]]. */
case class HadoopMapReduceAdapter(sparkCommitter:
HadoopMapReduceCommitProtocol) {
private lazy val committer: OutputCommitter = {
@@ -132,12 +139,26 @@ case class HadoopMapReduceAdapter(sparkCommitter:
HadoopMapReduceCommitProtocol)
GetFilename.invoke(sparkCommitter, taskContext, spec).asInstanceOf[String]
}
- def getTaskAttemptTempPathAndFilename(
+ def getTaskAttemptTempPathAndFilePattern(
taskContext: TaskAttemptContext,
description: WriteJobDescription): (String, String) = {
val stageDir = newTaskAttemptTempPath(description.path)
- val filename = getFilename(taskContext, CreateFileNameSpec(taskContext,
description))
- (stageDir, filename)
+
+ if (isBucketWrite(description)) {
+ val filePart = getFilename(taskContext, FileNameSpec("", ""))
+ val fileSuffix = CreateFileNameSpec(taskContext, description).suffix
+ (stageDir, s"${filePart}_${FileNamePlaceHolder.BUCKET}$fileSuffix")
+ } else {
+ val filename = getFilename(taskContext, CreateFileNameSpec(taskContext,
description))
+ (stageDir, filename)
+ }
+ }
+
+ private def isBucketWrite(desc: WriteJobDescription): Boolean = {
+ // In Spark 3.2, bucketSpec is not defined, instead, it uses
bucketIdExpression.
+ val bucketSpecField: Field = desc.getClass.getDeclaredField("bucketSpec")
+ bucketSpecField.setAccessible(true)
+ bucketSpecField.get(desc).asInstanceOf[Option[_]].isDefined
}
}
@@ -234,10 +255,15 @@ case class HadoopMapReduceCommitProtocolWrite(
* initializing the native plan and collect native write files metrics for
each backend.
*/
override def doSetupNativeTask(): Unit = {
- val (writePath, writeFileName) =
- adapter.getTaskAttemptTempPathAndFilename(taskAttemptContext,
description)
- logDebug(s"Native staging write path: $writePath and file name:
$writeFileName")
-
BackendsApiManager.getIteratorApiInstance.injectWriteFilesTempPath(writePath,
writeFileName)
+ val (writePath, writeFilePattern) =
+ adapter.getTaskAttemptTempPathAndFilePattern(taskAttemptContext,
description)
+ logDebug(s"Native staging write path: $writePath and file pattern:
$writeFilePattern")
+
+ val settings =
+ Map(
+ RuntimeSettings.TASK_WRITE_TMP_DIR.key -> writePath,
+ RuntimeSettings.TASK_WRITE_FILENAME_PATTERN.key -> writeFilePattern)
+ NativeExpressionEvaluator.updateQueryRuntimeSettings(settings)
}
def doCollectNativeResult(stats: Seq[InternalRow]): Option[WriteTaskResult]
= {
diff --git
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/hive/GlutenClickHouseNativeWriteTableSuite.scala
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/hive/GlutenClickHouseNativeWriteTableSuite.scala
index 03d27f33b1..16ed302a02 100644
---
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/hive/GlutenClickHouseNativeWriteTableSuite.scala
+++
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/hive/GlutenClickHouseNativeWriteTableSuite.scala
@@ -553,7 +553,7 @@ class GlutenClickHouseNativeWriteTableSuite
// spark write does not support bucketed table
// https://issues.apache.org/jira/browse/SPARK-19256
val table_name = table_name_template.format(format)
- writeAndCheckRead(origin_table, table_name, fields_.keys.toSeq,
isSparkVersionLE("3.3")) {
+ writeAndCheckRead(origin_table, table_name, fields_.keys.toSeq) {
fields =>
spark
.table("origin_table")
@@ -589,8 +589,9 @@ class GlutenClickHouseNativeWriteTableSuite
("byte_field", "byte"),
("boolean_field", "boolean"),
("decimal_field", "decimal(23,12)"),
- ("date_field", "date"),
- ("timestamp_field", "timestamp")
+ ("date_field", "date")
+ // ("timestamp_field", "timestamp")
+ // FIXME https://github.com/apache/incubator-gluten/issues/8053
)
val origin_table = "origin_table"
withSource(genTestData(), origin_table) {
@@ -598,7 +599,7 @@ class GlutenClickHouseNativeWriteTableSuite
format =>
val table_name = table_name_template.format(format)
val testFields = fields.keys.toSeq
- writeAndCheckRead(origin_table, table_name, testFields,
isSparkVersionLE("3.3")) {
+ writeAndCheckRead(origin_table, table_name, testFields) {
fields =>
spark
.table(origin_table)
@@ -658,7 +659,7 @@ class GlutenClickHouseNativeWriteTableSuite
nativeWrite {
format =>
val table_name = table_name_template.format(format)
- writeAndCheckRead(origin_table, table_name, fields.keys.toSeq,
isSparkVersionLE("3.3")) {
+ writeAndCheckRead(origin_table, table_name, fields.keys.toSeq) {
fields =>
spark
.table("origin_table")
@@ -762,7 +763,7 @@ class GlutenClickHouseNativeWriteTableSuite
format =>
val table_name = table_name_template.format(format)
spark.sql(s"drop table IF EXISTS $table_name")
- withNativeWriteCheck(checkNative = isSparkVersionLE("3.3")) {
+ withNativeWriteCheck(checkNative = true) {
spark
.range(10000000)
.selectExpr("id", "cast('2020-01-01' as date) as p")
@@ -798,7 +799,7 @@ class GlutenClickHouseNativeWriteTableSuite
format =>
val table_name = table_name_template.format(format)
spark.sql(s"drop table IF EXISTS $table_name")
- withNativeWriteCheck(checkNative = isSparkVersionLE("3.3")) {
+ withNativeWriteCheck(checkNative = true) {
spark
.range(30000)
.selectExpr("id", "cast(null as string) as p")
diff --git a/cpp-ch/local-engine/Parser/RelParsers/WriteRelParser.cpp
b/cpp-ch/local-engine/Parser/RelParsers/WriteRelParser.cpp
index 2dacb39918..a76b4d398d 100644
--- a/cpp-ch/local-engine/Parser/RelParsers/WriteRelParser.cpp
+++ b/cpp-ch/local-engine/Parser/RelParsers/WriteRelParser.cpp
@@ -56,10 +56,11 @@ DB::ProcessorPtr make_sink(
const std::string & format_hint,
const std::shared_ptr<WriteStats> & stats)
{
- if (partition_by.empty())
+ bool no_bucketed =
!SparkPartitionedBaseSink::isBucketedWrite(input_header);
+ if (partition_by.empty() && no_bucketed)
{
return std::make_shared<SubstraitFileSink>(
- context, base_path, "", generator.generate(), format_hint,
input_header, stats, DeltaStats{input_header.columns()});
+ context, base_path, "", false, generator.generate(), format_hint,
input_header, stats, DeltaStats{input_header.columns()});
}
return std::make_shared<SubstraitPartitionedFileSink>(
@@ -184,13 +185,10 @@ void addNormalFileWriterSinkTransform(
if (write_settings.task_write_tmp_dir.empty())
throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Write Pipeline
need inject temp directory.");
- if (write_settings.task_write_filename.empty() &&
write_settings.task_write_filename_pattern.empty())
- throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Write Pipeline
need inject file name or file name pattern.");
+ if (write_settings.task_write_filename_pattern.empty())
+ throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Write Pipeline
need inject file pattern.");
- FileNameGenerator generator{
- .pattern = write_settings.task_write_filename.empty(),
- .filename_or_pattern
- = write_settings.task_write_filename.empty() ?
write_settings.task_write_filename_pattern :
write_settings.task_write_filename};
+ FileNameGenerator generator(write_settings.task_write_filename_pattern);
auto stats = WriteStats::create(output, partitionCols);
diff --git a/cpp-ch/local-engine/Parser/RelParsers/WriteRelParser.h
b/cpp-ch/local-engine/Parser/RelParsers/WriteRelParser.h
index 0c9bc11f1f..01e0dabaaa 100644
--- a/cpp-ch/local-engine/Parser/RelParsers/WriteRelParser.h
+++ b/cpp-ch/local-engine/Parser/RelParsers/WriteRelParser.h
@@ -44,7 +44,6 @@ DB::Names collect_partition_cols(const DB::Block & header,
const substrait::Name
#define WRITE_RELATED_SETTINGS(M, ALIAS) \
M(String, task_write_tmp_dir, , "The temporary directory for writing
data") \
- M(String, task_write_filename, , "The filename for writing data") \
M(String, task_write_filename_pattern, , "The pattern to generate file
name for writing delta parquet in spark 3.5")
DECLARE_GLUTEN_SETTINGS(GlutenWriteSettings, WRITE_RELATED_SETTINGS)
diff --git a/cpp-ch/local-engine/Storages/MergeTree/SparkMergeTreeSink.h
b/cpp-ch/local-engine/Storages/MergeTree/SparkMergeTreeSink.h
index 38f574ea98..b551d86d1d 100644
--- a/cpp-ch/local-engine/Storages/MergeTree/SparkMergeTreeSink.h
+++ b/cpp-ch/local-engine/Storages/MergeTree/SparkMergeTreeSink.h
@@ -278,6 +278,12 @@ public:
return SparkMergeTreeSink::create(
table, write_settings, context_->getGlobalContext(),
{std::dynamic_pointer_cast<MergeTreeStats>(stats_)});
}
+
+ // TODO implement with bucket
+ DB::SinkPtr createSinkForPartition(const String & partition_id, const
String & bucket) override
+ {
+ return createSinkForPartition(partition_id);
+ }
};
}
diff --git a/cpp-ch/local-engine/Storages/Output/NormalFileWriter.cpp
b/cpp-ch/local-engine/Storages/Output/NormalFileWriter.cpp
index ad2e3abf7b..2d70380a89 100644
--- a/cpp-ch/local-engine/Storages/Output/NormalFileWriter.cpp
+++ b/cpp-ch/local-engine/Storages/Output/NormalFileWriter.cpp
@@ -30,6 +30,8 @@ using namespace DB;
const std::string SubstraitFileSink::NO_PARTITION_ID{"__NO_PARTITION_ID__"};
const std::string
SparkPartitionedBaseSink::DEFAULT_PARTITION_NAME{"__HIVE_DEFAULT_PARTITION__"};
+const std::string
SparkPartitionedBaseSink::BUCKET_COLUMN_NAME{"__bucket_value__"};
+const std::vector<std::string> FileNameGenerator::SUPPORT_PLACEHOLDERS{"{id}",
"{bucket}"};
/// For Nullable(Map(K, V)) or Nullable(Array(T)), if the i-th row is null, we
must make sure its nested data is empty.
/// It is for ORC/Parquet writing compatiability. For more details, refer to
@@ -168,7 +170,7 @@ void NormalFileWriter::write(DB::Block & block)
const auto & preferred_schema = file->getPreferredSchema();
for (auto & column : block)
{
- if (column.name.starts_with("__bucket_value__"))
+ if
(column.name.starts_with(SparkPartitionedBaseSink::BUCKET_COLUMN_NAME))
continue;
const auto & preferred_column =
preferred_schema.getByPosition(index++);
diff --git a/cpp-ch/local-engine/Storages/Output/NormalFileWriter.h
b/cpp-ch/local-engine/Storages/Output/NormalFileWriter.h
index 8cfe079d92..998f8d6247 100644
--- a/cpp-ch/local-engine/Storages/Output/NormalFileWriter.h
+++ b/cpp-ch/local-engine/Storages/Output/NormalFileWriter.h
@@ -230,20 +230,57 @@ public:
struct FileNameGenerator
{
- const bool pattern;
- const std::string filename_or_pattern;
+ // Align with org.apache.spark.sql.execution.FileNamePlaceHolder
+ static const std::vector<std::string> SUPPORT_PLACEHOLDERS;
+ // Align with placeholders above
+ const std::vector<bool> need_to_replace;
+ const std::string file_pattern;
+
+ FileNameGenerator(const std::string & file_pattern)
+ : file_pattern(file_pattern),
need_to_replace(compute_need_to_replace(file_pattern))
+ {
+ }
+
+ std::vector<bool> compute_need_to_replace(const std::string & file_pattern)
+ {
+ std::vector<bool> result;
+ for(const std::string& placeholder: SUPPORT_PLACEHOLDERS)
+ {
+ if (file_pattern.find(placeholder) != std::string::npos)
+ result.push_back(true);
+ else
+ result.push_back(false);
+ }
+ return result;
+ }
+
+ std::string generate(const std::string & bucket = "") const
+ {
+ std::string result = file_pattern;
+ if (need_to_replace[0]) // {id}
+ result = pattern_format(SUPPORT_PLACEHOLDERS[0],
toString(DB::UUIDHelpers::generateV4()));
+ if (need_to_replace[1]) // {bucket}
+ result = pattern_format(SUPPORT_PLACEHOLDERS[1], bucket);
+ return result;
+ }
- std::string generate() const
+ std::string pattern_format(const std::string & arg, const std::string &
replacement) const
{
- if (pattern)
- return fmt::vformat(filename_or_pattern,
fmt::make_format_args(toString(DB::UUIDHelpers::generateV4())));
- return filename_or_pattern;
+ std::string format_str = file_pattern;
+ size_t pos = format_str.find(arg);
+ while (pos != std::string::npos)
+ {
+ format_str.replace(pos, arg.length(), replacement);
+ pos = format_str.find(arg, pos + arg.length());
+ }
+ return format_str;
}
};
class SubstraitFileSink final : public DB::SinkToStorage
{
const std::string partition_id_;
+ const bool bucketed_write_;
const std::string relative_path_;
OutputFormatFilePtr format_file_;
OutputFormatFile::OutputFormatPtr output_format_;
@@ -265,6 +302,7 @@ public:
const DB::ContextPtr & context,
const std::string & base_path,
const std::string & partition_id,
+ const bool bucketed_write,
const std::string & relative,
const std::string & format_hint,
const DB::Block & header,
@@ -272,6 +310,7 @@ public:
const DeltaStats & delta_stats)
: SinkToStorage(header)
, partition_id_(partition_id.empty() ? NO_PARTITION_ID : partition_id)
+ , bucketed_write_(bucketed_write)
, relative_path_(relative)
, format_file_(createOutputFormatFile(context,
makeAbsoluteFilename(base_path, partition_id, relative), header, format_hint))
, stats_(std::dynamic_pointer_cast<WriteStats>(stats))
@@ -287,7 +326,18 @@ protected:
delta_stats_.update(chunk);
if (!output_format_) [[unlikely]]
output_format_ = format_file_->createOutputFormat();
-
output_format_->output->write(materializeBlock(getHeader().cloneWithColumns(chunk.detachColumns())));
+
+ const DB::Block & input_header = getHeader();
+ if (bucketed_write_)
+ {
+ chunk.erase(input_header.columns() - 1);
+ const DB::ColumnsWithTypeAndName & cols =
input_header.getColumnsWithTypeAndName();
+ DB::ColumnsWithTypeAndName without_bucket_cols(cols.begin(),
cols.end() - 1);
+ DB::Block without_bucket_header = DB::Block(without_bucket_cols);
+
output_format_->output->write(materializeBlock(without_bucket_header.cloneWithColumns(chunk.detachColumns())));
+ }
+ else
+
output_format_->output->write(materializeBlock(input_header.cloneWithColumns(chunk.detachColumns())));
}
void onFinish() override
{
@@ -303,11 +353,19 @@ protected:
class SparkPartitionedBaseSink : public DB::PartitionedSink
{
- static const std::string DEFAULT_PARTITION_NAME;
public:
+ static const std::string DEFAULT_PARTITION_NAME;
+ static const std::string BUCKET_COLUMN_NAME;
+
+ static bool isBucketedWrite(const DB::Block & input_header)
+ {
+ return input_header.has(BUCKET_COLUMN_NAME) &&
+ input_header.getPositionByName(BUCKET_COLUMN_NAME) ==
input_header.columns() - 1;
+ }
+
/// visible for UTs
- static DB::ASTPtr make_partition_expression(const DB::Names &
partition_columns)
+ static DB::ASTPtr make_partition_expression(const DB::Names &
partition_columns, const DB::Block & input_header)
{
/// Parse the following expression into ASTs
/// cancat('/col_name=', 'toString(col_name)')
@@ -327,13 +385,35 @@ public:
makeASTFunction("toString", DB::ASTs{column_ast}),
std::make_shared<DB::ASTLiteral>(DEFAULT_PARTITION_NAME)};
arguments.emplace_back(makeASTFunction("ifNull",
std::move(if_null_args)));
}
+ if (isBucketedWrite(input_header))
+ {
+ DB::ASTs args {std::make_shared<DB::ASTLiteral>("%05d"),
std::make_shared<DB::ASTIdentifier>(BUCKET_COLUMN_NAME)};
+ arguments.emplace_back(DB::makeASTFunction("printf",
std::move(args)));
+ }
+ assert(!arguments.empty());
+ if (arguments.size() == 1)
+ return arguments[0];
return DB::makeASTFunction("concat", std::move(arguments));
}
+ DB::SinkPtr createSinkForPartition(const String & partition_id) override
+ {
+ if (bucketed_write_)
+ {
+ std::string bucket_val = partition_id.substr(partition_id.length()
- 5, 5);
+ std::string real_partition_id = partition_id.substr(0,
partition_id.length() - 5);
+ return createSinkForPartition(real_partition_id, bucket_val);
+ }
+ return createSinkForPartition(partition_id, "");
+ }
+
+ virtual DB::SinkPtr createSinkForPartition(const String & partition_id,
const String & bucket) = 0;
+
protected:
DB::ContextPtr context_;
std::shared_ptr<WriteStatsBase> stats_;
DeltaStats empty_delta_stats_;
+ bool bucketed_write_;
public:
SparkPartitionedBaseSink(
@@ -341,9 +421,10 @@ public:
const DB::Names & partition_by,
const DB::Block & input_header,
const std::shared_ptr<WriteStatsBase> & stats)
- : PartitionedSink(make_partition_expression(partition_by), context,
input_header)
+ : PartitionedSink(make_partition_expression(partition_by,
input_header), context, input_header)
, context_(context)
, stats_(stats)
+ , bucketed_write_(isBucketedWrite(input_header))
, empty_delta_stats_(DeltaStats::create(input_header, partition_by))
{
}
@@ -353,6 +434,7 @@ class SubstraitPartitionedFileSink final : public
SparkPartitionedBaseSink
{
const std::string base_path_;
const FileNameGenerator generator_;
+ const DB::Block input_header_;
const DB::Block sample_block_;
const std::string format_hint_;
@@ -370,18 +452,20 @@ public:
, base_path_(base_path)
, generator_(generator)
, sample_block_(sample_block)
+ , input_header_(input_header)
, format_hint_(format_hint)
{
}
- DB::SinkPtr createSinkForPartition(const String & partition_id) override
+ DB::SinkPtr createSinkForPartition(const String & partition_id, const
String & bucket) override
{
assert(stats_);
- std::string filename = generator_.generate();
+ bool bucketed_write = !bucket.empty();
+ std::string filename = bucketed_write ? generator_.generate(bucket) :
generator_.generate();
const auto partition_path = fmt::format("{}/{}", partition_id,
filename);
validatePartitionKey(partition_path, true);
return std::make_shared<SubstraitFileSink>(
- context_, base_path_, partition_id, filename, format_hint_,
sample_block_, stats_, empty_delta_stats_);
+ context_, base_path_, partition_id, bucketed_write, filename,
format_hint_, sample_block_, stats_, empty_delta_stats_);
}
String getName() const override { return "SubstraitPartitionedFileSink"; }
};
diff --git a/cpp-ch/local-engine/Storages/Output/OutputFormatFile.cpp
b/cpp-ch/local-engine/Storages/Output/OutputFormatFile.cpp
index 194d997ddf..d5ed430943 100644
--- a/cpp-ch/local-engine/Storages/Output/OutputFormatFile.cpp
+++ b/cpp-ch/local-engine/Storages/Output/OutputFormatFile.cpp
@@ -61,7 +61,6 @@ Block OutputFormatFile::createHeaderWithPreferredSchema(const
Block & header)
ColumnWithTypeAndName column(preferred_column.type->createColumn(),
preferred_column.type, preferred_column.name);
columns.emplace_back(std::move(column));
}
- assert(preferred_schema.columns() == index);
return {std::move(columns)};
}
diff --git a/cpp-ch/local-engine/tests/gtest_write_pipeline.cpp
b/cpp-ch/local-engine/tests/gtest_write_pipeline.cpp
index 00f2da20c5..b764f62f54 100644
--- a/cpp-ch/local-engine/tests/gtest_write_pipeline.cpp
+++ b/cpp-ch/local-engine/tests/gtest_write_pipeline.cpp
@@ -206,14 +206,14 @@ TEST(WritePipeline, SubstraitPartitionedFileSink)
TEST(WritePipeline, ComputePartitionedExpression)
{
const auto context =
DB::Context::createCopy(QueryContext::globalContext());
-
- auto partition_by =
SubstraitPartitionedFileSink::make_partition_expression({"s_nationkey",
"name"});
+
+ Block sample_block{{STRING(), "name"}, {UINT(), "s_nationkey"}};
+ auto partition_by =
SubstraitPartitionedFileSink::make_partition_expression({"s_nationkey",
"name"}, sample_block);
// auto partition_by = printColumn("s_nationkey");
ASTs arguments(1, partition_by);
ASTPtr partition_by_string = makeASTFunction("toString",
std::move(arguments));
- Block sample_block{{STRING(), "name"}, {UINT(), "s_nationkey"}};
auto syntax_result = TreeRewriter(context).analyze(partition_by_string,
sample_block.getNamesAndTypesList());
auto partition_by_expr = ExpressionAnalyzer(partition_by_string,
syntax_result, context).getActions(false);
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]