Github user kiszk commented on a diff in the pull request: https://github.com/apache/spark/pull/21236#discussion_r187690481 --- Diff: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala --- @@ -118,6 +119,161 @@ case class MapValues(child: Expression) override def prettyName: String = "map_values" } +/** + * Returns an unordered array of all entries in the given map. + */ +@ExpressionDescription( + usage = "_FUNC_(map) - Returns an unordered array of all entries in the given map.", + examples = """ + Examples: + > SELECT _FUNC_(map(1, 'a', 2, 'b')); + [(1,"a"),(2,"b")] + """, + since = "2.4.0") +case class MapEntries(child: Expression) extends UnaryExpression with ExpectsInputTypes { + + override def inputTypes: Seq[AbstractDataType] = Seq(MapType) + + lazy val childDataType: MapType = child.dataType.asInstanceOf[MapType] + + override def dataType: DataType = { + ArrayType( + StructType( + StructField("key", childDataType.keyType, false) :: + StructField("value", childDataType.valueType, childDataType.valueContainsNull) :: + Nil), + false) + } + + override protected def nullSafeEval(input: Any): Any = { + val childMap = input.asInstanceOf[MapData] + val keys = childMap.keyArray() + val values = childMap.valueArray() + val length = childMap.numElements() + val resultData = new Array[AnyRef](length) + var i = 0; + while (i < length) { + val key = keys.get(i, childDataType.keyType) + val value = values.get(i, childDataType.valueType) + val row = new GenericInternalRow(Array[Any](key, value)) + resultData.update(i, row) + i += 1 + } + new GenericArrayData(resultData) + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, c => { + val numElements = ctx.freshName("numElements") + val keys = ctx.freshName("keys") + val values = ctx.freshName("values") + val isKeyPrimitive = CodeGenerator.isPrimitiveType(childDataType.keyType) + val isValuePrimitive = CodeGenerator.isPrimitiveType(childDataType.valueType) + val code = if (isKeyPrimitive && isValuePrimitive) { + genCodeForPrimitiveElements(ctx, keys, values, ev.value, numElements) + } else { + genCodeForAnyElements(ctx, keys, values, ev.value, numElements) + } + s""" + |final int $numElements = $c.numElements(); + |final ArrayData $keys = $c.keyArray(); + |final ArrayData $values = $c.valueArray(); + |$code + """.stripMargin + }) + } + + private def getKey(varName: String) = CodeGenerator.getValue(varName, childDataType.keyType, "z") + + private def getValue(varName: String) = { + CodeGenerator.getValue(varName, childDataType.valueType, "z") + } + + private def genCodeForPrimitiveElements( + ctx: CodegenContext, + keys: String, + values: String, + arrayData: String, + numElements: String): String = { + val byteArraySize = ctx.freshName("byteArraySize") + val data = ctx.freshName("byteArray") + val unsafeRow = ctx.freshName("unsafeRow") + val unsafeArrayData = ctx.freshName("unsafeArrayData") + val structsOffset = ctx.freshName("structsOffset") + val calculateArraySize = "UnsafeArrayData.calculateSizeOfUnderlyingByteArray" + val calculateHeader = "UnsafeArrayData.calculateHeaderPortionInBytes" + + val baseOffset = Platform.BYTE_ARRAY_OFFSET + val longSize = LongType.defaultSize + val structSize = UnsafeRow.calculateBitSetWidthInBytes(2) + longSize * 2 + val keyTypeName = CodeGenerator.primitiveTypeName(childDataType.keyType) + val valueTypeName = CodeGenerator.primitiveTypeName(childDataType.keyType) + + val valueAssignment = s"$unsafeRow.set$valueTypeName(1, ${getValue(values)});" + val valueAssignmentChecked = if (childDataType.valueContainsNull) { + s""" + |if ($values.isNullAt(z)) { + | $unsafeRow.setNullAt(1); + |} else { + | $valueAssignment + |} + """.stripMargin + } else { + valueAssignment + } + + s""" + |final long $byteArraySize = $calculateArraySize($numElements, ${longSize + structSize}); + |if ($byteArraySize > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { + | ${genCodeForAnyElements(ctx, keys, values, arrayData, numElements)} + |} else { + | final int $structsOffset = $calculateHeader($numElements) + $numElements * $longSize; + | final byte[] $data = new byte[(int)$byteArraySize]; + | UnsafeArrayData $unsafeArrayData = new UnsafeArrayData(); + | Platform.putLong($data, $baseOffset, $numElements); + | $unsafeArrayData.pointTo($data, $baseOffset, (int)$byteArraySize); + | UnsafeRow $unsafeRow = new UnsafeRow(2); + | for (int z = 0; z < $numElements; z++) { + | long offset = $structsOffset + z * $structSize; --- End diff -- nit: `$structSize` -> `${$structSize}L`
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org