Github user viirya commented on a diff in the pull request: https://github.com/apache/spark/pull/21121#discussion_r183237327 --- Diff: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala --- @@ -883,3 +884,140 @@ case class Concat(children: Seq[Expression]) extends Expression { override def sql: String = s"concat(${children.map(_.sql).mkString(", ")})" } + +/** + * Returns the maximum value in the array. + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(array[, indexFirst]) - Transforms the input array by encapsulating elements into pairs with indexes indicating the order.", + examples = """ + Examples: + > SELECT _FUNC_(array("d", "a", null, "b")); + [("d",0),("a",1),(null,2),("b",3)] + > SELECT _FUNC_(array("d", "a", null, "b"), true); + [(0,"d"),(1,"a"),(2,null),(3,"b")] + """, + since = "2.4.0") +// scalastyle:on line.size.limit +case class ZipWithIndex(child: Expression, indexFirst: Expression) + extends UnaryExpression with ExpectsInputTypes { + + def this(e: Expression) = this(e, Literal.FalseLiteral) + + private val idxFirst: Boolean = indexFirst match { + case Literal(v: Boolean, BooleanType) => v + case _ => throw new AnalysisException("The second argument has to be a boolean constant.") + } + + private val MAX_ARRAY_LENGTH: Int = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH + + override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType) + + lazy val childArrayType: ArrayType = child.dataType.asInstanceOf[ArrayType] + + override def dataType: DataType = { + val elementField = StructField("value", childArrayType.elementType, childArrayType.containsNull) + val indexField = StructField("index", IntegerType, false) + + val fields = if (idxFirst) Seq(indexField, elementField) else Seq(elementField, indexField) + + ArrayType(StructType(fields), false) + } + + override protected def nullSafeEval(input: Any): Any = { + val array = input.asInstanceOf[ArrayData].toObjectArray(childArrayType.elementType) + + val makeStruct = (v: Any, i: Int) => if (idxFirst) InternalRow(i, v) else InternalRow(v, i) + val resultData = array.zipWithIndex.map{case (v, i) => makeStruct(v, i)} + + new GenericArrayData(resultData) + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, c => { + if (CodeGenerator.isPrimitiveType(childArrayType.elementType)) { + genCodeForPrimitiveElements(ctx, c, ev.value) + } else { + genCodeForNonPrimitiveElements(ctx, c, ev.value) + } + }) + } + + private def genCodeForPrimitiveElements( + ctx: CodegenContext, + childVariableName: String, + arrayData: String): String = { + val numElements = ctx.freshName("numElements") + val byteArraySize = ctx.freshName("byteArraySize") + val data = ctx.freshName("byteArray") + val unsafeRow = ctx.freshName("unsafeRow") + val structSize = ctx.freshName("structSize") + val unsafeArrayData = ctx.freshName("unsafeArrayData") + val structsOffset = ctx.freshName("structsOffset") + val calculateArraySize = "UnsafeArrayData.calculateSizeOfUnderlyingByteArray" + val calculateHeader = "UnsafeArrayData.calculateHeaderPortionInBytes" + + val baseOffset = Platform.BYTE_ARRAY_OFFSET + val longSize = LongType.defaultSize + val primitiveValueTypeName = CodeGenerator.primitiveTypeName(childArrayType.elementType) + val (valuePosition, indexPosition) = if (idxFirst) ("1", "0") else ("0", "1") + + s""" + |final int $numElements = $childVariableName.numElements(); + |final int $structSize = ${UnsafeRow.calculateBitSetWidthInBytes(2) + longSize * 2}; + |final long $byteArraySize = $calculateArraySize($numElements, $longSize + $structSize); + |final int $structsOffset = $calculateHeader($numElements) + $numElements * $longSize; + |if ($byteArraySize > $MAX_ARRAY_LENGTH) { --- End diff -- Btw, if we use `GenericArrayData` as output array, can't we avoid this limit?
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org