cloud-fan commented on code in PR #40789:
URL: https://github.com/apache/spark/pull/40789#discussion_r1166805724
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala:
##########
@@ -4819,148 +4825,181 @@ case class ArrayInsert(srcArrayExpr: Expression,
posExpr: Expression, itemExpr:
}
override def nullSafeEval(arr: Any, pos: Any, item: Any): Any = {
- var posInt = pos.asInstanceOf[Int]
- if (posInt == 0) {
- throw QueryExecutionErrors.invalidIndexOfZeroError(getContextOrNull())
- }
val baseArr = arr.asInstanceOf[ArrayData]
val arrayElementType = dataType.asInstanceOf[ArrayType].elementType
- val newPosExtendsArrayLeft = (posInt < 0) && (-posInt >
baseArr.numElements())
-
- if (newPosExtendsArrayLeft) {
- // special case- if the new position is negative but larger than the
current array size
- // place the new item at start of array, place the current array
contents at the end
- // and fill the newly created array elements inbetween with a null
-
- val newArrayLength = -posInt + 1
-
+ if (canBeOptimizedForPrepend) {
+ val newArrayLength = baseArr.numElements() + 1
if (newArrayLength > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
throw
QueryExecutionErrors.concatArraysWithElementsExceedLimitError(newArrayLength)
}
-
val newArray = new Array[Any](newArrayLength)
-
- baseArr.foreach(arrayElementType, (i, v) => {
- // current position, offset by new item + new null array elements
- val elementPosition = i + 1 + math.abs(posInt + baseArr.numElements())
- newArray(elementPosition) = v
- })
-
- newArray(0) = item
-
+ newArray.update(0, item)
+ baseArr.foreach(elementType, (i: Int, v: Any) => newArray.update(i + 1,
v))
return new GenericArrayData(newArray)
} else {
- if (posInt < 0) {
- posInt = posInt + baseArr.numElements()
- } else if (posInt > 0) {
- posInt = posInt - 1
+ var posInt = pos.asInstanceOf[Int]
+ if (posInt == 0) {
+ throw QueryExecutionErrors.invalidIndexOfZeroError(getContextOrNull())
}
+ val newPosExtendsArrayLeft = (posInt < 0) && (-posInt >
baseArr.numElements())
- val newArrayLength = math.max(baseArr.numElements() + 1, posInt + 1)
+ if (newPosExtendsArrayLeft) {
+ // special case- if the new position is negative but larger than the
current array size
+ // place the new item at start of array, place the current array
contents at the end
+ // and fill the newly created array elements inbetween with a null
- if (newArrayLength > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
- throw
QueryExecutionErrors.concatArraysWithElementsExceedLimitError(newArrayLength)
- }
+ val newArrayLength = -posInt + 1
- val newArray = new Array[Any](newArrayLength)
+ if (newArrayLength > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
+ throw
QueryExecutionErrors.concatArraysWithElementsExceedLimitError(newArrayLength)
+ }
- baseArr.foreach(arrayElementType, (i, v) => {
- if (i >= posInt) {
- newArray(i + 1) = v
- } else {
- newArray(i) = v
+ val newArray = new Array[Any](newArrayLength)
+
+ baseArr.foreach(elementType, (i, v) => {
+ // current position, offset by new item + new null array elements
+ val elementPosition = i + 1 + math.abs(posInt +
baseArr.numElements())
+ newArray(elementPosition) = v
+ })
+
+ newArray(0) = item
+
+ return new GenericArrayData(newArray)
+ } else {
+ if (posInt < 0) {
+ posInt = posInt + baseArr.numElements()
+ } else if (posInt > 0) {
+ posInt = posInt - 1
}
- })
- newArray(posInt) = item
+ val newArrayLength = math.max(baseArr.numElements() + 1, posInt + 1)
- return new GenericArrayData(newArray)
+ if (newArrayLength > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
+ throw
QueryExecutionErrors.concatArraysWithElementsExceedLimitError(newArrayLength)
+ }
+
+ val newArray = new Array[Any](newArrayLength)
+
+ baseArr.foreach(elementType, (i, v) => {
+ if (i >= posInt) {
+ newArray(i + 1) = v
+ } else {
+ newArray(i) = v
+ }
+ })
+
+ newArray(posInt) = item
+
+ return new GenericArrayData(newArray)
+ }
}
}
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val f = (arrExpr: ExprCode, posExpr: ExprCode, itemExpr: ExprCode) => {
val arr = arrExpr.value
- val pos = posExpr.value
val item = itemExpr.value
- val itemInsertionIndex = ctx.freshName("itemInsertionIndex")
val adjustedAllocIdx = ctx.freshName("adjustedAllocIdx")
val resLength = ctx.freshName("resLength")
val insertedItemIsNull = ctx.freshName("insertedItemIsNull")
val i = ctx.freshName("i")
- val j = ctx.freshName("j")
val values = ctx.freshName("values")
val allocation = CodeGenerator.createArrayData(
values, elementType, resLength, s"$prettyName failed.")
val assignment = CodeGenerator.createArrayAssignment(values,
elementType, arr,
adjustedAllocIdx, i,
first.dataType.asInstanceOf[ArrayType].containsNull)
- val errorContext = getContextOrNullCode(ctx)
- s"""
- |int $itemInsertionIndex = 0;
- |int $resLength = 0;
- |int $adjustedAllocIdx = 0;
- |boolean $insertedItemIsNull = ${itemExpr.isNull};
- |
- |if ($pos == 0) {
- | throw QueryExecutionErrors.invalidIndexOfZeroError($errorContext);
- |}
- |
- |if ($pos < 0 && (java.lang.Math.abs($pos) > $arr.numElements())) {
- |
- | $resLength = java.lang.Math.abs($pos) + 1;
- | if ($resLength > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) {
- | throw
QueryExecutionErrors.createArrayWithElementsExceedLimitError($resLength);
- | }
- |
- | $allocation
- | for (int $i = 0; $i < $arr.numElements(); $i ++) {
- | $adjustedAllocIdx = $i + 1 + java.lang.Math.abs($pos +
$arr.numElements());
- | $assignment
- | }
- | ${CodeGenerator.setArrayElement(
- values, elementType, itemInsertionIndex, item,
Some(insertedItemIsNull))}
- |
- | for (int $j = $pos + $arr.numElements(); $j < 0; $j ++) {
- | $values.setNullAt($j + 1 + java.lang.Math.abs($pos +
$arr.numElements()));
- | }
- |
- | ${ev.value} = $values;
- |} else {
- |
- | $itemInsertionIndex = 0;
- | if ($pos < 0) {
- | $itemInsertionIndex = $pos + $arr.numElements();
- | } else if ($pos > 0) {
- | $itemInsertionIndex = $pos - 1;
- | }
- |
- | $resLength = java.lang.Math.max($arr.numElements() + 1,
$itemInsertionIndex + 1);
- | if ($resLength > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) {
- | throw
QueryExecutionErrors.createArrayWithElementsExceedLimitError($resLength);
- | }
- |
- | $allocation
- | for (int $i = 0; $i < $arr.numElements(); $i ++) {
- | $adjustedAllocIdx = $i;
- | if ($i >= $itemInsertionIndex) {
- | $adjustedAllocIdx = $adjustedAllocIdx + 1;
- | }
- | $assignment
- | }
- | ${CodeGenerator.setArrayElement(
- values, elementType, itemInsertionIndex, item,
Some(insertedItemIsNull))}
- |
- | for (int $j = $arr.numElements(); $j < $resLength - 1; $j ++) {
- | $values.setNullAt($j);
- | }
- |
- | ${ev.value} = $values;
- |}
- """.stripMargin
+ if (canBeOptimizedForPrepend) {
+ val zero = "0"
+ s"""
+ |int $resLength = $arr.numElements() + 1;
+ |int $adjustedAllocIdx = 0;
+ |if ($resLength > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) {
+ | throw
QueryExecutionErrors.createArrayWithElementsExceedLimitError($resLength);
+ |}
+ |boolean $insertedItemIsNull = ${itemExpr.isNull};
+ |
+ |$allocation
+ |for (int $i = 0; $i < $arr.numElements(); $i ++) {
+ | $adjustedAllocIdx = $i + 1;
+ | $assignment
+ |}
+ |${CodeGenerator.setArrayElement(
+ values, elementType, zero, item, Some(insertedItemIsNull))}
+ |
+ |${ev.value} = $values;
+ |""".stripMargin
+ } else {
+ val pos = posExpr.value
+ val itemInsertionIndex = ctx.freshName("itemInsertionIndex")
+ val j = ctx.freshName("j")
+ val errorContext = getContextOrNullCode(ctx)
+ s"""
+ |int $itemInsertionIndex = 0;
+ |int $resLength = 0;
+ |int $adjustedAllocIdx = 0;
+ |boolean $insertedItemIsNull = ${itemExpr.isNull};
+ |
+ |if ($pos == 0) {
+ | throw
QueryExecutionErrors.invalidIndexOfZeroError($errorContext);
+ |}
+ |
+ |if ($pos < 0 && (java.lang.Math.abs($pos) > $arr.numElements())) {
+ |
+ | $resLength = java.lang.Math.abs($pos) + 1;
+ | if ($resLength > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) {
+ | throw
QueryExecutionErrors.createArrayWithElementsExceedLimitError($resLength);
+ | }
+ |
+ | $allocation
+ | for (int $i = 0; $i < $arr.numElements(); $i ++) {
+ | $adjustedAllocIdx = $i + 1 + java.lang.Math.abs($pos +
$arr.numElements());
+ | $assignment
+ | }
+ | ${CodeGenerator.setArrayElement(
+ values, elementType, itemInsertionIndex, item,
Some(insertedItemIsNull))}
+ |
+ | for (int $j = $pos + $arr.numElements(); $j < 0; $j ++) {
+ | $values.setNullAt($j + 1 + java.lang.Math.abs($pos +
$arr.numElements()));
+ | }
+ |
+ | ${ev.value} = $values;
+ |} else {
+ |
Review Comment:
can you take a close look at the generated code and see how we can simplify
it if the position is a constant value? By just 10 second look, I found that if
the position is constant and positive, the code can be simplified to
```
final int $resLength = java.lang.Math.max($arr.numElements() + 1, $pos);
...
```
Let's not be so specific that only target the case when `position` is `1`.
We can generate code when position is constant and positive, which is a very
common case.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]