Github user kiszk commented on a diff in the pull request:
https://github.com/apache/spark/pull/21121#discussion_r183947619
--- Diff:
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
---
@@ -883,3 +884,157 @@ case class Concat(children: Seq[Expression]) extends
Expression {
override def sql: String = s"concat(${children.map(_.sql).mkString(",
")})"
}
+
+/**
+ * Transforms an array by assigning an order number to each element.
+ */
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage = "_FUNC_(array[, indexFirst, startFromZero]) - Transforms the
input array by encapsulating elements into pairs with indexes indicating the
order.",
+ examples = """
+ Examples:
+ > SELECT _FUNC_(array("d", "a", null, "b"));
+ [("d",1),("a",2),(null,3),("b",4)]
+ > SELECT _FUNC_(array("d", "a", null, "b"), true, false);
+ [(1,"d"),(2,"a"),(3,null),(4,"b")]
+ > SELECT _FUNC_(array("d", "a", null, "b"), true, true);
+ [(0,"d"),(1,"a"),(2,null),(3,"b")]
+ """,
+ since = "2.4.0")
+// scalastyle:on line.size.limit
+case class ZipWithIndex(child: Expression, indexFirst: Expression,
startFromZero: Expression)
+ extends UnaryExpression with ExpectsInputTypes {
+
+ def this(e: Expression) = this(e, Literal.FalseLiteral,
Literal.FalseLiteral)
+
+ def exprToFlag(e: Expression, order: String): Boolean = e match {
+ case Literal(v: Boolean, BooleanType) => v
+ case _ => throw new AnalysisException(s"The $order argument has to be
a boolean constant.")
+ }
+
+ private val idxFirst: Boolean = exprToFlag(indexFirst, "second")
+
+ private val (idxShift, idxGen): (Int, String) = if
(exprToFlag(startFromZero, "third")) {
+ (0, "z")
+ } else {
+ (1, "z + 1")
+ }
+
+ private val MAX_ARRAY_LENGTH: Int =
ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH
+
+ override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType)
+
+ lazy val childArrayType: ArrayType =
child.dataType.asInstanceOf[ArrayType]
+
+ override def dataType: DataType = {
+ val elementField = StructField("value", childArrayType.elementType,
childArrayType.containsNull)
+ val indexField = StructField("index", IntegerType, false)
+
+ val fields = if (idxFirst) Seq(indexField, elementField) else
Seq(elementField, indexField)
+
+ ArrayType(StructType(fields), false)
+ }
+
+ override protected def nullSafeEval(input: Any): Any = {
+ val array =
input.asInstanceOf[ArrayData].toObjectArray(childArrayType.elementType)
+
+ val makeStruct = (v: Any, i: Int) => if (idxFirst) InternalRow(i, v)
else InternalRow(v, i)
+ val resultData = array.zipWithIndex.map{case (v, i) => makeStruct(v, i
+ idxShift)}
+
+ new GenericArrayData(resultData)
+ }
+
+ override protected def doGenCode(ctx: CodegenContext, ev: ExprCode):
ExprCode = {
+ nullSafeCodeGen(ctx, ev, c => {
+ val numElements = ctx.freshName("numElements")
+ val code = if
(CodeGenerator.isPrimitiveType(childArrayType.elementType)) {
+ genCodeForPrimitiveElements(ctx, c, ev.value, numElements)
+ } else {
+ genCodeForAnyElements(ctx, c, ev.value, numElements)
+ }
+ s"""
+ |final int $numElements = $c.numElements();
+ |$code
+ """.stripMargin
+ })
+ }
+
+ private def genCodeForPrimitiveElements(
+ ctx: CodegenContext,
+ childVariableName: String,
+ arrayData: String,
+ numElements: String): String = {
+ val byteArraySize = ctx.freshName("byteArraySize")
+ val data = ctx.freshName("byteArray")
+ val unsafeRow = ctx.freshName("unsafeRow")
+ val structSize = ctx.freshName("structSize")
+ val unsafeArrayData = ctx.freshName("unsafeArrayData")
+ val structsOffset = ctx.freshName("structsOffset")
+ val calculateArraySize =
"UnsafeArrayData.calculateSizeOfUnderlyingByteArray"
+ val calculateHeader = "UnsafeArrayData.calculateHeaderPortionInBytes"
+
+ val baseOffset = Platform.BYTE_ARRAY_OFFSET
+ val longSize = LongType.defaultSize
+ val primitiveValueTypeName =
CodeGenerator.primitiveTypeName(childArrayType.elementType)
+ val (valuePosition, indexPosition) = if (idxFirst) ("1", "0") else
("0", "1")
+
+ s"""
+ |final int $structSize = ${UnsafeRow.calculateBitSetWidthInBytes(2)
+ longSize * 2};
+ |final long $byteArraySize = $calculateArraySize($numElements,
$longSize + $structSize);
+ |final int $structsOffset = $calculateHeader($numElements) +
$numElements * $longSize;
+ |if ($byteArraySize > $MAX_ARRAY_LENGTH) {
+ | ${genCodeForAnyElements(ctx, childVariableName, arrayData,
numElements)}
+ |} else {
+ | final byte[] $data = new byte[(int)$byteArraySize];
+ | UnsafeArrayData $unsafeArrayData = new UnsafeArrayData();
+ | Platform.putLong($data, $baseOffset, $numElements);
+ | $unsafeArrayData.pointTo($data, $baseOffset,
(int)$byteArraySize);
+ | UnsafeRow $unsafeRow = new UnsafeRow(2);
+ | for (int z = 0; z < $numElements; z++) {
+ | long offset = $structsOffset + z * $structSize;
+ | $unsafeArrayData.setLong(z, (offset << 32) + $structSize);
+ | $unsafeRow.pointTo($data, $baseOffset + offset, $structSize);
+ | if ($childVariableName.isNullAt(z)) {
+ | $unsafeRow.setNullAt($valuePosition);
+ | } else {
+ | $unsafeRow.set$primitiveValueTypeName(
+ | $valuePosition,
+ | ${CodeGenerator.getValue(childVariableName,
childArrayType.elementType, "z")}
+ | );
+ | }
+ | $unsafeRow.setInt($indexPosition, $idxGen);
+ | }
+ | $arrayData = $unsafeArrayData;
+ |}
+ """.stripMargin
+ }
+
+ private def genCodeForAnyElements(
+ ctx: CodegenContext,
+ childVariableName: String,
+ arrayData: String,
+ numElements: String): String = {
+ val genericArrayClass = classOf[GenericArrayData].getName
+ val rowClass = classOf[GenericInternalRow].getName
+ val data = ctx.freshName("internalRowArray")
+
+ val getElement = CodeGenerator.getValue(childVariableName,
childArrayType.elementType, "z")
+ val elementValue = if
(CodeGenerator.isPrimitiveType(childArrayType.elementType)) {
--- End diff --
Can we remove null check if `containNulls` is false even when elementType
is not primitive type?
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]