Github user kiszk commented on a diff in the pull request:

    https://github.com/apache/spark/pull/20858#discussion_r180181355
  
    --- Diff: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
 ---
    @@ -287,3 +290,191 @@ case class ArrayContains(left: Expression, right: 
Expression)
     
       override def prettyName: String = "array_contains"
     }
    +
    +/**
    + * Concatenates multiple input columns together into a single column.
    + * The function works with strings, binary and compatible array columns.
    + */
    +@ExpressionDescription(
    +  usage = "_FUNC_(col1, col2, ..., colN) - Returns the concatenation of 
col1, col2, ..., colN.",
    +  examples = """
    +    Examples:
    +      > SELECT _FUNC_('Spark', 'SQL');
    +       SparkSQL
    +      > SELECT _FUNC_(array(1, 2, 3), array(4, 5), array(6));
    + |     [1,2,3,4,5,6]
    +  """)
    +case class Concat(children: Seq[Expression]) extends Expression {
    +
    +  val allowedTypes = Seq(StringType, BinaryType, ArrayType)
    +
    +  override def checkInputDataTypes(): TypeCheckResult = {
    +    if (children.isEmpty) {
    +      TypeCheckResult.TypeCheckSuccess
    +    } else {
    +      val childTypes = children.map(_.dataType)
    +      if (childTypes.exists(tpe => 
!allowedTypes.exists(_.acceptsType(tpe)))) {
    +        return TypeCheckResult.TypeCheckFailure(
    +          s"input to function $prettyName should have been StringType, 
BinaryType or ArrayType," +
    +            s" but it's " + childTypes.map(_.simpleString).mkString("[", 
", ", "]"))
    +      }
    +      TypeUtils.checkForSameTypeInputExpr(childTypes, s"function 
$prettyName")
    +    }
    +  }
    +
    +  override def dataType: DataType = 
children.map(_.dataType).headOption.getOrElse(StringType)
    +
    +  override def nullable: Boolean = children.exists(_.nullable)
    +  override def foldable: Boolean = children.forall(_.foldable)
    +
    +  override def eval(input: InternalRow): Any = dataType match {
    +    case BinaryType =>
    +      val inputs = children.map(_.eval(input).asInstanceOf[Array[Byte]])
    +      ByteArray.concat(inputs: _*)
    +    case StringType =>
    +      val inputs = children.map(_.eval(input).asInstanceOf[UTF8String])
    +      UTF8String.concat(inputs : _*)
    +    case ArrayType(elementType, _) =>
    +      val inputs = children.toStream.map(_.eval(input))
    +      if (inputs.contains(null)) {
    +        null
    +      } else {
    +        val elements = 
inputs.flatMap(_.asInstanceOf[ArrayData].toObjectArray(elementType))
    +        new GenericArrayData(elements)
    +      }
    +  }
    +
    +  override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): 
ExprCode = {
    +    val evals = children.map(_.genCode(ctx))
    +    val args = ctx.freshName("args")
    +
    +    val inputs = evals.zipWithIndex.map { case (eval, index) =>
    +      s"""
    +        ${eval.code}
    +        if (!${eval.isNull}) {
    +          $args[$index] = ${eval.value};
    +        }
    +      """
    +    }
    +
    +    val (concatenator, initCode) = dataType match {
    +      case BinaryType =>
    +        (classOf[ByteArray].getName, s"byte[][] $args = new 
byte[${evals.length}][];")
    +      case StringType =>
    +        ("UTF8String", s"UTF8String[] $args = new 
UTF8String[${evals.length}];")
    +      case ArrayType(elementType, _) =>
    +        val arrayConcatClass = if 
(CodeGenerator.isPrimitiveType(elementType)) {
    +          genCodeForPrimitiveArrayConcat(ctx, elementType)
    +        } else {
    +          genCodeForComplexArrayConcat(ctx)
    +        }
    +        (arrayConcatClass, s"ArrayData[] $args = new 
ArrayData[${evals.length}];")
    +    }
    +    val codes = ctx.splitExpressionsWithCurrentInputs(
    +      expressions = inputs,
    +      funcName = "valueConcat",
    +      extraArguments = (s"${CodeGenerator.javaType(dataType)}[]", args) :: 
Nil)
    +    ev.copy(s"""
    +      $initCode
    +      $codes
    +      ${CodeGenerator.javaType(dataType)} ${ev.value} = 
$concatenator.concat($args);
    +      boolean ${ev.isNull} = ${ev.value} == null;
    +    """)
    +  }
    +
    +  private def genCodeForNumberOfElements(ctx: CodegenContext) : (String, 
String) = {
    +    val variableName = ctx.freshName("numElements")
    +    val code = (0 until children.length)
    +      .map(idx => s"$variableName += args[$idx].numElements();")
    +      .foldLeft(s"int $variableName = 0;")((acc, s) => acc + "\n" + s)
    +    (code, variableName)
    +  }
    +
    +  private def nullArgumentProtection() : String = {
    +    children.zipWithIndex
    +      .filter(_._1.nullable)
    +      .map(ci => s"if (args[${ci._2}] == null) return null;")
    +      .mkString("\n")
    +  }
    +
    +  private def genCodeForPrimitiveArrayConcat(ctx: CodegenContext, 
elementType: DataType): String = {
    +    val arrayName = ctx.freshName("array")
    +    val arraySizeName = ctx.freshName("size")
    +    val counter = ctx.freshName("counter")
    +    val arrayDataName = ctx.freshName("arrayData")
    +
    +    val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx)
    +
    +    val unsafeArraySizeInBytes = s"""
    +      |int $arraySizeName = 
UnsafeArrayData.calculateHeaderPortionInBytes($numElemName) +
    +      
|${classOf[ByteArrayMethods].getName}.roundNumberOfBytesToNearestWord(
    +      |  ${elementType.defaultSize} * $numElemName
    +      |);
    +      """.stripMargin
    +    val baseOffset = Platform.BYTE_ARRAY_OFFSET
    +
    +    val primitiveValueTypeName = 
CodeGenerator.primitiveTypeName(elementType)
    +    val assignments = (0 until children.length).map { idx =>
    +      s"""
    +         |for (int z = 0; z < args[$idx].numElements(); z++) {
    +         |  if (args[$idx].isNullAt(z)) {
    +         |    $arrayDataName.setNullAt($counter);
    +         |  } else {
    +         |    $arrayDataName.set$primitiveValueTypeName(
    +         |      $counter,
    +         |      args[$idx].get$primitiveValueTypeName(z)
    +         |    );
    +         |  }
    +         |  $counter++;
    +         |}
    +        """.stripMargin
    +    }.mkString("\n")
    +
    +    s"""new Object() {
    +       |  public ArrayData concat(${CodeGenerator.javaType(dataType)}[] 
args) {
    +       |    ${nullArgumentProtection()}
    +       |    $numElemCode
    +       |    $unsafeArraySizeInBytes
    +       |    byte[] $arrayName = new byte[$arraySizeName];
    +       |    UnsafeArrayData $arrayDataName = new UnsafeArrayData();
    +       |    Platform.putLong($arrayName, $baseOffset, $numElemName);
    +       |    $arrayDataName.pointTo($arrayName, $baseOffset, 
$arraySizeName);
    +       |    int $counter = 0;
    +       |    $assignments
    +       |    return $arrayDataName;
    +       |  }
    +       |}""".stripMargin
    +  }
    +
    +  private def genCodeForComplexArrayConcat(ctx: CodegenContext): String = {
    +    val genericArrayClass = classOf[GenericArrayData].getName
    +    val arrayName = ctx.freshName("arrayObject")
    +    val counter = ctx.freshName("counter")
    +
    +    val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx)
    +
    +    val assignments = (0 until children.length).map { idx =>
    +      s"""
    +         |for (int z = 0; z < args[$idx].numElements(); z++) {
    +         |  $arrayName[$counter] = args[$idx].array()[z];
    +         |  $counter++;
    +         |}
    +        """.stripMargin
    +    }.mkString("\n")
    --- End diff --
    
    To use `mkString` may lead to a compilation error due to 64KB bytecode 
limitation. Would it be possible to use `CodegenContext.splitExpressions()`?


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to