This is an automated email from the ASF dual-hosted git repository. yihua pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/hudi.git
The following commit(s) were added to refs/heads/master by this push: new 17ea14ab6d6 [HUDI-7378] Fix Spark SQL DML with custom key generator (#10615) 17ea14ab6d6 is described below commit 17ea14ab6d6a8ca7ecef2cfcdbc67b0c87f23987 Author: Y Ethan Guo <ethan.guoyi...@gmail.com> AuthorDate: Fri Apr 12 22:51:03 2024 -0700 [HUDI-7378] Fix Spark SQL DML with custom key generator (#10615) --- .../factory/HoodieSparkKeyGeneratorFactory.java | 4 + .../org/apache/hudi/util/SparkKeyGenUtils.scala | 16 +- .../scala/org/apache/hudi/HoodieWriterUtils.scala | 20 +- .../spark/sql/hudi/ProvidesHoodieConfig.scala | 60 ++- .../spark/sql/hudi/TestProvidesHoodieConfig.scala | 79 +++ .../hudi/command/MergeIntoHoodieTableCommand.scala | 5 +- .../TestSparkSqlWithCustomKeyGenerator.scala | 571 +++++++++++++++++++++ 7 files changed, 742 insertions(+), 13 deletions(-) diff --git a/hudi-client/hudi-spark-client/src/main/java/org/apache/hudi/keygen/factory/HoodieSparkKeyGeneratorFactory.java b/hudi-client/hudi-spark-client/src/main/java/org/apache/hudi/keygen/factory/HoodieSparkKeyGeneratorFactory.java index 1ea5adcd6b4..dcc2eaec9eb 100644 --- a/hudi-client/hudi-spark-client/src/main/java/org/apache/hudi/keygen/factory/HoodieSparkKeyGeneratorFactory.java +++ b/hudi-client/hudi-spark-client/src/main/java/org/apache/hudi/keygen/factory/HoodieSparkKeyGeneratorFactory.java @@ -79,6 +79,10 @@ public class HoodieSparkKeyGeneratorFactory { public static KeyGenerator createKeyGenerator(TypedProperties props) throws IOException { String keyGeneratorClass = getKeyGeneratorClassName(props); + return createKeyGenerator(keyGeneratorClass, props); + } + + public static KeyGenerator createKeyGenerator(String keyGeneratorClass, TypedProperties props) throws IOException { boolean autoRecordKeyGen = KeyGenUtils.isAutoGeneratedRecordKeysEnabled(props) //Need to prevent overwriting the keygen for spark sql merge into because we need to extract //the recordkey from the meta cols if it exists. Sql keygen will use pkless keygen if needed. diff --git a/hudi-client/hudi-spark-client/src/main/scala/org/apache/hudi/util/SparkKeyGenUtils.scala b/hudi-client/hudi-spark-client/src/main/scala/org/apache/hudi/util/SparkKeyGenUtils.scala index 7b91ae5a728..bd094464096 100644 --- a/hudi-client/hudi-spark-client/src/main/scala/org/apache/hudi/util/SparkKeyGenUtils.scala +++ b/hudi-client/hudi-spark-client/src/main/scala/org/apache/hudi/util/SparkKeyGenUtils.scala @@ -21,8 +21,8 @@ import org.apache.hudi.common.config.TypedProperties import org.apache.hudi.common.util.StringUtils import org.apache.hudi.common.util.ValidationUtils.checkArgument import org.apache.hudi.keygen.constant.KeyGeneratorOptions -import org.apache.hudi.keygen.{AutoRecordKeyGeneratorWrapper, AutoRecordGenWrapperKeyGenerator, CustomAvroKeyGenerator, CustomKeyGenerator, GlobalAvroDeleteKeyGenerator, GlobalDeleteKeyGenerator, KeyGenerator, NonpartitionedAvroKeyGenerator, NonpartitionedKeyGenerator} import org.apache.hudi.keygen.factory.HoodieSparkKeyGeneratorFactory +import org.apache.hudi.keygen.{AutoRecordKeyGeneratorWrapper, CustomAvroKeyGenerator, CustomKeyGenerator, GlobalAvroDeleteKeyGenerator, GlobalDeleteKeyGenerator, KeyGenerator, NonpartitionedAvroKeyGenerator, NonpartitionedKeyGenerator} object SparkKeyGenUtils { @@ -35,6 +35,20 @@ object SparkKeyGenUtils { getPartitionColumns(keyGenerator, props) } + /** + * @param KeyGenClassNameOption key generator class name if present. + * @param props config properties. + * @return partition column names only, concatenated by "," + */ + def getPartitionColumns(KeyGenClassNameOption: Option[String], props: TypedProperties): String = { + val keyGenerator = if (KeyGenClassNameOption.isEmpty) { + HoodieSparkKeyGeneratorFactory.createKeyGenerator(props) + } else { + HoodieSparkKeyGeneratorFactory.createKeyGenerator(KeyGenClassNameOption.get, props) + } + getPartitionColumns(keyGenerator, props) + } + /** * @param keyGen key generator class name * @return partition columns diff --git a/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/HoodieWriterUtils.scala b/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/HoodieWriterUtils.scala index 63495b0eede..5df773542d6 100644 --- a/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/HoodieWriterUtils.scala +++ b/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/HoodieWriterUtils.scala @@ -201,8 +201,26 @@ object HoodieWriterUtils { diffConfigs.append(s"KeyGenerator:\t$datasourceKeyGen\t$tableConfigKeyGen\n") } + // Please note that the validation of partition path fields needs the key generator class + // for the table, since the custom key generator expects a different format of + // the value of the write config "hoodie.datasource.write.partitionpath.field" + // e.g., "col:simple,ts:timestamp", whereas the table config "hoodie.table.partition.fields" + // in hoodie.properties stores "col,ts". + // The "params" here may only contain the write config of partition path field, + // so we need to pass in the validated key generator class name. + val validatedKeyGenClassName = if (tableConfigKeyGen != null) { + Option(tableConfigKeyGen) + } else if (datasourceKeyGen != null) { + Option(datasourceKeyGen) + } else { + None + } val datasourcePartitionFields = params.getOrElse(PARTITIONPATH_FIELD.key(), null) - val currentPartitionFields = if (datasourcePartitionFields == null) null else SparkKeyGenUtils.getPartitionColumns(TypedProperties.fromMap(params)) + val currentPartitionFields = if (datasourcePartitionFields == null) { + null + } else { + SparkKeyGenUtils.getPartitionColumns(validatedKeyGenClassName, TypedProperties.fromMap(params)) + } val tableConfigPartitionFields = tableConfig.getString(HoodieTableConfig.PARTITION_FIELDS) if (null != datasourcePartitionFields && null != tableConfigPartitionFields && currentPartitionFields != tableConfigPartitionFields) { diff --git a/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/spark/sql/hudi/ProvidesHoodieConfig.scala b/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/spark/sql/hudi/ProvidesHoodieConfig.scala index a4003bbd480..7d35c490fd4 100644 --- a/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/spark/sql/hudi/ProvidesHoodieConfig.scala +++ b/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/spark/sql/hudi/ProvidesHoodieConfig.scala @@ -18,34 +18,36 @@ package org.apache.spark.sql.hudi import org.apache.hudi.AutoRecordKeyGenerationUtils.shouldAutoGenerateRecordKeys -import org.apache.hudi.{DataSourceWriteOptions, HoodieFileIndex} import org.apache.hudi.DataSourceWriteOptions._ import org.apache.hudi.HoodieConversionUtils.toProperties import org.apache.hudi.common.config.{DFSPropertiesConfiguration, TypedProperties} import org.apache.hudi.common.model.{DefaultHoodieRecordPayload, WriteOperationType} import org.apache.hudi.common.table.HoodieTableConfig +import org.apache.hudi.common.util.{ReflectionUtils, StringUtils} import org.apache.hudi.config.HoodieWriteConfig.TBL_NAME import org.apache.hudi.config.{HoodieIndexConfig, HoodieInternalConfig, HoodieWriteConfig} import org.apache.hudi.hive.ddl.HiveSyncMode import org.apache.hudi.hive.{HiveSyncConfig, HiveSyncConfigHolder, MultiPartKeysValueExtractor} -import org.apache.hudi.keygen.ComplexKeyGenerator +import org.apache.hudi.keygen.{ComplexKeyGenerator, CustomAvroKeyGenerator, CustomKeyGenerator} import org.apache.hudi.sql.InsertMode import org.apache.hudi.sync.common.HoodieSyncConfig +import org.apache.hudi.{DataSourceWriteOptions, HoodieFileIndex} import org.apache.spark.internal.Logging -import org.apache.spark.sql.{SaveMode, SparkSession} import org.apache.spark.sql.catalyst.catalog.HoodieCatalogTable import org.apache.spark.sql.catalyst.expressions.{AttributeReference, EqualTo, Literal} import org.apache.spark.sql.execution.datasources.FileStatusCache import org.apache.spark.sql.hive.HiveExternalCatalog import org.apache.spark.sql.hudi.HoodieOptionConfig.mapSqlOptionsToDataSourceWriteConfigs import org.apache.spark.sql.hudi.HoodieSqlCommonUtils.{isHoodieConfigKey, isUsingHiveCatalog} -import org.apache.spark.sql.hudi.ProvidesHoodieConfig.combineOptions +import org.apache.spark.sql.hudi.ProvidesHoodieConfig.{combineOptions, getPartitionPathFieldWriteConfig} import org.apache.spark.sql.hudi.command.{SqlKeyGenerator, ValidateDuplicateKeyPayload} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.PARTITION_OVERWRITE_MODE import org.apache.spark.sql.types.StructType -import java.util.Locale +import org.apache.spark.sql.{SaveMode, SparkSession} +import org.slf4j.LoggerFactory +import java.util.Locale import scala.collection.JavaConverters._ trait ProvidesHoodieConfig extends Logging { @@ -82,7 +84,8 @@ trait ProvidesHoodieConfig extends Logging { PRECOMBINE_FIELD.key -> preCombineField, HIVE_STYLE_PARTITIONING.key -> tableConfig.getHiveStylePartitioningEnable, URL_ENCODE_PARTITIONING.key -> tableConfig.getUrlEncodePartitioning, - PARTITIONPATH_FIELD.key -> tableConfig.getPartitionFieldProp + PARTITIONPATH_FIELD.key -> getPartitionPathFieldWriteConfig( + tableConfig.getKeyGeneratorClassName, tableConfig.getPartitionFieldProp, hoodieCatalogTable) ) combineOptions(hoodieCatalogTable, tableConfig, sparkSession.sqlContext.conf, @@ -313,7 +316,8 @@ trait ProvidesHoodieConfig extends Logging { URL_ENCODE_PARTITIONING.key -> urlEncodePartitioning, RECORDKEY_FIELD.key -> recordKeyConfigValue, PRECOMBINE_FIELD.key -> preCombineField, - PARTITIONPATH_FIELD.key -> partitionFieldsStr + PARTITIONPATH_FIELD.key -> getPartitionPathFieldWriteConfig( + keyGeneratorClassName, partitionFieldsStr, hoodieCatalogTable) ) ++ overwriteTableOpts ++ getDropDupsConfig(useLegacyInsertModeFlow, combinedOpts) ++ staticOverwritePartitionPathOptions combineOptions(hoodieCatalogTable, tableConfig, sparkSession.sqlContext.conf, @@ -405,7 +409,8 @@ trait ProvidesHoodieConfig extends Logging { PARTITIONS_TO_DELETE.key -> partitionsToDrop, RECORDKEY_FIELD.key -> hoodieCatalogTable.primaryKeys.mkString(","), PRECOMBINE_FIELD.key -> hoodieCatalogTable.preCombineKey.getOrElse(""), - PARTITIONPATH_FIELD.key -> partitionFields, + PARTITIONPATH_FIELD.key -> getPartitionPathFieldWriteConfig( + tableConfig.getKeyGeneratorClassName, partitionFields, hoodieCatalogTable), HoodieSyncConfig.META_SYNC_ENABLED.key -> hiveSyncConfig.getString(HoodieSyncConfig.META_SYNC_ENABLED.key), HiveSyncConfigHolder.HIVE_SYNC_ENABLED.key -> hiveSyncConfig.getString(HiveSyncConfigHolder.HIVE_SYNC_ENABLED.key), HiveSyncConfigHolder.HIVE_SYNC_MODE.key -> hiveSyncConfig.getStringOrDefault(HiveSyncConfigHolder.HIVE_SYNC_MODE, HiveSyncMode.HMS.name()), @@ -451,7 +456,8 @@ trait ProvidesHoodieConfig extends Logging { HIVE_STYLE_PARTITIONING.key -> tableConfig.getHiveStylePartitioningEnable, URL_ENCODE_PARTITIONING.key -> tableConfig.getUrlEncodePartitioning, OPERATION.key -> DataSourceWriteOptions.DELETE_OPERATION_OPT_VAL, - PARTITIONPATH_FIELD.key -> tableConfig.getPartitionFieldProp + PARTITIONPATH_FIELD.key -> getPartitionPathFieldWriteConfig( + tableConfig.getKeyGeneratorClassName, tableConfig.getPartitionFieldProp, hoodieCatalogTable) ) combineOptions(hoodieCatalogTable, tableConfig, sparkSession.sqlContext.conf, @@ -496,6 +502,8 @@ trait ProvidesHoodieConfig extends Logging { object ProvidesHoodieConfig { + private val log = LoggerFactory.getLogger(getClass) + // NOTE: PLEASE READ CAREFULLY BEFORE CHANGING // // Spark SQL operations configuration might be coming from a variety of diverse sources @@ -530,6 +538,40 @@ object ProvidesHoodieConfig { filterNullValues(overridingOpts) } + /** + * @param tableConfigKeyGeneratorClassName key generator class name in the table config. + * @param partitionFieldNamesWithoutKeyGenType partition field names without key generator types + * from the table config. + * @param catalogTable HoodieCatalogTable instance to fetch table properties. + * @return the write config value to set for "hoodie.datasource.write.partitionpath.field". + */ + def getPartitionPathFieldWriteConfig(tableConfigKeyGeneratorClassName: String, + partitionFieldNamesWithoutKeyGenType: String, + catalogTable: HoodieCatalogTable): String = { + if (StringUtils.isNullOrEmpty(tableConfigKeyGeneratorClassName)) { + partitionFieldNamesWithoutKeyGenType + } else { + val writeConfigPartitionField = catalogTable.catalogProperties.get(PARTITIONPATH_FIELD.key()) + val keyGenClass = ReflectionUtils.getClass(tableConfigKeyGeneratorClassName) + if (classOf[CustomKeyGenerator].equals(keyGenClass) + || classOf[CustomAvroKeyGenerator].equals(keyGenClass)) { + // For custom key generator, we have to take the write config value from + // "hoodie.datasource.write.partitionpath.field" which contains the key generator + // type, whereas the table config only contains the prtition field names without + // key generator types. + if (writeConfigPartitionField.isDefined) { + writeConfigPartitionField.get + } else { + log.warn("Write config \"hoodie.datasource.write.partitionpath.field\" is not set for " + + "custom key generator. This may fail the write operation.") + partitionFieldNamesWithoutKeyGenType + } + } else { + partitionFieldNamesWithoutKeyGenType + } + } + } + private def filterNullValues(opts: Map[String, String]): Map[String, String] = opts.filter { case (_, v) => v != null } diff --git a/hudi-spark-datasource/hudi-spark-common/src/test/scala/org/apache/spark/sql/hudi/TestProvidesHoodieConfig.scala b/hudi-spark-datasource/hudi-spark-common/src/test/scala/org/apache/spark/sql/hudi/TestProvidesHoodieConfig.scala new file mode 100644 index 00000000000..8414e41ca6c --- /dev/null +++ b/hudi-spark-datasource/hudi-spark-common/src/test/scala/org/apache/spark/sql/hudi/TestProvidesHoodieConfig.scala @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.hudi + +import org.apache.hudi.DataSourceWriteOptions.PARTITIONPATH_FIELD +import org.apache.hudi.keygen.{ComplexKeyGenerator, CustomKeyGenerator} + +import org.apache.spark.sql.catalyst.catalog.HoodieCatalogTable +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.Test +import org.mockito.Mockito +import org.mockito.Mockito.when + +/** + * Tests {@link ProvidesHoodieConfig} + */ +class TestProvidesHoodieConfig { + @Test + def testGetPartitionPathFieldWriteConfig(): Unit = { + val mockTable = Mockito.mock(classOf[HoodieCatalogTable]) + val partitionFieldNames = "ts,segment" + val customKeyGenPartitionFieldWriteConfig = "ts:timestamp,segment:simple" + + mockPartitionWriteConfigInCatalogProps(mockTable, None) + assertEquals( + partitionFieldNames, + ProvidesHoodieConfig.getPartitionPathFieldWriteConfig( + "", partitionFieldNames, mockTable)) + assertEquals( + partitionFieldNames, + ProvidesHoodieConfig.getPartitionPathFieldWriteConfig( + classOf[ComplexKeyGenerator].getName, partitionFieldNames, mockTable)) + assertEquals( + partitionFieldNames, + ProvidesHoodieConfig.getPartitionPathFieldWriteConfig( + classOf[CustomKeyGenerator].getName, partitionFieldNames, mockTable)) + + mockPartitionWriteConfigInCatalogProps(mockTable, Option(customKeyGenPartitionFieldWriteConfig)) + assertEquals( + partitionFieldNames, + ProvidesHoodieConfig.getPartitionPathFieldWriteConfig( + "", partitionFieldNames, mockTable)) + assertEquals( + partitionFieldNames, + ProvidesHoodieConfig.getPartitionPathFieldWriteConfig( + classOf[ComplexKeyGenerator].getName, partitionFieldNames, mockTable)) + assertEquals( + customKeyGenPartitionFieldWriteConfig, + ProvidesHoodieConfig.getPartitionPathFieldWriteConfig( + classOf[CustomKeyGenerator].getName, partitionFieldNames, mockTable)) + } + + private def mockPartitionWriteConfigInCatalogProps(mockTable: HoodieCatalogTable, + value: Option[String]): Unit = { + val props = if (value.isDefined) { + Map(PARTITIONPATH_FIELD.key() -> value.get) + } else { + Map[String, String]() + } + when(mockTable.catalogProperties).thenReturn(props) + } +} diff --git a/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/MergeIntoHoodieTableCommand.scala b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/MergeIntoHoodieTableCommand.scala index 18403872f4a..79cd2646e08 100644 --- a/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/MergeIntoHoodieTableCommand.scala +++ b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/MergeIntoHoodieTableCommand.scala @@ -40,7 +40,7 @@ import org.apache.spark.sql.catalyst.plans.LeftOuter import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.hudi.HoodieSqlCommonUtils._ import org.apache.spark.sql.hudi.ProvidesHoodieConfig -import org.apache.spark.sql.hudi.ProvidesHoodieConfig.combineOptions +import org.apache.spark.sql.hudi.ProvidesHoodieConfig.{combineOptions, getPartitionPathFieldWriteConfig} import org.apache.spark.sql.hudi.analysis.HoodieAnalysis.failAnalysis import org.apache.spark.sql.hudi.command.MergeIntoHoodieTableCommand.{CoercedAttributeReference, encodeAsBase64String, stripCasting, toStructType} import org.apache.spark.sql.hudi.command.PartialAssignmentMode.PartialAssignmentMode @@ -729,7 +729,8 @@ case class MergeIntoHoodieTableCommand(mergeInto: MergeIntoTable) extends Hoodie RECORDKEY_FIELD.key -> tableConfig.getRawRecordKeyFieldProp, PRECOMBINE_FIELD.key -> preCombineField, TBL_NAME.key -> hoodieCatalogTable.tableName, - PARTITIONPATH_FIELD.key -> tableConfig.getPartitionFieldProp, + PARTITIONPATH_FIELD.key -> getPartitionPathFieldWriteConfig( + tableConfig.getKeyGeneratorClassName, tableConfig.getPartitionFieldProp, hoodieCatalogTable), HIVE_STYLE_PARTITIONING.key -> tableConfig.getHiveStylePartitioningEnable, URL_ENCODE_PARTITIONING.key -> tableConfig.getUrlEncodePartitioning, KEYGENERATOR_CLASS_NAME.key -> keyGeneratorClassName, diff --git a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestSparkSqlWithCustomKeyGenerator.scala b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestSparkSqlWithCustomKeyGenerator.scala new file mode 100644 index 00000000000..c85eb40bca7 --- /dev/null +++ b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestSparkSqlWithCustomKeyGenerator.scala @@ -0,0 +1,571 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.hudi.functional + +import org.apache.hudi.HoodieSparkUtils +import org.apache.hudi.common.config.TypedProperties +import org.apache.hudi.common.table.HoodieTableMetaClient +import org.apache.hudi.common.util.StringUtils +import org.apache.hudi.exception.HoodieException +import org.apache.hudi.functional.TestSparkSqlWithCustomKeyGenerator._ +import org.apache.hudi.util.SparkKeyGenUtils +import org.apache.spark.sql.SaveMode +import org.apache.spark.sql.hudi.common.HoodieSparkSqlTestBase +import org.joda.time.DateTime +import org.joda.time.format.DateTimeFormat +import org.junit.jupiter.api.Assertions.{assertEquals, assertFalse, assertTrue} +import org.slf4j.LoggerFactory + +import java.io.IOException + +/** + * Tests Spark SQL DML with custom key generator and write configs. + */ +class TestSparkSqlWithCustomKeyGenerator extends HoodieSparkSqlTestBase { + private val LOG = LoggerFactory.getLogger(getClass) + + test("Test Spark SQL DML with custom key generator") { + withTempDir { tmp => + Seq( + Seq("COPY_ON_WRITE", "ts:timestamp,segment:simple", + "(ts=202401, segment='cat2')", "202401/cat2", + Seq("202312/cat2", "202312/cat4", "202401/cat1", "202401/cat3", "202402/cat1", "202402/cat3", "202402/cat5"), + TS_FORMATTER_FUNC, + (ts: Integer, segment: String) => TS_FORMATTER_FUNC.apply(ts) + "/" + segment), + Seq("MERGE_ON_READ", "segment:simple", + "(segment='cat3')", "cat3", + Seq("cat1", "cat2", "cat4", "cat5"), + TS_TO_STRING_FUNC, + (_: Integer, segment: String) => segment), + Seq("MERGE_ON_READ", "ts:timestamp", + "(ts=202312)", "202312", + Seq("202401", "202402"), + TS_FORMATTER_FUNC, + (ts: Integer, _: String) => TS_FORMATTER_FUNC.apply(ts)), + Seq("MERGE_ON_READ", "ts:timestamp,segment:simple", + "(ts=202401, segment='cat2')", "202401/cat2", + Seq("202312/cat2", "202312/cat4", "202401/cat1", "202401/cat3", "202402/cat1", "202402/cat3", "202402/cat5"), + TS_FORMATTER_FUNC, + (ts: Integer, segment: String) => TS_FORMATTER_FUNC.apply(ts) + "/" + segment) + ).foreach { testParams => + withTable(generateTableName) { tableName => + LOG.warn("Testing with parameters: " + testParams) + val tableType = testParams(0).asInstanceOf[String] + val writePartitionFields = testParams(1).asInstanceOf[String] + val dropPartitionStatement = testParams(2).asInstanceOf[String] + val droppedPartition = testParams(3).asInstanceOf[String] + val expectedPartitions = testParams(4).asInstanceOf[Seq[String]] + val tsGenFunc = testParams(5).asInstanceOf[Integer => String] + val partitionGenFunc = testParams(6).asInstanceOf[(Integer, String) => String] + val tablePath = tmp.getCanonicalPath + "/" + tableName + val timestampKeyGeneratorConfig = if (writePartitionFields.contains("timestamp")) { + TS_KEY_GEN_CONFIGS + } else { + Map[String, String]() + } + val timestampKeyGenProps = if (timestampKeyGeneratorConfig.nonEmpty) { + ", " + timestampKeyGeneratorConfig.map(e => e._1 + " = '" + e._2 + "'").mkString(", ") + } else { + "" + } + + prepareTableWithKeyGenerator( + tableName, tablePath, tableType, + CUSTOM_KEY_GEN_CLASS_NAME, writePartitionFields, timestampKeyGeneratorConfig) + + // SQL CTAS with table properties containing key generator write configs + createTableWithSql(tableName, tablePath, + s"hoodie.datasource.write.partitionpath.field = '$writePartitionFields'" + timestampKeyGenProps) + + // Prepare source and test SQL INSERT INTO + val sourceTableName = tableName + "_source" + prepareParquetSource(sourceTableName, Seq( + "(7, 'a7', 1399.0, 1706800227, 'cat1')", + "(8, 'a8', 26.9, 1706800227, 'cat3')", + "(9, 'a9', 299.0, 1701443427, 'cat4')")) + spark.sql( + s""" + | INSERT INTO $tableName + | SELECT * from ${tableName}_source + | """.stripMargin) + validateResults( + tableName, + s"SELECT id, name, cast(price as string), cast(ts as string), segment from $tableName", + tsGenFunc, + partitionGenFunc, + Seq(), + Seq(1, "a1", "1.6", 1704121827, "cat1"), + Seq(2, "a2", "10.8", 1704121827, "cat1"), + Seq(3, "a3", "30.0", 1706800227, "cat1"), + Seq(4, "a4", "103.4", 1701443427, "cat2"), + Seq(5, "a5", "1999.0", 1704121827, "cat2"), + Seq(6, "a6", "80.0", 1704121827, "cat3"), + Seq(7, "a7", "1399.0", 1706800227, "cat1"), + Seq(8, "a8", "26.9", 1706800227, "cat3"), + Seq(9, "a9", "299.0", 1701443427, "cat4") + ) + + // Test SQL UPDATE + spark.sql( + s""" + | UPDATE $tableName + | SET price = price + 10.0 + | WHERE id between 4 and 7 + | """.stripMargin) + validateResults( + tableName, + s"SELECT id, name, cast(price as string), cast(ts as string), segment from $tableName", + tsGenFunc, + partitionGenFunc, + Seq(), + Seq(1, "a1", "1.6", 1704121827, "cat1"), + Seq(2, "a2", "10.8", 1704121827, "cat1"), + Seq(3, "a3", "30.0", 1706800227, "cat1"), + Seq(4, "a4", "113.4", 1701443427, "cat2"), + Seq(5, "a5", "2009.0", 1704121827, "cat2"), + Seq(6, "a6", "90.0", 1704121827, "cat3"), + Seq(7, "a7", "1409.0", 1706800227, "cat1"), + Seq(8, "a8", "26.9", 1706800227, "cat3"), + Seq(9, "a9", "299.0", 1701443427, "cat4") + ) + + // Test SQL MERGE INTO + spark.sql( + s""" + | MERGE INTO $tableName as target + | USING ( + | SELECT 1 as id, 'a1' as name, 1.6 as price, 1704121827 as ts, 'cat1' as segment, 'delete' as flag + | UNION + | SELECT 2 as id, 'a2' as name, 11.9 as price, 1704121827 as ts, 'cat1' as segment, '' as flag + | UNION + | SELECT 6 as id, 'a6' as name, 99.0 as price, 1704121827 as ts, 'cat3' as segment, '' as flag + | UNION + | SELECT 8 as id, 'a8' as name, 24.9 as price, 1706800227 as ts, 'cat3' as segment, '' as flag + | UNION + | SELECT 10 as id, 'a10' as name, 888.8 as price, 1706800227 as ts, 'cat5' as segment, '' as flag + | ) source + | on target.id = source.id + | WHEN MATCHED AND flag != 'delete' THEN UPDATE SET + | id = source.id, name = source.name, price = source.price, ts = source.ts, segment = source.segment + | WHEN MATCHED AND flag = 'delete' THEN DELETE + | WHEN NOT MATCHED THEN INSERT (id, name, price, ts, segment) + | values (source.id, source.name, source.price, source.ts, source.segment) + | """.stripMargin) + validateResults( + tableName, + s"SELECT id, name, cast(price as string), cast(ts as string), segment from $tableName", + tsGenFunc, + partitionGenFunc, + Seq(), + Seq(2, "a2", "11.9", 1704121827, "cat1"), + Seq(3, "a3", "30.0", 1706800227, "cat1"), + Seq(4, "a4", "113.4", 1701443427, "cat2"), + Seq(5, "a5", "2009.0", 1704121827, "cat2"), + Seq(6, "a6", "99.0", 1704121827, "cat3"), + Seq(7, "a7", "1409.0", 1706800227, "cat1"), + Seq(8, "a8", "24.9", 1706800227, "cat3"), + Seq(9, "a9", "299.0", 1701443427, "cat4"), + Seq(10, "a10", "888.8", 1706800227, "cat5") + ) + + // Test SQL DELETE + spark.sql( + s""" + | DELETE FROM $tableName + | WHERE id = 7 + | """.stripMargin) + validateResults( + tableName, + s"SELECT id, name, cast(price as string), cast(ts as string), segment from $tableName", + tsGenFunc, + partitionGenFunc, + Seq(), + Seq(2, "a2", "11.9", 1704121827, "cat1"), + Seq(3, "a3", "30.0", 1706800227, "cat1"), + Seq(4, "a4", "113.4", 1701443427, "cat2"), + Seq(5, "a5", "2009.0", 1704121827, "cat2"), + Seq(6, "a6", "99.0", 1704121827, "cat3"), + Seq(8, "a8", "24.9", 1706800227, "cat3"), + Seq(9, "a9", "299.0", 1701443427, "cat4"), + Seq(10, "a10", "888.8", 1706800227, "cat5") + ) + + // Test DROP PARTITION + assertTrue(getSortedTablePartitions(tableName).contains(droppedPartition)) + spark.sql( + s""" + | ALTER TABLE $tableName DROP PARTITION $dropPartitionStatement + |""".stripMargin) + validatePartitions(tableName, Seq(droppedPartition), expectedPartitions) + + if (HoodieSparkUtils.isSpark3) { + // Test INSERT OVERWRITE, only supported in Spark 3.x + spark.sql( + s""" + | INSERT OVERWRITE $tableName + | SELECT 100 as id, 'a100' as name, 299.0 as price, 1706800227 as ts, 'cat10' as segment + | """.stripMargin) + validateResults( + tableName, + s"SELECT id, name, cast(price as string), cast(ts as string), segment from $tableName", + tsGenFunc, + partitionGenFunc, + Seq(), + Seq(100, "a100", "299.0", 1706800227, "cat10") + ) + } + } + } + } + } + + test("Test table property isolation for partition path field config " + + "with custom key generator for Spark 3.1 and above") { + // Only testing Spark 3.1 and above as lower Spark versions do not support + // ALTER TABLE .. SET TBLPROPERTIES .. to store table-level properties in Hudi Catalog + if (HoodieSparkUtils.gteqSpark3_1) { + withTempDir { tmp => { + val tableNameNonPartitioned = generateTableName + val tableNameSimpleKey = generateTableName + val tableNameCustom1 = generateTableName + val tableNameCustom2 = generateTableName + + val tablePathNonPartitioned = tmp.getCanonicalPath + "/" + tableNameNonPartitioned + val tablePathSimpleKey = tmp.getCanonicalPath + "/" + tableNameSimpleKey + val tablePathCustom1 = tmp.getCanonicalPath + "/" + tableNameCustom1 + val tablePathCustom2 = tmp.getCanonicalPath + "/" + tableNameCustom2 + + val tableType = "MERGE_ON_READ" + val writePartitionFields1 = "segment:simple" + val writePartitionFields2 = "ts:timestamp,segment:simple" + + prepareTableWithKeyGenerator( + tableNameNonPartitioned, tablePathNonPartitioned, tableType, + NONPARTITIONED_KEY_GEN_CLASS_NAME, "", Map()) + prepareTableWithKeyGenerator( + tableNameSimpleKey, tablePathSimpleKey, tableType, + SIMPLE_KEY_GEN_CLASS_NAME, "segment", Map()) + prepareTableWithKeyGenerator( + tableNameCustom1, tablePathCustom1, tableType, + CUSTOM_KEY_GEN_CLASS_NAME, writePartitionFields1, Map()) + prepareTableWithKeyGenerator( + tableNameCustom2, tablePathCustom2, tableType, + CUSTOM_KEY_GEN_CLASS_NAME, writePartitionFields2, TS_KEY_GEN_CONFIGS) + + // Non-partitioned table does not require additional partition path field write config + createTableWithSql(tableNameNonPartitioned, tablePathNonPartitioned, "") + // Partitioned table with simple key generator does not require additional partition path field write config + createTableWithSql(tableNameSimpleKey, tablePathSimpleKey, "") + // Partitioned table with custom key generator requires additional partition path field write config + // Without that, right now the SQL DML fails + createTableWithSql(tableNameCustom1, tablePathCustom1, "") + createTableWithSql(tableNameCustom2, tablePathCustom2, + s"hoodie.datasource.write.partitionpath.field = '$writePartitionFields2', " + + TS_KEY_GEN_CONFIGS.map(e => e._1 + " = '" + e._2 + "'").mkString(", ")) + + val segmentPartitionFunc = (_: Integer, segment: String) => segment + val customPartitionFunc = (ts: Integer, segment: String) => TS_FORMATTER_FUNC.apply(ts) + "/" + segment + + testFirstRoundInserts(tableNameNonPartitioned, TS_TO_STRING_FUNC, (_, _) => "") + testFirstRoundInserts(tableNameSimpleKey, TS_TO_STRING_FUNC, segmentPartitionFunc) + // INSERT INTO should fail for tableNameCustom1 + val sourceTableName = tableNameCustom1 + "_source" + prepareParquetSource(sourceTableName, Seq("(7, 'a7', 1399.0, 1706800227, 'cat1')")) + assertThrows[IOException] { + spark.sql( + s""" + | INSERT INTO $tableNameCustom1 + | SELECT * from $sourceTableName + | """.stripMargin) + } + testFirstRoundInserts(tableNameCustom2, TS_FORMATTER_FUNC, customPartitionFunc) + + // Now add the missing partition path field write config for tableNameCustom1 + spark.sql( + s"""ALTER TABLE $tableNameCustom1 + | SET TBLPROPERTIES (hoodie.datasource.write.partitionpath.field = '$writePartitionFields1') + | """.stripMargin) + + // All tables should be able to do INSERT INTO without any problem, + // since the scope of the added write config is at the catalog table level + testSecondRoundInserts(tableNameNonPartitioned, TS_TO_STRING_FUNC, (_, _) => "") + testSecondRoundInserts(tableNameSimpleKey, TS_TO_STRING_FUNC, segmentPartitionFunc) + testFirstRoundInserts(tableNameCustom1, TS_TO_STRING_FUNC, segmentPartitionFunc) + testSecondRoundInserts(tableNameCustom2, TS_FORMATTER_FUNC, customPartitionFunc) + } + } + } + } + + test("Test wrong partition path field write config with custom key generator") { + withTempDir { tmp => { + val tableName = generateTableName + val tablePath = tmp.getCanonicalPath + "/" + tableName + val tableType = "MERGE_ON_READ" + val writePartitionFields = "segment:simple,ts:timestamp" + val wrongWritePartitionFields = "segment:simple" + val customPartitionFunc = (ts: Integer, segment: String) => segment + "/" + TS_FORMATTER_FUNC.apply(ts) + + prepareTableWithKeyGenerator( + tableName, tablePath, "MERGE_ON_READ", + CUSTOM_KEY_GEN_CLASS_NAME, writePartitionFields, TS_KEY_GEN_CONFIGS) + + // CREATE TABLE should fail due to config conflict + assertThrows[HoodieException] { + createTableWithSql(tableName, tablePath, + s"hoodie.datasource.write.partitionpath.field = '$wrongWritePartitionFields', " + + TS_KEY_GEN_CONFIGS.map(e => e._1 + " = '" + e._2 + "'").mkString(", ")) + } + + createTableWithSql(tableName, tablePath, + s"hoodie.datasource.write.partitionpath.field = '$writePartitionFields', " + + TS_KEY_GEN_CONFIGS.map(e => e._1 + " = '" + e._2 + "'").mkString(", ")) + // Set wrong write config + spark.sql( + s"""ALTER TABLE $tableName + | SET TBLPROPERTIES (hoodie.datasource.write.partitionpath.field = '$wrongWritePartitionFields') + | """.stripMargin) + + // INSERT INTO should fail due to conflict between write and table config of partition path fields + val sourceTableName = tableName + "_source" + prepareParquetSource(sourceTableName, Seq("(7, 'a7', 1399.0, 1706800227, 'cat1')")) + assertThrows[HoodieException] { + spark.sql( + s""" + | INSERT INTO $tableName + | SELECT * from $sourceTableName + | """.stripMargin) + } + + // Only testing Spark 3.1 and above as lower Spark versions do not support + // ALTER TABLE .. SET TBLPROPERTIES .. to store table-level properties in Hudi Catalog + if (HoodieSparkUtils.gteqSpark3_1) { + // Now fix the partition path field write config for tableName + spark.sql( + s"""ALTER TABLE $tableName + | SET TBLPROPERTIES (hoodie.datasource.write.partitionpath.field = '$writePartitionFields') + | """.stripMargin) + + // INSERT INTO should succeed now + testFirstRoundInserts(tableName, TS_FORMATTER_FUNC, customPartitionFunc) + } + } + } + } + + private def testFirstRoundInserts(tableName: String, + tsGenFunc: Integer => String, + partitionGenFunc: (Integer, String) => String): Unit = { + val sourceTableName = tableName + "_source1" + prepareParquetSource(sourceTableName, Seq("(7, 'a7', 1399.0, 1706800227, 'cat1')")) + spark.sql( + s""" + | INSERT INTO $tableName + | SELECT * from $sourceTableName + | """.stripMargin) + validateResults( + tableName, + s"SELECT id, name, cast(price as string), cast(ts as string), segment from $tableName", + tsGenFunc, + partitionGenFunc, + Seq(), + Seq(1, "a1", "1.6", 1704121827, "cat1"), + Seq(2, "a2", "10.8", 1704121827, "cat1"), + Seq(3, "a3", "30.0", 1706800227, "cat1"), + Seq(4, "a4", "103.4", 1701443427, "cat2"), + Seq(5, "a5", "1999.0", 1704121827, "cat2"), + Seq(6, "a6", "80.0", 1704121827, "cat3"), + Seq(7, "a7", "1399.0", 1706800227, "cat1") + ) + } + + private def testSecondRoundInserts(tableName: String, + tsGenFunc: Integer => String, + partitionGenFunc: (Integer, String) => String): Unit = { + val sourceTableName = tableName + "_source2" + prepareParquetSource(sourceTableName, Seq("(8, 'a8', 26.9, 1706800227, 'cat3')")) + spark.sql( + s""" + | INSERT INTO $tableName + | SELECT * from $sourceTableName + | """.stripMargin) + validateResults( + tableName, + s"SELECT id, name, cast(price as string), cast(ts as string), segment from $tableName", + tsGenFunc, + partitionGenFunc, + Seq(), + Seq(1, "a1", "1.6", 1704121827, "cat1"), + Seq(2, "a2", "10.8", 1704121827, "cat1"), + Seq(3, "a3", "30.0", 1706800227, "cat1"), + Seq(4, "a4", "103.4", 1701443427, "cat2"), + Seq(5, "a5", "1999.0", 1704121827, "cat2"), + Seq(6, "a6", "80.0", 1704121827, "cat3"), + Seq(7, "a7", "1399.0", 1706800227, "cat1"), + Seq(8, "a8", "26.9", 1706800227, "cat3") + ) + } + + private def prepareTableWithKeyGenerator(tableName: String, + tablePath: String, + tableType: String, + keyGenClassName: String, + writePartitionFields: String, + timestampKeyGeneratorConfig: Map[String, String]): Unit = { + val df = spark.sql( + s"""SELECT 1 as id, 'a1' as name, 1.6 as price, 1704121827 as ts, 'cat1' as segment + | UNION + | SELECT 2 as id, 'a2' as name, 10.8 as price, 1704121827 as ts, 'cat1' as segment + | UNION + | SELECT 3 as id, 'a3' as name, 30.0 as price, 1706800227 as ts, 'cat1' as segment + | UNION + | SELECT 4 as id, 'a4' as name, 103.4 as price, 1701443427 as ts, 'cat2' as segment + | UNION + | SELECT 5 as id, 'a5' as name, 1999.0 as price, 1704121827 as ts, 'cat2' as segment + | UNION + | SELECT 6 as id, 'a6' as name, 80.0 as price, 1704121827 as ts, 'cat3' as segment + |""".stripMargin) + + df.write.format("hudi") + .option("hoodie.datasource.write.table.type", tableType) + .option("hoodie.datasource.write.keygenerator.class", keyGenClassName) + .option("hoodie.datasource.write.partitionpath.field", writePartitionFields) + .option("hoodie.datasource.write.recordkey.field", "id") + .option("hoodie.datasource.write.precombine.field", "name") + .option("hoodie.table.name", tableName) + .option("hoodie.insert.shuffle.parallelism", "1") + .option("hoodie.upsert.shuffle.parallelism", "1") + .option("hoodie.bulkinsert.shuffle.parallelism", "1") + .options(timestampKeyGeneratorConfig) + .mode(SaveMode.Overwrite) + .save(tablePath) + + // Validate that the generated table has expected table configs of key generator and partition path fields + val metaClient = HoodieTableMetaClient.builder() + .setConf(spark.sparkContext.hadoopConfiguration) + .setBasePath(tablePath) + .build() + assertEquals(keyGenClassName, metaClient.getTableConfig.getKeyGeneratorClassName) + // Validate that that partition path fields in the table config should always + // contain the field names only (no key generator type like "segment:simple") + if (CUSTOM_KEY_GEN_CLASS_NAME.equals(keyGenClassName)) { + val props = new TypedProperties() + props.put("hoodie.datasource.write.partitionpath.field", writePartitionFields) + timestampKeyGeneratorConfig.foreach(e => { + props.put(e._1, e._2) + }) + // For custom key generator, the "hoodie.datasource.write.partitionpath.field" + // contains the key generator type, like "ts:timestamp,segment:simple", + // whereas the partition path fields in table config is "ts,segment" + assertEquals( + SparkKeyGenUtils.getPartitionColumns(Option(CUSTOM_KEY_GEN_CLASS_NAME), props), + metaClient.getTableConfig.getPartitionFieldProp) + } else { + assertEquals(writePartitionFields, metaClient.getTableConfig.getPartitionFieldProp) + } + } + + private def createTableWithSql(tableName: String, + tablePath: String, + tblProps: String): Unit = { + val tblPropsStatement = if (StringUtils.isNullOrEmpty(tblProps)) { + "" + } else { + "TBLPROPERTIES (\n" + tblProps + "\n)" + } + spark.sql( + s""" + | CREATE TABLE $tableName USING HUDI + | location '$tablePath' + | $tblPropsStatement + | """.stripMargin) + } + + private def prepareParquetSource(sourceTableName: String, + rows: Seq[String]): Unit = { + spark.sql( + s"""CREATE TABLE $sourceTableName + | (id int, name string, price decimal(5, 1), ts int, segment string) + | USING PARQUET + |""".stripMargin) + spark.sql( + s""" + | INSERT INTO $sourceTableName values + | ${rows.mkString(", ")} + | """.stripMargin) + } + + private def validateResults(tableName: String, + sql: String, + tsGenFunc: Integer => String, + partitionGenFunc: (Integer, String) => String, + droppedPartitions: Seq[String], + expects: Seq[Any]*): Unit = { + checkAnswer(sql)( + expects.map(e => Seq(e(0), e(1), e(2), tsGenFunc.apply(e(3).asInstanceOf[Integer]), e(4))): _* + ) + val expectedPartitions: Seq[String] = expects + .map(e => partitionGenFunc.apply(e(3).asInstanceOf[Integer], e(4).asInstanceOf[String])) + .distinct.sorted + validatePartitions(tableName, droppedPartitions, expectedPartitions) + } + + private def getSortedTablePartitions(tableName: String): Seq[String] = { + spark.sql(s"SHOW PARTITIONS $tableName").collect() + .map(row => row.getString(0)) + .sorted.toSeq + } + + private def validatePartitions(tableName: String, + droppedPartitions: Seq[String], + expectedPartitions: Seq[String]): Unit = { + val actualPartitions: Seq[String] = getSortedTablePartitions(tableName) + if (expectedPartitions.size == 1 && expectedPartitions.head.isEmpty) { + assertTrue(actualPartitions.isEmpty) + } else { + assertEquals(expectedPartitions, actualPartitions) + } + droppedPartitions.foreach(dropped => assertFalse(actualPartitions.contains(dropped))) + } +} + +object TestSparkSqlWithCustomKeyGenerator { + val SIMPLE_KEY_GEN_CLASS_NAME = "org.apache.hudi.keygen.SimpleKeyGenerator" + val NONPARTITIONED_KEY_GEN_CLASS_NAME = "org.apache.hudi.keygen.NonpartitionedKeyGenerator" + val CUSTOM_KEY_GEN_CLASS_NAME = "org.apache.hudi.keygen.CustomKeyGenerator" + val DATE_FORMAT_PATTERN = "yyyyMM" + val TS_KEY_GEN_CONFIGS = Map( + "hoodie.keygen.timebased.timestamp.type" -> "SCALAR", + "hoodie.keygen.timebased.output.dateformat" -> DATE_FORMAT_PATTERN, + "hoodie.keygen.timebased.timestamp.scalar.time.unit" -> "seconds" + ) + val TS_TO_STRING_FUNC = (tsSeconds: Integer) => tsSeconds.toString + val TS_FORMATTER_FUNC = (tsSeconds: Integer) => { + new DateTime(tsSeconds * 1000L).toString(DateTimeFormat.forPattern(DATE_FORMAT_PATTERN)) + } + + def getTimestampKeyGenConfigs: Map[String, String] = { + Map( + "hoodie.keygen.timebased.timestamp.type" -> "SCALAR", + "hoodie.keygen.timebased.output.dateformat" -> DATE_FORMAT_PATTERN, + "hoodie.keygen.timebased.timestamp.scalar.time.unit" -> "seconds" + ) + } +}