nsivabalan commented on a change in pull request #2431:
URL: https://github.com/apache/hudi/pull/2431#discussion_r573907738



##########
File path: 
hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/DataSourceOptions.scala
##########
@@ -192,6 +194,42 @@ object DataSourceWriteOptions {
     }
   }
 
+  /**
+    * Translate spark parameters to hudi parameters
+    *
+    * @param optParams Parameters to be translated
+    * @return Parameters after translation
+    */
+  def translateSqlOptions(optParams: Map[String, String]): Map[String, String] 
= {
+    var translatedOptParams = optParams
+    // translate the api partitionBy of spark DataFrameWriter to 
PARTITIONPATH_FIELD_OPT_KEY
+    if (optParams.contains(SparkDataSourceUtils.PARTITIONING_COLUMNS_KEY)) {
+      val partitionColumns = 
optParams.get(SparkDataSourceUtils.PARTITIONING_COLUMNS_KEY)
+        .map(SparkDataSourceUtils.decodePartitioningColumns)
+        .getOrElse(Nil)
+      val keyGeneratorClass = 
optParams.getOrElse(DataSourceWriteOptions.KEYGENERATOR_CLASS_OPT_KEY,
+        DataSourceWriteOptions.DEFAULT_KEYGENERATOR_CLASS_OPT_VAL)
+
+      val partitionPathField =
+        keyGeneratorClass match {
+          // Only CustomKeyGenerator needs special treatment, because it needs 
to be specified in a way
+          // such as "field1:PartitionKeyType1,field2:PartitionKeyType2".

Review comment:
       nice.

##########
File path: 
hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestCOWDataSource.scala
##########
@@ -348,4 +352,141 @@ class TestCOWDataSource extends HoodieClientTestBase {
 
     assertTrue(HoodieDataSourceHelpers.hasNewCommits(fs, basePath, "000"))
   }
+
+  private def getDataFrameWriter(keyGenerator: String): DataFrameWriter[Row] = 
{
+    val records = recordsToStrings(dataGen.generateInserts("000", 100)).toList
+    val inputDF = spark.read.json(spark.sparkContext.parallelize(records, 2))
+
+    inputDF.write.format("hudi")
+      .options(commonOpts)
+      .option(DataSourceWriteOptions.KEYGENERATOR_CLASS_OPT_KEY, keyGenerator)
+      .mode(SaveMode.Overwrite)
+  }
+
+  @Test def testTranslateSparkParamsToHudiParamsWithCustomKeyGenerator(): Unit 
= {
+    // Without fieldType, the default is SIMPLE
+    var writer = getDataFrameWriter(classOf[CustomKeyGenerator].getName)
+    writer.partitionBy("current_ts")
+      .save(basePath)
+
+    var recordsReadDF = spark.read.format("org.apache.hudi")
+      .load(basePath + "/*/*")
+
+    assertTrue(recordsReadDF.filter(col("_hoodie_partition_path") =!= 
col("current_ts").cast("string")).count() == 0)
+
+    // Specify fieldType as TIMESTAMP
+    writer = getDataFrameWriter(classOf[CustomKeyGenerator].getName)
+    writer.partitionBy("current_ts:TIMESTAMP")
+      .option(Config.TIMESTAMP_TYPE_FIELD_PROP, "EPOCHMILLISECONDS")
+      .option(Config.TIMESTAMP_OUTPUT_DATE_FORMAT_PROP, "yyyyMMdd")
+      .save(basePath)
+
+    recordsReadDF = spark.read.format("org.apache.hudi")
+      .load(basePath + "/*/*")
+
+    val udf_date_format = udf((data: Long) => new 
DateTime(data).toString(DateTimeFormat.forPattern("yyyyMMdd")))
+    assertTrue(recordsReadDF.filter(col("_hoodie_partition_path") =!= 
udf_date_format(col("current_ts"))).count() == 0)
+
+    // Mixed fieldType
+    writer = getDataFrameWriter(classOf[CustomKeyGenerator].getName)
+    writer.partitionBy("driver", "rider:SIMPLE", "current_ts:TIMESTAMP")
+      .option(Config.TIMESTAMP_TYPE_FIELD_PROP, "EPOCHMILLISECONDS")
+      .option(Config.TIMESTAMP_OUTPUT_DATE_FORMAT_PROP, "yyyyMMdd")
+      .save(basePath)
+
+    recordsReadDF = spark.read.format("org.apache.hudi")
+      .load(basePath + "/*/*/*")
+    assertTrue(recordsReadDF.filter(col("_hoodie_partition_path") =!=
+      concat(col("driver"), lit("/"), col("rider"), lit("/"), 
udf_date_format(col("current_ts")))).count() == 0)
+
+    // Test invalid partitionKeyType
+    writer = getDataFrameWriter(classOf[CustomKeyGenerator].getName)
+    writer = writer.partitionBy("current_ts:DUMMY")
+      .option(Config.TIMESTAMP_TYPE_FIELD_PROP, "EPOCHMILLISECONDS")
+      .option(Config.TIMESTAMP_OUTPUT_DATE_FORMAT_PROP, "yyyyMMdd")
+    try {
+      writer.save(basePath)
+      fail("should fail when invalid PartitionKeyType is provided!")
+    } catch {
+      case e: Exception =>
+        assertTrue(e.getMessage.contains("No enum constant 
org.apache.hudi.keygen.CustomAvroKeyGenerator.PartitionKeyType.DUMMY"))
+    }
+  }
+
+  @Test def testTranslateSparkParamsToHudiParamsWithSimpleKeyGenerator() {
+    // Use the `driver` field as the partition key
+    var writer = getDataFrameWriter(classOf[SimpleKeyGenerator].getName)
+    writer.partitionBy("driver")
+      .save(basePath)
+
+    var recordsReadDF = spark.read.format("org.apache.hudi")
+      .load(basePath + "/*/*")
+
+    assertTrue(recordsReadDF.filter(col("_hoodie_partition_path") =!= 
col("driver")).count() == 0)
+
+    // Use the `driver,rider` field as the partition key, If no such field 
exists, the default value `default` is used
+    writer = getDataFrameWriter(classOf[SimpleKeyGenerator].getName)
+    writer.partitionBy("driver", "rider")
+      .save(basePath)
+
+    recordsReadDF = spark.read.format("org.apache.hudi")
+      .load(basePath + "/*/*")
+
+    assertTrue(recordsReadDF.filter(col("_hoodie_partition_path") =!= 
lit("default")).count() == 0)
+  }
+
+  @Test def testTranslateSparkParamsToHudiParamsWithComplexKeyGenerator() {
+    // Use the `driver` field as the partition key
+    var writer = getDataFrameWriter(classOf[ComplexKeyGenerator].getName)
+    writer.partitionBy("driver")
+      .save(basePath)
+
+    var recordsReadDF = spark.read.format("org.apache.hudi")
+      .load(basePath + "/*/*")
+
+    assertTrue(recordsReadDF.filter(col("_hoodie_partition_path") =!= 
col("driver")).count() == 0)
+
+    // Use the `driver`,`rider` field as the partition key
+    writer = getDataFrameWriter(classOf[ComplexKeyGenerator].getName)
+    writer.partitionBy("driver", "rider")
+      .save(basePath)
+
+    recordsReadDF = spark.read.format("org.apache.hudi")
+      .load(basePath + "/*/*")
+
+    assertTrue(recordsReadDF.filter(col("_hoodie_partition_path") =!= 
concat(col("driver"), lit("/"), col("rider"))).count() == 0)
+  }
+
+  @Test def 
testTranslateSparkParamsToHudiParamsWithTimestampBasedKeyGenerator() {
+    val writer = 
getDataFrameWriter(classOf[TimestampBasedKeyGenerator].getName)

Review comment:
       https://issues.apache.org/jira/browse/HUDI-1610




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to