panbingkun commented on code in PR #47984:
URL: https://github.com/apache/spark/pull/47984#discussion_r1770932660
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala:
##########
@@ -1525,6 +1526,126 @@ case class ArrayContains(left: Expression, right:
Expression)
copy(left = newLeft, right = newRight)
}
+/**
+ * This expression converts data of `ArrayData` to an array of java type.
+ *
+ * NOTE: When the data type of expression is `ArrayType`, and the expression
is foldable,
+ * the `ConstantFolding` can do constant folding optimization automatically,
+ * (avoiding frequent calls to `ArrayData.to{XXX}Array()`).
+ */
+case class ToJavaArray(array: Expression)
+ extends UnaryExpression
+ with ImplicitCastInputTypes
+ with NullIntolerant
+ with QueryErrorsBase {
+
+ override def checkInputDataTypes(): TypeCheckResult = array.dataType match {
+ case ArrayType(_, _) =>
+ TypeCheckResult.TypeCheckSuccess
+ case _ =>
+ DataTypeMismatch(
+ errorSubClass = "UNEXPECTED_INPUT_TYPE",
+ messageParameters = Map(
+ "paramIndex" -> ordinalNumber(0),
+ "requiredType" -> toSQLType(ArrayType),
+ "inputSql" -> toSQLExpr(array),
+ "inputType" -> toSQLType(array.dataType))
+ )
+ }
+
+ override def inputTypes: Seq[AbstractDataType] = Seq(array.dataType)
+ override def dataType: DataType = {
+ if (canPerformFast) {
+ elementType match {
+ case ByteType => ObjectType(classOf[Array[Byte]])
+ case ShortType => ObjectType(classOf[Array[Short]])
+ case IntegerType => ObjectType(classOf[Array[Int]])
+ case LongType => ObjectType(classOf[Array[Long]])
+ case FloatType => ObjectType(classOf[Array[Float]])
+ case DoubleType => ObjectType(classOf[Array[Double]])
+ }
+ } else if (isPrimitiveType) {
+ elementType match {
+ case BooleanType => ObjectType(classOf[Array[java.lang.Boolean]])
+ case ByteType => ObjectType(classOf[Array[java.lang.Byte]])
+ case ShortType => ObjectType(classOf[Array[java.lang.Short]])
+ case IntegerType => ObjectType(classOf[Array[java.lang.Integer]])
+ case LongType => ObjectType(classOf[Array[java.lang.Long]])
+ case FloatType => ObjectType(classOf[Array[java.lang.Float]])
+ case DoubleType => ObjectType(classOf[Array[java.lang.Double]])
+ }
+ } else {
+ ObjectType(classOf[Array[Object]])
+ }
+ }
+
+ override def child: Expression = array
+ override def prettyName: String = "to_java_array"
+
+ @transient lazy val elementType: DataType =
+ array.dataType.asInstanceOf[ArrayType].elementType
+ private def resultArrayElementNullable: Boolean =
+ array.dataType.asInstanceOf[ArrayType].containsNull
+ private def isPrimitiveType: Boolean =
CodeGenerator.isPrimitiveType(elementType)
+ private def canPerformFast: Boolean =
+ isPrimitiveType && elementType != BooleanType &&
!resultArrayElementNullable
+
+ private def toJavaArray(array: Any): Any = {
+ val arrayData = array.asInstanceOf[ArrayData]
+ if (canPerformFast) {
+ elementType match {
+ case ByteType => arrayData.toByteArray()
+ case ShortType => arrayData.toShortArray()
+ case IntegerType => arrayData.toIntArray()
+ case LongType => arrayData.toLongArray()
+ case FloatType => arrayData.toFloatArray()
+ case DoubleType => arrayData.toDoubleArray()
+ }
+ } else if (isPrimitiveType) {
+ elementType match {
+ case BooleanType => arrayData.toArray[java.lang.Boolean](BooleanType)
+ case ByteType => arrayData.toArray[java.lang.Byte](ByteType)
+ case ShortType => arrayData.toArray[java.lang.Short](ShortType)
+ case IntegerType => arrayData.toArray[java.lang.Integer](IntegerType)
+ case LongType => arrayData.toArray[java.lang.Long](LongType)
+ case FloatType => arrayData.toArray[java.lang.Float](FloatType)
+ case DoubleType => arrayData.toArray[java.lang.Double](DoubleType)
Review Comment:
If we use `Invoke` to implement `ToJavaArray`, we must wrap it with
`java.lang.xxx`, otherwise `ConstantFolding` cannot optimize it because:
https://github.com/apache/spark/blob/d2e8c1cb60e34a1c7e92374c07d682aa5ca79145/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala#L74-L75
https://github.com/apache/spark/blob/d2e8c1cb60e34a1c7e92374c07d682aa5ca79145/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala#L88-L94
Cannot optimize `ObjectType(java.lang.Object)`
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]