This is an automated email from the ASF dual-hosted git repository.

changchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new be909b6475 [GLUTEN-8836][CH] Support partition values with escape char 
(#8840)
be909b6475 is described below

commit be909b647502272f116cae0f07d811778b2a2539
Author: Wenzheng Liu <lwz9...@163.com>
AuthorDate: Wed Mar 5 14:03:21 2025 +0800

    [GLUTEN-8836][CH] Support partition values with escape char (#8840)
---
 .../execution/GlutenMergeTreePartition.scala       |  22 ++-
 .../delta/files/MergeTreeFileCommitProtocol.scala  |   2 +-
 .../v2/clickhouse/metadata/AddFileTags.scala       |   3 +-
 .../GlutenClickHouseNativeWriteTableSuite.scala    | 160 ++++++++-------------
 ...GlutenClickHouseMergeTreeWriteOnHDFSSuite.scala |  51 ++++++-
 .../apache/spark/gluten/NativeWriteChecker.scala   |  85 +++++++++++
 .../Functions/SparkPartitionEscape.cpp             | 109 ++++++++++++++
 .../local-engine/Functions/SparkPartitionEscape.h  |  57 ++++++++
 .../CommonScalarFunctionParser.cpp                 |   1 +
 .../Storages/MergeTree/SparkMergeTreeMeta.cpp      |   4 +-
 .../Storages/Output/NormalFileWriter.h             |  14 +-
 cpp-ch/local-engine/tests/CMakeLists.txt           |   1 +
 .../benchmark_spark_partition_escape_function.cpp  |  53 +++++++
 13 files changed, 453 insertions(+), 109 deletions(-)

diff --git 
a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/GlutenMergeTreePartition.scala
 
b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/GlutenMergeTreePartition.scala
index a4394740f8..cc7114dd72 100644
--- 
a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/GlutenMergeTreePartition.scala
+++ 
b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/GlutenMergeTreePartition.scala
@@ -19,6 +19,10 @@ package org.apache.gluten.execution
 import org.apache.spark.sql.connector.read.InputPartition
 import org.apache.spark.sql.types.StructType
 
+import org.apache.hadoop.fs.Path
+
+import java.net.URI
+
 case class MergeTreePartRange(
     name: String,
     dirName: String,
@@ -32,7 +36,7 @@ case class MergeTreePartRange(
   }
 }
 
-case class MergeTreePartSplit(
+case class MergeTreePartSplit private (
     name: String,
     dirName: String,
     targetNode: String,
@@ -44,6 +48,22 @@ case class MergeTreePartSplit(
   }
 }
 
+object MergeTreePartSplit {
+  def apply(
+      name: String,
+      dirName: String,
+      targetNode: String,
+      start: Long,
+      length: Long,
+      bytesOnDisk: Long
+  ): MergeTreePartSplit = {
+    // Ref to org.apache.spark.sql.delta.files.TahoeFileIndex.absolutePath
+    val uriDecodeName = new Path(new URI(name)).toString
+    val uriDecodeDirName = new Path(new URI(dirName)).toString
+    new MergeTreePartSplit(uriDecodeName, uriDecodeDirName, targetNode, start, 
length, bytesOnDisk)
+  }
+}
+
 case class GlutenMergeTreePartition(
     index: Int,
     engine: String,
diff --git 
a/backends-clickhouse/src/main/scala/org/apache/spark/sql/delta/files/MergeTreeFileCommitProtocol.scala
 
b/backends-clickhouse/src/main/scala/org/apache/spark/sql/delta/files/MergeTreeFileCommitProtocol.scala
index 13a9efa359..a8d572c93f 100644
--- 
a/backends-clickhouse/src/main/scala/org/apache/spark/sql/delta/files/MergeTreeFileCommitProtocol.scala
+++ 
b/backends-clickhouse/src/main/scala/org/apache/spark/sql/delta/files/MergeTreeFileCommitProtocol.scala
@@ -52,7 +52,7 @@ trait MergeTreeFileCommitProtocol extends FileCommitProtocol {
       dir: Option[String],
       ext: String): String = {
 
-    val partitionStr = dir.map(p => new Path(p).toUri.toString)
+    val partitionStr = dir.map(p => new Path(p).toString)
     val bucketIdStr = 
ext.split("\\.").headOption.filter(_.startsWith("_")).map(_.substring(1))
     val split = taskContext.getTaskAttemptID.getTaskID.getId
 
diff --git 
a/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/v2/clickhouse/metadata/AddFileTags.scala
 
b/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/v2/clickhouse/metadata/AddFileTags.scala
index c4c971633a..df79b161cf 100644
--- 
a/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/v2/clickhouse/metadata/AddFileTags.scala
+++ 
b/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/v2/clickhouse/metadata/AddFileTags.scala
@@ -152,7 +152,8 @@ object AddFileTags {
     rootNode.put("nullCount", "")
     // Add the `stats` into delta meta log
     val metricsStats = mapper.writeValueAsString(rootNode)
-    AddFile(name, partitionValues, bytesOnDisk, modificationTime, dataChange, 
metricsStats, tags)
+    val uriName = new Path(name).toUri.toString
+    AddFile(uriName, partitionValues, bytesOnDisk, modificationTime, 
dataChange, metricsStats, tags)
   }
 
   def addFileToAddMergeTreeParts(addFile: AddFile): AddMergeTreeParts = {
diff --git 
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/hive/GlutenClickHouseNativeWriteTableSuite.scala
 
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/hive/GlutenClickHouseNativeWriteTableSuite.scala
index 1ee0b18b11..7e8ca2236c 100644
--- 
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/hive/GlutenClickHouseNativeWriteTableSuite.scala
+++ 
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/hive/GlutenClickHouseNativeWriteTableSuite.scala
@@ -23,12 +23,14 @@ import 
org.apache.gluten.test.AllDataTypesWithComplexType.genTestData
 
 import org.apache.spark.SparkConf
 import org.apache.spark.gluten.NativeWriteChecker
+import org.apache.spark.sql.Row
 import org.apache.spark.sql.delta.DeltaLog
 import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
 import 
org.apache.spark.sql.execution.datasources.v2.clickhouse.ClickHouseConfig
 import org.apache.spark.sql.types._
 
-import scala.reflect.runtime.universe.TypeTag
+import java.io.File
+import java.sql.Date
 
 class GlutenClickHouseNativeWriteTableSuite
   extends GlutenClickHouseWholeStageTransformerSuite
@@ -67,12 +69,6 @@ class GlutenClickHouseNativeWriteTableSuite
       .setMaster("local[1]")
   }
 
-  private def getWarehouseDir = {
-    // test non-ascii path, by the way
-    // scalastyle:off nonascii
-    basePath + "/中文/spark-warehouse"
-  }
-
   private val table_name_template = "hive_%s_test"
   private val table_name_vanilla_template = "hive_%s_test_written_by_vanilla"
 
@@ -81,58 +77,7 @@ class GlutenClickHouseNativeWriteTableSuite
     super.afterAll()
   }
 
-  def getColumnName(s: String): String = {
-    s.replaceAll("\\(", "_").replaceAll("\\)", "_")
-  }
-
   import collection.immutable.ListMap
-
-  import java.io.File
-
-  def compareSource(original_table: String, table_name: String, fields: 
Seq[String]): Unit = {
-    val rowsFromOriginTable =
-      spark.sql(s"select ${fields.mkString(",")} from 
$original_table").collect()
-    val dfFromWriteTable =
-      spark.sql(
-        s"select " +
-          s"${fields
-              .map(getColumnName)
-              .mkString(",")} " +
-          s"from $table_name")
-    checkAnswer(dfFromWriteTable, rowsFromOriginTable)
-  }
-  def writeAndCheckRead(
-      original_table: String,
-      table_name: String,
-      fields: Seq[String],
-      checkNative: Boolean = true)(write: Seq[String] => Unit): Unit =
-    withDestinationTable(table_name) {
-      withNativeWriteCheck(checkNative) {
-        write(fields)
-      }
-      compareSource(original_table, table_name, fields)
-    }
-
-  def recursiveListFiles(f: File): Array[File] = {
-    val these = f.listFiles
-    these ++ these.filter(_.isDirectory).flatMap(recursiveListFiles)
-  }
-
-  def getSignature(format: String, filesOfNativeWriter: Array[File]): 
Array[(Long, Long)] = {
-    filesOfNativeWriter.map(
-      f => {
-        val df = if (format.equals("parquet")) {
-          spark.read.parquet(f.getAbsolutePath)
-        } else {
-          spark.read.orc(f.getAbsolutePath)
-        }
-        (
-          df.count(),
-          df.agg(("int_field", 
"sum")).collect().apply(0).apply(0).asInstanceOf[Long]
-        )
-      })
-  }
-
   private val fields_ = ListMap(
     ("string_field", "string"),
     ("int_field", "int"),
@@ -146,22 +91,6 @@ class GlutenClickHouseNativeWriteTableSuite
     ("date_field", "date")
   )
 
-  def nativeWrite2(
-      f: String => (String, String, String),
-      extraCheck: (String, String) => Unit = null,
-      checkNative: Boolean = true): Unit = nativeWrite {
-    format =>
-      val (table_name, table_create_sql, insert_sql) = f(format)
-      withDestinationTable(table_name, Option(table_create_sql)) {
-        checkInsertQuery(insert_sql, checkNative)
-        Option(extraCheck).foreach(_(table_name, format))
-      }
-  }
-
-  def withSource[A <: Product: TypeTag](data: Seq[A], viewName: String, pairs: 
(String, String)*)(
-      block: => Unit): Unit =
-    withSource(spark.createDataFrame(data), viewName, pairs: _*)(block)
-
   private lazy val supplierSchema = StructType.apply(
     Seq(
       StructField.apply("s_suppkey", LongType, nullable = true),
@@ -618,18 +547,7 @@ class GlutenClickHouseNativeWriteTableSuite
                 .saveAsTable(table_name_vanilla)
             }
           }
-          val sigsOfNativeWriter =
-            getSignature(
-              format,
-              recursiveListFiles(new File(getWarehouseDir + "/" + table_name))
-                .filter(_.getName.endsWith(s".$format"))).sorted
-          val sigsOfVanillaWriter =
-            getSignature(
-              format,
-              recursiveListFiles(new File(getWarehouseDir + "/" + 
table_name_vanilla))
-                .filter(_.getName.endsWith(s".$format"))).sorted
-
-          assertResult(sigsOfVanillaWriter)(sigsOfNativeWriter)
+          compareWriteFilesSignature(format, table_name, table_name_vanilla, 
"sum(int_field)")
       }
     }
   }
@@ -680,18 +598,7 @@ class GlutenClickHouseNativeWriteTableSuite
                 .bucketBy(10, "byte_field", "string_field")
                 .saveAsTable(table_name_vanilla)
             }
-            val sigsOfNativeWriter =
-              getSignature(
-                format,
-                recursiveListFiles(new File(getWarehouseDir + "/" + 
table_name))
-                  .filter(_.getName.endsWith(s".$format"))).sorted
-            val sigsOfVanillaWriter =
-              getSignature(
-                format,
-                recursiveListFiles(new File(getWarehouseDir + "/" + 
table_name_vanilla))
-                  .filter(_.getName.endsWith(s".$format"))).sorted
-
-            assertResult(sigsOfVanillaWriter)(sigsOfNativeWriter)
+            compareWriteFilesSignature(format, table_name, table_name_vanilla, 
"sum(int_field)")
           }
       }
     }
@@ -754,6 +661,63 @@ class GlutenClickHouseNativeWriteTableSuite
     }
   }
 
+  test("test partitioned with escaped characters") {
+
+    val schema = StructType(
+      Seq(
+        StructField.apply("id", IntegerType, nullable = true),
+        StructField.apply("escape", StringType, nullable = true),
+        StructField.apply("bucket/col", StringType, nullable = true),
+        StructField.apply("part=col1", DateType, nullable = true),
+        StructField.apply("part_col2", StringType, nullable = true)
+      ))
+
+    val data: Seq[Row] = Seq(
+      Row(1, "=", "00000", Date.valueOf("2024-01-01"), "2024=01/01"),
+      Row(2, "/", "00000", Date.valueOf("2024-01-01"), "2024=01/01"),
+      Row(3, "#", "00000", Date.valueOf("2024-01-01"), "2024#01:01"),
+      Row(4, ":", "00001", Date.valueOf("2024-01-02"), "2024#01:01"),
+      Row(5, "\\", "00001", Date.valueOf("2024-01-02"), "2024\\01\u000101"),
+      Row(6, "\u0001", "000001", Date.valueOf("2024-01-02"), 
"2024\\01\u000101"),
+      Row(7, "", "000002", null, null)
+    )
+
+    val df = spark.createDataFrame(spark.sparkContext.parallelize(data), 
schema)
+    df.createOrReplaceTempView("origin_table")
+    spark.sql("select * from origin_table").show()
+
+    nativeWrite {
+      format =>
+        val table_name = table_name_template.format(format)
+        spark.sql(s"drop table IF EXISTS $table_name")
+        writeAndCheckRead("origin_table", table_name, schema.fieldNames.map(f 
=> s"`$f`")) {
+          _ =>
+            spark
+              .table("origin_table")
+              .write
+              .format(format)
+              .partitionBy("part=col1", "part_col2")
+              .bucketBy(2, "bucket/col")
+              .saveAsTable(table_name)
+        }
+
+        val table_name_vanilla = table_name_vanilla_template.format(format)
+        spark.sql(s"drop table IF EXISTS $table_name_vanilla")
+        withSQLConf((GlutenConfig.NATIVE_WRITER_ENABLED.key, "false")) {
+          withNativeWriteCheck(checkNative = false) {
+            spark
+              .table("origin_table")
+              .write
+              .format(format)
+              .partitionBy("part=col1", "part_col2")
+              .bucketBy(2, "bucket/col")
+              .saveAsTable(table_name_vanilla)
+          }
+          compareWriteFilesSignature(format, table_name, table_name_vanilla, 
"sum(id)")
+        }
+    }
+  }
+
   test("test bucketed by constant") {
     nativeWrite {
       format =>
diff --git 
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/mergetree/GlutenClickHouseMergeTreeWriteOnHDFSSuite.scala
 
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/mergetree/GlutenClickHouseMergeTreeWriteOnHDFSSuite.scala
index 6d404fe3aa..faffe19136 100644
--- 
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/mergetree/GlutenClickHouseMergeTreeWriteOnHDFSSuite.scala
+++ 
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/mergetree/GlutenClickHouseMergeTreeWriteOnHDFSSuite.scala
@@ -21,12 +21,13 @@ import org.apache.gluten.config.GlutenConfig
 import org.apache.gluten.execution.{FileSourceScanExecTransformer, 
GlutenClickHouseTPCHAbstractSuite}
 
 import org.apache.spark.SparkConf
-import org.apache.spark.sql.SaveMode
+import org.apache.spark.sql.{Row, SaveMode}
 import org.apache.spark.sql.delta.catalog.ClickHouseTableV2
 import org.apache.spark.sql.delta.files.TahoeFileIndex
 import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
 import org.apache.spark.sql.execution.datasources.mergetree.StorageMeta
 import 
org.apache.spark.sql.execution.datasources.v2.clickhouse.metadata.AddMergeTreeParts
+import org.apache.spark.sql.types._
 
 import org.apache.commons.io.FileUtils
 import org.apache.hadoop.conf.Configuration
@@ -359,6 +360,54 @@ class GlutenClickHouseMergeTreeWriteOnHDFSSuite
     spark.sql("drop table lineitem_mergetree_partition_hdfs")
   }
 
+  test("test partition values with escape chars") {
+
+    val schema = StructType(
+      Seq(
+        StructField.apply("id", IntegerType, nullable = true),
+        StructField.apply("escape", StringType, nullable = true)
+      ))
+
+    // scalastyle:off nonascii
+    val data: Seq[Row] = Seq(
+      Row(1, "="),
+      Row(2, "/"),
+      Row(3, "#"),
+      Row(4, ":"),
+      Row(5, "\\"),
+      Row(6, "\u0001"),
+      Row(7, "中文"),
+      Row(8, " "),
+      Row(9, "a b")
+    )
+    // scalastyle:on nonascii
+
+    val df = spark.createDataFrame(spark.sparkContext.parallelize(data), 
schema)
+    df.createOrReplaceTempView("origin_table")
+
+    // spark.conf.set("spark.gluten.enabled", "false")
+    spark.sql(s"""
+                 |DROP TABLE IF EXISTS partition_escape;
+                 |""".stripMargin)
+
+    spark.sql(s"""
+                 |CREATE TABLE IF NOT EXISTS partition_escape
+                 |(
+                 | c1  int,
+                 | c2  string
+                 |)
+                 |USING clickhouse
+                 |PARTITIONED BY (c2)
+                 |TBLPROPERTIES (storage_policy='__hdfs_main',
+                 |               orderByKey='c1',
+                 |               primaryKey='c1')
+                 |LOCATION '$HDFS_URL/test/partition_escape'
+                 |""".stripMargin)
+
+    spark.sql("insert into partition_escape select * from origin_table")
+    spark.sql("select * from partition_escape").show()
+  }
+
   testSparkVersionLE33("test mergetree write with bucket table") {
     spark.sql(s"""
                  |DROP TABLE IF EXISTS lineitem_mergetree_bucket_hdfs;
diff --git 
a/backends-clickhouse/src/test/scala/org/apache/spark/gluten/NativeWriteChecker.scala
 
b/backends-clickhouse/src/test/scala/org/apache/spark/gluten/NativeWriteChecker.scala
index 384780e7d2..481e340d87 100644
--- 
a/backends-clickhouse/src/test/scala/org/apache/spark/gluten/NativeWriteChecker.scala
+++ 
b/backends-clickhouse/src/test/scala/org/apache/spark/gluten/NativeWriteChecker.scala
@@ -25,12 +25,21 @@ import 
org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
 import org.apache.spark.sql.execution.datasources.FakeRowAdaptor
 import org.apache.spark.sql.util.QueryExecutionListener
 
+import java.io.File
+
+import scala.reflect.runtime.universe.TypeTag
+
 trait NativeWriteChecker
   extends GlutenClickHouseWholeStageTransformerSuite
   with AdaptiveSparkPlanHelper {
 
   private val formats: Seq[String] = Seq("orc", "parquet")
 
+  // test non-ascii path, by the way
+  // scalastyle:off nonascii
+  protected def getWarehouseDir: String = basePath + "/中文/spark-warehouse"
+  // scalastyle:on nonascii
+
   def withNativeWriteCheck(checkNative: Boolean)(block: => Unit): Unit = {
     var nativeUsed = false
 
@@ -82,6 +91,22 @@ trait NativeWriteChecker
     }
   }
 
+  def nativeWrite2(
+      f: String => (String, String, String),
+      extraCheck: (String, String) => Unit = null,
+      checkNative: Boolean = true): Unit = nativeWrite {
+    format =>
+      val (table_name, table_create_sql, insert_sql) = f(format)
+      withDestinationTable(table_name, Option(table_create_sql)) {
+        checkInsertQuery(insert_sql, checkNative)
+        Option(extraCheck).foreach(_(table_name, format))
+      }
+  }
+
+  def withSource[A <: Product: TypeTag](data: Seq[A], viewName: String, pairs: 
(String, String)*)(
+      block: => Unit): Unit =
+    withSource(spark.createDataFrame(data), viewName, pairs: _*)(block)
+
   def withSource(df: Dataset[Row], viewName: String, pairs: (String, String)*)(
       block: => Unit): Unit = {
     withSQLConf(pairs: _*) {
@@ -91,4 +116,64 @@ trait NativeWriteChecker
       }
     }
   }
+
+  def getColumnName(col: String): String = {
+    col.replaceAll("\\(", "_").replaceAll("\\)", "_")
+  }
+
+  def compareSource(originTable: String, table: String, fields: Seq[String]): 
Unit = {
+    def query(table: String, selectFields: Seq[String]): String = {
+      s"select ${selectFields.mkString(",")} from $table"
+    }
+    val expectedRows = spark.sql(query(originTable, fields)).collect()
+    val actual = spark.sql(query(table, fields.map(getColumnName)))
+    checkAnswer(actual, expectedRows)
+  }
+
+  def writeAndCheckRead(
+      original_table: String,
+      table_name: String,
+      fields: Seq[String],
+      checkNative: Boolean = true)(write: Seq[String] => Unit): Unit = {
+    withDestinationTable(table_name) {
+      withNativeWriteCheck(checkNative) {
+        write(fields)
+      }
+      compareSource(original_table, table_name, fields)
+    }
+  }
+
+  def compareWriteFilesSignature(
+      format: String,
+      table: String,
+      vanillaTable: String,
+      sigExpr: String): Unit = {
+    val tableFiles = recursiveListFiles(new File(getWarehouseDir + "/" + 
table))
+      .filter(_.getName.endsWith(s".$format"))
+    val sigsOfNativeWriter = getSignature(format, tableFiles, sigExpr).sorted
+    val vanillaTableFiles = recursiveListFiles(new File(getWarehouseDir + "/" 
+ vanillaTable))
+      .filter(_.getName.endsWith(s".$format"))
+    val sigsOfVanillaWriter = getSignature(format, vanillaTableFiles, 
sigExpr).sorted
+    assertResult(sigsOfVanillaWriter)(sigsOfNativeWriter)
+  }
+
+  def recursiveListFiles(f: File): Array[File] = {
+    val these = f.listFiles
+    these ++ these.filter(_.isDirectory).flatMap(recursiveListFiles)
+  }
+
+  def getSignature(
+      format: String,
+      writeFiles: Array[File],
+      sigExpr: String): Array[(Long, Long)] = {
+    writeFiles.map(
+      f => {
+        val df = if (format.equals("parquet")) {
+          spark.read.parquet(f.getAbsolutePath)
+        } else {
+          spark.read.orc(f.getAbsolutePath)
+        }
+        (df.count(), 
df.selectExpr(sigExpr).collect().apply(0).apply(0).asInstanceOf[Long])
+      })
+  }
 }
diff --git a/cpp-ch/local-engine/Functions/SparkPartitionEscape.cpp 
b/cpp-ch/local-engine/Functions/SparkPartitionEscape.cpp
new file mode 100644
index 0000000000..522f9ddee2
--- /dev/null
+++ b/cpp-ch/local-engine/Functions/SparkPartitionEscape.cpp
@@ -0,0 +1,109 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "SparkPartitionEscape.h"
+#include <Functions/FunctionFactory.h>
+#include <Common/Exception.h>
+#include <DataTypes/IDataType.h>
+#include <DataTypes/DataTypeString.h>
+#include <sstream>
+#include <iomanip>
+#include <string>
+
+namespace DB
+{
+namespace ErrorCodes
+{
+extern const int ILLEGAL_TYPE_OF_ARGUMENT;
+extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
+}
+}
+
+namespace local_engine
+{
+
+const std::vector<char> SparkPartitionEscape::ESCAPE_CHAR_LIST = {
+    '\u0001', '\u0002', '\u0003', '\u0004', '\u0005', '\u0006', '\u0007', 
'\u0008', '\u0009',
+    '\n', '\u000B', '\u000C', '\r', '\u000E', '\u000F', '\u0010', '\u0011', 
'\u0012', '\u0013',
+    '\u0014', '\u0015', '\u0016', '\u0017', '\u0018', '\u0019', '\u001A', 
'\u001B', '\u001C',
+    '\u001D', '\u001E', '\u001F', '"', '#', '%', '\'', '*', '/', ':', '=', 
'?', '\\', '\u007F',
+    '{', '[', ']', '^'
+};
+
+const std::bitset<128> SparkPartitionEscape::ESCAPE_BITSET = []()
+{
+    std::bitset<128> bitset;
+    for (char c : SparkPartitionEscape::ESCAPE_CHAR_LIST)
+    {
+        bitset.set(c);
+    }
+#ifdef _WIN32
+    bitset.set(' ');
+    bitset.set('<');
+    bitset.set('>');
+    bitset.set('|');
+#endif
+    return bitset;
+}();
+
+static bool needsEscaping(char c) {
+    return c >= 0 && c < SparkPartitionEscape::ESCAPE_BITSET.size()
+        && SparkPartitionEscape::ESCAPE_BITSET.test(c);
+}
+
+static std::string escapePathName(const std::string & path) {
+    std::ostringstream builder;
+    for (char c : path) {
+        if (needsEscaping(c)) {
+            builder << '%' << std::uppercase << std::setw(2) << 
std::setfill('0') << std::hex << (int)c;
+        } else {
+            builder << c;
+        }
+    }
+
+    return builder.str();
+}
+
+DB::DataTypePtr SparkPartitionEscape::getReturnTypeImpl(const DB::DataTypes & 
arguments) const
+{
+    if (arguments.size() != 1)
+        throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, 
"Function {} argument size must be 1", name);
+    
+    if (!isString(arguments[0]))
+        throw DB::Exception(DB::ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, 
"Argument of function {} must be String", getName());
+
+    return std::make_shared<DataTypeString>();
+}
+
+DB::ColumnPtr SparkPartitionEscape::executeImpl(
+   const DB::ColumnsWithTypeAndName & arguments, const DB::DataTypePtr & 
result_type, size_t input_rows_count) const
+{
+   auto result = result_type->createColumn();
+   result->reserve(input_rows_count);
+
+   for (size_t i = 0; i < input_rows_count; ++i)
+   {
+       auto escaped_name = 
escapePathName(arguments[0].column->getDataAt(i).toString());
+       result->insertData(escaped_name.c_str(), escaped_name.size());
+   }
+   return result;
+}
+
+REGISTER_FUNCTION(SparkPartitionEscape)
+{
+   factory.registerFunction<SparkPartitionEscape>();
+}
+}
diff --git a/cpp-ch/local-engine/Functions/SparkPartitionEscape.h 
b/cpp-ch/local-engine/Functions/SparkPartitionEscape.h
new file mode 100644
index 0000000000..916134506a
--- /dev/null
+++ b/cpp-ch/local-engine/Functions/SparkPartitionEscape.h
@@ -0,0 +1,57 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#pragma once
+#include <Columns/IColumn.h>
+#include <Core/ColumnsWithTypeAndName.h>
+#include <DataTypes/DataTypeDate.h>
+#include <DataTypes/DataTypeNullable.h>
+#include <DataTypes/IDataType.h>
+#include <Functions/IFunction.h>
+#include <Interpreters/Context.h>
+#include <bitset>
+
+namespace DB
+{
+namespace ErrorCodes
+{
+extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
+}
+}
+
+using namespace DB;
+
+namespace local_engine
+{
+
+class SparkPartitionEscape : public DB::IFunction
+{
+public:
+    static const std::vector<char> ESCAPE_CHAR_LIST;
+    static const std::bitset<128> ESCAPE_BITSET;
+    static constexpr auto name = "sparkPartitionEscape";
+    static FunctionPtr create(ContextPtr /*context*/) { return 
std::make_shared<SparkPartitionEscape>(); }
+    SparkPartitionEscape() = default;
+    ~SparkPartitionEscape() override = default;
+    String getName() const override { return name; }
+    size_t getNumberOfArguments() const override { return 1; }
+    bool isSuitableForShortCircuitArgumentsExecution(const 
DB::DataTypesWithConstInfo & /*arguments*/) const override { return true; }
+    DB::DataTypePtr getReturnTypeImpl(const DB::DataTypes & /*arguments*/) 
const override;
+    DB::ColumnPtr executeImpl(
+        const DB::ColumnsWithTypeAndName & arguments, const DB::DataTypePtr & 
result_type, size_t /*input_rows_count*/) const override;
+};
+
+}
diff --git 
a/cpp-ch/local-engine/Parser/scalar_function_parser/CommonScalarFunctionParser.cpp
 
b/cpp-ch/local-engine/Parser/scalar_function_parser/CommonScalarFunctionParser.cpp
index ba2129deb1..841e51c00a 100644
--- 
a/cpp-ch/local-engine/Parser/scalar_function_parser/CommonScalarFunctionParser.cpp
+++ 
b/cpp-ch/local-engine/Parser/scalar_function_parser/CommonScalarFunctionParser.cpp
@@ -130,6 +130,7 @@ REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Uuid, uuid, 
generateUUIDv4);
 REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Levenshtein, levenshtein, 
editDistanceUTF8);
 REGISTER_COMMON_SCALAR_FUNCTION_PARSER(FormatString, format_string, printf);
 REGISTER_COMMON_SCALAR_FUNCTION_PARSER(SoundEx, soundex, soundex);
+REGISTER_COMMON_SCALAR_FUNCTION_PARSER(PartitionEscape, partition_escape, 
sparkPartitionEscape);
 
 // hash functions
 REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Crc32, crc32, CRC32);
diff --git a/cpp-ch/local-engine/Storages/MergeTree/SparkMergeTreeMeta.cpp 
b/cpp-ch/local-engine/Storages/MergeTree/SparkMergeTreeMeta.cpp
index ef4d5504ff..b8e1a3cbd4 100644
--- a/cpp-ch/local-engine/Storages/MergeTree/SparkMergeTreeMeta.cpp
+++ b/cpp-ch/local-engine/Storages/MergeTree/SparkMergeTreeMeta.cpp
@@ -212,9 +212,7 @@ MergeTreeTableInstance::MergeTreeTableInstance(const 
std::string & info) : Merge
     while (!in.eof())
     {
         MergeTreePart part;
-        std::string encoded_name;
-        readString(encoded_name, in);
-        Poco::URI::decode(encoded_name, part.name);
+        readString(part.name, in);
         assertChar('\n', in);
         readIntText(part.begin, in);
         assertChar('\n', in);
diff --git a/cpp-ch/local-engine/Storages/Output/NormalFileWriter.h 
b/cpp-ch/local-engine/Storages/Output/NormalFileWriter.h
index c0c762906a..77096f3f49 100644
--- a/cpp-ch/local-engine/Storages/Output/NormalFileWriter.h
+++ b/cpp-ch/local-engine/Storages/Output/NormalFileWriter.h
@@ -416,17 +416,23 @@ public:
         for (const auto & column : partition_columns)
         {
             // partition_column=
-            std::string key = add_slash ? fmt::format("/{}=", column) : 
fmt::format("{}=", column);
+            auto column_name = std::make_shared<DB::ASTLiteral>(column);
+            auto escaped_name = makeASTFunction("sparkPartitionEscape", 
DB::ASTs{column_name});
+            if (add_slash)
+                arguments.emplace_back(std::make_shared<DB::ASTLiteral>("/"));
             add_slash = true;
-            arguments.emplace_back(std::make_shared<DB::ASTLiteral>(key));
+            arguments.emplace_back(escaped_name);
+            arguments.emplace_back(std::make_shared<DB::ASTLiteral>("="));
 
             // ifNull(toString(partition_column), DEFAULT_PARTITION_NAME)
             // FIXME if toString(partition_column) is empty
-            auto column_ast = std::make_shared<DB::ASTIdentifier>(column);
+            auto column_ast = makeASTFunction("toString", 
DB::ASTs{std::make_shared<DB::ASTIdentifier>(column)});
+            auto escaped_value = makeASTFunction("sparkPartitionEscape", 
DB::ASTs{column_ast});
             DB::ASTs if_null_args{
-                makeASTFunction("toString", DB::ASTs{column_ast}), 
std::make_shared<DB::ASTLiteral>(DEFAULT_PARTITION_NAME)};
+                makeASTFunction("toString", DB::ASTs{escaped_value}), 
std::make_shared<DB::ASTLiteral>(DEFAULT_PARTITION_NAME)};
             arguments.emplace_back(makeASTFunction("ifNull", 
std::move(if_null_args)));
         }
+
         if (isBucketedWrite(input_header))
         {
             DB::ASTs args{std::make_shared<DB::ASTLiteral>("%05d"), 
std::make_shared<DB::ASTIdentifier>(BUCKET_COLUMN_NAME)};
diff --git a/cpp-ch/local-engine/tests/CMakeLists.txt 
b/cpp-ch/local-engine/tests/CMakeLists.txt
index 09ca32a01a..9c18d70b0f 100644
--- a/cpp-ch/local-engine/tests/CMakeLists.txt
+++ b/cpp-ch/local-engine/tests/CMakeLists.txt
@@ -107,6 +107,7 @@ if(ENABLE_BENCHMARKS)
     benchmark_spark_row.cpp
     benchmark_unix_timestamp_function.cpp
     benchmark_spark_functions.cpp
+    benchmark_spark_partition_escape_function.cpp
     benchmark_cast_float_function.cpp
     benchmark_to_datetime_function.cpp
     benchmark_spark_divide_function.cpp
diff --git 
a/cpp-ch/local-engine/tests/benchmark_spark_partition_escape_function.cpp 
b/cpp-ch/local-engine/tests/benchmark_spark_partition_escape_function.cpp
new file mode 100644
index 0000000000..3299eb9fee
--- /dev/null
+++ b/cpp-ch/local-engine/tests/benchmark_spark_partition_escape_function.cpp
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <Core/Block.h>
+#include <DataTypes/DataTypeFactory.h>
+#include <Functions/FunctionFactory.h>
+#include <Parser/FunctionParser.h>
+#include <benchmark/benchmark.h>
+#include <Common/QueryContext.h>
+
+using namespace DB;
+
+static Block createDataBlock(size_t rows)
+{
+   auto type = DataTypeFactory::instance().get("String");
+   auto column = type->createColumn();
+   for (size_t i = 0; i < rows; ++i)
+   {
+       char ch = static_cast<char>(i % 128);
+       std::string str = "escape_" + ch;
+       column->insert(str);
+   }
+   Block block;
+   block.insert(ColumnWithTypeAndName(std::move(column), type, "d"));
+   return std::move(block);
+}
+
+static void BM_CHSparkPartitionEscape(benchmark::State & state)
+{
+   using namespace DB;
+   auto & factory = FunctionFactory::instance();
+   auto function = factory.get("sparkPartitionEscape", 
local_engine::QueryContext::globalContext());
+   Block block = createDataBlock(1000000);
+   auto executable = function->build(block.getColumnsWithTypeAndName());
+   for (auto _ : state) [[maybe_unused]]
+       auto result = executable->execute(block.getColumnsWithTypeAndName(), 
executable->getResultType(), block.rows(), false);
+}
+
+BENCHMARK(BM_CHSparkPartitionEscape)->Unit(benchmark::kMillisecond)->Iterations(50);
\ No newline at end of file


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@gluten.apache.org
For additional commands, e-mail: commits-h...@gluten.apache.org

Reply via email to