This is an automated email from the ASF dual-hosted git repository.

yihua pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/hudi.git


The following commit(s) were added to refs/heads/master by this push:
     new 17ea14ab6d6 [HUDI-7378] Fix Spark SQL DML with custom key generator 
(#10615)
17ea14ab6d6 is described below

commit 17ea14ab6d6a8ca7ecef2cfcdbc67b0c87f23987
Author: Y Ethan Guo <ethan.guoyi...@gmail.com>
AuthorDate: Fri Apr 12 22:51:03 2024 -0700

    [HUDI-7378] Fix Spark SQL DML with custom key generator (#10615)
---
 .../factory/HoodieSparkKeyGeneratorFactory.java    |   4 +
 .../org/apache/hudi/util/SparkKeyGenUtils.scala    |  16 +-
 .../scala/org/apache/hudi/HoodieWriterUtils.scala  |  20 +-
 .../spark/sql/hudi/ProvidesHoodieConfig.scala      |  60 ++-
 .../spark/sql/hudi/TestProvidesHoodieConfig.scala  |  79 +++
 .../hudi/command/MergeIntoHoodieTableCommand.scala |   5 +-
 .../TestSparkSqlWithCustomKeyGenerator.scala       | 571 +++++++++++++++++++++
 7 files changed, 742 insertions(+), 13 deletions(-)

diff --git 
a/hudi-client/hudi-spark-client/src/main/java/org/apache/hudi/keygen/factory/HoodieSparkKeyGeneratorFactory.java
 
b/hudi-client/hudi-spark-client/src/main/java/org/apache/hudi/keygen/factory/HoodieSparkKeyGeneratorFactory.java
index 1ea5adcd6b4..dcc2eaec9eb 100644
--- 
a/hudi-client/hudi-spark-client/src/main/java/org/apache/hudi/keygen/factory/HoodieSparkKeyGeneratorFactory.java
+++ 
b/hudi-client/hudi-spark-client/src/main/java/org/apache/hudi/keygen/factory/HoodieSparkKeyGeneratorFactory.java
@@ -79,6 +79,10 @@ public class HoodieSparkKeyGeneratorFactory {
 
   public static KeyGenerator createKeyGenerator(TypedProperties props) throws 
IOException {
     String keyGeneratorClass = getKeyGeneratorClassName(props);
+    return createKeyGenerator(keyGeneratorClass, props);
+  }
+
+  public static KeyGenerator createKeyGenerator(String keyGeneratorClass, 
TypedProperties props) throws IOException {
     boolean autoRecordKeyGen = 
KeyGenUtils.isAutoGeneratedRecordKeysEnabled(props)
         //Need to prevent overwriting the keygen for spark sql merge into 
because we need to extract
         //the recordkey from the meta cols if it exists. Sql keygen will use 
pkless keygen if needed.
diff --git 
a/hudi-client/hudi-spark-client/src/main/scala/org/apache/hudi/util/SparkKeyGenUtils.scala
 
b/hudi-client/hudi-spark-client/src/main/scala/org/apache/hudi/util/SparkKeyGenUtils.scala
index 7b91ae5a728..bd094464096 100644
--- 
a/hudi-client/hudi-spark-client/src/main/scala/org/apache/hudi/util/SparkKeyGenUtils.scala
+++ 
b/hudi-client/hudi-spark-client/src/main/scala/org/apache/hudi/util/SparkKeyGenUtils.scala
@@ -21,8 +21,8 @@ import org.apache.hudi.common.config.TypedProperties
 import org.apache.hudi.common.util.StringUtils
 import org.apache.hudi.common.util.ValidationUtils.checkArgument
 import org.apache.hudi.keygen.constant.KeyGeneratorOptions
-import org.apache.hudi.keygen.{AutoRecordKeyGeneratorWrapper, 
AutoRecordGenWrapperKeyGenerator, CustomAvroKeyGenerator, CustomKeyGenerator, 
GlobalAvroDeleteKeyGenerator, GlobalDeleteKeyGenerator, KeyGenerator, 
NonpartitionedAvroKeyGenerator, NonpartitionedKeyGenerator}
 import org.apache.hudi.keygen.factory.HoodieSparkKeyGeneratorFactory
+import org.apache.hudi.keygen.{AutoRecordKeyGeneratorWrapper, 
CustomAvroKeyGenerator, CustomKeyGenerator, GlobalAvroDeleteKeyGenerator, 
GlobalDeleteKeyGenerator, KeyGenerator, NonpartitionedAvroKeyGenerator, 
NonpartitionedKeyGenerator}
 
 object SparkKeyGenUtils {
 
@@ -35,6 +35,20 @@ object SparkKeyGenUtils {
     getPartitionColumns(keyGenerator, props)
   }
 
+  /**
+   * @param KeyGenClassNameOption key generator class name if present.
+   * @param props                 config properties.
+   * @return partition column names only, concatenated by ","
+   */
+  def getPartitionColumns(KeyGenClassNameOption: Option[String], props: 
TypedProperties): String = {
+    val keyGenerator = if (KeyGenClassNameOption.isEmpty) {
+      HoodieSparkKeyGeneratorFactory.createKeyGenerator(props)
+    } else {
+      
HoodieSparkKeyGeneratorFactory.createKeyGenerator(KeyGenClassNameOption.get, 
props)
+    }
+    getPartitionColumns(keyGenerator, props)
+  }
+
   /**
    * @param keyGen key generator class name
    * @return partition columns
diff --git 
a/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/HoodieWriterUtils.scala
 
b/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/HoodieWriterUtils.scala
index 63495b0eede..5df773542d6 100644
--- 
a/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/HoodieWriterUtils.scala
+++ 
b/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/HoodieWriterUtils.scala
@@ -201,8 +201,26 @@ object HoodieWriterUtils {
           
diffConfigs.append(s"KeyGenerator:\t$datasourceKeyGen\t$tableConfigKeyGen\n")
         }
 
+        // Please note that the validation of partition path fields needs the 
key generator class
+        // for the table, since the custom key generator expects a different 
format of
+        // the value of the write config 
"hoodie.datasource.write.partitionpath.field"
+        // e.g., "col:simple,ts:timestamp", whereas the table config 
"hoodie.table.partition.fields"
+        // in hoodie.properties stores "col,ts".
+        // The "params" here may only contain the write config of partition 
path field,
+        // so we need to pass in the validated key generator class name.
+        val validatedKeyGenClassName = if (tableConfigKeyGen != null) {
+          Option(tableConfigKeyGen)
+        } else if (datasourceKeyGen != null) {
+          Option(datasourceKeyGen)
+        } else {
+          None
+        }
         val datasourcePartitionFields = 
params.getOrElse(PARTITIONPATH_FIELD.key(), null)
-        val currentPartitionFields = if (datasourcePartitionFields == null) 
null else SparkKeyGenUtils.getPartitionColumns(TypedProperties.fromMap(params))
+        val currentPartitionFields = if (datasourcePartitionFields == null) {
+          null
+        } else {
+          SparkKeyGenUtils.getPartitionColumns(validatedKeyGenClassName, 
TypedProperties.fromMap(params))
+        }
         val tableConfigPartitionFields = 
tableConfig.getString(HoodieTableConfig.PARTITION_FIELDS)
         if (null != datasourcePartitionFields && null != 
tableConfigPartitionFields
           && currentPartitionFields != tableConfigPartitionFields) {
diff --git 
a/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/spark/sql/hudi/ProvidesHoodieConfig.scala
 
b/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/spark/sql/hudi/ProvidesHoodieConfig.scala
index a4003bbd480..7d35c490fd4 100644
--- 
a/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/spark/sql/hudi/ProvidesHoodieConfig.scala
+++ 
b/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/spark/sql/hudi/ProvidesHoodieConfig.scala
@@ -18,34 +18,36 @@
 package org.apache.spark.sql.hudi
 
 import 
org.apache.hudi.AutoRecordKeyGenerationUtils.shouldAutoGenerateRecordKeys
-import org.apache.hudi.{DataSourceWriteOptions, HoodieFileIndex}
 import org.apache.hudi.DataSourceWriteOptions._
 import org.apache.hudi.HoodieConversionUtils.toProperties
 import org.apache.hudi.common.config.{DFSPropertiesConfiguration, 
TypedProperties}
 import org.apache.hudi.common.model.{DefaultHoodieRecordPayload, 
WriteOperationType}
 import org.apache.hudi.common.table.HoodieTableConfig
+import org.apache.hudi.common.util.{ReflectionUtils, StringUtils}
 import org.apache.hudi.config.HoodieWriteConfig.TBL_NAME
 import org.apache.hudi.config.{HoodieIndexConfig, HoodieInternalConfig, 
HoodieWriteConfig}
 import org.apache.hudi.hive.ddl.HiveSyncMode
 import org.apache.hudi.hive.{HiveSyncConfig, HiveSyncConfigHolder, 
MultiPartKeysValueExtractor}
-import org.apache.hudi.keygen.ComplexKeyGenerator
+import org.apache.hudi.keygen.{ComplexKeyGenerator, CustomAvroKeyGenerator, 
CustomKeyGenerator}
 import org.apache.hudi.sql.InsertMode
 import org.apache.hudi.sync.common.HoodieSyncConfig
+import org.apache.hudi.{DataSourceWriteOptions, HoodieFileIndex}
 import org.apache.spark.internal.Logging
-import org.apache.spark.sql.{SaveMode, SparkSession}
 import org.apache.spark.sql.catalyst.catalog.HoodieCatalogTable
 import org.apache.spark.sql.catalyst.expressions.{AttributeReference, EqualTo, 
Literal}
 import org.apache.spark.sql.execution.datasources.FileStatusCache
 import org.apache.spark.sql.hive.HiveExternalCatalog
 import 
org.apache.spark.sql.hudi.HoodieOptionConfig.mapSqlOptionsToDataSourceWriteConfigs
 import org.apache.spark.sql.hudi.HoodieSqlCommonUtils.{isHoodieConfigKey, 
isUsingHiveCatalog}
-import org.apache.spark.sql.hudi.ProvidesHoodieConfig.combineOptions
+import org.apache.spark.sql.hudi.ProvidesHoodieConfig.{combineOptions, 
getPartitionPathFieldWriteConfig}
 import org.apache.spark.sql.hudi.command.{SqlKeyGenerator, 
ValidateDuplicateKeyPayload}
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.internal.SQLConf.PARTITION_OVERWRITE_MODE
 import org.apache.spark.sql.types.StructType
-import java.util.Locale
+import org.apache.spark.sql.{SaveMode, SparkSession}
+import org.slf4j.LoggerFactory
 
+import java.util.Locale
 import scala.collection.JavaConverters._
 
 trait ProvidesHoodieConfig extends Logging {
@@ -82,7 +84,8 @@ trait ProvidesHoodieConfig extends Logging {
       PRECOMBINE_FIELD.key -> preCombineField,
       HIVE_STYLE_PARTITIONING.key -> 
tableConfig.getHiveStylePartitioningEnable,
       URL_ENCODE_PARTITIONING.key -> tableConfig.getUrlEncodePartitioning,
-      PARTITIONPATH_FIELD.key -> tableConfig.getPartitionFieldProp
+      PARTITIONPATH_FIELD.key -> getPartitionPathFieldWriteConfig(
+        tableConfig.getKeyGeneratorClassName, 
tableConfig.getPartitionFieldProp, hoodieCatalogTable)
     )
 
     combineOptions(hoodieCatalogTable, tableConfig, 
sparkSession.sqlContext.conf,
@@ -313,7 +316,8 @@ trait ProvidesHoodieConfig extends Logging {
       URL_ENCODE_PARTITIONING.key -> urlEncodePartitioning,
       RECORDKEY_FIELD.key -> recordKeyConfigValue,
       PRECOMBINE_FIELD.key -> preCombineField,
-      PARTITIONPATH_FIELD.key -> partitionFieldsStr
+      PARTITIONPATH_FIELD.key -> getPartitionPathFieldWriteConfig(
+        keyGeneratorClassName, partitionFieldsStr, hoodieCatalogTable)
     ) ++ overwriteTableOpts ++ getDropDupsConfig(useLegacyInsertModeFlow, 
combinedOpts) ++ staticOverwritePartitionPathOptions
 
     combineOptions(hoodieCatalogTable, tableConfig, 
sparkSession.sqlContext.conf,
@@ -405,7 +409,8 @@ trait ProvidesHoodieConfig extends Logging {
       PARTITIONS_TO_DELETE.key -> partitionsToDrop,
       RECORDKEY_FIELD.key -> hoodieCatalogTable.primaryKeys.mkString(","),
       PRECOMBINE_FIELD.key -> hoodieCatalogTable.preCombineKey.getOrElse(""),
-      PARTITIONPATH_FIELD.key -> partitionFields,
+      PARTITIONPATH_FIELD.key -> getPartitionPathFieldWriteConfig(
+        tableConfig.getKeyGeneratorClassName, partitionFields, 
hoodieCatalogTable),
       HoodieSyncConfig.META_SYNC_ENABLED.key -> 
hiveSyncConfig.getString(HoodieSyncConfig.META_SYNC_ENABLED.key),
       HiveSyncConfigHolder.HIVE_SYNC_ENABLED.key -> 
hiveSyncConfig.getString(HiveSyncConfigHolder.HIVE_SYNC_ENABLED.key),
       HiveSyncConfigHolder.HIVE_SYNC_MODE.key -> 
hiveSyncConfig.getStringOrDefault(HiveSyncConfigHolder.HIVE_SYNC_MODE, 
HiveSyncMode.HMS.name()),
@@ -451,7 +456,8 @@ trait ProvidesHoodieConfig extends Logging {
       HIVE_STYLE_PARTITIONING.key -> 
tableConfig.getHiveStylePartitioningEnable,
       URL_ENCODE_PARTITIONING.key -> tableConfig.getUrlEncodePartitioning,
       OPERATION.key -> DataSourceWriteOptions.DELETE_OPERATION_OPT_VAL,
-      PARTITIONPATH_FIELD.key -> tableConfig.getPartitionFieldProp
+      PARTITIONPATH_FIELD.key -> getPartitionPathFieldWriteConfig(
+        tableConfig.getKeyGeneratorClassName, 
tableConfig.getPartitionFieldProp, hoodieCatalogTable)
     )
 
     combineOptions(hoodieCatalogTable, tableConfig, 
sparkSession.sqlContext.conf,
@@ -496,6 +502,8 @@ trait ProvidesHoodieConfig extends Logging {
 
 object ProvidesHoodieConfig {
 
+  private val log = LoggerFactory.getLogger(getClass)
+
   // NOTE: PLEASE READ CAREFULLY BEFORE CHANGING
   //
   // Spark SQL operations configuration might be coming from a variety of 
diverse sources
@@ -530,6 +538,40 @@ object ProvidesHoodieConfig {
       filterNullValues(overridingOpts)
   }
 
+  /**
+   * @param tableConfigKeyGeneratorClassName     key generator class name in 
the table config.
+   * @param partitionFieldNamesWithoutKeyGenType partition field names without 
key generator types
+   *                                             from the table config.
+   * @param catalogTable                         HoodieCatalogTable instance 
to fetch table properties.
+   * @return the write config value to set for 
"hoodie.datasource.write.partitionpath.field".
+   */
+  def getPartitionPathFieldWriteConfig(tableConfigKeyGeneratorClassName: 
String,
+                                       partitionFieldNamesWithoutKeyGenType: 
String,
+                                       catalogTable: HoodieCatalogTable): 
String = {
+    if (StringUtils.isNullOrEmpty(tableConfigKeyGeneratorClassName)) {
+      partitionFieldNamesWithoutKeyGenType
+    } else {
+      val writeConfigPartitionField = 
catalogTable.catalogProperties.get(PARTITIONPATH_FIELD.key())
+      val keyGenClass = 
ReflectionUtils.getClass(tableConfigKeyGeneratorClassName)
+      if (classOf[CustomKeyGenerator].equals(keyGenClass)
+        || classOf[CustomAvroKeyGenerator].equals(keyGenClass)) {
+        // For custom key generator, we have to take the write config value 
from
+        // "hoodie.datasource.write.partitionpath.field" which contains the 
key generator
+        // type, whereas the table config only contains the prtition field 
names without
+        // key generator types.
+        if (writeConfigPartitionField.isDefined) {
+          writeConfigPartitionField.get
+        } else {
+          log.warn("Write config 
\"hoodie.datasource.write.partitionpath.field\" is not set for "
+            + "custom key generator. This may fail the write operation.")
+          partitionFieldNamesWithoutKeyGenType
+        }
+      } else {
+        partitionFieldNamesWithoutKeyGenType
+      }
+    }
+  }
+
   private def filterNullValues(opts: Map[String, String]): Map[String, String] 
=
     opts.filter { case (_, v) => v != null }
 
diff --git 
a/hudi-spark-datasource/hudi-spark-common/src/test/scala/org/apache/spark/sql/hudi/TestProvidesHoodieConfig.scala
 
b/hudi-spark-datasource/hudi-spark-common/src/test/scala/org/apache/spark/sql/hudi/TestProvidesHoodieConfig.scala
new file mode 100644
index 00000000000..8414e41ca6c
--- /dev/null
+++ 
b/hudi-spark-datasource/hudi-spark-common/src/test/scala/org/apache/spark/sql/hudi/TestProvidesHoodieConfig.scala
@@ -0,0 +1,79 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.spark.sql.hudi
+
+import org.apache.hudi.DataSourceWriteOptions.PARTITIONPATH_FIELD
+import org.apache.hudi.keygen.{ComplexKeyGenerator, CustomKeyGenerator}
+
+import org.apache.spark.sql.catalyst.catalog.HoodieCatalogTable
+import org.junit.jupiter.api.Assertions.assertEquals
+import org.junit.jupiter.api.Test
+import org.mockito.Mockito
+import org.mockito.Mockito.when
+
+/**
+ * Tests {@link ProvidesHoodieConfig}
+ */
+class TestProvidesHoodieConfig {
+  @Test
+  def testGetPartitionPathFieldWriteConfig(): Unit = {
+    val mockTable = Mockito.mock(classOf[HoodieCatalogTable])
+    val partitionFieldNames = "ts,segment"
+    val customKeyGenPartitionFieldWriteConfig = "ts:timestamp,segment:simple"
+
+    mockPartitionWriteConfigInCatalogProps(mockTable, None)
+    assertEquals(
+      partitionFieldNames,
+      ProvidesHoodieConfig.getPartitionPathFieldWriteConfig(
+        "", partitionFieldNames, mockTable))
+    assertEquals(
+      partitionFieldNames,
+      ProvidesHoodieConfig.getPartitionPathFieldWriteConfig(
+        classOf[ComplexKeyGenerator].getName, partitionFieldNames, mockTable))
+    assertEquals(
+      partitionFieldNames,
+      ProvidesHoodieConfig.getPartitionPathFieldWriteConfig(
+        classOf[CustomKeyGenerator].getName, partitionFieldNames, mockTable))
+
+    mockPartitionWriteConfigInCatalogProps(mockTable, 
Option(customKeyGenPartitionFieldWriteConfig))
+    assertEquals(
+      partitionFieldNames,
+      ProvidesHoodieConfig.getPartitionPathFieldWriteConfig(
+        "", partitionFieldNames, mockTable))
+    assertEquals(
+      partitionFieldNames,
+      ProvidesHoodieConfig.getPartitionPathFieldWriteConfig(
+        classOf[ComplexKeyGenerator].getName, partitionFieldNames, mockTable))
+    assertEquals(
+      customKeyGenPartitionFieldWriteConfig,
+      ProvidesHoodieConfig.getPartitionPathFieldWriteConfig(
+        classOf[CustomKeyGenerator].getName, partitionFieldNames, mockTable))
+  }
+
+  private def mockPartitionWriteConfigInCatalogProps(mockTable: 
HoodieCatalogTable,
+                                                     value: Option[String]): 
Unit = {
+    val props = if (value.isDefined) {
+      Map(PARTITIONPATH_FIELD.key() -> value.get)
+    } else {
+      Map[String, String]()
+    }
+    when(mockTable.catalogProperties).thenReturn(props)
+  }
+}
diff --git 
a/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/MergeIntoHoodieTableCommand.scala
 
b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/MergeIntoHoodieTableCommand.scala
index 18403872f4a..79cd2646e08 100644
--- 
a/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/MergeIntoHoodieTableCommand.scala
+++ 
b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/MergeIntoHoodieTableCommand.scala
@@ -40,7 +40,7 @@ import org.apache.spark.sql.catalyst.plans.LeftOuter
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.hudi.HoodieSqlCommonUtils._
 import org.apache.spark.sql.hudi.ProvidesHoodieConfig
-import org.apache.spark.sql.hudi.ProvidesHoodieConfig.combineOptions
+import org.apache.spark.sql.hudi.ProvidesHoodieConfig.{combineOptions, 
getPartitionPathFieldWriteConfig}
 import org.apache.spark.sql.hudi.analysis.HoodieAnalysis.failAnalysis
 import 
org.apache.spark.sql.hudi.command.MergeIntoHoodieTableCommand.{CoercedAttributeReference,
 encodeAsBase64String, stripCasting, toStructType}
 import 
org.apache.spark.sql.hudi.command.PartialAssignmentMode.PartialAssignmentMode
@@ -729,7 +729,8 @@ case class MergeIntoHoodieTableCommand(mergeInto: 
MergeIntoTable) extends Hoodie
       RECORDKEY_FIELD.key -> tableConfig.getRawRecordKeyFieldProp,
       PRECOMBINE_FIELD.key -> preCombineField,
       TBL_NAME.key -> hoodieCatalogTable.tableName,
-      PARTITIONPATH_FIELD.key -> tableConfig.getPartitionFieldProp,
+      PARTITIONPATH_FIELD.key -> getPartitionPathFieldWriteConfig(
+        tableConfig.getKeyGeneratorClassName, 
tableConfig.getPartitionFieldProp, hoodieCatalogTable),
       HIVE_STYLE_PARTITIONING.key -> 
tableConfig.getHiveStylePartitioningEnable,
       URL_ENCODE_PARTITIONING.key -> tableConfig.getUrlEncodePartitioning,
       KEYGENERATOR_CLASS_NAME.key -> keyGeneratorClassName,
diff --git 
a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestSparkSqlWithCustomKeyGenerator.scala
 
b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestSparkSqlWithCustomKeyGenerator.scala
new file mode 100644
index 00000000000..c85eb40bca7
--- /dev/null
+++ 
b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestSparkSqlWithCustomKeyGenerator.scala
@@ -0,0 +1,571 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.hudi.functional
+
+import org.apache.hudi.HoodieSparkUtils
+import org.apache.hudi.common.config.TypedProperties
+import org.apache.hudi.common.table.HoodieTableMetaClient
+import org.apache.hudi.common.util.StringUtils
+import org.apache.hudi.exception.HoodieException
+import org.apache.hudi.functional.TestSparkSqlWithCustomKeyGenerator._
+import org.apache.hudi.util.SparkKeyGenUtils
+import org.apache.spark.sql.SaveMode
+import org.apache.spark.sql.hudi.common.HoodieSparkSqlTestBase
+import org.joda.time.DateTime
+import org.joda.time.format.DateTimeFormat
+import org.junit.jupiter.api.Assertions.{assertEquals, assertFalse, assertTrue}
+import org.slf4j.LoggerFactory
+
+import java.io.IOException
+
+/**
+ * Tests Spark SQL DML with custom key generator and write configs.
+ */
+class TestSparkSqlWithCustomKeyGenerator extends HoodieSparkSqlTestBase {
+  private val LOG = LoggerFactory.getLogger(getClass)
+
+  test("Test Spark SQL DML with custom key generator") {
+    withTempDir { tmp =>
+      Seq(
+        Seq("COPY_ON_WRITE", "ts:timestamp,segment:simple",
+          "(ts=202401, segment='cat2')", "202401/cat2",
+          Seq("202312/cat2", "202312/cat4", "202401/cat1", "202401/cat3", 
"202402/cat1", "202402/cat3", "202402/cat5"),
+          TS_FORMATTER_FUNC,
+          (ts: Integer, segment: String) => TS_FORMATTER_FUNC.apply(ts) + "/" 
+ segment),
+        Seq("MERGE_ON_READ", "segment:simple",
+          "(segment='cat3')", "cat3",
+          Seq("cat1", "cat2", "cat4", "cat5"),
+          TS_TO_STRING_FUNC,
+          (_: Integer, segment: String) => segment),
+        Seq("MERGE_ON_READ", "ts:timestamp",
+          "(ts=202312)", "202312",
+          Seq("202401", "202402"),
+          TS_FORMATTER_FUNC,
+          (ts: Integer, _: String) => TS_FORMATTER_FUNC.apply(ts)),
+        Seq("MERGE_ON_READ", "ts:timestamp,segment:simple",
+          "(ts=202401, segment='cat2')", "202401/cat2",
+          Seq("202312/cat2", "202312/cat4", "202401/cat1", "202401/cat3", 
"202402/cat1", "202402/cat3", "202402/cat5"),
+          TS_FORMATTER_FUNC,
+          (ts: Integer, segment: String) => TS_FORMATTER_FUNC.apply(ts) + "/" 
+ segment)
+      ).foreach { testParams =>
+        withTable(generateTableName) { tableName =>
+          LOG.warn("Testing with parameters: " + testParams)
+          val tableType = testParams(0).asInstanceOf[String]
+          val writePartitionFields = testParams(1).asInstanceOf[String]
+          val dropPartitionStatement = testParams(2).asInstanceOf[String]
+          val droppedPartition = testParams(3).asInstanceOf[String]
+          val expectedPartitions = testParams(4).asInstanceOf[Seq[String]]
+          val tsGenFunc = testParams(5).asInstanceOf[Integer => String]
+          val partitionGenFunc = testParams(6).asInstanceOf[(Integer, String) 
=> String]
+          val tablePath = tmp.getCanonicalPath + "/" + tableName
+          val timestampKeyGeneratorConfig = if 
(writePartitionFields.contains("timestamp")) {
+            TS_KEY_GEN_CONFIGS
+          } else {
+            Map[String, String]()
+          }
+          val timestampKeyGenProps = if (timestampKeyGeneratorConfig.nonEmpty) 
{
+            ", " + timestampKeyGeneratorConfig.map(e => e._1 + " = '" + e._2 + 
"'").mkString(", ")
+          } else {
+            ""
+          }
+
+          prepareTableWithKeyGenerator(
+            tableName, tablePath, tableType,
+            CUSTOM_KEY_GEN_CLASS_NAME, writePartitionFields, 
timestampKeyGeneratorConfig)
+
+          // SQL CTAS with table properties containing key generator write 
configs
+          createTableWithSql(tableName, tablePath,
+            s"hoodie.datasource.write.partitionpath.field = 
'$writePartitionFields'" + timestampKeyGenProps)
+
+          // Prepare source and test SQL INSERT INTO
+          val sourceTableName = tableName + "_source"
+          prepareParquetSource(sourceTableName, Seq(
+            "(7, 'a7', 1399.0, 1706800227, 'cat1')",
+            "(8, 'a8', 26.9, 1706800227, 'cat3')",
+            "(9, 'a9', 299.0, 1701443427, 'cat4')"))
+          spark.sql(
+            s"""
+               | INSERT INTO $tableName
+               | SELECT * from ${tableName}_source
+               | """.stripMargin)
+          validateResults(
+            tableName,
+            s"SELECT id, name, cast(price as string), cast(ts as string), 
segment from $tableName",
+            tsGenFunc,
+            partitionGenFunc,
+            Seq(),
+            Seq(1, "a1", "1.6", 1704121827, "cat1"),
+            Seq(2, "a2", "10.8", 1704121827, "cat1"),
+            Seq(3, "a3", "30.0", 1706800227, "cat1"),
+            Seq(4, "a4", "103.4", 1701443427, "cat2"),
+            Seq(5, "a5", "1999.0", 1704121827, "cat2"),
+            Seq(6, "a6", "80.0", 1704121827, "cat3"),
+            Seq(7, "a7", "1399.0", 1706800227, "cat1"),
+            Seq(8, "a8", "26.9", 1706800227, "cat3"),
+            Seq(9, "a9", "299.0", 1701443427, "cat4")
+          )
+
+          // Test SQL UPDATE
+          spark.sql(
+            s"""
+               | UPDATE $tableName
+               | SET price = price + 10.0
+               | WHERE id between 4 and 7
+               | """.stripMargin)
+          validateResults(
+            tableName,
+            s"SELECT id, name, cast(price as string), cast(ts as string), 
segment from $tableName",
+            tsGenFunc,
+            partitionGenFunc,
+            Seq(),
+            Seq(1, "a1", "1.6", 1704121827, "cat1"),
+            Seq(2, "a2", "10.8", 1704121827, "cat1"),
+            Seq(3, "a3", "30.0", 1706800227, "cat1"),
+            Seq(4, "a4", "113.4", 1701443427, "cat2"),
+            Seq(5, "a5", "2009.0", 1704121827, "cat2"),
+            Seq(6, "a6", "90.0", 1704121827, "cat3"),
+            Seq(7, "a7", "1409.0", 1706800227, "cat1"),
+            Seq(8, "a8", "26.9", 1706800227, "cat3"),
+            Seq(9, "a9", "299.0", 1701443427, "cat4")
+          )
+
+          // Test SQL MERGE INTO
+          spark.sql(
+            s"""
+               | MERGE INTO $tableName as target
+               | USING (
+               |   SELECT 1 as id, 'a1' as name, 1.6 as price, 1704121827 as 
ts, 'cat1' as segment, 'delete' as flag
+               |   UNION
+               |   SELECT 2 as id, 'a2' as name, 11.9 as price, 1704121827 as 
ts, 'cat1' as segment, '' as flag
+               |   UNION
+               |   SELECT 6 as id, 'a6' as name, 99.0 as price, 1704121827 as 
ts, 'cat3' as segment, '' as flag
+               |   UNION
+               |   SELECT 8 as id, 'a8' as name, 24.9 as price, 1706800227 as 
ts, 'cat3' as segment, '' as flag
+               |   UNION
+               |   SELECT 10 as id, 'a10' as name, 888.8 as price, 1706800227 
as ts, 'cat5' as segment, '' as flag
+               | ) source
+               | on target.id = source.id
+               | WHEN MATCHED AND flag != 'delete' THEN UPDATE SET
+               |   id = source.id, name = source.name, price = source.price, 
ts = source.ts, segment = source.segment
+               | WHEN MATCHED AND flag = 'delete' THEN DELETE
+               | WHEN NOT MATCHED THEN INSERT (id, name, price, ts, segment)
+               |   values (source.id, source.name, source.price, source.ts, 
source.segment)
+               | """.stripMargin)
+          validateResults(
+            tableName,
+            s"SELECT id, name, cast(price as string), cast(ts as string), 
segment from $tableName",
+            tsGenFunc,
+            partitionGenFunc,
+            Seq(),
+            Seq(2, "a2", "11.9", 1704121827, "cat1"),
+            Seq(3, "a3", "30.0", 1706800227, "cat1"),
+            Seq(4, "a4", "113.4", 1701443427, "cat2"),
+            Seq(5, "a5", "2009.0", 1704121827, "cat2"),
+            Seq(6, "a6", "99.0", 1704121827, "cat3"),
+            Seq(7, "a7", "1409.0", 1706800227, "cat1"),
+            Seq(8, "a8", "24.9", 1706800227, "cat3"),
+            Seq(9, "a9", "299.0", 1701443427, "cat4"),
+            Seq(10, "a10", "888.8", 1706800227, "cat5")
+          )
+
+          // Test SQL DELETE
+          spark.sql(
+            s"""
+               | DELETE FROM $tableName
+               | WHERE id = 7
+               | """.stripMargin)
+          validateResults(
+            tableName,
+            s"SELECT id, name, cast(price as string), cast(ts as string), 
segment from $tableName",
+            tsGenFunc,
+            partitionGenFunc,
+            Seq(),
+            Seq(2, "a2", "11.9", 1704121827, "cat1"),
+            Seq(3, "a3", "30.0", 1706800227, "cat1"),
+            Seq(4, "a4", "113.4", 1701443427, "cat2"),
+            Seq(5, "a5", "2009.0", 1704121827, "cat2"),
+            Seq(6, "a6", "99.0", 1704121827, "cat3"),
+            Seq(8, "a8", "24.9", 1706800227, "cat3"),
+            Seq(9, "a9", "299.0", 1701443427, "cat4"),
+            Seq(10, "a10", "888.8", 1706800227, "cat5")
+          )
+
+          // Test DROP PARTITION
+          
assertTrue(getSortedTablePartitions(tableName).contains(droppedPartition))
+          spark.sql(
+            s"""
+               | ALTER TABLE $tableName DROP PARTITION $dropPartitionStatement
+               |""".stripMargin)
+          validatePartitions(tableName, Seq(droppedPartition), 
expectedPartitions)
+
+          if (HoodieSparkUtils.isSpark3) {
+            // Test INSERT OVERWRITE, only supported in Spark 3.x
+            spark.sql(
+              s"""
+                 | INSERT OVERWRITE $tableName
+                 | SELECT 100 as id, 'a100' as name, 299.0 as price, 
1706800227 as ts, 'cat10' as segment
+                 | """.stripMargin)
+            validateResults(
+              tableName,
+              s"SELECT id, name, cast(price as string), cast(ts as string), 
segment from $tableName",
+              tsGenFunc,
+              partitionGenFunc,
+              Seq(),
+              Seq(100, "a100", "299.0", 1706800227, "cat10")
+            )
+          }
+        }
+      }
+    }
+  }
+
+  test("Test table property isolation for partition path field config "
+    + "with custom key generator for Spark 3.1 and above") {
+    // Only testing Spark 3.1 and above as lower Spark versions do not support
+    // ALTER TABLE .. SET TBLPROPERTIES .. to store table-level properties in 
Hudi Catalog
+    if (HoodieSparkUtils.gteqSpark3_1) {
+      withTempDir { tmp => {
+        val tableNameNonPartitioned = generateTableName
+        val tableNameSimpleKey = generateTableName
+        val tableNameCustom1 = generateTableName
+        val tableNameCustom2 = generateTableName
+
+        val tablePathNonPartitioned = tmp.getCanonicalPath + "/" + 
tableNameNonPartitioned
+        val tablePathSimpleKey = tmp.getCanonicalPath + "/" + 
tableNameSimpleKey
+        val tablePathCustom1 = tmp.getCanonicalPath + "/" + tableNameCustom1
+        val tablePathCustom2 = tmp.getCanonicalPath + "/" + tableNameCustom2
+
+        val tableType = "MERGE_ON_READ"
+        val writePartitionFields1 = "segment:simple"
+        val writePartitionFields2 = "ts:timestamp,segment:simple"
+
+        prepareTableWithKeyGenerator(
+          tableNameNonPartitioned, tablePathNonPartitioned, tableType,
+          NONPARTITIONED_KEY_GEN_CLASS_NAME, "", Map())
+        prepareTableWithKeyGenerator(
+          tableNameSimpleKey, tablePathSimpleKey, tableType,
+          SIMPLE_KEY_GEN_CLASS_NAME, "segment", Map())
+        prepareTableWithKeyGenerator(
+          tableNameCustom1, tablePathCustom1, tableType,
+          CUSTOM_KEY_GEN_CLASS_NAME, writePartitionFields1, Map())
+        prepareTableWithKeyGenerator(
+          tableNameCustom2, tablePathCustom2, tableType,
+          CUSTOM_KEY_GEN_CLASS_NAME, writePartitionFields2, TS_KEY_GEN_CONFIGS)
+
+        // Non-partitioned table does not require additional partition path 
field write config
+        createTableWithSql(tableNameNonPartitioned, tablePathNonPartitioned, 
"")
+        // Partitioned table with simple key generator does not require 
additional partition path field write config
+        createTableWithSql(tableNameSimpleKey, tablePathSimpleKey, "")
+        // Partitioned table with custom key generator requires additional 
partition path field write config
+        // Without that, right now the SQL DML fails
+        createTableWithSql(tableNameCustom1, tablePathCustom1, "")
+        createTableWithSql(tableNameCustom2, tablePathCustom2,
+          s"hoodie.datasource.write.partitionpath.field = 
'$writePartitionFields2', "
+            + TS_KEY_GEN_CONFIGS.map(e => e._1 + " = '" + e._2 + 
"'").mkString(", "))
+
+        val segmentPartitionFunc = (_: Integer, segment: String) => segment
+        val customPartitionFunc = (ts: Integer, segment: String) => 
TS_FORMATTER_FUNC.apply(ts) + "/" + segment
+
+        testFirstRoundInserts(tableNameNonPartitioned, TS_TO_STRING_FUNC, (_, 
_) => "")
+        testFirstRoundInserts(tableNameSimpleKey, TS_TO_STRING_FUNC, 
segmentPartitionFunc)
+        // INSERT INTO should fail for tableNameCustom1
+        val sourceTableName = tableNameCustom1 + "_source"
+        prepareParquetSource(sourceTableName, Seq("(7, 'a7', 1399.0, 
1706800227, 'cat1')"))
+        assertThrows[IOException] {
+          spark.sql(
+            s"""
+               | INSERT INTO $tableNameCustom1
+               | SELECT * from $sourceTableName
+               | """.stripMargin)
+        }
+        testFirstRoundInserts(tableNameCustom2, TS_FORMATTER_FUNC, 
customPartitionFunc)
+
+        // Now add the missing partition path field write config for 
tableNameCustom1
+        spark.sql(
+          s"""ALTER TABLE $tableNameCustom1
+             | SET TBLPROPERTIES (hoodie.datasource.write.partitionpath.field 
= '$writePartitionFields1')
+             | """.stripMargin)
+
+        // All tables should be able to do INSERT INTO without any problem,
+        // since the scope of the added write config is at the catalog table 
level
+        testSecondRoundInserts(tableNameNonPartitioned, TS_TO_STRING_FUNC, (_, 
_) => "")
+        testSecondRoundInserts(tableNameSimpleKey, TS_TO_STRING_FUNC, 
segmentPartitionFunc)
+        testFirstRoundInserts(tableNameCustom1, TS_TO_STRING_FUNC, 
segmentPartitionFunc)
+        testSecondRoundInserts(tableNameCustom2, TS_FORMATTER_FUNC, 
customPartitionFunc)
+      }
+      }
+    }
+  }
+
+  test("Test wrong partition path field write config with custom key 
generator") {
+    withTempDir { tmp => {
+      val tableName = generateTableName
+      val tablePath = tmp.getCanonicalPath + "/" + tableName
+      val tableType = "MERGE_ON_READ"
+      val writePartitionFields = "segment:simple,ts:timestamp"
+      val wrongWritePartitionFields = "segment:simple"
+      val customPartitionFunc = (ts: Integer, segment: String) => segment + 
"/" + TS_FORMATTER_FUNC.apply(ts)
+
+      prepareTableWithKeyGenerator(
+        tableName, tablePath, "MERGE_ON_READ",
+        CUSTOM_KEY_GEN_CLASS_NAME, writePartitionFields, TS_KEY_GEN_CONFIGS)
+
+      // CREATE TABLE should fail due to config conflict
+      assertThrows[HoodieException] {
+        createTableWithSql(tableName, tablePath,
+          s"hoodie.datasource.write.partitionpath.field = 
'$wrongWritePartitionFields', "
+            + TS_KEY_GEN_CONFIGS.map(e => e._1 + " = '" + e._2 + 
"'").mkString(", "))
+      }
+
+      createTableWithSql(tableName, tablePath,
+        s"hoodie.datasource.write.partitionpath.field = 
'$writePartitionFields', "
+          + TS_KEY_GEN_CONFIGS.map(e => e._1 + " = '" + e._2 + 
"'").mkString(", "))
+      // Set wrong write config
+      spark.sql(
+        s"""ALTER TABLE $tableName
+           | SET TBLPROPERTIES (hoodie.datasource.write.partitionpath.field = 
'$wrongWritePartitionFields')
+           | """.stripMargin)
+
+      // INSERT INTO should fail due to conflict between write and table 
config of partition path fields
+      val sourceTableName = tableName + "_source"
+      prepareParquetSource(sourceTableName, Seq("(7, 'a7', 1399.0, 1706800227, 
'cat1')"))
+      assertThrows[HoodieException] {
+        spark.sql(
+          s"""
+             | INSERT INTO $tableName
+             | SELECT * from $sourceTableName
+             | """.stripMargin)
+      }
+
+      // Only testing Spark 3.1 and above as lower Spark versions do not 
support
+      // ALTER TABLE .. SET TBLPROPERTIES .. to store table-level properties 
in Hudi Catalog
+      if (HoodieSparkUtils.gteqSpark3_1) {
+        // Now fix the partition path field write config for tableName
+        spark.sql(
+          s"""ALTER TABLE $tableName
+             | SET TBLPROPERTIES (hoodie.datasource.write.partitionpath.field 
= '$writePartitionFields')
+             | """.stripMargin)
+
+        // INSERT INTO should succeed now
+        testFirstRoundInserts(tableName, TS_FORMATTER_FUNC, 
customPartitionFunc)
+      }
+    }
+    }
+  }
+
+  private def testFirstRoundInserts(tableName: String,
+                                    tsGenFunc: Integer => String,
+                                    partitionGenFunc: (Integer, String) => 
String): Unit = {
+    val sourceTableName = tableName + "_source1"
+    prepareParquetSource(sourceTableName, Seq("(7, 'a7', 1399.0, 1706800227, 
'cat1')"))
+    spark.sql(
+      s"""
+         | INSERT INTO $tableName
+         | SELECT * from $sourceTableName
+         | """.stripMargin)
+    validateResults(
+      tableName,
+      s"SELECT id, name, cast(price as string), cast(ts as string), segment 
from $tableName",
+      tsGenFunc,
+      partitionGenFunc,
+      Seq(),
+      Seq(1, "a1", "1.6", 1704121827, "cat1"),
+      Seq(2, "a2", "10.8", 1704121827, "cat1"),
+      Seq(3, "a3", "30.0", 1706800227, "cat1"),
+      Seq(4, "a4", "103.4", 1701443427, "cat2"),
+      Seq(5, "a5", "1999.0", 1704121827, "cat2"),
+      Seq(6, "a6", "80.0", 1704121827, "cat3"),
+      Seq(7, "a7", "1399.0", 1706800227, "cat1")
+    )
+  }
+
+  private def testSecondRoundInserts(tableName: String,
+                                     tsGenFunc: Integer => String,
+                                     partitionGenFunc: (Integer, String) => 
String): Unit = {
+    val sourceTableName = tableName + "_source2"
+    prepareParquetSource(sourceTableName, Seq("(8, 'a8', 26.9, 1706800227, 
'cat3')"))
+    spark.sql(
+      s"""
+         | INSERT INTO $tableName
+         | SELECT * from $sourceTableName
+         | """.stripMargin)
+    validateResults(
+      tableName,
+      s"SELECT id, name, cast(price as string), cast(ts as string), segment 
from $tableName",
+      tsGenFunc,
+      partitionGenFunc,
+      Seq(),
+      Seq(1, "a1", "1.6", 1704121827, "cat1"),
+      Seq(2, "a2", "10.8", 1704121827, "cat1"),
+      Seq(3, "a3", "30.0", 1706800227, "cat1"),
+      Seq(4, "a4", "103.4", 1701443427, "cat2"),
+      Seq(5, "a5", "1999.0", 1704121827, "cat2"),
+      Seq(6, "a6", "80.0", 1704121827, "cat3"),
+      Seq(7, "a7", "1399.0", 1706800227, "cat1"),
+      Seq(8, "a8", "26.9", 1706800227, "cat3")
+    )
+  }
+
+  private def prepareTableWithKeyGenerator(tableName: String,
+                                           tablePath: String,
+                                           tableType: String,
+                                           keyGenClassName: String,
+                                           writePartitionFields: String,
+                                           timestampKeyGeneratorConfig: 
Map[String, String]): Unit = {
+    val df = spark.sql(
+      s"""SELECT 1 as id, 'a1' as name, 1.6 as price, 1704121827 as ts, 'cat1' 
as segment
+         | UNION
+         | SELECT 2 as id, 'a2' as name, 10.8 as price, 1704121827 as ts, 
'cat1' as segment
+         | UNION
+         | SELECT 3 as id, 'a3' as name, 30.0 as price, 1706800227 as ts, 
'cat1' as segment
+         | UNION
+         | SELECT 4 as id, 'a4' as name, 103.4 as price, 1701443427 as ts, 
'cat2' as segment
+         | UNION
+         | SELECT 5 as id, 'a5' as name, 1999.0 as price, 1704121827 as ts, 
'cat2' as segment
+         | UNION
+         | SELECT 6 as id, 'a6' as name, 80.0 as price, 1704121827 as ts, 
'cat3' as segment
+         |""".stripMargin)
+
+    df.write.format("hudi")
+      .option("hoodie.datasource.write.table.type", tableType)
+      .option("hoodie.datasource.write.keygenerator.class", keyGenClassName)
+      .option("hoodie.datasource.write.partitionpath.field", 
writePartitionFields)
+      .option("hoodie.datasource.write.recordkey.field", "id")
+      .option("hoodie.datasource.write.precombine.field", "name")
+      .option("hoodie.table.name", tableName)
+      .option("hoodie.insert.shuffle.parallelism", "1")
+      .option("hoodie.upsert.shuffle.parallelism", "1")
+      .option("hoodie.bulkinsert.shuffle.parallelism", "1")
+      .options(timestampKeyGeneratorConfig)
+      .mode(SaveMode.Overwrite)
+      .save(tablePath)
+
+    // Validate that the generated table has expected table configs of key 
generator and partition path fields
+    val metaClient = HoodieTableMetaClient.builder()
+      .setConf(spark.sparkContext.hadoopConfiguration)
+      .setBasePath(tablePath)
+      .build()
+    assertEquals(keyGenClassName, 
metaClient.getTableConfig.getKeyGeneratorClassName)
+    // Validate that that partition path fields in the table config should 
always
+    // contain the field names only (no key generator type like 
"segment:simple")
+    if (CUSTOM_KEY_GEN_CLASS_NAME.equals(keyGenClassName)) {
+      val props = new TypedProperties()
+      props.put("hoodie.datasource.write.partitionpath.field", 
writePartitionFields)
+      timestampKeyGeneratorConfig.foreach(e => {
+        props.put(e._1, e._2)
+      })
+      // For custom key generator, the 
"hoodie.datasource.write.partitionpath.field"
+      // contains the key generator type, like "ts:timestamp,segment:simple",
+      // whereas the partition path fields in table config is "ts,segment"
+      assertEquals(
+        
SparkKeyGenUtils.getPartitionColumns(Option(CUSTOM_KEY_GEN_CLASS_NAME), props),
+        metaClient.getTableConfig.getPartitionFieldProp)
+    } else {
+      assertEquals(writePartitionFields, 
metaClient.getTableConfig.getPartitionFieldProp)
+    }
+  }
+
+  private def createTableWithSql(tableName: String,
+                                 tablePath: String,
+                                 tblProps: String): Unit = {
+    val tblPropsStatement = if (StringUtils.isNullOrEmpty(tblProps)) {
+      ""
+    } else {
+      "TBLPROPERTIES (\n" + tblProps + "\n)"
+    }
+    spark.sql(
+      s"""
+         | CREATE TABLE $tableName USING HUDI
+         | location '$tablePath'
+         | $tblPropsStatement
+         | """.stripMargin)
+  }
+
+  private def prepareParquetSource(sourceTableName: String,
+                                   rows: Seq[String]): Unit = {
+    spark.sql(
+      s"""CREATE TABLE $sourceTableName
+         | (id int, name string, price decimal(5, 1), ts int, segment string)
+         | USING PARQUET
+         |""".stripMargin)
+    spark.sql(
+      s"""
+         | INSERT INTO $sourceTableName values
+         | ${rows.mkString(", ")}
+         | """.stripMargin)
+  }
+
+  private def validateResults(tableName: String,
+                              sql: String,
+                              tsGenFunc: Integer => String,
+                              partitionGenFunc: (Integer, String) => String,
+                              droppedPartitions: Seq[String],
+                              expects: Seq[Any]*): Unit = {
+    checkAnswer(sql)(
+      expects.map(e => Seq(e(0), e(1), e(2), 
tsGenFunc.apply(e(3).asInstanceOf[Integer]), e(4))): _*
+    )
+    val expectedPartitions: Seq[String] = expects
+      .map(e => partitionGenFunc.apply(e(3).asInstanceOf[Integer], 
e(4).asInstanceOf[String]))
+      .distinct.sorted
+    validatePartitions(tableName, droppedPartitions, expectedPartitions)
+  }
+
+  private def getSortedTablePartitions(tableName: String): Seq[String] = {
+    spark.sql(s"SHOW PARTITIONS $tableName").collect()
+      .map(row => row.getString(0))
+      .sorted.toSeq
+  }
+
+  private def validatePartitions(tableName: String,
+                                 droppedPartitions: Seq[String],
+                                 expectedPartitions: Seq[String]): Unit = {
+    val actualPartitions: Seq[String] = getSortedTablePartitions(tableName)
+    if (expectedPartitions.size == 1 && expectedPartitions.head.isEmpty) {
+      assertTrue(actualPartitions.isEmpty)
+    } else {
+      assertEquals(expectedPartitions, actualPartitions)
+    }
+    droppedPartitions.foreach(dropped => 
assertFalse(actualPartitions.contains(dropped)))
+  }
+}
+
+object TestSparkSqlWithCustomKeyGenerator {
+  val SIMPLE_KEY_GEN_CLASS_NAME = "org.apache.hudi.keygen.SimpleKeyGenerator"
+  val NONPARTITIONED_KEY_GEN_CLASS_NAME = 
"org.apache.hudi.keygen.NonpartitionedKeyGenerator"
+  val CUSTOM_KEY_GEN_CLASS_NAME = "org.apache.hudi.keygen.CustomKeyGenerator"
+  val DATE_FORMAT_PATTERN = "yyyyMM"
+  val TS_KEY_GEN_CONFIGS = Map(
+    "hoodie.keygen.timebased.timestamp.type" -> "SCALAR",
+    "hoodie.keygen.timebased.output.dateformat" -> DATE_FORMAT_PATTERN,
+    "hoodie.keygen.timebased.timestamp.scalar.time.unit" -> "seconds"
+  )
+  val TS_TO_STRING_FUNC = (tsSeconds: Integer) => tsSeconds.toString
+  val TS_FORMATTER_FUNC = (tsSeconds: Integer) => {
+    new DateTime(tsSeconds * 
1000L).toString(DateTimeFormat.forPattern(DATE_FORMAT_PATTERN))
+  }
+
+  def getTimestampKeyGenConfigs: Map[String, String] = {
+    Map(
+      "hoodie.keygen.timebased.timestamp.type" -> "SCALAR",
+      "hoodie.keygen.timebased.output.dateformat" -> DATE_FORMAT_PATTERN,
+      "hoodie.keygen.timebased.timestamp.scalar.time.unit" -> "seconds"
+    )
+  }
+}


Reply via email to