jonvex commented on code in PR #10615:
URL: https://github.com/apache/hudi/pull/10615#discussion_r1562569055


##########
hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/HoodieWriterUtils.scala:
##########
@@ -201,8 +201,26 @@ object HoodieWriterUtils {
           
diffConfigs.append(s"KeyGenerator:\t$datasourceKeyGen\t$tableConfigKeyGen\n")
         }
 
+        // Please note that the validation of partition path fields needs the 
key generator class
+        // for the table, since the custom key generator expects a different 
format of
+        // the value of the write config 
"hoodie.datasource.write.partitionpath.field"
+        // e.g., "col:simple,ts:timestamp", whereas the table config 
"hoodie.table.partition.fields"
+        // in hoodie.properties stores "col,ts".
+        // The "params" here may only contain the write config of partition 
path field,
+        // so we need to pass in the validated key generator class name.
+        val validatedKeyGenClassName = if (tableConfigKeyGen != null) {

Review Comment:
   So when `hoodie.datasource.write.partitionpath.field` is set, we don't set 
`hoodie.table.partition.fields` ?



##########
hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/spark/sql/hudi/ProvidesHoodieConfig.scala:
##########
@@ -530,6 +539,40 @@ object ProvidesHoodieConfig {
       filterNullValues(overridingOpts)
   }
 
+  /**
+   * @param tableConfigKeyGeneratorClassName     key generator class name in 
the table config.
+   * @param partitionFieldNamesWithoutKeyGenType partition field names without 
key generator types
+   *                                             from the table config.
+   * @param catalogTable                         HoodieCatalogTable instance 
to fetch table properties.
+   * @return the write config value to set for 
"hoodie.datasource.write.partitionpath.field".
+   */
+  def getPartitionPathFieldWriteConfig(tableConfigKeyGeneratorClassName: 
String,
+                                       partitionFieldNamesWithoutKeyGenType: 
String,
+                                       catalogTable: HoodieCatalogTable): 
String = {
+    if (StringUtils.isNullOrEmpty(tableConfigKeyGeneratorClassName)) {
+      partitionFieldNamesWithoutKeyGenType
+    } else {
+      val writeConfigPartitionField = 
catalogTable.catalogProperties.get(PARTITIONPATH_FIELD.key())
+      val keyGenClass = 
ReflectionUtils.getClass(tableConfigKeyGeneratorClassName)
+      if (classOf[CustomKeyGenerator].equals(keyGenClass)

Review Comment:
   Do we want to make this cover any classes that extend customkeygen as well?



##########
hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/spark/sql/hudi/ProvidesHoodieConfig.scala:
##########
@@ -528,6 +536,40 @@ object ProvidesHoodieConfig {
       filterNullValues(overridingOpts)
   }
 
+  /**
+   * @param tableConfigKeyGeneratorClassName     key generator class name in 
the table config.
+   * @param partitionFieldNamesWithoutKeyGenType partition field names without 
key generator types
+   *                                             from the table config.
+   * @param catalogTable                         HoodieCatalogTable instance 
to fetch table properties.
+   * @return the write config value to set for 
"hoodie.datasource.write.partitionpath.field".
+   */
+  def getPartitionPathFieldWriteConfig(tableConfigKeyGeneratorClassName: 
String,
+                                       partitionFieldNamesWithoutKeyGenType: 
String,
+                                       catalogTable: HoodieCatalogTable): 
String = {
+    if (StringUtils.isNullOrEmpty(tableConfigKeyGeneratorClassName)) {
+      partitionFieldNamesWithoutKeyGenType
+    } else {

Review Comment:
   So does this mean that it's still an issue for flink and hive etc?



##########
hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestSparkSqlWithCustomKeyGenerator.scala:
##########
@@ -0,0 +1,571 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.hudi.functional
+
+import org.apache.hudi.HoodieSparkUtils
+import org.apache.hudi.common.config.TypedProperties
+import org.apache.hudi.common.table.HoodieTableMetaClient
+import org.apache.hudi.common.util.StringUtils
+import org.apache.hudi.exception.HoodieException
+import org.apache.hudi.functional.TestSparkSqlWithCustomKeyGenerator._
+import org.apache.hudi.util.SparkKeyGenUtils
+import org.apache.spark.sql.SaveMode
+import org.apache.spark.sql.hudi.common.HoodieSparkSqlTestBase
+import org.joda.time.DateTime
+import org.joda.time.format.DateTimeFormat
+import org.junit.jupiter.api.Assertions.{assertEquals, assertFalse, assertTrue}
+import org.slf4j.LoggerFactory
+
+import java.io.IOException
+
+/**
+ * Tests Spark SQL DML with custom key generator and write configs.
+ */
+class TestSparkSqlWithCustomKeyGenerator extends HoodieSparkSqlTestBase {
+  private val LOG = LoggerFactory.getLogger(getClass)
+
+  test("Test Spark SQL DML with custom key generator") {
+    withTempDir { tmp =>
+      Seq(
+        Seq("COPY_ON_WRITE", "ts:timestamp,segment:simple",
+          "(ts=202401, segment='cat2')", "202401/cat2",
+          Seq("202312/cat2", "202312/cat4", "202401/cat1", "202401/cat3", 
"202402/cat1", "202402/cat3", "202402/cat5"),
+          TS_FORMATTER_FUNC,
+          (ts: Integer, segment: String) => TS_FORMATTER_FUNC.apply(ts) + "/" 
+ segment),
+        Seq("MERGE_ON_READ", "segment:simple",
+          "(segment='cat3')", "cat3",
+          Seq("cat1", "cat2", "cat4", "cat5"),
+          TS_TO_STRING_FUNC,
+          (_: Integer, segment: String) => segment),
+        Seq("MERGE_ON_READ", "ts:timestamp",
+          "(ts=202312)", "202312",
+          Seq("202401", "202402"),
+          TS_FORMATTER_FUNC,
+          (ts: Integer, _: String) => TS_FORMATTER_FUNC.apply(ts)),
+        Seq("MERGE_ON_READ", "ts:timestamp,segment:simple",
+          "(ts=202401, segment='cat2')", "202401/cat2",
+          Seq("202312/cat2", "202312/cat4", "202401/cat1", "202401/cat3", 
"202402/cat1", "202402/cat3", "202402/cat5"),
+          TS_FORMATTER_FUNC,
+          (ts: Integer, segment: String) => TS_FORMATTER_FUNC.apply(ts) + "/" 
+ segment)
+      ).foreach { testParams =>
+        withTable(generateTableName) { tableName =>
+          LOG.warn("Testing with parameters: " + testParams)
+          val tableType = testParams(0).asInstanceOf[String]
+          val writePartitionFields = testParams(1).asInstanceOf[String]
+          val dropPartitionStatement = testParams(2).asInstanceOf[String]
+          val droppedPartition = testParams(3).asInstanceOf[String]
+          val expectedPartitions = testParams(4).asInstanceOf[Seq[String]]
+          val tsGenFunc = testParams(5).asInstanceOf[Integer => String]
+          val partitionGenFunc = testParams(6).asInstanceOf[(Integer, String) 
=> String]
+          val tablePath = tmp.getCanonicalPath + "/" + tableName
+          val timestampKeyGeneratorConfig = if 
(writePartitionFields.contains("timestamp")) {
+            TS_KEY_GEN_CONFIGS
+          } else {
+            Map[String, String]()
+          }
+          val timestampKeyGenProps = if (timestampKeyGeneratorConfig.nonEmpty) 
{
+            ", " + timestampKeyGeneratorConfig.map(e => e._1 + " = '" + e._2 + 
"'").mkString(", ")
+          } else {
+            ""
+          }
+
+          prepareTableWithKeyGenerator(
+            tableName, tablePath, tableType,
+            CUSTOM_KEY_GEN_CLASS_NAME, writePartitionFields, 
timestampKeyGeneratorConfig)
+
+          // SQL CTAS with table properties containing key generator write 
configs
+          createTableWithSql(tableName, tablePath,
+            s"hoodie.datasource.write.partitionpath.field = 
'$writePartitionFields'" + timestampKeyGenProps)
+
+          // Prepare source and test SQL INSERT INTO
+          val sourceTableName = tableName + "_source"
+          prepareParquetSource(sourceTableName, Seq(
+            "(7, 'a7', 1399.0, 1706800227, 'cat1')",
+            "(8, 'a8', 26.9, 1706800227, 'cat3')",
+            "(9, 'a9', 299.0, 1701443427, 'cat4')"))
+          spark.sql(
+            s"""
+               | INSERT INTO $tableName
+               | SELECT * from ${tableName}_source
+               | """.stripMargin)
+          validateResults(
+            tableName,
+            s"SELECT id, name, cast(price as string), cast(ts as string), 
segment from $tableName",
+            tsGenFunc,
+            partitionGenFunc,
+            Seq(),
+            Seq(1, "a1", "1.6", 1704121827, "cat1"),
+            Seq(2, "a2", "10.8", 1704121827, "cat1"),
+            Seq(3, "a3", "30.0", 1706800227, "cat1"),
+            Seq(4, "a4", "103.4", 1701443427, "cat2"),
+            Seq(5, "a5", "1999.0", 1704121827, "cat2"),
+            Seq(6, "a6", "80.0", 1704121827, "cat3"),
+            Seq(7, "a7", "1399.0", 1706800227, "cat1"),
+            Seq(8, "a8", "26.9", 1706800227, "cat3"),
+            Seq(9, "a9", "299.0", 1701443427, "cat4")
+          )
+
+          // Test SQL UPDATE
+          spark.sql(
+            s"""
+               | UPDATE $tableName
+               | SET price = price + 10.0
+               | WHERE id between 4 and 7
+               | """.stripMargin)
+          validateResults(
+            tableName,
+            s"SELECT id, name, cast(price as string), cast(ts as string), 
segment from $tableName",
+            tsGenFunc,
+            partitionGenFunc,
+            Seq(),
+            Seq(1, "a1", "1.6", 1704121827, "cat1"),
+            Seq(2, "a2", "10.8", 1704121827, "cat1"),
+            Seq(3, "a3", "30.0", 1706800227, "cat1"),
+            Seq(4, "a4", "113.4", 1701443427, "cat2"),
+            Seq(5, "a5", "2009.0", 1704121827, "cat2"),
+            Seq(6, "a6", "90.0", 1704121827, "cat3"),
+            Seq(7, "a7", "1409.0", 1706800227, "cat1"),
+            Seq(8, "a8", "26.9", 1706800227, "cat3"),
+            Seq(9, "a9", "299.0", 1701443427, "cat4")
+          )
+
+          // Test SQL MERGE INTO
+          spark.sql(
+            s"""
+               | MERGE INTO $tableName as target
+               | USING (
+               |   SELECT 1 as id, 'a1' as name, 1.6 as price, 1704121827 as 
ts, 'cat1' as segment, 'delete' as flag
+               |   UNION
+               |   SELECT 2 as id, 'a2' as name, 11.9 as price, 1704121827 as 
ts, 'cat1' as segment, '' as flag
+               |   UNION
+               |   SELECT 6 as id, 'a6' as name, 99.0 as price, 1704121827 as 
ts, 'cat3' as segment, '' as flag
+               |   UNION
+               |   SELECT 8 as id, 'a8' as name, 24.9 as price, 1706800227 as 
ts, 'cat3' as segment, '' as flag
+               |   UNION
+               |   SELECT 10 as id, 'a10' as name, 888.8 as price, 1706800227 
as ts, 'cat5' as segment, '' as flag
+               | ) source
+               | on target.id = source.id
+               | WHEN MATCHED AND flag != 'delete' THEN UPDATE SET
+               |   id = source.id, name = source.name, price = source.price, 
ts = source.ts, segment = source.segment
+               | WHEN MATCHED AND flag = 'delete' THEN DELETE
+               | WHEN NOT MATCHED THEN INSERT (id, name, price, ts, segment)
+               |   values (source.id, source.name, source.price, source.ts, 
source.segment)
+               | """.stripMargin)
+          validateResults(
+            tableName,
+            s"SELECT id, name, cast(price as string), cast(ts as string), 
segment from $tableName",
+            tsGenFunc,
+            partitionGenFunc,
+            Seq(),
+            Seq(2, "a2", "11.9", 1704121827, "cat1"),
+            Seq(3, "a3", "30.0", 1706800227, "cat1"),
+            Seq(4, "a4", "113.4", 1701443427, "cat2"),
+            Seq(5, "a5", "2009.0", 1704121827, "cat2"),
+            Seq(6, "a6", "99.0", 1704121827, "cat3"),
+            Seq(7, "a7", "1409.0", 1706800227, "cat1"),
+            Seq(8, "a8", "24.9", 1706800227, "cat3"),
+            Seq(9, "a9", "299.0", 1701443427, "cat4"),
+            Seq(10, "a10", "888.8", 1706800227, "cat5")
+          )
+
+          // Test SQL DELETE
+          spark.sql(
+            s"""
+               | DELETE FROM $tableName
+               | WHERE id = 7
+               | """.stripMargin)
+          validateResults(
+            tableName,
+            s"SELECT id, name, cast(price as string), cast(ts as string), 
segment from $tableName",
+            tsGenFunc,
+            partitionGenFunc,
+            Seq(),
+            Seq(2, "a2", "11.9", 1704121827, "cat1"),
+            Seq(3, "a3", "30.0", 1706800227, "cat1"),
+            Seq(4, "a4", "113.4", 1701443427, "cat2"),
+            Seq(5, "a5", "2009.0", 1704121827, "cat2"),
+            Seq(6, "a6", "99.0", 1704121827, "cat3"),
+            Seq(8, "a8", "24.9", 1706800227, "cat3"),
+            Seq(9, "a9", "299.0", 1701443427, "cat4"),
+            Seq(10, "a10", "888.8", 1706800227, "cat5")
+          )
+
+          // Test DROP PARTITION
+          
assertTrue(getSortedTablePartitions(tableName).contains(droppedPartition))
+          spark.sql(
+            s"""
+               | ALTER TABLE $tableName DROP PARTITION $dropPartitionStatement
+               |""".stripMargin)
+          validatePartitions(tableName, Seq(droppedPartition), 
expectedPartitions)
+
+          if (HoodieSparkUtils.isSpark3) {
+            // Test INSERT OVERWRITE, only supported in Spark 3.x
+            spark.sql(
+              s"""
+                 | INSERT OVERWRITE $tableName
+                 | SELECT 100 as id, 'a100' as name, 299.0 as price, 
1706800227 as ts, 'cat10' as segment
+                 | """.stripMargin)
+            validateResults(
+              tableName,
+              s"SELECT id, name, cast(price as string), cast(ts as string), 
segment from $tableName",
+              tsGenFunc,
+              partitionGenFunc,
+              Seq(),
+              Seq(100, "a100", "299.0", 1706800227, "cat10")
+            )
+          }
+        }
+      }
+    }
+  }
+
+  test("Test table property isolation for partition path field config "
+    + "with custom key generator for Spark 3.1 and above") {
+    // Only testing Spark 3.1 and above as lower Spark versions do not support
+    // ALTER TABLE .. SET TBLPROPERTIES .. to store table-level properties in 
Hudi Catalog
+    if (HoodieSparkUtils.gteqSpark3_1) {
+      withTempDir { tmp => {
+        val tableNameNonPartitioned = generateTableName
+        val tableNameSimpleKey = generateTableName
+        val tableNameCustom1 = generateTableName
+        val tableNameCustom2 = generateTableName
+
+        val tablePathNonPartitioned = tmp.getCanonicalPath + "/" + 
tableNameNonPartitioned
+        val tablePathSimpleKey = tmp.getCanonicalPath + "/" + 
tableNameSimpleKey
+        val tablePathCustom1 = tmp.getCanonicalPath + "/" + tableNameCustom1
+        val tablePathCustom2 = tmp.getCanonicalPath + "/" + tableNameCustom2
+
+        val tableType = "MERGE_ON_READ"
+        val writePartitionFields1 = "segment:simple"
+        val writePartitionFields2 = "ts:timestamp,segment:simple"
+
+        prepareTableWithKeyGenerator(
+          tableNameNonPartitioned, tablePathNonPartitioned, tableType,
+          NONPARTITIONED_KEY_GEN_CLASS_NAME, "", Map())
+        prepareTableWithKeyGenerator(
+          tableNameSimpleKey, tablePathSimpleKey, tableType,
+          SIMPLE_KEY_GEN_CLASS_NAME, "segment", Map())
+        prepareTableWithKeyGenerator(
+          tableNameCustom1, tablePathCustom1, tableType,
+          CUSTOM_KEY_GEN_CLASS_NAME, writePartitionFields1, Map())
+        prepareTableWithKeyGenerator(
+          tableNameCustom2, tablePathCustom2, tableType,
+          CUSTOM_KEY_GEN_CLASS_NAME, writePartitionFields2, TS_KEY_GEN_CONFIGS)
+
+        // Non-partitioned table does not require additional partition path 
field write config
+        createTableWithSql(tableNameNonPartitioned, tablePathNonPartitioned, 
"")
+        // Partitioned table with simple key generator does not require 
additional partition path field write config
+        createTableWithSql(tableNameSimpleKey, tablePathSimpleKey, "")
+        // Partitioned table with custom key generator requires additional 
partition path field write config
+        // Without that, right now the SQL DML fails
+        createTableWithSql(tableNameCustom1, tablePathCustom1, "")
+        createTableWithSql(tableNameCustom2, tablePathCustom2,
+          s"hoodie.datasource.write.partitionpath.field = 
'$writePartitionFields2', "
+            + TS_KEY_GEN_CONFIGS.map(e => e._1 + " = '" + e._2 + 
"'").mkString(", "))
+
+        val segmentPartitionFunc = (_: Integer, segment: String) => segment
+        val customPartitionFunc = (ts: Integer, segment: String) => 
TS_FORMATTER_FUNC.apply(ts) + "/" + segment
+
+        testInsertInto1(tableNameNonPartitioned, TS_TO_STRING_FUNC, (_, _) => 
"")
+        testInsertInto1(tableNameSimpleKey, TS_TO_STRING_FUNC, 
segmentPartitionFunc)
+        // INSERT INTO should fail for tableNameCustom1
+        val sourceTableName = tableNameCustom1 + "_source"
+        prepareParquetSource(sourceTableName, Seq("(7, 'a7', 1399.0, 
1706800227, 'cat1')"))
+        assertThrows[IOException] {
+          spark.sql(
+            s"""
+               | INSERT INTO $tableNameCustom1
+               | SELECT * from $sourceTableName
+               | """.stripMargin)
+        }
+        testInsertInto1(tableNameCustom2, TS_FORMATTER_FUNC, 
customPartitionFunc)
+
+        // Now add the missing partition path field write config for 
tableNameCustom1
+        spark.sql(
+          s"""ALTER TABLE $tableNameCustom1
+             | SET TBLPROPERTIES (hoodie.datasource.write.partitionpath.field 
= '$writePartitionFields1')
+             | """.stripMargin)
+
+        // All tables should be able to do INSERT INTO without any problem,
+        // since the scope of the added write config is at the catalog table 
level
+        testInsertInto2(tableNameNonPartitioned, TS_TO_STRING_FUNC, (_, _) => 
"")
+        testInsertInto2(tableNameSimpleKey, TS_TO_STRING_FUNC, 
segmentPartitionFunc)
+        testInsertInto1(tableNameCustom1, TS_TO_STRING_FUNC, 
segmentPartitionFunc)
+        testInsertInto2(tableNameCustom2, TS_FORMATTER_FUNC, 
customPartitionFunc)
+      }
+      }
+    }
+  }
+
+  test("Test wrong partition path field write config with custom key 
generator") {
+    withTempDir { tmp => {
+      val tableName = generateTableName
+      val tablePath = tmp.getCanonicalPath + "/" + tableName
+      val tableType = "MERGE_ON_READ"
+      val writePartitionFields = "segment:simple,ts:timestamp"
+      val wrongWritePartitionFields = "segment:simple"
+      val customPartitionFunc = (ts: Integer, segment: String) => segment + 
"/" + TS_FORMATTER_FUNC.apply(ts)
+
+      prepareTableWithKeyGenerator(
+        tableName, tablePath, "MERGE_ON_READ",
+        CUSTOM_KEY_GEN_CLASS_NAME, writePartitionFields, TS_KEY_GEN_CONFIGS)
+
+      // CREATE TABLE should fail due to config conflict
+      assertThrows[HoodieException] {
+        createTableWithSql(tableName, tablePath,
+          s"hoodie.datasource.write.partitionpath.field = 
'$wrongWritePartitionFields', "
+            + TS_KEY_GEN_CONFIGS.map(e => e._1 + " = '" + e._2 + 
"'").mkString(", "))
+      }
+
+      createTableWithSql(tableName, tablePath,
+        s"hoodie.datasource.write.partitionpath.field = 
'$writePartitionFields', "
+          + TS_KEY_GEN_CONFIGS.map(e => e._1 + " = '" + e._2 + 
"'").mkString(", "))
+      // Set wrong write config
+      spark.sql(
+        s"""ALTER TABLE $tableName
+           | SET TBLPROPERTIES (hoodie.datasource.write.partitionpath.field = 
'$wrongWritePartitionFields')
+           | """.stripMargin)
+
+      // INSERT INTO should fail due to conflict between write and table 
config of partition path fields
+      val sourceTableName = tableName + "_source"
+      prepareParquetSource(sourceTableName, Seq("(7, 'a7', 1399.0, 1706800227, 
'cat1')"))
+      assertThrows[HoodieException] {
+        spark.sql(
+          s"""
+             | INSERT INTO $tableName
+             | SELECT * from $sourceTableName
+             | """.stripMargin)
+      }
+
+      // Only testing Spark 3.1 and above as lower Spark versions do not 
support
+      // ALTER TABLE .. SET TBLPROPERTIES .. to store table-level properties 
in Hudi Catalog
+      if (HoodieSparkUtils.gteqSpark3_1) {
+        // Now fix the partition path field write config for tableName
+        spark.sql(
+          s"""ALTER TABLE $tableName
+             | SET TBLPROPERTIES (hoodie.datasource.write.partitionpath.field 
= '$writePartitionFields')
+             | """.stripMargin)
+
+        // INSERT INTO should succeed now
+        testInsertInto1(tableName, TS_FORMATTER_FUNC, customPartitionFunc)
+      }
+    }
+    }
+  }
+
+  private def testInsertInto1(tableName: String,

Review Comment:
   Or can you at least add some more comments on what the helper functions are 
doing?



##########
hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/spark/sql/hudi/ProvidesHoodieConfig.scala:
##########
@@ -530,6 +539,40 @@ object ProvidesHoodieConfig {
       filterNullValues(overridingOpts)
   }
 
+  /**
+   * @param tableConfigKeyGeneratorClassName     key generator class name in 
the table config.
+   * @param partitionFieldNamesWithoutKeyGenType partition field names without 
key generator types
+   *                                             from the table config.
+   * @param catalogTable                         HoodieCatalogTable instance 
to fetch table properties.
+   * @return the write config value to set for 
"hoodie.datasource.write.partitionpath.field".
+   */
+  def getPartitionPathFieldWriteConfig(tableConfigKeyGeneratorClassName: 
String,
+                                       partitionFieldNamesWithoutKeyGenType: 
String,
+                                       catalogTable: HoodieCatalogTable): 
String = {
+    if (StringUtils.isNullOrEmpty(tableConfigKeyGeneratorClassName)) {
+      partitionFieldNamesWithoutKeyGenType
+    } else {
+      val writeConfigPartitionField = 
catalogTable.catalogProperties.get(PARTITIONPATH_FIELD.key())

Review Comment:
   Is HoodieCatalogTable persisted between sessions? When you add an existing 
hudi table in spark sql you only need column names usually right?



##########
hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestSparkSqlWithCustomKeyGenerator.scala:
##########
@@ -0,0 +1,571 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.hudi.functional
+
+import org.apache.hudi.HoodieSparkUtils
+import org.apache.hudi.common.config.TypedProperties
+import org.apache.hudi.common.table.HoodieTableMetaClient
+import org.apache.hudi.common.util.StringUtils
+import org.apache.hudi.exception.HoodieException
+import org.apache.hudi.functional.TestSparkSqlWithCustomKeyGenerator._
+import org.apache.hudi.util.SparkKeyGenUtils
+import org.apache.spark.sql.SaveMode
+import org.apache.spark.sql.hudi.common.HoodieSparkSqlTestBase
+import org.joda.time.DateTime
+import org.joda.time.format.DateTimeFormat
+import org.junit.jupiter.api.Assertions.{assertEquals, assertFalse, assertTrue}
+import org.slf4j.LoggerFactory
+
+import java.io.IOException
+
+/**
+ * Tests Spark SQL DML with custom key generator and write configs.
+ */
+class TestSparkSqlWithCustomKeyGenerator extends HoodieSparkSqlTestBase {
+  private val LOG = LoggerFactory.getLogger(getClass)
+
+  test("Test Spark SQL DML with custom key generator") {
+    withTempDir { tmp =>
+      Seq(
+        Seq("COPY_ON_WRITE", "ts:timestamp,segment:simple",
+          "(ts=202401, segment='cat2')", "202401/cat2",
+          Seq("202312/cat2", "202312/cat4", "202401/cat1", "202401/cat3", 
"202402/cat1", "202402/cat3", "202402/cat5"),
+          TS_FORMATTER_FUNC,
+          (ts: Integer, segment: String) => TS_FORMATTER_FUNC.apply(ts) + "/" 
+ segment),
+        Seq("MERGE_ON_READ", "segment:simple",
+          "(segment='cat3')", "cat3",
+          Seq("cat1", "cat2", "cat4", "cat5"),
+          TS_TO_STRING_FUNC,
+          (_: Integer, segment: String) => segment),
+        Seq("MERGE_ON_READ", "ts:timestamp",
+          "(ts=202312)", "202312",
+          Seq("202401", "202402"),
+          TS_FORMATTER_FUNC,
+          (ts: Integer, _: String) => TS_FORMATTER_FUNC.apply(ts)),
+        Seq("MERGE_ON_READ", "ts:timestamp,segment:simple",
+          "(ts=202401, segment='cat2')", "202401/cat2",
+          Seq("202312/cat2", "202312/cat4", "202401/cat1", "202401/cat3", 
"202402/cat1", "202402/cat3", "202402/cat5"),
+          TS_FORMATTER_FUNC,
+          (ts: Integer, segment: String) => TS_FORMATTER_FUNC.apply(ts) + "/" 
+ segment)
+      ).foreach { testParams =>
+        withTable(generateTableName) { tableName =>
+          LOG.warn("Testing with parameters: " + testParams)
+          val tableType = testParams(0).asInstanceOf[String]
+          val writePartitionFields = testParams(1).asInstanceOf[String]
+          val dropPartitionStatement = testParams(2).asInstanceOf[String]
+          val droppedPartition = testParams(3).asInstanceOf[String]
+          val expectedPartitions = testParams(4).asInstanceOf[Seq[String]]
+          val tsGenFunc = testParams(5).asInstanceOf[Integer => String]
+          val partitionGenFunc = testParams(6).asInstanceOf[(Integer, String) 
=> String]
+          val tablePath = tmp.getCanonicalPath + "/" + tableName
+          val timestampKeyGeneratorConfig = if 
(writePartitionFields.contains("timestamp")) {
+            TS_KEY_GEN_CONFIGS
+          } else {
+            Map[String, String]()
+          }
+          val timestampKeyGenProps = if (timestampKeyGeneratorConfig.nonEmpty) 
{
+            ", " + timestampKeyGeneratorConfig.map(e => e._1 + " = '" + e._2 + 
"'").mkString(", ")
+          } else {
+            ""
+          }
+
+          prepareTableWithKeyGenerator(
+            tableName, tablePath, tableType,
+            CUSTOM_KEY_GEN_CLASS_NAME, writePartitionFields, 
timestampKeyGeneratorConfig)
+
+          // SQL CTAS with table properties containing key generator write 
configs
+          createTableWithSql(tableName, tablePath,
+            s"hoodie.datasource.write.partitionpath.field = 
'$writePartitionFields'" + timestampKeyGenProps)
+
+          // Prepare source and test SQL INSERT INTO
+          val sourceTableName = tableName + "_source"
+          prepareParquetSource(sourceTableName, Seq(
+            "(7, 'a7', 1399.0, 1706800227, 'cat1')",
+            "(8, 'a8', 26.9, 1706800227, 'cat3')",
+            "(9, 'a9', 299.0, 1701443427, 'cat4')"))
+          spark.sql(
+            s"""
+               | INSERT INTO $tableName
+               | SELECT * from ${tableName}_source
+               | """.stripMargin)
+          validateResults(
+            tableName,
+            s"SELECT id, name, cast(price as string), cast(ts as string), 
segment from $tableName",
+            tsGenFunc,
+            partitionGenFunc,
+            Seq(),
+            Seq(1, "a1", "1.6", 1704121827, "cat1"),
+            Seq(2, "a2", "10.8", 1704121827, "cat1"),
+            Seq(3, "a3", "30.0", 1706800227, "cat1"),
+            Seq(4, "a4", "103.4", 1701443427, "cat2"),
+            Seq(5, "a5", "1999.0", 1704121827, "cat2"),
+            Seq(6, "a6", "80.0", 1704121827, "cat3"),
+            Seq(7, "a7", "1399.0", 1706800227, "cat1"),
+            Seq(8, "a8", "26.9", 1706800227, "cat3"),
+            Seq(9, "a9", "299.0", 1701443427, "cat4")
+          )
+
+          // Test SQL UPDATE
+          spark.sql(
+            s"""
+               | UPDATE $tableName
+               | SET price = price + 10.0
+               | WHERE id between 4 and 7
+               | """.stripMargin)
+          validateResults(
+            tableName,
+            s"SELECT id, name, cast(price as string), cast(ts as string), 
segment from $tableName",
+            tsGenFunc,
+            partitionGenFunc,
+            Seq(),
+            Seq(1, "a1", "1.6", 1704121827, "cat1"),
+            Seq(2, "a2", "10.8", 1704121827, "cat1"),
+            Seq(3, "a3", "30.0", 1706800227, "cat1"),
+            Seq(4, "a4", "113.4", 1701443427, "cat2"),
+            Seq(5, "a5", "2009.0", 1704121827, "cat2"),
+            Seq(6, "a6", "90.0", 1704121827, "cat3"),
+            Seq(7, "a7", "1409.0", 1706800227, "cat1"),
+            Seq(8, "a8", "26.9", 1706800227, "cat3"),
+            Seq(9, "a9", "299.0", 1701443427, "cat4")
+          )
+
+          // Test SQL MERGE INTO
+          spark.sql(
+            s"""
+               | MERGE INTO $tableName as target
+               | USING (
+               |   SELECT 1 as id, 'a1' as name, 1.6 as price, 1704121827 as 
ts, 'cat1' as segment, 'delete' as flag
+               |   UNION
+               |   SELECT 2 as id, 'a2' as name, 11.9 as price, 1704121827 as 
ts, 'cat1' as segment, '' as flag
+               |   UNION
+               |   SELECT 6 as id, 'a6' as name, 99.0 as price, 1704121827 as 
ts, 'cat3' as segment, '' as flag
+               |   UNION
+               |   SELECT 8 as id, 'a8' as name, 24.9 as price, 1706800227 as 
ts, 'cat3' as segment, '' as flag
+               |   UNION
+               |   SELECT 10 as id, 'a10' as name, 888.8 as price, 1706800227 
as ts, 'cat5' as segment, '' as flag
+               | ) source
+               | on target.id = source.id
+               | WHEN MATCHED AND flag != 'delete' THEN UPDATE SET
+               |   id = source.id, name = source.name, price = source.price, 
ts = source.ts, segment = source.segment
+               | WHEN MATCHED AND flag = 'delete' THEN DELETE
+               | WHEN NOT MATCHED THEN INSERT (id, name, price, ts, segment)
+               |   values (source.id, source.name, source.price, source.ts, 
source.segment)
+               | """.stripMargin)
+          validateResults(
+            tableName,
+            s"SELECT id, name, cast(price as string), cast(ts as string), 
segment from $tableName",
+            tsGenFunc,
+            partitionGenFunc,
+            Seq(),
+            Seq(2, "a2", "11.9", 1704121827, "cat1"),
+            Seq(3, "a3", "30.0", 1706800227, "cat1"),
+            Seq(4, "a4", "113.4", 1701443427, "cat2"),
+            Seq(5, "a5", "2009.0", 1704121827, "cat2"),
+            Seq(6, "a6", "99.0", 1704121827, "cat3"),
+            Seq(7, "a7", "1409.0", 1706800227, "cat1"),
+            Seq(8, "a8", "24.9", 1706800227, "cat3"),
+            Seq(9, "a9", "299.0", 1701443427, "cat4"),
+            Seq(10, "a10", "888.8", 1706800227, "cat5")
+          )
+
+          // Test SQL DELETE
+          spark.sql(
+            s"""
+               | DELETE FROM $tableName
+               | WHERE id = 7
+               | """.stripMargin)
+          validateResults(
+            tableName,
+            s"SELECT id, name, cast(price as string), cast(ts as string), 
segment from $tableName",
+            tsGenFunc,
+            partitionGenFunc,
+            Seq(),
+            Seq(2, "a2", "11.9", 1704121827, "cat1"),
+            Seq(3, "a3", "30.0", 1706800227, "cat1"),
+            Seq(4, "a4", "113.4", 1701443427, "cat2"),
+            Seq(5, "a5", "2009.0", 1704121827, "cat2"),
+            Seq(6, "a6", "99.0", 1704121827, "cat3"),
+            Seq(8, "a8", "24.9", 1706800227, "cat3"),
+            Seq(9, "a9", "299.0", 1701443427, "cat4"),
+            Seq(10, "a10", "888.8", 1706800227, "cat5")
+          )
+
+          // Test DROP PARTITION
+          
assertTrue(getSortedTablePartitions(tableName).contains(droppedPartition))
+          spark.sql(
+            s"""
+               | ALTER TABLE $tableName DROP PARTITION $dropPartitionStatement
+               |""".stripMargin)
+          validatePartitions(tableName, Seq(droppedPartition), 
expectedPartitions)
+
+          if (HoodieSparkUtils.isSpark3) {
+            // Test INSERT OVERWRITE, only supported in Spark 3.x
+            spark.sql(
+              s"""
+                 | INSERT OVERWRITE $tableName
+                 | SELECT 100 as id, 'a100' as name, 299.0 as price, 
1706800227 as ts, 'cat10' as segment
+                 | """.stripMargin)
+            validateResults(
+              tableName,
+              s"SELECT id, name, cast(price as string), cast(ts as string), 
segment from $tableName",
+              tsGenFunc,
+              partitionGenFunc,
+              Seq(),
+              Seq(100, "a100", "299.0", 1706800227, "cat10")
+            )
+          }
+        }
+      }
+    }
+  }
+
+  test("Test table property isolation for partition path field config "
+    + "with custom key generator for Spark 3.1 and above") {
+    // Only testing Spark 3.1 and above as lower Spark versions do not support
+    // ALTER TABLE .. SET TBLPROPERTIES .. to store table-level properties in 
Hudi Catalog
+    if (HoodieSparkUtils.gteqSpark3_1) {
+      withTempDir { tmp => {
+        val tableNameNonPartitioned = generateTableName
+        val tableNameSimpleKey = generateTableName
+        val tableNameCustom1 = generateTableName
+        val tableNameCustom2 = generateTableName
+
+        val tablePathNonPartitioned = tmp.getCanonicalPath + "/" + 
tableNameNonPartitioned
+        val tablePathSimpleKey = tmp.getCanonicalPath + "/" + 
tableNameSimpleKey
+        val tablePathCustom1 = tmp.getCanonicalPath + "/" + tableNameCustom1
+        val tablePathCustom2 = tmp.getCanonicalPath + "/" + tableNameCustom2
+
+        val tableType = "MERGE_ON_READ"
+        val writePartitionFields1 = "segment:simple"
+        val writePartitionFields2 = "ts:timestamp,segment:simple"
+
+        prepareTableWithKeyGenerator(
+          tableNameNonPartitioned, tablePathNonPartitioned, tableType,
+          NONPARTITIONED_KEY_GEN_CLASS_NAME, "", Map())
+        prepareTableWithKeyGenerator(
+          tableNameSimpleKey, tablePathSimpleKey, tableType,
+          SIMPLE_KEY_GEN_CLASS_NAME, "segment", Map())
+        prepareTableWithKeyGenerator(
+          tableNameCustom1, tablePathCustom1, tableType,
+          CUSTOM_KEY_GEN_CLASS_NAME, writePartitionFields1, Map())
+        prepareTableWithKeyGenerator(
+          tableNameCustom2, tablePathCustom2, tableType,
+          CUSTOM_KEY_GEN_CLASS_NAME, writePartitionFields2, TS_KEY_GEN_CONFIGS)
+
+        // Non-partitioned table does not require additional partition path 
field write config
+        createTableWithSql(tableNameNonPartitioned, tablePathNonPartitioned, 
"")
+        // Partitioned table with simple key generator does not require 
additional partition path field write config
+        createTableWithSql(tableNameSimpleKey, tablePathSimpleKey, "")
+        // Partitioned table with custom key generator requires additional 
partition path field write config
+        // Without that, right now the SQL DML fails
+        createTableWithSql(tableNameCustom1, tablePathCustom1, "")
+        createTableWithSql(tableNameCustom2, tablePathCustom2,
+          s"hoodie.datasource.write.partitionpath.field = 
'$writePartitionFields2', "
+            + TS_KEY_GEN_CONFIGS.map(e => e._1 + " = '" + e._2 + 
"'").mkString(", "))
+
+        val segmentPartitionFunc = (_: Integer, segment: String) => segment
+        val customPartitionFunc = (ts: Integer, segment: String) => 
TS_FORMATTER_FUNC.apply(ts) + "/" + segment
+
+        testInsertInto1(tableNameNonPartitioned, TS_TO_STRING_FUNC, (_, _) => 
"")
+        testInsertInto1(tableNameSimpleKey, TS_TO_STRING_FUNC, 
segmentPartitionFunc)
+        // INSERT INTO should fail for tableNameCustom1
+        val sourceTableName = tableNameCustom1 + "_source"
+        prepareParquetSource(sourceTableName, Seq("(7, 'a7', 1399.0, 
1706800227, 'cat1')"))
+        assertThrows[IOException] {
+          spark.sql(
+            s"""
+               | INSERT INTO $tableNameCustom1
+               | SELECT * from $sourceTableName
+               | """.stripMargin)
+        }
+        testInsertInto1(tableNameCustom2, TS_FORMATTER_FUNC, 
customPartitionFunc)
+
+        // Now add the missing partition path field write config for 
tableNameCustom1
+        spark.sql(
+          s"""ALTER TABLE $tableNameCustom1
+             | SET TBLPROPERTIES (hoodie.datasource.write.partitionpath.field 
= '$writePartitionFields1')
+             | """.stripMargin)
+
+        // All tables should be able to do INSERT INTO without any problem,
+        // since the scope of the added write config is at the catalog table 
level
+        testInsertInto2(tableNameNonPartitioned, TS_TO_STRING_FUNC, (_, _) => 
"")
+        testInsertInto2(tableNameSimpleKey, TS_TO_STRING_FUNC, 
segmentPartitionFunc)
+        testInsertInto1(tableNameCustom1, TS_TO_STRING_FUNC, 
segmentPartitionFunc)
+        testInsertInto2(tableNameCustom2, TS_FORMATTER_FUNC, 
customPartitionFunc)
+      }
+      }
+    }
+  }
+
+  test("Test wrong partition path field write config with custom key 
generator") {
+    withTempDir { tmp => {
+      val tableName = generateTableName
+      val tablePath = tmp.getCanonicalPath + "/" + tableName
+      val tableType = "MERGE_ON_READ"
+      val writePartitionFields = "segment:simple,ts:timestamp"
+      val wrongWritePartitionFields = "segment:simple"
+      val customPartitionFunc = (ts: Integer, segment: String) => segment + 
"/" + TS_FORMATTER_FUNC.apply(ts)
+
+      prepareTableWithKeyGenerator(
+        tableName, tablePath, "MERGE_ON_READ",
+        CUSTOM_KEY_GEN_CLASS_NAME, writePartitionFields, TS_KEY_GEN_CONFIGS)
+
+      // CREATE TABLE should fail due to config conflict
+      assertThrows[HoodieException] {
+        createTableWithSql(tableName, tablePath,
+          s"hoodie.datasource.write.partitionpath.field = 
'$wrongWritePartitionFields', "
+            + TS_KEY_GEN_CONFIGS.map(e => e._1 + " = '" + e._2 + 
"'").mkString(", "))
+      }
+
+      createTableWithSql(tableName, tablePath,
+        s"hoodie.datasource.write.partitionpath.field = 
'$writePartitionFields', "
+          + TS_KEY_GEN_CONFIGS.map(e => e._1 + " = '" + e._2 + 
"'").mkString(", "))
+      // Set wrong write config
+      spark.sql(
+        s"""ALTER TABLE $tableName
+           | SET TBLPROPERTIES (hoodie.datasource.write.partitionpath.field = 
'$wrongWritePartitionFields')
+           | """.stripMargin)
+
+      // INSERT INTO should fail due to conflict between write and table 
config of partition path fields
+      val sourceTableName = tableName + "_source"
+      prepareParquetSource(sourceTableName, Seq("(7, 'a7', 1399.0, 1706800227, 
'cat1')"))
+      assertThrows[HoodieException] {
+        spark.sql(
+          s"""
+             | INSERT INTO $tableName
+             | SELECT * from $sourceTableName
+             | """.stripMargin)
+      }
+
+      // Only testing Spark 3.1 and above as lower Spark versions do not 
support
+      // ALTER TABLE .. SET TBLPROPERTIES .. to store table-level properties 
in Hudi Catalog
+      if (HoodieSparkUtils.gteqSpark3_1) {
+        // Now fix the partition path field write config for tableName
+        spark.sql(
+          s"""ALTER TABLE $tableName
+             | SET TBLPROPERTIES (hoodie.datasource.write.partitionpath.field 
= '$writePartitionFields')
+             | """.stripMargin)
+
+        // INSERT INTO should succeed now
+        testInsertInto1(tableName, TS_FORMATTER_FUNC, customPartitionFunc)
+      }
+    }
+    }
+  }
+
+  private def testInsertInto1(tableName: String,

Review Comment:
   This should be named better



##########
hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/spark/sql/hudi/ProvidesHoodieConfig.scala:
##########
@@ -528,6 +536,40 @@ object ProvidesHoodieConfig {
       filterNullValues(overridingOpts)
   }
 
+  /**
+   * @param tableConfigKeyGeneratorClassName     key generator class name in 
the table config.
+   * @param partitionFieldNamesWithoutKeyGenType partition field names without 
key generator types
+   *                                             from the table config.
+   * @param catalogTable                         HoodieCatalogTable instance 
to fetch table properties.
+   * @return the write config value to set for 
"hoodie.datasource.write.partitionpath.field".
+   */
+  def getPartitionPathFieldWriteConfig(tableConfigKeyGeneratorClassName: 
String,
+                                       partitionFieldNamesWithoutKeyGenType: 
String,
+                                       catalogTable: HoodieCatalogTable): 
String = {
+    if (StringUtils.isNullOrEmpty(tableConfigKeyGeneratorClassName)) {
+      partitionFieldNamesWithoutKeyGenType
+    } else {
+      val writeConfigPartitionField = 
catalogTable.catalogProperties.get(PARTITIONPATH_FIELD.key())
+      val keyGenClass = 
ReflectionUtils.getClass(tableConfigKeyGeneratorClassName)
+      if (classOf[CustomKeyGenerator].equals(keyGenClass)
+        || classOf[CustomAvroKeyGenerator].equals(keyGenClass)) {
+        // For custom key generator, we have to take the write config value 
from
+        // "hoodie.datasource.write.partitionpath.field" which contains the 
key generator
+        // type, whereas the table config only contains the prtition field 
names without
+        // key generator types.
+        if (writeConfigPartitionField.isDefined) {
+          writeConfigPartitionField.get
+        } else {
+          log.warn("Write config 
\"hoodie.datasource.write.partitionpath.field\" is not set for "
+            + "custom key generator. This may fail the write operation.")
+          partitionFieldNamesWithoutKeyGenType

Review Comment:
   Does it fail in an obvious way, so the user would know how to correct?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscr...@hudi.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to