This is an automated email from the ASF dual-hosted git repository.
danny0405 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/hudi.git
The following commit(s) were added to refs/heads/master by this push:
new 696911ed8c4 [HUDI-7305] Fix cast exception for byte/short/float
partitioned field (#10518)
696911ed8c4 is described below
commit 696911ed8c48bd74cd1a93322a4c1d39bba11a6c
Author: stream2000 <[email protected]>
AuthorDate: Fri Jan 19 10:12:43 2024 +0800
[HUDI-7305] Fix cast exception for byte/short/float partitioned field
(#10518)
---
.../apache/spark/sql/hudi/TestInsertTable.scala | 37 ++++++++++++++++++++++
.../datasources/Spark3ParsePartitionUtil.scala | 10 +++---
2 files changed, 43 insertions(+), 4 deletions(-)
diff --git
a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/TestInsertTable.scala
b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/TestInsertTable.scala
index 044b6451cdf..05a04daf417 100644
---
a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/TestInsertTable.scala
+++
b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/TestInsertTable.scala
@@ -2334,6 +2334,43 @@ class TestInsertTable extends HoodieSparkSqlTestBase {
})
}
+ test("Test various data types as partition fields") {
+ withRecordType()(withTempDir { tmp =>
+ val tableName = generateTableName
+ spark.sql(
+ s"""
+ |CREATE TABLE $tableName (
+ | id INT,
+ | boolean_field BOOLEAN,
+ | float_field FLOAT,
+ | byte_field BYTE,
+ | short_field SHORT,
+ | decimal_field DECIMAL(10, 5),
+ | date_field DATE,
+ | string_field STRING,
+ | timestamp_field TIMESTAMP
+ |) USING hudi
+ | TBLPROPERTIES (primaryKey = 'id')
+ | PARTITIONED BY (boolean_field, float_field, byte_field,
short_field, decimal_field, date_field, string_field, timestamp_field)
+ |LOCATION '${tmp.getCanonicalPath}'
+ """.stripMargin)
+
+ // Insert data into partitioned table
+ spark.sql(
+ s"""
+ |INSERT INTO $tableName VALUES
+ |(1, TRUE, CAST(1.0 as FLOAT), 1, 1, 1234.56789, DATE '2021-01-05',
'partition1', TIMESTAMP '2021-01-05 10:00:00'),
+ |(2, FALSE,CAST(2.0 as FLOAT), 2, 2, 6789.12345, DATE '2021-01-06',
'partition2', TIMESTAMP '2021-01-06 11:00:00')
+ """.stripMargin)
+
+ checkAnswer(s"SELECT id, boolean_field FROM $tableName ORDER BY id")(
+ Seq(1, true),
+ Seq(2, false)
+ )
+ })
+ }
+
+
def ingestAndValidateDataDupPolicy(tableType: String, tableName: String,
tmp: File,
expectedOperationtype: WriteOperationType
= WriteOperationType.INSERT,
setOptions: List[String] = List.empty,
diff --git
a/hudi-spark-datasource/hudi-spark3-common/src/main/scala/org/apache/spark/sql/execution/datasources/Spark3ParsePartitionUtil.scala
b/hudi-spark-datasource/hudi-spark3-common/src/main/scala/org/apache/spark/sql/execution/datasources/Spark3ParsePartitionUtil.scala
index ebe92a5a32a..fca21d202a9 100644
---
a/hudi-spark-datasource/hudi-spark3-common/src/main/scala/org/apache/spark/sql/execution/datasources/Spark3ParsePartitionUtil.scala
+++
b/hudi-spark-datasource/hudi-spark3-common/src/main/scala/org/apache/spark/sql/execution/datasources/Spark3ParsePartitionUtil.scala
@@ -20,7 +20,6 @@ package org.apache.spark.sql.execution.datasources
import org.apache.hadoop.fs.Path
import
org.apache.hudi.common.util.PartitionPathEncodeUtils.DEFAULT_PARTITION_PATH
import org.apache.hudi.spark3.internal.ReflectUtil
-import org.apache.hudi.util.JFunction
import org.apache.spark.sql.catalyst.InternalRow
import
org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils.unescapePathName
import org.apache.spark.sql.catalyst.expressions.{Cast, Literal}
@@ -29,10 +28,9 @@ import
org.apache.spark.sql.execution.datasources.PartitioningUtils.timestampPar
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
-import java.lang.{Boolean => JBoolean, Double => JDouble, Long => JLong}
+import java.lang.{Double => JDouble, Long => JLong}
import java.math.{BigDecimal => JBigDecimal}
import java.time.ZoneId
-import java.util
import java.util.concurrent.ConcurrentHashMap
import java.util.{Locale, TimeZone}
import scala.collection.convert.Wrappers.JConcurrentMapWrapper
@@ -259,10 +257,12 @@ object Spark3ParsePartitionUtil extends
SparkParsePartitionUtil {
zoneId: ZoneId): Any = desiredType match {
case _ if value == DEFAULT_PARTITION_PATH => null
case NullType => null
- case BooleanType => JBoolean.parseBoolean(value)
case StringType => UTF8String.fromString(unescapePathName(value))
+ case ByteType => Integer.parseInt(value).toByte
+ case ShortType => Integer.parseInt(value).toShort
case IntegerType => Integer.parseInt(value)
case LongType => JLong.parseLong(value)
+ case FloatType => JDouble.parseDouble(value).toFloat
case DoubleType => JDouble.parseDouble(value)
case _: DecimalType => Literal(new JBigDecimal(value)).value
case DateType =>
@@ -274,6 +274,8 @@ object Spark3ParsePartitionUtil extends
SparkParsePartitionUtil {
}.getOrElse {
Cast(Cast(Literal(value), DateType, Some(zoneId.getId)), dt).eval()
}
+ case BinaryType => value.getBytes()
+ case BooleanType => value.toBoolean
case dt => throw new IllegalArgumentException(s"Unexpected type $dt")
}