Github user mn-mikke commented on a diff in the pull request:

    https://github.com/apache/spark/pull/20858#discussion_r178759108
  
    --- Diff: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
 ---
    @@ -287,3 +289,152 @@ case class ArrayContains(left: Expression, right: 
Expression)
     
       override def prettyName: String = "array_contains"
     }
    +
    +/**
    + * Concatenates multiple arrays into one.
    + */
    +@ExpressionDescription(
    +  usage = "_FUNC_(expr, ...) - Concatenates multiple arrays into one.",
    +  examples = """
    +    Examples:
    +      > SELECT _FUNC_(array(1, 2, 3), array(4, 5), array(6));
    +       [1,2,3,4,5,6]
    +  """)
    +case class ConcatArrays(children: Seq[Expression]) extends Expression with 
NullSafeEvaluation {
    +
    +  override def checkInputDataTypes(): TypeCheckResult = {
    +    val arrayCheck = checkInputDataTypesAreArrays
    +    if(arrayCheck.isFailure) arrayCheck
    +    else TypeUtils.checkForSameTypeInputExpr(children.map(_.dataType), 
s"function $prettyName")
    +  }
    +
    +  private def checkInputDataTypesAreArrays(): TypeCheckResult =
    +  {
    +    val mismatches = children.zipWithIndex.collect {
    +      case (child, idx) if !ArrayType.acceptsType(child.dataType) =>
    +        s"argument ${idx + 1} has to be ${ArrayType.simpleString} type, " +
    +          s"however, '${child.sql}' is of ${child.dataType.simpleString} 
type."
    +    }
    +
    +    if (mismatches.isEmpty) {
    +      TypeCheckResult.TypeCheckSuccess
    +    } else {
    +      TypeCheckResult.TypeCheckFailure(mismatches.mkString(" "))
    +    }
    +  }
    +
    +  override def dataType: ArrayType =
    +    children
    +      .headOption.map(_.dataType.asInstanceOf[ArrayType])
    +      .getOrElse(ArrayType.defaultConcreteType.asInstanceOf[ArrayType])
    +
    +
    +  override protected def nullSafeEval(inputs: Seq[Any]): Any = {
    +    val elements = 
inputs.flatMap(_.asInstanceOf[ArrayData].toObjectArray(dataType.elementType))
    +    new GenericArrayData(elements)
    +  }
    +
    +  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    +    nullSafeCodeGen(ctx, ev, arrays => {
    +      val elementType = dataType.elementType
    +      if (CodeGenerator.isPrimitiveType(elementType)) {
    +        genCodeForConcatOfPrimitiveElements(ctx, elementType, arrays, 
ev.value)
    +      } else {
    +        genCodeForConcatOfComplexElements(ctx, arrays, ev.value)
    +      }
    +    })
    +  }
    +
    +  private def genCodeForNumberOfElements(
    +    ctx: CodegenContext,
    +    elements: Seq[String]
    +  ) : (String, String) = {
    +    val variableName = ctx.freshName("numElements")
    +    val code = elements
    +      .map(el => s"$variableName += $el.numElements();")
    +      .foldLeft( s"int $variableName = 0;")((acc, s) => acc + "\n" + s)
    +    (code, variableName)
    +  }
    +
    +  private def genCodeForConcatOfPrimitiveElements(
    +    ctx: CodegenContext,
    +    elementType: DataType,
    +    elements: Seq[String],
    +    arrayDataName: String
    +  ): String = {
    +    val arrayName = ctx.freshName("array")
    +    val arraySizeName = ctx.freshName("size")
    +    val counter = ctx.freshName("counter")
    +    val tempArrayDataName = ctx.freshName("tempArrayData")
    +
    +    val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx, 
elements)
    +
    +    val unsafeArraySizeInBytes = s"""
    +      |int $arraySizeName = 
UnsafeArrayData.calculateHeaderPortionInBytes($numElemName) +
    +      
|${classOf[ByteArrayMethods].getName}.roundNumberOfBytesToNearestWord(
    +      |${elementType.defaultSize} * $numElemName
    +      |);
    +      """.stripMargin
    +    val baseOffset = Platform.BYTE_ARRAY_OFFSET
    +
    +    val primitiveValueTypeName = 
CodeGenerator.primitiveTypeName(elementType)
    +    val assignments = elements.map { el =>
    +      s"""
    +        |for(int z = 0; z < $el.numElements(); z++) {
    +        | if($el.isNullAt(z)) {
    +        |   $tempArrayDataName.setNullAt($counter);
    +        | } else {
    +        |   $tempArrayDataName.set$primitiveValueTypeName(
    +        |     $counter,
    +        |     $el.get$primitiveValueTypeName(z)
    +        |   );
    +        | }
    +        | $counter++;
    +        |}
    +        """.stripMargin
    +    }.mkString("\n")
    +
    +    s"""
    +      |$numElemCode
    +      |$unsafeArraySizeInBytes
    +      |byte[] $arrayName = new byte[$arraySizeName];
    +      |UnsafeArrayData $tempArrayDataName = new UnsafeArrayData();
    +      |Platform.putLong($arrayName, $baseOffset, $numElemName);
    +      |$tempArrayDataName.pointTo($arrayName, $baseOffset, $arraySizeName);
    +      |int $counter = 0;
    +      |$assignments
    +      |$arrayDataName = $tempArrayDataName;
    +    """.stripMargin
    +
    +  }
    +
    +  private def genCodeForConcatOfComplexElements(
    +   ctx: CodegenContext,
    +   elements: Seq[String],
    +   arrayDataName: String
    +  ): String = {
    +    val genericArrayClass = classOf[GenericArrayData].getName
    +    val arrayName = ctx.freshName("arrayObject")
    +    val counter = ctx.freshName("counter")
    +    val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx, 
elements)
    +
    +    val assignments = elements.map { el =>
    +      s"""
    +        |for(int z = 0; z < $el.numElements(); z++) {
    +        |  $arrayName[$counter] = $el.array()[z];
    +        |  $counter++;
    +        |}
    +     """.stripMargin
    +    }.mkString("\n")
    +
    +    s"""
    +      |$numElemCode
    +      |Object[] $arrayName = new Object[$numElemName];
    +      |int $counter = 0;
    +      |$assignments
    +      |$arrayDataName = new $genericArrayClass($arrayName);
    --- End diff --
    
    Yeah, currently there are no `write` methods on `UnsafeArrayWriter` or 
`set` methods on `UnsafeArrayData` that we could leverage for complex types. In 
theory, we could follow the same approach as in `InterprettedUnsafeProjection` 
and each complex type to a byte array and subsequently insert the produced byte 
array into the target `UnsafeArrayData`. Since this logic could be utilized 
from more places (e.g. `CreateArray`), it should be encapsulated into 
`UnsafeArrayWriter` or  `UnsafeArrayData` at first. What do you think?


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to