Github user ueshin commented on a diff in the pull request:
https://github.com/apache/spark/pull/21073#discussion_r197666362
--- Diff:
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
---
@@ -475,6 +474,231 @@ case class MapEntries(child: Expression) extends
UnaryExpression with ExpectsInp
override def prettyName: String = "map_entries"
}
+/**
+ * Returns the union of all the given maps.
+ */
+@ExpressionDescription(
+ usage = "_FUNC_(map, ...) - Returns the union of all the given maps",
+ examples = """
+ Examples:
+ > SELECT _FUNC_(map(1, 'a', 2, 'b'), map(2, 'c', 3, 'd'));
+ [[1 -> "a"], [2 -> "b"], [2 -> "c"], [3 -> "d"]]
+ """, since = "2.4.0")
+case class MapConcat(children: Seq[Expression]) extends Expression {
+
+ override def checkInputDataTypes(): TypeCheckResult = {
+ var funcName = s"function $prettyName"
+ if (children.exists(!_.dataType.isInstanceOf[MapType])) {
+ TypeCheckResult.TypeCheckFailure(
+ s"input to $funcName should all be of type map, but it's " +
+ children.map(_.dataType.simpleString).mkString("[", ", ", "]"))
+ } else {
+ TypeUtils.checkForSameTypeInputExpr(children.map(_.dataType),
funcName)
+ }
+ }
+
+ override def dataType: MapType = {
+ val dt = children.map(_.dataType.asInstanceOf[MapType]).headOption
+ .getOrElse(MapType(StringType, StringType))
+ val valueContainsNull = children.map(_.dataType.asInstanceOf[MapType])
+ .exists(_.valueContainsNull)
+ if (dt.valueContainsNull != valueContainsNull) {
+ dt.copy(valueContainsNull = valueContainsNull)
+ } else {
+ dt
+ }
+ }
+
+ override def nullable: Boolean = children.exists(_.nullable)
+
+ override def eval(input: InternalRow): Any = {
+ val maps = children.map(_.eval(input))
+ if (maps.contains(null)) {
+ return null
+ }
+ val keyArrayDatas = maps.map(_.asInstanceOf[MapData].keyArray())
+ val valueArrayDatas = maps.map(_.asInstanceOf[MapData].valueArray())
+
+ val numElements = keyArrayDatas.foldLeft(0L)((sum, ad) => sum +
ad.numElements())
+ if (numElements > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
+ throw new RuntimeException(s"Unsuccessful attempt to concat maps
with $numElements " +
+ s"elements due to exceeding the map size limit " +
+ s"${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}.")
+ }
+ val finalKeyArray = new Array[AnyRef](numElements.toInt)
+ val finalValueArray = new Array[AnyRef](numElements.toInt)
+ var position = 0
+ for (i <- keyArrayDatas.indices) {
+ val keyArray = keyArrayDatas(i).toObjectArray(dataType.keyType)
+ val valueArray = valueArrayDatas(i).toObjectArray(dataType.valueType)
+ Array.copy(keyArray, 0, finalKeyArray, position, keyArray.length)
+ Array.copy(valueArray, 0, finalValueArray, position,
valueArray.length)
+ position += keyArray.length
+ }
+
+ new ArrayBasedMapData(new GenericArrayData(finalKeyArray),
+ new GenericArrayData(finalValueArray))
+ }
+
+ override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+ val mapCodes = children.map(_.genCode(ctx))
+ val keyType = dataType.keyType
+ val valueType = dataType.valueType
+ val argsName = ctx.freshName("args")
+ val keyArgsName = ctx.freshName("keyArgs")
+ val valArgsName = ctx.freshName("valArgs")
+
+ val mapDataClass = classOf[MapData].getName
+ val arrayBasedMapDataClass = classOf[ArrayBasedMapData].getName
+ val arrayDataClass = classOf[ArrayData].getName
+
+ val init =
+ s"""
+ |$mapDataClass[] $argsName = new $mapDataClass[${mapCodes.size}];
+ |$arrayDataClass[] $keyArgsName = new
$arrayDataClass[${mapCodes.size}];
+ |$arrayDataClass[] $valArgsName = new
$arrayDataClass[${mapCodes.size}];
+ |boolean ${ev.isNull} = false;
+ |$mapDataClass ${ev.value} = null;
+ """.stripMargin
+
+ val assignments = mapCodes.zipWithIndex.map { case (m, i) =>
+ s"""
+ |${m.code}
+ |$argsName[$i] = ${m.value};
+ |if (${m.isNull}) {
+ | ${ev.isNull} = true;
+ |}
+ """.stripMargin
+ }
+
+ val codes = ctx.splitExpressionsWithCurrentInputs(
+ expressions = assignments,
+ funcName = "getMapConcatInputs",
+ extraArguments = (s"$mapDataClass[]", argsName) :: ("boolean",
ev.isNull.code) :: Nil,
+ returnType = "boolean",
+ makeSplitFunction = body =>
+ s"""
+ |$body
+ |return ${ev.isNull};
+ """.stripMargin,
+ foldFunctions = _.map(funcCall => s"${ev.isNull} =
$funcCall;").mkString("\n")
+ )
+
+ val idxName = ctx.freshName("idx")
+ val numElementsName = ctx.freshName("numElems")
+ val finKeysName = ctx.freshName("finalKeys")
+ val finValsName = ctx.freshName("finalValues")
+
+ val keyConcatenator = if (CodeGenerator.isPrimitiveType(keyType)) {
+ genCodeForPrimitiveArrays(ctx, keyType, false)
+ } else {
+ genCodeForNonPrimitiveArrays(ctx, keyType)
+ }
+
+ val valueConcatenator = if (CodeGenerator.isPrimitiveType(valueType)) {
+ genCodeForPrimitiveArrays(ctx, valueType, dataType.valueContainsNull)
+ } else {
+ genCodeForNonPrimitiveArrays(ctx, valueType)
+ }
+
+ val mapMerge =
+ s"""
+ |if (!${ev.isNull}) {
+ | long $numElementsName = 0;
+ | for (int $idxName = 0; $idxName < $argsName.length; $idxName++)
{
+ | $keyArgsName[$idxName] = $argsName[$idxName].keyArray();
+ | $valArgsName[$idxName] = $argsName[$idxName].valueArray();
+ | $numElementsName += $argsName[$idxName].numElements();
+ | }
+ | if ($numElementsName >
${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) {
+ | throw new RuntimeException("Unsuccessful attempt to concat
maps with " +
+ | $numElementsName + " elements due to exceeding the map
size limit " +
+ | "${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}.");
+ | }
+ | $arrayDataClass $finKeysName =
$keyConcatenator.concat($keyArgsName,
+ | (int) $numElementsName);
+ | $arrayDataClass $finValsName =
$valueConcatenator.concat($valArgsName,
+ | (int) $numElementsName);
+ | ${ev.value} = new $arrayBasedMapDataClass($finKeysName,
$finValsName);
+ |}
+ """.stripMargin
+
+ ev.copy(
+ code = code"""
+ |$init
+ |$codes
+ |$mapMerge
+ """.stripMargin)
+ }
+
+ private def genCodeForPrimitiveArrays(ctx: CodegenContext, elementType:
DataType,
+ checkForNull: Boolean): String = {
+ val counter = ctx.freshName("counter")
+ val arrayData = ctx.freshName("arrayData")
+ val argsName = ctx.freshName("args")
+ val numElemName = ctx.freshName("numElements")
+ val primitiveValueTypeName =
CodeGenerator.primitiveTypeName(elementType)
+
+ val setterCode1 =
+ s"""
+ |$arrayData.set$primitiveValueTypeName(
+ | $counter,
+ | ${CodeGenerator.getValue(s"$argsName[y]", elementType, "z")}
+ |);""".stripMargin.stripPrefix("\n")
+
+ val setterCode = if (checkForNull) {
+ s"""
+ |if ($argsName[y].isNullAt(z)) {
+ | $arrayData.setNullAt($counter);
+ |} else {
+ | $setterCode1
+ |}""".stripMargin.stripPrefix("\n")
--- End diff --
ditto.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]