This is an automated email from the ASF dual-hosted git repository.

brycemecum pushed a commit to branch maint-18.1.0
in repository https://gitbox.apache.org/repos/asf/arrow.git

commit 47ba57030cf05f77791a03b3d00ba48cae5bb2c2
Author: 0xderek <[email protected]>
AuthorDate: Tue Oct 15 08:41:45 2024 +0800

    GH-44353: [Java] Implement `map()` for `UnionMapWriter` (#44390)
    
    ### Rationale for this change
    
    See #44353
    
    `UnionMapWriter` does not override `map()` and falls back to implementation 
in `UnionListWriter` which causes exception when used.
    
    ### What changes are included in this PR?
    
    Implement `map()` for `UnionMapWriter`.
    
    ### Are these changes tested?
    
    Yes.
    
    ### Are there any user-facing changes?
    
    Changed behavior of the related function.
    
    * GitHub Issue: #44353
    
    Authored-by: 0xderek <[email protected]>
    Signed-off-by: David Li <[email protected]>
---
 .../src/main/codegen/templates/UnionMapWriter.java | 12 ++++++
 .../org/apache/arrow/vector/TestMapVector.java     | 49 +++++++++++++++++++---
 2 files changed, 55 insertions(+), 6 deletions(-)

diff --git a/java/vector/src/main/codegen/templates/UnionMapWriter.java 
b/java/vector/src/main/codegen/templates/UnionMapWriter.java
index 606f880377..90b55cb65e 100644
--- a/java/vector/src/main/codegen/templates/UnionMapWriter.java
+++ b/java/vector/src/main/codegen/templates/UnionMapWriter.java
@@ -219,4 +219,16 @@ public class UnionMapWriter extends UnionListWriter {
         return super.map();
     }
   }
+
+  @Override
+  public MapWriter map() {
+    switch (mode) {
+      case KEY:
+        return entryWriter.map(MapVector.KEY_NAME);
+      case VALUE:
+        return entryWriter.map(MapVector.VALUE_NAME);
+      default:
+        return super.map();
+    }
+  }
 }
diff --git 
a/java/vector/src/test/java/org/apache/arrow/vector/TestMapVector.java 
b/java/vector/src/test/java/org/apache/arrow/vector/TestMapVector.java
index 213ffced27..a4197c50b5 100644
--- a/java/vector/src/test/java/org/apache/arrow/vector/TestMapVector.java
+++ b/java/vector/src/test/java/org/apache/arrow/vector/TestMapVector.java
@@ -640,11 +640,12 @@ public class TestMapVector {
       MapWriter valueWriter;
 
       // we are essentially writing Map<Long, Map<Long, Long>>
-      // populate map vector with the following three records
+      // populate map vector with the following four records
       // [
       //    null,
       //    [1:[50: 100, 200:400], 2:[75: 175, 150: 250]],
-      //    [3:[10: 20], 4:[15: 20], 5:[25: 30, 35: null]]
+      //    [3:[10: 20], 4:[15: 20], 5:[25: 30, 35: null]],
+      //    [8:[15: 30, 10: 20]]
       // ]
 
       /* write null at index 0 */
@@ -706,11 +707,26 @@ public class TestMapVector {
 
       mapWriter.endMap();
 
-      assertEquals(2, mapVector.getLastSet());
+      /* write one or more maps at index 3 */
+      mapWriter.setPosition(3);
+      mapWriter.startMap();
+
+      mapWriter.startEntry();
+      mapWriter.key().bigInt().writeBigInt(8);
+      valueWriter = mapWriter.value().map();
+      valueWriter.startMap();
+      writeEntry(valueWriter, 15, 30L);
+      writeEntry(valueWriter, 10, 20L);
+      valueWriter.endMap();
+      mapWriter.endEntry();
+
+      mapWriter.endMap();
+
+      assertEquals(3, mapVector.getLastSet());
 
-      mapWriter.setValueCount(3);
+      mapWriter.setValueCount(4);
 
-      assertEquals(3, mapVector.getValueCount());
+      assertEquals(4, mapVector.getValueCount());
 
       // Get mapVector element at index 0
       Object result = mapVector.getObject(0);
@@ -784,19 +800,40 @@ public class TestMapVector {
       assertEquals(35L, getResultKey(innerMap));
       assertNull(innerMap.get(MapVector.VALUE_NAME));
 
+      // Get mapVector element at index 3
+      result = mapVector.getObject(3);
+      resultSet = (ArrayList<?>) result;
+
+      // only 1 map entry at index 3
+      assertEquals(1, resultSet.size());
+
+      resultStruct = (Map<?, ?>) resultSet.get(0);
+      assertEquals(8L, getResultKey(resultStruct));
+      list = (ArrayList<Map<?, ?>>) getResultValue(resultStruct);
+      assertEquals(2, list.size()); // value is a list of 2 maps
+      innerMap = list.get(0);
+      assertEquals(15L, getResultKey(innerMap));
+      assertEquals(30L, getResultValue(innerMap));
+      innerMap = list.get(1);
+      assertEquals(10L, getResultKey(innerMap));
+      assertEquals(20L, getResultValue(innerMap));
+
       /* check underlying bitVector */
       assertTrue(mapVector.isNull(0));
       assertFalse(mapVector.isNull(1));
       assertFalse(mapVector.isNull(2));
+      assertFalse(mapVector.isNull(3));
 
       /* check underlying offsets */
       final ArrowBuf offsetBuffer = mapVector.getOffsetBuffer();
 
-      /* mapVector has 0 entries at index 0, 2 entries at index 1, and 3 
entries at index 2 */
+      // mapVector has 0 entries at index 0, 2 entries at index 1, 3 entries 
at index 2,
+      // and 1 entry at index 3
       assertEquals(0, offsetBuffer.getInt(0 * MapVector.OFFSET_WIDTH));
       assertEquals(0, offsetBuffer.getInt(1 * MapVector.OFFSET_WIDTH));
       assertEquals(2, offsetBuffer.getInt(2 * MapVector.OFFSET_WIDTH));
       assertEquals(5, offsetBuffer.getInt(3 * MapVector.OFFSET_WIDTH));
+      assertEquals(6, offsetBuffer.getInt(4 * MapVector.OFFSET_WIDTH));
     }
   }
 

Reply via email to