This is an automated email from the ASF dual-hosted git repository.

lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/main by this push:
     new f9b7ac2e92 GH-37841: [Java] Dictionary decoding not using the 
compression factory from the ArrowReader (#38371)
f9b7ac2e92 is described below

commit f9b7ac2e922bceed8bab09b1e28d7261cbe8b41d
Author: Vibhatha Lakmal Abeykoon <[email protected]>
AuthorDate: Thu Feb 1 23:08:21 2024 +0530

    GH-37841: [Java] Dictionary decoding not using the compression factory from 
the ArrowReader (#38371)
    
    ### Rationale for this change
    
    This PR addresses https://github.com/apache/arrow/issues/37841.
    
    ### What changes are included in this PR?
    
    Adding compression-based write and read for Dictionary data.
    
    ### Are these changes tested?
    
    Yes.
    
    ### Are there any user-facing changes?
    
    No
    * Closes: #37841
    
    Lead-authored-by: Vibhatha Lakmal Abeykoon <[email protected]>
    Co-authored-by: vibhatha <[email protected]>
    Signed-off-by: David Li <[email protected]>
---
 .../TestArrowReaderWriterWithCompression.java      | 206 ++++++++++++++++++---
 .../org/apache/arrow/vector/ipc/ArrowReader.java   |   2 +-
 .../org/apache/arrow/vector/ipc/ArrowWriter.java   |  23 ++-
 3 files changed, 201 insertions(+), 30 deletions(-)

diff --git 
a/java/compression/src/test/java/org/apache/arrow/compression/TestArrowReaderWriterWithCompression.java
 
b/java/compression/src/test/java/org/apache/arrow/compression/TestArrowReaderWriterWithCompression.java
index 6104cb1a13..af28333746 100644
--- 
a/java/compression/src/test/java/org/apache/arrow/compression/TestArrowReaderWriterWithCompression.java
+++ 
b/java/compression/src/test/java/org/apache/arrow/compression/TestArrowReaderWriterWithCompression.java
@@ -18,7 +18,9 @@
 package org.apache.arrow.compression;
 
 import java.io.ByteArrayOutputStream;
+import java.io.IOException;
 import java.nio.channels.Channels;
+import java.nio.charset.StandardCharsets;
 import java.util.ArrayList;
 import java.util.HashMap;
 import java.util.List;
@@ -27,63 +29,223 @@ import java.util.Optional;
 import org.apache.arrow.memory.BufferAllocator;
 import org.apache.arrow.memory.RootAllocator;
 import org.apache.arrow.vector.GenerateSampleData;
+import org.apache.arrow.vector.VarCharVector;
 import org.apache.arrow.vector.VectorSchemaRoot;
 import org.apache.arrow.vector.compression.CompressionUtil;
 import org.apache.arrow.vector.compression.NoCompressionCodec;
+import org.apache.arrow.vector.dictionary.Dictionary;
+import org.apache.arrow.vector.dictionary.DictionaryProvider;
 import org.apache.arrow.vector.ipc.ArrowFileReader;
 import org.apache.arrow.vector.ipc.ArrowFileWriter;
+import org.apache.arrow.vector.ipc.ArrowStreamReader;
+import org.apache.arrow.vector.ipc.ArrowStreamWriter;
 import org.apache.arrow.vector.ipc.message.IpcOption;
 import org.apache.arrow.vector.types.pojo.ArrowType;
+import org.apache.arrow.vector.types.pojo.DictionaryEncoding;
 import org.apache.arrow.vector.types.pojo.Field;
 import org.apache.arrow.vector.types.pojo.FieldType;
 import org.apache.arrow.vector.types.pojo.Schema;
 import org.apache.arrow.vector.util.ByteArrayReadableSeekableByteChannel;
+import org.junit.After;
 import org.junit.Assert;
-import org.junit.Test;
+import org.junit.jupiter.api.Assertions;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
 
 public class TestArrowReaderWriterWithCompression {
 
-  @Test
-  public void testArrowFileZstdRoundTrip() throws Exception {
-    // Prepare sample data
-    final BufferAllocator allocator = new RootAllocator(Integer.MAX_VALUE);
+  private BufferAllocator allocator;
+  private ByteArrayOutputStream out;
+  private VectorSchemaRoot root;
+
+  @BeforeEach
+  public void setup() {
+    if (allocator == null) {
+      allocator = new RootAllocator(Integer.MAX_VALUE);
+    }
+    out = new ByteArrayOutputStream();
+    root = null;
+  }
+
+  @After
+  public void tearDown() {
+    if (root != null) {
+      root.close();
+    }
+    if (allocator != null) {
+      allocator.close();
+    }
+    if (out != null) {
+      out.reset();
+    }
+
+  }
+
+  private void createAndWriteArrowFile(DictionaryProvider provider,
+      CompressionUtil.CodecType codecType) throws IOException {
     List<Field> fields = new ArrayList<>();
     fields.add(new Field("col", FieldType.notNullable(new ArrowType.Utf8()), 
new ArrayList<>()));
-    VectorSchemaRoot root = VectorSchemaRoot.create(new Schema(fields), 
allocator);
+    root = VectorSchemaRoot.create(new Schema(fields), allocator);
+
     final int rowCount = 10;
     GenerateSampleData.generateTestData(root.getVector(0), rowCount);
     root.setRowCount(rowCount);
 
-    // Write an in-memory compressed arrow file
-    ByteArrayOutputStream out = new ByteArrayOutputStream();
-    try (final ArrowFileWriter writer =
-           new ArrowFileWriter(root, null, Channels.newChannel(out), new 
HashMap<>(),
-             IpcOption.DEFAULT, CommonsCompressionFactory.INSTANCE, 
CompressionUtil.CodecType.ZSTD, Optional.of(7))) {
+    try (final ArrowFileWriter writer = new ArrowFileWriter(root, provider, 
Channels.newChannel(out),
+        new HashMap<>(), IpcOption.DEFAULT, 
CommonsCompressionFactory.INSTANCE, codecType, Optional.of(7))) {
       writer.start();
       writer.writeBatch();
       writer.end();
     }
+  }
+
+  private void createAndWriteArrowStream(DictionaryProvider provider,
+                                       CompressionUtil.CodecType codecType) 
throws IOException {
+    List<Field> fields = new ArrayList<>();
+    fields.add(new Field("col", FieldType.notNullable(new ArrowType.Utf8()), 
new ArrayList<>()));
+    root = VectorSchemaRoot.create(new Schema(fields), allocator);
+
+    final int rowCount = 10;
+    GenerateSampleData.generateTestData(root.getVector(0), rowCount);
+    root.setRowCount(rowCount);
+
+    try (final ArrowStreamWriter writer = new ArrowStreamWriter(root, 
provider, Channels.newChannel(out),
+            IpcOption.DEFAULT, CommonsCompressionFactory.INSTANCE, codecType, 
Optional.of(7))) {
+      writer.start();
+      writer.writeBatch();
+      writer.end();
+    }
+  }
 
-    // Read the in-memory compressed arrow file with CommonsCompressionFactory 
provided
+  private Dictionary createDictionary(VarCharVector dictionaryVector) {
+    setVector(dictionaryVector,
+        "foo".getBytes(StandardCharsets.UTF_8),
+        "bar".getBytes(StandardCharsets.UTF_8),
+        "baz".getBytes(StandardCharsets.UTF_8));
+
+    return new Dictionary(dictionaryVector,
+        new DictionaryEncoding(/*id=*/1L, /*ordered=*/false, 
/*indexType=*/null));
+  }
+
+  @Test
+  public void testArrowFileZstdRoundTrip() throws Exception {
+    createAndWriteArrowFile(null, CompressionUtil.CodecType.ZSTD);
+    // with compression
+    try (ArrowFileReader reader =
+        new ArrowFileReader(new 
ByteArrayReadableSeekableByteChannel(out.toByteArray()), allocator,
+            CommonsCompressionFactory.INSTANCE)) {
+      Assertions.assertEquals(1, reader.getRecordBlocks().size());
+      Assertions.assertTrue(reader.loadNextBatch());
+      Assertions.assertTrue(root.equals(reader.getVectorSchemaRoot()));
+      Assertions.assertFalse(reader.loadNextBatch());
+    }
+    // without compression
     try (ArrowFileReader reader =
-           new ArrowFileReader(new 
ByteArrayReadableSeekableByteChannel(out.toByteArray()),
-             allocator, CommonsCompressionFactory.INSTANCE)) {
-      Assert.assertEquals(1, reader.getRecordBlocks().size());
+        new ArrowFileReader(new 
ByteArrayReadableSeekableByteChannel(out.toByteArray()), allocator,
+            NoCompressionCodec.Factory.INSTANCE)) {
+      Assertions.assertEquals(1, reader.getRecordBlocks().size());
+      Exception exception = Assert.assertThrows(IllegalArgumentException.class,
+          reader::loadNextBatch);
+      Assertions.assertEquals("Please add arrow-compression module to use 
CommonsCompressionFactory for ZSTD",
+              exception.getMessage());
+    }
+  }
+
+  @Test
+  public void testArrowStreamZstdRoundTrip() throws Exception {
+    createAndWriteArrowStream(null, CompressionUtil.CodecType.ZSTD);
+    // with compression
+    try (ArrowStreamReader reader =
+                 new ArrowStreamReader(new 
ByteArrayReadableSeekableByteChannel(out.toByteArray()), allocator,
+                         CommonsCompressionFactory.INSTANCE)) {
       Assert.assertTrue(reader.loadNextBatch());
       Assert.assertTrue(root.equals(reader.getVectorSchemaRoot()));
       Assert.assertFalse(reader.loadNextBatch());
     }
+    // without compression
+    try (ArrowStreamReader reader =
+                 new ArrowStreamReader(new 
ByteArrayReadableSeekableByteChannel(out.toByteArray()), allocator,
+                         NoCompressionCodec.Factory.INSTANCE)) {
+      Exception exception = Assert.assertThrows(IllegalArgumentException.class,
+              reader::loadNextBatch);
+      Assert.assertEquals(
+              "Please add arrow-compression module to use 
CommonsCompressionFactory for ZSTD",
+              exception.getMessage()
+      );
+    }
+  }
 
-    // Read the in-memory compressed arrow file without CompressionFactory 
provided
+  @Test
+  public void testArrowFileZstdRoundTripWithDictionary() throws Exception {
+    VarCharVector dictionaryVector = (VarCharVector)
+        FieldType.nullable(new 
ArrowType.Utf8()).createNewSingleVector("f1_file", allocator, null);
+    Dictionary dictionary = createDictionary(dictionaryVector);
+    DictionaryProvider.MapDictionaryProvider provider = new 
DictionaryProvider.MapDictionaryProvider();
+    provider.put(dictionary);
+
+    createAndWriteArrowFile(provider, CompressionUtil.CodecType.ZSTD);
+
+    // with compression
+    try (ArrowFileReader reader =
+        new ArrowFileReader(new 
ByteArrayReadableSeekableByteChannel(out.toByteArray()), allocator,
+            CommonsCompressionFactory.INSTANCE)) {
+      Assertions.assertEquals(1, reader.getRecordBlocks().size());
+      Assertions.assertTrue(reader.loadNextBatch());
+      Assertions.assertTrue(root.equals(reader.getVectorSchemaRoot()));
+      Assertions.assertFalse(reader.loadNextBatch());
+    }
+    // without compression
     try (ArrowFileReader reader =
-           new ArrowFileReader(new 
ByteArrayReadableSeekableByteChannel(out.toByteArray()),
-             allocator, NoCompressionCodec.Factory.INSTANCE)) {
-      Assert.assertEquals(1, reader.getRecordBlocks().size());
+        new ArrowFileReader(new 
ByteArrayReadableSeekableByteChannel(out.toByteArray()), allocator,
+            NoCompressionCodec.Factory.INSTANCE)) {
+      Assertions.assertEquals(1, reader.getRecordBlocks().size());
+      Exception exception = Assert.assertThrows(IllegalArgumentException.class,
+          reader::loadNextBatch);
+      Assertions.assertEquals("Please add arrow-compression module to use 
CommonsCompressionFactory for ZSTD",
+              exception.getMessage());
+    }
+    dictionaryVector.close();
+  }
+
+  @Test
+  public void testArrowStreamZstdRoundTripWithDictionary() throws Exception {
+    VarCharVector dictionaryVector = (VarCharVector)
+            FieldType.nullable(new 
ArrowType.Utf8()).createNewSingleVector("f1_stream", allocator, null);
+    Dictionary dictionary = createDictionary(dictionaryVector);
+    DictionaryProvider.MapDictionaryProvider provider = new 
DictionaryProvider.MapDictionaryProvider();
+    provider.put(dictionary);
+
+    createAndWriteArrowStream(provider, CompressionUtil.CodecType.ZSTD);
+
+    // with compression
+    try (ArrowStreamReader reader =
+                 new ArrowStreamReader(new 
ByteArrayReadableSeekableByteChannel(out.toByteArray()), allocator,
+                         CommonsCompressionFactory.INSTANCE)) {
+      Assertions.assertTrue(reader.loadNextBatch());
+      Assertions.assertTrue(root.equals(reader.getVectorSchemaRoot()));
+      Assertions.assertFalse(reader.loadNextBatch());
+    }
+    // without compression
+    try (ArrowStreamReader reader =
+                 new ArrowStreamReader(new 
ByteArrayReadableSeekableByteChannel(out.toByteArray()), allocator,
+                         NoCompressionCodec.Factory.INSTANCE)) {
+      Exception exception = Assert.assertThrows(IllegalArgumentException.class,
+              reader::loadNextBatch);
+      Assertions.assertEquals("Please add arrow-compression module to use 
CommonsCompressionFactory for ZSTD",
+              exception.getMessage());
+    }
+    dictionaryVector.close();
+  }
 
-      Exception exception = 
Assert.assertThrows(IllegalArgumentException.class, () -> 
reader.loadNextBatch());
-      String expectedMessage = "Please add arrow-compression module to use 
CommonsCompressionFactory for ZSTD";
-      Assert.assertEquals(expectedMessage, exception.getMessage());
+  public static void setVector(VarCharVector vector, byte[]... values) {
+    final int length = values.length;
+    vector.allocateNewSafe();
+    for (int i = 0; i < length; i++) {
+      if (values[i] != null) {
+        vector.set(i, values[i]);
+      }
     }
+    vector.setValueCount(length);
   }
 
 }
diff --git 
a/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowReader.java 
b/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowReader.java
index 04c57d7e82..01f4e925c6 100644
--- a/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowReader.java
+++ b/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowReader.java
@@ -251,7 +251,7 @@ public abstract class ArrowReader implements 
DictionaryProvider, AutoCloseable {
     VectorSchemaRoot root = new VectorSchemaRoot(
         Collections.singletonList(vector.getField()),
         Collections.singletonList(vector), 0);
-    VectorLoader loader = new VectorLoader(root);
+    VectorLoader loader = new VectorLoader(root, this.compressionFactory);
     try {
       loader.load(dictionaryBatch.getDictionary());
     } finally {
diff --git 
a/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowWriter.java 
b/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowWriter.java
index a33c55de53..1cc201ae56 100644
--- a/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowWriter.java
+++ b/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowWriter.java
@@ -61,9 +61,14 @@ public abstract class ArrowWriter implements AutoCloseable {
   private final DictionaryProvider dictionaryProvider;
   private final Set<Long> dictionaryIdsUsed = new HashSet<>();
 
+  private final CompressionCodec.Factory compressionFactory;
+  private final CompressionUtil.CodecType codecType;
+  private final Optional<Integer> compressionLevel;
   private boolean started = false;
   private boolean ended = false;
 
+  private final CompressionCodec codec;
+
   protected IpcOption option;
 
   protected ArrowWriter(VectorSchemaRoot root, DictionaryProvider provider, 
WritableByteChannel out) {
@@ -89,16 +94,19 @@ public abstract class ArrowWriter implements AutoCloseable {
   protected ArrowWriter(VectorSchemaRoot root, DictionaryProvider provider, 
WritableByteChannel out, IpcOption option,
                         CompressionCodec.Factory compressionFactory, 
CompressionUtil.CodecType codecType,
                         Optional<Integer> compressionLevel) {
-    this.unloader = new VectorUnloader(
-        root, /*includeNullCount*/ true,
-        compressionLevel.isPresent() ?
-            compressionFactory.createCodec(codecType, compressionLevel.get()) :
-            compressionFactory.createCodec(codecType),
-        /*alignBuffers*/ true);
     this.out = new WriteChannel(out);
     this.option = option;
     this.dictionaryProvider = provider;
 
+    this.compressionFactory = compressionFactory;
+    this.codecType = codecType;
+    this.compressionLevel = compressionLevel;
+    this.codec = this.compressionLevel.isPresent() ?
+            this.compressionFactory.createCodec(this.codecType, 
this.compressionLevel.get()) :
+            this.compressionFactory.createCodec(this.codecType);
+    this.unloader = new VectorUnloader(root, /*includeNullCount*/ true, codec,
+        /*alignBuffers*/ true);
+
     List<Field> fields = new ArrayList<>(root.getSchema().getFields().size());
 
     
MetadataV4UnionChecker.checkForUnion(root.getSchema().getFields().iterator(), 
option.metadataVersion);
@@ -133,7 +141,8 @@ public abstract class ArrowWriter implements AutoCloseable {
         Collections.singletonList(vector.getField()),
         Collections.singletonList(vector),
         count);
-    VectorUnloader unloader = new VectorUnloader(dictRoot);
+    VectorUnloader unloader = new VectorUnloader(dictRoot, 
/*includeNullCount*/ true, this.codec,
+        /*alignBuffers*/ true);
     ArrowRecordBatch batch = unloader.getRecordBatch();
     ArrowDictionaryBatch dictionaryBatch = new ArrowDictionaryBatch(id, batch, 
false);
     try {

Reply via email to