GitHub user michalsenkyr opened a pull request:
https://github.com/apache/spark/pull/16986
[SPARK-18891][SQL] Support for Map collection types
## What changes were proposed in this pull request?
Add support for arbitrary Scala `Map` types in deserialization as well as a
generic implicit encoder.
Used the builder approach as in #16541 to construct any provided `Map` type
upon deserialization.
Please note that this PR also adds (ignored) tests for issue [SPARK-19104
CompileException with Map and Case Class in Spark
2.1.0](https://issues.apache.org/jira/browse/SPARK-19104) but doesn't solve it.
Resulting codegen for `Seq(Map(1 ->
2)).toDS().map(identity).queryExecution.debug.codegen`:
```
/* 001 */ public Object generate(Object[] references) {
/* 002 */ return new GeneratedIterator(references);
/* 003 */ }
/* 004 */
/* 005 */ final class GeneratedIterator extends
org.apache.spark.sql.execution.BufferedRowIterator {
/* 006 */ private Object[] references;
/* 007 */ private scala.collection.Iterator[] inputs;
/* 008 */ private scala.collection.Iterator inputadapter_input;
/* 009 */ private boolean CollectObjectsToMap_loopIsNull1;
/* 010 */ private int CollectObjectsToMap_loopValue0;
/* 011 */ private boolean CollectObjectsToMap_loopIsNull3;
/* 012 */ private int CollectObjectsToMap_loopValue2;
/* 013 */ private UnsafeRow deserializetoobject_result;
/* 014 */ private
org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder
deserializetoobject_holder;
/* 015 */ private
org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter
deserializetoobject_rowWriter;
/* 016 */ private scala.collection.immutable.Map mapelements_argValue;
/* 017 */ private UnsafeRow mapelements_result;
/* 018 */ private
org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder
mapelements_holder;
/* 019 */ private
org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter
mapelements_rowWriter;
/* 020 */ private UnsafeRow serializefromobject_result;
/* 021 */ private
org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder
serializefromobject_holder;
/* 022 */ private
org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter
serializefromobject_rowWriter;
/* 023 */ private
org.apache.spark.sql.catalyst.expressions.codegen.UnsafeArrayWriter
serializefromobject_arrayWriter;
/* 024 */ private
org.apache.spark.sql.catalyst.expressions.codegen.UnsafeArrayWriter
serializefromobject_arrayWriter1;
/* 025 */
/* 026 */ public GeneratedIterator(Object[] references) {
/* 027 */ this.references = references;
/* 028 */ }
/* 029 */
/* 030 */ public void init(int index, scala.collection.Iterator[] inputs)
{
/* 031 */ partitionIndex = index;
/* 032 */ this.inputs = inputs;
/* 033 */ wholestagecodegen_init_0();
/* 034 */ wholestagecodegen_init_1();
/* 035 */
/* 036 */ }
/* 037 */
/* 038 */ private void wholestagecodegen_init_0() {
/* 039 */ inputadapter_input = inputs[0];
/* 040 */
/* 041 */ deserializetoobject_result = new UnsafeRow(1);
/* 042 */ this.deserializetoobject_holder = new
org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(deserializetoobject_result,
32);
/* 043 */ this.deserializetoobject_rowWriter = new
org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(deserializetoobject_holder,
1);
/* 044 */
/* 045 */ mapelements_result = new UnsafeRow(1);
/* 046 */ this.mapelements_holder = new
org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(mapelements_result,
32);
/* 047 */ this.mapelements_rowWriter = new
org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(mapelements_holder,
1);
/* 048 */ serializefromobject_result = new UnsafeRow(1);
/* 049 */ this.serializefromobject_holder = new
org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(serializefromobject_result,
32);
/* 050 */ this.serializefromobject_rowWriter = new
org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(serializefromobject_holder,
1);
/* 051 */ this.serializefromobject_arrayWriter = new
org.apache.spark.sql.catalyst.expressions.codegen.UnsafeArrayWriter();
/* 052 */
/* 053 */ }
/* 054 */
/* 055 */ private void wholestagecodegen_init_1() {
/* 056 */ this.serializefromobject_arrayWriter1 = new
org.apache.spark.sql.catalyst.expressions.codegen.UnsafeArrayWriter();
/* 057 */
/* 058 */ }
/* 059 */
/* 060 */ protected void processNext() throws java.io.IOException {
/* 061 */ while (inputadapter_input.hasNext() && !stopEarly()) {
/* 062 */ InternalRow inputadapter_row = (InternalRow)
inputadapter_input.next();
/* 063 */ boolean inputadapter_isNull = inputadapter_row.isNullAt(0);
/* 064 */ MapData inputadapter_value = inputadapter_isNull ? null :
(inputadapter_row.getMap(0));
/* 065 */
/* 066 */ boolean deserializetoobject_isNull1 = true;
/* 067 */ ArrayData deserializetoobject_value1 = null;
/* 068 */ if (!inputadapter_isNull) {
/* 069 */ deserializetoobject_isNull1 = false;
/* 070 */ if (!deserializetoobject_isNull1) {
/* 071 */ Object deserializetoobject_funcResult = null;
/* 072 */ deserializetoobject_funcResult =
inputadapter_value.keyArray();
/* 073 */ if (deserializetoobject_funcResult == null) {
/* 074 */ deserializetoobject_isNull1 = true;
/* 075 */ } else {
/* 076 */ deserializetoobject_value1 = (ArrayData)
deserializetoobject_funcResult;
/* 077 */ }
/* 078 */
/* 079 */ }
/* 080 */ deserializetoobject_isNull1 = deserializetoobject_value1
== null;
/* 081 */ }
/* 082 */
/* 083 */ boolean deserializetoobject_isNull3 = true;
/* 084 */ ArrayData deserializetoobject_value3 = null;
/* 085 */ if (!inputadapter_isNull) {
/* 086 */ deserializetoobject_isNull3 = false;
/* 087 */ if (!deserializetoobject_isNull3) {
/* 088 */ Object deserializetoobject_funcResult1 = null;
/* 089 */ deserializetoobject_funcResult1 =
inputadapter_value.valueArray();
/* 090 */ if (deserializetoobject_funcResult1 == null) {
/* 091 */ deserializetoobject_isNull3 = true;
/* 092 */ } else {
/* 093 */ deserializetoobject_value3 = (ArrayData)
deserializetoobject_funcResult1;
/* 094 */ }
/* 095 */
/* 096 */ }
/* 097 */ deserializetoobject_isNull3 = deserializetoobject_value3
== null;
/* 098 */ }
/* 099 */ scala.collection.immutable.Map deserializetoobject_value =
null;
/* 100 */
/* 101 */ if ((deserializetoobject_isNull1 &&
!deserializetoobject_isNull3) ||
/* 102 */ (!deserializetoobject_isNull1 &&
deserializetoobject_isNull3)) {
/* 103 */ throw new RuntimeException("Invalid state: Inconsistent
nullability of key-value");
/* 104 */ }
/* 105 */
/* 106 */ if (!deserializetoobject_isNull1) {
/* 107 */ if (deserializetoobject_value1.numElements() !=
deserializetoobject_value3.numElements()) {
/* 108 */ throw new RuntimeException("Invalid state: Inconsistent
lengths of key-value arrays");
/* 109 */ }
/* 110 */ int deserializetoobject_dataLength =
deserializetoobject_value1.numElements();
/* 111 */ scala.collection.mutable.Builder
CollectObjectsToMap_builderValue5 =
scala.collection.immutable.Map$.MODULE$.newBuilder();
/* 112 */
CollectObjectsToMap_builderValue5.sizeHint(deserializetoobject_dataLength);
/* 113 */
/* 114 */ int deserializetoobject_loopIndex = 0;
/* 115 */ while (deserializetoobject_loopIndex <
deserializetoobject_dataLength) {
/* 116 */ CollectObjectsToMap_loopValue0 = (int)
(deserializetoobject_value1.getInt(deserializetoobject_loopIndex));
/* 117 */ CollectObjectsToMap_loopValue2 = (int)
(deserializetoobject_value3.getInt(deserializetoobject_loopIndex));
/* 118 */ CollectObjectsToMap_loopIsNull1 =
deserializetoobject_value1.isNullAt(deserializetoobject_loopIndex);
/* 119 */ CollectObjectsToMap_loopIsNull3 =
deserializetoobject_value3.isNullAt(deserializetoobject_loopIndex);
/* 120 */
/* 121 */ scala.Tuple2 CollectObjectsToMap_loopValue4;
/* 122 */
/* 123 */ if (CollectObjectsToMap_loopIsNull1) {
/* 124 */ throw new RuntimeException("Found null in map key!");
/* 125 */ }
/* 126 */
/* 127 */ if (CollectObjectsToMap_loopIsNull3) {
/* 128 */ CollectObjectsToMap_loopValue4 = new
scala.Tuple2(CollectObjectsToMap_loopValue0, null);
/* 129 */ } else {
/* 130 */ CollectObjectsToMap_loopValue4 = new
scala.Tuple2(CollectObjectsToMap_loopValue0, CollectObjectsToMap_loopValue2);
/* 131 */ }
/* 132 */
/* 133 */
CollectObjectsToMap_builderValue5.$plus$eq(CollectObjectsToMap_loopValue4);
/* 134 */
/* 135 */ deserializetoobject_loopIndex += 1;
/* 136 */ }
/* 137 */
/* 138 */ deserializetoobject_value =
(scala.collection.immutable.Map) CollectObjectsToMap_builderValue5.result();
/* 139 */ }
/* 140 */
/* 141 */ boolean mapelements_isNull = true;
/* 142 */ scala.collection.immutable.Map mapelements_value = null;
/* 143 */ if (!false) {
/* 144 */ mapelements_argValue = deserializetoobject_value;
/* 145 */
/* 146 */ mapelements_isNull = false;
/* 147 */ if (!mapelements_isNull) {
/* 148 */ Object mapelements_funcResult = null;
/* 149 */ mapelements_funcResult = ((scala.Function1)
references[0]).apply(mapelements_argValue);
/* 150 */ if (mapelements_funcResult == null) {
/* 151 */ mapelements_isNull = true;
/* 152 */ } else {
/* 153 */ mapelements_value = (scala.collection.immutable.Map)
mapelements_funcResult;
/* 154 */ }
/* 155 */
/* 156 */ }
/* 157 */ mapelements_isNull = mapelements_value == null;
/* 158 */ }
/* 159 */
/* 160 */ MapData serializefromobject_value = null;
/* 161 */ if (!mapelements_isNull) {
/* 162 */ final int serializefromobject_length =
mapelements_value.size();
/* 163 */ final Object[] serializefromobject_convertedKeys = new
Object[serializefromobject_length];
/* 164 */ final Object[] serializefromobject_convertedValues = new
Object[serializefromobject_length];
/* 165 */ int serializefromobject_index = 0;
/* 166 */ final scala.collection.Iterator
serializefromobject_entries = mapelements_value.iterator();
/* 167 */ while(serializefromobject_entries.hasNext()) {
/* 168 */ final scala.Tuple2 serializefromobject_entry =
(scala.Tuple2) serializefromobject_entries.next();
/* 169 */ int ExternalMapToCatalyst_key1 = (Integer)
serializefromobject_entry._1();
/* 170 */ int ExternalMapToCatalyst_value1 = (Integer)
serializefromobject_entry._2();
/* 171 */
/* 172 */ boolean ExternalMapToCatalyst_value_isNull1 = false;
/* 173 */
/* 174 */ if (false) {
/* 175 */ throw new RuntimeException("Cannot use null as map
key!");
/* 176 */ } else {
/* 177 */
serializefromobject_convertedKeys[serializefromobject_index] = (Integer)
ExternalMapToCatalyst_key1;
/* 178 */ }
/* 179 */
/* 180 */ if (false) {
/* 181 */
serializefromobject_convertedValues[serializefromobject_index] = null;
/* 182 */ } else {
/* 183 */
serializefromobject_convertedValues[serializefromobject_index] = (Integer)
ExternalMapToCatalyst_value1;
/* 184 */ }
/* 185 */
/* 186 */ serializefromobject_index++;
/* 187 */ }
/* 188 */
/* 189 */ serializefromobject_value = new
org.apache.spark.sql.catalyst.util.ArrayBasedMapData(new
org.apache.spark.sql.catalyst.util.GenericArrayData(serializefromobject_convertedKeys),
new
org.apache.spark.sql.catalyst.util.GenericArrayData(serializefromobject_convertedValues));
/* 190 */ }
/* 191 */ serializefromobject_holder.reset();
/* 192 */
/* 193 */ serializefromobject_rowWriter.zeroOutNullBytes();
/* 194 */
/* 195 */ if (mapelements_isNull) {
/* 196 */ serializefromobject_rowWriter.setNullAt(0);
/* 197 */ } else {
/* 198 */ // Remember the current cursor so that we can calculate
how many bytes are
/* 199 */ // written later.
/* 200 */ final int serializefromobject_tmpCursor =
serializefromobject_holder.cursor;
/* 201 */
/* 202 */ if (serializefromobject_value instanceof UnsafeMapData) {
/* 203 */ final int serializefromobject_sizeInBytes =
((UnsafeMapData) serializefromobject_value).getSizeInBytes();
/* 204 */ // grow the global buffer before writing data.
/* 205 */
serializefromobject_holder.grow(serializefromobject_sizeInBytes);
/* 206 */ ((UnsafeMapData)
serializefromobject_value).writeToMemory(serializefromobject_holder.buffer,
serializefromobject_holder.cursor);
/* 207 */ serializefromobject_holder.cursor +=
serializefromobject_sizeInBytes;
/* 208 */
/* 209 */ } else {
/* 210 */ final ArrayData serializefromobject_keys =
serializefromobject_value.keyArray();
/* 211 */ final ArrayData serializefromobject_values =
serializefromobject_value.valueArray();
/* 212 */
/* 213 */ // preserve 8 bytes to write the key array numBytes
later.
/* 214 */ serializefromobject_holder.grow(8);
/* 215 */ serializefromobject_holder.cursor += 8;
/* 216 */
/* 217 */ // Remember the current cursor so that we can write
numBytes of key array later.
/* 218 */ final int serializefromobject_tmpCursor1 =
serializefromobject_holder.cursor;
/* 219 */
/* 220 */ if (serializefromobject_keys instanceof
UnsafeArrayData) {
/* 221 */ final int serializefromobject_sizeInBytes1 =
((UnsafeArrayData) serializefromobject_keys).getSizeInBytes();
/* 222 */ // grow the global buffer before writing data.
/* 223 */
serializefromobject_holder.grow(serializefromobject_sizeInBytes1);
/* 224 */ ((UnsafeArrayData)
serializefromobject_keys).writeToMemory(serializefromobject_holder.buffer,
serializefromobject_holder.cursor);
/* 225 */ serializefromobject_holder.cursor +=
serializefromobject_sizeInBytes1;
/* 226 */
/* 227 */ } else {
/* 228 */ final int serializefromobject_numElements =
serializefromobject_keys.numElements();
/* 229 */
serializefromobject_arrayWriter.initialize(serializefromobject_holder,
serializefromobject_numElements, 4);
/* 230 */
/* 231 */ for (int serializefromobject_index1 = 0;
serializefromobject_index1 < serializefromobject_numElements;
serializefromobject_index1++) {
/* 232 */ if
(serializefromobject_keys.isNullAt(serializefromobject_index1)) {
/* 233 */
serializefromobject_arrayWriter.setNullInt(serializefromobject_index1);
/* 234 */ } else {
/* 235 */ final int serializefromobject_element =
serializefromobject_keys.getInt(serializefromobject_index1);
/* 236 */
serializefromobject_arrayWriter.write(serializefromobject_index1,
serializefromobject_element);
/* 237 */ }
/* 238 */ }
/* 239 */ }
/* 240 */
/* 241 */ // Write the numBytes of key array into the first 8
bytes.
/* 242 */ Platform.putLong(serializefromobject_holder.buffer,
serializefromobject_tmpCursor1 - 8, serializefromobject_holder.cursor -
serializefromobject_tmpCursor1);
/* 243 */
/* 244 */ if (serializefromobject_values instanceof
UnsafeArrayData) {
/* 245 */ final int serializefromobject_sizeInBytes2 =
((UnsafeArrayData) serializefromobject_values).getSizeInBytes();
/* 246 */ // grow the global buffer before writing data.
/* 247 */
serializefromobject_holder.grow(serializefromobject_sizeInBytes2);
/* 248 */ ((UnsafeArrayData)
serializefromobject_values).writeToMemory(serializefromobject_holder.buffer,
serializefromobject_holder.cursor);
/* 249 */ serializefromobject_holder.cursor +=
serializefromobject_sizeInBytes2;
/* 250 */
/* 251 */ } else {
/* 252 */ final int serializefromobject_numElements1 =
serializefromobject_values.numElements();
/* 253 */
serializefromobject_arrayWriter1.initialize(serializefromobject_holder,
serializefromobject_numElements1, 4);
/* 254 */
/* 255 */ for (int serializefromobject_index2 = 0;
serializefromobject_index2 < serializefromobject_numElements1;
serializefromobject_index2++) {
/* 256 */ if
(serializefromobject_values.isNullAt(serializefromobject_index2)) {
/* 257 */
serializefromobject_arrayWriter1.setNullInt(serializefromobject_index2);
/* 258 */ } else {
/* 259 */ final int serializefromobject_element1 =
serializefromobject_values.getInt(serializefromobject_index2);
/* 260 */
serializefromobject_arrayWriter1.write(serializefromobject_index2,
serializefromobject_element1);
/* 261 */ }
/* 262 */ }
/* 263 */ }
/* 264 */
/* 265 */ }
/* 266 */
/* 267 */ serializefromobject_rowWriter.setOffsetAndSize(0,
serializefromobject_tmpCursor, serializefromobject_holder.cursor -
serializefromobject_tmpCursor);
/* 268 */ }
/* 269 */
serializefromobject_result.setTotalSize(serializefromobject_holder.totalSize());
/* 270 */ append(serializefromobject_result);
/* 271 */ if (shouldStop()) return;
/* 272 */ }
/* 273 */ }
/* 274 */ }
```
## How was this patch tested?
```
build/mvn -DskipTests clean package && dev/run-tests
```
Additionally in Spark shell:
```
scala> Seq(collection.mutable.HashMap(1 -> 2, 2 -> 3)).toDS().map(_ += (3
-> 4)).collect()
res0: Array[scala.collection.mutable.HashMap[Int,Int]] = Array(Map(2 -> 3,
1 -> 2, 3 -> 4))
```
You can merge this pull request into a Git repository by running:
$ git pull https://github.com/michalsenkyr/spark dataset-map-builder
Alternatively you can review and apply these changes as the patch at:
https://github.com/apache/spark/pull/16986.patch
To close this pull request, make a commit to your master/trunk branch
with (at least) the following in the commit message:
This closes #16986
----
commit 2da9ffb1e89f765d816fc1ac5f95372dbc934aa4
Author: Michal Senkyr <[email protected]>
Date: 2017-02-12T15:51:51Z
Arbitrary map support implementation
----
---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]