This is an automated email from the ASF dual-hosted git repository.

xushiyan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/hudi.git


The following commit(s) were added to refs/heads/master by this push:
     new 6c73075afca [HUDI-8783] Add tests for decimal data type (#12519)
6c73075afca is described below

commit 6c73075afcacbfa5f55b929ae98238ec54e458e8
Author: Lin Liu <[email protected]>
AuthorDate: Mon Jan 6 11:06:52 2025 -0800

    [HUDI-8783] Add tests for decimal data type (#12519)
---
 .../SparkFileFormatInternalRowReaderContext.scala  |  12 +-
 .../apache/hudi/TestDecimalTypeDataWorkflow.scala  | 130 +++++++++++++++++++++
 2 files changed, 138 insertions(+), 4 deletions(-)

diff --git 
a/hudi-client/hudi-spark-client/src/main/scala/org/apache/hudi/SparkFileFormatInternalRowReaderContext.scala
 
b/hudi-client/hudi-spark-client/src/main/scala/org/apache/hudi/SparkFileFormatInternalRowReaderContext.scala
index 06d13e12a01..9ffc12a3db2 100644
--- 
a/hudi-client/hudi-spark-client/src/main/scala/org/apache/hudi/SparkFileFormatInternalRowReaderContext.scala
+++ 
b/hudi-client/hudi-spark-client/src/main/scala/org/apache/hudi/SparkFileFormatInternalRowReaderContext.scala
@@ -31,7 +31,6 @@ import 
org.apache.hudi.common.util.collection.{CachingIterator, ClosableIterator
 import org.apache.hudi.io.storage.{HoodieSparkFileReaderFactory, 
HoodieSparkParquetReader}
 import org.apache.hudi.storage.{HoodieStorage, StorageConfiguration, 
StoragePath}
 import org.apache.hudi.util.CloseableInternalRowIterator
-
 import org.apache.avro.Schema
 import org.apache.avro.Schema.Type
 import org.apache.avro.generic.{GenericRecord, IndexedRecord}
@@ -42,10 +41,11 @@ import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.JoinedRow
 import org.apache.spark.sql.execution.datasources.PartitionedFile
 import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat, 
SparkParquetReader}
+import org.apache.spark.sql.functions.lit
 import org.apache.spark.sql.hudi.SparkAdapter
 import org.apache.spark.sql.sources.Filter
-import org.apache.spark.sql.types.{LongType, MetadataBuilder, StructField, 
StructType}
-import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector}
+import org.apache.spark.sql.types.{DecimalType, LongType, MetadataBuilder, 
StructField, StructType}
+import org.apache.spark.sql.vectorized.{ColumnVector, ColumnarBatch}
 import org.apache.spark.unsafe.types.UTF8String
 
 import scala.collection.mutable
@@ -263,13 +263,15 @@ class 
SparkFileFormatInternalRowReaderContext(parquetFileReader: SparkParquetRea
   }
 
   override def castValue(value: Comparable[_], newType: Schema.Type): 
Comparable[_] = {
-    value match {
+    val valueToCast = if (value == null) 0 else value
+    valueToCast match {
       case v: Integer => newType match {
         case Type.INT => v
         case Type.LONG => v.longValue()
         case Type.FLOAT => v.floatValue()
         case Type.DOUBLE => v.doubleValue()
         case Type.STRING => UTF8String.fromString(v.toString)
+        case Type.FIXED => BigDecimal(v)
         case x => throw new UnsupportedOperationException(s"Cast from Integer 
to $x is not supported")
       }
       case v: java.lang.Long => newType match {
@@ -277,6 +279,7 @@ class 
SparkFileFormatInternalRowReaderContext(parquetFileReader: SparkParquetRea
         case Type.FLOAT => v.floatValue()
         case Type.DOUBLE => v.doubleValue()
         case Type.STRING => UTF8String.fromString(v.toString)
+        case Type.FIXED => BigDecimal(v)
         case x => throw new UnsupportedOperationException(s"Cast from Long to 
$x is not supported")
       }
       case v: java.lang.Float => newType match {
@@ -288,6 +291,7 @@ class 
SparkFileFormatInternalRowReaderContext(parquetFileReader: SparkParquetRea
       case v: java.lang.Double => newType match {
         case Type.DOUBLE => v
         case Type.STRING => UTF8String.fromString(v.toString)
+        case Type.FIXED => BigDecimal(v)
         case x => throw new UnsupportedOperationException(s"Cast from Double 
to $x is not supported")
       }
       case v: String => newType match {
diff --git 
a/hudi-spark-datasource/hudi-spark/src/test/java/org/apache/hudi/TestDecimalTypeDataWorkflow.scala
 
b/hudi-spark-datasource/hudi-spark/src/test/java/org/apache/hudi/TestDecimalTypeDataWorkflow.scala
new file mode 100644
index 00000000000..c4014bc5719
--- /dev/null
+++ 
b/hudi-spark-datasource/hudi-spark/src/test/java/org/apache/hudi/TestDecimalTypeDataWorkflow.scala
@@ -0,0 +1,130 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.hudi
+
+import org.apache.hudi.DataSourceWriteOptions._
+import org.apache.hudi.common.config.{HoodieReaderConfig, HoodieStorageConfig}
+import org.apache.hudi.config.HoodieWriteConfig
+import org.apache.hudi.testutils.SparkClientFunctionalTestHarness
+import org.apache.spark.sql.types.{Decimal, DecimalType, IntegerType, 
StructField, StructType}
+import org.apache.spark.sql.{DataFrame, Row, SaveMode}
+import org.junit.jupiter.api.Assertions.assertTrue
+import org.junit.jupiter.api.Test
+import org.junit.jupiter.params.ParameterizedTest
+import org.junit.jupiter.params.provider.CsvSource
+
+class TestDecimalTypeDataWorkflow extends SparkClientFunctionalTestHarness{
+  val sparkOpts: Map[String, String] = Map(
+    HoodieStorageConfig.LOGFILE_DATA_BLOCK_FORMAT.key -> "parquet",
+    HoodieWriteConfig.RECORD_MERGE_IMPL_CLASSES.key -> 
classOf[DefaultSparkRecordMerger].getName)
+  val fgReaderOpts: Map[String, String] = Map(
+    HoodieReaderConfig.FILE_GROUP_READER_ENABLED.key -> "true",
+    HoodieReaderConfig.MERGE_USE_RECORD_POSITIONS.key -> "true")
+  val opts = sparkOpts ++ fgReaderOpts
+
+  @ParameterizedTest
+  @CsvSource(value = Array("10,2", "15,5", "20,10", "38,18", "5,0"))
+  def testDecimalInsertUpdateDeleteRead(precision: String, scale: String): 
Unit = {
+    // Create schema
+    val schema = StructType(Seq(
+      StructField("id", IntegerType, nullable = true),
+      StructField(
+        "decimal_col",
+        DecimalType(Integer.valueOf(precision), Integer.valueOf(scale)),
+        nullable = true)))
+    // Build data conforming to the schema.
+    val tablePath = basePath
+    val data: Seq[(Int, Decimal)] = Seq(
+      (1, Decimal("123.45")),
+      (2, Decimal("987.65")),
+      (3, Decimal("-10.23")),
+      (4, Decimal("0.01")),
+      (5, Decimal("1000.00")))
+    val rows = data.map{
+      case (id, decimalVal) => Row(id, decimalVal.toJavaBigDecimal)}
+    val rddData = spark.sparkContext.parallelize(rows)
+
+    // Insert.
+    val insertDf: DataFrame = spark.sqlContext.createDataFrame(rddData, schema)
+      .toDF("id", "decimal_col").sort("id")
+    insertDf.write.format("hudi")
+      .option(RECORDKEY_FIELD.key(), "id")
+      .option(PRECOMBINE_FIELD.key(), "decimal_col")
+      .option(TABLE_TYPE.key, "MERGE_ON_READ")
+      .option(TABLE_NAME.key, "test_table")
+      .options(opts)
+      .mode(SaveMode.Overwrite)
+      .save(tablePath)
+
+    // Update.
+    val update: Seq[(Int, Decimal)] = Seq(
+      (1, Decimal("543.21")),
+      (2, Decimal("111.11")),
+      (6, Decimal("1001.00")))
+    val updateRows = update.map {
+      case (id, decimalVal) => Row(id, decimalVal.toJavaBigDecimal)
+    }
+    val rddUpdate = spark.sparkContext.parallelize(updateRows)
+    val updateDf: DataFrame = spark.createDataFrame(rddUpdate, schema)
+      .toDF("id", "decimal_col").sort("id")
+    updateDf.write.format("hudi")
+      .option(OPERATION.key(), "upsert")
+      .options(opts)
+      .mode(SaveMode.Append)
+      .save(tablePath)
+
+    // Delete.
+    val delete: Seq[(Int, Decimal)] = Seq(
+      (3, Decimal("543.21")),
+      (4, Decimal("111.11")))
+    val deleteRows = delete.map {
+      case (id, decimalVal) => Row(id, decimalVal.toJavaBigDecimal)
+    }
+    val rddDelete = spark.sparkContext.parallelize(deleteRows)
+    val deleteDf: DataFrame = spark.createDataFrame(rddDelete, schema)
+      .toDF("id", "decimal_col").sort("id")
+    deleteDf.write.format("hudi")
+      .option(OPERATION.key(), "delete")
+      .options(opts)
+      .mode(SaveMode.Append)
+      .save(tablePath)
+
+    // Asserts
+    val actual = spark.read.format("hudi").load(tablePath).select("id", 
"decimal_col")
+    val expected: Seq[(Int, Decimal)] = Seq(
+      (1, Decimal("543.21")),
+      (2, Decimal("987.65")),
+      (5, Decimal("1000.00")),
+      (6, Decimal("1001.00")))
+    val expectedRows = expected.map {
+      case (id, decimalVal) => Row(id, decimalVal.toJavaBigDecimal)
+    }
+    val rddExpected = spark.sparkContext.parallelize(expectedRows)
+    val expectedDf: DataFrame = spark.createDataFrame(rddExpected, schema)
+      .toDF("id", "decimal_col").sort("id")
+    val expectedMinusActual = expectedDf.except(actual)
+    val actualMinusExpected = actual.except(expectedDf)
+    expectedDf.show(false)
+    actual.show(false)
+    expectedMinusActual.show(false)
+    actualMinusExpected.show(false)
+    assertTrue(expectedMinusActual.isEmpty && actualMinusExpected.isEmpty)
+  }
+}

Reply via email to