This is an automated email from the ASF dual-hosted git repository.
huaxingao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new b6681fbf32fa [SPARK-49787][SQL] Cast between UDT and other types
b6681fbf32fa is described below
commit b6681fbf32fa3596d7649d413f20cc5c6da64991
Author: Liang-Chi Hsieh <[email protected]>
AuthorDate: Fri Sep 27 13:42:54 2024 -0700
[SPARK-49787][SQL] Cast between UDT and other types
### What changes were proposed in this pull request?
This patch adds UDT support to `Cast` expression.
### Why are the changes needed?
Our customer faced an error when migrating queries that write UDT column
from Hive to Iceberg table.
The error happens when Spark tries to cast UDT column to the data type
(i.e., the sql type of the UDT) of the table column. The cast is added by table
column resolution rule for V2 writing commands.
Currently `Cast` expression doesn't support casting between UDT and other
types. However, underlying an UDT, it is serialized as its `sqlType`, `Cast`
should be able to cast between the `sqlType` and other types.
### Does this PR introduce _any_ user-facing change?
Yes. User query can cast between UDT and other types.
### How was this patch tested?
Unit test
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #48251 from viirya/cast_udt.
Authored-by: Liang-Chi Hsieh <[email protected]>
Signed-off-by: huaxingao <[email protected]>
---
python/pyspark/sql/tests/test_types.py | 16 +-
.../org/apache/spark/sql/types/UpCastRule.scala | 4 +
.../spark/sql/catalyst/expressions/Cast.scala | 175 ++++++++++++---------
.../spark/sql/catalyst/expressions/literals.scala | 84 +++++-----
.../sql/catalyst/expressions/CastSuiteBase.scala | 42 ++++-
5 files changed, 202 insertions(+), 119 deletions(-)
diff --git a/python/pyspark/sql/tests/test_types.py
b/python/pyspark/sql/tests/test_types.py
index 8610ace52d86..c240a84d1edb 100644
--- a/python/pyspark/sql/tests/test_types.py
+++ b/python/pyspark/sql/tests/test_types.py
@@ -28,7 +28,6 @@ from dataclasses import dataclass, asdict
from pyspark.sql import Row
from pyspark.sql import functions as F
from pyspark.errors import (
- AnalysisException,
ParseException,
PySparkTypeError,
PySparkValueError,
@@ -1130,10 +1129,17 @@ class TypesTestsMixin:
def test_cast_to_udt_with_udt(self):
row = Row(point=ExamplePoint(1.0, 2.0),
python_only_point=PythonOnlyPoint(1.0, 2.0))
df = self.spark.createDataFrame([row])
- with self.assertRaises(AnalysisException):
- df.select(F.col("point").cast(PythonOnlyUDT())).collect()
- with self.assertRaises(AnalysisException):
-
df.select(F.col("python_only_point").cast(ExamplePointUDT())).collect()
+ result = df.select(F.col("point").cast(PythonOnlyUDT())).collect()
+ self.assertEqual(
+ result,
+ [Row(point=PythonOnlyPoint(1.0, 2.0))],
+ )
+
+ result =
df.select(F.col("python_only_point").cast(ExamplePointUDT())).collect()
+ self.assertEqual(
+ result,
+ [Row(python_only_point=ExamplePoint(1.0, 2.0))],
+ )
def test_struct_type(self):
struct1 = StructType().add("f1", StringType(), True).add("f2",
StringType(), True, None)
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/UpCastRule.scala
b/sql/api/src/main/scala/org/apache/spark/sql/types/UpCastRule.scala
index 4993e249b305..6f2fd41f1f79 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/types/UpCastRule.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/types/UpCastRule.scala
@@ -66,6 +66,10 @@ private[sql] object UpCastRule {
case (from: UserDefinedType[_], to: UserDefinedType[_]) if
to.acceptsType(from) => true
+ case (udt: UserDefinedType[_], toType) => canUpCast(udt.sqlType, toType)
+
+ case (fromType, udt: UserDefinedType[_]) => canUpCast(fromType,
udt.sqlType)
+
case _ => false
}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index 7a2799e99fe2..9a29cb4a2bfb 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -150,6 +150,10 @@ object Cast extends QueryErrorsBase {
case (udt1: UserDefinedType[_], udt2: UserDefinedType[_]) if
udt2.acceptsType(udt1) => true
+ case (udt: UserDefinedType[_], toType) => canAnsiCast(udt.sqlType, toType)
+
+ case (fromType, udt: UserDefinedType[_]) => canAnsiCast(fromType,
udt.sqlType)
+
case _ => false
}
@@ -267,6 +271,10 @@ object Cast extends QueryErrorsBase {
case (udt1: UserDefinedType[_], udt2: UserDefinedType[_]) if
udt2.acceptsType(udt1) => true
+ case (udt: UserDefinedType[_], toType) => canCast(udt.sqlType, toType)
+
+ case (fromType, udt: UserDefinedType[_]) => canCast(fromType, udt.sqlType)
+
case _ => false
}
@@ -1123,33 +1131,42 @@ case class Cast(
variant.VariantGet.cast(v, to, evalMode != EvalMode.TRY, timeZoneId,
zoneId)
})
} else {
- to match {
- case dt if dt == from => identity[Any]
- case VariantType => input =>
variant.VariantExpressionEvalUtils.castToVariant(input, from)
- case _: StringType => castToString(from)
- case BinaryType => castToBinary(from)
- case DateType => castToDate(from)
- case decimal: DecimalType => castToDecimal(from, decimal)
- case TimestampType => castToTimestamp(from)
- case TimestampNTZType => castToTimestampNTZ(from)
- case CalendarIntervalType => castToInterval(from)
- case it: DayTimeIntervalType => castToDayTimeInterval(from, it)
- case it: YearMonthIntervalType => castToYearMonthInterval(from, it)
- case BooleanType => castToBoolean(from)
- case ByteType => castToByte(from)
- case ShortType => castToShort(from)
- case IntegerType => castToInt(from)
- case FloatType => castToFloat(from)
- case LongType => castToLong(from)
- case DoubleType => castToDouble(from)
- case array: ArrayType =>
- castArray(from.asInstanceOf[ArrayType].elementType,
array.elementType)
- case map: MapType => castMap(from.asInstanceOf[MapType], map)
- case struct: StructType => castStruct(from.asInstanceOf[StructType],
struct)
- case udt: UserDefinedType[_] if udt.acceptsType(from) =>
- identity[Any]
- case _: UserDefinedType[_] =>
- throw QueryExecutionErrors.cannotCastError(from, to)
+ from match {
+ // `castToString` has special handling for `UserDefinedType`
+ case udt: UserDefinedType[_] if !to.isInstanceOf[StringType] =>
+ castInternal(udt.sqlType, to)
+ case _ =>
+ to match {
+ case dt if dt == from => identity[Any]
+ case VariantType => input =>
+ variant.VariantExpressionEvalUtils.castToVariant(input, from)
+ case _: StringType => castToString(from)
+ case BinaryType => castToBinary(from)
+ case DateType => castToDate(from)
+ case decimal: DecimalType => castToDecimal(from, decimal)
+ case TimestampType => castToTimestamp(from)
+ case TimestampNTZType => castToTimestampNTZ(from)
+ case CalendarIntervalType => castToInterval(from)
+ case it: DayTimeIntervalType => castToDayTimeInterval(from, it)
+ case it: YearMonthIntervalType => castToYearMonthInterval(from, it)
+ case BooleanType => castToBoolean(from)
+ case ByteType => castToByte(from)
+ case ShortType => castToShort(from)
+ case IntegerType => castToInt(from)
+ case FloatType => castToFloat(from)
+ case LongType => castToLong(from)
+ case DoubleType => castToDouble(from)
+ case array: ArrayType =>
+ castArray(from.asInstanceOf[ArrayType].elementType,
array.elementType)
+ case map: MapType => castMap(from.asInstanceOf[MapType], map)
+ case struct: StructType =>
castStruct(from.asInstanceOf[StructType], struct)
+ case udt: UserDefinedType[_] if udt.acceptsType(from) =>
+ identity[Any]
+ case udt: UserDefinedType[_] =>
+ castInternal(from, udt.sqlType)
+ case _ =>
+ throw QueryExecutionErrors.cannotCastError(from, to)
+ }
}
}
}
@@ -1211,54 +1228,64 @@ case class Cast(
private[this] def nullSafeCastFunction(
from: DataType,
to: DataType,
- ctx: CodegenContext): CastFunction = to match {
-
- case _ if from == NullType => (c, evPrim, evNull) => code"$evNull = true;"
- case _ if to == from => (c, evPrim, evNull) => code"$evPrim = $c;"
- case _ if from.isInstanceOf[VariantType] => (c, evPrim, evNull) =>
- val tmp = ctx.freshVariable("tmp", classOf[Object])
- val dataTypeArg = ctx.addReferenceObj("dataType", to)
- val zoneStrArg = ctx.addReferenceObj("zoneStr", timeZoneId)
- val zoneIdArg = ctx.addReferenceObj("zoneId", zoneId,
classOf[ZoneId].getName)
- val failOnError = evalMode != EvalMode.TRY
- val cls = classOf[variant.VariantGet].getName
- code"""
- Object $tmp = $cls.cast($c, $dataTypeArg, $failOnError, $zoneStrArg,
$zoneIdArg);
- if ($tmp == null) {
- $evNull = true;
- } else {
- $evPrim = (${CodeGenerator.boxedType(to)})$tmp;
+ ctx: CodegenContext): CastFunction = {
+ from match {
+ // `castToStringCode` has special handling for `UserDefinedType`
+ case udt: UserDefinedType[_] if !to.isInstanceOf[StringType] =>
+ nullSafeCastFunction(udt.sqlType, to, ctx)
+ case _ =>
+ to match {
+
+ case _ if from == NullType => (c, evPrim, evNull) => code"$evNull =
true;"
+ case _ if to == from => (c, evPrim, evNull) => code"$evPrim = $c;"
+ case _ if from.isInstanceOf[VariantType] => (c, evPrim, evNull) =>
+ val tmp = ctx.freshVariable("tmp", classOf[Object])
+ val dataTypeArg = ctx.addReferenceObj("dataType", to)
+ val zoneStrArg = ctx.addReferenceObj("zoneStr", timeZoneId)
+ val zoneIdArg = ctx.addReferenceObj("zoneId", zoneId,
classOf[ZoneId].getName)
+ val failOnError = evalMode != EvalMode.TRY
+ val cls = classOf[variant.VariantGet].getName
+ code"""
+ Object $tmp = $cls.cast($c, $dataTypeArg, $failOnError,
$zoneStrArg, $zoneIdArg);
+ if ($tmp == null) {
+ $evNull = true;
+ } else {
+ $evPrim = (${CodeGenerator.boxedType(to)})$tmp;
+ }
+ """
+ case VariantType =>
+ val cls =
variant.VariantExpressionEvalUtils.getClass.getName.stripSuffix("$")
+ val fromArg = ctx.addReferenceObj("from", from)
+ (c, evPrim, evNull) => code"$evPrim = $cls.castToVariant($c,
$fromArg);"
+ case _: StringType => (c, evPrim, _) => castToStringCode(from,
ctx).apply(c, evPrim)
+ case BinaryType => castToBinaryCode(from)
+ case DateType => castToDateCode(from, ctx)
+ case decimal: DecimalType => castToDecimalCode(from, decimal, ctx)
+ case TimestampType => castToTimestampCode(from, ctx)
+ case TimestampNTZType => castToTimestampNTZCode(from, ctx)
+ case CalendarIntervalType => castToIntervalCode(from)
+ case it: DayTimeIntervalType => castToDayTimeIntervalCode(from, it)
+ case it: YearMonthIntervalType => castToYearMonthIntervalCode(from,
it)
+ case BooleanType => castToBooleanCode(from, ctx)
+ case ByteType => castToByteCode(from, ctx)
+ case ShortType => castToShortCode(from, ctx)
+ case IntegerType => castToIntCode(from, ctx)
+ case FloatType => castToFloatCode(from, ctx)
+ case LongType => castToLongCode(from, ctx)
+ case DoubleType => castToDoubleCode(from, ctx)
+
+ case array: ArrayType =>
+ castArrayCode(from.asInstanceOf[ArrayType].elementType,
array.elementType, ctx)
+ case map: MapType => castMapCode(from.asInstanceOf[MapType], map,
ctx)
+ case struct: StructType =>
castStructCode(from.asInstanceOf[StructType], struct, ctx)
+ case udt: UserDefinedType[_] if udt.acceptsType(from) =>
+ (c, evPrim, evNull) => code"$evPrim = $c;"
+ case udt: UserDefinedType[_] =>
+ nullSafeCastFunction(from, udt.sqlType, ctx)
+ case _ =>
+ throw QueryExecutionErrors.cannotCastError(from, to)
}
- """
- case VariantType =>
- val cls =
variant.VariantExpressionEvalUtils.getClass.getName.stripSuffix("$")
- val fromArg = ctx.addReferenceObj("from", from)
- (c, evPrim, evNull) => code"$evPrim = $cls.castToVariant($c, $fromArg);"
- case _: StringType => (c, evPrim, _) => castToStringCode(from,
ctx).apply(c, evPrim)
- case BinaryType => castToBinaryCode(from)
- case DateType => castToDateCode(from, ctx)
- case decimal: DecimalType => castToDecimalCode(from, decimal, ctx)
- case TimestampType => castToTimestampCode(from, ctx)
- case TimestampNTZType => castToTimestampNTZCode(from, ctx)
- case CalendarIntervalType => castToIntervalCode(from)
- case it: DayTimeIntervalType => castToDayTimeIntervalCode(from, it)
- case it: YearMonthIntervalType => castToYearMonthIntervalCode(from, it)
- case BooleanType => castToBooleanCode(from, ctx)
- case ByteType => castToByteCode(from, ctx)
- case ShortType => castToShortCode(from, ctx)
- case IntegerType => castToIntCode(from, ctx)
- case FloatType => castToFloatCode(from, ctx)
- case LongType => castToLongCode(from, ctx)
- case DoubleType => castToDoubleCode(from, ctx)
-
- case array: ArrayType =>
- castArrayCode(from.asInstanceOf[ArrayType].elementType,
array.elementType, ctx)
- case map: MapType => castMapCode(from.asInstanceOf[MapType], map, ctx)
- case struct: StructType => castStructCode(from.asInstanceOf[StructType],
struct, ctx)
- case udt: UserDefinedType[_] if udt.acceptsType(from) =>
- (c, evPrim, evNull) => code"$evPrim = $c;"
- case _: UserDefinedType[_] =>
- throw QueryExecutionErrors.cannotCastError(from, to)
+ }
}
// Since we need to cast input expressions recursively inside ComplexTypes,
such as Map's
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
index 4cffc7f0b53a..362bb9af1661 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
@@ -441,47 +441,53 @@ case class Literal (value: Any, dataType: DataType)
extends LeafExpression {
override def eval(input: InternalRow): Any = value
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
- val javaType = CodeGenerator.javaType(dataType)
- if (value == null) {
- ExprCode.forNullValue(dataType)
- } else {
- def toExprCode(code: String): ExprCode = {
- ExprCode.forNonNullValue(JavaCode.literal(code, dataType))
- }
- dataType match {
- case BooleanType | IntegerType | DateType | _: YearMonthIntervalType =>
- toExprCode(value.toString)
- case FloatType =>
- value.asInstanceOf[Float] match {
- case v if v.isNaN =>
- toExprCode("Float.NaN")
- case Float.PositiveInfinity =>
- toExprCode("Float.POSITIVE_INFINITY")
- case Float.NegativeInfinity =>
- toExprCode("Float.NEGATIVE_INFINITY")
- case _ =>
- toExprCode(s"${value}F")
- }
- case DoubleType =>
- value.asInstanceOf[Double] match {
- case v if v.isNaN =>
- toExprCode("Double.NaN")
- case Double.PositiveInfinity =>
- toExprCode("Double.POSITIVE_INFINITY")
- case Double.NegativeInfinity =>
- toExprCode("Double.NEGATIVE_INFINITY")
- case _ =>
- toExprCode(s"${value}D")
- }
- case ByteType | ShortType =>
- ExprCode.forNonNullValue(JavaCode.expression(s"($javaType)$value",
dataType))
- case TimestampType | TimestampNTZType | LongType | _:
DayTimeIntervalType =>
- toExprCode(s"${value}L")
- case _ =>
- val constRef = ctx.addReferenceObj("literal", value, javaType)
- ExprCode.forNonNullValue(JavaCode.global(constRef, dataType))
+ def gen(ctx: CodegenContext, ev: ExprCode, dataType: DataType): ExprCode =
{
+ val javaType = CodeGenerator.javaType(dataType)
+ if (value == null) {
+ ExprCode.forNullValue(dataType)
+ } else {
+ def toExprCode(code: String): ExprCode = {
+ ExprCode.forNonNullValue(JavaCode.literal(code, dataType))
+ }
+
+ dataType match {
+ case BooleanType | IntegerType | DateType | _: YearMonthIntervalType
=>
+ toExprCode(value.toString)
+ case FloatType =>
+ value.asInstanceOf[Float] match {
+ case v if v.isNaN =>
+ toExprCode("Float.NaN")
+ case Float.PositiveInfinity =>
+ toExprCode("Float.POSITIVE_INFINITY")
+ case Float.NegativeInfinity =>
+ toExprCode("Float.NEGATIVE_INFINITY")
+ case _ =>
+ toExprCode(s"${value}F")
+ }
+ case DoubleType =>
+ value.asInstanceOf[Double] match {
+ case v if v.isNaN =>
+ toExprCode("Double.NaN")
+ case Double.PositiveInfinity =>
+ toExprCode("Double.POSITIVE_INFINITY")
+ case Double.NegativeInfinity =>
+ toExprCode("Double.NEGATIVE_INFINITY")
+ case _ =>
+ toExprCode(s"${value}D")
+ }
+ case ByteType | ShortType =>
+ ExprCode.forNonNullValue(JavaCode.expression(s"($javaType)$value",
dataType))
+ case TimestampType | TimestampNTZType | LongType | _:
DayTimeIntervalType =>
+ toExprCode(s"${value}L")
+ case udt: UserDefinedType[_] =>
+ gen(ctx, ev, udt.sqlType)
+ case _ =>
+ val constRef = ctx.addReferenceObj("literal", value, javaType)
+ ExprCode.forNonNullValue(JavaCode.global(constRef, dataType))
+ }
}
}
+ gen(ctx, ev, dataType)
}
override def sql: String = (value, dataType) match {
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala
index e87b54339821..f915d6efeb82 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala
@@ -18,7 +18,7 @@
package org.apache.spark.sql.catalyst.expressions
import java.sql.{Date, Timestamp}
-import java.time.{Duration, LocalDate, LocalDateTime, Period}
+import java.time.{Duration, LocalDate, LocalDateTime, Period, Year => JYear}
import java.time.temporal.ChronoUnit
import java.util.{Calendar, Locale, TimeZone}
@@ -37,6 +37,7 @@ import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.sql.types.DataTypeTestUtils.{dayTimeIntervalTypes,
yearMonthIntervalTypes}
import org.apache.spark.sql.types.DayTimeIntervalType.{DAY, HOUR, MINUTE,
SECOND}
+import org.apache.spark.sql.types.TestUDT._
import org.apache.spark.sql.types.UpCastRule.numericPrecedence
import org.apache.spark.sql.types.YearMonthIntervalType.{MONTH, YEAR}
import org.apache.spark.unsafe.types.UTF8String
@@ -1409,4 +1410,43 @@ abstract class CastSuiteBase extends SparkFunSuite with
ExpressionEvalHelper {
assert(!Cast(timestampLiteral, TimestampNTZType).resolved)
assert(!Cast(timestampNTZLiteral, TimestampType).resolved)
}
+
+ test("SPARK-49787: Cast between UDT and other types") {
+ val value = new MyDenseVector(Array(1.0, 2.0, -1.0))
+ val udtType = new MyDenseVectorUDT()
+ val targetType = ArrayType(DoubleType, containsNull = false)
+
+ val serialized = udtType.serialize(value)
+
+ checkEvaluation(Cast(new Literal(serialized, udtType), targetType),
serialized)
+ checkEvaluation(Cast(new Literal(serialized, targetType), udtType),
serialized)
+
+ val year = JYear.parse("2024")
+ val yearUDTType = new YearUDT()
+
+ val yearSerialized = yearUDTType.serialize(year)
+
+ checkEvaluation(Cast(new Literal(yearSerialized, yearUDTType),
IntegerType), 2024)
+ checkEvaluation(Cast(new Literal(2024, IntegerType), yearUDTType),
yearSerialized)
+
+ val yearString = UTF8String.fromString("2024")
+ checkEvaluation(Cast(new Literal(yearSerialized, yearUDTType),
StringType), yearString)
+ checkEvaluation(Cast(new Literal(yearString, StringType), yearUDTType),
yearSerialized)
+ }
+}
+
+private[sql] class YearUDT extends UserDefinedType[JYear] {
+ override def sqlType: DataType = IntegerType
+
+ override def serialize(obj: JYear): Int = {
+ obj.getValue
+ }
+
+ def deserialize(datum: Any): JYear = datum match {
+ case value: Int => JYear.of(value)
+ }
+
+ override def userClass: Class[JYear] = classOf[JYear]
+
+ private[spark] override def asNullable: YearUDT = this
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]