GitHub user michalsenkyr opened a pull request:

    https://github.com/apache/spark/pull/16986

    [SPARK-18891][SQL] Support for Map collection types

    ## What changes were proposed in this pull request?
    
    Add support for arbitrary Scala `Map` types in deserialization as well as a 
generic implicit encoder.
    
    Used the builder approach as in #16541 to construct any provided `Map` type 
upon deserialization.
    
    Please note that this PR also adds (ignored) tests for issue [SPARK-19104 
CompileException with Map and Case Class in Spark 
2.1.0](https://issues.apache.org/jira/browse/SPARK-19104) but doesn't solve it.
    
    Resulting codegen for `Seq(Map(1 -> 
2)).toDS().map(identity).queryExecution.debug.codegen`:
    
    ```
    /* 001 */ public Object generate(Object[] references) {
    /* 002 */   return new GeneratedIterator(references);
    /* 003 */ }
    /* 004 */
    /* 005 */ final class GeneratedIterator extends 
org.apache.spark.sql.execution.BufferedRowIterator {
    /* 006 */   private Object[] references;
    /* 007 */   private scala.collection.Iterator[] inputs;
    /* 008 */   private scala.collection.Iterator inputadapter_input;
    /* 009 */   private boolean CollectObjectsToMap_loopIsNull1;
    /* 010 */   private int CollectObjectsToMap_loopValue0;
    /* 011 */   private boolean CollectObjectsToMap_loopIsNull3;
    /* 012 */   private int CollectObjectsToMap_loopValue2;
    /* 013 */   private UnsafeRow deserializetoobject_result;
    /* 014 */   private 
org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder 
deserializetoobject_holder;
    /* 015 */   private 
org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter 
deserializetoobject_rowWriter;
    /* 016 */   private scala.collection.immutable.Map mapelements_argValue;
    /* 017 */   private UnsafeRow mapelements_result;
    /* 018 */   private 
org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder 
mapelements_holder;
    /* 019 */   private 
org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter 
mapelements_rowWriter;
    /* 020 */   private UnsafeRow serializefromobject_result;
    /* 021 */   private 
org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder 
serializefromobject_holder;
    /* 022 */   private 
org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter 
serializefromobject_rowWriter;
    /* 023 */   private 
org.apache.spark.sql.catalyst.expressions.codegen.UnsafeArrayWriter 
serializefromobject_arrayWriter;
    /* 024 */   private 
org.apache.spark.sql.catalyst.expressions.codegen.UnsafeArrayWriter 
serializefromobject_arrayWriter1;
    /* 025 */
    /* 026 */   public GeneratedIterator(Object[] references) {
    /* 027 */     this.references = references;
    /* 028 */   }
    /* 029 */
    /* 030 */   public void init(int index, scala.collection.Iterator[] inputs) 
{
    /* 031 */     partitionIndex = index;
    /* 032 */     this.inputs = inputs;
    /* 033 */     wholestagecodegen_init_0();
    /* 034 */     wholestagecodegen_init_1();
    /* 035 */
    /* 036 */   }
    /* 037 */
    /* 038 */   private void wholestagecodegen_init_0() {
    /* 039 */     inputadapter_input = inputs[0];
    /* 040 */
    /* 041 */     deserializetoobject_result = new UnsafeRow(1);
    /* 042 */     this.deserializetoobject_holder = new 
org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(deserializetoobject_result,
 32);
    /* 043 */     this.deserializetoobject_rowWriter = new 
org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(deserializetoobject_holder,
 1);
    /* 044 */
    /* 045 */     mapelements_result = new UnsafeRow(1);
    /* 046 */     this.mapelements_holder = new 
org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(mapelements_result,
 32);
    /* 047 */     this.mapelements_rowWriter = new 
org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(mapelements_holder,
 1);
    /* 048 */     serializefromobject_result = new UnsafeRow(1);
    /* 049 */     this.serializefromobject_holder = new 
org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(serializefromobject_result,
 32);
    /* 050 */     this.serializefromobject_rowWriter = new 
org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(serializefromobject_holder,
 1);
    /* 051 */     this.serializefromobject_arrayWriter = new 
org.apache.spark.sql.catalyst.expressions.codegen.UnsafeArrayWriter();
    /* 052 */
    /* 053 */   }
    /* 054 */
    /* 055 */   private void wholestagecodegen_init_1() {
    /* 056 */     this.serializefromobject_arrayWriter1 = new 
org.apache.spark.sql.catalyst.expressions.codegen.UnsafeArrayWriter();
    /* 057 */
    /* 058 */   }
    /* 059 */
    /* 060 */   protected void processNext() throws java.io.IOException {
    /* 061 */     while (inputadapter_input.hasNext() && !stopEarly()) {
    /* 062 */       InternalRow inputadapter_row = (InternalRow) 
inputadapter_input.next();
    /* 063 */       boolean inputadapter_isNull = inputadapter_row.isNullAt(0);
    /* 064 */       MapData inputadapter_value = inputadapter_isNull ? null : 
(inputadapter_row.getMap(0));
    /* 065 */
    /* 066 */       boolean deserializetoobject_isNull1 = true;
    /* 067 */       ArrayData deserializetoobject_value1 = null;
    /* 068 */       if (!inputadapter_isNull) {
    /* 069 */         deserializetoobject_isNull1 = false;
    /* 070 */         if (!deserializetoobject_isNull1) {
    /* 071 */           Object deserializetoobject_funcResult = null;
    /* 072 */           deserializetoobject_funcResult = 
inputadapter_value.keyArray();
    /* 073 */           if (deserializetoobject_funcResult == null) {
    /* 074 */             deserializetoobject_isNull1 = true;
    /* 075 */           } else {
    /* 076 */             deserializetoobject_value1 = (ArrayData) 
deserializetoobject_funcResult;
    /* 077 */           }
    /* 078 */
    /* 079 */         }
    /* 080 */         deserializetoobject_isNull1 = deserializetoobject_value1 
== null;
    /* 081 */       }
    /* 082 */
    /* 083 */       boolean deserializetoobject_isNull3 = true;
    /* 084 */       ArrayData deserializetoobject_value3 = null;
    /* 085 */       if (!inputadapter_isNull) {
    /* 086 */         deserializetoobject_isNull3 = false;
    /* 087 */         if (!deserializetoobject_isNull3) {
    /* 088 */           Object deserializetoobject_funcResult1 = null;
    /* 089 */           deserializetoobject_funcResult1 = 
inputadapter_value.valueArray();
    /* 090 */           if (deserializetoobject_funcResult1 == null) {
    /* 091 */             deserializetoobject_isNull3 = true;
    /* 092 */           } else {
    /* 093 */             deserializetoobject_value3 = (ArrayData) 
deserializetoobject_funcResult1;
    /* 094 */           }
    /* 095 */
    /* 096 */         }
    /* 097 */         deserializetoobject_isNull3 = deserializetoobject_value3 
== null;
    /* 098 */       }
    /* 099 */       scala.collection.immutable.Map deserializetoobject_value = 
null;
    /* 100 */
    /* 101 */       if ((deserializetoobject_isNull1 && 
!deserializetoobject_isNull3) ||
    /* 102 */         (!deserializetoobject_isNull1 && 
deserializetoobject_isNull3)) {
    /* 103 */         throw new RuntimeException("Invalid state: Inconsistent 
nullability of key-value");
    /* 104 */       }
    /* 105 */
    /* 106 */       if (!deserializetoobject_isNull1) {
    /* 107 */         if (deserializetoobject_value1.numElements() != 
deserializetoobject_value3.numElements()) {
    /* 108 */           throw new RuntimeException("Invalid state: Inconsistent 
lengths of key-value arrays");
    /* 109 */         }
    /* 110 */         int deserializetoobject_dataLength = 
deserializetoobject_value1.numElements();
    /* 111 */         scala.collection.mutable.Builder 
CollectObjectsToMap_builderValue5 = 
scala.collection.immutable.Map$.MODULE$.newBuilder();
    /* 112 */         
CollectObjectsToMap_builderValue5.sizeHint(deserializetoobject_dataLength);
    /* 113 */
    /* 114 */         int deserializetoobject_loopIndex = 0;
    /* 115 */         while (deserializetoobject_loopIndex < 
deserializetoobject_dataLength) {
    /* 116 */           CollectObjectsToMap_loopValue0 = (int) 
(deserializetoobject_value1.getInt(deserializetoobject_loopIndex));
    /* 117 */           CollectObjectsToMap_loopValue2 = (int) 
(deserializetoobject_value3.getInt(deserializetoobject_loopIndex));
    /* 118 */           CollectObjectsToMap_loopIsNull1 = 
deserializetoobject_value1.isNullAt(deserializetoobject_loopIndex);
    /* 119 */           CollectObjectsToMap_loopIsNull3 = 
deserializetoobject_value3.isNullAt(deserializetoobject_loopIndex);
    /* 120 */
    /* 121 */           scala.Tuple2 CollectObjectsToMap_loopValue4;
    /* 122 */
    /* 123 */           if (CollectObjectsToMap_loopIsNull1) {
    /* 124 */             throw new RuntimeException("Found null in map key!");
    /* 125 */           }
    /* 126 */
    /* 127 */           if (CollectObjectsToMap_loopIsNull3) {
    /* 128 */             CollectObjectsToMap_loopValue4 = new 
scala.Tuple2(CollectObjectsToMap_loopValue0, null);
    /* 129 */           } else {
    /* 130 */             CollectObjectsToMap_loopValue4 = new 
scala.Tuple2(CollectObjectsToMap_loopValue0, CollectObjectsToMap_loopValue2);
    /* 131 */           }
    /* 132 */
    /* 133 */           
CollectObjectsToMap_builderValue5.$plus$eq(CollectObjectsToMap_loopValue4);
    /* 134 */
    /* 135 */           deserializetoobject_loopIndex += 1;
    /* 136 */         }
    /* 137 */
    /* 138 */         deserializetoobject_value = 
(scala.collection.immutable.Map) CollectObjectsToMap_builderValue5.result();
    /* 139 */       }
    /* 140 */
    /* 141 */       boolean mapelements_isNull = true;
    /* 142 */       scala.collection.immutable.Map mapelements_value = null;
    /* 143 */       if (!false) {
    /* 144 */         mapelements_argValue = deserializetoobject_value;
    /* 145 */
    /* 146 */         mapelements_isNull = false;
    /* 147 */         if (!mapelements_isNull) {
    /* 148 */           Object mapelements_funcResult = null;
    /* 149 */           mapelements_funcResult = ((scala.Function1) 
references[0]).apply(mapelements_argValue);
    /* 150 */           if (mapelements_funcResult == null) {
    /* 151 */             mapelements_isNull = true;
    /* 152 */           } else {
    /* 153 */             mapelements_value = (scala.collection.immutable.Map) 
mapelements_funcResult;
    /* 154 */           }
    /* 155 */
    /* 156 */         }
    /* 157 */         mapelements_isNull = mapelements_value == null;
    /* 158 */       }
    /* 159 */
    /* 160 */       MapData serializefromobject_value = null;
    /* 161 */       if (!mapelements_isNull) {
    /* 162 */         final int serializefromobject_length = 
mapelements_value.size();
    /* 163 */         final Object[] serializefromobject_convertedKeys = new 
Object[serializefromobject_length];
    /* 164 */         final Object[] serializefromobject_convertedValues = new 
Object[serializefromobject_length];
    /* 165 */         int serializefromobject_index = 0;
    /* 166 */         final scala.collection.Iterator 
serializefromobject_entries = mapelements_value.iterator();
    /* 167 */         while(serializefromobject_entries.hasNext()) {
    /* 168 */           final scala.Tuple2 serializefromobject_entry = 
(scala.Tuple2) serializefromobject_entries.next();
    /* 169 */           int ExternalMapToCatalyst_key1 = (Integer) 
serializefromobject_entry._1();
    /* 170 */           int ExternalMapToCatalyst_value1 = (Integer) 
serializefromobject_entry._2();
    /* 171 */
    /* 172 */           boolean ExternalMapToCatalyst_value_isNull1 = false;
    /* 173 */
    /* 174 */           if (false) {
    /* 175 */             throw new RuntimeException("Cannot use null as map 
key!");
    /* 176 */           } else {
    /* 177 */             
serializefromobject_convertedKeys[serializefromobject_index] = (Integer) 
ExternalMapToCatalyst_key1;
    /* 178 */           }
    /* 179 */
    /* 180 */           if (false) {
    /* 181 */             
serializefromobject_convertedValues[serializefromobject_index] = null;
    /* 182 */           } else {
    /* 183 */             
serializefromobject_convertedValues[serializefromobject_index] = (Integer) 
ExternalMapToCatalyst_value1;
    /* 184 */           }
    /* 185 */
    /* 186 */           serializefromobject_index++;
    /* 187 */         }
    /* 188 */
    /* 189 */         serializefromobject_value = new 
org.apache.spark.sql.catalyst.util.ArrayBasedMapData(new 
org.apache.spark.sql.catalyst.util.GenericArrayData(serializefromobject_convertedKeys),
 new 
org.apache.spark.sql.catalyst.util.GenericArrayData(serializefromobject_convertedValues));
    /* 190 */       }
    /* 191 */       serializefromobject_holder.reset();
    /* 192 */
    /* 193 */       serializefromobject_rowWriter.zeroOutNullBytes();
    /* 194 */
    /* 195 */       if (mapelements_isNull) {
    /* 196 */         serializefromobject_rowWriter.setNullAt(0);
    /* 197 */       } else {
    /* 198 */         // Remember the current cursor so that we can calculate 
how many bytes are
    /* 199 */         // written later.
    /* 200 */         final int serializefromobject_tmpCursor = 
serializefromobject_holder.cursor;
    /* 201 */
    /* 202 */         if (serializefromobject_value instanceof UnsafeMapData) {
    /* 203 */           final int serializefromobject_sizeInBytes = 
((UnsafeMapData) serializefromobject_value).getSizeInBytes();
    /* 204 */           // grow the global buffer before writing data.
    /* 205 */           
serializefromobject_holder.grow(serializefromobject_sizeInBytes);
    /* 206 */           ((UnsafeMapData) 
serializefromobject_value).writeToMemory(serializefromobject_holder.buffer, 
serializefromobject_holder.cursor);
    /* 207 */           serializefromobject_holder.cursor += 
serializefromobject_sizeInBytes;
    /* 208 */
    /* 209 */         } else {
    /* 210 */           final ArrayData serializefromobject_keys = 
serializefromobject_value.keyArray();
    /* 211 */           final ArrayData serializefromobject_values = 
serializefromobject_value.valueArray();
    /* 212 */
    /* 213 */           // preserve 8 bytes to write the key array numBytes 
later.
    /* 214 */           serializefromobject_holder.grow(8);
    /* 215 */           serializefromobject_holder.cursor += 8;
    /* 216 */
    /* 217 */           // Remember the current cursor so that we can write 
numBytes of key array later.
    /* 218 */           final int serializefromobject_tmpCursor1 = 
serializefromobject_holder.cursor;
    /* 219 */
    /* 220 */           if (serializefromobject_keys instanceof 
UnsafeArrayData) {
    /* 221 */             final int serializefromobject_sizeInBytes1 = 
((UnsafeArrayData) serializefromobject_keys).getSizeInBytes();
    /* 222 */             // grow the global buffer before writing data.
    /* 223 */             
serializefromobject_holder.grow(serializefromobject_sizeInBytes1);
    /* 224 */             ((UnsafeArrayData) 
serializefromobject_keys).writeToMemory(serializefromobject_holder.buffer, 
serializefromobject_holder.cursor);
    /* 225 */             serializefromobject_holder.cursor += 
serializefromobject_sizeInBytes1;
    /* 226 */
    /* 227 */           } else {
    /* 228 */             final int serializefromobject_numElements = 
serializefromobject_keys.numElements();
    /* 229 */             
serializefromobject_arrayWriter.initialize(serializefromobject_holder, 
serializefromobject_numElements, 4);
    /* 230 */
    /* 231 */             for (int serializefromobject_index1 = 0; 
serializefromobject_index1 < serializefromobject_numElements; 
serializefromobject_index1++) {
    /* 232 */               if 
(serializefromobject_keys.isNullAt(serializefromobject_index1)) {
    /* 233 */                 
serializefromobject_arrayWriter.setNullInt(serializefromobject_index1);
    /* 234 */               } else {
    /* 235 */                 final int serializefromobject_element = 
serializefromobject_keys.getInt(serializefromobject_index1);
    /* 236 */                 
serializefromobject_arrayWriter.write(serializefromobject_index1, 
serializefromobject_element);
    /* 237 */               }
    /* 238 */             }
    /* 239 */           }
    /* 240 */
    /* 241 */           // Write the numBytes of key array into the first 8 
bytes.
    /* 242 */           Platform.putLong(serializefromobject_holder.buffer, 
serializefromobject_tmpCursor1 - 8, serializefromobject_holder.cursor - 
serializefromobject_tmpCursor1);
    /* 243 */
    /* 244 */           if (serializefromobject_values instanceof 
UnsafeArrayData) {
    /* 245 */             final int serializefromobject_sizeInBytes2 = 
((UnsafeArrayData) serializefromobject_values).getSizeInBytes();
    /* 246 */             // grow the global buffer before writing data.
    /* 247 */             
serializefromobject_holder.grow(serializefromobject_sizeInBytes2);
    /* 248 */             ((UnsafeArrayData) 
serializefromobject_values).writeToMemory(serializefromobject_holder.buffer, 
serializefromobject_holder.cursor);
    /* 249 */             serializefromobject_holder.cursor += 
serializefromobject_sizeInBytes2;
    /* 250 */
    /* 251 */           } else {
    /* 252 */             final int serializefromobject_numElements1 = 
serializefromobject_values.numElements();
    /* 253 */             
serializefromobject_arrayWriter1.initialize(serializefromobject_holder, 
serializefromobject_numElements1, 4);
    /* 254 */
    /* 255 */             for (int serializefromobject_index2 = 0; 
serializefromobject_index2 < serializefromobject_numElements1; 
serializefromobject_index2++) {
    /* 256 */               if 
(serializefromobject_values.isNullAt(serializefromobject_index2)) {
    /* 257 */                 
serializefromobject_arrayWriter1.setNullInt(serializefromobject_index2);
    /* 258 */               } else {
    /* 259 */                 final int serializefromobject_element1 = 
serializefromobject_values.getInt(serializefromobject_index2);
    /* 260 */                 
serializefromobject_arrayWriter1.write(serializefromobject_index2, 
serializefromobject_element1);
    /* 261 */               }
    /* 262 */             }
    /* 263 */           }
    /* 264 */
    /* 265 */         }
    /* 266 */
    /* 267 */         serializefromobject_rowWriter.setOffsetAndSize(0, 
serializefromobject_tmpCursor, serializefromobject_holder.cursor - 
serializefromobject_tmpCursor);
    /* 268 */       }
    /* 269 */       
serializefromobject_result.setTotalSize(serializefromobject_holder.totalSize());
    /* 270 */       append(serializefromobject_result);
    /* 271 */       if (shouldStop()) return;
    /* 272 */     }
    /* 273 */   }
    /* 274 */ }
    ```
    
    ## How was this patch tested?
    
    ```
    build/mvn -DskipTests clean package && dev/run-tests
    ```
    
    Additionally in Spark shell:
    
    ```
    scala> Seq(collection.mutable.HashMap(1 -> 2, 2 -> 3)).toDS().map(_ += (3 
-> 4)).collect()
    res0: Array[scala.collection.mutable.HashMap[Int,Int]] = Array(Map(2 -> 3, 
1 -> 2, 3 -> 4))
    ```

You can merge this pull request into a Git repository by running:

    $ git pull https://github.com/michalsenkyr/spark dataset-map-builder

Alternatively you can review and apply these changes as the patch at:

    https://github.com/apache/spark/pull/16986.patch

To close this pull request, make a commit to your master/trunk branch
with (at least) the following in the commit message:

    This closes #16986
    
----
commit 2da9ffb1e89f765d816fc1ac5f95372dbc934aa4
Author: Michal Senkyr <[email protected]>
Date:   2017-02-12T15:51:51Z

    Arbitrary map support implementation

----


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to