beliefer commented on code in PR #40789:
URL: https://github.com/apache/spark/pull/40789#discussion_r1169518326
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala:
##########
@@ -4819,148 +4825,181 @@ case class ArrayInsert(srcArrayExpr: Expression,
posExpr: Expression, itemExpr:
}
override def nullSafeEval(arr: Any, pos: Any, item: Any): Any = {
- var posInt = pos.asInstanceOf[Int]
- if (posInt == 0) {
- throw QueryExecutionErrors.invalidIndexOfZeroError(getContextOrNull())
- }
val baseArr = arr.asInstanceOf[ArrayData]
val arrayElementType = dataType.asInstanceOf[ArrayType].elementType
- val newPosExtendsArrayLeft = (posInt < 0) && (-posInt >
baseArr.numElements())
-
- if (newPosExtendsArrayLeft) {
- // special case- if the new position is negative but larger than the
current array size
- // place the new item at start of array, place the current array
contents at the end
- // and fill the newly created array elements inbetween with a null
-
- val newArrayLength = -posInt + 1
-
+ if (canBeOptimizedForPrepend) {
+ val newArrayLength = baseArr.numElements() + 1
if (newArrayLength > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
throw
QueryExecutionErrors.concatArraysWithElementsExceedLimitError(newArrayLength)
}
-
val newArray = new Array[Any](newArrayLength)
-
- baseArr.foreach(arrayElementType, (i, v) => {
- // current position, offset by new item + new null array elements
- val elementPosition = i + 1 + math.abs(posInt + baseArr.numElements())
- newArray(elementPosition) = v
- })
-
- newArray(0) = item
-
+ newArray.update(0, item)
+ baseArr.foreach(elementType, (i: Int, v: Any) => newArray.update(i + 1,
v))
return new GenericArrayData(newArray)
} else {
- if (posInt < 0) {
- posInt = posInt + baseArr.numElements()
- } else if (posInt > 0) {
- posInt = posInt - 1
+ var posInt = pos.asInstanceOf[Int]
+ if (posInt == 0) {
+ throw QueryExecutionErrors.invalidIndexOfZeroError(getContextOrNull())
}
+ val newPosExtendsArrayLeft = (posInt < 0) && (-posInt >
baseArr.numElements())
- val newArrayLength = math.max(baseArr.numElements() + 1, posInt + 1)
+ if (newPosExtendsArrayLeft) {
+ // special case- if the new position is negative but larger than the
current array size
+ // place the new item at start of array, place the current array
contents at the end
+ // and fill the newly created array elements inbetween with a null
- if (newArrayLength > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
- throw
QueryExecutionErrors.concatArraysWithElementsExceedLimitError(newArrayLength)
- }
+ val newArrayLength = -posInt + 1
- val newArray = new Array[Any](newArrayLength)
+ if (newArrayLength > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
+ throw
QueryExecutionErrors.concatArraysWithElementsExceedLimitError(newArrayLength)
+ }
- baseArr.foreach(arrayElementType, (i, v) => {
- if (i >= posInt) {
- newArray(i + 1) = v
- } else {
- newArray(i) = v
+ val newArray = new Array[Any](newArrayLength)
+
+ baseArr.foreach(elementType, (i, v) => {
+ // current position, offset by new item + new null array elements
+ val elementPosition = i + 1 + math.abs(posInt +
baseArr.numElements())
+ newArray(elementPosition) = v
+ })
+
+ newArray(0) = item
+
+ return new GenericArrayData(newArray)
+ } else {
+ if (posInt < 0) {
+ posInt = posInt + baseArr.numElements()
+ } else if (posInt > 0) {
+ posInt = posInt - 1
}
- })
- newArray(posInt) = item
+ val newArrayLength = math.max(baseArr.numElements() + 1, posInt + 1)
- return new GenericArrayData(newArray)
+ if (newArrayLength > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
+ throw
QueryExecutionErrors.concatArraysWithElementsExceedLimitError(newArrayLength)
+ }
+
+ val newArray = new Array[Any](newArrayLength)
+
+ baseArr.foreach(elementType, (i, v) => {
+ if (i >= posInt) {
+ newArray(i + 1) = v
+ } else {
+ newArray(i) = v
+ }
+ })
+
+ newArray(posInt) = item
+
+ return new GenericArrayData(newArray)
+ }
}
}
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val f = (arrExpr: ExprCode, posExpr: ExprCode, itemExpr: ExprCode) => {
val arr = arrExpr.value
- val pos = posExpr.value
val item = itemExpr.value
- val itemInsertionIndex = ctx.freshName("itemInsertionIndex")
val adjustedAllocIdx = ctx.freshName("adjustedAllocIdx")
val resLength = ctx.freshName("resLength")
val insertedItemIsNull = ctx.freshName("insertedItemIsNull")
val i = ctx.freshName("i")
- val j = ctx.freshName("j")
val values = ctx.freshName("values")
val allocation = CodeGenerator.createArrayData(
values, elementType, resLength, s"$prettyName failed.")
val assignment = CodeGenerator.createArrayAssignment(values,
elementType, arr,
adjustedAllocIdx, i,
first.dataType.asInstanceOf[ArrayType].containsNull)
- val errorContext = getContextOrNullCode(ctx)
- s"""
- |int $itemInsertionIndex = 0;
- |int $resLength = 0;
- |int $adjustedAllocIdx = 0;
- |boolean $insertedItemIsNull = ${itemExpr.isNull};
- |
- |if ($pos == 0) {
- | throw QueryExecutionErrors.invalidIndexOfZeroError($errorContext);
- |}
- |
- |if ($pos < 0 && (java.lang.Math.abs($pos) > $arr.numElements())) {
- |
- | $resLength = java.lang.Math.abs($pos) + 1;
- | if ($resLength > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) {
- | throw
QueryExecutionErrors.createArrayWithElementsExceedLimitError($resLength);
- | }
- |
- | $allocation
- | for (int $i = 0; $i < $arr.numElements(); $i ++) {
- | $adjustedAllocIdx = $i + 1 + java.lang.Math.abs($pos +
$arr.numElements());
- | $assignment
- | }
- | ${CodeGenerator.setArrayElement(
- values, elementType, itemInsertionIndex, item,
Some(insertedItemIsNull))}
- |
- | for (int $j = $pos + $arr.numElements(); $j < 0; $j ++) {
- | $values.setNullAt($j + 1 + java.lang.Math.abs($pos +
$arr.numElements()));
- | }
- |
- | ${ev.value} = $values;
- |} else {
- |
- | $itemInsertionIndex = 0;
- | if ($pos < 0) {
- | $itemInsertionIndex = $pos + $arr.numElements();
- | } else if ($pos > 0) {
- | $itemInsertionIndex = $pos - 1;
- | }
- |
- | $resLength = java.lang.Math.max($arr.numElements() + 1,
$itemInsertionIndex + 1);
- | if ($resLength > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) {
- | throw
QueryExecutionErrors.createArrayWithElementsExceedLimitError($resLength);
- | }
- |
- | $allocation
- | for (int $i = 0; $i < $arr.numElements(); $i ++) {
- | $adjustedAllocIdx = $i;
- | if ($i >= $itemInsertionIndex) {
- | $adjustedAllocIdx = $adjustedAllocIdx + 1;
- | }
- | $assignment
- | }
- | ${CodeGenerator.setArrayElement(
- values, elementType, itemInsertionIndex, item,
Some(insertedItemIsNull))}
- |
- | for (int $j = $arr.numElements(); $j < $resLength - 1; $j ++) {
- | $values.setNullAt($j);
- | }
- |
- | ${ev.value} = $values;
- |}
- """.stripMargin
+ if (canBeOptimizedForPrepend) {
+ val zero = "0"
+ s"""
+ |int $resLength = $arr.numElements() + 1;
+ |int $adjustedAllocIdx = 0;
+ |if ($resLength > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) {
+ | throw
QueryExecutionErrors.createArrayWithElementsExceedLimitError($resLength);
+ |}
+ |boolean $insertedItemIsNull = ${itemExpr.isNull};
+ |
+ |$allocation
+ |for (int $i = 0; $i < $arr.numElements(); $i ++) {
+ | $adjustedAllocIdx = $i + 1;
+ | $assignment
+ |}
+ |${CodeGenerator.setArrayElement(
+ values, elementType, zero, item, Some(insertedItemIsNull))}
+ |
+ |${ev.value} = $values;
+ |""".stripMargin
+ } else {
+ val pos = posExpr.value
+ val itemInsertionIndex = ctx.freshName("itemInsertionIndex")
+ val j = ctx.freshName("j")
+ val errorContext = getContextOrNullCode(ctx)
+ s"""
+ |int $itemInsertionIndex = 0;
+ |int $resLength = 0;
+ |int $adjustedAllocIdx = 0;
+ |boolean $insertedItemIsNull = ${itemExpr.isNull};
+ |
+ |if ($pos == 0) {
+ | throw
QueryExecutionErrors.invalidIndexOfZeroError($errorContext);
+ |}
+ |
+ |if ($pos < 0 && (java.lang.Math.abs($pos) > $arr.numElements())) {
+ |
+ | $resLength = java.lang.Math.abs($pos) + 1;
+ | if ($resLength > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) {
+ | throw
QueryExecutionErrors.createArrayWithElementsExceedLimitError($resLength);
+ | }
+ |
+ | $allocation
+ | for (int $i = 0; $i < $arr.numElements(); $i ++) {
+ | $adjustedAllocIdx = $i + 1 + java.lang.Math.abs($pos +
$arr.numElements());
+ | $assignment
+ | }
+ | ${CodeGenerator.setArrayElement(
+ values, elementType, itemInsertionIndex, item,
Some(insertedItemIsNull))}
+ |
+ | for (int $j = $pos + $arr.numElements(); $j < 0; $j ++) {
+ | $values.setNullAt($j + 1 + java.lang.Math.abs($pos +
$arr.numElements()));
+ | }
+ |
+ | ${ev.value} = $values;
+ |} else {
+ |
Review Comment:
@cloud-fan https://github.com/apache/spark/pull/40833 used to replace this
one.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]