This is an automated email from the ASF dual-hosted git repository.

changchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new 2346584a72 [GLUTEN-7028][CH][Part-11] Support write parquet files with 
bucket (#8052)
2346584a72 is described below

commit 2346584a72cae67494a073a4f0141386bad723db
Author: Wenzheng Liu <[email protected]>
AuthorDate: Tue Dec 3 15:57:26 2024 +0800

    [GLUTEN-7028][CH][Part-11] Support write parquet files with bucket (#8052)
    
    * [GLUTEN-7028][CH] Support write parquet files with bucket
    
    * [GLUTEN-7028][CH] Fix comment
---
 .../sql/execution/FileDeltaColumnarWrite.scala     |   6 +-
 .../gluten/backendsapi/clickhouse/CHBackend.scala  |  11 +--
 .../backendsapi/clickhouse/CHIteratorApi.scala     |  13 +--
 .../gluten/backendsapi/clickhouse/CHRuleApi.scala  |   1 +
 .../backendsapi/clickhouse/RuntimeSettings.scala   |   6 --
 .../extension/WriteFilesWithBucketValue.scala      |  76 ++++++++++++++
 .../spark/sql/execution/CHColumnarWrite.scala      |  46 +++++++--
 .../GlutenClickHouseNativeWriteTableSuite.scala    |  15 +--
 .../Parser/RelParsers/WriteRelParser.cpp           |  14 ++-
 .../Parser/RelParsers/WriteRelParser.h             |   1 -
 .../Storages/MergeTree/SparkMergeTreeSink.h        |   6 ++
 .../Storages/Output/NormalFileWriter.cpp           |   4 +-
 .../Storages/Output/NormalFileWriter.h             | 110 ++++++++++++++++++---
 .../Storages/Output/OutputFormatFile.cpp           |   1 -
 cpp-ch/local-engine/tests/gtest_write_pipeline.cpp |   6 +-
 15 files changed, 242 insertions(+), 74 deletions(-)

diff --git 
a/backends-clickhouse/src/main/delta-32/org/apache/spark/sql/execution/FileDeltaColumnarWrite.scala
 
b/backends-clickhouse/src/main/delta-32/org/apache/spark/sql/execution/FileDeltaColumnarWrite.scala
index 784614152f..bf6b0c0074 100644
--- 
a/backends-clickhouse/src/main/delta-32/org/apache/spark/sql/execution/FileDeltaColumnarWrite.scala
+++ 
b/backends-clickhouse/src/main/delta-32/org/apache/spark/sql/execution/FileDeltaColumnarWrite.scala
@@ -108,13 +108,15 @@ case class FileDeltaColumnarWrite(
      * {{{
      *   part-00000-7d672b28-c079-4b00-bb0a-196c15112918-c000.snappy.parquet
      *     =>
-     *   part-00000-{}.snappy.parquet
+     *   part-00000-{id}.snappy.parquet
      * }}}
      */
     val guidPattern =
       
""".*-([0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12})(?:-c(\d+)\..*)?$""".r
     val fileNamePattern =
-      guidPattern.replaceAllIn(writeFileName, m => 
writeFileName.replace(m.group(1), "{}"))
+      guidPattern.replaceAllIn(
+        writeFileName,
+        m => writeFileName.replace(m.group(1), FileNamePlaceHolder.ID))
 
     logDebug(s"Native staging write path: $writePath and with pattern: 
$fileNamePattern")
     val settings =
diff --git 
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala
 
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala
index c6c8acf705..e5eb91b69b 100644
--- 
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala
+++ 
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala
@@ -246,20 +246,11 @@ object CHBackendSettings extends BackendSettingsApi with 
Logging {
       }
     }
 
-    def validateBucketSpec(): Option[String] = {
-      if (bucketSpec.nonEmpty) {
-        Some("Unsupported native write: bucket write is not supported.")
-      } else {
-        None
-      }
-    }
-
     validateCompressionCodec()
       .orElse(validateFileFormat())
       .orElse(validateFieldMetadata())
       .orElse(validateDateTypes())
-      .orElse(validateWriteFilesOptions())
-      .orElse(validateBucketSpec()) match {
+      .orElse(validateWriteFilesOptions()) match {
       case Some(reason) => ValidationResult.failed(reason)
       case _ => ValidationResult.succeeded
     }
diff --git 
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHIteratorApi.scala
 
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHIteratorApi.scala
index ff268b95d8..878e27a5b8 100644
--- 
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHIteratorApi.scala
+++ 
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHIteratorApi.scala
@@ -26,7 +26,7 @@ import org.apache.gluten.sql.shims.SparkShimLoader
 import org.apache.gluten.substrait.plan.PlanNode
 import org.apache.gluten.substrait.rel._
 import org.apache.gluten.substrait.rel.LocalFilesNode.ReadFileFormat
-import org.apache.gluten.vectorized.{BatchIterator, 
CHNativeExpressionEvaluator, CloseableCHColumnBatchIterator, 
NativeExpressionEvaluator}
+import org.apache.gluten.vectorized.{BatchIterator, 
CHNativeExpressionEvaluator, CloseableCHColumnBatchIterator}
 
 import org.apache.spark.{InterruptibleIterator, SparkConf, TaskContext}
 import org.apache.spark.affinity.CHAffinity
@@ -322,17 +322,6 @@ class CHIteratorApi extends IteratorApi with Logging with 
LogLevelUtil {
       createNativeIterator(splitInfoByteArray, wsPlan, materializeInput, 
inputIterators))
   }
 
-  /**
-   * This function used to inject the staging write path before initializing 
the native plan.Only
-   * used in a pipeline model (spark 3.5) for writing parquet or orc files.
-   */
-  override def injectWriteFilesTempPath(path: String, fileName: String): Unit 
= {
-    val settings =
-      Map(
-        RuntimeSettings.TASK_WRITE_TMP_DIR.key -> path,
-        RuntimeSettings.TASK_WRITE_FILENAME.key -> fileName)
-    NativeExpressionEvaluator.updateQueryRuntimeSettings(settings)
-  }
 }
 
 class CollectMetricIterator(
diff --git 
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala
 
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala
index edf7a48025..98cfa0e754 100644
--- 
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala
+++ 
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala
@@ -79,6 +79,7 @@ object CHRuleApi {
     injector.injectPreTransform(_ => RewriteSubqueryBroadcast())
     injector.injectPreTransform(c => 
FallbackBroadcastHashJoin.apply(c.session))
     injector.injectPreTransform(c => 
MergeTwoPhasesHashBaseAggregate.apply(c.session))
+    injector.injectPreTransform(_ => WriteFilesWithBucketValue)
 
     // Legacy: The legacy transform rule.
     val validatorBuilder: GlutenConfig => Validator = conf =>
diff --git 
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/RuntimeSettings.scala
 
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/RuntimeSettings.scala
index b59bb32392..c2747cf1eb 100644
--- 
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/RuntimeSettings.scala
+++ 
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/RuntimeSettings.scala
@@ -35,12 +35,6 @@ object RuntimeSettings {
       .stringConf
       .createWithDefault("")
 
-  val TASK_WRITE_FILENAME =
-    buildConf(runtimeSettings("gluten.task_write_filename"))
-      .doc("The temporary file name for writing data")
-      .stringConf
-      .createWithDefault("")
-
   val TASK_WRITE_FILENAME_PATTERN =
     buildConf(runtimeSettings("gluten.task_write_filename_pattern"))
       .doc("The pattern to generate file name for writing delta parquet in 
spark 3.5")
diff --git 
a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/WriteFilesWithBucketValue.scala
 
b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/WriteFilesWithBucketValue.scala
new file mode 100644
index 0000000000..8ab78dcff9
--- /dev/null
+++ 
b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/WriteFilesWithBucketValue.scala
@@ -0,0 +1,76 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.gluten.extension
+
+import org.apache.gluten.GlutenConfig
+
+import org.apache.spark.sql.catalyst.expressions.{Alias, BitwiseAnd, 
Expression, HiveHash, Literal, Pmod}
+import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.execution.{ProjectExec, SparkPlan}
+import org.apache.spark.sql.execution.datasources.WriteFilesExec
+
+/**
+ * Wrap with bucket value to specify the bucket file name in native write. 
Native writer will remove
+ * this value in the final output.
+ */
+object WriteFilesWithBucketValue extends Rule[SparkPlan] {
+
+  val optionForHiveCompatibleBucketWrite = 
"__hive_compatible_bucketed_table_insertion__"
+
+  override def apply(plan: SparkPlan): SparkPlan = {
+    if (
+      GlutenConfig.getConf.enableGluten
+      && GlutenConfig.getConf.enableNativeWriter.getOrElse(false)
+    ) {
+      plan.transformDown {
+        case writeFiles: WriteFilesExec if writeFiles.bucketSpec.isDefined =>
+          val bucketIdExp = getWriterBucketIdExp(writeFiles)
+          val wrapBucketValue = ProjectExec(
+            writeFiles.child.output :+ Alias(bucketIdExp, 
"__bucket_value__")(),
+            writeFiles.child)
+          writeFiles.copy(child = wrapBucketValue)
+      }
+    } else {
+      plan
+    }
+  }
+
+  private def getWriterBucketIdExp(writeFilesExec: WriteFilesExec): Expression 
= {
+    val partitionColumns = writeFilesExec.partitionColumns
+    val outputColumns = writeFilesExec.child.output
+    val dataColumns = outputColumns.filterNot(partitionColumns.contains)
+    val bucketSpec = writeFilesExec.bucketSpec.get
+    val bucketColumns = bucketSpec.bucketColumnNames.map(c => 
dataColumns.find(_.name == c).get)
+    if (writeFilesExec.options.getOrElse(optionForHiveCompatibleBucketWrite, 
"false") == "true") {
+      val hashId = BitwiseAnd(HiveHash(bucketColumns), Literal(Int.MaxValue))
+      Pmod(hashId, Literal(bucketSpec.numBuckets))
+      // The bucket file name prefix is following Hive, Presto and Trino 
conversion, so this
+      // makes sure Hive bucketed table written by Spark, can be read by other 
SQL engines.
+      //
+      // Hive: 
`org.apache.hadoop.hive.ql.exec.Utilities#getBucketIdFromFile()`.
+      // Trino: 
`io.trino.plugin.hive.BackgroundHiveSplitLoader#BUCKET_PATTERNS`.
+
+    } else {
+      // Spark bucketed table: use `HashPartitioning.partitionIdExpression` as 
bucket id
+      // expression, so that we can guarantee the data distribution is same 
between shuffle and
+      // bucketed data source, which enables us to only shuffle one side when 
join a bucketed
+      // table and a normal one.
+      HashPartitioning(bucketColumns, 
bucketSpec.numBuckets).partitionIdExpression
+    }
+  }
+}
diff --git 
a/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/CHColumnarWrite.scala
 
b/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/CHColumnarWrite.scala
index 6c7877cc02..1342e25043 100644
--- 
a/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/CHColumnarWrite.scala
+++ 
b/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/CHColumnarWrite.scala
@@ -16,7 +16,8 @@
  */
 package org.apache.spark.sql.execution
 
-import org.apache.gluten.backendsapi.BackendsApiManager
+import org.apache.gluten.backendsapi.clickhouse.RuntimeSettings
+import org.apache.gluten.vectorized.NativeExpressionEvaluator
 
 import org.apache.spark.TaskContext
 import org.apache.spark.internal.Logging
@@ -25,11 +26,11 @@ import 
org.apache.spark.internal.io.FileCommitProtocol.TaskCommitMessage
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.GenericInternalRow
 import org.apache.spark.sql.delta.stats.DeltaJobStatisticsTracker
-import org.apache.spark.sql.execution.datasources.{BasicWriteJobStatsTracker, 
BasicWriteTaskStats, ExecutedWriteSummary, PartitioningUtils, 
WriteJobDescription, WriteTaskResult, WriteTaskStatsTracker}
+import org.apache.spark.sql.execution.datasources._
 import org.apache.spark.util.Utils
 
 import org.apache.hadoop.fs.Path
-import org.apache.hadoop.mapreduce.{JobID, OutputCommitter, 
TaskAttemptContext, TaskAttemptID, TaskID, TaskType}
+import org.apache.hadoop.mapreduce._
 import org.apache.hadoop.mapreduce.lib.output.FileOutputCommitter
 import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl
 
@@ -102,6 +103,12 @@ object CreateFileNameSpec {
   }
 }
 
+// More details in local_engine::FileNameGenerator in NormalFileWriter.cpp
+object FileNamePlaceHolder {
+  val ID = "{id}"
+  val BUCKET = "{bucket}"
+}
+
 /** [[HadoopMapReduceAdapter]] for [[HadoopMapReduceCommitProtocol]]. */
 case class HadoopMapReduceAdapter(sparkCommitter: 
HadoopMapReduceCommitProtocol) {
   private lazy val committer: OutputCommitter = {
@@ -132,12 +139,26 @@ case class HadoopMapReduceAdapter(sparkCommitter: 
HadoopMapReduceCommitProtocol)
     GetFilename.invoke(sparkCommitter, taskContext, spec).asInstanceOf[String]
   }
 
-  def getTaskAttemptTempPathAndFilename(
+  def getTaskAttemptTempPathAndFilePattern(
       taskContext: TaskAttemptContext,
       description: WriteJobDescription): (String, String) = {
     val stageDir = newTaskAttemptTempPath(description.path)
-    val filename = getFilename(taskContext, CreateFileNameSpec(taskContext, 
description))
-    (stageDir, filename)
+
+    if (isBucketWrite(description)) {
+      val filePart = getFilename(taskContext, FileNameSpec("", ""))
+      val fileSuffix = CreateFileNameSpec(taskContext, description).suffix
+      (stageDir, s"${filePart}_${FileNamePlaceHolder.BUCKET}$fileSuffix")
+    } else {
+      val filename = getFilename(taskContext, CreateFileNameSpec(taskContext, 
description))
+      (stageDir, filename)
+    }
+  }
+
+  private def isBucketWrite(desc: WriteJobDescription): Boolean = {
+    // In Spark 3.2, bucketSpec is not defined, instead, it uses 
bucketIdExpression.
+    val bucketSpecField: Field = desc.getClass.getDeclaredField("bucketSpec")
+    bucketSpecField.setAccessible(true)
+    bucketSpecField.get(desc).asInstanceOf[Option[_]].isDefined
   }
 }
 
@@ -234,10 +255,15 @@ case class HadoopMapReduceCommitProtocolWrite(
    * initializing the native plan and collect native write files metrics for 
each backend.
    */
   override def doSetupNativeTask(): Unit = {
-    val (writePath, writeFileName) =
-      adapter.getTaskAttemptTempPathAndFilename(taskAttemptContext, 
description)
-    logDebug(s"Native staging write path: $writePath and file name: 
$writeFileName")
-    
BackendsApiManager.getIteratorApiInstance.injectWriteFilesTempPath(writePath, 
writeFileName)
+    val (writePath, writeFilePattern) =
+      adapter.getTaskAttemptTempPathAndFilePattern(taskAttemptContext, 
description)
+    logDebug(s"Native staging write path: $writePath and file pattern: 
$writeFilePattern")
+
+    val settings =
+      Map(
+        RuntimeSettings.TASK_WRITE_TMP_DIR.key -> writePath,
+        RuntimeSettings.TASK_WRITE_FILENAME_PATTERN.key -> writeFilePattern)
+    NativeExpressionEvaluator.updateQueryRuntimeSettings(settings)
   }
 
   def doCollectNativeResult(stats: Seq[InternalRow]): Option[WriteTaskResult] 
= {
diff --git 
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/hive/GlutenClickHouseNativeWriteTableSuite.scala
 
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/hive/GlutenClickHouseNativeWriteTableSuite.scala
index 03d27f33b1..16ed302a02 100644
--- 
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/hive/GlutenClickHouseNativeWriteTableSuite.scala
+++ 
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/hive/GlutenClickHouseNativeWriteTableSuite.scala
@@ -553,7 +553,7 @@ class GlutenClickHouseNativeWriteTableSuite
           // spark write does not support bucketed table
           // https://issues.apache.org/jira/browse/SPARK-19256
           val table_name = table_name_template.format(format)
-          writeAndCheckRead(origin_table, table_name, fields_.keys.toSeq, 
isSparkVersionLE("3.3")) {
+          writeAndCheckRead(origin_table, table_name, fields_.keys.toSeq) {
             fields =>
               spark
                 .table("origin_table")
@@ -589,8 +589,9 @@ class GlutenClickHouseNativeWriteTableSuite
       ("byte_field", "byte"),
       ("boolean_field", "boolean"),
       ("decimal_field", "decimal(23,12)"),
-      ("date_field", "date"),
-      ("timestamp_field", "timestamp")
+      ("date_field", "date")
+      // ("timestamp_field", "timestamp")
+      // FIXME https://github.com/apache/incubator-gluten/issues/8053
     )
     val origin_table = "origin_table"
     withSource(genTestData(), origin_table) {
@@ -598,7 +599,7 @@ class GlutenClickHouseNativeWriteTableSuite
         format =>
           val table_name = table_name_template.format(format)
           val testFields = fields.keys.toSeq
-          writeAndCheckRead(origin_table, table_name, testFields, 
isSparkVersionLE("3.3")) {
+          writeAndCheckRead(origin_table, table_name, testFields) {
             fields =>
               spark
                 .table(origin_table)
@@ -658,7 +659,7 @@ class GlutenClickHouseNativeWriteTableSuite
       nativeWrite {
         format =>
           val table_name = table_name_template.format(format)
-          writeAndCheckRead(origin_table, table_name, fields.keys.toSeq, 
isSparkVersionLE("3.3")) {
+          writeAndCheckRead(origin_table, table_name, fields.keys.toSeq) {
             fields =>
               spark
                 .table("origin_table")
@@ -762,7 +763,7 @@ class GlutenClickHouseNativeWriteTableSuite
       format =>
         val table_name = table_name_template.format(format)
         spark.sql(s"drop table IF EXISTS $table_name")
-        withNativeWriteCheck(checkNative = isSparkVersionLE("3.3")) {
+        withNativeWriteCheck(checkNative = true) {
           spark
             .range(10000000)
             .selectExpr("id", "cast('2020-01-01' as date) as p")
@@ -798,7 +799,7 @@ class GlutenClickHouseNativeWriteTableSuite
       format =>
         val table_name = table_name_template.format(format)
         spark.sql(s"drop table IF EXISTS $table_name")
-        withNativeWriteCheck(checkNative = isSparkVersionLE("3.3")) {
+        withNativeWriteCheck(checkNative = true) {
           spark
             .range(30000)
             .selectExpr("id", "cast(null as string) as p")
diff --git a/cpp-ch/local-engine/Parser/RelParsers/WriteRelParser.cpp 
b/cpp-ch/local-engine/Parser/RelParsers/WriteRelParser.cpp
index 2dacb39918..a76b4d398d 100644
--- a/cpp-ch/local-engine/Parser/RelParsers/WriteRelParser.cpp
+++ b/cpp-ch/local-engine/Parser/RelParsers/WriteRelParser.cpp
@@ -56,10 +56,11 @@ DB::ProcessorPtr make_sink(
     const std::string & format_hint,
     const std::shared_ptr<WriteStats> & stats)
 {
-    if (partition_by.empty())
+    bool no_bucketed = 
!SparkPartitionedBaseSink::isBucketedWrite(input_header);
+    if (partition_by.empty() && no_bucketed)
     {
         return std::make_shared<SubstraitFileSink>(
-            context, base_path, "", generator.generate(), format_hint, 
input_header, stats, DeltaStats{input_header.columns()});
+            context, base_path, "", false, generator.generate(), format_hint, 
input_header, stats, DeltaStats{input_header.columns()});
     }
 
     return std::make_shared<SubstraitPartitionedFileSink>(
@@ -184,13 +185,10 @@ void addNormalFileWriterSinkTransform(
     if (write_settings.task_write_tmp_dir.empty())
         throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Write Pipeline 
need inject temp directory.");
 
-    if (write_settings.task_write_filename.empty() && 
write_settings.task_write_filename_pattern.empty())
-        throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Write Pipeline 
need inject file name or file name pattern.");
+    if (write_settings.task_write_filename_pattern.empty())
+        throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Write Pipeline 
need inject file pattern.");
 
-    FileNameGenerator generator{
-        .pattern = write_settings.task_write_filename.empty(),
-        .filename_or_pattern
-        = write_settings.task_write_filename.empty() ? 
write_settings.task_write_filename_pattern : 
write_settings.task_write_filename};
+    FileNameGenerator generator(write_settings.task_write_filename_pattern);
 
     auto stats = WriteStats::create(output, partitionCols);
 
diff --git a/cpp-ch/local-engine/Parser/RelParsers/WriteRelParser.h 
b/cpp-ch/local-engine/Parser/RelParsers/WriteRelParser.h
index 0c9bc11f1f..01e0dabaaa 100644
--- a/cpp-ch/local-engine/Parser/RelParsers/WriteRelParser.h
+++ b/cpp-ch/local-engine/Parser/RelParsers/WriteRelParser.h
@@ -44,7 +44,6 @@ DB::Names collect_partition_cols(const DB::Block & header, 
const substrait::Name
 
 #define WRITE_RELATED_SETTINGS(M, ALIAS) \
     M(String, task_write_tmp_dir, , "The temporary directory for writing 
data") \
-    M(String, task_write_filename, , "The filename for writing data") \
     M(String, task_write_filename_pattern, , "The pattern to generate file 
name for writing delta parquet in spark 3.5")
 
 DECLARE_GLUTEN_SETTINGS(GlutenWriteSettings, WRITE_RELATED_SETTINGS)
diff --git a/cpp-ch/local-engine/Storages/MergeTree/SparkMergeTreeSink.h 
b/cpp-ch/local-engine/Storages/MergeTree/SparkMergeTreeSink.h
index 38f574ea98..b551d86d1d 100644
--- a/cpp-ch/local-engine/Storages/MergeTree/SparkMergeTreeSink.h
+++ b/cpp-ch/local-engine/Storages/MergeTree/SparkMergeTreeSink.h
@@ -278,6 +278,12 @@ public:
         return SparkMergeTreeSink::create(
             table, write_settings, context_->getGlobalContext(), 
{std::dynamic_pointer_cast<MergeTreeStats>(stats_)});
     }
+
+    // TODO implement with bucket
+    DB::SinkPtr createSinkForPartition(const String & partition_id, const 
String & bucket) override
+    {
+        return createSinkForPartition(partition_id);
+    }
 };
 
 }
diff --git a/cpp-ch/local-engine/Storages/Output/NormalFileWriter.cpp 
b/cpp-ch/local-engine/Storages/Output/NormalFileWriter.cpp
index ad2e3abf7b..2d70380a89 100644
--- a/cpp-ch/local-engine/Storages/Output/NormalFileWriter.cpp
+++ b/cpp-ch/local-engine/Storages/Output/NormalFileWriter.cpp
@@ -30,6 +30,8 @@ using namespace DB;
 
 const std::string SubstraitFileSink::NO_PARTITION_ID{"__NO_PARTITION_ID__"};
 const std::string 
SparkPartitionedBaseSink::DEFAULT_PARTITION_NAME{"__HIVE_DEFAULT_PARTITION__"};
+const std::string 
SparkPartitionedBaseSink::BUCKET_COLUMN_NAME{"__bucket_value__"};
+const std::vector<std::string> FileNameGenerator::SUPPORT_PLACEHOLDERS{"{id}", 
"{bucket}"};
 
 /// For Nullable(Map(K, V)) or Nullable(Array(T)), if the i-th row is null, we 
must make sure its nested data is empty.
 /// It is for ORC/Parquet writing compatiability. For more details, refer to
@@ -168,7 +170,7 @@ void NormalFileWriter::write(DB::Block & block)
     const auto & preferred_schema = file->getPreferredSchema();
     for (auto & column : block)
     {
-        if (column.name.starts_with("__bucket_value__"))
+        if 
(column.name.starts_with(SparkPartitionedBaseSink::BUCKET_COLUMN_NAME))
             continue;
 
         const auto & preferred_column = 
preferred_schema.getByPosition(index++);
diff --git a/cpp-ch/local-engine/Storages/Output/NormalFileWriter.h 
b/cpp-ch/local-engine/Storages/Output/NormalFileWriter.h
index 8cfe079d92..998f8d6247 100644
--- a/cpp-ch/local-engine/Storages/Output/NormalFileWriter.h
+++ b/cpp-ch/local-engine/Storages/Output/NormalFileWriter.h
@@ -230,20 +230,57 @@ public:
 
 struct FileNameGenerator
 {
-    const bool pattern;
-    const std::string filename_or_pattern;
+    // Align with org.apache.spark.sql.execution.FileNamePlaceHolder
+    static const std::vector<std::string> SUPPORT_PLACEHOLDERS;
+    // Align with placeholders above
+    const std::vector<bool> need_to_replace;
+    const std::string file_pattern;
+
+    FileNameGenerator(const std::string & file_pattern)
+        : file_pattern(file_pattern), 
need_to_replace(compute_need_to_replace(file_pattern))
+    {
+    }
+
+    std::vector<bool> compute_need_to_replace(const std::string & file_pattern)
+    {
+        std::vector<bool> result;
+        for(const std::string& placeholder: SUPPORT_PLACEHOLDERS)
+        {
+            if (file_pattern.find(placeholder) != std::string::npos)
+                result.push_back(true);
+            else
+                result.push_back(false);
+        }
+        return result;
+    }
+
+    std::string generate(const std::string & bucket = "") const
+    {
+        std::string result = file_pattern;
+        if (need_to_replace[0]) // {id}
+            result = pattern_format(SUPPORT_PLACEHOLDERS[0], 
toString(DB::UUIDHelpers::generateV4()));
+        if (need_to_replace[1]) // {bucket}
+            result = pattern_format(SUPPORT_PLACEHOLDERS[1], bucket);
+        return result;
+    }
 
-    std::string generate() const
+    std::string pattern_format(const std::string & arg, const std::string & 
replacement) const
     {
-        if (pattern)
-            return fmt::vformat(filename_or_pattern, 
fmt::make_format_args(toString(DB::UUIDHelpers::generateV4())));
-        return filename_or_pattern;
+        std::string format_str = file_pattern;
+        size_t pos = format_str.find(arg);
+        while (pos != std::string::npos)
+        {
+            format_str.replace(pos, arg.length(), replacement);
+            pos = format_str.find(arg, pos + arg.length());
+        }
+        return format_str;
     }
 };
 
 class SubstraitFileSink final : public DB::SinkToStorage
 {
     const std::string partition_id_;
+    const bool bucketed_write_;
     const std::string relative_path_;
     OutputFormatFilePtr format_file_;
     OutputFormatFile::OutputFormatPtr output_format_;
@@ -265,6 +302,7 @@ public:
         const DB::ContextPtr & context,
         const std::string & base_path,
         const std::string & partition_id,
+        const bool bucketed_write,
         const std::string & relative,
         const std::string & format_hint,
         const DB::Block & header,
@@ -272,6 +310,7 @@ public:
         const DeltaStats & delta_stats)
         : SinkToStorage(header)
         , partition_id_(partition_id.empty() ? NO_PARTITION_ID : partition_id)
+        , bucketed_write_(bucketed_write)
         , relative_path_(relative)
         , format_file_(createOutputFormatFile(context, 
makeAbsoluteFilename(base_path, partition_id, relative), header, format_hint))
         , stats_(std::dynamic_pointer_cast<WriteStats>(stats))
@@ -287,7 +326,18 @@ protected:
         delta_stats_.update(chunk);
         if (!output_format_) [[unlikely]]
             output_format_ = format_file_->createOutputFormat();
-        
output_format_->output->write(materializeBlock(getHeader().cloneWithColumns(chunk.detachColumns())));
+
+        const DB::Block & input_header = getHeader();
+        if (bucketed_write_)
+        {
+            chunk.erase(input_header.columns() - 1);
+            const DB::ColumnsWithTypeAndName & cols = 
input_header.getColumnsWithTypeAndName();
+            DB::ColumnsWithTypeAndName without_bucket_cols(cols.begin(), 
cols.end() - 1);
+            DB::Block without_bucket_header = DB::Block(without_bucket_cols);
+            
output_format_->output->write(materializeBlock(without_bucket_header.cloneWithColumns(chunk.detachColumns())));
+        }
+        else
+            
output_format_->output->write(materializeBlock(input_header.cloneWithColumns(chunk.detachColumns())));
     }
     void onFinish() override
     {
@@ -303,11 +353,19 @@ protected:
 
 class SparkPartitionedBaseSink : public DB::PartitionedSink
 {
-    static const std::string DEFAULT_PARTITION_NAME;
 
 public:
+    static const std::string DEFAULT_PARTITION_NAME;
+    static const std::string BUCKET_COLUMN_NAME;
+
+    static bool isBucketedWrite(const DB::Block & input_header)
+    {
+        return input_header.has(BUCKET_COLUMN_NAME) &&
+            input_header.getPositionByName(BUCKET_COLUMN_NAME) == 
input_header.columns() - 1;
+    }
+
     /// visible for UTs
-    static DB::ASTPtr make_partition_expression(const DB::Names & 
partition_columns)
+    static DB::ASTPtr make_partition_expression(const DB::Names & 
partition_columns, const DB::Block & input_header)
     {
         /// Parse the following expression into ASTs
         /// cancat('/col_name=', 'toString(col_name)')
@@ -327,13 +385,35 @@ public:
                 makeASTFunction("toString", DB::ASTs{column_ast}), 
std::make_shared<DB::ASTLiteral>(DEFAULT_PARTITION_NAME)};
             arguments.emplace_back(makeASTFunction("ifNull", 
std::move(if_null_args)));
         }
+        if (isBucketedWrite(input_header))
+        {
+            DB::ASTs args {std::make_shared<DB::ASTLiteral>("%05d"), 
std::make_shared<DB::ASTIdentifier>(BUCKET_COLUMN_NAME)};
+            arguments.emplace_back(DB::makeASTFunction("printf", 
std::move(args)));
+        }
+        assert(!arguments.empty());
+        if (arguments.size() == 1)
+            return arguments[0];
         return DB::makeASTFunction("concat", std::move(arguments));
     }
 
+    DB::SinkPtr createSinkForPartition(const String & partition_id) override
+    {
+        if (bucketed_write_)
+        {
+            std::string bucket_val = partition_id.substr(partition_id.length() 
- 5, 5);
+            std::string real_partition_id = partition_id.substr(0, 
partition_id.length() - 5);
+            return createSinkForPartition(real_partition_id, bucket_val);
+        }
+        return createSinkForPartition(partition_id, "");
+    }
+
+    virtual DB::SinkPtr createSinkForPartition(const String & partition_id, 
const String & bucket) = 0;
+
 protected:
     DB::ContextPtr context_;
     std::shared_ptr<WriteStatsBase> stats_;
     DeltaStats empty_delta_stats_;
+    bool bucketed_write_;
 
 public:
     SparkPartitionedBaseSink(
@@ -341,9 +421,10 @@ public:
         const DB::Names & partition_by,
         const DB::Block & input_header,
         const std::shared_ptr<WriteStatsBase> & stats)
-        : PartitionedSink(make_partition_expression(partition_by), context, 
input_header)
+        : PartitionedSink(make_partition_expression(partition_by, 
input_header), context, input_header)
         , context_(context)
         , stats_(stats)
+        , bucketed_write_(isBucketedWrite(input_header))
         , empty_delta_stats_(DeltaStats::create(input_header, partition_by))
     {
     }
@@ -353,6 +434,7 @@ class SubstraitPartitionedFileSink final : public 
SparkPartitionedBaseSink
 {
     const std::string base_path_;
     const FileNameGenerator generator_;
+    const DB::Block input_header_;
     const DB::Block sample_block_;
     const std::string format_hint_;
 
@@ -370,18 +452,20 @@ public:
         , base_path_(base_path)
         , generator_(generator)
         , sample_block_(sample_block)
+        , input_header_(input_header)
         , format_hint_(format_hint)
     {
     }
 
-    DB::SinkPtr createSinkForPartition(const String & partition_id) override
+    DB::SinkPtr createSinkForPartition(const String & partition_id, const 
String & bucket) override
     {
         assert(stats_);
-        std::string filename = generator_.generate();
+        bool bucketed_write = !bucket.empty();
+        std::string filename = bucketed_write ? generator_.generate(bucket) : 
generator_.generate();
         const auto partition_path = fmt::format("{}/{}", partition_id, 
filename);
         validatePartitionKey(partition_path, true);
         return std::make_shared<SubstraitFileSink>(
-            context_, base_path_, partition_id, filename, format_hint_, 
sample_block_, stats_, empty_delta_stats_);
+            context_, base_path_, partition_id, bucketed_write, filename, 
format_hint_, sample_block_, stats_, empty_delta_stats_);
     }
     String getName() const override { return "SubstraitPartitionedFileSink"; }
 };
diff --git a/cpp-ch/local-engine/Storages/Output/OutputFormatFile.cpp 
b/cpp-ch/local-engine/Storages/Output/OutputFormatFile.cpp
index 194d997ddf..d5ed430943 100644
--- a/cpp-ch/local-engine/Storages/Output/OutputFormatFile.cpp
+++ b/cpp-ch/local-engine/Storages/Output/OutputFormatFile.cpp
@@ -61,7 +61,6 @@ Block OutputFormatFile::createHeaderWithPreferredSchema(const 
Block & header)
         ColumnWithTypeAndName column(preferred_column.type->createColumn(), 
preferred_column.type, preferred_column.name);
         columns.emplace_back(std::move(column));
     }
-    assert(preferred_schema.columns() == index);
     return {std::move(columns)};
 }
 
diff --git a/cpp-ch/local-engine/tests/gtest_write_pipeline.cpp 
b/cpp-ch/local-engine/tests/gtest_write_pipeline.cpp
index 00f2da20c5..b764f62f54 100644
--- a/cpp-ch/local-engine/tests/gtest_write_pipeline.cpp
+++ b/cpp-ch/local-engine/tests/gtest_write_pipeline.cpp
@@ -206,14 +206,14 @@ TEST(WritePipeline, SubstraitPartitionedFileSink)
 TEST(WritePipeline, ComputePartitionedExpression)
 {
     const auto context = 
DB::Context::createCopy(QueryContext::globalContext());
-
-    auto partition_by = 
SubstraitPartitionedFileSink::make_partition_expression({"s_nationkey", 
"name"});
+    
+    Block sample_block{{STRING(), "name"}, {UINT(), "s_nationkey"}};
+    auto partition_by = 
SubstraitPartitionedFileSink::make_partition_expression({"s_nationkey", 
"name"}, sample_block);
     // auto partition_by = printColumn("s_nationkey");
 
     ASTs arguments(1, partition_by);
     ASTPtr partition_by_string = makeASTFunction("toString", 
std::move(arguments));
 
-    Block sample_block{{STRING(), "name"}, {UINT(), "s_nationkey"}};
 
     auto syntax_result = TreeRewriter(context).analyze(partition_by_string, 
sample_block.getNamesAndTypesList());
     auto partition_by_expr = ExpressionAnalyzer(partition_by_string, 
syntax_result, context).getActions(false);


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to