This is an automated email from the ASF dual-hosted git repository.

yihua pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/hudi.git


The following commit(s) were added to refs/heads/master by this push:
     new cc797953866e [HUDI-9606] Support user provided key generator class 
(#13656)
cc797953866e is described below

commit cc797953866e8834072a814aa7fff8c073f943be
Author: Lin Liu <[email protected]>
AuthorDate: Mon Sep 29 21:18:48 2025 -0700

    [HUDI-9606] Support user provided key generator class (#13656)
    
    Co-authored-by: Y Ethan Guo <[email protected]>
---
 .../hudi/keygen/MockUserProvidedKeyGenerator.java  | 128 +++++++++++++++++++++
 .../hudi/common/table/HoodieTableMetaClient.java   |  13 ++-
 .../hudi/keygen/constant/KeyGeneratorType.java     |  47 +++++---
 .../hudi/keygen/constant/TestKeyGeneratorType.java |  49 ++++++++
 .../apache/hudi/functional/TestCOWDataSource.scala |  97 +++++++++++++++-
 .../apache/hudi/functional/TestMORDataSource.scala |  98 +++++++++++++++-
 6 files changed, 408 insertions(+), 24 deletions(-)

diff --git 
a/hudi-client/hudi-spark-client/src/test/java/org/apache/hudi/keygen/MockUserProvidedKeyGenerator.java
 
b/hudi-client/hudi-spark-client/src/test/java/org/apache/hudi/keygen/MockUserProvidedKeyGenerator.java
new file mode 100644
index 000000000000..0c310d454fa4
--- /dev/null
+++ 
b/hudi-client/hudi-spark-client/src/test/java/org/apache/hudi/keygen/MockUserProvidedKeyGenerator.java
@@ -0,0 +1,128 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hudi.keygen;
+
+import org.apache.hudi.common.config.TypedProperties;
+import org.apache.hudi.common.util.Option;
+import org.apache.hudi.keygen.constant.KeyGeneratorOptions;
+
+import org.apache.avro.generic.GenericRecord;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.catalyst.InternalRow;
+import org.apache.spark.sql.types.StructType;
+import org.apache.spark.unsafe.types.UTF8String;
+
+import java.util.Collections;
+
+import static org.apache.hudi.common.util.ValidationUtils.checkArgument;
+
+/**
+ * This class is used for test purpose, and should never be used
+ * in other ways.
+ */
+public class MockUserProvidedKeyGenerator extends BuiltinKeyGenerator {
+  private static final String PREFIX = "MOCK_";
+
+  private final SimpleAvroKeyGenerator simpleAvroKeyGenerator;
+
+  public MockUserProvidedKeyGenerator(TypedProperties props) {
+    this(props, 
Option.ofNullable(props.getString(KeyGeneratorOptions.RECORDKEY_FIELD_NAME.key(),
 null)),
+        props.getString(KeyGeneratorOptions.PARTITIONPATH_FIELD_NAME.key()));
+  }
+
+  public MockUserProvidedKeyGenerator(TypedProperties props, String 
partitionPathField) {
+    this(props, Option.empty(), partitionPathField);
+  }
+
+  public MockUserProvidedKeyGenerator(TypedProperties props, Option<String> 
recordKeyField, String partitionPathField) {
+    super(props);
+    // Make sure key-generator is configured properly
+    validateRecordKey(recordKeyField);
+    validatePartitionPath(partitionPathField);
+
+    this.recordKeyFields = !recordKeyField.isPresent() ? 
Collections.emptyList() : Collections.singletonList(recordKeyField.get());
+    this.partitionPathFields = partitionPathField == null ? 
Collections.emptyList() : Collections.singletonList(partitionPathField);
+    this.simpleAvroKeyGenerator = new SimpleAvroKeyGenerator(props, 
recordKeyField, partitionPathField);
+  }
+
+  @Override
+  public String getRecordKey(GenericRecord record) {
+    return PREFIX + simpleAvroKeyGenerator.getRecordKey(record);
+  }
+
+  @Override
+  public String getPartitionPath(GenericRecord record) {
+    return PREFIX + simpleAvroKeyGenerator.getPartitionPath(record);
+  }
+
+  @Override
+  public String getRecordKey(Row row) {
+    tryInitRowAccessor(row.schema());
+
+    Object[] recordKeys = rowAccessor.getRecordKeyParts(row);
+    // NOTE: [[SimpleKeyGenerator]] is restricted to allow only primitive 
(non-composite)
+    //       record-key field
+    if (recordKeys[0] == null) {
+      return handleNullRecordKey(null);
+    } else {
+      return requireNonNullNonEmptyKey(recordKeys[0].toString());
+    }
+  }
+
+  @Override
+  public UTF8String getRecordKey(InternalRow internalRow, StructType schema) {
+    tryInitRowAccessor(schema);
+
+    Object[] recordKeyValues = rowAccessor.getRecordKeyParts(internalRow);
+    // NOTE: [[SimpleKeyGenerator]] is restricted to allow only primitive 
(non-composite)
+    //       record-key field
+    if (recordKeyValues[0] == null) {
+      return handleNullRecordKey(null);
+    } else if (recordKeyValues[0] instanceof UTF8String) {
+      return requireNonNullNonEmptyKey((UTF8String) recordKeyValues[0]);
+    } else {
+      return 
requireNonNullNonEmptyKey(UTF8String.fromString(recordKeyValues[0].toString()));
+    }
+  }
+
+  @Override
+  public String getPartitionPath(Row row) {
+    tryInitRowAccessor(row.schema());
+    return combinePartitionPath(rowAccessor.getRecordPartitionPathValues(row));
+  }
+
+  @Override
+  public UTF8String getPartitionPath(InternalRow row, StructType schema) {
+    tryInitRowAccessor(schema);
+    return 
combinePartitionPathUnsafe(rowAccessor.getRecordPartitionPathValues(row));
+  }
+
+  private static void validatePartitionPath(String partitionPathField) {
+    checkArgument(partitionPathField == null || !partitionPathField.isEmpty(),
+        "Partition-path field has to be non-empty!");
+    checkArgument(partitionPathField == null || 
!partitionPathField.contains(FIELDS_SEP),
+        String.format("Single partition-path field is expected; provided 
(%s)", partitionPathField));
+  }
+
+  private void validateRecordKey(Option<String> recordKeyField) {
+    checkArgument(!recordKeyField.isPresent() || 
!recordKeyField.get().isEmpty(),
+        "Record key field has to be non-empty!");
+    checkArgument(!recordKeyField.isPresent() || 
!recordKeyField.get().contains(FIELDS_SEP),
+        String.format("Single record-key field is expected; provided (%s)", 
recordKeyField));
+  }
+}
diff --git 
a/hudi-common/src/main/java/org/apache/hudi/common/table/HoodieTableMetaClient.java
 
b/hudi-common/src/main/java/org/apache/hudi/common/table/HoodieTableMetaClient.java
index 75198a7176e9..a7e91f69ef00 100644
--- 
a/hudi-common/src/main/java/org/apache/hudi/common/table/HoodieTableMetaClient.java
+++ 
b/hudi-common/src/main/java/org/apache/hudi/common/table/HoodieTableMetaClient.java
@@ -104,6 +104,7 @@ import static 
org.apache.hudi.common.util.StringUtils.getUTF8Bytes;
 import static org.apache.hudi.common.util.ValidationUtils.checkArgument;
 import static org.apache.hudi.common.util.ValidationUtils.checkState;
 import static org.apache.hudi.io.storage.HoodieIOFactory.getIOFactory;
+import static org.apache.hudi.keygen.constant.KeyGeneratorType.USER_PROVIDED;
 import static 
org.apache.hudi.metadata.HoodieIndexVersion.isValidIndexDefinition;
 
 /**
@@ -1563,9 +1564,17 @@ public class HoodieTableMetaClient implements 
Serializable {
         tableConfig.setValue(HoodieTableConfig.POPULATE_META_FIELDS, 
Boolean.toString(populateMetaFields));
       }
       if (null != keyGeneratorClassProp) {
-        tableConfig.setValue(HoodieTableConfig.KEY_GENERATOR_TYPE, 
KeyGeneratorType.fromClassName(keyGeneratorClassProp).name());
+        KeyGeneratorType type = 
KeyGeneratorType.fromClassName(keyGeneratorClassProp);
+        tableConfig.setValue(HoodieTableConfig.KEY_GENERATOR_TYPE, 
type.name());
+        if (USER_PROVIDED == type) {
+          tableConfig.setValue(HoodieTableConfig.KEY_GENERATOR_CLASS_NAME, 
keyGeneratorClassProp);
+        }
       } else if (null != keyGeneratorType) {
-        tableConfig.setValue(HoodieTableConfig.KEY_GENERATOR_TYPE, 
keyGeneratorType);
+        checkArgument(!keyGeneratorType.equals(USER_PROVIDED.name()),
+            String.format("When key generator type is %s, the key generator 
class must be set properly",
+                USER_PROVIDED.name()));
+        KeyGeneratorType type = KeyGeneratorType.valueOf(keyGeneratorType);
+        tableConfig.setValue(HoodieTableConfig.KEY_GENERATOR_TYPE, 
type.name());
       }
       if (null != hiveStylePartitioningEnable) {
         tableConfig.setValue(HoodieTableConfig.HIVE_STYLE_PARTITIONING_ENABLE, 
Boolean.toString(hiveStylePartitioningEnable));
diff --git 
a/hudi-common/src/main/java/org/apache/hudi/keygen/constant/KeyGeneratorType.java
 
b/hudi-common/src/main/java/org/apache/hudi/keygen/constant/KeyGeneratorType.java
index b84cba3e28ac..6ea63da66db8 100644
--- 
a/hudi-common/src/main/java/org/apache/hudi/keygen/constant/KeyGeneratorType.java
+++ 
b/hudi-common/src/main/java/org/apache/hudi/keygen/constant/KeyGeneratorType.java
@@ -23,6 +23,10 @@ import org.apache.hudi.common.config.EnumFieldDescription;
 import org.apache.hudi.common.config.HoodieConfig;
 import org.apache.hudi.common.config.TypedProperties;
 import org.apache.hudi.common.util.ConfigUtils;
+import org.apache.hudi.common.util.StringUtils;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
 
 import javax.annotation.Nullable;
 
@@ -96,10 +100,11 @@ public enum KeyGeneratorType {
   @EnumFieldDescription("Meant to be used internally for the spark sql MERGE 
INTO command.")
   
SPARK_SQL_MERGE_INTO("org.apache.spark.sql.hudi.command.MergeIntoKeyGenerator"),
 
-  @EnumFieldDescription("A test KeyGenerator for deltastreamer tests.")
-  
STREAMER_TEST("org.apache.hudi.utilities.deltastreamer.TestHoodieDeltaStreamer$TestGenerator");
+  @EnumFieldDescription("A KeyGenerator specified from the configuration.")
+  USER_PROVIDED(StringUtils.EMPTY_STRING);
 
-  private final String className;
+  private String className;
+  private static final Logger LOG = 
LoggerFactory.getLogger(KeyGeneratorType.class);
 
   KeyGeneratorType(String className) {
     this.className = className;
@@ -110,12 +115,15 @@ public enum KeyGeneratorType {
   }
 
   public static KeyGeneratorType fromClassName(String className) {
+    if (StringUtils.isNullOrEmpty(className)) {
+      throw new IllegalArgumentException("Invalid keyGenerator class: " + 
className);
+    }
     for (KeyGeneratorType type : KeyGeneratorType.values()) {
       if (type.getClassName().equals(className)) {
         return type;
       }
     }
-    throw new IllegalArgumentException("No KeyGeneratorType found for class 
name: " + className);
+    return USER_PROVIDED;
   }
 
   public static List<String> getNames() {
@@ -125,30 +133,37 @@ public enum KeyGeneratorType {
     return names;
   }
 
-  @Nullable
-  public static String getKeyGeneratorClassName(HoodieConfig config) {
-    return getKeyGeneratorClassName(config.getProps());
-  }
-
   @Nullable
   public static String getKeyGeneratorClassName(TypedProperties props) {
+    // For USER_PROVIDED type, since we set key generator class only for this 
type.
     if (ConfigUtils.containsConfigProperty(props, KEY_GENERATOR_CLASS_NAME)) {
       return ConfigUtils.getStringWithAltKeys(props, KEY_GENERATOR_CLASS_NAME);
     }
+    // For other types.
+    KeyGeneratorType keyGeneratorType;
     if (ConfigUtils.containsConfigProperty(props, KEY_GENERATOR_TYPE)) {
-      return KeyGeneratorType.valueOf(ConfigUtils.getStringWithAltKeys(props, 
KEY_GENERATOR_TYPE)).getClassName();
+      keyGeneratorType = 
KeyGeneratorType.valueOf(ConfigUtils.getStringWithAltKeys(props, 
KEY_GENERATOR_TYPE));
+      // For USER_PROVIDED type, the key generator class has to be provided.
+      if (USER_PROVIDED == keyGeneratorType) {
+        throw new IllegalArgumentException("No key generator class is provided 
properly for type: " + USER_PROVIDED.name());
+      }
+      return keyGeneratorType.getClassName();
     }
+    // No key generator information is provided.
+    LOG.warn("No key generator type is set properly");
     return null;
   }
 
+  @Nullable
+  public static String getKeyGeneratorClassName(HoodieConfig config) {
+    return getKeyGeneratorClassName(config.getProps());
+  }
+
   @Nullable
   public static String getKeyGeneratorClassName(Map<String, String> config) {
-    if (config.containsKey(KEY_GENERATOR_CLASS_NAME.key())) {
-      return config.get(KEY_GENERATOR_CLASS_NAME.key());
-    } else if (config.containsKey(KEY_GENERATOR_TYPE.key())) {
-      return 
KeyGeneratorType.valueOf(config.get(KEY_GENERATOR_TYPE.key())).getClassName();
-    }
-    return null;
+    TypedProperties props = new TypedProperties();
+    config.forEach(props::setProperty);
+    return getKeyGeneratorClassName(props);
   }
 
   public static boolean isComplexKeyGenerator(HoodieConfig config) {
diff --git 
a/hudi-common/src/test/java/org/apache/hudi/keygen/constant/TestKeyGeneratorType.java
 
b/hudi-common/src/test/java/org/apache/hudi/keygen/constant/TestKeyGeneratorType.java
index 8c9db7cae797..acd5fc739c1a 100644
--- 
a/hudi-common/src/test/java/org/apache/hudi/keygen/constant/TestKeyGeneratorType.java
+++ 
b/hudi-common/src/test/java/org/apache/hudi/keygen/constant/TestKeyGeneratorType.java
@@ -22,11 +22,18 @@ package org.apache.hudi.keygen.constant;
 import org.apache.hudi.common.config.HoodieConfig;
 
 import org.junit.jupiter.api.Test;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.Arguments;
+import org.junit.jupiter.params.provider.MethodSource;
+
+import java.util.stream.Stream;
 
 import static 
org.apache.hudi.common.table.HoodieTableConfig.KEY_GENERATOR_CLASS_NAME;
 import static 
org.apache.hudi.common.table.HoodieTableConfig.KEY_GENERATOR_TYPE;
+import static org.junit.jupiter.api.Assertions.assertEquals;
 import static org.junit.jupiter.api.Assertions.assertFalse;
 import static org.junit.jupiter.api.Assertions.assertTrue;
+import static org.junit.jupiter.api.Assertions.assertThrows;
 
 public class TestKeyGeneratorType {
   @Test
@@ -89,4 +96,46 @@ public class TestKeyGeneratorType {
 
     assertFalse(KeyGeneratorType.isComplexKeyGenerator(config));
   }
+
+  private static Stream<Arguments> testFromClassNameParams() {
+    return Stream.of(
+        Arguments.of("org.apache.hudi.keygen.SimpleKeyGenerator", 
KeyGeneratorType.SIMPLE),
+        Arguments.of("org.apache.hudi.keygen.SimpleAvroKeyGenerator", 
KeyGeneratorType.SIMPLE_AVRO),
+        Arguments.of("org.apache.hudi.keygen.ComplexKeyGenerator", 
KeyGeneratorType.COMPLEX),
+        Arguments.of("org.apache.hudi.keygen.ComplexAvroKeyGenerator", 
KeyGeneratorType.COMPLEX_AVRO),
+        Arguments.of("org.apache.hudi.keygen.TimestampBasedKeyGenerator", 
KeyGeneratorType.TIMESTAMP),
+        Arguments.of("org.apache.hudi.keygen.TimestampBasedAvroKeyGenerator", 
KeyGeneratorType.TIMESTAMP_AVRO),
+        Arguments.of("org.apache.hudi.keygen.CustomKeyGenerator", 
KeyGeneratorType.CUSTOM),
+        Arguments.of("org.apache.hudi.keygen.CustomAvroKeyGenerator", 
KeyGeneratorType.CUSTOM_AVRO),
+        Arguments.of("org.apache.hudi.keygen.NonpartitionedKeyGenerator", 
KeyGeneratorType.NON_PARTITION),
+        Arguments.of("org.apache.hudi.keygen.NonpartitionedAvroKeyGenerator", 
KeyGeneratorType.NON_PARTITION_AVRO),
+        Arguments.of("org.apache.hudi.keygen.GlobalDeleteKeyGenerator", 
KeyGeneratorType.GLOBAL_DELETE),
+        Arguments.of("org.apache.hudi.keygen.GlobalAvroDeleteKeyGenerator", 
KeyGeneratorType.GLOBAL_DELETE_AVRO),
+        
Arguments.of("org.apache.hudi.keygen.AutoRecordGenWrapperKeyGenerator", 
KeyGeneratorType.AUTO_RECORD),
+        
Arguments.of("org.apache.hudi.keygen.AutoRecordGenWrapperAvroKeyGenerator", 
KeyGeneratorType.AUTO_RECORD_AVRO),
+        
Arguments.of("org.apache.hudi.metadata.HoodieTableMetadataKeyGenerator", 
KeyGeneratorType.HOODIE_TABLE_METADATA),
+        Arguments.of("org.apache.spark.sql.hudi.command.SqlKeyGenerator", 
KeyGeneratorType.SPARK_SQL),
+        Arguments.of("org.apache.spark.sql.hudi.command.UuidKeyGenerator", 
KeyGeneratorType.SPARK_SQL_UUID),
+        
Arguments.of("org.apache.spark.sql.hudi.command.MergeIntoKeyGenerator", 
KeyGeneratorType.SPARK_SQL_MERGE_INTO),
+        Arguments.of("org.apache.hudi.keygen.CustomUserProvidedKeyGenerator", 
KeyGeneratorType.USER_PROVIDED),
+        Arguments.of("com.example.CustomKeyGenerator", 
KeyGeneratorType.USER_PROVIDED)
+    );
+  }
+
+  @ParameterizedTest
+  @MethodSource("testFromClassNameParams")
+  void testFromClassName(String className, KeyGeneratorType expectedType) {
+    KeyGeneratorType result = KeyGeneratorType.fromClassName(className);
+    assertEquals(expectedType, result);
+  }
+
+  @Test
+  void testFromClassNameWithNull() {
+    assertThrows(IllegalArgumentException.class, () -> 
KeyGeneratorType.fromClassName(null));
+  }
+
+  @Test
+  void testFromClassNameWithEmpty() {
+    assertThrows(IllegalArgumentException.class, () -> 
KeyGeneratorType.fromClassName(""));
+  }
 }
diff --git 
a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestCOWDataSource.scala
 
b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestCOWDataSource.scala
index 96d9685638a1..873803018939 100644
--- 
a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestCOWDataSource.scala
+++ 
b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestCOWDataSource.scala
@@ -41,7 +41,7 @@ import org.apache.hudi.exception.{HoodieException, 
SchemaBackwardsCompatibilityE
 import org.apache.hudi.exception.ExceptionUtil.getRootCause
 import org.apache.hudi.hive.HiveSyncConfigHolder
 import org.apache.hudi.keygen.{ComplexKeyGenerator, CustomKeyGenerator, 
GlobalDeleteKeyGenerator, NonpartitionedKeyGenerator, SimpleKeyGenerator, 
TimestampBasedKeyGenerator}
-import org.apache.hudi.keygen.constant.KeyGeneratorOptions
+import org.apache.hudi.keygen.constant.{KeyGeneratorOptions, KeyGeneratorType}
 import org.apache.hudi.metrics.{Metrics, MetricsReporterType}
 import org.apache.hudi.storage.{StoragePath, StoragePathFilter}
 import org.apache.hudi.table.HoodieSparkTable
@@ -436,10 +436,11 @@ class TestCOWDataSource extends HoodieSparkClientTestBase 
with ScalaAssertionSup
     assertLastCommitIsUpsert()
   }
 
-  private def writeToHudi(opts: Map[String, String], df: Dataset[Row]): Unit = 
{
+  private def writeToHudi(opts: Map[String, String], df: Dataset[Row],
+                          operation: String = 
DataSourceWriteOptions.INSERT_OPERATION_OPT_VAL): Unit = {
     df.write.format("hudi")
       .options(opts)
-      .option(DataSourceWriteOptions.OPERATION.key, 
DataSourceWriteOptions.INSERT_OPERATION_OPT_VAL)
+      .option(DataSourceWriteOptions.OPERATION.key, operation)
       .mode(SaveMode.Append)
       .save(basePath)
   }
@@ -2112,7 +2113,6 @@ class TestCOWDataSource extends HoodieSparkClientTestBase 
with ScalaAssertionSup
       .save(basePath)
   }
 
-
   @Test
   def testReadOfAnEmptyTable(): Unit = {
     val (writeOpts, _) = getWriterReaderOpts(HoodieRecordType.AVRO)
@@ -2412,6 +2412,67 @@ class TestCOWDataSource extends 
HoodieSparkClientTestBase with ScalaAssertionSup
     }
   }
 
+  @ParameterizedTest
+  @MethodSource(Array("provideParamsForKeyGenTest"))
+  def testUserProvidedKeyGeneratorClass(keyGenClass: Option[String],
+                                        keyGenType: Option[String]): Unit = {
+    val recordType = HoodieRecordType.AVRO
+    var opts: Map[String, String] = Map()
+    if (keyGenClass.isPresent) {
+      opts = opts ++ Map(HoodieWriteConfig.KEYGENERATOR_CLASS_NAME.key -> 
keyGenClass.get)
+    }
+    if (keyGenType.isPresent) {
+      opts = opts ++ Map(HoodieWriteConfig.KEYGENERATOR_TYPE.key -> 
keyGenType.get)
+    }
+    val (writeOpts, readOpts) = getWriterReaderOpts(
+      recordType,
+      CommonOptionUtils.commonOpts ++ opts ++ Map(
+        DataSourceWriteOptions.PARTITIONPATH_FIELD.key -> "partition")
+    )
+    val expectedKeyGenType = if (keyGenClass.isEmpty) {
+      // By default SIMPLE type is returned when no class is provided.
+      KeyGeneratorType.SIMPLE.name
+    } else {
+      KeyGeneratorType.fromClassName(keyGenClass.get()).name
+    }
+
+    // Insert.
+    val records = recordsToStrings(dataGen.generateInserts("000", 
10)).asScala.toList
+    val inputDF = spark.read.json(spark.sparkContext.parallelize(records, 2))
+    writeToHudi(writeOpts, inputDF)
+    var actualDF = spark.read.format("hudi").options(readOpts).load(basePath)
+
+    // Transform the input keys based on the key generator to match the 
expected format
+    val inputKeyDF = TestCOWDataSource.transformRecordKeyColumn(
+        inputDF.select("_row_key"), "_row_key", 
KeyGeneratorType.valueOf(expectedKeyGenType))
+      .sort("_row_key")
+    var actualKeyDF = 
actualDF.select("_hoodie_record_key").sort("_hoodie_record_key")
+    assertTrue(inputKeyDF.except(actualKeyDF).isEmpty && 
actualKeyDF.except(inputKeyDF).isEmpty)
+    val metaClient = getHoodieMetaClient(storageConf, basePath)
+    val actualKeyGenType = metaClient.getTableConfig
+      .getProps.getString(HoodieTableConfig.KEY_GENERATOR_TYPE.key, null)
+    assertEquals(expectedKeyGenType, actualKeyGenType)
+    // For USER_PROVIDED type, the class should exist in table config.
+    if (KeyGeneratorType.USER_PROVIDED.name == actualKeyGenType) {
+      assertEquals(keyGenClass.get(), 
metaClient.getTableConfig.getKeyGeneratorClassName)
+    }
+
+    // First update.
+    val firstUpdate = 
recordsToStrings(dataGen.generateUpdatesForAllRecords("001")).asScala.toList
+    val firstUpdateDF = 
spark.read.json(spark.sparkContext.parallelize(firstUpdate, 2))
+    writeToHudi(writeOpts, firstUpdateDF, 
DataSourceWriteOptions.UPSERT_OPERATION_OPT_VAL)
+    actualDF = spark.read.format("hudi").options(readOpts).load(basePath)
+    actualKeyDF = 
actualDF.select("_hoodie_record_key").sort("_hoodie_record_key")
+    assertTrue(inputKeyDF.except(actualKeyDF).isEmpty && 
actualKeyDF.except(inputKeyDF).isEmpty)
+
+    // Second update.
+    // Change keyGenerator class name to generate exception.
+    val opt = writeOpts ++ Map(
+      "hoodie.datasource.write.keygenerator.class" -> 
"org.apache.hudi.keygen.SqlKeyGenerator")
+    assertThrows(classOf[HoodieException])({
+      writeToHudi(opt, firstUpdateDF, 
DataSourceWriteOptions.UPSERT_OPERATION_OPT_VAL)
+    })
+  }
 }
 
 object TestCOWDataSource {
@@ -2423,6 +2484,17 @@ object TestCOWDataSource {
     }
   }
 
+  def transformRecordKeyColumn(df: DataFrame, columnName: String, keyGenType: 
KeyGeneratorType): DataFrame = {
+    keyGenType match {
+      case KeyGeneratorType.USER_PROVIDED =>
+        df.withColumn(columnName, concat(lit(s"MOCK_"), col(columnName)))
+      case KeyGeneratorType.COMPLEX =>
+        df.withColumn(columnName, concat(lit(s"_row_key:"), col(columnName)))
+      case _ =>
+        df
+    }
+  }
+
   def tableVersionCreationTestCases = {
     val autoUpgradeValues = Array("true", "false")
     val targetVersions = Array("1", "2", "3", "4", "5", "6", "7", "8", "9", 
"null")
@@ -2430,4 +2502,21 @@ object TestCOWDataSource {
       (autoUpgrade: String) => targetVersions.map(
         (targetVersion: String) => Arguments.of(autoUpgrade, targetVersion)))
   }
+
+  def provideParamsForKeyGenTest(): java.util.List[Arguments] = {
+    java.util.Arrays.asList(
+      Arguments.of(
+        Option.of("org.apache.hudi.keygen.MockUserProvidedKeyGenerator"),
+        Option.of(KeyGeneratorType.USER_PROVIDED.name())),
+      Arguments.of(
+        Option.empty(),
+        Option.of(KeyGeneratorType.SIMPLE.name())),
+      Arguments.of(
+        Option.of("org.apache.hudi.keygen.SimpleAvroKeyGenerator"),
+        Option.of(KeyGeneratorType.SIMPLE.name())),
+      Arguments.of(
+        Option.of("org.apache.hudi.keygen.ComplexKeyGenerator"),
+        Option.empty())
+    )
+  }
 }
diff --git 
a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestMORDataSource.scala
 
b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestMORDataSource.scala
index b381c5570b17..39179e966222 100644
--- 
a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestMORDataSource.scala
+++ 
b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestMORDataSource.scala
@@ -31,8 +31,9 @@ import 
org.apache.hudi.common.testutils.HoodieTestDataGenerator.recordsToStrings
 import org.apache.hudi.common.util.Option
 import org.apache.hudi.common.util.StringUtils.isNullOrEmpty
 import org.apache.hudi.config.{HoodieCleanConfig, HoodieCompactionConfig, 
HoodieIndexConfig, HoodieWriteConfig}
-import org.apache.hudi.exception.HoodieUpgradeDowngradeException
+import org.apache.hudi.exception.{HoodieException, 
HoodieUpgradeDowngradeException}
 import org.apache.hudi.index.HoodieIndex.IndexType
+import org.apache.hudi.keygen.constant.KeyGeneratorType
 import 
org.apache.hudi.metadata.HoodieTableMetadataUtil.{metadataPartitionExists, 
PARTITION_NAME_SECONDARY_INDEX_PREFIX}
 import org.apache.hudi.storage.{StoragePath, StoragePathInfo}
 import org.apache.hudi.table.action.compact.CompactionTriggerStrategy
@@ -49,7 +50,7 @@ import org.apache.spark.sql.types.{LongType, StringType, 
StructField, StructType
 import org.junit.jupiter.api.{AfterEach, BeforeEach, Test}
 import org.junit.jupiter.api.Assertions.{assertEquals, assertFalse, assertTrue}
 import org.junit.jupiter.params.ParameterizedTest
-import org.junit.jupiter.params.provider.{CsvSource, EnumSource, ValueSource}
+import org.junit.jupiter.params.provider.{Arguments, CsvSource, EnumSource, 
MethodSource, ValueSource}
 
 import java.io.File
 import java.nio.file.Files
@@ -2195,6 +2196,71 @@ class TestMORDataSource extends 
HoodieSparkClientTestBase with SparkDatasetMixin
     }
   }
 
+  @ParameterizedTest
+  @MethodSource(Array("provideParamsForKeyGenTest"))
+  def testUserProvidedKeyGeneratorClass(keyGenClass: Option[String],
+                                        keyGenType: Option[String]): Unit = {
+    val recordType = HoodieRecordType.AVRO
+    var opts: Map[String, String] = Map(
+      DataSourceWriteOptions.TABLE_TYPE.key -> 
DataSourceWriteOptions.MOR_TABLE_TYPE_OPT_VAL,
+      HoodieWriteConfig.MERGE_SMALL_FILE_GROUP_CANDIDATES_LIMIT.key -> "0"
+    )
+    if (keyGenClass.isPresent) {
+      opts = opts ++ Map(HoodieWriteConfig.KEYGENERATOR_CLASS_NAME.key -> 
keyGenClass.get)
+    }
+    if (keyGenType.isPresent) {
+      opts = opts ++ Map(HoodieWriteConfig.KEYGENERATOR_TYPE.key -> 
keyGenType.get)
+    }
+    val (writeOpts, readOpts) = getWriterReaderOpts(
+      recordType,
+      CommonOptionUtils.commonOpts ++ opts ++ Map(
+        DataSourceWriteOptions.PARTITIONPATH_FIELD.key -> "partition")
+    )
+    val expectedKeyGenType = if (keyGenClass.isEmpty) {
+      // By default SIMPLE type is returned when no class is provided.
+      KeyGeneratorType.SIMPLE.name
+    } else {
+      KeyGeneratorType.fromClassName(keyGenClass.get()).name
+    }
+
+    // Insert.
+    val records = recordsToStrings(dataGen.generateInserts("000", 
10)).asScala.toList
+    val inputDF = spark.read.json(spark.sparkContext.parallelize(records, 2))
+    writeToHudi(writeOpts, inputDF)
+    var actualDF = spark.read.format("hudi").options(readOpts).load(basePath)
+
+    // Transform the input keys based on the key generator to match the 
expected format
+    val inputKeyDF = TestCOWDataSource.transformRecordKeyColumn(
+        inputDF.select("_row_key"), "_row_key", 
KeyGeneratorType.valueOf(expectedKeyGenType))
+      .sort("_row_key")
+    var actualKeyDF = 
actualDF.select("_hoodie_record_key").sort("_hoodie_record_key")
+    assertTrue(inputKeyDF.except(actualKeyDF).isEmpty && 
actualKeyDF.except(inputKeyDF).isEmpty)
+    val metaClient = getHoodieMetaClient(storageConf, basePath)
+    val actualKeyGenType = metaClient.getTableConfig
+      .getProps.getString(HoodieTableConfig.KEY_GENERATOR_TYPE.key, null)
+    assertEquals(expectedKeyGenType, actualKeyGenType)
+    // For USER_PROVIDED type, the class should exist in table config.
+    if (KeyGeneratorType.USER_PROVIDED.name == actualKeyGenType) {
+      assertEquals(keyGenClass.get(), 
metaClient.getTableConfig.getKeyGeneratorClassName)
+    }
+
+    // First update.
+    val firstUpdate = 
recordsToStrings(dataGen.generateUpdatesForAllRecords("001")).asScala.toList
+    val firstUpdateDF = 
spark.read.json(spark.sparkContext.parallelize(firstUpdate, 2))
+    writeToHudi(writeOpts, firstUpdateDF, 
DataSourceWriteOptions.UPSERT_OPERATION_OPT_VAL)
+    actualDF = spark.read.format("hudi").options(readOpts).load(basePath)
+    actualKeyDF = 
actualDF.select("_hoodie_record_key").sort("_hoodie_record_key")
+    assertTrue(inputKeyDF.except(actualKeyDF).isEmpty && 
actualKeyDF.except(inputKeyDF).isEmpty)
+
+    // Second update.
+    // Change keyGenerator class name to generate exception.
+    val opt = writeOpts ++ Map(
+      "hoodie.datasource.write.keygenerator.class" -> 
"org.apache.hudi.keygen.SqlKeyGenerator")
+    assertThrows(classOf[HoodieException])({
+      writeToHudi(opt, firstUpdateDF, 
DataSourceWriteOptions.UPSERT_OPERATION_OPT_VAL)
+    })
+  }
+
   private def loadFixtureTable(testBasePath: String, version: 
HoodieTableVersion): HoodieTableMetaClient = {
     val fixtureName = getFixtureName(version, "")
     val resourcePath = 
s"/upgrade-downgrade-fixtures/unsupported-upgrade-tables/$fixtureName"
@@ -2251,4 +2317,32 @@ class TestMORDataSource extends 
HoodieSparkClientTestBase with SparkDatasetMixin
       Row("id11", "TestUser3", 11000L, "2023-01-06")
     )
   }
+
+  private def writeToHudi(opts: Map[String, String], df: Dataset[Row],
+                          operation: String = 
DataSourceWriteOptions.INSERT_OPERATION_OPT_VAL): Unit = {
+    df.write.format("hudi")
+      .options(opts)
+      .option(DataSourceWriteOptions.OPERATION.key, operation)
+      .mode(SaveMode.Append)
+      .save(basePath)
+  }
+}
+
+object TestMORDataSource {
+  def provideParamsForKeyGenTest(): java.util.List[Arguments] = {
+    java.util.Arrays.asList(
+      Arguments.of(
+        Option.of("org.apache.hudi.keygen.MockUserProvidedKeyGenerator"),
+        Option.of(KeyGeneratorType.USER_PROVIDED.name())),
+      Arguments.of(
+        Option.empty(),
+        Option.of(KeyGeneratorType.SIMPLE.name())),
+      Arguments.of(
+        Option.of("org.apache.hudi.keygen.SimpleAvroKeyGenerator"),
+        Option.of(KeyGeneratorType.SIMPLE.name())),
+      Arguments.of(
+        Option.of("org.apache.hudi.keygen.ComplexKeyGenerator"),
+        Option.empty())
+    )
+  }
 }

Reply via email to