beliefer commented on code in PR #40789:
URL: https://github.com/apache/spark/pull/40789#discussion_r1169518326


##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala:
##########
@@ -4819,148 +4825,181 @@ case class ArrayInsert(srcArrayExpr: Expression, 
posExpr: Expression, itemExpr:
   }
 
   override def nullSafeEval(arr: Any, pos: Any, item: Any): Any = {
-    var posInt = pos.asInstanceOf[Int]
-    if (posInt == 0) {
-      throw QueryExecutionErrors.invalidIndexOfZeroError(getContextOrNull())
-    }
     val baseArr = arr.asInstanceOf[ArrayData]
     val arrayElementType = dataType.asInstanceOf[ArrayType].elementType
 
-    val newPosExtendsArrayLeft = (posInt < 0) && (-posInt > 
baseArr.numElements())
-
-    if (newPosExtendsArrayLeft) {
-      // special case- if the new position is negative but larger than the 
current array size
-      // place the new item at start of array, place the current array 
contents at the end
-      // and fill the newly created array elements inbetween with a null
-
-      val newArrayLength = -posInt + 1
-
+    if (canBeOptimizedForPrepend) {
+      val newArrayLength = baseArr.numElements() + 1
       if (newArrayLength > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
         throw 
QueryExecutionErrors.concatArraysWithElementsExceedLimitError(newArrayLength)
       }
-
       val newArray = new Array[Any](newArrayLength)
-
-      baseArr.foreach(arrayElementType, (i, v) => {
-        // current position, offset by new item + new null array elements
-        val elementPosition = i + 1 + math.abs(posInt + baseArr.numElements())
-        newArray(elementPosition) = v
-      })
-
-      newArray(0) = item
-
+      newArray.update(0, item)
+      baseArr.foreach(elementType, (i: Int, v: Any) => newArray.update(i + 1, 
v))
       return new GenericArrayData(newArray)
     } else {
-      if (posInt < 0) {
-        posInt = posInt + baseArr.numElements()
-      } else if (posInt > 0) {
-        posInt = posInt - 1
+      var posInt = pos.asInstanceOf[Int]
+      if (posInt == 0) {
+        throw QueryExecutionErrors.invalidIndexOfZeroError(getContextOrNull())
       }
+      val newPosExtendsArrayLeft = (posInt < 0) && (-posInt > 
baseArr.numElements())
 
-      val newArrayLength = math.max(baseArr.numElements() + 1, posInt + 1)
+      if (newPosExtendsArrayLeft) {
+        // special case- if the new position is negative but larger than the 
current array size
+        // place the new item at start of array, place the current array 
contents at the end
+        // and fill the newly created array elements inbetween with a null
 
-      if (newArrayLength > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
-        throw 
QueryExecutionErrors.concatArraysWithElementsExceedLimitError(newArrayLength)
-      }
+        val newArrayLength = -posInt + 1
 
-      val newArray = new Array[Any](newArrayLength)
+        if (newArrayLength > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
+          throw 
QueryExecutionErrors.concatArraysWithElementsExceedLimitError(newArrayLength)
+        }
 
-      baseArr.foreach(arrayElementType, (i, v) => {
-        if (i >= posInt) {
-          newArray(i + 1) = v
-        } else {
-          newArray(i) = v
+        val newArray = new Array[Any](newArrayLength)
+
+        baseArr.foreach(elementType, (i, v) => {
+          // current position, offset by new item + new null array elements
+          val elementPosition = i + 1 + math.abs(posInt + 
baseArr.numElements())
+          newArray(elementPosition) = v
+        })
+
+        newArray(0) = item
+
+        return new GenericArrayData(newArray)
+      } else {
+        if (posInt < 0) {
+          posInt = posInt + baseArr.numElements()
+        } else if (posInt > 0) {
+          posInt = posInt - 1
         }
-      })
 
-      newArray(posInt) = item
+        val newArrayLength = math.max(baseArr.numElements() + 1, posInt + 1)
 
-      return new GenericArrayData(newArray)
+        if (newArrayLength > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
+          throw 
QueryExecutionErrors.concatArraysWithElementsExceedLimitError(newArrayLength)
+        }
+
+        val newArray = new Array[Any](newArrayLength)
+
+        baseArr.foreach(elementType, (i, v) => {
+          if (i >= posInt) {
+            newArray(i + 1) = v
+          } else {
+            newArray(i) = v
+          }
+        })
+
+        newArray(posInt) = item
+
+        return new GenericArrayData(newArray)
+      }
     }
   }
 
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
     val f = (arrExpr: ExprCode, posExpr: ExprCode, itemExpr: ExprCode) => {
       val arr = arrExpr.value
-      val pos = posExpr.value
       val item = itemExpr.value
 
-      val itemInsertionIndex = ctx.freshName("itemInsertionIndex")
       val adjustedAllocIdx = ctx.freshName("adjustedAllocIdx")
       val resLength = ctx.freshName("resLength")
       val insertedItemIsNull = ctx.freshName("insertedItemIsNull")
       val i = ctx.freshName("i")
-      val j = ctx.freshName("j")
       val values = ctx.freshName("values")
 
       val allocation = CodeGenerator.createArrayData(
         values, elementType, resLength, s"$prettyName failed.")
       val assignment = CodeGenerator.createArrayAssignment(values, 
elementType, arr,
         adjustedAllocIdx, i, 
first.dataType.asInstanceOf[ArrayType].containsNull)
-      val errorContext = getContextOrNullCode(ctx)
 
-      s"""
-         |int $itemInsertionIndex = 0;
-         |int $resLength = 0;
-         |int $adjustedAllocIdx = 0;
-         |boolean $insertedItemIsNull = ${itemExpr.isNull};
-         |
-         |if ($pos == 0) {
-         |  throw QueryExecutionErrors.invalidIndexOfZeroError($errorContext);
-         |}
-         |
-         |if ($pos < 0 && (java.lang.Math.abs($pos) > $arr.numElements())) {
-         |
-         |  $resLength = java.lang.Math.abs($pos) + 1;
-         |  if ($resLength > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) {
-         |    throw 
QueryExecutionErrors.createArrayWithElementsExceedLimitError($resLength);
-         |  }
-         |
-         |  $allocation
-         |  for (int $i = 0; $i < $arr.numElements(); $i ++) {
-         |    $adjustedAllocIdx = $i + 1 + java.lang.Math.abs($pos + 
$arr.numElements());
-         |    $assignment
-         |  }
-         |  ${CodeGenerator.setArrayElement(
-              values, elementType, itemInsertionIndex, item, 
Some(insertedItemIsNull))}
-         |
-         |  for (int $j = $pos + $arr.numElements(); $j < 0; $j ++) {
-         |    $values.setNullAt($j + 1 + java.lang.Math.abs($pos + 
$arr.numElements()));
-         |  }
-         |
-         |  ${ev.value} = $values;
-         |} else {
-         |
-         |  $itemInsertionIndex = 0;
-         |  if ($pos < 0) {
-         |    $itemInsertionIndex = $pos + $arr.numElements();
-         |  } else if ($pos > 0) {
-         |    $itemInsertionIndex = $pos - 1;
-         |  }
-         |
-         |  $resLength = java.lang.Math.max($arr.numElements() + 1, 
$itemInsertionIndex + 1);
-         |  if ($resLength > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) {
-         |    throw 
QueryExecutionErrors.createArrayWithElementsExceedLimitError($resLength);
-         |  }
-         |
-         |  $allocation
-         |  for (int $i = 0; $i < $arr.numElements(); $i ++) {
-         |    $adjustedAllocIdx = $i;
-         |    if ($i >= $itemInsertionIndex) {
-         |      $adjustedAllocIdx = $adjustedAllocIdx + 1;
-         |    }
-         |    $assignment
-         |  }
-         |  ${CodeGenerator.setArrayElement(
-              values, elementType, itemInsertionIndex, item, 
Some(insertedItemIsNull))}
-         |
-         |  for (int $j = $arr.numElements(); $j < $resLength - 1; $j ++) {
-         |    $values.setNullAt($j);
-         |  }
-         |
-         |  ${ev.value} = $values;
-         |}
-      """.stripMargin
+      if (canBeOptimizedForPrepend) {
+        val zero = "0"
+        s"""
+           |int $resLength = $arr.numElements() + 1;
+           |int $adjustedAllocIdx = 0;
+           |if ($resLength > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) {
+           |  throw 
QueryExecutionErrors.createArrayWithElementsExceedLimitError($resLength);
+           |}
+           |boolean $insertedItemIsNull = ${itemExpr.isNull};
+           |
+           |$allocation
+           |for (int $i = 0; $i < $arr.numElements(); $i ++) {
+           |  $adjustedAllocIdx = $i + 1;
+           |  $assignment
+           |}
+           |${CodeGenerator.setArrayElement(
+              values, elementType, zero, item, Some(insertedItemIsNull))}
+           |
+           |${ev.value} = $values;
+           |""".stripMargin
+      } else {
+        val pos = posExpr.value
+        val itemInsertionIndex = ctx.freshName("itemInsertionIndex")
+        val j = ctx.freshName("j")
+        val errorContext = getContextOrNullCode(ctx)
+        s"""
+           |int $itemInsertionIndex = 0;
+           |int $resLength = 0;
+           |int $adjustedAllocIdx = 0;
+           |boolean $insertedItemIsNull = ${itemExpr.isNull};
+           |
+           |if ($pos == 0) {
+           |  throw 
QueryExecutionErrors.invalidIndexOfZeroError($errorContext);
+           |}
+           |
+           |if ($pos < 0 && (java.lang.Math.abs($pos) > $arr.numElements())) {
+           |
+           |  $resLength = java.lang.Math.abs($pos) + 1;
+           |  if ($resLength > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) {
+           |    throw 
QueryExecutionErrors.createArrayWithElementsExceedLimitError($resLength);
+           |  }
+           |
+           |  $allocation
+           |  for (int $i = 0; $i < $arr.numElements(); $i ++) {
+           |    $adjustedAllocIdx = $i + 1 + java.lang.Math.abs($pos + 
$arr.numElements());
+           |    $assignment
+           |  }
+           |  ${CodeGenerator.setArrayElement(
+                values, elementType, itemInsertionIndex, item, 
Some(insertedItemIsNull))}
+           |
+           |  for (int $j = $pos + $arr.numElements(); $j < 0; $j ++) {
+           |    $values.setNullAt($j + 1 + java.lang.Math.abs($pos + 
$arr.numElements()));
+           |  }
+           |
+           |  ${ev.value} = $values;
+           |} else {
+           |

Review Comment:
   @cloud-fan https://github.com/apache/spark/pull/40833 used to replace this 
one.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to