Github user kiszk commented on a diff in the pull request:

    https://github.com/apache/spark/pull/21121#discussion_r183947619
  
    --- Diff: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
 ---
    @@ -883,3 +884,157 @@ case class Concat(children: Seq[Expression]) extends 
Expression {
     
       override def sql: String = s"concat(${children.map(_.sql).mkString(", 
")})"
     }
    +
    +/**
    + * Transforms an array by assigning an order number to each element.
    + */
    +// scalastyle:off line.size.limit
    +@ExpressionDescription(
    +  usage = "_FUNC_(array[, indexFirst, startFromZero]) - Transforms the 
input array by encapsulating elements into pairs with indexes indicating the 
order.",
    +  examples = """
    +    Examples:
    +      > SELECT _FUNC_(array("d", "a", null, "b"));
    +       [("d",1),("a",2),(null,3),("b",4)]
    +      > SELECT _FUNC_(array("d", "a", null, "b"), true, false);
    +       [(1,"d"),(2,"a"),(3,null),(4,"b")]
    +      > SELECT _FUNC_(array("d", "a", null, "b"), true, true);
    +       [(0,"d"),(1,"a"),(2,null),(3,"b")]
    +  """,
    +  since = "2.4.0")
    +// scalastyle:on line.size.limit
    +case class ZipWithIndex(child: Expression, indexFirst: Expression, 
startFromZero: Expression)
    +  extends UnaryExpression with ExpectsInputTypes {
    +
    +  def this(e: Expression) = this(e, Literal.FalseLiteral, 
Literal.FalseLiteral)
    +
    +  def exprToFlag(e: Expression, order: String): Boolean = e match {
    +    case Literal(v: Boolean, BooleanType) => v
    +    case _ => throw new AnalysisException(s"The $order argument has to be 
a boolean constant.")
    +  }
    +
    +  private val idxFirst: Boolean = exprToFlag(indexFirst, "second")
    +
    +  private val (idxShift, idxGen): (Int, String) = if 
(exprToFlag(startFromZero, "third")) {
    +    (0, "z")
    +  } else {
    +    (1, "z + 1")
    +  }
    +
    +  private val MAX_ARRAY_LENGTH: Int = 
ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH
    +
    +  override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType)
    +
    +  lazy val childArrayType: ArrayType = 
child.dataType.asInstanceOf[ArrayType]
    +
    +  override def dataType: DataType = {
    +    val elementField = StructField("value", childArrayType.elementType, 
childArrayType.containsNull)
    +    val indexField = StructField("index", IntegerType, false)
    +
    +    val fields = if (idxFirst) Seq(indexField, elementField) else 
Seq(elementField, indexField)
    +
    +    ArrayType(StructType(fields), false)
    +  }
    +
    +  override protected def nullSafeEval(input: Any): Any = {
    +    val array = 
input.asInstanceOf[ArrayData].toObjectArray(childArrayType.elementType)
    +
    +    val makeStruct = (v: Any, i: Int) => if (idxFirst) InternalRow(i, v) 
else InternalRow(v, i)
    +    val resultData = array.zipWithIndex.map{case (v, i) => makeStruct(v, i 
+ idxShift)}
    +
    +    new GenericArrayData(resultData)
    +  }
    +
    +  override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): 
ExprCode = {
    +    nullSafeCodeGen(ctx, ev, c => {
    +      val numElements = ctx.freshName("numElements")
    +      val code = if 
(CodeGenerator.isPrimitiveType(childArrayType.elementType)) {
    +        genCodeForPrimitiveElements(ctx, c, ev.value, numElements)
    +      } else {
    +        genCodeForAnyElements(ctx, c, ev.value, numElements)
    +      }
    +      s"""
    +         |final int $numElements = $c.numElements();
    +         |$code
    +       """.stripMargin
    +    })
    +  }
    +
    +  private def genCodeForPrimitiveElements(
    +      ctx: CodegenContext,
    +      childVariableName: String,
    +      arrayData: String,
    +      numElements: String): String = {
    +    val byteArraySize = ctx.freshName("byteArraySize")
    +    val data = ctx.freshName("byteArray")
    +    val unsafeRow = ctx.freshName("unsafeRow")
    +    val structSize = ctx.freshName("structSize")
    +    val unsafeArrayData = ctx.freshName("unsafeArrayData")
    +    val structsOffset = ctx.freshName("structsOffset")
    +    val calculateArraySize = 
"UnsafeArrayData.calculateSizeOfUnderlyingByteArray"
    +    val calculateHeader = "UnsafeArrayData.calculateHeaderPortionInBytes"
    +
    +    val baseOffset = Platform.BYTE_ARRAY_OFFSET
    +    val longSize = LongType.defaultSize
    +    val primitiveValueTypeName = 
CodeGenerator.primitiveTypeName(childArrayType.elementType)
    +    val (valuePosition, indexPosition) = if (idxFirst) ("1", "0") else 
("0", "1")
    +
    +    s"""
    +       |final int $structSize = ${UnsafeRow.calculateBitSetWidthInBytes(2) 
+ longSize * 2};
    +       |final long $byteArraySize = $calculateArraySize($numElements, 
$longSize + $structSize);
    +       |final int $structsOffset = $calculateHeader($numElements) + 
$numElements * $longSize;
    +       |if ($byteArraySize > $MAX_ARRAY_LENGTH) {
    +       |  ${genCodeForAnyElements(ctx, childVariableName, arrayData, 
numElements)}
    +       |} else {
    +       |  final byte[] $data = new byte[(int)$byteArraySize];
    +       |  UnsafeArrayData $unsafeArrayData = new UnsafeArrayData();
    +       |  Platform.putLong($data, $baseOffset, $numElements);
    +       |  $unsafeArrayData.pointTo($data, $baseOffset, 
(int)$byteArraySize);
    +       |  UnsafeRow $unsafeRow = new UnsafeRow(2);
    +       |  for (int z = 0; z < $numElements; z++) {
    +       |    long offset = $structsOffset + z * $structSize;
    +       |    $unsafeArrayData.setLong(z, (offset << 32) + $structSize);
    +       |    $unsafeRow.pointTo($data, $baseOffset + offset, $structSize);
    +       |    if ($childVariableName.isNullAt(z)) {
    +       |      $unsafeRow.setNullAt($valuePosition);
    +       |    } else {
    +       |      $unsafeRow.set$primitiveValueTypeName(
    +       |        $valuePosition,
    +       |        ${CodeGenerator.getValue(childVariableName, 
childArrayType.elementType, "z")}
    +       |      );
    +       |    }
    +       |    $unsafeRow.setInt($indexPosition, $idxGen);
    +       |  }
    +       |  $arrayData = $unsafeArrayData;
    +       |}
    +     """.stripMargin
    +  }
    +
    +  private def genCodeForAnyElements(
    +      ctx: CodegenContext,
    +      childVariableName: String,
    +      arrayData: String,
    +      numElements: String): String = {
    +    val genericArrayClass = classOf[GenericArrayData].getName
    +    val rowClass = classOf[GenericInternalRow].getName
    +    val data = ctx.freshName("internalRowArray")
    +
    +    val getElement = CodeGenerator.getValue(childVariableName, 
childArrayType.elementType, "z")
    +    val elementValue = if 
(CodeGenerator.isPrimitiveType(childArrayType.elementType)) {
    --- End diff --
    
    Can we remove null check if `containNulls` is false even when elementType 
is not primitive type?


---

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to