This is an automated email from the ASF dual-hosted git repository.
xushiyan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/hudi.git
The following commit(s) were added to refs/heads/master by this push:
new 6c73075afca [HUDI-8783] Add tests for decimal data type (#12519)
6c73075afca is described below
commit 6c73075afcacbfa5f55b929ae98238ec54e458e8
Author: Lin Liu <[email protected]>
AuthorDate: Mon Jan 6 11:06:52 2025 -0800
[HUDI-8783] Add tests for decimal data type (#12519)
---
.../SparkFileFormatInternalRowReaderContext.scala | 12 +-
.../apache/hudi/TestDecimalTypeDataWorkflow.scala | 130 +++++++++++++++++++++
2 files changed, 138 insertions(+), 4 deletions(-)
diff --git
a/hudi-client/hudi-spark-client/src/main/scala/org/apache/hudi/SparkFileFormatInternalRowReaderContext.scala
b/hudi-client/hudi-spark-client/src/main/scala/org/apache/hudi/SparkFileFormatInternalRowReaderContext.scala
index 06d13e12a01..9ffc12a3db2 100644
---
a/hudi-client/hudi-spark-client/src/main/scala/org/apache/hudi/SparkFileFormatInternalRowReaderContext.scala
+++
b/hudi-client/hudi-spark-client/src/main/scala/org/apache/hudi/SparkFileFormatInternalRowReaderContext.scala
@@ -31,7 +31,6 @@ import
org.apache.hudi.common.util.collection.{CachingIterator, ClosableIterator
import org.apache.hudi.io.storage.{HoodieSparkFileReaderFactory,
HoodieSparkParquetReader}
import org.apache.hudi.storage.{HoodieStorage, StorageConfiguration,
StoragePath}
import org.apache.hudi.util.CloseableInternalRowIterator
-
import org.apache.avro.Schema
import org.apache.avro.Schema.Type
import org.apache.avro.generic.{GenericRecord, IndexedRecord}
@@ -42,10 +41,11 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.JoinedRow
import org.apache.spark.sql.execution.datasources.PartitionedFile
import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat,
SparkParquetReader}
+import org.apache.spark.sql.functions.lit
import org.apache.spark.sql.hudi.SparkAdapter
import org.apache.spark.sql.sources.Filter
-import org.apache.spark.sql.types.{LongType, MetadataBuilder, StructField,
StructType}
-import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector}
+import org.apache.spark.sql.types.{DecimalType, LongType, MetadataBuilder,
StructField, StructType}
+import org.apache.spark.sql.vectorized.{ColumnVector, ColumnarBatch}
import org.apache.spark.unsafe.types.UTF8String
import scala.collection.mutable
@@ -263,13 +263,15 @@ class
SparkFileFormatInternalRowReaderContext(parquetFileReader: SparkParquetRea
}
override def castValue(value: Comparable[_], newType: Schema.Type):
Comparable[_] = {
- value match {
+ val valueToCast = if (value == null) 0 else value
+ valueToCast match {
case v: Integer => newType match {
case Type.INT => v
case Type.LONG => v.longValue()
case Type.FLOAT => v.floatValue()
case Type.DOUBLE => v.doubleValue()
case Type.STRING => UTF8String.fromString(v.toString)
+ case Type.FIXED => BigDecimal(v)
case x => throw new UnsupportedOperationException(s"Cast from Integer
to $x is not supported")
}
case v: java.lang.Long => newType match {
@@ -277,6 +279,7 @@ class
SparkFileFormatInternalRowReaderContext(parquetFileReader: SparkParquetRea
case Type.FLOAT => v.floatValue()
case Type.DOUBLE => v.doubleValue()
case Type.STRING => UTF8String.fromString(v.toString)
+ case Type.FIXED => BigDecimal(v)
case x => throw new UnsupportedOperationException(s"Cast from Long to
$x is not supported")
}
case v: java.lang.Float => newType match {
@@ -288,6 +291,7 @@ class
SparkFileFormatInternalRowReaderContext(parquetFileReader: SparkParquetRea
case v: java.lang.Double => newType match {
case Type.DOUBLE => v
case Type.STRING => UTF8String.fromString(v.toString)
+ case Type.FIXED => BigDecimal(v)
case x => throw new UnsupportedOperationException(s"Cast from Double
to $x is not supported")
}
case v: String => newType match {
diff --git
a/hudi-spark-datasource/hudi-spark/src/test/java/org/apache/hudi/TestDecimalTypeDataWorkflow.scala
b/hudi-spark-datasource/hudi-spark/src/test/java/org/apache/hudi/TestDecimalTypeDataWorkflow.scala
new file mode 100644
index 00000000000..c4014bc5719
--- /dev/null
+++
b/hudi-spark-datasource/hudi-spark/src/test/java/org/apache/hudi/TestDecimalTypeDataWorkflow.scala
@@ -0,0 +1,130 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.hudi
+
+import org.apache.hudi.DataSourceWriteOptions._
+import org.apache.hudi.common.config.{HoodieReaderConfig, HoodieStorageConfig}
+import org.apache.hudi.config.HoodieWriteConfig
+import org.apache.hudi.testutils.SparkClientFunctionalTestHarness
+import org.apache.spark.sql.types.{Decimal, DecimalType, IntegerType,
StructField, StructType}
+import org.apache.spark.sql.{DataFrame, Row, SaveMode}
+import org.junit.jupiter.api.Assertions.assertTrue
+import org.junit.jupiter.api.Test
+import org.junit.jupiter.params.ParameterizedTest
+import org.junit.jupiter.params.provider.CsvSource
+
+class TestDecimalTypeDataWorkflow extends SparkClientFunctionalTestHarness{
+ val sparkOpts: Map[String, String] = Map(
+ HoodieStorageConfig.LOGFILE_DATA_BLOCK_FORMAT.key -> "parquet",
+ HoodieWriteConfig.RECORD_MERGE_IMPL_CLASSES.key ->
classOf[DefaultSparkRecordMerger].getName)
+ val fgReaderOpts: Map[String, String] = Map(
+ HoodieReaderConfig.FILE_GROUP_READER_ENABLED.key -> "true",
+ HoodieReaderConfig.MERGE_USE_RECORD_POSITIONS.key -> "true")
+ val opts = sparkOpts ++ fgReaderOpts
+
+ @ParameterizedTest
+ @CsvSource(value = Array("10,2", "15,5", "20,10", "38,18", "5,0"))
+ def testDecimalInsertUpdateDeleteRead(precision: String, scale: String):
Unit = {
+ // Create schema
+ val schema = StructType(Seq(
+ StructField("id", IntegerType, nullable = true),
+ StructField(
+ "decimal_col",
+ DecimalType(Integer.valueOf(precision), Integer.valueOf(scale)),
+ nullable = true)))
+ // Build data conforming to the schema.
+ val tablePath = basePath
+ val data: Seq[(Int, Decimal)] = Seq(
+ (1, Decimal("123.45")),
+ (2, Decimal("987.65")),
+ (3, Decimal("-10.23")),
+ (4, Decimal("0.01")),
+ (5, Decimal("1000.00")))
+ val rows = data.map{
+ case (id, decimalVal) => Row(id, decimalVal.toJavaBigDecimal)}
+ val rddData = spark.sparkContext.parallelize(rows)
+
+ // Insert.
+ val insertDf: DataFrame = spark.sqlContext.createDataFrame(rddData, schema)
+ .toDF("id", "decimal_col").sort("id")
+ insertDf.write.format("hudi")
+ .option(RECORDKEY_FIELD.key(), "id")
+ .option(PRECOMBINE_FIELD.key(), "decimal_col")
+ .option(TABLE_TYPE.key, "MERGE_ON_READ")
+ .option(TABLE_NAME.key, "test_table")
+ .options(opts)
+ .mode(SaveMode.Overwrite)
+ .save(tablePath)
+
+ // Update.
+ val update: Seq[(Int, Decimal)] = Seq(
+ (1, Decimal("543.21")),
+ (2, Decimal("111.11")),
+ (6, Decimal("1001.00")))
+ val updateRows = update.map {
+ case (id, decimalVal) => Row(id, decimalVal.toJavaBigDecimal)
+ }
+ val rddUpdate = spark.sparkContext.parallelize(updateRows)
+ val updateDf: DataFrame = spark.createDataFrame(rddUpdate, schema)
+ .toDF("id", "decimal_col").sort("id")
+ updateDf.write.format("hudi")
+ .option(OPERATION.key(), "upsert")
+ .options(opts)
+ .mode(SaveMode.Append)
+ .save(tablePath)
+
+ // Delete.
+ val delete: Seq[(Int, Decimal)] = Seq(
+ (3, Decimal("543.21")),
+ (4, Decimal("111.11")))
+ val deleteRows = delete.map {
+ case (id, decimalVal) => Row(id, decimalVal.toJavaBigDecimal)
+ }
+ val rddDelete = spark.sparkContext.parallelize(deleteRows)
+ val deleteDf: DataFrame = spark.createDataFrame(rddDelete, schema)
+ .toDF("id", "decimal_col").sort("id")
+ deleteDf.write.format("hudi")
+ .option(OPERATION.key(), "delete")
+ .options(opts)
+ .mode(SaveMode.Append)
+ .save(tablePath)
+
+ // Asserts
+ val actual = spark.read.format("hudi").load(tablePath).select("id",
"decimal_col")
+ val expected: Seq[(Int, Decimal)] = Seq(
+ (1, Decimal("543.21")),
+ (2, Decimal("987.65")),
+ (5, Decimal("1000.00")),
+ (6, Decimal("1001.00")))
+ val expectedRows = expected.map {
+ case (id, decimalVal) => Row(id, decimalVal.toJavaBigDecimal)
+ }
+ val rddExpected = spark.sparkContext.parallelize(expectedRows)
+ val expectedDf: DataFrame = spark.createDataFrame(rddExpected, schema)
+ .toDF("id", "decimal_col").sort("id")
+ val expectedMinusActual = expectedDf.except(actual)
+ val actualMinusExpected = actual.except(expectedDf)
+ expectedDf.show(false)
+ actual.show(false)
+ expectedMinusActual.show(false)
+ actualMinusExpected.show(false)
+ assertTrue(expectedMinusActual.isEmpty && actualMinusExpected.isEmpty)
+ }
+}