Github user ueshin commented on a diff in the pull request:

    https://github.com/apache/spark/pull/20858#discussion_r181643397
  
    --- Diff: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
 ---
    @@ -287,3 +290,231 @@ case class ArrayContains(left: Expression, right: 
Expression)
     
       override def prettyName: String = "array_contains"
     }
    +
    +/**
    + * Concatenates multiple input columns together into a single column.
    + * The function works with strings, binary and compatible array columns.
    + */
    +@ExpressionDescription(
    +  usage = "_FUNC_(col1, col2, ..., colN) - Returns the concatenation of 
col1, col2, ..., colN.",
    +  examples = """
    +    Examples:
    +      > SELECT _FUNC_('Spark', 'SQL');
    +       SparkSQL
    +      > SELECT _FUNC_(array(1, 2, 3), array(4, 5), array(6));
    + |     [1,2,3,4,5,6]
    +  """)
    +case class Concat(children: Seq[Expression]) extends Expression {
    +
    +  val allowedTypes = Seq(StringType, BinaryType, ArrayType)
    +
    +  override def checkInputDataTypes(): TypeCheckResult = {
    +    if (children.isEmpty) {
    +      TypeCheckResult.TypeCheckSuccess
    +    } else {
    +      val childTypes = children.map(_.dataType)
    +      if (childTypes.exists(tpe => 
!allowedTypes.exists(_.acceptsType(tpe)))) {
    +        return TypeCheckResult.TypeCheckFailure(
    +          s"input to function $prettyName should have been StringType, 
BinaryType or ArrayType," +
    +            s" but it's " + childTypes.map(_.simpleString).mkString("[", 
", ", "]"))
    +      }
    +      TypeUtils.checkForSameTypeInputExpr(childTypes, s"function 
$prettyName")
    +    }
    +  }
    +
    +  override def dataType: DataType = 
children.map(_.dataType).headOption.getOrElse(StringType)
    +
    +  lazy val javaType: String = CodeGenerator.javaType(dataType)
    +
    +  override def nullable: Boolean = children.exists(_.nullable)
    +  override def foldable: Boolean = children.forall(_.foldable)
    +
    +  override def eval(input: InternalRow): Any = dataType match {
    +    case BinaryType =>
    +      val inputs = children.map(_.eval(input).asInstanceOf[Array[Byte]])
    +      ByteArray.concat(inputs: _*)
    +    case StringType =>
    +      val inputs = children.map(_.eval(input).asInstanceOf[UTF8String])
    +      UTF8String.concat(inputs : _*)
    +    case ArrayType(elementType, _) =>
    +      val inputs = children.toStream.map(_.eval(input))
    +      if (inputs.contains(null)) {
    +        null
    +      } else {
    +        val elements = 
inputs.flatMap(_.asInstanceOf[ArrayData].toObjectArray(elementType))
    +        new GenericArrayData(elements)
    +      }
    +  }
    +
    +  override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): 
ExprCode = {
    +    val evals = children.map(_.genCode(ctx))
    +    val args = ctx.freshName("args")
    +
    +    val inputs = evals.zipWithIndex.map { case (eval, index) =>
    +      s"""
    +        ${eval.code}
    +        if (!${eval.isNull}) {
    +          $args[$index] = ${eval.value};
    +        }
    +      """
    +    }
    +
    +    val (concatenator, initCode) = dataType match {
    +      case BinaryType =>
    +        (classOf[ByteArray].getName, s"byte[][] $args = new 
byte[${evals.length}][];")
    +      case StringType =>
    +        ("UTF8String", s"UTF8String[] $args = new 
UTF8String[${evals.length}];")
    +      case ArrayType(elementType, _) =>
    +        val arrayConcatClass = if 
(CodeGenerator.isPrimitiveType(elementType)) {
    +          genCodeForPrimitiveArrayConcat(ctx, elementType)
    +        } else {
    +          genCodeForComplexArrayConcat(ctx, elementType)
    +        }
    +        (arrayConcatClass, s"ArrayData[] $args = new 
ArrayData[${evals.length}];")
    +    }
    +    val codes = ctx.splitExpressionsWithCurrentInputs(
    +      expressions = inputs,
    +      funcName = "valueConcat",
    +      extraArguments = (s"${javaType}[]", args) :: Nil)
    +    ev.copy(s"""
    +      $initCode
    +      $codes
    +      ${javaType} ${ev.value} = $concatenator.concat($args);
    +      boolean ${ev.isNull} = ${ev.value} == null;
    +    """)
    +  }
    +
    +  private def genCodeForNumberOfElements(ctx: CodegenContext) : (String, 
String) = {
    +    val tempVariableName = ctx.freshName("tempNumElements")
    +    val numElementsConstant = ctx.freshName("numElements")
    +    val assignments = (0 until children.length)
    +      .map(idx => s"$tempVariableName[0] += args[$idx].numElements();")
    +
    +    val assignmentSection = ctx.splitExpressions(
    +      expressions = assignments,
    +      funcName = "complexArrayConcat",
    +      arguments = Seq((s"${javaType}[]", "args"), ("int[]", 
tempVariableName)))
    +
    +    (s"""
    +        |int[] $tempVariableName = new int[]{0};
    +        |$assignmentSection
    +        |final int $numElementsConstant = $tempVariableName[0];
    +      """.stripMargin,
    +     numElementsConstant)
    +  }
    +
    +  private def nullArgumentProtection(ctx: CodegenContext) : String = {
    +    val isNullVariable = ctx.freshName("isArrayNull")
    +    val assignments = children
    +      .zipWithIndex
    +      .filter(_._1.nullable)
    +      .map(ci => s"$isNullVariable[0] |= args[${ci._2}] == null;")
    +
    +    if (assignments.length > 0) {
    +      val assignmentSection = ctx.splitExpressions(
    +        expressions = assignments,
    +        funcName = "isNullArrayConcat",
    +        arguments = Seq((s"${javaType}[]", "args"), ("boolean[]", 
isNullVariable)))
    +
    +      s"""
    +         |boolean[] $isNullVariable = new boolean[]{false};
    +         |$assignmentSection;
    +         |if ($isNullVariable[0]) return null;
    +       """.stripMargin
    +    } else {
    +      ""
    +    }
    +  }
    +
    +  private def genCodeForPrimitiveArrayConcat(ctx: CodegenContext, 
elementType: DataType): String = {
    +    val arrayName = ctx.freshName("array")
    +    val arraySizeName = ctx.freshName("size")
    +    val counter = ctx.freshName("counter")
    +    val arrayData = ctx.freshName("arrayData")
    +
    +    val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx)
    +
    +    val unsafeArraySizeInBytes = s"""
    +      |int $arraySizeName = 
UnsafeArrayData.calculateHeaderPortionInBytes($numElemName) +
    +      
|${classOf[ByteArrayMethods].getName}.roundNumberOfBytesToNearestWord(
    +      |  ${elementType.defaultSize} * $numElemName
    +      |);
    +      """.stripMargin
    +    val baseOffset = Platform.BYTE_ARRAY_OFFSET
    +
    +    val primitiveValueTypeName = 
CodeGenerator.primitiveTypeName(elementType)
    +    val assignments = (0 until children.length).map { idx =>
    +      s"""
    +         |for (int z = 0; z < args[$idx].numElements(); z++) {
    +         |  if (args[$idx].isNullAt(z)) {
    +         |    $arrayData.setNullAt($counter[0]);
    +         |  } else {
    +         |    $arrayData.set$primitiveValueTypeName(
    +         |      $counter[0],
    +         |      ${CodeGenerator.getValue(s"args[$idx]", elementType, "z")}
    +         |    );
    +         |  }
    +         |  $counter[0]++;
    +         |}
    +        """.stripMargin
    +    }
    +    val assignmentSection = ctx.splitExpressions(
    +      expressions = assignments,
    +      funcName = "primitiveArrayConcat",
    +      arguments = Seq(
    +        (s"${javaType}[]", "args"),
    +        ("UnsafeArrayData", arrayData),
    +        ("int[]", counter)))
    +
    +    s"""new Object() {
    +       |  public ArrayData concat(${CodeGenerator.javaType(dataType)}[] 
args) {
    +       |    ${nullArgumentProtection(ctx)}
    +       |    $numElemCode
    +       |    $unsafeArraySizeInBytes
    +       |    byte[] $arrayName = new byte[$arraySizeName];
    +       |    UnsafeArrayData $arrayData = new UnsafeArrayData();
    +       |    Platform.putLong($arrayName, $baseOffset, $numElemName);
    +       |    $arrayData.pointTo($arrayName, $baseOffset, $arraySizeName);
    +       |    int[] $counter = new int[]{0};
    +       |    $assignmentSection
    +       |    return $arrayData;
    +       |  }
    +       |}""".stripMargin
    +  }
    +
    +  private def genCodeForComplexArrayConcat(ctx: CodegenContext, 
elementType: DataType): String = {
    +    val genericArrayClass = classOf[GenericArrayData].getName
    +    val arrayData = ctx.freshName("arrayObjects")
    +    val counter = ctx.freshName("counter")
    +
    +    val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx)
    +
    +    val assignments = (0 until children.length).map { idx =>
    +      s"""
    +         |for (int z = 0; z < args[$idx].numElements(); z++) {
    +         |  $arrayData[$counter[0]] = 
${CodeGenerator.getValue(s"args[$idx]", elementType, "z")};
    --- End diff --
    
    We need to check null?


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to