Github user ueshin commented on a diff in the pull request:

    https://github.com/apache/spark/pull/21073#discussion_r197061444
  
    --- Diff: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
 ---
    @@ -474,6 +473,221 @@ case class MapEntries(child: Expression) extends 
UnaryExpression with ExpectsInp
       override def prettyName: String = "map_entries"
     }
     
    +/**
    + * Returns the union of all the given maps.
    + */
    +@ExpressionDescription(
    +  usage = "_FUNC_(map, ...) - Returns the union of all the given maps",
    +  examples = """
    +    Examples:
    +      > SELECT _FUNC_(map(1, 'a', 2, 'b'), map(2, 'c', 3, 'd'));
    +       [[1 -> "a"], [2 -> "b"], [2 -> "c"], [3 -> "d"]]
    +  """, since = "2.4.0")
    +case class MapConcat(children: Seq[Expression]) extends Expression {
    +
    +  private val MAX_MAP_SIZE: Int = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH
    +
    +  override def checkInputDataTypes(): TypeCheckResult = {
    +    // check key types and value types separately to allow 
valueContainsNull to vary
    +    if (children.exists(!_.dataType.isInstanceOf[MapType])) {
    +      TypeCheckResult.TypeCheckFailure(
    +        s"The given input of function $prettyName should all be of type 
map, " +
    +          "but they are " + 
children.map(_.dataType.simpleString).mkString("[", ", ", "]"))
    +    } else if 
(children.map(_.dataType.asInstanceOf[MapType].keyType).distinct.length > 1) {
    +      TypeCheckResult.TypeCheckFailure(
    +        s"The given input maps of function $prettyName should all be the 
same type, " +
    +          "but they are " + 
children.map(_.dataType.simpleString).mkString("[", ", ", "]"))
    +    } else if 
(children.map(_.dataType.asInstanceOf[MapType].valueType).distinct.length > 1) {
    +      TypeCheckResult.TypeCheckFailure(
    +        s"The given input maps of function $prettyName should all be the 
same type, " +
    +          "but they are " + 
children.map(_.dataType.simpleString).mkString("[", ", ", "]"))
    +    } else {
    +      TypeCheckResult.TypeCheckSuccess
    +    }
    +  }
    +
    +  override def dataType: MapType = {
    +    MapType(
    +      keyType = children.headOption
    +        
.map(_.dataType.asInstanceOf[MapType].keyType).getOrElse(StringType),
    +      valueType = children.headOption
    +        
.map(_.dataType.asInstanceOf[MapType].valueType).getOrElse(StringType),
    +      valueContainsNull = children.map(_.dataType.asInstanceOf[MapType])
    +        .exists(_.valueContainsNull)
    +    )
    +  }
    +
    +  override def nullable: Boolean = children.exists(_.nullable)
    +
    +  override def eval(input: InternalRow): Any = {
    +    val maps = children.map(_.eval(input))
    +    if (maps.contains(null)) {
    +      return null
    +    }
    +    val keyArrayDatas = maps.map(_.asInstanceOf[MapData].keyArray())
    +    val valueArrayDatas = maps.map(_.asInstanceOf[MapData].valueArray())
    +
    +    val numElements = keyArrayDatas.foldLeft(0L)((sum, ad) => sum + 
ad.numElements())
    +    if (numElements > MAX_MAP_SIZE) {
    +      throw new RuntimeException(s"Unsuccessful attempt to concat maps 
with $numElements" +
    +        s" elements due to exceeding the map size limit" +
    +        s" $MAX_MAP_SIZE.")
    +    }
    +    val finalKeyArray = new Array[AnyRef](numElements.toInt)
    +    val finalValueArray = new Array[AnyRef](numElements.toInt)
    +    var position = 0
    +    for (i <- keyArrayDatas.indices) {
    +      val keyArray = keyArrayDatas(i).toObjectArray(dataType.keyType)
    +      val valueArray = valueArrayDatas(i).toObjectArray(dataType.valueType)
    +      Array.copy(keyArray, 0, finalKeyArray, position, keyArray.length)
    +      Array.copy(valueArray, 0, finalValueArray, position, 
valueArray.length)
    +      position += keyArray.length
    +    }
    +
    +    new ArrayBasedMapData(new GenericArrayData(finalKeyArray),
    +      new GenericArrayData(finalValueArray))
    +  }
    +
    +  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    +    val mapCodes = children.map(_.genCode(ctx))
    +    val keyType = dataType.keyType
    +    val valueType = dataType.valueType
    +    val argsName = ctx.freshName("args")
    +    val keyArgsName = ctx.freshName("keyArgs")
    +    val valArgsName = ctx.freshName("valArgs")
    +
    +    val mapDataClass = classOf[MapData].getName
    +    val arrayBasedMapDataClass = classOf[ArrayBasedMapData].getName
    +    val arrayDataClass = classOf[ArrayData].getName
    +
    +    val init =
    +      s"""
    +        |$mapDataClass[] $argsName = new $mapDataClass[${mapCodes.size}];
    +        |$arrayDataClass[] $keyArgsName = new 
$arrayDataClass[${mapCodes.size}];
    +        |$arrayDataClass[] $valArgsName = new 
$arrayDataClass[${mapCodes.size}];
    +        |boolean ${ev.isNull} = false;
    +        |$mapDataClass ${ev.value} = null;
    +      """.stripMargin
    +
    +    val assignments = mapCodes.zipWithIndex.map { case (m, i) =>
    +      s"""
    +         |${m.code}
    +         |$argsName[$i] = ${m.value.code};
    +       """.stripMargin
    +    }
    +
    +    val codes = ctx.splitExpressionsWithCurrentInputs(
    +      expressions = assignments,
    +      funcName = "mapConcat",
    +      extraArguments = (s"${mapDataClass}[]", argsName) :: Nil)
    +
    +    val idxName = ctx.freshName("idx")
    +    val numElementsName = ctx.freshName("numElems")
    +    val finKeysName = ctx.freshName("finalKeys")
    +    val finValsName = ctx.freshName("finalValues")
    +
    +    val keyConcatenator = if (CodeGenerator.isPrimitiveType(keyType)) {
    +      genCodeForPrimitiveArrays(ctx, keyType)
    +    } else {
    +      genCodeForNonPrimitiveArrays(ctx, keyType)
    +    }
    +
    +    val valueConcatenator = if (CodeGenerator.isPrimitiveType(valueType)) {
    +      genCodeForPrimitiveArrays(ctx, valueType)
    +    } else {
    +      genCodeForNonPrimitiveArrays(ctx, valueType)
    +    }
    +
    +    val mapMerge =
    +      s"""
    +        |long $numElementsName = 0;
    +        |for (int $idxName = 0; $idxName < $argsName.length; $idxName++) {
    +        |  if ($argsName[$idxName] == null) {
    +        |    ${ev.isNull} = true;
    +        |    break;
    +        |  }
    +        |  $keyArgsName[$idxName] = $argsName[$idxName].keyArray();
    +        |  $valArgsName[$idxName] = $argsName[$idxName].valueArray();
    +        |  $numElementsName += $argsName[$idxName].numElements();
    +        |}
    +        |
    +        |if (!${ev.isNull}) {
    +        |  if ($numElementsName > $MAX_MAP_SIZE) {
    +        |    throw new RuntimeException("Unsuccessful attempt to concat 
maps with " +
    +        |       $numElementsName + " elements due to exceeding the map 
size limit $MAX_MAP_SIZE.");
    +        |  }
    +        |  $arrayDataClass $finKeysName = 
$keyConcatenator.concat($keyArgsName,
    +        |    (int) $numElementsName);
    +        |  $arrayDataClass $finValsName = 
$valueConcatenator.concat($valArgsName,
    +        |    (int) $numElementsName);
    +        |  ${ev.value} = new $arrayBasedMapDataClass($finKeysName, 
$finValsName);
    +        |}
    +      """.stripMargin
    +
    +    ev.copy(
    +      code = code"""
    +        |$init
    +        |$codes
    +        |$mapMerge
    +      """.stripMargin)
    +  }
    +
    +  private def genCodeForPrimitiveArrays(ctx: CodegenContext, elementType: 
DataType): String = {
    +    val counter = ctx.freshName("counter")
    +    val arrayData = ctx.freshName("arrayData")
    +    val argsName = ctx.freshName("args")
    +    val numElemName = ctx.freshName("numElements")
    +    val primitiveValueTypeName = 
CodeGenerator.primitiveTypeName(elementType)
    +
    +    s"""
    +       |new Object() {
    +       |  public ArrayData concat(${classOf[ArrayData].getName}[] 
$argsName, int $numElemName) {
    +       |    ${ctx.createUnsafeArray(arrayData, numElemName, elementType, 
s" $prettyName failed.")}
    +       |    int $counter = 0;
    +       |    for (int y = 0; y < ${children.length}; y++) {
    +       |      for (int z = 0; z < $argsName[y].numElements(); z++) {
    +       |        if ($argsName[y].isNullAt(z)) {
    --- End diff --
    
    We don't need to check null for keys, and for values if 
`!valueContainsNull`.


---

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to