Github user kiszk commented on a diff in the pull request:
https://github.com/apache/spark/pull/21912#discussion_r213707288
--- Diff:
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
---
@@ -385,107 +385,120 @@ case class MapEntries(child: Expression) extends
UnaryExpression with ExpectsInp
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode):
ExprCode = {
nullSafeCodeGen(ctx, ev, c => {
+ val arrayData = ctx.freshName("arrayData")
val numElements = ctx.freshName("numElements")
val keys = ctx.freshName("keys")
val values = ctx.freshName("values")
val isKeyPrimitive =
CodeGenerator.isPrimitiveType(childDataType.keyType)
val isValuePrimitive =
CodeGenerator.isPrimitiveType(childDataType.valueType)
+
+ val wordSize = UnsafeRow.WORD_SIZE
+ val structSize = UnsafeRow.calculateBitSetWidthInBytes(2) + wordSize
* 2
+ val elementSize = if (isKeyPrimitive && isValuePrimitive) {
+ Some(structSize + wordSize)
+ } else {
+ None
+ }
+
+ val allocation = CodeGenerator.createArrayData(arrayData,
childDataType.keyType, numElements,
+ s" $prettyName failed.", elementSize = elementSize)
+
val code = if (isKeyPrimitive && isValuePrimitive) {
- genCodeForPrimitiveElements(ctx, keys, values, ev.value,
numElements)
+ val genCodeForPrimitive = genCodeForPrimitiveElements(
+ ctx, arrayData, keys, values, ev.value, numElements, structSize)
+ s"""
+ |if ($arrayData instanceof UnsafeArrayData) {
+ | $genCodeForPrimitive
+ |} else {
+ | ${genCodeForAnyElements(ctx, arrayData, keys, values,
ev.value, numElements)}
+ |}
+ """.stripMargin
} else {
- genCodeForAnyElements(ctx, keys, values, ev.value, numElements)
+ s"${genCodeForAnyElements(ctx, arrayData, keys, values, ev.value,
numElements)}"
}
+
s"""
|final int $numElements = $c.numElements();
|final ArrayData $keys = $c.keyArray();
|final ArrayData $values = $c.valueArray();
+ |$allocation
|$code
""".stripMargin
})
}
- private def getKey(varName: String) = CodeGenerator.getValue(varName,
childDataType.keyType, "z")
+ private def getKey(varName: String, index: String) =
+ CodeGenerator.getValue(varName, childDataType.keyType, index)
- private def getValue(varName: String) = {
- CodeGenerator.getValue(varName, childDataType.valueType, "z")
- }
+ private def getValue(varName: String, index: String) =
+ CodeGenerator.getValue(varName, childDataType.valueType, index)
private def genCodeForPrimitiveElements(
ctx: CodegenContext,
+ arrayData: String,
keys: String,
values: String,
- arrayData: String,
- numElements: String): String = {
- val unsafeRow = ctx.freshName("unsafeRow")
+ resultArrayData: String,
+ numElements: String,
+ structSize: Int): String = {
val unsafeArrayData = ctx.freshName("unsafeArrayData")
+ val baseObject = ctx.freshName("baseObject")
+ val unsafeRow = ctx.freshName("unsafeRow")
val structsOffset = ctx.freshName("structsOffset")
+ val offset = ctx.freshName("offset")
+ val z = ctx.freshName("z")
val calculateHeader = "UnsafeArrayData.calculateHeaderPortionInBytes"
val baseOffset = Platform.BYTE_ARRAY_OFFSET
val wordSize = UnsafeRow.WORD_SIZE
- val structSize = UnsafeRow.calculateBitSetWidthInBytes(2) + wordSize *
2
- val structSizeAsLong = structSize + "L"
- val keyTypeName =
CodeGenerator.primitiveTypeName(childDataType.keyType)
- val valueTypeName =
CodeGenerator.primitiveTypeName(childDataType.valueType)
-
- val valueAssignment = s"$unsafeRow.set$valueTypeName(1,
${getValue(values)});"
- val valueAssignmentChecked = if (childDataType.valueContainsNull) {
- s"""
- |if ($values.isNullAt(z)) {
- | $unsafeRow.setNullAt(1);
- |} else {
- | $valueAssignment
- |}
- """.stripMargin
- } else {
- valueAssignment
- }
+ val structSizeAsLong = s"${structSize}L"
- val assignmentLoop = (byteArray: String) =>
- s"""
- |final int $structsOffset = $calculateHeader($numElements) +
$numElements * $wordSize;
- |UnsafeRow $unsafeRow = new UnsafeRow(2);
- |for (int z = 0; z < $numElements; z++) {
- | long offset = $structsOffset + z * $structSizeAsLong;
- | $unsafeArrayData.setLong(z, (offset << 32) +
$structSizeAsLong);
- | $unsafeRow.pointTo($byteArray, $baseOffset + offset,
$structSize);
- | $unsafeRow.set$keyTypeName(0, ${getKey(keys)});
- | $valueAssignmentChecked
- |}
- |$arrayData = $unsafeArrayData;
- """.stripMargin
+ val setKey =
+ CodeGenerator.setArrayElement(unsafeRow, childDataType.keyType, "0",
getKey(keys, z))
--- End diff --
@cloud-fan Good catch. We will use `setColumn` here.
When I checked source files, it is not straightforward since there are two
differences.
1. `value()` is called for `StructType`, `ArrayType`, and others in
`setColumn`
2. `setDecimal()` is not supported in `Array`
If we add one boolean value to distinguish `column` and `array`, we can
unify them into one. Do we do this unification?
If no, do we update `setColumn` to generate `if` statement for nullcheck
like `setArrayElement`? Or, will we update `setColumn` in another PR?
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]