This is an automated email from the ASF dual-hosted git repository.
yihua pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/hudi.git
The following commit(s) were added to refs/heads/master by this push:
new 17ea14ab6d6 [HUDI-7378] Fix Spark SQL DML with custom key generator
(#10615)
17ea14ab6d6 is described below
commit 17ea14ab6d6a8ca7ecef2cfcdbc67b0c87f23987
Author: Y Ethan Guo <[email protected]>
AuthorDate: Fri Apr 12 22:51:03 2024 -0700
[HUDI-7378] Fix Spark SQL DML with custom key generator (#10615)
---
.../factory/HoodieSparkKeyGeneratorFactory.java | 4 +
.../org/apache/hudi/util/SparkKeyGenUtils.scala | 16 +-
.../scala/org/apache/hudi/HoodieWriterUtils.scala | 20 +-
.../spark/sql/hudi/ProvidesHoodieConfig.scala | 60 ++-
.../spark/sql/hudi/TestProvidesHoodieConfig.scala | 79 +++
.../hudi/command/MergeIntoHoodieTableCommand.scala | 5 +-
.../TestSparkSqlWithCustomKeyGenerator.scala | 571 +++++++++++++++++++++
7 files changed, 742 insertions(+), 13 deletions(-)
diff --git
a/hudi-client/hudi-spark-client/src/main/java/org/apache/hudi/keygen/factory/HoodieSparkKeyGeneratorFactory.java
b/hudi-client/hudi-spark-client/src/main/java/org/apache/hudi/keygen/factory/HoodieSparkKeyGeneratorFactory.java
index 1ea5adcd6b4..dcc2eaec9eb 100644
---
a/hudi-client/hudi-spark-client/src/main/java/org/apache/hudi/keygen/factory/HoodieSparkKeyGeneratorFactory.java
+++
b/hudi-client/hudi-spark-client/src/main/java/org/apache/hudi/keygen/factory/HoodieSparkKeyGeneratorFactory.java
@@ -79,6 +79,10 @@ public class HoodieSparkKeyGeneratorFactory {
public static KeyGenerator createKeyGenerator(TypedProperties props) throws
IOException {
String keyGeneratorClass = getKeyGeneratorClassName(props);
+ return createKeyGenerator(keyGeneratorClass, props);
+ }
+
+ public static KeyGenerator createKeyGenerator(String keyGeneratorClass,
TypedProperties props) throws IOException {
boolean autoRecordKeyGen =
KeyGenUtils.isAutoGeneratedRecordKeysEnabled(props)
//Need to prevent overwriting the keygen for spark sql merge into
because we need to extract
//the recordkey from the meta cols if it exists. Sql keygen will use
pkless keygen if needed.
diff --git
a/hudi-client/hudi-spark-client/src/main/scala/org/apache/hudi/util/SparkKeyGenUtils.scala
b/hudi-client/hudi-spark-client/src/main/scala/org/apache/hudi/util/SparkKeyGenUtils.scala
index 7b91ae5a728..bd094464096 100644
---
a/hudi-client/hudi-spark-client/src/main/scala/org/apache/hudi/util/SparkKeyGenUtils.scala
+++
b/hudi-client/hudi-spark-client/src/main/scala/org/apache/hudi/util/SparkKeyGenUtils.scala
@@ -21,8 +21,8 @@ import org.apache.hudi.common.config.TypedProperties
import org.apache.hudi.common.util.StringUtils
import org.apache.hudi.common.util.ValidationUtils.checkArgument
import org.apache.hudi.keygen.constant.KeyGeneratorOptions
-import org.apache.hudi.keygen.{AutoRecordKeyGeneratorWrapper,
AutoRecordGenWrapperKeyGenerator, CustomAvroKeyGenerator, CustomKeyGenerator,
GlobalAvroDeleteKeyGenerator, GlobalDeleteKeyGenerator, KeyGenerator,
NonpartitionedAvroKeyGenerator, NonpartitionedKeyGenerator}
import org.apache.hudi.keygen.factory.HoodieSparkKeyGeneratorFactory
+import org.apache.hudi.keygen.{AutoRecordKeyGeneratorWrapper,
CustomAvroKeyGenerator, CustomKeyGenerator, GlobalAvroDeleteKeyGenerator,
GlobalDeleteKeyGenerator, KeyGenerator, NonpartitionedAvroKeyGenerator,
NonpartitionedKeyGenerator}
object SparkKeyGenUtils {
@@ -35,6 +35,20 @@ object SparkKeyGenUtils {
getPartitionColumns(keyGenerator, props)
}
+ /**
+ * @param KeyGenClassNameOption key generator class name if present.
+ * @param props config properties.
+ * @return partition column names only, concatenated by ","
+ */
+ def getPartitionColumns(KeyGenClassNameOption: Option[String], props:
TypedProperties): String = {
+ val keyGenerator = if (KeyGenClassNameOption.isEmpty) {
+ HoodieSparkKeyGeneratorFactory.createKeyGenerator(props)
+ } else {
+
HoodieSparkKeyGeneratorFactory.createKeyGenerator(KeyGenClassNameOption.get,
props)
+ }
+ getPartitionColumns(keyGenerator, props)
+ }
+
/**
* @param keyGen key generator class name
* @return partition columns
diff --git
a/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/HoodieWriterUtils.scala
b/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/HoodieWriterUtils.scala
index 63495b0eede..5df773542d6 100644
---
a/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/HoodieWriterUtils.scala
+++
b/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/HoodieWriterUtils.scala
@@ -201,8 +201,26 @@ object HoodieWriterUtils {
diffConfigs.append(s"KeyGenerator:\t$datasourceKeyGen\t$tableConfigKeyGen\n")
}
+ // Please note that the validation of partition path fields needs the
key generator class
+ // for the table, since the custom key generator expects a different
format of
+ // the value of the write config
"hoodie.datasource.write.partitionpath.field"
+ // e.g., "col:simple,ts:timestamp", whereas the table config
"hoodie.table.partition.fields"
+ // in hoodie.properties stores "col,ts".
+ // The "params" here may only contain the write config of partition
path field,
+ // so we need to pass in the validated key generator class name.
+ val validatedKeyGenClassName = if (tableConfigKeyGen != null) {
+ Option(tableConfigKeyGen)
+ } else if (datasourceKeyGen != null) {
+ Option(datasourceKeyGen)
+ } else {
+ None
+ }
val datasourcePartitionFields =
params.getOrElse(PARTITIONPATH_FIELD.key(), null)
- val currentPartitionFields = if (datasourcePartitionFields == null)
null else SparkKeyGenUtils.getPartitionColumns(TypedProperties.fromMap(params))
+ val currentPartitionFields = if (datasourcePartitionFields == null) {
+ null
+ } else {
+ SparkKeyGenUtils.getPartitionColumns(validatedKeyGenClassName,
TypedProperties.fromMap(params))
+ }
val tableConfigPartitionFields =
tableConfig.getString(HoodieTableConfig.PARTITION_FIELDS)
if (null != datasourcePartitionFields && null !=
tableConfigPartitionFields
&& currentPartitionFields != tableConfigPartitionFields) {
diff --git
a/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/spark/sql/hudi/ProvidesHoodieConfig.scala
b/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/spark/sql/hudi/ProvidesHoodieConfig.scala
index a4003bbd480..7d35c490fd4 100644
---
a/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/spark/sql/hudi/ProvidesHoodieConfig.scala
+++
b/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/spark/sql/hudi/ProvidesHoodieConfig.scala
@@ -18,34 +18,36 @@
package org.apache.spark.sql.hudi
import
org.apache.hudi.AutoRecordKeyGenerationUtils.shouldAutoGenerateRecordKeys
-import org.apache.hudi.{DataSourceWriteOptions, HoodieFileIndex}
import org.apache.hudi.DataSourceWriteOptions._
import org.apache.hudi.HoodieConversionUtils.toProperties
import org.apache.hudi.common.config.{DFSPropertiesConfiguration,
TypedProperties}
import org.apache.hudi.common.model.{DefaultHoodieRecordPayload,
WriteOperationType}
import org.apache.hudi.common.table.HoodieTableConfig
+import org.apache.hudi.common.util.{ReflectionUtils, StringUtils}
import org.apache.hudi.config.HoodieWriteConfig.TBL_NAME
import org.apache.hudi.config.{HoodieIndexConfig, HoodieInternalConfig,
HoodieWriteConfig}
import org.apache.hudi.hive.ddl.HiveSyncMode
import org.apache.hudi.hive.{HiveSyncConfig, HiveSyncConfigHolder,
MultiPartKeysValueExtractor}
-import org.apache.hudi.keygen.ComplexKeyGenerator
+import org.apache.hudi.keygen.{ComplexKeyGenerator, CustomAvroKeyGenerator,
CustomKeyGenerator}
import org.apache.hudi.sql.InsertMode
import org.apache.hudi.sync.common.HoodieSyncConfig
+import org.apache.hudi.{DataSourceWriteOptions, HoodieFileIndex}
import org.apache.spark.internal.Logging
-import org.apache.spark.sql.{SaveMode, SparkSession}
import org.apache.spark.sql.catalyst.catalog.HoodieCatalogTable
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, EqualTo,
Literal}
import org.apache.spark.sql.execution.datasources.FileStatusCache
import org.apache.spark.sql.hive.HiveExternalCatalog
import
org.apache.spark.sql.hudi.HoodieOptionConfig.mapSqlOptionsToDataSourceWriteConfigs
import org.apache.spark.sql.hudi.HoodieSqlCommonUtils.{isHoodieConfigKey,
isUsingHiveCatalog}
-import org.apache.spark.sql.hudi.ProvidesHoodieConfig.combineOptions
+import org.apache.spark.sql.hudi.ProvidesHoodieConfig.{combineOptions,
getPartitionPathFieldWriteConfig}
import org.apache.spark.sql.hudi.command.{SqlKeyGenerator,
ValidateDuplicateKeyPayload}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.PARTITION_OVERWRITE_MODE
import org.apache.spark.sql.types.StructType
-import java.util.Locale
+import org.apache.spark.sql.{SaveMode, SparkSession}
+import org.slf4j.LoggerFactory
+import java.util.Locale
import scala.collection.JavaConverters._
trait ProvidesHoodieConfig extends Logging {
@@ -82,7 +84,8 @@ trait ProvidesHoodieConfig extends Logging {
PRECOMBINE_FIELD.key -> preCombineField,
HIVE_STYLE_PARTITIONING.key ->
tableConfig.getHiveStylePartitioningEnable,
URL_ENCODE_PARTITIONING.key -> tableConfig.getUrlEncodePartitioning,
- PARTITIONPATH_FIELD.key -> tableConfig.getPartitionFieldProp
+ PARTITIONPATH_FIELD.key -> getPartitionPathFieldWriteConfig(
+ tableConfig.getKeyGeneratorClassName,
tableConfig.getPartitionFieldProp, hoodieCatalogTable)
)
combineOptions(hoodieCatalogTable, tableConfig,
sparkSession.sqlContext.conf,
@@ -313,7 +316,8 @@ trait ProvidesHoodieConfig extends Logging {
URL_ENCODE_PARTITIONING.key -> urlEncodePartitioning,
RECORDKEY_FIELD.key -> recordKeyConfigValue,
PRECOMBINE_FIELD.key -> preCombineField,
- PARTITIONPATH_FIELD.key -> partitionFieldsStr
+ PARTITIONPATH_FIELD.key -> getPartitionPathFieldWriteConfig(
+ keyGeneratorClassName, partitionFieldsStr, hoodieCatalogTable)
) ++ overwriteTableOpts ++ getDropDupsConfig(useLegacyInsertModeFlow,
combinedOpts) ++ staticOverwritePartitionPathOptions
combineOptions(hoodieCatalogTable, tableConfig,
sparkSession.sqlContext.conf,
@@ -405,7 +409,8 @@ trait ProvidesHoodieConfig extends Logging {
PARTITIONS_TO_DELETE.key -> partitionsToDrop,
RECORDKEY_FIELD.key -> hoodieCatalogTable.primaryKeys.mkString(","),
PRECOMBINE_FIELD.key -> hoodieCatalogTable.preCombineKey.getOrElse(""),
- PARTITIONPATH_FIELD.key -> partitionFields,
+ PARTITIONPATH_FIELD.key -> getPartitionPathFieldWriteConfig(
+ tableConfig.getKeyGeneratorClassName, partitionFields,
hoodieCatalogTable),
HoodieSyncConfig.META_SYNC_ENABLED.key ->
hiveSyncConfig.getString(HoodieSyncConfig.META_SYNC_ENABLED.key),
HiveSyncConfigHolder.HIVE_SYNC_ENABLED.key ->
hiveSyncConfig.getString(HiveSyncConfigHolder.HIVE_SYNC_ENABLED.key),
HiveSyncConfigHolder.HIVE_SYNC_MODE.key ->
hiveSyncConfig.getStringOrDefault(HiveSyncConfigHolder.HIVE_SYNC_MODE,
HiveSyncMode.HMS.name()),
@@ -451,7 +456,8 @@ trait ProvidesHoodieConfig extends Logging {
HIVE_STYLE_PARTITIONING.key ->
tableConfig.getHiveStylePartitioningEnable,
URL_ENCODE_PARTITIONING.key -> tableConfig.getUrlEncodePartitioning,
OPERATION.key -> DataSourceWriteOptions.DELETE_OPERATION_OPT_VAL,
- PARTITIONPATH_FIELD.key -> tableConfig.getPartitionFieldProp
+ PARTITIONPATH_FIELD.key -> getPartitionPathFieldWriteConfig(
+ tableConfig.getKeyGeneratorClassName,
tableConfig.getPartitionFieldProp, hoodieCatalogTable)
)
combineOptions(hoodieCatalogTable, tableConfig,
sparkSession.sqlContext.conf,
@@ -496,6 +502,8 @@ trait ProvidesHoodieConfig extends Logging {
object ProvidesHoodieConfig {
+ private val log = LoggerFactory.getLogger(getClass)
+
// NOTE: PLEASE READ CAREFULLY BEFORE CHANGING
//
// Spark SQL operations configuration might be coming from a variety of
diverse sources
@@ -530,6 +538,40 @@ object ProvidesHoodieConfig {
filterNullValues(overridingOpts)
}
+ /**
+ * @param tableConfigKeyGeneratorClassName key generator class name in
the table config.
+ * @param partitionFieldNamesWithoutKeyGenType partition field names without
key generator types
+ * from the table config.
+ * @param catalogTable HoodieCatalogTable instance
to fetch table properties.
+ * @return the write config value to set for
"hoodie.datasource.write.partitionpath.field".
+ */
+ def getPartitionPathFieldWriteConfig(tableConfigKeyGeneratorClassName:
String,
+ partitionFieldNamesWithoutKeyGenType:
String,
+ catalogTable: HoodieCatalogTable):
String = {
+ if (StringUtils.isNullOrEmpty(tableConfigKeyGeneratorClassName)) {
+ partitionFieldNamesWithoutKeyGenType
+ } else {
+ val writeConfigPartitionField =
catalogTable.catalogProperties.get(PARTITIONPATH_FIELD.key())
+ val keyGenClass =
ReflectionUtils.getClass(tableConfigKeyGeneratorClassName)
+ if (classOf[CustomKeyGenerator].equals(keyGenClass)
+ || classOf[CustomAvroKeyGenerator].equals(keyGenClass)) {
+ // For custom key generator, we have to take the write config value
from
+ // "hoodie.datasource.write.partitionpath.field" which contains the
key generator
+ // type, whereas the table config only contains the prtition field
names without
+ // key generator types.
+ if (writeConfigPartitionField.isDefined) {
+ writeConfigPartitionField.get
+ } else {
+ log.warn("Write config
\"hoodie.datasource.write.partitionpath.field\" is not set for "
+ + "custom key generator. This may fail the write operation.")
+ partitionFieldNamesWithoutKeyGenType
+ }
+ } else {
+ partitionFieldNamesWithoutKeyGenType
+ }
+ }
+ }
+
private def filterNullValues(opts: Map[String, String]): Map[String, String]
=
opts.filter { case (_, v) => v != null }
diff --git
a/hudi-spark-datasource/hudi-spark-common/src/test/scala/org/apache/spark/sql/hudi/TestProvidesHoodieConfig.scala
b/hudi-spark-datasource/hudi-spark-common/src/test/scala/org/apache/spark/sql/hudi/TestProvidesHoodieConfig.scala
new file mode 100644
index 00000000000..8414e41ca6c
--- /dev/null
+++
b/hudi-spark-datasource/hudi-spark-common/src/test/scala/org/apache/spark/sql/hudi/TestProvidesHoodieConfig.scala
@@ -0,0 +1,79 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.spark.sql.hudi
+
+import org.apache.hudi.DataSourceWriteOptions.PARTITIONPATH_FIELD
+import org.apache.hudi.keygen.{ComplexKeyGenerator, CustomKeyGenerator}
+
+import org.apache.spark.sql.catalyst.catalog.HoodieCatalogTable
+import org.junit.jupiter.api.Assertions.assertEquals
+import org.junit.jupiter.api.Test
+import org.mockito.Mockito
+import org.mockito.Mockito.when
+
+/**
+ * Tests {@link ProvidesHoodieConfig}
+ */
+class TestProvidesHoodieConfig {
+ @Test
+ def testGetPartitionPathFieldWriteConfig(): Unit = {
+ val mockTable = Mockito.mock(classOf[HoodieCatalogTable])
+ val partitionFieldNames = "ts,segment"
+ val customKeyGenPartitionFieldWriteConfig = "ts:timestamp,segment:simple"
+
+ mockPartitionWriteConfigInCatalogProps(mockTable, None)
+ assertEquals(
+ partitionFieldNames,
+ ProvidesHoodieConfig.getPartitionPathFieldWriteConfig(
+ "", partitionFieldNames, mockTable))
+ assertEquals(
+ partitionFieldNames,
+ ProvidesHoodieConfig.getPartitionPathFieldWriteConfig(
+ classOf[ComplexKeyGenerator].getName, partitionFieldNames, mockTable))
+ assertEquals(
+ partitionFieldNames,
+ ProvidesHoodieConfig.getPartitionPathFieldWriteConfig(
+ classOf[CustomKeyGenerator].getName, partitionFieldNames, mockTable))
+
+ mockPartitionWriteConfigInCatalogProps(mockTable,
Option(customKeyGenPartitionFieldWriteConfig))
+ assertEquals(
+ partitionFieldNames,
+ ProvidesHoodieConfig.getPartitionPathFieldWriteConfig(
+ "", partitionFieldNames, mockTable))
+ assertEquals(
+ partitionFieldNames,
+ ProvidesHoodieConfig.getPartitionPathFieldWriteConfig(
+ classOf[ComplexKeyGenerator].getName, partitionFieldNames, mockTable))
+ assertEquals(
+ customKeyGenPartitionFieldWriteConfig,
+ ProvidesHoodieConfig.getPartitionPathFieldWriteConfig(
+ classOf[CustomKeyGenerator].getName, partitionFieldNames, mockTable))
+ }
+
+ private def mockPartitionWriteConfigInCatalogProps(mockTable:
HoodieCatalogTable,
+ value: Option[String]):
Unit = {
+ val props = if (value.isDefined) {
+ Map(PARTITIONPATH_FIELD.key() -> value.get)
+ } else {
+ Map[String, String]()
+ }
+ when(mockTable.catalogProperties).thenReturn(props)
+ }
+}
diff --git
a/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/MergeIntoHoodieTableCommand.scala
b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/MergeIntoHoodieTableCommand.scala
index 18403872f4a..79cd2646e08 100644
---
a/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/MergeIntoHoodieTableCommand.scala
+++
b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/MergeIntoHoodieTableCommand.scala
@@ -40,7 +40,7 @@ import org.apache.spark.sql.catalyst.plans.LeftOuter
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.hudi.HoodieSqlCommonUtils._
import org.apache.spark.sql.hudi.ProvidesHoodieConfig
-import org.apache.spark.sql.hudi.ProvidesHoodieConfig.combineOptions
+import org.apache.spark.sql.hudi.ProvidesHoodieConfig.{combineOptions,
getPartitionPathFieldWriteConfig}
import org.apache.spark.sql.hudi.analysis.HoodieAnalysis.failAnalysis
import
org.apache.spark.sql.hudi.command.MergeIntoHoodieTableCommand.{CoercedAttributeReference,
encodeAsBase64String, stripCasting, toStructType}
import
org.apache.spark.sql.hudi.command.PartialAssignmentMode.PartialAssignmentMode
@@ -729,7 +729,8 @@ case class MergeIntoHoodieTableCommand(mergeInto:
MergeIntoTable) extends Hoodie
RECORDKEY_FIELD.key -> tableConfig.getRawRecordKeyFieldProp,
PRECOMBINE_FIELD.key -> preCombineField,
TBL_NAME.key -> hoodieCatalogTable.tableName,
- PARTITIONPATH_FIELD.key -> tableConfig.getPartitionFieldProp,
+ PARTITIONPATH_FIELD.key -> getPartitionPathFieldWriteConfig(
+ tableConfig.getKeyGeneratorClassName,
tableConfig.getPartitionFieldProp, hoodieCatalogTable),
HIVE_STYLE_PARTITIONING.key ->
tableConfig.getHiveStylePartitioningEnable,
URL_ENCODE_PARTITIONING.key -> tableConfig.getUrlEncodePartitioning,
KEYGENERATOR_CLASS_NAME.key -> keyGeneratorClassName,
diff --git
a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestSparkSqlWithCustomKeyGenerator.scala
b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestSparkSqlWithCustomKeyGenerator.scala
new file mode 100644
index 00000000000..c85eb40bca7
--- /dev/null
+++
b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestSparkSqlWithCustomKeyGenerator.scala
@@ -0,0 +1,571 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.hudi.functional
+
+import org.apache.hudi.HoodieSparkUtils
+import org.apache.hudi.common.config.TypedProperties
+import org.apache.hudi.common.table.HoodieTableMetaClient
+import org.apache.hudi.common.util.StringUtils
+import org.apache.hudi.exception.HoodieException
+import org.apache.hudi.functional.TestSparkSqlWithCustomKeyGenerator._
+import org.apache.hudi.util.SparkKeyGenUtils
+import org.apache.spark.sql.SaveMode
+import org.apache.spark.sql.hudi.common.HoodieSparkSqlTestBase
+import org.joda.time.DateTime
+import org.joda.time.format.DateTimeFormat
+import org.junit.jupiter.api.Assertions.{assertEquals, assertFalse, assertTrue}
+import org.slf4j.LoggerFactory
+
+import java.io.IOException
+
+/**
+ * Tests Spark SQL DML with custom key generator and write configs.
+ */
+class TestSparkSqlWithCustomKeyGenerator extends HoodieSparkSqlTestBase {
+ private val LOG = LoggerFactory.getLogger(getClass)
+
+ test("Test Spark SQL DML with custom key generator") {
+ withTempDir { tmp =>
+ Seq(
+ Seq("COPY_ON_WRITE", "ts:timestamp,segment:simple",
+ "(ts=202401, segment='cat2')", "202401/cat2",
+ Seq("202312/cat2", "202312/cat4", "202401/cat1", "202401/cat3",
"202402/cat1", "202402/cat3", "202402/cat5"),
+ TS_FORMATTER_FUNC,
+ (ts: Integer, segment: String) => TS_FORMATTER_FUNC.apply(ts) + "/"
+ segment),
+ Seq("MERGE_ON_READ", "segment:simple",
+ "(segment='cat3')", "cat3",
+ Seq("cat1", "cat2", "cat4", "cat5"),
+ TS_TO_STRING_FUNC,
+ (_: Integer, segment: String) => segment),
+ Seq("MERGE_ON_READ", "ts:timestamp",
+ "(ts=202312)", "202312",
+ Seq("202401", "202402"),
+ TS_FORMATTER_FUNC,
+ (ts: Integer, _: String) => TS_FORMATTER_FUNC.apply(ts)),
+ Seq("MERGE_ON_READ", "ts:timestamp,segment:simple",
+ "(ts=202401, segment='cat2')", "202401/cat2",
+ Seq("202312/cat2", "202312/cat4", "202401/cat1", "202401/cat3",
"202402/cat1", "202402/cat3", "202402/cat5"),
+ TS_FORMATTER_FUNC,
+ (ts: Integer, segment: String) => TS_FORMATTER_FUNC.apply(ts) + "/"
+ segment)
+ ).foreach { testParams =>
+ withTable(generateTableName) { tableName =>
+ LOG.warn("Testing with parameters: " + testParams)
+ val tableType = testParams(0).asInstanceOf[String]
+ val writePartitionFields = testParams(1).asInstanceOf[String]
+ val dropPartitionStatement = testParams(2).asInstanceOf[String]
+ val droppedPartition = testParams(3).asInstanceOf[String]
+ val expectedPartitions = testParams(4).asInstanceOf[Seq[String]]
+ val tsGenFunc = testParams(5).asInstanceOf[Integer => String]
+ val partitionGenFunc = testParams(6).asInstanceOf[(Integer, String)
=> String]
+ val tablePath = tmp.getCanonicalPath + "/" + tableName
+ val timestampKeyGeneratorConfig = if
(writePartitionFields.contains("timestamp")) {
+ TS_KEY_GEN_CONFIGS
+ } else {
+ Map[String, String]()
+ }
+ val timestampKeyGenProps = if (timestampKeyGeneratorConfig.nonEmpty)
{
+ ", " + timestampKeyGeneratorConfig.map(e => e._1 + " = '" + e._2 +
"'").mkString(", ")
+ } else {
+ ""
+ }
+
+ prepareTableWithKeyGenerator(
+ tableName, tablePath, tableType,
+ CUSTOM_KEY_GEN_CLASS_NAME, writePartitionFields,
timestampKeyGeneratorConfig)
+
+ // SQL CTAS with table properties containing key generator write
configs
+ createTableWithSql(tableName, tablePath,
+ s"hoodie.datasource.write.partitionpath.field =
'$writePartitionFields'" + timestampKeyGenProps)
+
+ // Prepare source and test SQL INSERT INTO
+ val sourceTableName = tableName + "_source"
+ prepareParquetSource(sourceTableName, Seq(
+ "(7, 'a7', 1399.0, 1706800227, 'cat1')",
+ "(8, 'a8', 26.9, 1706800227, 'cat3')",
+ "(9, 'a9', 299.0, 1701443427, 'cat4')"))
+ spark.sql(
+ s"""
+ | INSERT INTO $tableName
+ | SELECT * from ${tableName}_source
+ | """.stripMargin)
+ validateResults(
+ tableName,
+ s"SELECT id, name, cast(price as string), cast(ts as string),
segment from $tableName",
+ tsGenFunc,
+ partitionGenFunc,
+ Seq(),
+ Seq(1, "a1", "1.6", 1704121827, "cat1"),
+ Seq(2, "a2", "10.8", 1704121827, "cat1"),
+ Seq(3, "a3", "30.0", 1706800227, "cat1"),
+ Seq(4, "a4", "103.4", 1701443427, "cat2"),
+ Seq(5, "a5", "1999.0", 1704121827, "cat2"),
+ Seq(6, "a6", "80.0", 1704121827, "cat3"),
+ Seq(7, "a7", "1399.0", 1706800227, "cat1"),
+ Seq(8, "a8", "26.9", 1706800227, "cat3"),
+ Seq(9, "a9", "299.0", 1701443427, "cat4")
+ )
+
+ // Test SQL UPDATE
+ spark.sql(
+ s"""
+ | UPDATE $tableName
+ | SET price = price + 10.0
+ | WHERE id between 4 and 7
+ | """.stripMargin)
+ validateResults(
+ tableName,
+ s"SELECT id, name, cast(price as string), cast(ts as string),
segment from $tableName",
+ tsGenFunc,
+ partitionGenFunc,
+ Seq(),
+ Seq(1, "a1", "1.6", 1704121827, "cat1"),
+ Seq(2, "a2", "10.8", 1704121827, "cat1"),
+ Seq(3, "a3", "30.0", 1706800227, "cat1"),
+ Seq(4, "a4", "113.4", 1701443427, "cat2"),
+ Seq(5, "a5", "2009.0", 1704121827, "cat2"),
+ Seq(6, "a6", "90.0", 1704121827, "cat3"),
+ Seq(7, "a7", "1409.0", 1706800227, "cat1"),
+ Seq(8, "a8", "26.9", 1706800227, "cat3"),
+ Seq(9, "a9", "299.0", 1701443427, "cat4")
+ )
+
+ // Test SQL MERGE INTO
+ spark.sql(
+ s"""
+ | MERGE INTO $tableName as target
+ | USING (
+ | SELECT 1 as id, 'a1' as name, 1.6 as price, 1704121827 as
ts, 'cat1' as segment, 'delete' as flag
+ | UNION
+ | SELECT 2 as id, 'a2' as name, 11.9 as price, 1704121827 as
ts, 'cat1' as segment, '' as flag
+ | UNION
+ | SELECT 6 as id, 'a6' as name, 99.0 as price, 1704121827 as
ts, 'cat3' as segment, '' as flag
+ | UNION
+ | SELECT 8 as id, 'a8' as name, 24.9 as price, 1706800227 as
ts, 'cat3' as segment, '' as flag
+ | UNION
+ | SELECT 10 as id, 'a10' as name, 888.8 as price, 1706800227
as ts, 'cat5' as segment, '' as flag
+ | ) source
+ | on target.id = source.id
+ | WHEN MATCHED AND flag != 'delete' THEN UPDATE SET
+ | id = source.id, name = source.name, price = source.price,
ts = source.ts, segment = source.segment
+ | WHEN MATCHED AND flag = 'delete' THEN DELETE
+ | WHEN NOT MATCHED THEN INSERT (id, name, price, ts, segment)
+ | values (source.id, source.name, source.price, source.ts,
source.segment)
+ | """.stripMargin)
+ validateResults(
+ tableName,
+ s"SELECT id, name, cast(price as string), cast(ts as string),
segment from $tableName",
+ tsGenFunc,
+ partitionGenFunc,
+ Seq(),
+ Seq(2, "a2", "11.9", 1704121827, "cat1"),
+ Seq(3, "a3", "30.0", 1706800227, "cat1"),
+ Seq(4, "a4", "113.4", 1701443427, "cat2"),
+ Seq(5, "a5", "2009.0", 1704121827, "cat2"),
+ Seq(6, "a6", "99.0", 1704121827, "cat3"),
+ Seq(7, "a7", "1409.0", 1706800227, "cat1"),
+ Seq(8, "a8", "24.9", 1706800227, "cat3"),
+ Seq(9, "a9", "299.0", 1701443427, "cat4"),
+ Seq(10, "a10", "888.8", 1706800227, "cat5")
+ )
+
+ // Test SQL DELETE
+ spark.sql(
+ s"""
+ | DELETE FROM $tableName
+ | WHERE id = 7
+ | """.stripMargin)
+ validateResults(
+ tableName,
+ s"SELECT id, name, cast(price as string), cast(ts as string),
segment from $tableName",
+ tsGenFunc,
+ partitionGenFunc,
+ Seq(),
+ Seq(2, "a2", "11.9", 1704121827, "cat1"),
+ Seq(3, "a3", "30.0", 1706800227, "cat1"),
+ Seq(4, "a4", "113.4", 1701443427, "cat2"),
+ Seq(5, "a5", "2009.0", 1704121827, "cat2"),
+ Seq(6, "a6", "99.0", 1704121827, "cat3"),
+ Seq(8, "a8", "24.9", 1706800227, "cat3"),
+ Seq(9, "a9", "299.0", 1701443427, "cat4"),
+ Seq(10, "a10", "888.8", 1706800227, "cat5")
+ )
+
+ // Test DROP PARTITION
+
assertTrue(getSortedTablePartitions(tableName).contains(droppedPartition))
+ spark.sql(
+ s"""
+ | ALTER TABLE $tableName DROP PARTITION $dropPartitionStatement
+ |""".stripMargin)
+ validatePartitions(tableName, Seq(droppedPartition),
expectedPartitions)
+
+ if (HoodieSparkUtils.isSpark3) {
+ // Test INSERT OVERWRITE, only supported in Spark 3.x
+ spark.sql(
+ s"""
+ | INSERT OVERWRITE $tableName
+ | SELECT 100 as id, 'a100' as name, 299.0 as price,
1706800227 as ts, 'cat10' as segment
+ | """.stripMargin)
+ validateResults(
+ tableName,
+ s"SELECT id, name, cast(price as string), cast(ts as string),
segment from $tableName",
+ tsGenFunc,
+ partitionGenFunc,
+ Seq(),
+ Seq(100, "a100", "299.0", 1706800227, "cat10")
+ )
+ }
+ }
+ }
+ }
+ }
+
+ test("Test table property isolation for partition path field config "
+ + "with custom key generator for Spark 3.1 and above") {
+ // Only testing Spark 3.1 and above as lower Spark versions do not support
+ // ALTER TABLE .. SET TBLPROPERTIES .. to store table-level properties in
Hudi Catalog
+ if (HoodieSparkUtils.gteqSpark3_1) {
+ withTempDir { tmp => {
+ val tableNameNonPartitioned = generateTableName
+ val tableNameSimpleKey = generateTableName
+ val tableNameCustom1 = generateTableName
+ val tableNameCustom2 = generateTableName
+
+ val tablePathNonPartitioned = tmp.getCanonicalPath + "/" +
tableNameNonPartitioned
+ val tablePathSimpleKey = tmp.getCanonicalPath + "/" +
tableNameSimpleKey
+ val tablePathCustom1 = tmp.getCanonicalPath + "/" + tableNameCustom1
+ val tablePathCustom2 = tmp.getCanonicalPath + "/" + tableNameCustom2
+
+ val tableType = "MERGE_ON_READ"
+ val writePartitionFields1 = "segment:simple"
+ val writePartitionFields2 = "ts:timestamp,segment:simple"
+
+ prepareTableWithKeyGenerator(
+ tableNameNonPartitioned, tablePathNonPartitioned, tableType,
+ NONPARTITIONED_KEY_GEN_CLASS_NAME, "", Map())
+ prepareTableWithKeyGenerator(
+ tableNameSimpleKey, tablePathSimpleKey, tableType,
+ SIMPLE_KEY_GEN_CLASS_NAME, "segment", Map())
+ prepareTableWithKeyGenerator(
+ tableNameCustom1, tablePathCustom1, tableType,
+ CUSTOM_KEY_GEN_CLASS_NAME, writePartitionFields1, Map())
+ prepareTableWithKeyGenerator(
+ tableNameCustom2, tablePathCustom2, tableType,
+ CUSTOM_KEY_GEN_CLASS_NAME, writePartitionFields2, TS_KEY_GEN_CONFIGS)
+
+ // Non-partitioned table does not require additional partition path
field write config
+ createTableWithSql(tableNameNonPartitioned, tablePathNonPartitioned,
"")
+ // Partitioned table with simple key generator does not require
additional partition path field write config
+ createTableWithSql(tableNameSimpleKey, tablePathSimpleKey, "")
+ // Partitioned table with custom key generator requires additional
partition path field write config
+ // Without that, right now the SQL DML fails
+ createTableWithSql(tableNameCustom1, tablePathCustom1, "")
+ createTableWithSql(tableNameCustom2, tablePathCustom2,
+ s"hoodie.datasource.write.partitionpath.field =
'$writePartitionFields2', "
+ + TS_KEY_GEN_CONFIGS.map(e => e._1 + " = '" + e._2 +
"'").mkString(", "))
+
+ val segmentPartitionFunc = (_: Integer, segment: String) => segment
+ val customPartitionFunc = (ts: Integer, segment: String) =>
TS_FORMATTER_FUNC.apply(ts) + "/" + segment
+
+ testFirstRoundInserts(tableNameNonPartitioned, TS_TO_STRING_FUNC, (_,
_) => "")
+ testFirstRoundInserts(tableNameSimpleKey, TS_TO_STRING_FUNC,
segmentPartitionFunc)
+ // INSERT INTO should fail for tableNameCustom1
+ val sourceTableName = tableNameCustom1 + "_source"
+ prepareParquetSource(sourceTableName, Seq("(7, 'a7', 1399.0,
1706800227, 'cat1')"))
+ assertThrows[IOException] {
+ spark.sql(
+ s"""
+ | INSERT INTO $tableNameCustom1
+ | SELECT * from $sourceTableName
+ | """.stripMargin)
+ }
+ testFirstRoundInserts(tableNameCustom2, TS_FORMATTER_FUNC,
customPartitionFunc)
+
+ // Now add the missing partition path field write config for
tableNameCustom1
+ spark.sql(
+ s"""ALTER TABLE $tableNameCustom1
+ | SET TBLPROPERTIES (hoodie.datasource.write.partitionpath.field
= '$writePartitionFields1')
+ | """.stripMargin)
+
+ // All tables should be able to do INSERT INTO without any problem,
+ // since the scope of the added write config is at the catalog table
level
+ testSecondRoundInserts(tableNameNonPartitioned, TS_TO_STRING_FUNC, (_,
_) => "")
+ testSecondRoundInserts(tableNameSimpleKey, TS_TO_STRING_FUNC,
segmentPartitionFunc)
+ testFirstRoundInserts(tableNameCustom1, TS_TO_STRING_FUNC,
segmentPartitionFunc)
+ testSecondRoundInserts(tableNameCustom2, TS_FORMATTER_FUNC,
customPartitionFunc)
+ }
+ }
+ }
+ }
+
+ test("Test wrong partition path field write config with custom key
generator") {
+ withTempDir { tmp => {
+ val tableName = generateTableName
+ val tablePath = tmp.getCanonicalPath + "/" + tableName
+ val tableType = "MERGE_ON_READ"
+ val writePartitionFields = "segment:simple,ts:timestamp"
+ val wrongWritePartitionFields = "segment:simple"
+ val customPartitionFunc = (ts: Integer, segment: String) => segment +
"/" + TS_FORMATTER_FUNC.apply(ts)
+
+ prepareTableWithKeyGenerator(
+ tableName, tablePath, "MERGE_ON_READ",
+ CUSTOM_KEY_GEN_CLASS_NAME, writePartitionFields, TS_KEY_GEN_CONFIGS)
+
+ // CREATE TABLE should fail due to config conflict
+ assertThrows[HoodieException] {
+ createTableWithSql(tableName, tablePath,
+ s"hoodie.datasource.write.partitionpath.field =
'$wrongWritePartitionFields', "
+ + TS_KEY_GEN_CONFIGS.map(e => e._1 + " = '" + e._2 +
"'").mkString(", "))
+ }
+
+ createTableWithSql(tableName, tablePath,
+ s"hoodie.datasource.write.partitionpath.field =
'$writePartitionFields', "
+ + TS_KEY_GEN_CONFIGS.map(e => e._1 + " = '" + e._2 +
"'").mkString(", "))
+ // Set wrong write config
+ spark.sql(
+ s"""ALTER TABLE $tableName
+ | SET TBLPROPERTIES (hoodie.datasource.write.partitionpath.field =
'$wrongWritePartitionFields')
+ | """.stripMargin)
+
+ // INSERT INTO should fail due to conflict between write and table
config of partition path fields
+ val sourceTableName = tableName + "_source"
+ prepareParquetSource(sourceTableName, Seq("(7, 'a7', 1399.0, 1706800227,
'cat1')"))
+ assertThrows[HoodieException] {
+ spark.sql(
+ s"""
+ | INSERT INTO $tableName
+ | SELECT * from $sourceTableName
+ | """.stripMargin)
+ }
+
+ // Only testing Spark 3.1 and above as lower Spark versions do not
support
+ // ALTER TABLE .. SET TBLPROPERTIES .. to store table-level properties
in Hudi Catalog
+ if (HoodieSparkUtils.gteqSpark3_1) {
+ // Now fix the partition path field write config for tableName
+ spark.sql(
+ s"""ALTER TABLE $tableName
+ | SET TBLPROPERTIES (hoodie.datasource.write.partitionpath.field
= '$writePartitionFields')
+ | """.stripMargin)
+
+ // INSERT INTO should succeed now
+ testFirstRoundInserts(tableName, TS_FORMATTER_FUNC,
customPartitionFunc)
+ }
+ }
+ }
+ }
+
+ private def testFirstRoundInserts(tableName: String,
+ tsGenFunc: Integer => String,
+ partitionGenFunc: (Integer, String) =>
String): Unit = {
+ val sourceTableName = tableName + "_source1"
+ prepareParquetSource(sourceTableName, Seq("(7, 'a7', 1399.0, 1706800227,
'cat1')"))
+ spark.sql(
+ s"""
+ | INSERT INTO $tableName
+ | SELECT * from $sourceTableName
+ | """.stripMargin)
+ validateResults(
+ tableName,
+ s"SELECT id, name, cast(price as string), cast(ts as string), segment
from $tableName",
+ tsGenFunc,
+ partitionGenFunc,
+ Seq(),
+ Seq(1, "a1", "1.6", 1704121827, "cat1"),
+ Seq(2, "a2", "10.8", 1704121827, "cat1"),
+ Seq(3, "a3", "30.0", 1706800227, "cat1"),
+ Seq(4, "a4", "103.4", 1701443427, "cat2"),
+ Seq(5, "a5", "1999.0", 1704121827, "cat2"),
+ Seq(6, "a6", "80.0", 1704121827, "cat3"),
+ Seq(7, "a7", "1399.0", 1706800227, "cat1")
+ )
+ }
+
+ private def testSecondRoundInserts(tableName: String,
+ tsGenFunc: Integer => String,
+ partitionGenFunc: (Integer, String) =>
String): Unit = {
+ val sourceTableName = tableName + "_source2"
+ prepareParquetSource(sourceTableName, Seq("(8, 'a8', 26.9, 1706800227,
'cat3')"))
+ spark.sql(
+ s"""
+ | INSERT INTO $tableName
+ | SELECT * from $sourceTableName
+ | """.stripMargin)
+ validateResults(
+ tableName,
+ s"SELECT id, name, cast(price as string), cast(ts as string), segment
from $tableName",
+ tsGenFunc,
+ partitionGenFunc,
+ Seq(),
+ Seq(1, "a1", "1.6", 1704121827, "cat1"),
+ Seq(2, "a2", "10.8", 1704121827, "cat1"),
+ Seq(3, "a3", "30.0", 1706800227, "cat1"),
+ Seq(4, "a4", "103.4", 1701443427, "cat2"),
+ Seq(5, "a5", "1999.0", 1704121827, "cat2"),
+ Seq(6, "a6", "80.0", 1704121827, "cat3"),
+ Seq(7, "a7", "1399.0", 1706800227, "cat1"),
+ Seq(8, "a8", "26.9", 1706800227, "cat3")
+ )
+ }
+
+ private def prepareTableWithKeyGenerator(tableName: String,
+ tablePath: String,
+ tableType: String,
+ keyGenClassName: String,
+ writePartitionFields: String,
+ timestampKeyGeneratorConfig:
Map[String, String]): Unit = {
+ val df = spark.sql(
+ s"""SELECT 1 as id, 'a1' as name, 1.6 as price, 1704121827 as ts, 'cat1'
as segment
+ | UNION
+ | SELECT 2 as id, 'a2' as name, 10.8 as price, 1704121827 as ts,
'cat1' as segment
+ | UNION
+ | SELECT 3 as id, 'a3' as name, 30.0 as price, 1706800227 as ts,
'cat1' as segment
+ | UNION
+ | SELECT 4 as id, 'a4' as name, 103.4 as price, 1701443427 as ts,
'cat2' as segment
+ | UNION
+ | SELECT 5 as id, 'a5' as name, 1999.0 as price, 1704121827 as ts,
'cat2' as segment
+ | UNION
+ | SELECT 6 as id, 'a6' as name, 80.0 as price, 1704121827 as ts,
'cat3' as segment
+ |""".stripMargin)
+
+ df.write.format("hudi")
+ .option("hoodie.datasource.write.table.type", tableType)
+ .option("hoodie.datasource.write.keygenerator.class", keyGenClassName)
+ .option("hoodie.datasource.write.partitionpath.field",
writePartitionFields)
+ .option("hoodie.datasource.write.recordkey.field", "id")
+ .option("hoodie.datasource.write.precombine.field", "name")
+ .option("hoodie.table.name", tableName)
+ .option("hoodie.insert.shuffle.parallelism", "1")
+ .option("hoodie.upsert.shuffle.parallelism", "1")
+ .option("hoodie.bulkinsert.shuffle.parallelism", "1")
+ .options(timestampKeyGeneratorConfig)
+ .mode(SaveMode.Overwrite)
+ .save(tablePath)
+
+ // Validate that the generated table has expected table configs of key
generator and partition path fields
+ val metaClient = HoodieTableMetaClient.builder()
+ .setConf(spark.sparkContext.hadoopConfiguration)
+ .setBasePath(tablePath)
+ .build()
+ assertEquals(keyGenClassName,
metaClient.getTableConfig.getKeyGeneratorClassName)
+ // Validate that that partition path fields in the table config should
always
+ // contain the field names only (no key generator type like
"segment:simple")
+ if (CUSTOM_KEY_GEN_CLASS_NAME.equals(keyGenClassName)) {
+ val props = new TypedProperties()
+ props.put("hoodie.datasource.write.partitionpath.field",
writePartitionFields)
+ timestampKeyGeneratorConfig.foreach(e => {
+ props.put(e._1, e._2)
+ })
+ // For custom key generator, the
"hoodie.datasource.write.partitionpath.field"
+ // contains the key generator type, like "ts:timestamp,segment:simple",
+ // whereas the partition path fields in table config is "ts,segment"
+ assertEquals(
+
SparkKeyGenUtils.getPartitionColumns(Option(CUSTOM_KEY_GEN_CLASS_NAME), props),
+ metaClient.getTableConfig.getPartitionFieldProp)
+ } else {
+ assertEquals(writePartitionFields,
metaClient.getTableConfig.getPartitionFieldProp)
+ }
+ }
+
+ private def createTableWithSql(tableName: String,
+ tablePath: String,
+ tblProps: String): Unit = {
+ val tblPropsStatement = if (StringUtils.isNullOrEmpty(tblProps)) {
+ ""
+ } else {
+ "TBLPROPERTIES (\n" + tblProps + "\n)"
+ }
+ spark.sql(
+ s"""
+ | CREATE TABLE $tableName USING HUDI
+ | location '$tablePath'
+ | $tblPropsStatement
+ | """.stripMargin)
+ }
+
+ private def prepareParquetSource(sourceTableName: String,
+ rows: Seq[String]): Unit = {
+ spark.sql(
+ s"""CREATE TABLE $sourceTableName
+ | (id int, name string, price decimal(5, 1), ts int, segment string)
+ | USING PARQUET
+ |""".stripMargin)
+ spark.sql(
+ s"""
+ | INSERT INTO $sourceTableName values
+ | ${rows.mkString(", ")}
+ | """.stripMargin)
+ }
+
+ private def validateResults(tableName: String,
+ sql: String,
+ tsGenFunc: Integer => String,
+ partitionGenFunc: (Integer, String) => String,
+ droppedPartitions: Seq[String],
+ expects: Seq[Any]*): Unit = {
+ checkAnswer(sql)(
+ expects.map(e => Seq(e(0), e(1), e(2),
tsGenFunc.apply(e(3).asInstanceOf[Integer]), e(4))): _*
+ )
+ val expectedPartitions: Seq[String] = expects
+ .map(e => partitionGenFunc.apply(e(3).asInstanceOf[Integer],
e(4).asInstanceOf[String]))
+ .distinct.sorted
+ validatePartitions(tableName, droppedPartitions, expectedPartitions)
+ }
+
+ private def getSortedTablePartitions(tableName: String): Seq[String] = {
+ spark.sql(s"SHOW PARTITIONS $tableName").collect()
+ .map(row => row.getString(0))
+ .sorted.toSeq
+ }
+
+ private def validatePartitions(tableName: String,
+ droppedPartitions: Seq[String],
+ expectedPartitions: Seq[String]): Unit = {
+ val actualPartitions: Seq[String] = getSortedTablePartitions(tableName)
+ if (expectedPartitions.size == 1 && expectedPartitions.head.isEmpty) {
+ assertTrue(actualPartitions.isEmpty)
+ } else {
+ assertEquals(expectedPartitions, actualPartitions)
+ }
+ droppedPartitions.foreach(dropped =>
assertFalse(actualPartitions.contains(dropped)))
+ }
+}
+
+object TestSparkSqlWithCustomKeyGenerator {
+ val SIMPLE_KEY_GEN_CLASS_NAME = "org.apache.hudi.keygen.SimpleKeyGenerator"
+ val NONPARTITIONED_KEY_GEN_CLASS_NAME =
"org.apache.hudi.keygen.NonpartitionedKeyGenerator"
+ val CUSTOM_KEY_GEN_CLASS_NAME = "org.apache.hudi.keygen.CustomKeyGenerator"
+ val DATE_FORMAT_PATTERN = "yyyyMM"
+ val TS_KEY_GEN_CONFIGS = Map(
+ "hoodie.keygen.timebased.timestamp.type" -> "SCALAR",
+ "hoodie.keygen.timebased.output.dateformat" -> DATE_FORMAT_PATTERN,
+ "hoodie.keygen.timebased.timestamp.scalar.time.unit" -> "seconds"
+ )
+ val TS_TO_STRING_FUNC = (tsSeconds: Integer) => tsSeconds.toString
+ val TS_FORMATTER_FUNC = (tsSeconds: Integer) => {
+ new DateTime(tsSeconds *
1000L).toString(DateTimeFormat.forPattern(DATE_FORMAT_PATTERN))
+ }
+
+ def getTimestampKeyGenConfigs: Map[String, String] = {
+ Map(
+ "hoodie.keygen.timebased.timestamp.type" -> "SCALAR",
+ "hoodie.keygen.timebased.output.dateformat" -> DATE_FORMAT_PATTERN,
+ "hoodie.keygen.timebased.timestamp.scalar.time.unit" -> "seconds"
+ )
+ }
+}