This is an automated email from the ASF dual-hosted git repository.
yihua pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/hudi.git
The following commit(s) were added to refs/heads/master by this push:
new cc797953866e [HUDI-9606] Support user provided key generator class
(#13656)
cc797953866e is described below
commit cc797953866e8834072a814aa7fff8c073f943be
Author: Lin Liu <[email protected]>
AuthorDate: Mon Sep 29 21:18:48 2025 -0700
[HUDI-9606] Support user provided key generator class (#13656)
Co-authored-by: Y Ethan Guo <[email protected]>
---
.../hudi/keygen/MockUserProvidedKeyGenerator.java | 128 +++++++++++++++++++++
.../hudi/common/table/HoodieTableMetaClient.java | 13 ++-
.../hudi/keygen/constant/KeyGeneratorType.java | 47 +++++---
.../hudi/keygen/constant/TestKeyGeneratorType.java | 49 ++++++++
.../apache/hudi/functional/TestCOWDataSource.scala | 97 +++++++++++++++-
.../apache/hudi/functional/TestMORDataSource.scala | 98 +++++++++++++++-
6 files changed, 408 insertions(+), 24 deletions(-)
diff --git
a/hudi-client/hudi-spark-client/src/test/java/org/apache/hudi/keygen/MockUserProvidedKeyGenerator.java
b/hudi-client/hudi-spark-client/src/test/java/org/apache/hudi/keygen/MockUserProvidedKeyGenerator.java
new file mode 100644
index 000000000000..0c310d454fa4
--- /dev/null
+++
b/hudi-client/hudi-spark-client/src/test/java/org/apache/hudi/keygen/MockUserProvidedKeyGenerator.java
@@ -0,0 +1,128 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hudi.keygen;
+
+import org.apache.hudi.common.config.TypedProperties;
+import org.apache.hudi.common.util.Option;
+import org.apache.hudi.keygen.constant.KeyGeneratorOptions;
+
+import org.apache.avro.generic.GenericRecord;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.catalyst.InternalRow;
+import org.apache.spark.sql.types.StructType;
+import org.apache.spark.unsafe.types.UTF8String;
+
+import java.util.Collections;
+
+import static org.apache.hudi.common.util.ValidationUtils.checkArgument;
+
+/**
+ * This class is used for test purpose, and should never be used
+ * in other ways.
+ */
+public class MockUserProvidedKeyGenerator extends BuiltinKeyGenerator {
+ private static final String PREFIX = "MOCK_";
+
+ private final SimpleAvroKeyGenerator simpleAvroKeyGenerator;
+
+ public MockUserProvidedKeyGenerator(TypedProperties props) {
+ this(props,
Option.ofNullable(props.getString(KeyGeneratorOptions.RECORDKEY_FIELD_NAME.key(),
null)),
+ props.getString(KeyGeneratorOptions.PARTITIONPATH_FIELD_NAME.key()));
+ }
+
+ public MockUserProvidedKeyGenerator(TypedProperties props, String
partitionPathField) {
+ this(props, Option.empty(), partitionPathField);
+ }
+
+ public MockUserProvidedKeyGenerator(TypedProperties props, Option<String>
recordKeyField, String partitionPathField) {
+ super(props);
+ // Make sure key-generator is configured properly
+ validateRecordKey(recordKeyField);
+ validatePartitionPath(partitionPathField);
+
+ this.recordKeyFields = !recordKeyField.isPresent() ?
Collections.emptyList() : Collections.singletonList(recordKeyField.get());
+ this.partitionPathFields = partitionPathField == null ?
Collections.emptyList() : Collections.singletonList(partitionPathField);
+ this.simpleAvroKeyGenerator = new SimpleAvroKeyGenerator(props,
recordKeyField, partitionPathField);
+ }
+
+ @Override
+ public String getRecordKey(GenericRecord record) {
+ return PREFIX + simpleAvroKeyGenerator.getRecordKey(record);
+ }
+
+ @Override
+ public String getPartitionPath(GenericRecord record) {
+ return PREFIX + simpleAvroKeyGenerator.getPartitionPath(record);
+ }
+
+ @Override
+ public String getRecordKey(Row row) {
+ tryInitRowAccessor(row.schema());
+
+ Object[] recordKeys = rowAccessor.getRecordKeyParts(row);
+ // NOTE: [[SimpleKeyGenerator]] is restricted to allow only primitive
(non-composite)
+ // record-key field
+ if (recordKeys[0] == null) {
+ return handleNullRecordKey(null);
+ } else {
+ return requireNonNullNonEmptyKey(recordKeys[0].toString());
+ }
+ }
+
+ @Override
+ public UTF8String getRecordKey(InternalRow internalRow, StructType schema) {
+ tryInitRowAccessor(schema);
+
+ Object[] recordKeyValues = rowAccessor.getRecordKeyParts(internalRow);
+ // NOTE: [[SimpleKeyGenerator]] is restricted to allow only primitive
(non-composite)
+ // record-key field
+ if (recordKeyValues[0] == null) {
+ return handleNullRecordKey(null);
+ } else if (recordKeyValues[0] instanceof UTF8String) {
+ return requireNonNullNonEmptyKey((UTF8String) recordKeyValues[0]);
+ } else {
+ return
requireNonNullNonEmptyKey(UTF8String.fromString(recordKeyValues[0].toString()));
+ }
+ }
+
+ @Override
+ public String getPartitionPath(Row row) {
+ tryInitRowAccessor(row.schema());
+ return combinePartitionPath(rowAccessor.getRecordPartitionPathValues(row));
+ }
+
+ @Override
+ public UTF8String getPartitionPath(InternalRow row, StructType schema) {
+ tryInitRowAccessor(schema);
+ return
combinePartitionPathUnsafe(rowAccessor.getRecordPartitionPathValues(row));
+ }
+
+ private static void validatePartitionPath(String partitionPathField) {
+ checkArgument(partitionPathField == null || !partitionPathField.isEmpty(),
+ "Partition-path field has to be non-empty!");
+ checkArgument(partitionPathField == null ||
!partitionPathField.contains(FIELDS_SEP),
+ String.format("Single partition-path field is expected; provided
(%s)", partitionPathField));
+ }
+
+ private void validateRecordKey(Option<String> recordKeyField) {
+ checkArgument(!recordKeyField.isPresent() ||
!recordKeyField.get().isEmpty(),
+ "Record key field has to be non-empty!");
+ checkArgument(!recordKeyField.isPresent() ||
!recordKeyField.get().contains(FIELDS_SEP),
+ String.format("Single record-key field is expected; provided (%s)",
recordKeyField));
+ }
+}
diff --git
a/hudi-common/src/main/java/org/apache/hudi/common/table/HoodieTableMetaClient.java
b/hudi-common/src/main/java/org/apache/hudi/common/table/HoodieTableMetaClient.java
index 75198a7176e9..a7e91f69ef00 100644
---
a/hudi-common/src/main/java/org/apache/hudi/common/table/HoodieTableMetaClient.java
+++
b/hudi-common/src/main/java/org/apache/hudi/common/table/HoodieTableMetaClient.java
@@ -104,6 +104,7 @@ import static
org.apache.hudi.common.util.StringUtils.getUTF8Bytes;
import static org.apache.hudi.common.util.ValidationUtils.checkArgument;
import static org.apache.hudi.common.util.ValidationUtils.checkState;
import static org.apache.hudi.io.storage.HoodieIOFactory.getIOFactory;
+import static org.apache.hudi.keygen.constant.KeyGeneratorType.USER_PROVIDED;
import static
org.apache.hudi.metadata.HoodieIndexVersion.isValidIndexDefinition;
/**
@@ -1563,9 +1564,17 @@ public class HoodieTableMetaClient implements
Serializable {
tableConfig.setValue(HoodieTableConfig.POPULATE_META_FIELDS,
Boolean.toString(populateMetaFields));
}
if (null != keyGeneratorClassProp) {
- tableConfig.setValue(HoodieTableConfig.KEY_GENERATOR_TYPE,
KeyGeneratorType.fromClassName(keyGeneratorClassProp).name());
+ KeyGeneratorType type =
KeyGeneratorType.fromClassName(keyGeneratorClassProp);
+ tableConfig.setValue(HoodieTableConfig.KEY_GENERATOR_TYPE,
type.name());
+ if (USER_PROVIDED == type) {
+ tableConfig.setValue(HoodieTableConfig.KEY_GENERATOR_CLASS_NAME,
keyGeneratorClassProp);
+ }
} else if (null != keyGeneratorType) {
- tableConfig.setValue(HoodieTableConfig.KEY_GENERATOR_TYPE,
keyGeneratorType);
+ checkArgument(!keyGeneratorType.equals(USER_PROVIDED.name()),
+ String.format("When key generator type is %s, the key generator
class must be set properly",
+ USER_PROVIDED.name()));
+ KeyGeneratorType type = KeyGeneratorType.valueOf(keyGeneratorType);
+ tableConfig.setValue(HoodieTableConfig.KEY_GENERATOR_TYPE,
type.name());
}
if (null != hiveStylePartitioningEnable) {
tableConfig.setValue(HoodieTableConfig.HIVE_STYLE_PARTITIONING_ENABLE,
Boolean.toString(hiveStylePartitioningEnable));
diff --git
a/hudi-common/src/main/java/org/apache/hudi/keygen/constant/KeyGeneratorType.java
b/hudi-common/src/main/java/org/apache/hudi/keygen/constant/KeyGeneratorType.java
index b84cba3e28ac..6ea63da66db8 100644
---
a/hudi-common/src/main/java/org/apache/hudi/keygen/constant/KeyGeneratorType.java
+++
b/hudi-common/src/main/java/org/apache/hudi/keygen/constant/KeyGeneratorType.java
@@ -23,6 +23,10 @@ import org.apache.hudi.common.config.EnumFieldDescription;
import org.apache.hudi.common.config.HoodieConfig;
import org.apache.hudi.common.config.TypedProperties;
import org.apache.hudi.common.util.ConfigUtils;
+import org.apache.hudi.common.util.StringUtils;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
import javax.annotation.Nullable;
@@ -96,10 +100,11 @@ public enum KeyGeneratorType {
@EnumFieldDescription("Meant to be used internally for the spark sql MERGE
INTO command.")
SPARK_SQL_MERGE_INTO("org.apache.spark.sql.hudi.command.MergeIntoKeyGenerator"),
- @EnumFieldDescription("A test KeyGenerator for deltastreamer tests.")
-
STREAMER_TEST("org.apache.hudi.utilities.deltastreamer.TestHoodieDeltaStreamer$TestGenerator");
+ @EnumFieldDescription("A KeyGenerator specified from the configuration.")
+ USER_PROVIDED(StringUtils.EMPTY_STRING);
- private final String className;
+ private String className;
+ private static final Logger LOG =
LoggerFactory.getLogger(KeyGeneratorType.class);
KeyGeneratorType(String className) {
this.className = className;
@@ -110,12 +115,15 @@ public enum KeyGeneratorType {
}
public static KeyGeneratorType fromClassName(String className) {
+ if (StringUtils.isNullOrEmpty(className)) {
+ throw new IllegalArgumentException("Invalid keyGenerator class: " +
className);
+ }
for (KeyGeneratorType type : KeyGeneratorType.values()) {
if (type.getClassName().equals(className)) {
return type;
}
}
- throw new IllegalArgumentException("No KeyGeneratorType found for class
name: " + className);
+ return USER_PROVIDED;
}
public static List<String> getNames() {
@@ -125,30 +133,37 @@ public enum KeyGeneratorType {
return names;
}
- @Nullable
- public static String getKeyGeneratorClassName(HoodieConfig config) {
- return getKeyGeneratorClassName(config.getProps());
- }
-
@Nullable
public static String getKeyGeneratorClassName(TypedProperties props) {
+ // For USER_PROVIDED type, since we set key generator class only for this
type.
if (ConfigUtils.containsConfigProperty(props, KEY_GENERATOR_CLASS_NAME)) {
return ConfigUtils.getStringWithAltKeys(props, KEY_GENERATOR_CLASS_NAME);
}
+ // For other types.
+ KeyGeneratorType keyGeneratorType;
if (ConfigUtils.containsConfigProperty(props, KEY_GENERATOR_TYPE)) {
- return KeyGeneratorType.valueOf(ConfigUtils.getStringWithAltKeys(props,
KEY_GENERATOR_TYPE)).getClassName();
+ keyGeneratorType =
KeyGeneratorType.valueOf(ConfigUtils.getStringWithAltKeys(props,
KEY_GENERATOR_TYPE));
+ // For USER_PROVIDED type, the key generator class has to be provided.
+ if (USER_PROVIDED == keyGeneratorType) {
+ throw new IllegalArgumentException("No key generator class is provided
properly for type: " + USER_PROVIDED.name());
+ }
+ return keyGeneratorType.getClassName();
}
+ // No key generator information is provided.
+ LOG.warn("No key generator type is set properly");
return null;
}
+ @Nullable
+ public static String getKeyGeneratorClassName(HoodieConfig config) {
+ return getKeyGeneratorClassName(config.getProps());
+ }
+
@Nullable
public static String getKeyGeneratorClassName(Map<String, String> config) {
- if (config.containsKey(KEY_GENERATOR_CLASS_NAME.key())) {
- return config.get(KEY_GENERATOR_CLASS_NAME.key());
- } else if (config.containsKey(KEY_GENERATOR_TYPE.key())) {
- return
KeyGeneratorType.valueOf(config.get(KEY_GENERATOR_TYPE.key())).getClassName();
- }
- return null;
+ TypedProperties props = new TypedProperties();
+ config.forEach(props::setProperty);
+ return getKeyGeneratorClassName(props);
}
public static boolean isComplexKeyGenerator(HoodieConfig config) {
diff --git
a/hudi-common/src/test/java/org/apache/hudi/keygen/constant/TestKeyGeneratorType.java
b/hudi-common/src/test/java/org/apache/hudi/keygen/constant/TestKeyGeneratorType.java
index 8c9db7cae797..acd5fc739c1a 100644
---
a/hudi-common/src/test/java/org/apache/hudi/keygen/constant/TestKeyGeneratorType.java
+++
b/hudi-common/src/test/java/org/apache/hudi/keygen/constant/TestKeyGeneratorType.java
@@ -22,11 +22,18 @@ package org.apache.hudi.keygen.constant;
import org.apache.hudi.common.config.HoodieConfig;
import org.junit.jupiter.api.Test;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.Arguments;
+import org.junit.jupiter.params.provider.MethodSource;
+
+import java.util.stream.Stream;
import static
org.apache.hudi.common.table.HoodieTableConfig.KEY_GENERATOR_CLASS_NAME;
import static
org.apache.hudi.common.table.HoodieTableConfig.KEY_GENERATOR_TYPE;
+import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;
+import static org.junit.jupiter.api.Assertions.assertThrows;
public class TestKeyGeneratorType {
@Test
@@ -89,4 +96,46 @@ public class TestKeyGeneratorType {
assertFalse(KeyGeneratorType.isComplexKeyGenerator(config));
}
+
+ private static Stream<Arguments> testFromClassNameParams() {
+ return Stream.of(
+ Arguments.of("org.apache.hudi.keygen.SimpleKeyGenerator",
KeyGeneratorType.SIMPLE),
+ Arguments.of("org.apache.hudi.keygen.SimpleAvroKeyGenerator",
KeyGeneratorType.SIMPLE_AVRO),
+ Arguments.of("org.apache.hudi.keygen.ComplexKeyGenerator",
KeyGeneratorType.COMPLEX),
+ Arguments.of("org.apache.hudi.keygen.ComplexAvroKeyGenerator",
KeyGeneratorType.COMPLEX_AVRO),
+ Arguments.of("org.apache.hudi.keygen.TimestampBasedKeyGenerator",
KeyGeneratorType.TIMESTAMP),
+ Arguments.of("org.apache.hudi.keygen.TimestampBasedAvroKeyGenerator",
KeyGeneratorType.TIMESTAMP_AVRO),
+ Arguments.of("org.apache.hudi.keygen.CustomKeyGenerator",
KeyGeneratorType.CUSTOM),
+ Arguments.of("org.apache.hudi.keygen.CustomAvroKeyGenerator",
KeyGeneratorType.CUSTOM_AVRO),
+ Arguments.of("org.apache.hudi.keygen.NonpartitionedKeyGenerator",
KeyGeneratorType.NON_PARTITION),
+ Arguments.of("org.apache.hudi.keygen.NonpartitionedAvroKeyGenerator",
KeyGeneratorType.NON_PARTITION_AVRO),
+ Arguments.of("org.apache.hudi.keygen.GlobalDeleteKeyGenerator",
KeyGeneratorType.GLOBAL_DELETE),
+ Arguments.of("org.apache.hudi.keygen.GlobalAvroDeleteKeyGenerator",
KeyGeneratorType.GLOBAL_DELETE_AVRO),
+
Arguments.of("org.apache.hudi.keygen.AutoRecordGenWrapperKeyGenerator",
KeyGeneratorType.AUTO_RECORD),
+
Arguments.of("org.apache.hudi.keygen.AutoRecordGenWrapperAvroKeyGenerator",
KeyGeneratorType.AUTO_RECORD_AVRO),
+
Arguments.of("org.apache.hudi.metadata.HoodieTableMetadataKeyGenerator",
KeyGeneratorType.HOODIE_TABLE_METADATA),
+ Arguments.of("org.apache.spark.sql.hudi.command.SqlKeyGenerator",
KeyGeneratorType.SPARK_SQL),
+ Arguments.of("org.apache.spark.sql.hudi.command.UuidKeyGenerator",
KeyGeneratorType.SPARK_SQL_UUID),
+
Arguments.of("org.apache.spark.sql.hudi.command.MergeIntoKeyGenerator",
KeyGeneratorType.SPARK_SQL_MERGE_INTO),
+ Arguments.of("org.apache.hudi.keygen.CustomUserProvidedKeyGenerator",
KeyGeneratorType.USER_PROVIDED),
+ Arguments.of("com.example.CustomKeyGenerator",
KeyGeneratorType.USER_PROVIDED)
+ );
+ }
+
+ @ParameterizedTest
+ @MethodSource("testFromClassNameParams")
+ void testFromClassName(String className, KeyGeneratorType expectedType) {
+ KeyGeneratorType result = KeyGeneratorType.fromClassName(className);
+ assertEquals(expectedType, result);
+ }
+
+ @Test
+ void testFromClassNameWithNull() {
+ assertThrows(IllegalArgumentException.class, () ->
KeyGeneratorType.fromClassName(null));
+ }
+
+ @Test
+ void testFromClassNameWithEmpty() {
+ assertThrows(IllegalArgumentException.class, () ->
KeyGeneratorType.fromClassName(""));
+ }
}
diff --git
a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestCOWDataSource.scala
b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestCOWDataSource.scala
index 96d9685638a1..873803018939 100644
---
a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestCOWDataSource.scala
+++
b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestCOWDataSource.scala
@@ -41,7 +41,7 @@ import org.apache.hudi.exception.{HoodieException,
SchemaBackwardsCompatibilityE
import org.apache.hudi.exception.ExceptionUtil.getRootCause
import org.apache.hudi.hive.HiveSyncConfigHolder
import org.apache.hudi.keygen.{ComplexKeyGenerator, CustomKeyGenerator,
GlobalDeleteKeyGenerator, NonpartitionedKeyGenerator, SimpleKeyGenerator,
TimestampBasedKeyGenerator}
-import org.apache.hudi.keygen.constant.KeyGeneratorOptions
+import org.apache.hudi.keygen.constant.{KeyGeneratorOptions, KeyGeneratorType}
import org.apache.hudi.metrics.{Metrics, MetricsReporterType}
import org.apache.hudi.storage.{StoragePath, StoragePathFilter}
import org.apache.hudi.table.HoodieSparkTable
@@ -436,10 +436,11 @@ class TestCOWDataSource extends HoodieSparkClientTestBase
with ScalaAssertionSup
assertLastCommitIsUpsert()
}
- private def writeToHudi(opts: Map[String, String], df: Dataset[Row]): Unit =
{
+ private def writeToHudi(opts: Map[String, String], df: Dataset[Row],
+ operation: String =
DataSourceWriteOptions.INSERT_OPERATION_OPT_VAL): Unit = {
df.write.format("hudi")
.options(opts)
- .option(DataSourceWriteOptions.OPERATION.key,
DataSourceWriteOptions.INSERT_OPERATION_OPT_VAL)
+ .option(DataSourceWriteOptions.OPERATION.key, operation)
.mode(SaveMode.Append)
.save(basePath)
}
@@ -2112,7 +2113,6 @@ class TestCOWDataSource extends HoodieSparkClientTestBase
with ScalaAssertionSup
.save(basePath)
}
-
@Test
def testReadOfAnEmptyTable(): Unit = {
val (writeOpts, _) = getWriterReaderOpts(HoodieRecordType.AVRO)
@@ -2412,6 +2412,67 @@ class TestCOWDataSource extends
HoodieSparkClientTestBase with ScalaAssertionSup
}
}
+ @ParameterizedTest
+ @MethodSource(Array("provideParamsForKeyGenTest"))
+ def testUserProvidedKeyGeneratorClass(keyGenClass: Option[String],
+ keyGenType: Option[String]): Unit = {
+ val recordType = HoodieRecordType.AVRO
+ var opts: Map[String, String] = Map()
+ if (keyGenClass.isPresent) {
+ opts = opts ++ Map(HoodieWriteConfig.KEYGENERATOR_CLASS_NAME.key ->
keyGenClass.get)
+ }
+ if (keyGenType.isPresent) {
+ opts = opts ++ Map(HoodieWriteConfig.KEYGENERATOR_TYPE.key ->
keyGenType.get)
+ }
+ val (writeOpts, readOpts) = getWriterReaderOpts(
+ recordType,
+ CommonOptionUtils.commonOpts ++ opts ++ Map(
+ DataSourceWriteOptions.PARTITIONPATH_FIELD.key -> "partition")
+ )
+ val expectedKeyGenType = if (keyGenClass.isEmpty) {
+ // By default SIMPLE type is returned when no class is provided.
+ KeyGeneratorType.SIMPLE.name
+ } else {
+ KeyGeneratorType.fromClassName(keyGenClass.get()).name
+ }
+
+ // Insert.
+ val records = recordsToStrings(dataGen.generateInserts("000",
10)).asScala.toList
+ val inputDF = spark.read.json(spark.sparkContext.parallelize(records, 2))
+ writeToHudi(writeOpts, inputDF)
+ var actualDF = spark.read.format("hudi").options(readOpts).load(basePath)
+
+ // Transform the input keys based on the key generator to match the
expected format
+ val inputKeyDF = TestCOWDataSource.transformRecordKeyColumn(
+ inputDF.select("_row_key"), "_row_key",
KeyGeneratorType.valueOf(expectedKeyGenType))
+ .sort("_row_key")
+ var actualKeyDF =
actualDF.select("_hoodie_record_key").sort("_hoodie_record_key")
+ assertTrue(inputKeyDF.except(actualKeyDF).isEmpty &&
actualKeyDF.except(inputKeyDF).isEmpty)
+ val metaClient = getHoodieMetaClient(storageConf, basePath)
+ val actualKeyGenType = metaClient.getTableConfig
+ .getProps.getString(HoodieTableConfig.KEY_GENERATOR_TYPE.key, null)
+ assertEquals(expectedKeyGenType, actualKeyGenType)
+ // For USER_PROVIDED type, the class should exist in table config.
+ if (KeyGeneratorType.USER_PROVIDED.name == actualKeyGenType) {
+ assertEquals(keyGenClass.get(),
metaClient.getTableConfig.getKeyGeneratorClassName)
+ }
+
+ // First update.
+ val firstUpdate =
recordsToStrings(dataGen.generateUpdatesForAllRecords("001")).asScala.toList
+ val firstUpdateDF =
spark.read.json(spark.sparkContext.parallelize(firstUpdate, 2))
+ writeToHudi(writeOpts, firstUpdateDF,
DataSourceWriteOptions.UPSERT_OPERATION_OPT_VAL)
+ actualDF = spark.read.format("hudi").options(readOpts).load(basePath)
+ actualKeyDF =
actualDF.select("_hoodie_record_key").sort("_hoodie_record_key")
+ assertTrue(inputKeyDF.except(actualKeyDF).isEmpty &&
actualKeyDF.except(inputKeyDF).isEmpty)
+
+ // Second update.
+ // Change keyGenerator class name to generate exception.
+ val opt = writeOpts ++ Map(
+ "hoodie.datasource.write.keygenerator.class" ->
"org.apache.hudi.keygen.SqlKeyGenerator")
+ assertThrows(classOf[HoodieException])({
+ writeToHudi(opt, firstUpdateDF,
DataSourceWriteOptions.UPSERT_OPERATION_OPT_VAL)
+ })
+ }
}
object TestCOWDataSource {
@@ -2423,6 +2484,17 @@ object TestCOWDataSource {
}
}
+ def transformRecordKeyColumn(df: DataFrame, columnName: String, keyGenType:
KeyGeneratorType): DataFrame = {
+ keyGenType match {
+ case KeyGeneratorType.USER_PROVIDED =>
+ df.withColumn(columnName, concat(lit(s"MOCK_"), col(columnName)))
+ case KeyGeneratorType.COMPLEX =>
+ df.withColumn(columnName, concat(lit(s"_row_key:"), col(columnName)))
+ case _ =>
+ df
+ }
+ }
+
def tableVersionCreationTestCases = {
val autoUpgradeValues = Array("true", "false")
val targetVersions = Array("1", "2", "3", "4", "5", "6", "7", "8", "9",
"null")
@@ -2430,4 +2502,21 @@ object TestCOWDataSource {
(autoUpgrade: String) => targetVersions.map(
(targetVersion: String) => Arguments.of(autoUpgrade, targetVersion)))
}
+
+ def provideParamsForKeyGenTest(): java.util.List[Arguments] = {
+ java.util.Arrays.asList(
+ Arguments.of(
+ Option.of("org.apache.hudi.keygen.MockUserProvidedKeyGenerator"),
+ Option.of(KeyGeneratorType.USER_PROVIDED.name())),
+ Arguments.of(
+ Option.empty(),
+ Option.of(KeyGeneratorType.SIMPLE.name())),
+ Arguments.of(
+ Option.of("org.apache.hudi.keygen.SimpleAvroKeyGenerator"),
+ Option.of(KeyGeneratorType.SIMPLE.name())),
+ Arguments.of(
+ Option.of("org.apache.hudi.keygen.ComplexKeyGenerator"),
+ Option.empty())
+ )
+ }
}
diff --git
a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestMORDataSource.scala
b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestMORDataSource.scala
index b381c5570b17..39179e966222 100644
---
a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestMORDataSource.scala
+++
b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestMORDataSource.scala
@@ -31,8 +31,9 @@ import
org.apache.hudi.common.testutils.HoodieTestDataGenerator.recordsToStrings
import org.apache.hudi.common.util.Option
import org.apache.hudi.common.util.StringUtils.isNullOrEmpty
import org.apache.hudi.config.{HoodieCleanConfig, HoodieCompactionConfig,
HoodieIndexConfig, HoodieWriteConfig}
-import org.apache.hudi.exception.HoodieUpgradeDowngradeException
+import org.apache.hudi.exception.{HoodieException,
HoodieUpgradeDowngradeException}
import org.apache.hudi.index.HoodieIndex.IndexType
+import org.apache.hudi.keygen.constant.KeyGeneratorType
import
org.apache.hudi.metadata.HoodieTableMetadataUtil.{metadataPartitionExists,
PARTITION_NAME_SECONDARY_INDEX_PREFIX}
import org.apache.hudi.storage.{StoragePath, StoragePathInfo}
import org.apache.hudi.table.action.compact.CompactionTriggerStrategy
@@ -49,7 +50,7 @@ import org.apache.spark.sql.types.{LongType, StringType,
StructField, StructType
import org.junit.jupiter.api.{AfterEach, BeforeEach, Test}
import org.junit.jupiter.api.Assertions.{assertEquals, assertFalse, assertTrue}
import org.junit.jupiter.params.ParameterizedTest
-import org.junit.jupiter.params.provider.{CsvSource, EnumSource, ValueSource}
+import org.junit.jupiter.params.provider.{Arguments, CsvSource, EnumSource,
MethodSource, ValueSource}
import java.io.File
import java.nio.file.Files
@@ -2195,6 +2196,71 @@ class TestMORDataSource extends
HoodieSparkClientTestBase with SparkDatasetMixin
}
}
+ @ParameterizedTest
+ @MethodSource(Array("provideParamsForKeyGenTest"))
+ def testUserProvidedKeyGeneratorClass(keyGenClass: Option[String],
+ keyGenType: Option[String]): Unit = {
+ val recordType = HoodieRecordType.AVRO
+ var opts: Map[String, String] = Map(
+ DataSourceWriteOptions.TABLE_TYPE.key ->
DataSourceWriteOptions.MOR_TABLE_TYPE_OPT_VAL,
+ HoodieWriteConfig.MERGE_SMALL_FILE_GROUP_CANDIDATES_LIMIT.key -> "0"
+ )
+ if (keyGenClass.isPresent) {
+ opts = opts ++ Map(HoodieWriteConfig.KEYGENERATOR_CLASS_NAME.key ->
keyGenClass.get)
+ }
+ if (keyGenType.isPresent) {
+ opts = opts ++ Map(HoodieWriteConfig.KEYGENERATOR_TYPE.key ->
keyGenType.get)
+ }
+ val (writeOpts, readOpts) = getWriterReaderOpts(
+ recordType,
+ CommonOptionUtils.commonOpts ++ opts ++ Map(
+ DataSourceWriteOptions.PARTITIONPATH_FIELD.key -> "partition")
+ )
+ val expectedKeyGenType = if (keyGenClass.isEmpty) {
+ // By default SIMPLE type is returned when no class is provided.
+ KeyGeneratorType.SIMPLE.name
+ } else {
+ KeyGeneratorType.fromClassName(keyGenClass.get()).name
+ }
+
+ // Insert.
+ val records = recordsToStrings(dataGen.generateInserts("000",
10)).asScala.toList
+ val inputDF = spark.read.json(spark.sparkContext.parallelize(records, 2))
+ writeToHudi(writeOpts, inputDF)
+ var actualDF = spark.read.format("hudi").options(readOpts).load(basePath)
+
+ // Transform the input keys based on the key generator to match the
expected format
+ val inputKeyDF = TestCOWDataSource.transformRecordKeyColumn(
+ inputDF.select("_row_key"), "_row_key",
KeyGeneratorType.valueOf(expectedKeyGenType))
+ .sort("_row_key")
+ var actualKeyDF =
actualDF.select("_hoodie_record_key").sort("_hoodie_record_key")
+ assertTrue(inputKeyDF.except(actualKeyDF).isEmpty &&
actualKeyDF.except(inputKeyDF).isEmpty)
+ val metaClient = getHoodieMetaClient(storageConf, basePath)
+ val actualKeyGenType = metaClient.getTableConfig
+ .getProps.getString(HoodieTableConfig.KEY_GENERATOR_TYPE.key, null)
+ assertEquals(expectedKeyGenType, actualKeyGenType)
+ // For USER_PROVIDED type, the class should exist in table config.
+ if (KeyGeneratorType.USER_PROVIDED.name == actualKeyGenType) {
+ assertEquals(keyGenClass.get(),
metaClient.getTableConfig.getKeyGeneratorClassName)
+ }
+
+ // First update.
+ val firstUpdate =
recordsToStrings(dataGen.generateUpdatesForAllRecords("001")).asScala.toList
+ val firstUpdateDF =
spark.read.json(spark.sparkContext.parallelize(firstUpdate, 2))
+ writeToHudi(writeOpts, firstUpdateDF,
DataSourceWriteOptions.UPSERT_OPERATION_OPT_VAL)
+ actualDF = spark.read.format("hudi").options(readOpts).load(basePath)
+ actualKeyDF =
actualDF.select("_hoodie_record_key").sort("_hoodie_record_key")
+ assertTrue(inputKeyDF.except(actualKeyDF).isEmpty &&
actualKeyDF.except(inputKeyDF).isEmpty)
+
+ // Second update.
+ // Change keyGenerator class name to generate exception.
+ val opt = writeOpts ++ Map(
+ "hoodie.datasource.write.keygenerator.class" ->
"org.apache.hudi.keygen.SqlKeyGenerator")
+ assertThrows(classOf[HoodieException])({
+ writeToHudi(opt, firstUpdateDF,
DataSourceWriteOptions.UPSERT_OPERATION_OPT_VAL)
+ })
+ }
+
private def loadFixtureTable(testBasePath: String, version:
HoodieTableVersion): HoodieTableMetaClient = {
val fixtureName = getFixtureName(version, "")
val resourcePath =
s"/upgrade-downgrade-fixtures/unsupported-upgrade-tables/$fixtureName"
@@ -2251,4 +2317,32 @@ class TestMORDataSource extends
HoodieSparkClientTestBase with SparkDatasetMixin
Row("id11", "TestUser3", 11000L, "2023-01-06")
)
}
+
+ private def writeToHudi(opts: Map[String, String], df: Dataset[Row],
+ operation: String =
DataSourceWriteOptions.INSERT_OPERATION_OPT_VAL): Unit = {
+ df.write.format("hudi")
+ .options(opts)
+ .option(DataSourceWriteOptions.OPERATION.key, operation)
+ .mode(SaveMode.Append)
+ .save(basePath)
+ }
+}
+
+object TestMORDataSource {
+ def provideParamsForKeyGenTest(): java.util.List[Arguments] = {
+ java.util.Arrays.asList(
+ Arguments.of(
+ Option.of("org.apache.hudi.keygen.MockUserProvidedKeyGenerator"),
+ Option.of(KeyGeneratorType.USER_PROVIDED.name())),
+ Arguments.of(
+ Option.empty(),
+ Option.of(KeyGeneratorType.SIMPLE.name())),
+ Arguments.of(
+ Option.of("org.apache.hudi.keygen.SimpleAvroKeyGenerator"),
+ Option.of(KeyGeneratorType.SIMPLE.name())),
+ Arguments.of(
+ Option.of("org.apache.hudi.keygen.ComplexKeyGenerator"),
+ Option.empty())
+ )
+ }
}