Github user cloud-fan commented on a diff in the pull request:
https://github.com/apache/spark/pull/22468#discussion_r238534101
--- Diff:
sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
---
@@ -535,4 +535,98 @@ class UnsafeRowConverterSuite extends SparkFunSuite
with Matchers with PlanTestB
assert(unsafeRow.getSizeInBytes ==
8 + 8 * 2 + roundedSize(field1.getSizeInBytes) +
roundedSize(field2.getSizeInBytes))
}
+
+ testBothCodegenAndInterpreted("SPARK-25374 converts back into safe
representation") {
+ def convertBackToInternalRow(inputRow: InternalRow, fields:
Array[DataType]): InternalRow = {
+ val unsafeProj = UnsafeProjection.create(fields)
+ val unsafeRow = unsafeProj(inputRow)
+ val safeProj = SafeProjection.create(fields)
+ safeProj(unsafeRow)
+ }
+
+ // Simple tests
+ val inputRow = InternalRow.fromSeq(Seq(
+ false, 3.toByte, 15.toShort, -83, 129L, 1.0f, 8.0,
UTF8String.fromString("test"),
+ Decimal(255), CalendarInterval.fromString("interval 1 day"),
Array[Byte](1, 2)
+ ))
+ val fields1 = Array(
+ BooleanType, ByteType, ShortType, IntegerType, LongType, FloatType,
+ DoubleType, StringType, DecimalType.defaultConcreteType,
CalendarIntervalType,
+ BinaryType)
+
+ assert(convertBackToInternalRow(inputRow, fields1) === inputRow)
+
+ // Array tests
+ val arrayRow = InternalRow.fromSeq(Seq(
+ createArray(1, 2, 3),
+ createArray(
+ createArray(Seq("a", "b", "c").map(UTF8String.fromString): _*),
+ createArray(Seq("d").map(UTF8String.fromString): _*))
+ ))
+ val fields2 = Array[DataType](
+ ArrayType(IntegerType),
+ ArrayType(ArrayType(StringType)))
+
+ assert(convertBackToInternalRow(arrayRow, fields2) === arrayRow)
+
+ // Struct tests
+ val structRow = InternalRow.fromSeq(Seq(
+ InternalRow.fromSeq(Seq[Any](1, 4.0)),
+ InternalRow.fromSeq(Seq(
+ UTF8String.fromString("test"),
+ InternalRow.fromSeq(Seq(
+ 1,
+ createArray(Seq("2", "3").map(UTF8String.fromString): _*)
+ ))
+ ))
+ ))
+ val fields3 = Array[DataType](
+ StructType(
+ StructField("c0", IntegerType) ::
+ StructField("c1", DoubleType) ::
+ Nil),
+ StructType(
+ StructField("c2", StringType) ::
+ StructField("c3", StructType(
+ StructField("c4", IntegerType) ::
+ StructField("c5", ArrayType(StringType)) ::
+ Nil)) ::
+ Nil))
+
+ assert(convertBackToInternalRow(structRow, fields3) === structRow)
+
+ // Map tests
+ val mapRow = InternalRow.fromSeq(Seq(
+ createMap(Seq("k1", "k2").map(UTF8String.fromString): _*)(1, 2),
+ createMap(
+ createMap(3, 5)(Seq("v1", "v2").map(UTF8String.fromString): _*),
+ createMap(7, 9)(Seq("v3", "v4").map(UTF8String.fromString): _*)
+ )(
+ createMap(Seq("k3", "k4").map(UTF8String.fromString):
_*)(3.toShort, 4.toShort),
+ createMap(Seq("k5", "k6").map(UTF8String.fromString):
_*)(5.toShort, 6.toShort)
+ )))
+ val fields4 = Array[DataType](
+ MapType(StringType, IntegerType),
+ MapType(MapType(IntegerType, StringType), MapType(StringType,
ShortType)))
+
+ val mapResultRow = convertBackToInternalRow(mapRow,
fields4).toSeq(fields4)
+ val mapExpectedRow = mapRow.toSeq(fields4)
+ // Since `ArrayBasedMapData` does not override `equals` and `hashCode`,
--- End diff --
Or we can use `ExpressionEvalHelper.checkResult` here.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]