Github user cloud-fan commented on a diff in the pull request:
https://github.com/apache/spark/pull/19730#discussion_r152222498
--- Diff:
sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
---
@@ -827,4 +827,49 @@ class CastSuite extends SparkFunSuite with
ExpressionEvalHelper {
checkEvaluation(cast(Literal.create(input, from), to), input)
}
+
+ test("SPARK-22500: cast for struct should not generate codes beyond
64KB") {
+ val N = 1000
+ val M = 250
+
+ val from1 = new StructType(
+ (1 to N).map(i => StructField(s"s$i", StringType)).toArray)
+ val to1 = new StructType(
+ (1 to N).map(i => StructField(s"i$i", IntegerType)).toArray)
+ val input1 = Row.fromSeq((1 to N).map(i => i.toString))
+ val output1 = Row.fromSeq((1 to N))
+ checkEvaluation(cast(Literal.create(input1, from1), to1), output1)
+
+ val from2 = new StructType(
+ (1 to N).map(i => StructField(s"a$i", ArrayType(StringType,
containsNull = false))).toArray)
+ val to2 = new StructType(
+ (1 to N).map(i => StructField(s"i$i", ArrayType(IntegerType,
containsNull = true))).toArray)
+ val input2 = Row.fromSeq((1 to N).map(_ => Seq("456", "true", "78.9")))
+ val output2 = Row.fromSeq((1 to N).map(_ => Seq(456, null, 78)))
+ checkEvaluation(cast(Literal.create(input2, from2), to2), output2)
+
+ val from3 = new StructType(
+ (1 to N).map(i => StructField(s"s$i",
+ StructType(Seq(StructField("l$i", IntegerType, nullable =
true))))).toArray)
+ val to3 = new StructType(
+ (1 to N).map(i => StructField(s"s$i",
+ StructType(Seq(StructField("l$i", LongType, nullable =
true))))).toArray)
+ val input3 = Row.fromSeq((1 to N).map(i => Row(i)))
+ val output3 = Row.fromSeq((1 to N).map(i => Row(i.toLong)))
+ checkEvaluation(cast(Literal.create(input3, from3), to3), output3)
+
+ val fromInner = new StructType(
+ (1 to M).map(i => StructField(s"s$i", DoubleType)).toArray)
+ val toInner = new StructType(
+ (1 to M).map(i => StructField(s"i$i", IntegerType)).toArray)
+ val inputInner = Row.fromSeq((1 to M).map(i => i + 0.5))
+ val outputInner = Row.fromSeq((1 to M))
+ val fromOuter = new StructType(
+ (1 to M).map(i => StructField(s"s$i", fromInner)).toArray)
+ val toOuter = new StructType(
+ (1 to M).map(i => StructField(s"s$i", toInner)).toArray)
+ val inputOuter = Row.fromSeq((1 to M).map(_ => inputInner))
+ val outputOuter = Row.fromSeq((1 to M).map(_ => outputInner))
+ checkEvaluation(cast(Literal.create(inputOuter, fromOuter), toOuter),
outputOuter)
--- End diff --
I think this case is good enough to cover all the above cases?
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]