This is an automated email from the ASF dual-hosted git repository. changchen pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push: new be909b6475 [GLUTEN-8836][CH] Support partition values with escape char (#8840) be909b6475 is described below commit be909b647502272f116cae0f07d811778b2a2539 Author: Wenzheng Liu <lwz9...@163.com> AuthorDate: Wed Mar 5 14:03:21 2025 +0800 [GLUTEN-8836][CH] Support partition values with escape char (#8840) --- .../execution/GlutenMergeTreePartition.scala | 22 ++- .../delta/files/MergeTreeFileCommitProtocol.scala | 2 +- .../v2/clickhouse/metadata/AddFileTags.scala | 3 +- .../GlutenClickHouseNativeWriteTableSuite.scala | 160 ++++++++------------- ...GlutenClickHouseMergeTreeWriteOnHDFSSuite.scala | 51 ++++++- .../apache/spark/gluten/NativeWriteChecker.scala | 85 +++++++++++ .../Functions/SparkPartitionEscape.cpp | 109 ++++++++++++++ .../local-engine/Functions/SparkPartitionEscape.h | 57 ++++++++ .../CommonScalarFunctionParser.cpp | 1 + .../Storages/MergeTree/SparkMergeTreeMeta.cpp | 4 +- .../Storages/Output/NormalFileWriter.h | 14 +- cpp-ch/local-engine/tests/CMakeLists.txt | 1 + .../benchmark_spark_partition_escape_function.cpp | 53 +++++++ 13 files changed, 453 insertions(+), 109 deletions(-) diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/GlutenMergeTreePartition.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/GlutenMergeTreePartition.scala index a4394740f8..cc7114dd72 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/GlutenMergeTreePartition.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/GlutenMergeTreePartition.scala @@ -19,6 +19,10 @@ package org.apache.gluten.execution import org.apache.spark.sql.connector.read.InputPartition import org.apache.spark.sql.types.StructType +import org.apache.hadoop.fs.Path + +import java.net.URI + case class MergeTreePartRange( name: String, dirName: String, @@ -32,7 +36,7 @@ case class MergeTreePartRange( } } -case class MergeTreePartSplit( +case class MergeTreePartSplit private ( name: String, dirName: String, targetNode: String, @@ -44,6 +48,22 @@ case class MergeTreePartSplit( } } +object MergeTreePartSplit { + def apply( + name: String, + dirName: String, + targetNode: String, + start: Long, + length: Long, + bytesOnDisk: Long + ): MergeTreePartSplit = { + // Ref to org.apache.spark.sql.delta.files.TahoeFileIndex.absolutePath + val uriDecodeName = new Path(new URI(name)).toString + val uriDecodeDirName = new Path(new URI(dirName)).toString + new MergeTreePartSplit(uriDecodeName, uriDecodeDirName, targetNode, start, length, bytesOnDisk) + } +} + case class GlutenMergeTreePartition( index: Int, engine: String, diff --git a/backends-clickhouse/src/main/scala/org/apache/spark/sql/delta/files/MergeTreeFileCommitProtocol.scala b/backends-clickhouse/src/main/scala/org/apache/spark/sql/delta/files/MergeTreeFileCommitProtocol.scala index 13a9efa359..a8d572c93f 100644 --- a/backends-clickhouse/src/main/scala/org/apache/spark/sql/delta/files/MergeTreeFileCommitProtocol.scala +++ b/backends-clickhouse/src/main/scala/org/apache/spark/sql/delta/files/MergeTreeFileCommitProtocol.scala @@ -52,7 +52,7 @@ trait MergeTreeFileCommitProtocol extends FileCommitProtocol { dir: Option[String], ext: String): String = { - val partitionStr = dir.map(p => new Path(p).toUri.toString) + val partitionStr = dir.map(p => new Path(p).toString) val bucketIdStr = ext.split("\\.").headOption.filter(_.startsWith("_")).map(_.substring(1)) val split = taskContext.getTaskAttemptID.getTaskID.getId diff --git a/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/v2/clickhouse/metadata/AddFileTags.scala b/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/v2/clickhouse/metadata/AddFileTags.scala index c4c971633a..df79b161cf 100644 --- a/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/v2/clickhouse/metadata/AddFileTags.scala +++ b/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/v2/clickhouse/metadata/AddFileTags.scala @@ -152,7 +152,8 @@ object AddFileTags { rootNode.put("nullCount", "") // Add the `stats` into delta meta log val metricsStats = mapper.writeValueAsString(rootNode) - AddFile(name, partitionValues, bytesOnDisk, modificationTime, dataChange, metricsStats, tags) + val uriName = new Path(name).toUri.toString + AddFile(uriName, partitionValues, bytesOnDisk, modificationTime, dataChange, metricsStats, tags) } def addFileToAddMergeTreeParts(addFile: AddFile): AddMergeTreeParts = { diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/hive/GlutenClickHouseNativeWriteTableSuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/hive/GlutenClickHouseNativeWriteTableSuite.scala index 1ee0b18b11..7e8ca2236c 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/hive/GlutenClickHouseNativeWriteTableSuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/hive/GlutenClickHouseNativeWriteTableSuite.scala @@ -23,12 +23,14 @@ import org.apache.gluten.test.AllDataTypesWithComplexType.genTestData import org.apache.spark.SparkConf import org.apache.spark.gluten.NativeWriteChecker +import org.apache.spark.sql.Row import org.apache.spark.sql.delta.DeltaLog import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.execution.datasources.v2.clickhouse.ClickHouseConfig import org.apache.spark.sql.types._ -import scala.reflect.runtime.universe.TypeTag +import java.io.File +import java.sql.Date class GlutenClickHouseNativeWriteTableSuite extends GlutenClickHouseWholeStageTransformerSuite @@ -67,12 +69,6 @@ class GlutenClickHouseNativeWriteTableSuite .setMaster("local[1]") } - private def getWarehouseDir = { - // test non-ascii path, by the way - // scalastyle:off nonascii - basePath + "/中文/spark-warehouse" - } - private val table_name_template = "hive_%s_test" private val table_name_vanilla_template = "hive_%s_test_written_by_vanilla" @@ -81,58 +77,7 @@ class GlutenClickHouseNativeWriteTableSuite super.afterAll() } - def getColumnName(s: String): String = { - s.replaceAll("\\(", "_").replaceAll("\\)", "_") - } - import collection.immutable.ListMap - - import java.io.File - - def compareSource(original_table: String, table_name: String, fields: Seq[String]): Unit = { - val rowsFromOriginTable = - spark.sql(s"select ${fields.mkString(",")} from $original_table").collect() - val dfFromWriteTable = - spark.sql( - s"select " + - s"${fields - .map(getColumnName) - .mkString(",")} " + - s"from $table_name") - checkAnswer(dfFromWriteTable, rowsFromOriginTable) - } - def writeAndCheckRead( - original_table: String, - table_name: String, - fields: Seq[String], - checkNative: Boolean = true)(write: Seq[String] => Unit): Unit = - withDestinationTable(table_name) { - withNativeWriteCheck(checkNative) { - write(fields) - } - compareSource(original_table, table_name, fields) - } - - def recursiveListFiles(f: File): Array[File] = { - val these = f.listFiles - these ++ these.filter(_.isDirectory).flatMap(recursiveListFiles) - } - - def getSignature(format: String, filesOfNativeWriter: Array[File]): Array[(Long, Long)] = { - filesOfNativeWriter.map( - f => { - val df = if (format.equals("parquet")) { - spark.read.parquet(f.getAbsolutePath) - } else { - spark.read.orc(f.getAbsolutePath) - } - ( - df.count(), - df.agg(("int_field", "sum")).collect().apply(0).apply(0).asInstanceOf[Long] - ) - }) - } - private val fields_ = ListMap( ("string_field", "string"), ("int_field", "int"), @@ -146,22 +91,6 @@ class GlutenClickHouseNativeWriteTableSuite ("date_field", "date") ) - def nativeWrite2( - f: String => (String, String, String), - extraCheck: (String, String) => Unit = null, - checkNative: Boolean = true): Unit = nativeWrite { - format => - val (table_name, table_create_sql, insert_sql) = f(format) - withDestinationTable(table_name, Option(table_create_sql)) { - checkInsertQuery(insert_sql, checkNative) - Option(extraCheck).foreach(_(table_name, format)) - } - } - - def withSource[A <: Product: TypeTag](data: Seq[A], viewName: String, pairs: (String, String)*)( - block: => Unit): Unit = - withSource(spark.createDataFrame(data), viewName, pairs: _*)(block) - private lazy val supplierSchema = StructType.apply( Seq( StructField.apply("s_suppkey", LongType, nullable = true), @@ -618,18 +547,7 @@ class GlutenClickHouseNativeWriteTableSuite .saveAsTable(table_name_vanilla) } } - val sigsOfNativeWriter = - getSignature( - format, - recursiveListFiles(new File(getWarehouseDir + "/" + table_name)) - .filter(_.getName.endsWith(s".$format"))).sorted - val sigsOfVanillaWriter = - getSignature( - format, - recursiveListFiles(new File(getWarehouseDir + "/" + table_name_vanilla)) - .filter(_.getName.endsWith(s".$format"))).sorted - - assertResult(sigsOfVanillaWriter)(sigsOfNativeWriter) + compareWriteFilesSignature(format, table_name, table_name_vanilla, "sum(int_field)") } } } @@ -680,18 +598,7 @@ class GlutenClickHouseNativeWriteTableSuite .bucketBy(10, "byte_field", "string_field") .saveAsTable(table_name_vanilla) } - val sigsOfNativeWriter = - getSignature( - format, - recursiveListFiles(new File(getWarehouseDir + "/" + table_name)) - .filter(_.getName.endsWith(s".$format"))).sorted - val sigsOfVanillaWriter = - getSignature( - format, - recursiveListFiles(new File(getWarehouseDir + "/" + table_name_vanilla)) - .filter(_.getName.endsWith(s".$format"))).sorted - - assertResult(sigsOfVanillaWriter)(sigsOfNativeWriter) + compareWriteFilesSignature(format, table_name, table_name_vanilla, "sum(int_field)") } } } @@ -754,6 +661,63 @@ class GlutenClickHouseNativeWriteTableSuite } } + test("test partitioned with escaped characters") { + + val schema = StructType( + Seq( + StructField.apply("id", IntegerType, nullable = true), + StructField.apply("escape", StringType, nullable = true), + StructField.apply("bucket/col", StringType, nullable = true), + StructField.apply("part=col1", DateType, nullable = true), + StructField.apply("part_col2", StringType, nullable = true) + )) + + val data: Seq[Row] = Seq( + Row(1, "=", "00000", Date.valueOf("2024-01-01"), "2024=01/01"), + Row(2, "/", "00000", Date.valueOf("2024-01-01"), "2024=01/01"), + Row(3, "#", "00000", Date.valueOf("2024-01-01"), "2024#01:01"), + Row(4, ":", "00001", Date.valueOf("2024-01-02"), "2024#01:01"), + Row(5, "\\", "00001", Date.valueOf("2024-01-02"), "2024\\01\u000101"), + Row(6, "\u0001", "000001", Date.valueOf("2024-01-02"), "2024\\01\u000101"), + Row(7, "", "000002", null, null) + ) + + val df = spark.createDataFrame(spark.sparkContext.parallelize(data), schema) + df.createOrReplaceTempView("origin_table") + spark.sql("select * from origin_table").show() + + nativeWrite { + format => + val table_name = table_name_template.format(format) + spark.sql(s"drop table IF EXISTS $table_name") + writeAndCheckRead("origin_table", table_name, schema.fieldNames.map(f => s"`$f`")) { + _ => + spark + .table("origin_table") + .write + .format(format) + .partitionBy("part=col1", "part_col2") + .bucketBy(2, "bucket/col") + .saveAsTable(table_name) + } + + val table_name_vanilla = table_name_vanilla_template.format(format) + spark.sql(s"drop table IF EXISTS $table_name_vanilla") + withSQLConf((GlutenConfig.NATIVE_WRITER_ENABLED.key, "false")) { + withNativeWriteCheck(checkNative = false) { + spark + .table("origin_table") + .write + .format(format) + .partitionBy("part=col1", "part_col2") + .bucketBy(2, "bucket/col") + .saveAsTable(table_name_vanilla) + } + compareWriteFilesSignature(format, table_name, table_name_vanilla, "sum(id)") + } + } + } + test("test bucketed by constant") { nativeWrite { format => diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/mergetree/GlutenClickHouseMergeTreeWriteOnHDFSSuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/mergetree/GlutenClickHouseMergeTreeWriteOnHDFSSuite.scala index 6d404fe3aa..faffe19136 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/mergetree/GlutenClickHouseMergeTreeWriteOnHDFSSuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/mergetree/GlutenClickHouseMergeTreeWriteOnHDFSSuite.scala @@ -21,12 +21,13 @@ import org.apache.gluten.config.GlutenConfig import org.apache.gluten.execution.{FileSourceScanExecTransformer, GlutenClickHouseTPCHAbstractSuite} import org.apache.spark.SparkConf -import org.apache.spark.sql.SaveMode +import org.apache.spark.sql.{Row, SaveMode} import org.apache.spark.sql.delta.catalog.ClickHouseTableV2 import org.apache.spark.sql.delta.files.TahoeFileIndex import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.execution.datasources.mergetree.StorageMeta import org.apache.spark.sql.execution.datasources.v2.clickhouse.metadata.AddMergeTreeParts +import org.apache.spark.sql.types._ import org.apache.commons.io.FileUtils import org.apache.hadoop.conf.Configuration @@ -359,6 +360,54 @@ class GlutenClickHouseMergeTreeWriteOnHDFSSuite spark.sql("drop table lineitem_mergetree_partition_hdfs") } + test("test partition values with escape chars") { + + val schema = StructType( + Seq( + StructField.apply("id", IntegerType, nullable = true), + StructField.apply("escape", StringType, nullable = true) + )) + + // scalastyle:off nonascii + val data: Seq[Row] = Seq( + Row(1, "="), + Row(2, "/"), + Row(3, "#"), + Row(4, ":"), + Row(5, "\\"), + Row(6, "\u0001"), + Row(7, "中文"), + Row(8, " "), + Row(9, "a b") + ) + // scalastyle:on nonascii + + val df = spark.createDataFrame(spark.sparkContext.parallelize(data), schema) + df.createOrReplaceTempView("origin_table") + + // spark.conf.set("spark.gluten.enabled", "false") + spark.sql(s""" + |DROP TABLE IF EXISTS partition_escape; + |""".stripMargin) + + spark.sql(s""" + |CREATE TABLE IF NOT EXISTS partition_escape + |( + | c1 int, + | c2 string + |) + |USING clickhouse + |PARTITIONED BY (c2) + |TBLPROPERTIES (storage_policy='__hdfs_main', + | orderByKey='c1', + | primaryKey='c1') + |LOCATION '$HDFS_URL/test/partition_escape' + |""".stripMargin) + + spark.sql("insert into partition_escape select * from origin_table") + spark.sql("select * from partition_escape").show() + } + testSparkVersionLE33("test mergetree write with bucket table") { spark.sql(s""" |DROP TABLE IF EXISTS lineitem_mergetree_bucket_hdfs; diff --git a/backends-clickhouse/src/test/scala/org/apache/spark/gluten/NativeWriteChecker.scala b/backends-clickhouse/src/test/scala/org/apache/spark/gluten/NativeWriteChecker.scala index 384780e7d2..481e340d87 100644 --- a/backends-clickhouse/src/test/scala/org/apache/spark/gluten/NativeWriteChecker.scala +++ b/backends-clickhouse/src/test/scala/org/apache/spark/gluten/NativeWriteChecker.scala @@ -25,12 +25,21 @@ import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.execution.datasources.FakeRowAdaptor import org.apache.spark.sql.util.QueryExecutionListener +import java.io.File + +import scala.reflect.runtime.universe.TypeTag + trait NativeWriteChecker extends GlutenClickHouseWholeStageTransformerSuite with AdaptiveSparkPlanHelper { private val formats: Seq[String] = Seq("orc", "parquet") + // test non-ascii path, by the way + // scalastyle:off nonascii + protected def getWarehouseDir: String = basePath + "/中文/spark-warehouse" + // scalastyle:on nonascii + def withNativeWriteCheck(checkNative: Boolean)(block: => Unit): Unit = { var nativeUsed = false @@ -82,6 +91,22 @@ trait NativeWriteChecker } } + def nativeWrite2( + f: String => (String, String, String), + extraCheck: (String, String) => Unit = null, + checkNative: Boolean = true): Unit = nativeWrite { + format => + val (table_name, table_create_sql, insert_sql) = f(format) + withDestinationTable(table_name, Option(table_create_sql)) { + checkInsertQuery(insert_sql, checkNative) + Option(extraCheck).foreach(_(table_name, format)) + } + } + + def withSource[A <: Product: TypeTag](data: Seq[A], viewName: String, pairs: (String, String)*)( + block: => Unit): Unit = + withSource(spark.createDataFrame(data), viewName, pairs: _*)(block) + def withSource(df: Dataset[Row], viewName: String, pairs: (String, String)*)( block: => Unit): Unit = { withSQLConf(pairs: _*) { @@ -91,4 +116,64 @@ trait NativeWriteChecker } } } + + def getColumnName(col: String): String = { + col.replaceAll("\\(", "_").replaceAll("\\)", "_") + } + + def compareSource(originTable: String, table: String, fields: Seq[String]): Unit = { + def query(table: String, selectFields: Seq[String]): String = { + s"select ${selectFields.mkString(",")} from $table" + } + val expectedRows = spark.sql(query(originTable, fields)).collect() + val actual = spark.sql(query(table, fields.map(getColumnName))) + checkAnswer(actual, expectedRows) + } + + def writeAndCheckRead( + original_table: String, + table_name: String, + fields: Seq[String], + checkNative: Boolean = true)(write: Seq[String] => Unit): Unit = { + withDestinationTable(table_name) { + withNativeWriteCheck(checkNative) { + write(fields) + } + compareSource(original_table, table_name, fields) + } + } + + def compareWriteFilesSignature( + format: String, + table: String, + vanillaTable: String, + sigExpr: String): Unit = { + val tableFiles = recursiveListFiles(new File(getWarehouseDir + "/" + table)) + .filter(_.getName.endsWith(s".$format")) + val sigsOfNativeWriter = getSignature(format, tableFiles, sigExpr).sorted + val vanillaTableFiles = recursiveListFiles(new File(getWarehouseDir + "/" + vanillaTable)) + .filter(_.getName.endsWith(s".$format")) + val sigsOfVanillaWriter = getSignature(format, vanillaTableFiles, sigExpr).sorted + assertResult(sigsOfVanillaWriter)(sigsOfNativeWriter) + } + + def recursiveListFiles(f: File): Array[File] = { + val these = f.listFiles + these ++ these.filter(_.isDirectory).flatMap(recursiveListFiles) + } + + def getSignature( + format: String, + writeFiles: Array[File], + sigExpr: String): Array[(Long, Long)] = { + writeFiles.map( + f => { + val df = if (format.equals("parquet")) { + spark.read.parquet(f.getAbsolutePath) + } else { + spark.read.orc(f.getAbsolutePath) + } + (df.count(), df.selectExpr(sigExpr).collect().apply(0).apply(0).asInstanceOf[Long]) + }) + } } diff --git a/cpp-ch/local-engine/Functions/SparkPartitionEscape.cpp b/cpp-ch/local-engine/Functions/SparkPartitionEscape.cpp new file mode 100644 index 0000000000..522f9ddee2 --- /dev/null +++ b/cpp-ch/local-engine/Functions/SparkPartitionEscape.cpp @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "SparkPartitionEscape.h" +#include <Functions/FunctionFactory.h> +#include <Common/Exception.h> +#include <DataTypes/IDataType.h> +#include <DataTypes/DataTypeString.h> +#include <sstream> +#include <iomanip> +#include <string> + +namespace DB +{ +namespace ErrorCodes +{ +extern const int ILLEGAL_TYPE_OF_ARGUMENT; +extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; +} +} + +namespace local_engine +{ + +const std::vector<char> SparkPartitionEscape::ESCAPE_CHAR_LIST = { + '\u0001', '\u0002', '\u0003', '\u0004', '\u0005', '\u0006', '\u0007', '\u0008', '\u0009', + '\n', '\u000B', '\u000C', '\r', '\u000E', '\u000F', '\u0010', '\u0011', '\u0012', '\u0013', + '\u0014', '\u0015', '\u0016', '\u0017', '\u0018', '\u0019', '\u001A', '\u001B', '\u001C', + '\u001D', '\u001E', '\u001F', '"', '#', '%', '\'', '*', '/', ':', '=', '?', '\\', '\u007F', + '{', '[', ']', '^' +}; + +const std::bitset<128> SparkPartitionEscape::ESCAPE_BITSET = []() +{ + std::bitset<128> bitset; + for (char c : SparkPartitionEscape::ESCAPE_CHAR_LIST) + { + bitset.set(c); + } +#ifdef _WIN32 + bitset.set(' '); + bitset.set('<'); + bitset.set('>'); + bitset.set('|'); +#endif + return bitset; +}(); + +static bool needsEscaping(char c) { + return c >= 0 && c < SparkPartitionEscape::ESCAPE_BITSET.size() + && SparkPartitionEscape::ESCAPE_BITSET.test(c); +} + +static std::string escapePathName(const std::string & path) { + std::ostringstream builder; + for (char c : path) { + if (needsEscaping(c)) { + builder << '%' << std::uppercase << std::setw(2) << std::setfill('0') << std::hex << (int)c; + } else { + builder << c; + } + } + + return builder.str(); +} + +DB::DataTypePtr SparkPartitionEscape::getReturnTypeImpl(const DB::DataTypes & arguments) const +{ + if (arguments.size() != 1) + throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} argument size must be 1", name); + + if (!isString(arguments[0])) + throw DB::Exception(DB::ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Argument of function {} must be String", getName()); + + return std::make_shared<DataTypeString>(); +} + +DB::ColumnPtr SparkPartitionEscape::executeImpl( + const DB::ColumnsWithTypeAndName & arguments, const DB::DataTypePtr & result_type, size_t input_rows_count) const +{ + auto result = result_type->createColumn(); + result->reserve(input_rows_count); + + for (size_t i = 0; i < input_rows_count; ++i) + { + auto escaped_name = escapePathName(arguments[0].column->getDataAt(i).toString()); + result->insertData(escaped_name.c_str(), escaped_name.size()); + } + return result; +} + +REGISTER_FUNCTION(SparkPartitionEscape) +{ + factory.registerFunction<SparkPartitionEscape>(); +} +} diff --git a/cpp-ch/local-engine/Functions/SparkPartitionEscape.h b/cpp-ch/local-engine/Functions/SparkPartitionEscape.h new file mode 100644 index 0000000000..916134506a --- /dev/null +++ b/cpp-ch/local-engine/Functions/SparkPartitionEscape.h @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +#include <Columns/IColumn.h> +#include <Core/ColumnsWithTypeAndName.h> +#include <DataTypes/DataTypeDate.h> +#include <DataTypes/DataTypeNullable.h> +#include <DataTypes/IDataType.h> +#include <Functions/IFunction.h> +#include <Interpreters/Context.h> +#include <bitset> + +namespace DB +{ +namespace ErrorCodes +{ +extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; +} +} + +using namespace DB; + +namespace local_engine +{ + +class SparkPartitionEscape : public DB::IFunction +{ +public: + static const std::vector<char> ESCAPE_CHAR_LIST; + static const std::bitset<128> ESCAPE_BITSET; + static constexpr auto name = "sparkPartitionEscape"; + static FunctionPtr create(ContextPtr /*context*/) { return std::make_shared<SparkPartitionEscape>(); } + SparkPartitionEscape() = default; + ~SparkPartitionEscape() override = default; + String getName() const override { return name; } + size_t getNumberOfArguments() const override { return 1; } + bool isSuitableForShortCircuitArgumentsExecution(const DB::DataTypesWithConstInfo & /*arguments*/) const override { return true; } + DB::DataTypePtr getReturnTypeImpl(const DB::DataTypes & /*arguments*/) const override; + DB::ColumnPtr executeImpl( + const DB::ColumnsWithTypeAndName & arguments, const DB::DataTypePtr & result_type, size_t /*input_rows_count*/) const override; +}; + +} diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/CommonScalarFunctionParser.cpp b/cpp-ch/local-engine/Parser/scalar_function_parser/CommonScalarFunctionParser.cpp index ba2129deb1..841e51c00a 100644 --- a/cpp-ch/local-engine/Parser/scalar_function_parser/CommonScalarFunctionParser.cpp +++ b/cpp-ch/local-engine/Parser/scalar_function_parser/CommonScalarFunctionParser.cpp @@ -130,6 +130,7 @@ REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Uuid, uuid, generateUUIDv4); REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Levenshtein, levenshtein, editDistanceUTF8); REGISTER_COMMON_SCALAR_FUNCTION_PARSER(FormatString, format_string, printf); REGISTER_COMMON_SCALAR_FUNCTION_PARSER(SoundEx, soundex, soundex); +REGISTER_COMMON_SCALAR_FUNCTION_PARSER(PartitionEscape, partition_escape, sparkPartitionEscape); // hash functions REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Crc32, crc32, CRC32); diff --git a/cpp-ch/local-engine/Storages/MergeTree/SparkMergeTreeMeta.cpp b/cpp-ch/local-engine/Storages/MergeTree/SparkMergeTreeMeta.cpp index ef4d5504ff..b8e1a3cbd4 100644 --- a/cpp-ch/local-engine/Storages/MergeTree/SparkMergeTreeMeta.cpp +++ b/cpp-ch/local-engine/Storages/MergeTree/SparkMergeTreeMeta.cpp @@ -212,9 +212,7 @@ MergeTreeTableInstance::MergeTreeTableInstance(const std::string & info) : Merge while (!in.eof()) { MergeTreePart part; - std::string encoded_name; - readString(encoded_name, in); - Poco::URI::decode(encoded_name, part.name); + readString(part.name, in); assertChar('\n', in); readIntText(part.begin, in); assertChar('\n', in); diff --git a/cpp-ch/local-engine/Storages/Output/NormalFileWriter.h b/cpp-ch/local-engine/Storages/Output/NormalFileWriter.h index c0c762906a..77096f3f49 100644 --- a/cpp-ch/local-engine/Storages/Output/NormalFileWriter.h +++ b/cpp-ch/local-engine/Storages/Output/NormalFileWriter.h @@ -416,17 +416,23 @@ public: for (const auto & column : partition_columns) { // partition_column= - std::string key = add_slash ? fmt::format("/{}=", column) : fmt::format("{}=", column); + auto column_name = std::make_shared<DB::ASTLiteral>(column); + auto escaped_name = makeASTFunction("sparkPartitionEscape", DB::ASTs{column_name}); + if (add_slash) + arguments.emplace_back(std::make_shared<DB::ASTLiteral>("/")); add_slash = true; - arguments.emplace_back(std::make_shared<DB::ASTLiteral>(key)); + arguments.emplace_back(escaped_name); + arguments.emplace_back(std::make_shared<DB::ASTLiteral>("=")); // ifNull(toString(partition_column), DEFAULT_PARTITION_NAME) // FIXME if toString(partition_column) is empty - auto column_ast = std::make_shared<DB::ASTIdentifier>(column); + auto column_ast = makeASTFunction("toString", DB::ASTs{std::make_shared<DB::ASTIdentifier>(column)}); + auto escaped_value = makeASTFunction("sparkPartitionEscape", DB::ASTs{column_ast}); DB::ASTs if_null_args{ - makeASTFunction("toString", DB::ASTs{column_ast}), std::make_shared<DB::ASTLiteral>(DEFAULT_PARTITION_NAME)}; + makeASTFunction("toString", DB::ASTs{escaped_value}), std::make_shared<DB::ASTLiteral>(DEFAULT_PARTITION_NAME)}; arguments.emplace_back(makeASTFunction("ifNull", std::move(if_null_args))); } + if (isBucketedWrite(input_header)) { DB::ASTs args{std::make_shared<DB::ASTLiteral>("%05d"), std::make_shared<DB::ASTIdentifier>(BUCKET_COLUMN_NAME)}; diff --git a/cpp-ch/local-engine/tests/CMakeLists.txt b/cpp-ch/local-engine/tests/CMakeLists.txt index 09ca32a01a..9c18d70b0f 100644 --- a/cpp-ch/local-engine/tests/CMakeLists.txt +++ b/cpp-ch/local-engine/tests/CMakeLists.txt @@ -107,6 +107,7 @@ if(ENABLE_BENCHMARKS) benchmark_spark_row.cpp benchmark_unix_timestamp_function.cpp benchmark_spark_functions.cpp + benchmark_spark_partition_escape_function.cpp benchmark_cast_float_function.cpp benchmark_to_datetime_function.cpp benchmark_spark_divide_function.cpp diff --git a/cpp-ch/local-engine/tests/benchmark_spark_partition_escape_function.cpp b/cpp-ch/local-engine/tests/benchmark_spark_partition_escape_function.cpp new file mode 100644 index 0000000000..3299eb9fee --- /dev/null +++ b/cpp-ch/local-engine/tests/benchmark_spark_partition_escape_function.cpp @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <Core/Block.h> +#include <DataTypes/DataTypeFactory.h> +#include <Functions/FunctionFactory.h> +#include <Parser/FunctionParser.h> +#include <benchmark/benchmark.h> +#include <Common/QueryContext.h> + +using namespace DB; + +static Block createDataBlock(size_t rows) +{ + auto type = DataTypeFactory::instance().get("String"); + auto column = type->createColumn(); + for (size_t i = 0; i < rows; ++i) + { + char ch = static_cast<char>(i % 128); + std::string str = "escape_" + ch; + column->insert(str); + } + Block block; + block.insert(ColumnWithTypeAndName(std::move(column), type, "d")); + return std::move(block); +} + +static void BM_CHSparkPartitionEscape(benchmark::State & state) +{ + using namespace DB; + auto & factory = FunctionFactory::instance(); + auto function = factory.get("sparkPartitionEscape", local_engine::QueryContext::globalContext()); + Block block = createDataBlock(1000000); + auto executable = function->build(block.getColumnsWithTypeAndName()); + for (auto _ : state) [[maybe_unused]] + auto result = executable->execute(block.getColumnsWithTypeAndName(), executable->getResultType(), block.rows(), false); +} + +BENCHMARK(BM_CHSparkPartitionEscape)->Unit(benchmark::kMillisecond)->Iterations(50); \ No newline at end of file --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@gluten.apache.org For additional commands, e-mail: commits-h...@gluten.apache.org