Github user viirya commented on a diff in the pull request:
https://github.com/apache/spark/pull/20858#discussion_r176902161
--- Diff:
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
---
@@ -287,3 +289,152 @@ case class ArrayContains(left: Expression, right:
Expression)
override def prettyName: String = "array_contains"
}
+
+/**
+ * Concatenates multiple arrays into one.
+ */
+@ExpressionDescription(
+ usage = "_FUNC_(expr, ...) - Concatenates multiple arrays into one.",
+ examples = """
+ Examples:
+ > SELECT _FUNC_(array(1, 2, 3), array(4, 5), array(6));
+ [1,2,3,4,5,6]
+ """)
+case class ConcatArrays(children: Seq[Expression]) extends Expression with
NullSafeEvaluation {
+
+ override def checkInputDataTypes(): TypeCheckResult = {
+ val arrayCheck = checkInputDataTypesAreArrays
+ if(arrayCheck.isFailure) arrayCheck
+ else TypeUtils.checkForSameTypeInputExpr(children.map(_.dataType),
s"function $prettyName")
+ }
+
+ private def checkInputDataTypesAreArrays(): TypeCheckResult =
+ {
+ val mismatches = children.zipWithIndex.collect {
+ case (child, idx) if !ArrayType.acceptsType(child.dataType) =>
+ s"argument ${idx + 1} has to be ${ArrayType.simpleString} type, " +
+ s"however, '${child.sql}' is of ${child.dataType.simpleString}
type."
+ }
+
+ if (mismatches.isEmpty) {
+ TypeCheckResult.TypeCheckSuccess
+ } else {
+ TypeCheckResult.TypeCheckFailure(mismatches.mkString(" "))
+ }
+ }
+
+ override def dataType: ArrayType =
+ children
+ .headOption.map(_.dataType.asInstanceOf[ArrayType])
+ .getOrElse(ArrayType.defaultConcreteType.asInstanceOf[ArrayType])
+
+
+ override protected def nullSafeEval(inputs: Seq[Any]): Any = {
+ val elements =
inputs.flatMap(_.asInstanceOf[ArrayData].toObjectArray(dataType.elementType))
+ new GenericArrayData(elements)
+ }
+
+ override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+ nullSafeCodeGen(ctx, ev, arrays => {
+ val elementType = dataType.elementType
+ if (CodeGenerator.isPrimitiveType(elementType)) {
+ genCodeForConcatOfPrimitiveElements(ctx, elementType, arrays,
ev.value)
+ } else {
+ genCodeForConcatOfComplexElements(ctx, arrays, ev.value)
+ }
+ })
+ }
+
+ private def genCodeForNumberOfElements(
+ ctx: CodegenContext,
+ elements: Seq[String]
+ ) : (String, String) = {
+ val variableName = ctx.freshName("numElements")
+ val code = elements
+ .map(el => s"$variableName += $el.numElements();")
+ .foldLeft( s"int $variableName = 0;")((acc, s) => acc + "\n" + s)
+ (code, variableName)
+ }
+
+ private def genCodeForConcatOfPrimitiveElements(
+ ctx: CodegenContext,
+ elementType: DataType,
+ elements: Seq[String],
+ arrayDataName: String
+ ): String = {
+ val arrayName = ctx.freshName("array")
+ val arraySizeName = ctx.freshName("size")
+ val counter = ctx.freshName("counter")
+ val tempArrayDataName = ctx.freshName("tempArrayData")
+
+ val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx,
elements)
+
+ val unsafeArraySizeInBytes = s"""
+ |int $arraySizeName =
UnsafeArrayData.calculateHeaderPortionInBytes($numElemName) +
+
|${classOf[ByteArrayMethods].getName}.roundNumberOfBytesToNearestWord(
+ |${elementType.defaultSize} * $numElemName
+ |);
+ """.stripMargin
+ val baseOffset = Platform.BYTE_ARRAY_OFFSET
+
+ val primitiveValueTypeName =
CodeGenerator.primitiveTypeName(elementType)
+ val assignments = elements.map { el =>
+ s"""
+ |for(int z = 0; z < $el.numElements(); z++) {
--- End diff --
Stype: `for (`
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]