Github user kiszk commented on a diff in the pull request:
https://github.com/apache/spark/pull/13680#discussion_r73949824
--- Diff:
sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/UnsafeArraySuite.scala
---
@@ -18,27 +18,131 @@
package org.apache.spark.sql.catalyst.util
import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions.UnsafeArrayData
+import org.apache.spark.unsafe.Platform
class UnsafeArraySuite extends SparkFunSuite {
- test("from primitive int array") {
- val array = Array(1, 10, 100)
- val unsafe = UnsafeArrayData.fromPrimitiveArray(array)
- assert(unsafe.numElements == 3)
- assert(unsafe.getSizeInBytes == 4 + 4 * 3 + 4 * 3)
- assert(unsafe.getInt(0) == 1)
- assert(unsafe.getInt(1) == 10)
- assert(unsafe.getInt(2) == 100)
+ val booleanArray = Array(false, true)
+ val shortArray = Array(1.toShort, 10.toShort, 100.toShort)
+ val intArray = Array(1, 10, 100)
+ val longArray = Array(1.toLong, 10.toLong, 100.toLong)
+ val floatArray = Array(1.1.toFloat, 2.2.toFloat, 3.3.toFloat)
+ val doubleArray = Array(1.1, 2.2, 3.3)
+ val stringArray = Array("1", "10", "100")
--- End diff --
@davies could you favor me? To be honest, I am not familiar with
implementation of `Decimal`. Could you please let me know how I am wrong?
I wrote the following test suite and got the generated code for projection.
In generated code, lines 90 and 91 always uses 38 and 18 for precision and
scale while I pass different ```BigDecimal``` in the test suite. These code are
generated at
[here](https://github.com/apache/spark/blob/master/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala#L218)
and
[here](https://github.com/apache/spark/blob/23c58653f900bfb71ef2b3186a95ad2562c33969/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala).
I suspect that these places refers to
[DecimalType.SYSTEM_DEFAULT](https://github.com/apache/spark/blob/d5911d1173fe0872f21cae6c47abf8ff479345a4/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala#L110).
`DecimalType.SYSTEM_DEFAULT` uses 38 and 18.
Should we use always these values (38 and 18)? Would it be possible to let
us know what are other test cases that use ```DecimalType``` in
```UnsafeArrayData```?
```java
val decimalArray = Array(BigDecimal("123").setScale(1,
BigDecimal.RoundingMode.FLOOR))
test("read array") {
val unsafeDecimal =
ExpressionEncoder[Array[BigDecimal]].resolveAndBind().
toRow(decimalArray).getArray(0)
assert(unsafeDecimal.isInstanceOf[UnsafeArrayData])
assert(unsafeDecimal.numElements == decimalArray.length)
decimalArray.zipWithIndex.map { case (e, i) =>
print(s"e: $e, ${e.precision}, ${e.scale}\n")
assert(unsafeDecimal.getDecimal(i, e.precision, e.scale) == e)
}
}
```
```java
/* 031 */ public UnsafeRow apply(InternalRow i) {
/* 032 */ holder.reset();
/* 033 */
/* 034 */ rowWriter.zeroOutNullBytes();
/* 035 */
/* 036 */
/* 037 */ boolean isNull1 = i.isNullAt(0);
/* 038 */ scala.math.BigDecimal[] value1 = isNull1 ? null :
((scala.math.BigDecimal[])i.get(0, null));
/* 039 */ ArrayData value = null;
/* 040 */
/* 041 */ if (!isNull1) {
/* 042 */
/* 043 */ Decimal[] convertedArray = null;
/* 044 */ int dataLength = value1.length;
/* 045 */ convertedArray = new Decimal[dataLength];
/* 046 */
/* 047 */ int loopIndex = 0;
/* 048 */ while (loopIndex < dataLength) {
/* 049 */ MapObjects_loopValue0 = (scala.math.BigDecimal)
(value1[loopIndex]);
/* 050 */ MapObjects_loopIsNull1 = MapObjects_loopValue0 == null;
/* 051 */
/* 052 */
/* 053 */ boolean isNull2 = MapObjects_loopIsNull1;
/* 054 */ final Decimal value2 = isNull2 ? null :
org.apache.spark.sql.types.Decimal.apply(MapObjects_loopValue0);
/* 055 */ isNull2 = value2 == null;
/* 056 */ if (isNull2) {
/* 057 */ convertedArray[loopIndex] = null;
/* 058 */ } else {
/* 059 */ convertedArray[loopIndex] = value2;
/* 060 */ }
/* 061 */
/* 062 */ loopIndex += 1;
/* 063 */ }
/* 064 */
/* 065 */ value = new
org.apache.spark.sql.catalyst.util.GenericArrayData(convertedArray);
/* 066 */ }
/* 067 */ if (isNull1) {
/* 068 */ rowWriter.setNullAt(0);
/* 069 */ } else {
/* 070 */ // Remember the current cursor so that we can calculate how
many bytes are
/* 071 */ // written later.
/* 072 */ final int tmpCursor = holder.cursor;
/* 073 */
/* 074 */ if (value instanceof UnsafeArrayData) {
/* 075 */
/* 076 */ final int sizeInBytes = ((UnsafeArrayData)
value).getSizeInBytes();
/* 077 */ // grow the global buffer before writing data.
/* 078 */ holder.grow(sizeInBytes);
/* 079 */ ((UnsafeArrayData) value).writeToMemory(holder.buffer,
holder.cursor);
/* 080 */ holder.cursor += sizeInBytes;
/* 081 */
/* 082 */ } else {
/* 083 */ final int numElements = value.numElements();
/* 084 */ arrayWriter.initialize(holder, numElements, 4);
/* 085 */
/* 086 */ for (int index = 0; index < numElements; index++) {
/* 087 */ if (value.isNullAt(index)) {
/* 088 */ arrayWriter.setNull(index);
/* 089 */ } else {
/* 090 */ final Decimal element = value.getDecimal(index, 38,
18);
/* 091 */ arrayWriter.write(index, element, 38, 18);
/* 092 */ }
/* 093 */ }
/* 094 */ }
/* 095 */
/* 096 */ rowWriter.setOffsetAndSize(0, tmpCursor, holder.cursor -
tmpCursor);
/* 097 */ rowWriter.alignToWords(holder.cursor - tmpCursor);
/* 098 */ }
/* 099 */ result.setTotalSize(holder.totalSize());
/* 100 */ return result;
/* 101 */ }
```
---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]