This is an automated email from the ASF dual-hosted git repository.

danny0405 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/hudi.git


The following commit(s) were added to refs/heads/master by this push:
     new 696911ed8c4 [HUDI-7305] Fix cast exception for byte/short/float 
partitioned field (#10518)
696911ed8c4 is described below

commit 696911ed8c48bd74cd1a93322a4c1d39bba11a6c
Author: stream2000 <[email protected]>
AuthorDate: Fri Jan 19 10:12:43 2024 +0800

    [HUDI-7305] Fix cast exception for byte/short/float partitioned field 
(#10518)
---
 .../apache/spark/sql/hudi/TestInsertTable.scala    | 37 ++++++++++++++++++++++
 .../datasources/Spark3ParsePartitionUtil.scala     | 10 +++---
 2 files changed, 43 insertions(+), 4 deletions(-)

diff --git 
a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/TestInsertTable.scala
 
b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/TestInsertTable.scala
index 044b6451cdf..05a04daf417 100644
--- 
a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/TestInsertTable.scala
+++ 
b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/TestInsertTable.scala
@@ -2334,6 +2334,43 @@ class TestInsertTable extends HoodieSparkSqlTestBase {
     })
   }
 
+  test("Test various data types as partition fields") {
+    withRecordType()(withTempDir { tmp =>
+      val tableName = generateTableName
+      spark.sql(
+        s"""
+           |CREATE TABLE $tableName (
+           |  id INT,
+           |  boolean_field BOOLEAN,
+           |  float_field FLOAT,
+           |  byte_field BYTE,
+           |  short_field SHORT,
+           |  decimal_field DECIMAL(10, 5),
+           |  date_field DATE,
+           |  string_field STRING,
+           |  timestamp_field TIMESTAMP
+           |) USING hudi
+           | TBLPROPERTIES (primaryKey = 'id')
+           | PARTITIONED BY (boolean_field, float_field, byte_field, 
short_field, decimal_field, date_field, string_field, timestamp_field)
+           |LOCATION '${tmp.getCanonicalPath}'
+     """.stripMargin)
+
+      // Insert data into partitioned table
+      spark.sql(
+        s"""
+           |INSERT INTO $tableName VALUES
+           |(1, TRUE, CAST(1.0 as FLOAT), 1, 1, 1234.56789, DATE '2021-01-05', 
'partition1', TIMESTAMP '2021-01-05 10:00:00'),
+           |(2, FALSE,CAST(2.0 as FLOAT), 2, 2, 6789.12345, DATE '2021-01-06', 
'partition2', TIMESTAMP '2021-01-06 11:00:00')
+     """.stripMargin)
+
+      checkAnswer(s"SELECT id, boolean_field FROM $tableName ORDER BY id")(
+        Seq(1, true),
+        Seq(2, false)
+      )
+    })
+  }
+
+
   def ingestAndValidateDataDupPolicy(tableType: String, tableName: String, 
tmp: File,
                                      expectedOperationtype: WriteOperationType 
= WriteOperationType.INSERT,
                                      setOptions: List[String] = List.empty,
diff --git 
a/hudi-spark-datasource/hudi-spark3-common/src/main/scala/org/apache/spark/sql/execution/datasources/Spark3ParsePartitionUtil.scala
 
b/hudi-spark-datasource/hudi-spark3-common/src/main/scala/org/apache/spark/sql/execution/datasources/Spark3ParsePartitionUtil.scala
index ebe92a5a32a..fca21d202a9 100644
--- 
a/hudi-spark-datasource/hudi-spark3-common/src/main/scala/org/apache/spark/sql/execution/datasources/Spark3ParsePartitionUtil.scala
+++ 
b/hudi-spark-datasource/hudi-spark3-common/src/main/scala/org/apache/spark/sql/execution/datasources/Spark3ParsePartitionUtil.scala
@@ -20,7 +20,6 @@ package org.apache.spark.sql.execution.datasources
 import org.apache.hadoop.fs.Path
 import 
org.apache.hudi.common.util.PartitionPathEncodeUtils.DEFAULT_PARTITION_PATH
 import org.apache.hudi.spark3.internal.ReflectUtil
-import org.apache.hudi.util.JFunction
 import org.apache.spark.sql.catalyst.InternalRow
 import 
org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils.unescapePathName
 import org.apache.spark.sql.catalyst.expressions.{Cast, Literal}
@@ -29,10 +28,9 @@ import 
org.apache.spark.sql.execution.datasources.PartitioningUtils.timestampPar
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.types.UTF8String
 
-import java.lang.{Boolean => JBoolean, Double => JDouble, Long => JLong}
+import java.lang.{Double => JDouble, Long => JLong}
 import java.math.{BigDecimal => JBigDecimal}
 import java.time.ZoneId
-import java.util
 import java.util.concurrent.ConcurrentHashMap
 import java.util.{Locale, TimeZone}
 import scala.collection.convert.Wrappers.JConcurrentMapWrapper
@@ -259,10 +257,12 @@ object Spark3ParsePartitionUtil extends 
SparkParsePartitionUtil {
       zoneId: ZoneId): Any = desiredType match {
     case _ if value == DEFAULT_PARTITION_PATH => null
     case NullType => null
-    case BooleanType => JBoolean.parseBoolean(value)
     case StringType => UTF8String.fromString(unescapePathName(value))
+    case ByteType => Integer.parseInt(value).toByte
+    case ShortType => Integer.parseInt(value).toShort
     case IntegerType => Integer.parseInt(value)
     case LongType => JLong.parseLong(value)
+    case FloatType => JDouble.parseDouble(value).toFloat
     case DoubleType => JDouble.parseDouble(value)
     case _: DecimalType => Literal(new JBigDecimal(value)).value
     case DateType =>
@@ -274,6 +274,8 @@ object Spark3ParsePartitionUtil extends 
SparkParsePartitionUtil {
       }.getOrElse {
         Cast(Cast(Literal(value), DateType, Some(zoneId.getId)), dt).eval()
       }
+    case BinaryType => value.getBytes()
+    case BooleanType => value.toBoolean
     case dt => throw new IllegalArgumentException(s"Unexpected type $dt")
   }
 

Reply via email to