This is an automated email from the ASF dual-hosted git repository.

gabor pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/parquet-mr.git


The following commit(s) were added to refs/heads/master by this push:
     new d8396086b PARQUET-2431: Handle ByteBufferAllocator gracefully (#1274)
d8396086b is described below

commit d8396086b3e3fefc6829f8640917c3bbde0fa9c4
Author: Gabor Szadovszky <[email protected]>
AuthorDate: Mon Feb 19 10:07:16 2024 +0100

    PARQUET-2431: Handle ByteBufferAllocator gracefully (#1274)
---
 .../cli/commands/CheckParquet251Command.java       |   1 +
 .../parquet/cli/commands/ShowPagesCommand.java     |   1 +
 .../parquet/column/impl/ColumnReaderBase.java      |   2 +-
 .../apache/parquet/column/page/PageReadStore.java  |   7 +-
 .../apache/parquet/column/values/ValuesWriter.java |   3 +-
 .../values/dictionary/DictionaryValuesWriter.java  |  13 +-
 .../rle/RunLengthBitPackingHybridEncoder.java      |   3 +-
 .../column/values/dictionary/TestDictionary.java   | 560 +++++++++++----------
 .../bytes/CapacityByteArrayOutputStream.java       |   1 +
 .../parquet/bytes/TrackingByteBufferAllocator.java | 163 ++++++
 .../org/apache/parquet/util/AutoCloseables.java    |  73 +++
 .../bytes/TestCapacityByteArrayOutputStream.java   | 272 +++++-----
 .../parquet/hadoop/ColumnChunkPageReadStore.java   |  30 ++
 .../parquet/hadoop/ColumnIndexValidator.java       |   1 +
 .../hadoop/InternalParquetRecordReader.java        |  24 +-
 .../hadoop/InternalParquetRecordWriter.java        |   1 +
 .../apache/parquet/hadoop/ParquetFileReader.java   |   8 +
 .../org/apache/parquet/hadoop/ParquetWriter.java   |  12 +
 .../parquet/hadoop/rewrite/ParquetRewriter.java    |   1 +
 .../crypto/TestPropertiesDrivenEncryption.java     |  18 +-
 .../apache/parquet/encodings/FileEncodingsIT.java  |  16 +
 .../filter2/recordlevel/PhoneBookWriter.java       |  45 +-
 .../hadoop/TestColumnChunkPageWriteStore.java      |  19 +-
 .../parquet/hadoop/TestColumnIndexFiltering.java   |  35 +-
 .../apache/parquet/hadoop/TestParquetReader.java   |  18 +-
 .../apache/parquet/hadoop/TestParquetWriter.java   |  23 +
 .../parquet/hadoop/TestStoreBloomFilter.java       |  65 +--
 27 files changed, 936 insertions(+), 479 deletions(-)

diff --git 
a/parquet-cli/src/main/java/org/apache/parquet/cli/commands/CheckParquet251Command.java
 
b/parquet-cli/src/main/java/org/apache/parquet/cli/commands/CheckParquet251Command.java
index 4af86b596..d01776f9e 100644
--- 
a/parquet-cli/src/main/java/org/apache/parquet/cli/commands/CheckParquet251Command.java
+++ 
b/parquet-cli/src/main/java/org/apache/parquet/cli/commands/CheckParquet251Command.java
@@ -113,6 +113,7 @@ public class CheckParquet251Command extends BaseCommand {
             pages != null;
             pages = reader.readNextRowGroup()) {
           validator.validate(columns, pages);
+          pages.close();
         }
       } catch (BadStatsException e) {
         return e.getMessage();
diff --git 
a/parquet-cli/src/main/java/org/apache/parquet/cli/commands/ShowPagesCommand.java
 
b/parquet-cli/src/main/java/org/apache/parquet/cli/commands/ShowPagesCommand.java
index faee61815..f8a5b0007 100644
--- 
a/parquet-cli/src/main/java/org/apache/parquet/cli/commands/ShowPagesCommand.java
+++ 
b/parquet-cli/src/main/java/org/apache/parquet/cli/commands/ShowPagesCommand.java
@@ -131,6 +131,7 @@ public class ShowPagesCommand extends BaseCommand {
           }
         }
         rowGroupNum += 1;
+        pageStore.close();
       }
 
       // TODO: Show total column size and overall size per value in the column 
summary line
diff --git 
a/parquet-column/src/main/java/org/apache/parquet/column/impl/ColumnReaderBase.java
 
b/parquet-column/src/main/java/org/apache/parquet/column/impl/ColumnReaderBase.java
index aee1e53e2..98bb56eb3 100644
--- 
a/parquet-column/src/main/java/org/apache/parquet/column/impl/ColumnReaderBase.java
+++ 
b/parquet-column/src/main/java/org/apache/parquet/column/impl/ColumnReaderBase.java
@@ -729,7 +729,7 @@ abstract class ColumnReaderBase implements ColumnReader {
 
     if (CorruptDeltaByteArrays.requiresSequentialReads(writerVersion, 
dataEncoding)
         && previousReader != null
-        && previousReader instanceof RequiresPreviousReader) {
+        && dataColumn instanceof RequiresPreviousReader) {
       // previous reader can only be set if reading sequentially
       ((RequiresPreviousReader) dataColumn).setPreviousReader(previousReader);
     }
diff --git 
a/parquet-column/src/main/java/org/apache/parquet/column/page/PageReadStore.java
 
b/parquet-column/src/main/java/org/apache/parquet/column/page/PageReadStore.java
index 1930169a7..1dd87937c 100644
--- 
a/parquet-column/src/main/java/org/apache/parquet/column/page/PageReadStore.java
+++ 
b/parquet-column/src/main/java/org/apache/parquet/column/page/PageReadStore.java
@@ -27,7 +27,7 @@ import org.apache.parquet.column.ColumnDescriptor;
  * <p>
  * TODO: rename to RowGroup?
  */
-public interface PageReadStore {
+public interface PageReadStore extends AutoCloseable {
 
   /**
    * @param descriptor the descriptor of the column
@@ -58,4 +58,9 @@ public interface PageReadStore {
   default Optional<PrimitiveIterator.OfLong> getRowIndexes() {
     return Optional.empty();
   }
+
+  @Override
+  default void close() {
+    // No-op default implementation for compatibility
+  }
 }
diff --git 
a/parquet-column/src/main/java/org/apache/parquet/column/values/ValuesWriter.java
 
b/parquet-column/src/main/java/org/apache/parquet/column/values/ValuesWriter.java
index d48277655..ecea4a752 100755
--- 
a/parquet-column/src/main/java/org/apache/parquet/column/values/ValuesWriter.java
+++ 
b/parquet-column/src/main/java/org/apache/parquet/column/values/ValuesWriter.java
@@ -26,7 +26,7 @@ import org.apache.parquet.io.api.Binary;
 /**
  * base class to implement an encoding for a given column
  */
-public abstract class ValuesWriter {
+public abstract class ValuesWriter implements AutoCloseable {
 
   /**
    * used to decide if we want to work to the next page
@@ -58,6 +58,7 @@ public abstract class ValuesWriter {
    * Called to close the values writer. Any output stream is closed and can no 
longer be used.
    * All resources are released.
    */
+  @Override
   public void close() {}
 
   /**
diff --git 
a/parquet-column/src/main/java/org/apache/parquet/column/values/dictionary/DictionaryValuesWriter.java
 
b/parquet-column/src/main/java/org/apache/parquet/column/values/dictionary/DictionaryValuesWriter.java
index 561c9b22d..53526ae8d 100644
--- 
a/parquet-column/src/main/java/org/apache/parquet/column/values/dictionary/DictionaryValuesWriter.java
+++ 
b/parquet-column/src/main/java/org/apache/parquet/column/values/dictionary/DictionaryValuesWriter.java
@@ -52,6 +52,7 @@ import 
org.apache.parquet.column.values.plain.PlainValuesWriter;
 import org.apache.parquet.column.values.rle.RunLengthBitPackingHybridEncoder;
 import org.apache.parquet.io.ParquetEncodingException;
 import org.apache.parquet.io.api.Binary;
+import org.apache.parquet.util.AutoCloseables;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -98,7 +99,7 @@ public abstract class DictionaryValuesWriter extends 
ValuesWriter implements Req
   protected ByteBufferAllocator allocator;
   /* Track the list of writers used so they can be appropriately closed when 
necessary
   (currently used for off-heap memory which is not garbage collected) */
-  private List<RunLengthBitPackingHybridEncoder> encoders = new ArrayList<>();
+  private List<AutoCloseable> toClose = new ArrayList<>();
 
   protected DictionaryValuesWriter(
       int maxDictionaryByteSize,
@@ -114,7 +115,7 @@ public abstract class DictionaryValuesWriter extends 
ValuesWriter implements Req
   protected DictionaryPage dictPage(ValuesWriter dictPageWriter) {
     DictionaryPage ret =
         new DictionaryPage(dictPageWriter.getBytes(), lastUsedDictionarySize, 
encodingForDictionaryPage);
-    dictPageWriter.close();
+    toClose.add(dictPageWriter);
     return ret;
   }
 
@@ -164,7 +165,7 @@ public abstract class DictionaryValuesWriter extends 
ValuesWriter implements Req
 
     RunLengthBitPackingHybridEncoder encoder =
         new RunLengthBitPackingHybridEncoder(bitWidth, initialSlabSize, 
maxDictionaryByteSize, this.allocator);
-    encoders.add(encoder);
+    toClose.add(encoder);
     IntIterator iterator = encodedValues.iterator();
     try {
       while (iterator.hasNext()) {
@@ -198,10 +199,8 @@ public abstract class DictionaryValuesWriter extends 
ValuesWriter implements Req
   @Override
   public void close() {
     encodedValues = null;
-    for (RunLengthBitPackingHybridEncoder encoder : encoders) {
-      encoder.close();
-    }
-    encoders.clear();
+    AutoCloseables.uncheckedClose(toClose);
+    toClose.clear();
   }
 
   @Override
diff --git 
a/parquet-column/src/main/java/org/apache/parquet/column/values/rle/RunLengthBitPackingHybridEncoder.java
 
b/parquet-column/src/main/java/org/apache/parquet/column/values/rle/RunLengthBitPackingHybridEncoder.java
index 59ba67c64..e33824bff 100644
--- 
a/parquet-column/src/main/java/org/apache/parquet/column/values/rle/RunLengthBitPackingHybridEncoder.java
+++ 
b/parquet-column/src/main/java/org/apache/parquet/column/values/rle/RunLengthBitPackingHybridEncoder.java
@@ -55,7 +55,7 @@ import org.slf4j.LoggerFactory;
  * <p>
  * Only supports positive values (including 0) // TODO: is that ok? Should we 
make a signed version?
  */
-public class RunLengthBitPackingHybridEncoder {
+public class RunLengthBitPackingHybridEncoder implements AutoCloseable {
   private static final Logger LOG = 
LoggerFactory.getLogger(RunLengthBitPackingHybridEncoder.class);
 
   private final BytePacker packer;
@@ -279,6 +279,7 @@ public class RunLengthBitPackingHybridEncoder {
     reset(true);
   }
 
+  @Override
   public void close() {
     reset(false);
     baos.close();
diff --git 
a/parquet-column/src/test/java/org/apache/parquet/column/values/dictionary/TestDictionary.java
 
b/parquet-column/src/test/java/org/apache/parquet/column/values/dictionary/TestDictionary.java
index c4117bcbc..6f7116bc3 100644
--- 
a/parquet-column/src/test/java/org/apache/parquet/column/values/dictionary/TestDictionary.java
+++ 
b/parquet-column/src/test/java/org/apache/parquet/column/values/dictionary/TestDictionary.java
@@ -32,6 +32,7 @@ import java.nio.charset.StandardCharsets;
 import org.apache.parquet.bytes.ByteBufferInputStream;
 import org.apache.parquet.bytes.BytesInput;
 import org.apache.parquet.bytes.DirectByteBufferAllocator;
+import org.apache.parquet.bytes.TrackingByteBufferAllocator;
 import org.apache.parquet.column.ColumnDescriptor;
 import org.apache.parquet.column.Dictionary;
 import org.apache.parquet.column.Encoding;
@@ -49,23 +50,36 @@ import 
org.apache.parquet.column.values.plain.PlainValuesReader;
 import org.apache.parquet.column.values.plain.PlainValuesWriter;
 import org.apache.parquet.io.api.Binary;
 import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName;
+import org.junit.After;
 import org.junit.Assert;
+import org.junit.Before;
 import org.junit.Test;
 import org.mockito.Mockito;
 
 public class TestDictionary {
 
+  private TrackingByteBufferAllocator allocator;
+
+  @Before
+  public void initAllocator() {
+    allocator = TrackingByteBufferAllocator.wrap(new 
DirectByteBufferAllocator());
+  }
+
+  @After
+  public void closeAllocator() {
+    allocator.close();
+  }
+
   private <I extends DictionaryValuesWriter> FallbackValuesWriter<I, 
PlainValuesWriter> plainFallBack(
       I dvw, int initialSize) {
-    return FallbackValuesWriter.of(
-        dvw, new PlainValuesWriter(initialSize, initialSize * 5, new 
DirectByteBufferAllocator()));
+    return FallbackValuesWriter.of(dvw, new PlainValuesWriter(initialSize, 
initialSize * 5, allocator));
   }
 
   private FallbackValuesWriter<PlainBinaryDictionaryValuesWriter, 
PlainValuesWriter>
       newPlainBinaryDictionaryValuesWriter(int maxDictionaryByteSize, int 
initialSize) {
     return plainFallBack(
         new PlainBinaryDictionaryValuesWriter(
-            maxDictionaryByteSize, PLAIN_DICTIONARY, PLAIN_DICTIONARY, new 
DirectByteBufferAllocator()),
+            maxDictionaryByteSize, PLAIN_DICTIONARY, PLAIN_DICTIONARY, 
allocator),
         initialSize);
   }
 
@@ -73,7 +87,7 @@ public class TestDictionary {
       int maxDictionaryByteSize, int initialSize) {
     return plainFallBack(
         new PlainLongDictionaryValuesWriter(
-            maxDictionaryByteSize, PLAIN_DICTIONARY, PLAIN_DICTIONARY, new 
DirectByteBufferAllocator()),
+            maxDictionaryByteSize, PLAIN_DICTIONARY, PLAIN_DICTIONARY, 
allocator),
         initialSize);
   }
 
@@ -81,7 +95,7 @@ public class TestDictionary {
       newPlainIntegerDictionaryValuesWriter(int maxDictionaryByteSize, int 
initialSize) {
     return plainFallBack(
         new PlainIntegerDictionaryValuesWriter(
-            maxDictionaryByteSize, PLAIN_DICTIONARY, PLAIN_DICTIONARY, new 
DirectByteBufferAllocator()),
+            maxDictionaryByteSize, PLAIN_DICTIONARY, PLAIN_DICTIONARY, 
allocator),
         initialSize);
   }
 
@@ -89,7 +103,7 @@ public class TestDictionary {
       newPlainDoubleDictionaryValuesWriter(int maxDictionaryByteSize, int 
initialSize) {
     return plainFallBack(
         new PlainDoubleDictionaryValuesWriter(
-            maxDictionaryByteSize, PLAIN_DICTIONARY, PLAIN_DICTIONARY, new 
DirectByteBufferAllocator()),
+            maxDictionaryByteSize, PLAIN_DICTIONARY, PLAIN_DICTIONARY, 
allocator),
         initialSize);
   }
 
@@ -97,67 +111,69 @@ public class TestDictionary {
       newPlainFloatDictionaryValuesWriter(int maxDictionaryByteSize, int 
initialSize) {
     return plainFallBack(
         new PlainFloatDictionaryValuesWriter(
-            maxDictionaryByteSize, PLAIN_DICTIONARY, PLAIN_DICTIONARY, new 
DirectByteBufferAllocator()),
+            maxDictionaryByteSize, PLAIN_DICTIONARY, PLAIN_DICTIONARY, 
allocator),
         initialSize);
   }
 
   @Test
   public void testBinaryDictionary() throws IOException {
     int COUNT = 100;
-    ValuesWriter cw = newPlainBinaryDictionaryValuesWriter(200, 10000);
-    writeRepeated(COUNT, cw, "a");
-    BytesInput bytes1 = getBytesAndCheckEncoding(cw, PLAIN_DICTIONARY);
-    writeRepeated(COUNT, cw, "b");
-    BytesInput bytes2 = getBytesAndCheckEncoding(cw, PLAIN_DICTIONARY);
-    // now we will fall back
-    writeDistinct(COUNT, cw, "c");
-    BytesInput bytes3 = getBytesAndCheckEncoding(cw, PLAIN);
-
-    DictionaryValuesReader cr = initDicReader(cw, BINARY);
-    checkRepeated(COUNT, bytes1, cr, "a");
-    checkRepeated(COUNT, bytes2, cr, "b");
-    BinaryPlainValuesReader cr2 = new BinaryPlainValuesReader();
-    checkDistinct(COUNT, bytes3, cr2, "c");
+    try (ValuesWriter cw = newPlainBinaryDictionaryValuesWriter(200, 10000)) {
+      writeRepeated(COUNT, cw, "a");
+      BytesInput bytes1 = getBytesAndCheckEncoding(cw, PLAIN_DICTIONARY);
+      writeRepeated(COUNT, cw, "b");
+      BytesInput bytes2 = getBytesAndCheckEncoding(cw, PLAIN_DICTIONARY);
+      // now we will fall back
+      writeDistinct(COUNT, cw, "c");
+      BytesInput bytes3 = getBytesAndCheckEncoding(cw, PLAIN);
+
+      DictionaryValuesReader cr = initDicReader(cw, BINARY);
+      checkRepeated(COUNT, bytes1, cr, "a");
+      checkRepeated(COUNT, bytes2, cr, "b");
+      BinaryPlainValuesReader cr2 = new BinaryPlainValuesReader();
+      checkDistinct(COUNT, bytes3, cr2, "c");
+    }
   }
 
   @Test
   public void testSkipInBinaryDictionary() throws Exception {
-    ValuesWriter cw = newPlainBinaryDictionaryValuesWriter(1000, 10000);
-    writeRepeated(100, cw, "a");
-    writeDistinct(100, cw, "b");
-    assertEquals(PLAIN_DICTIONARY, cw.getEncoding());
-
-    // Test skip and skip-n with dictionary encoding
-    ByteBufferInputStream stream = cw.getBytes().toInputStream();
-    DictionaryValuesReader cr = initDicReader(cw, BINARY);
-    cr.initFromPage(200, stream);
-    for (int i = 0; i < 100; i += 2) {
-      assertEquals(Binary.fromString("a" + i % 10), cr.readBytes());
-      cr.skip();
-    }
-    int skipCount;
-    for (int i = 0; i < 100; i += skipCount + 1) {
-      skipCount = (100 - i) / 2;
-      assertEquals(Binary.fromString("b" + i), cr.readBytes());
-      cr.skip(skipCount);
-    }
-
-    // Ensure fallback
-    writeDistinct(1000, cw, "c");
-    assertEquals(PLAIN, cw.getEncoding());
+    try (ValuesWriter cw = newPlainBinaryDictionaryValuesWriter(1000, 10000)) {
+      writeRepeated(100, cw, "a");
+      writeDistinct(100, cw, "b");
+      assertEquals(PLAIN_DICTIONARY, cw.getEncoding());
+
+      // Test skip and skip-n with dictionary encoding
+      ByteBufferInputStream stream = cw.getBytes().toInputStream();
+      DictionaryValuesReader cr = initDicReader(cw, BINARY);
+      cr.initFromPage(200, stream);
+      for (int i = 0; i < 100; i += 2) {
+        assertEquals(Binary.fromString("a" + i % 10), cr.readBytes());
+        cr.skip();
+      }
+      int skipCount;
+      for (int i = 0; i < 100; i += skipCount + 1) {
+        skipCount = (100 - i) / 2;
+        assertEquals(Binary.fromString("b" + i), cr.readBytes());
+        cr.skip(skipCount);
+      }
 
-    // Test skip and skip-n with plain encoding (after fallback)
-    ValuesReader plainReader = new BinaryPlainValuesReader();
-    plainReader.initFromPage(1200, cw.getBytes().toInputStream());
-    plainReader.skip(200);
-    for (int i = 0; i < 100; i += 2) {
-      assertEquals("c" + i, plainReader.readBytes().toStringUsingUTF8());
-      plainReader.skip();
-    }
-    for (int i = 100; i < 1000; i += skipCount + 1) {
-      skipCount = (1000 - i) / 2;
-      assertEquals(Binary.fromString("c" + i), plainReader.readBytes());
-      plainReader.skip(skipCount);
+      // Ensure fallback
+      writeDistinct(1000, cw, "c");
+      assertEquals(PLAIN, cw.getEncoding());
+
+      // Test skip and skip-n with plain encoding (after fallback)
+      ValuesReader plainReader = new BinaryPlainValuesReader();
+      plainReader.initFromPage(1200, cw.getBytes().toInputStream());
+      plainReader.skip(200);
+      for (int i = 0; i < 100; i += 2) {
+        assertEquals("c" + i, plainReader.readBytes().toStringUsingUTF8());
+        plainReader.skip();
+      }
+      for (int i = 100; i < 1000; i += skipCount + 1) {
+        skipCount = (1000 - i) / 2;
+        assertEquals(Binary.fromString("c" + i), plainReader.readBytes());
+        plainReader.skip(skipCount);
+      }
     }
   }
 
@@ -165,31 +181,32 @@ public class TestDictionary {
   public void testBinaryDictionaryFallBack() throws IOException {
     int slabSize = 100;
     int maxDictionaryByteSize = 50;
-    final ValuesWriter cw = 
newPlainBinaryDictionaryValuesWriter(maxDictionaryByteSize, slabSize);
-    int fallBackThreshold = maxDictionaryByteSize;
-    int dataSize = 0;
-    for (long i = 0; i < 100; i++) {
-      Binary binary = Binary.fromString("str" + i);
-      cw.writeBytes(binary);
-      dataSize += (binary.length() + 4);
-      if (dataSize < fallBackThreshold) {
-        assertEquals(PLAIN_DICTIONARY, cw.getEncoding());
-      } else {
-        assertEquals(PLAIN, cw.getEncoding());
+    try (final ValuesWriter cw = 
newPlainBinaryDictionaryValuesWriter(maxDictionaryByteSize, slabSize)) {
+      int fallBackThreshold = maxDictionaryByteSize;
+      int dataSize = 0;
+      for (long i = 0; i < 100; i++) {
+        Binary binary = Binary.fromString("str" + i);
+        cw.writeBytes(binary);
+        dataSize += (binary.length() + 4);
+        if (dataSize < fallBackThreshold) {
+          assertEquals(PLAIN_DICTIONARY, cw.getEncoding());
+        } else {
+          assertEquals(PLAIN, cw.getEncoding());
+        }
       }
-    }
 
-    // Fallbacked to Plain encoding, therefore use PlainValuesReader to read 
it back
-    ValuesReader reader = new BinaryPlainValuesReader();
-    reader.initFromPage(100, cw.getBytes().toInputStream());
+      // Fallbacked to Plain encoding, therefore use PlainValuesReader to read 
it back
+      ValuesReader reader = new BinaryPlainValuesReader();
+      reader.initFromPage(100, cw.getBytes().toInputStream());
 
-    for (long i = 0; i < 100; i++) {
-      assertEquals(Binary.fromString("str" + i), reader.readBytes());
-    }
+      for (long i = 0; i < 100; i++) {
+        assertEquals(Binary.fromString("str" + i), reader.readBytes());
+      }
 
-    // simulate cutting the page
-    cw.reset();
-    assertEquals(0, cw.getBufferedSize());
+      // simulate cutting the page
+      cw.reset();
+      assertEquals(0, cw.getBufferedSize());
+    }
   }
 
   @Test
@@ -199,98 +216,103 @@ public class TestDictionary {
     // make the writer happy
     Mockito.when(mock.copy()).thenReturn(Binary.fromString(" world"));
 
-    final ValuesWriter cw = newPlainBinaryDictionaryValuesWriter(100, 100);
-    cw.writeBytes(Binary.fromString("hello"));
-    cw.writeBytes(mock);
+    try (final ValuesWriter cw = newPlainBinaryDictionaryValuesWriter(100, 
100)) {
+      cw.writeBytes(Binary.fromString("hello"));
+      cw.writeBytes(mock);
 
-    assertEquals(PLAIN, cw.getEncoding());
+      assertEquals(PLAIN, cw.getEncoding());
+    }
   }
 
   @Test
   public void testBinaryDictionaryChangedValues() throws IOException {
     int COUNT = 100;
-    ValuesWriter cw = newPlainBinaryDictionaryValuesWriter(200, 10000);
-    writeRepeatedWithReuse(COUNT, cw, "a");
-    BytesInput bytes1 = getBytesAndCheckEncoding(cw, PLAIN_DICTIONARY);
-    writeRepeatedWithReuse(COUNT, cw, "b");
-    BytesInput bytes2 = getBytesAndCheckEncoding(cw, PLAIN_DICTIONARY);
-    // now we will fall back
-    writeDistinct(COUNT, cw, "c");
-    BytesInput bytes3 = getBytesAndCheckEncoding(cw, PLAIN);
-
-    DictionaryValuesReader cr = initDicReader(cw, BINARY);
-    checkRepeated(COUNT, bytes1, cr, "a");
-    checkRepeated(COUNT, bytes2, cr, "b");
-    BinaryPlainValuesReader cr2 = new BinaryPlainValuesReader();
-    checkDistinct(COUNT, bytes3, cr2, "c");
+    try (ValuesWriter cw = newPlainBinaryDictionaryValuesWriter(200, 10000)) {
+      writeRepeatedWithReuse(COUNT, cw, "a");
+      BytesInput bytes1 = getBytesAndCheckEncoding(cw, PLAIN_DICTIONARY);
+      writeRepeatedWithReuse(COUNT, cw, "b");
+      BytesInput bytes2 = getBytesAndCheckEncoding(cw, PLAIN_DICTIONARY);
+      // now we will fall back
+      writeDistinct(COUNT, cw, "c");
+      BytesInput bytes3 = getBytesAndCheckEncoding(cw, PLAIN);
+
+      DictionaryValuesReader cr = initDicReader(cw, BINARY);
+      checkRepeated(COUNT, bytes1, cr, "a");
+      checkRepeated(COUNT, bytes2, cr, "b");
+      BinaryPlainValuesReader cr2 = new BinaryPlainValuesReader();
+      checkDistinct(COUNT, bytes3, cr2, "c");
+    }
   }
 
   @Test
   public void testFirstPageFallBack() throws IOException {
     int COUNT = 1000;
-    ValuesWriter cw = newPlainBinaryDictionaryValuesWriter(10000, 10000);
-    writeDistinct(COUNT, cw, "a");
-    // not efficient so falls back
-    BytesInput bytes1 = getBytesAndCheckEncoding(cw, PLAIN);
-    writeRepeated(COUNT, cw, "b");
-    // still plain because we fell back on first page
-    BytesInput bytes2 = getBytesAndCheckEncoding(cw, PLAIN);
+    try (ValuesWriter cw = newPlainBinaryDictionaryValuesWriter(10000, 10000)) 
{
+      writeDistinct(COUNT, cw, "a");
+      // not efficient so falls back
+      BytesInput bytes1 = getBytesAndCheckEncoding(cw, PLAIN);
+      writeRepeated(COUNT, cw, "b");
+      // still plain because we fell back on first page
+      BytesInput bytes2 = getBytesAndCheckEncoding(cw, PLAIN);
 
-    ValuesReader cr = new BinaryPlainValuesReader();
-    checkDistinct(COUNT, bytes1, cr, "a");
-    checkRepeated(COUNT, bytes2, cr, "b");
+      ValuesReader cr = new BinaryPlainValuesReader();
+      checkDistinct(COUNT, bytes1, cr, "a");
+      checkRepeated(COUNT, bytes2, cr, "b");
+    }
   }
 
   @Test
   public void testSecondPageFallBack() throws IOException {
     int COUNT = 1000;
-    ValuesWriter cw = newPlainBinaryDictionaryValuesWriter(1000, 10000);
-    writeRepeated(COUNT, cw, "a");
-    BytesInput bytes1 = getBytesAndCheckEncoding(cw, PLAIN_DICTIONARY);
-    writeDistinct(COUNT, cw, "b");
-    // not efficient so falls back
-    BytesInput bytes2 = getBytesAndCheckEncoding(cw, PLAIN);
-    writeRepeated(COUNT, cw, "a");
-    // still plain because we fell back on previous page
-    BytesInput bytes3 = getBytesAndCheckEncoding(cw, PLAIN);
-
-    ValuesReader cr = initDicReader(cw, BINARY);
-    checkRepeated(COUNT, bytes1, cr, "a");
-    cr = new BinaryPlainValuesReader();
-    checkDistinct(COUNT, bytes2, cr, "b");
-    checkRepeated(COUNT, bytes3, cr, "a");
+    try (ValuesWriter cw = newPlainBinaryDictionaryValuesWriter(1000, 10000)) {
+      writeRepeated(COUNT, cw, "a");
+      BytesInput bytes1 = getBytesAndCheckEncoding(cw, PLAIN_DICTIONARY);
+      writeDistinct(COUNT, cw, "b");
+      // not efficient so falls back
+      BytesInput bytes2 = getBytesAndCheckEncoding(cw, PLAIN);
+      writeRepeated(COUNT, cw, "a");
+      // still plain because we fell back on previous page
+      BytesInput bytes3 = getBytesAndCheckEncoding(cw, PLAIN);
+
+      ValuesReader cr = initDicReader(cw, BINARY);
+      checkRepeated(COUNT, bytes1, cr, "a");
+      cr = new BinaryPlainValuesReader();
+      checkDistinct(COUNT, bytes2, cr, "b");
+      checkRepeated(COUNT, bytes3, cr, "a");
+    }
   }
 
   @Test
   public void testLongDictionary() throws IOException {
     int COUNT = 1000;
     int COUNT2 = 2000;
-    final FallbackValuesWriter<PlainLongDictionaryValuesWriter, 
PlainValuesWriter> cw =
-        newPlainLongDictionaryValuesWriter(10000, 10000);
-    for (long i = 0; i < COUNT; i++) {
-      cw.writeLong(i % 50);
-    }
-    BytesInput bytes1 = getBytesAndCheckEncoding(cw, PLAIN_DICTIONARY);
-    assertEquals(50, cw.initialWriter.getDictionarySize());
+    try (final FallbackValuesWriter<PlainLongDictionaryValuesWriter, 
PlainValuesWriter> cw =
+        newPlainLongDictionaryValuesWriter(10000, 10000)) {
+      for (long i = 0; i < COUNT; i++) {
+        cw.writeLong(i % 50);
+      }
+      BytesInput bytes1 = getBytesAndCheckEncoding(cw, PLAIN_DICTIONARY);
+      assertEquals(50, cw.initialWriter.getDictionarySize());
 
-    for (long i = COUNT2; i > 0; i--) {
-      cw.writeLong(i % 50);
-    }
-    BytesInput bytes2 = getBytesAndCheckEncoding(cw, PLAIN_DICTIONARY);
-    assertEquals(50, cw.initialWriter.getDictionarySize());
+      for (long i = COUNT2; i > 0; i--) {
+        cw.writeLong(i % 50);
+      }
+      BytesInput bytes2 = getBytesAndCheckEncoding(cw, PLAIN_DICTIONARY);
+      assertEquals(50, cw.initialWriter.getDictionarySize());
 
-    DictionaryValuesReader cr = initDicReader(cw, PrimitiveTypeName.INT64);
+      DictionaryValuesReader cr = initDicReader(cw, PrimitiveTypeName.INT64);
 
-    cr.initFromPage(COUNT, bytes1.toInputStream());
-    for (long i = 0; i < COUNT; i++) {
-      long back = cr.readLong();
-      assertEquals(i % 50, back);
-    }
+      cr.initFromPage(COUNT, bytes1.toInputStream());
+      for (long i = 0; i < COUNT; i++) {
+        long back = cr.readLong();
+        assertEquals(i % 50, back);
+      }
 
-    cr.initFromPage(COUNT2, bytes2.toInputStream());
-    for (long i = COUNT2; i > 0; i--) {
-      long back = cr.readLong();
-      assertEquals(i % 50, back);
+      cr.initFromPage(COUNT2, bytes2.toInputStream());
+      for (long i = COUNT2; i > 0; i--) {
+        long back = cr.readLong();
+        assertEquals(i % 50, back);
+      }
     }
   }
 
@@ -336,18 +358,19 @@ public class TestDictionary {
   public void testLongDictionaryFallBack() throws IOException {
     int slabSize = 100;
     int maxDictionaryByteSize = 50;
-    final FallbackValuesWriter<PlainLongDictionaryValuesWriter, 
PlainValuesWriter> cw =
-        newPlainLongDictionaryValuesWriter(maxDictionaryByteSize, slabSize);
-    // Fallbacked to Plain encoding, therefore use PlainValuesReader to read 
it back
-    ValuesReader reader = new PlainValuesReader.LongPlainValuesReader();
+    try (final FallbackValuesWriter<PlainLongDictionaryValuesWriter, 
PlainValuesWriter> cw =
+        newPlainLongDictionaryValuesWriter(maxDictionaryByteSize, slabSize)) {
+      // Fallbacked to Plain encoding, therefore use PlainValuesReader to read 
it back
+      ValuesReader reader = new PlainValuesReader.LongPlainValuesReader();
 
-    roundTripLong(cw, reader, maxDictionaryByteSize);
-    // simulate cutting the page
-    cw.reset();
-    assertEquals(0, cw.getBufferedSize());
-    cw.resetDictionary();
+      roundTripLong(cw, reader, maxDictionaryByteSize);
+      // simulate cutting the page
+      cw.reset();
+      assertEquals(0, cw.getBufferedSize());
+      cw.resetDictionary();
 
-    roundTripLong(cw, reader, maxDictionaryByteSize);
+      roundTripLong(cw, reader, maxDictionaryByteSize);
+    }
   }
 
   @Test
@@ -355,34 +378,35 @@ public class TestDictionary {
 
     int COUNT = 1000;
     int COUNT2 = 2000;
-    final FallbackValuesWriter<PlainDoubleDictionaryValuesWriter, 
PlainValuesWriter> cw =
-        newPlainDoubleDictionaryValuesWriter(10000, 10000);
+    try (final FallbackValuesWriter<PlainDoubleDictionaryValuesWriter, 
PlainValuesWriter> cw =
+        newPlainDoubleDictionaryValuesWriter(10000, 10000)) {
 
-    for (double i = 0; i < COUNT; i++) {
-      cw.writeDouble(i % 50);
-    }
+      for (double i = 0; i < COUNT; i++) {
+        cw.writeDouble(i % 50);
+      }
 
-    BytesInput bytes1 = getBytesAndCheckEncoding(cw, PLAIN_DICTIONARY);
-    assertEquals(50, cw.initialWriter.getDictionarySize());
+      BytesInput bytes1 = getBytesAndCheckEncoding(cw, PLAIN_DICTIONARY);
+      assertEquals(50, cw.initialWriter.getDictionarySize());
 
-    for (double i = COUNT2; i > 0; i--) {
-      cw.writeDouble(i % 50);
-    }
-    BytesInput bytes2 = getBytesAndCheckEncoding(cw, PLAIN_DICTIONARY);
-    assertEquals(50, cw.initialWriter.getDictionarySize());
+      for (double i = COUNT2; i > 0; i--) {
+        cw.writeDouble(i % 50);
+      }
+      BytesInput bytes2 = getBytesAndCheckEncoding(cw, PLAIN_DICTIONARY);
+      assertEquals(50, cw.initialWriter.getDictionarySize());
 
-    final DictionaryValuesReader cr = initDicReader(cw, DOUBLE);
+      final DictionaryValuesReader cr = initDicReader(cw, DOUBLE);
 
-    cr.initFromPage(COUNT, bytes1.toInputStream());
-    for (double i = 0; i < COUNT; i++) {
-      double back = cr.readDouble();
-      assertEquals(i % 50, back, 0.0);
-    }
+      cr.initFromPage(COUNT, bytes1.toInputStream());
+      for (double i = 0; i < COUNT; i++) {
+        double back = cr.readDouble();
+        assertEquals(i % 50, back, 0.0);
+      }
 
-    cr.initFromPage(COUNT2, bytes2.toInputStream());
-    for (double i = COUNT2; i > 0; i--) {
-      double back = cr.readDouble();
-      assertEquals(i % 50, back, 0.0);
+      cr.initFromPage(COUNT2, bytes2.toInputStream());
+      for (double i = COUNT2; i > 0; i--) {
+        double back = cr.readDouble();
+        assertEquals(i % 50, back, 0.0);
+      }
     }
   }
 
@@ -428,19 +452,20 @@ public class TestDictionary {
   public void testDoubleDictionaryFallBack() throws IOException {
     int slabSize = 100;
     int maxDictionaryByteSize = 50;
-    final FallbackValuesWriter<PlainDoubleDictionaryValuesWriter, 
PlainValuesWriter> cw =
-        newPlainDoubleDictionaryValuesWriter(maxDictionaryByteSize, slabSize);
+    try (final FallbackValuesWriter<PlainDoubleDictionaryValuesWriter, 
PlainValuesWriter> cw =
+        newPlainDoubleDictionaryValuesWriter(maxDictionaryByteSize, slabSize)) 
{
 
-    // Fallbacked to Plain encoding, therefore use PlainValuesReader to read 
it back
-    ValuesReader reader = new PlainValuesReader.DoublePlainValuesReader();
+      // Fallbacked to Plain encoding, therefore use PlainValuesReader to read 
it back
+      ValuesReader reader = new PlainValuesReader.DoublePlainValuesReader();
 
-    roundTripDouble(cw, reader, maxDictionaryByteSize);
-    // simulate cutting the page
-    cw.reset();
-    assertEquals(0, cw.getBufferedSize());
-    cw.resetDictionary();
+      roundTripDouble(cw, reader, maxDictionaryByteSize);
+      // simulate cutting the page
+      cw.reset();
+      assertEquals(0, cw.getBufferedSize());
+      cw.resetDictionary();
 
-    roundTripDouble(cw, reader, maxDictionaryByteSize);
+      roundTripDouble(cw, reader, maxDictionaryByteSize);
+    }
   }
 
   @Test
@@ -448,33 +473,34 @@ public class TestDictionary {
 
     int COUNT = 2000;
     int COUNT2 = 4000;
-    final FallbackValuesWriter<PlainIntegerDictionaryValuesWriter, 
PlainValuesWriter> cw =
-        newPlainIntegerDictionaryValuesWriter(10000, 10000);
+    try (final FallbackValuesWriter<PlainIntegerDictionaryValuesWriter, 
PlainValuesWriter> cw =
+        newPlainIntegerDictionaryValuesWriter(10000, 10000)) {
 
-    for (int i = 0; i < COUNT; i++) {
-      cw.writeInteger(i % 50);
-    }
-    BytesInput bytes1 = getBytesAndCheckEncoding(cw, PLAIN_DICTIONARY);
-    assertEquals(50, cw.initialWriter.getDictionarySize());
+      for (int i = 0; i < COUNT; i++) {
+        cw.writeInteger(i % 50);
+      }
+      BytesInput bytes1 = getBytesAndCheckEncoding(cw, PLAIN_DICTIONARY);
+      assertEquals(50, cw.initialWriter.getDictionarySize());
 
-    for (int i = COUNT2; i > 0; i--) {
-      cw.writeInteger(i % 50);
-    }
-    BytesInput bytes2 = getBytesAndCheckEncoding(cw, PLAIN_DICTIONARY);
-    assertEquals(50, cw.initialWriter.getDictionarySize());
+      for (int i = COUNT2; i > 0; i--) {
+        cw.writeInteger(i % 50);
+      }
+      BytesInput bytes2 = getBytesAndCheckEncoding(cw, PLAIN_DICTIONARY);
+      assertEquals(50, cw.initialWriter.getDictionarySize());
 
-    DictionaryValuesReader cr = initDicReader(cw, INT32);
+      DictionaryValuesReader cr = initDicReader(cw, INT32);
 
-    cr.initFromPage(COUNT, bytes1.toInputStream());
-    for (int i = 0; i < COUNT; i++) {
-      int back = cr.readInteger();
-      assertEquals(i % 50, back);
-    }
+      cr.initFromPage(COUNT, bytes1.toInputStream());
+      for (int i = 0; i < COUNT; i++) {
+        int back = cr.readInteger();
+        assertEquals(i % 50, back);
+      }
 
-    cr.initFromPage(COUNT2, bytes2.toInputStream());
-    for (int i = COUNT2; i > 0; i--) {
-      int back = cr.readInteger();
-      assertEquals(i % 50, back);
+      cr.initFromPage(COUNT2, bytes2.toInputStream());
+      for (int i = COUNT2; i > 0; i--) {
+        int back = cr.readInteger();
+        assertEquals(i % 50, back);
+      }
     }
   }
 
@@ -520,19 +546,20 @@ public class TestDictionary {
   public void testIntDictionaryFallBack() throws IOException {
     int slabSize = 100;
     int maxDictionaryByteSize = 50;
-    final FallbackValuesWriter<PlainIntegerDictionaryValuesWriter, 
PlainValuesWriter> cw =
-        newPlainIntegerDictionaryValuesWriter(maxDictionaryByteSize, slabSize);
+    try (final FallbackValuesWriter<PlainIntegerDictionaryValuesWriter, 
PlainValuesWriter> cw =
+        newPlainIntegerDictionaryValuesWriter(maxDictionaryByteSize, 
slabSize)) {
 
-    // Fallbacked to Plain encoding, therefore use PlainValuesReader to read 
it back
-    ValuesReader reader = new PlainValuesReader.IntegerPlainValuesReader();
+      // Fallbacked to Plain encoding, therefore use PlainValuesReader to read 
it back
+      ValuesReader reader = new PlainValuesReader.IntegerPlainValuesReader();
 
-    roundTripInt(cw, reader, maxDictionaryByteSize);
-    // simulate cutting the page
-    cw.reset();
-    assertEquals(0, cw.getBufferedSize());
-    cw.resetDictionary();
+      roundTripInt(cw, reader, maxDictionaryByteSize);
+      // simulate cutting the page
+      cw.reset();
+      assertEquals(0, cw.getBufferedSize());
+      cw.resetDictionary();
 
-    roundTripInt(cw, reader, maxDictionaryByteSize);
+      roundTripInt(cw, reader, maxDictionaryByteSize);
+    }
   }
 
   @Test
@@ -540,33 +567,34 @@ public class TestDictionary {
 
     int COUNT = 2000;
     int COUNT2 = 4000;
-    final FallbackValuesWriter<PlainFloatDictionaryValuesWriter, 
PlainValuesWriter> cw =
-        newPlainFloatDictionaryValuesWriter(10000, 10000);
+    try (final FallbackValuesWriter<PlainFloatDictionaryValuesWriter, 
PlainValuesWriter> cw =
+        newPlainFloatDictionaryValuesWriter(10000, 10000)) {
 
-    for (float i = 0; i < COUNT; i++) {
-      cw.writeFloat(i % 50);
-    }
-    BytesInput bytes1 = getBytesAndCheckEncoding(cw, PLAIN_DICTIONARY);
-    assertEquals(50, cw.initialWriter.getDictionarySize());
+      for (float i = 0; i < COUNT; i++) {
+        cw.writeFloat(i % 50);
+      }
+      BytesInput bytes1 = getBytesAndCheckEncoding(cw, PLAIN_DICTIONARY);
+      assertEquals(50, cw.initialWriter.getDictionarySize());
 
-    for (float i = COUNT2; i > 0; i--) {
-      cw.writeFloat(i % 50);
-    }
-    BytesInput bytes2 = getBytesAndCheckEncoding(cw, PLAIN_DICTIONARY);
-    assertEquals(50, cw.initialWriter.getDictionarySize());
+      for (float i = COUNT2; i > 0; i--) {
+        cw.writeFloat(i % 50);
+      }
+      BytesInput bytes2 = getBytesAndCheckEncoding(cw, PLAIN_DICTIONARY);
+      assertEquals(50, cw.initialWriter.getDictionarySize());
 
-    DictionaryValuesReader cr = initDicReader(cw, FLOAT);
+      DictionaryValuesReader cr = initDicReader(cw, FLOAT);
 
-    cr.initFromPage(COUNT, bytes1.toInputStream());
-    for (float i = 0; i < COUNT; i++) {
-      float back = cr.readFloat();
-      assertEquals(i % 50, back, 0.0f);
-    }
+      cr.initFromPage(COUNT, bytes1.toInputStream());
+      for (float i = 0; i < COUNT; i++) {
+        float back = cr.readFloat();
+        assertEquals(i % 50, back, 0.0f);
+      }
 
-    cr.initFromPage(COUNT2, bytes2.toInputStream());
-    for (float i = COUNT2; i > 0; i--) {
-      float back = cr.readFloat();
-      assertEquals(i % 50, back, 0.0f);
+      cr.initFromPage(COUNT2, bytes2.toInputStream());
+      for (float i = COUNT2; i > 0; i--) {
+        float back = cr.readFloat();
+        assertEquals(i % 50, back, 0.0f);
+      }
     }
   }
 
@@ -612,40 +640,42 @@ public class TestDictionary {
   public void testFloatDictionaryFallBack() throws IOException {
     int slabSize = 100;
     int maxDictionaryByteSize = 50;
-    final FallbackValuesWriter<PlainFloatDictionaryValuesWriter, 
PlainValuesWriter> cw =
-        newPlainFloatDictionaryValuesWriter(maxDictionaryByteSize, slabSize);
+    try (final FallbackValuesWriter<PlainFloatDictionaryValuesWriter, 
PlainValuesWriter> cw =
+        newPlainFloatDictionaryValuesWriter(maxDictionaryByteSize, slabSize)) {
 
-    // Fallbacked to Plain encoding, therefore use PlainValuesReader to read 
it back
-    ValuesReader reader = new PlainValuesReader.FloatPlainValuesReader();
+      // Fallbacked to Plain encoding, therefore use PlainValuesReader to read 
it back
+      ValuesReader reader = new PlainValuesReader.FloatPlainValuesReader();
 
-    roundTripFloat(cw, reader, maxDictionaryByteSize);
-    // simulate cutting the page
-    cw.reset();
-    assertEquals(0, cw.getBufferedSize());
-    cw.resetDictionary();
+      roundTripFloat(cw, reader, maxDictionaryByteSize);
+      // simulate cutting the page
+      cw.reset();
+      assertEquals(0, cw.getBufferedSize());
+      cw.resetDictionary();
 
-    roundTripFloat(cw, reader, maxDictionaryByteSize);
+      roundTripFloat(cw, reader, maxDictionaryByteSize);
+    }
   }
 
   @Test
   public void testZeroValues() throws IOException {
-    FallbackValuesWriter<PlainIntegerDictionaryValuesWriter, 
PlainValuesWriter> cw =
-        newPlainIntegerDictionaryValuesWriter(100, 100);
-    cw.writeInteger(34);
-    cw.writeInteger(34);
-    getBytesAndCheckEncoding(cw, PLAIN_DICTIONARY);
-    DictionaryValuesReader reader = initDicReader(cw, INT32);
-
-    // pretend there are 100 nulls. what matters is offset = bytes.length.
-    ByteBuffer bytes = ByteBuffer.wrap(new byte[] {0x00, 0x01, 0x02, 0x03}); 
// data doesn't matter
-    ByteBufferInputStream stream = ByteBufferInputStream.wrap(bytes);
-    stream.skipFully(stream.available());
-    reader.initFromPage(100, stream);
-
-    // Testing the deprecated behavior of using byte arrays directly
-    reader = initDicReader(cw, INT32);
-    int offset = bytes.remaining();
-    reader.initFromPage(100, bytes, offset);
+    try (FallbackValuesWriter<PlainIntegerDictionaryValuesWriter, 
PlainValuesWriter> cw =
+        newPlainIntegerDictionaryValuesWriter(100, 100)) {
+      cw.writeInteger(34);
+      cw.writeInteger(34);
+      getBytesAndCheckEncoding(cw, PLAIN_DICTIONARY);
+      DictionaryValuesReader reader = initDicReader(cw, INT32);
+
+      // pretend there are 100 nulls. what matters is offset = bytes.length.
+      ByteBuffer bytes = ByteBuffer.wrap(new byte[] {0x00, 0x01, 0x02, 0x03}); 
// data doesn't matter
+      ByteBufferInputStream stream = ByteBufferInputStream.wrap(bytes);
+      stream.skipFully(stream.available());
+      reader.initFromPage(100, stream);
+
+      // Testing the deprecated behavior of using byte arrays directly
+      reader = initDicReader(cw, INT32);
+      int offset = bytes.remaining();
+      reader.initFromPage(100, bytes, offset);
+    }
   }
 
   private DictionaryValuesReader initDicReader(ValuesWriter cw, 
PrimitiveTypeName type) throws IOException {
diff --git 
a/parquet-common/src/main/java/org/apache/parquet/bytes/CapacityByteArrayOutputStream.java
 
b/parquet-common/src/main/java/org/apache/parquet/bytes/CapacityByteArrayOutputStream.java
index 528248a93..2031e625a 100644
--- 
a/parquet-common/src/main/java/org/apache/parquet/bytes/CapacityByteArrayOutputStream.java
+++ 
b/parquet-common/src/main/java/org/apache/parquet/bytes/CapacityByteArrayOutputStream.java
@@ -334,6 +334,7 @@ public class CapacityByteArrayOutputStream extends 
OutputStream {
     for (ByteBuffer slab : slabs) {
       allocator.release(slab);
     }
+    slabs.clear();
     try {
       super.close();
     } catch (IOException e) {
diff --git 
a/parquet-common/src/main/java/org/apache/parquet/bytes/TrackingByteBufferAllocator.java
 
b/parquet-common/src/main/java/org/apache/parquet/bytes/TrackingByteBufferAllocator.java
new file mode 100644
index 000000000..c6dd3431f
--- /dev/null
+++ 
b/parquet-common/src/main/java/org/apache/parquet/bytes/TrackingByteBufferAllocator.java
@@ -0,0 +1,163 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.parquet.bytes;
+
+import java.nio.ByteBuffer;
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * A wrapper {@link ByteBufferAllocator} implementation that tracks whether 
all allocated buffers are released. It
+ * throws the related exception at {@link #close()} if any buffer remains 
un-released. It also clears the buffers at
+ * release so if they continued being used it'll generate errors.
+ * <p>To be used for testing purposes.
+ */
+public final class TrackingByteBufferAllocator implements ByteBufferAllocator, 
AutoCloseable {
+
+  /**
+   * The stacktraces of the allocation are not stored by default because it 
significantly decreases the unit test
+   * execution performance
+   *
+   * @see ByteBufferAllocationStacktraceException
+   */
+  private static final boolean DEBUG = false;
+
+  public static TrackingByteBufferAllocator wrap(ByteBufferAllocator 
allocator) {
+    return new TrackingByteBufferAllocator(allocator);
+  }
+
+  private static class Key {
+
+    private final int hashCode;
+    private final ByteBuffer buffer;
+
+    Key(ByteBuffer buffer) {
+      hashCode = System.identityHashCode(buffer);
+      this.buffer = buffer;
+    }
+
+    @Override
+    public boolean equals(Object o) {
+      if (this == o) {
+        return true;
+      }
+      if (o == null || getClass() != o.getClass()) {
+        return false;
+      }
+      Key key = (Key) o;
+      return this.buffer == key.buffer;
+    }
+
+    @Override
+    public int hashCode() {
+      return hashCode;
+    }
+  }
+
+  public static class LeakDetectorHeapByteBufferAllocatorException extends 
RuntimeException {
+
+    private LeakDetectorHeapByteBufferAllocatorException(String msg) {
+      super(msg);
+    }
+
+    private LeakDetectorHeapByteBufferAllocatorException(String msg, Throwable 
cause) {
+      super(msg, cause);
+    }
+
+    private LeakDetectorHeapByteBufferAllocatorException(
+        String message, Throwable cause, boolean enableSuppression, boolean 
writableStackTrace) {
+      super(message, cause, enableSuppression, writableStackTrace);
+    }
+  }
+
+  public static class ByteBufferAllocationStacktraceException extends 
LeakDetectorHeapByteBufferAllocatorException {
+
+    private static final ByteBufferAllocationStacktraceException 
WITHOUT_STACKTRACE =
+        new ByteBufferAllocationStacktraceException(false);
+
+    private static ByteBufferAllocationStacktraceException create() {
+      return DEBUG ? new ByteBufferAllocationStacktraceException() : 
WITHOUT_STACKTRACE;
+    }
+
+    private ByteBufferAllocationStacktraceException() {
+      super("Allocation stacktrace of the first ByteBuffer:");
+    }
+
+    private ByteBufferAllocationStacktraceException(boolean unused) {
+      super(
+          "Set org.apache.parquet.bytes.TrackingByteBufferAllocator.DEBUG = 
true for more info",
+          null,
+          false,
+          false);
+    }
+  }
+
+  public static class ReleasingUnallocatedByteBufferException extends 
LeakDetectorHeapByteBufferAllocatorException {
+
+    private ReleasingUnallocatedByteBufferException() {
+      super("Releasing a ByteBuffer instance that is not allocated by this 
allocator or already been released");
+    }
+  }
+
+  public static class LeakedByteBufferException extends 
LeakDetectorHeapByteBufferAllocatorException {
+
+    private LeakedByteBufferException(int count, 
ByteBufferAllocationStacktraceException e) {
+      super(count + " ByteBuffer object(s) is/are remained unreleased after 
closing this allocator.", e);
+    }
+  }
+
+  private final Map<Key, ByteBufferAllocationStacktraceException> allocated = 
new HashMap<>();
+  private final ByteBufferAllocator allocator;
+
+  private TrackingByteBufferAllocator(ByteBufferAllocator allocator) {
+    this.allocator = allocator;
+  }
+
+  @Override
+  public ByteBuffer allocate(int size) {
+    ByteBuffer buffer = allocator.allocate(size);
+    allocated.put(new Key(buffer), 
ByteBufferAllocationStacktraceException.create());
+    return buffer;
+  }
+
+  @Override
+  public void release(ByteBuffer b) throws 
ReleasingUnallocatedByteBufferException {
+    if (allocated.remove(new Key(b)) == null) {
+      throw new ReleasingUnallocatedByteBufferException();
+    }
+    allocator.release(b);
+    // Clearing the buffer so subsequent access would probably generate errors
+    b.clear();
+  }
+
+  @Override
+  public boolean isDirect() {
+    return allocator.isDirect();
+  }
+
+  @Override
+  public void close() throws LeakedByteBufferException {
+    if (!allocated.isEmpty()) {
+      LeakedByteBufferException ex = new LeakedByteBufferException(
+          allocated.size(), allocated.values().iterator().next());
+      allocated.clear(); // Drop the references to the ByteBuffers, so they 
can be gc'd
+      throw ex;
+    }
+  }
+}
diff --git 
a/parquet-common/src/main/java/org/apache/parquet/util/AutoCloseables.java 
b/parquet-common/src/main/java/org/apache/parquet/util/AutoCloseables.java
new file mode 100644
index 000000000..6934c40bd
--- /dev/null
+++ b/parquet-common/src/main/java/org/apache/parquet/util/AutoCloseables.java
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.parquet.util;
+
+import org.apache.parquet.ParquetRuntimeException;
+
+/**
+ * Utility class to handle {@link AutoCloseable} objects.
+ */
+public final class AutoCloseables {
+
+  public static class ParquetCloseResourceException extends 
ParquetRuntimeException {
+
+    private ParquetCloseResourceException(Exception e) {
+      super("Unable to close resource", e);
+    }
+  }
+
+  /**
+   * Invokes the {@link AutoCloseable#close()} method of each specified 
objects in a way that guarantees that all the
+   * methods will be invoked even if an exception is occurred before.
+   *
+   * @param autoCloseables the objects to be closed
+   * @throws Exception the compound exception built from the exceptions thrown 
by the close methods
+   */
+  public static void close(Iterable<AutoCloseable> autoCloseables) throws 
Exception {
+    Exception root = null;
+    for (AutoCloseable autoCloseable : autoCloseables) {
+      try {
+        autoCloseable.close();
+      } catch (Exception e) {
+        if (root == null) {
+          root = e;
+        } else {
+          root.addSuppressed(e);
+        }
+      }
+    }
+    if (root != null) {
+      throw root;
+    }
+  }
+
+  /**
+   * Works similarly to {@link #close(Iterable)} but it wraps the thrown 
exception (if any) into a
+   * {@link ParquetCloseResourceException}.
+   */
+  public static void uncheckedClose(Iterable<AutoCloseable> autoCloseables) 
throws ParquetCloseResourceException {
+    try {
+      close(autoCloseables);
+    } catch (Exception e) {
+      throw new ParquetCloseResourceException(e);
+    }
+  }
+
+  private AutoCloseables() {}
+}
diff --git 
a/parquet-encoding/src/test/java/org/apache/parquet/bytes/TestCapacityByteArrayOutputStream.java
 
b/parquet-encoding/src/test/java/org/apache/parquet/bytes/TestCapacityByteArrayOutputStream.java
index 667498671..583b90209 100644
--- 
a/parquet-encoding/src/test/java/org/apache/parquet/bytes/TestCapacityByteArrayOutputStream.java
+++ 
b/parquet-encoding/src/test/java/org/apache/parquet/bytes/TestCapacityByteArrayOutputStream.java
@@ -25,172 +25,194 @@ import static org.junit.Assert.assertTrue;
 import java.io.ByteArrayOutputStream;
 import java.io.IOException;
 import java.util.Arrays;
+import org.junit.After;
+import org.junit.Before;
 import org.junit.Test;
 
 public class TestCapacityByteArrayOutputStream {
 
+  private TrackingByteBufferAllocator allocator;
+
+  @Before
+  public void initAllocator() {
+    allocator = TrackingByteBufferAllocator.wrap(new 
HeapByteBufferAllocator());
+  }
+
+  @After
+  public void closeAllocator() {
+    allocator.close();
+  }
+
   @Test
   public void testWrite() throws Throwable {
-    CapacityByteArrayOutputStream capacityByteArrayOutputStream = 
newCapacityBAOS(10);
-    final int expectedSize = 54;
-    for (int i = 0; i < expectedSize; i++) {
-      capacityByteArrayOutputStream.write(i);
-      assertEquals(i + 1, capacityByteArrayOutputStream.size());
+    try (CapacityByteArrayOutputStream capacityByteArrayOutputStream = 
newCapacityBAOS(10)) {
+      final int expectedSize = 54;
+      for (int i = 0; i < expectedSize; i++) {
+        capacityByteArrayOutputStream.write(i);
+        assertEquals(i + 1, capacityByteArrayOutputStream.size());
+      }
+      validate(capacityByteArrayOutputStream, expectedSize);
     }
-    validate(capacityByteArrayOutputStream, expectedSize);
   }
 
   @Test
   public void testWriteArray() throws Throwable {
-    CapacityByteArrayOutputStream capacityByteArrayOutputStream = 
newCapacityBAOS(10);
-    int v = 23;
-    writeArraysOf3(capacityByteArrayOutputStream, v);
-    validate(capacityByteArrayOutputStream, v * 3);
+    try (CapacityByteArrayOutputStream capacityByteArrayOutputStream = 
newCapacityBAOS(10)) {
+      int v = 23;
+      writeArraysOf3(capacityByteArrayOutputStream, v);
+      validate(capacityByteArrayOutputStream, v * 3);
+    }
   }
 
   @Test
   public void testWriteArrayExpand() throws Throwable {
-    CapacityByteArrayOutputStream capacityByteArrayOutputStream = 
newCapacityBAOS(2);
-    assertEquals(0, capacityByteArrayOutputStream.getCapacity());
+    try (CapacityByteArrayOutputStream capacityByteArrayOutputStream = 
newCapacityBAOS(2)) {
+      assertEquals(0, capacityByteArrayOutputStream.getCapacity());
 
-    byte[] toWrite = {(byte) (1), (byte) (2), (byte) (3), (byte) (4)};
-    int toWriteOffset = 0;
-    int writeLength = 2;
-    // write 2 bytes array
-    capacityByteArrayOutputStream.write(toWrite, toWriteOffset, writeLength);
-    toWriteOffset += writeLength;
-    assertEquals(2, capacityByteArrayOutputStream.size());
-    assertEquals(2, capacityByteArrayOutputStream.getCapacity());
+      byte[] toWrite = {(byte) (1), (byte) (2), (byte) (3), (byte) (4)};
+      int toWriteOffset = 0;
+      int writeLength = 2;
+      // write 2 bytes array
+      capacityByteArrayOutputStream.write(toWrite, toWriteOffset, writeLength);
+      toWriteOffset += writeLength;
+      assertEquals(2, capacityByteArrayOutputStream.size());
+      assertEquals(2, capacityByteArrayOutputStream.getCapacity());
 
-    // write 1 byte array, expand capacity to 4
-    writeLength = 1;
-    capacityByteArrayOutputStream.write(toWrite, toWriteOffset, writeLength);
-    toWriteOffset += writeLength;
-    assertEquals(3, capacityByteArrayOutputStream.size());
-    assertEquals(4, capacityByteArrayOutputStream.getCapacity());
+      // write 1 byte array, expand capacity to 4
+      writeLength = 1;
+      capacityByteArrayOutputStream.write(toWrite, toWriteOffset, writeLength);
+      toWriteOffset += writeLength;
+      assertEquals(3, capacityByteArrayOutputStream.size());
+      assertEquals(4, capacityByteArrayOutputStream.getCapacity());
 
-    // write 1 byte array, not expand
-    capacityByteArrayOutputStream.write(toWrite, toWriteOffset, writeLength);
-    assertEquals(4, capacityByteArrayOutputStream.size());
-    assertEquals(4, capacityByteArrayOutputStream.getCapacity());
-    final byte[] byteArray = 
BytesInput.from(capacityByteArrayOutputStream).toByteArray();
-    assertArrayEquals(toWrite, byteArray);
+      // write 1 byte array, not expand
+      capacityByteArrayOutputStream.write(toWrite, toWriteOffset, writeLength);
+      assertEquals(4, capacityByteArrayOutputStream.size());
+      assertEquals(4, capacityByteArrayOutputStream.getCapacity());
+      final byte[] byteArray =
+          BytesInput.from(capacityByteArrayOutputStream).toByteArray();
+      assertArrayEquals(toWrite, byteArray);
+    }
   }
 
   @Test
   public void testWriteArrayAndInt() throws Throwable {
-    CapacityByteArrayOutputStream capacityByteArrayOutputStream = 
newCapacityBAOS(10);
-    for (int i = 0; i < 23; i++) {
-      byte[] toWrite = {(byte) (i * 3), (byte) (i * 3 + 1)};
-      capacityByteArrayOutputStream.write(toWrite);
-      capacityByteArrayOutputStream.write((byte) (i * 3 + 2));
-      assertEquals((i + 1) * 3, capacityByteArrayOutputStream.size());
+    try (CapacityByteArrayOutputStream capacityByteArrayOutputStream = 
newCapacityBAOS(10)) {
+      for (int i = 0; i < 23; i++) {
+        byte[] toWrite = {(byte) (i * 3), (byte) (i * 3 + 1)};
+        capacityByteArrayOutputStream.write(toWrite);
+        capacityByteArrayOutputStream.write((byte) (i * 3 + 2));
+        assertEquals((i + 1) * 3, capacityByteArrayOutputStream.size());
+      }
+      validate(capacityByteArrayOutputStream, 23 * 3);
     }
-    validate(capacityByteArrayOutputStream, 23 * 3);
   }
 
   protected CapacityByteArrayOutputStream newCapacityBAOS(int initialSize) {
-    return new CapacityByteArrayOutputStream(initialSize, 1000000, new 
HeapByteBufferAllocator());
+    return new CapacityByteArrayOutputStream(initialSize, 1000000, allocator);
   }
 
   @Test
   public void testReset() throws Throwable {
-    CapacityByteArrayOutputStream capacityByteArrayOutputStream = 
newCapacityBAOS(10);
-    for (int i = 0; i < 54; i++) {
-      capacityByteArrayOutputStream.write(i);
-      assertEquals(i + 1, capacityByteArrayOutputStream.size());
-    }
-    capacityByteArrayOutputStream.reset();
-    for (int i = 0; i < 54; i++) {
-      capacityByteArrayOutputStream.write(54 + i);
-      assertEquals(i + 1, capacityByteArrayOutputStream.size());
-    }
-    final byte[] byteArray = 
BytesInput.from(capacityByteArrayOutputStream).toByteArray();
-    assertEquals(54, byteArray.length);
-    for (int i = 0; i < 54; i++) {
-      assertEquals(i + " in " + Arrays.toString(byteArray), 54 + i, 
byteArray[i]);
+    try (CapacityByteArrayOutputStream capacityByteArrayOutputStream = 
newCapacityBAOS(10)) {
+      for (int i = 0; i < 54; i++) {
+        capacityByteArrayOutputStream.write(i);
+        assertEquals(i + 1, capacityByteArrayOutputStream.size());
+      }
+      capacityByteArrayOutputStream.reset();
+      for (int i = 0; i < 54; i++) {
+        capacityByteArrayOutputStream.write(54 + i);
+        assertEquals(i + 1, capacityByteArrayOutputStream.size());
+      }
+      final byte[] byteArray =
+          BytesInput.from(capacityByteArrayOutputStream).toByteArray();
+      assertEquals(54, byteArray.length);
+      for (int i = 0; i < 54; i++) {
+        assertEquals(i + " in " + Arrays.toString(byteArray), 54 + i, 
byteArray[i]);
+      }
     }
   }
 
   @Test
   public void testWriteArrayBiggerThanSlab() throws Throwable {
-    CapacityByteArrayOutputStream capacityByteArrayOutputStream = 
newCapacityBAOS(10);
-    int v = 23;
-    writeArraysOf3(capacityByteArrayOutputStream, v);
-    int n = v * 3;
-    byte[] toWrite = { // bigger than 2 slabs of size of 10
-      (byte) n,
-      (byte) (n + 1),
-      (byte) (n + 2),
-      (byte) (n + 3),
-      (byte) (n + 4),
-      (byte) (n + 5),
-      (byte) (n + 6),
-      (byte) (n + 7),
-      (byte) (n + 8),
-      (byte) (n + 9),
-      (byte) (n + 10),
-      (byte) (n + 11),
-      (byte) (n + 12),
-      (byte) (n + 13),
-      (byte) (n + 14),
-      (byte) (n + 15),
-      (byte) (n + 16),
-      (byte) (n + 17),
-      (byte) (n + 18),
-      (byte) (n + 19),
-      (byte) (n + 20)
-    };
-    capacityByteArrayOutputStream.write(toWrite);
-    n = n + toWrite.length;
-    assertEquals(n, capacityByteArrayOutputStream.size());
-    validate(capacityByteArrayOutputStream, n);
-    capacityByteArrayOutputStream.reset();
-    // check it works after reset too
-    capacityByteArrayOutputStream.write(toWrite);
-    assertEquals(toWrite.length, capacityByteArrayOutputStream.size());
-    byte[] byteArray = 
BytesInput.from(capacityByteArrayOutputStream).toByteArray();
-    assertEquals(toWrite.length, byteArray.length);
-    for (int i = 0; i < toWrite.length; i++) {
-      assertEquals(toWrite[i], byteArray[i]);
+    try (CapacityByteArrayOutputStream capacityByteArrayOutputStream = 
newCapacityBAOS(10)) {
+      int v = 23;
+      writeArraysOf3(capacityByteArrayOutputStream, v);
+      int n = v * 3;
+      byte[] toWrite = { // bigger than 2 slabs of size of 10
+        (byte) n,
+        (byte) (n + 1),
+        (byte) (n + 2),
+        (byte) (n + 3),
+        (byte) (n + 4),
+        (byte) (n + 5),
+        (byte) (n + 6),
+        (byte) (n + 7),
+        (byte) (n + 8),
+        (byte) (n + 9),
+        (byte) (n + 10),
+        (byte) (n + 11),
+        (byte) (n + 12),
+        (byte) (n + 13),
+        (byte) (n + 14),
+        (byte) (n + 15),
+        (byte) (n + 16),
+        (byte) (n + 17),
+        (byte) (n + 18),
+        (byte) (n + 19),
+        (byte) (n + 20)
+      };
+      capacityByteArrayOutputStream.write(toWrite);
+      n = n + toWrite.length;
+      assertEquals(n, capacityByteArrayOutputStream.size());
+      validate(capacityByteArrayOutputStream, n);
+      capacityByteArrayOutputStream.reset();
+      // check it works after reset too
+      capacityByteArrayOutputStream.write(toWrite);
+      assertEquals(toWrite.length, capacityByteArrayOutputStream.size());
+      byte[] byteArray = 
BytesInput.from(capacityByteArrayOutputStream).toByteArray();
+      assertEquals(toWrite.length, byteArray.length);
+      for (int i = 0; i < toWrite.length; i++) {
+        assertEquals(toWrite[i], byteArray[i]);
+      }
     }
   }
 
   @Test
   public void testWriteArrayManySlabs() throws Throwable {
-    CapacityByteArrayOutputStream capacityByteArrayOutputStream = 
newCapacityBAOS(10);
-    int it = 500;
-    int v = 23;
-    for (int j = 0; j < it; j++) {
-      for (int i = 0; i < v; i++) {
-        byte[] toWrite = {(byte) (i * 3), (byte) (i * 3 + 1), (byte) (i * 3 + 
2)};
-        capacityByteArrayOutputStream.write(toWrite);
-        assertEquals((i + 1) * 3 + v * 3 * j, 
capacityByteArrayOutputStream.size());
+    try (CapacityByteArrayOutputStream capacityByteArrayOutputStream = 
newCapacityBAOS(10)) {
+      int it = 500;
+      int v = 23;
+      for (int j = 0; j < it; j++) {
+        for (int i = 0; i < v; i++) {
+          byte[] toWrite = {(byte) (i * 3), (byte) (i * 3 + 1), (byte) (i * 3 
+ 2)};
+          capacityByteArrayOutputStream.write(toWrite);
+          assertEquals((i + 1) * 3 + v * 3 * j, 
capacityByteArrayOutputStream.size());
+        }
       }
+      byte[] byteArray = 
BytesInput.from(capacityByteArrayOutputStream).toByteArray();
+      assertEquals(v * 3 * it, byteArray.length);
+      for (int i = 0; i < v * 3 * it; i++) {
+        assertEquals(i % (v * 3), byteArray[i]);
+      }
+      // verifying we have not created 500 * 23 / 10 slabs
+      assertTrue(
+          "slab count: " + capacityByteArrayOutputStream.getSlabCount(),
+          capacityByteArrayOutputStream.getSlabCount() <= 20);
+      capacityByteArrayOutputStream.reset();
+      writeArraysOf3(capacityByteArrayOutputStream, v);
+      validate(capacityByteArrayOutputStream, v * 3);
+      // verifying we use less slabs now
+      assertTrue(
+          "slab count: " + capacityByteArrayOutputStream.getSlabCount(),
+          capacityByteArrayOutputStream.getSlabCount() <= 2);
     }
-    byte[] byteArray = 
BytesInput.from(capacityByteArrayOutputStream).toByteArray();
-    assertEquals(v * 3 * it, byteArray.length);
-    for (int i = 0; i < v * 3 * it; i++) {
-      assertEquals(i % (v * 3), byteArray[i]);
-    }
-    // verifying we have not created 500 * 23 / 10 slabs
-    assertTrue(
-        "slab count: " + capacityByteArrayOutputStream.getSlabCount(),
-        capacityByteArrayOutputStream.getSlabCount() <= 20);
-    capacityByteArrayOutputStream.reset();
-    writeArraysOf3(capacityByteArrayOutputStream, v);
-    validate(capacityByteArrayOutputStream, v * 3);
-    // verifying we use less slabs now
-    assertTrue(
-        "slab count: " + capacityByteArrayOutputStream.getSlabCount(),
-        capacityByteArrayOutputStream.getSlabCount() <= 2);
   }
 
   @Test
   public void testReplaceByte() throws Throwable {
     // test replace the first value
-    {
-      CapacityByteArrayOutputStream cbaos = newCapacityBAOS(5);
+    try (CapacityByteArrayOutputStream cbaos = newCapacityBAOS(5)) {
       cbaos.write(10);
       assertEquals(0, cbaos.getCurrentIndex());
       cbaos.setByte(0, (byte) 7);
@@ -200,8 +222,7 @@ public class TestCapacityByteArrayOutputStream {
     }
 
     // test replace value in the first slab
-    {
-      CapacityByteArrayOutputStream cbaos = newCapacityBAOS(5);
+    try (CapacityByteArrayOutputStream cbaos = newCapacityBAOS(5)) {
       cbaos.write(10);
       cbaos.write(13);
       cbaos.write(15);
@@ -215,8 +236,7 @@ public class TestCapacityByteArrayOutputStream {
     }
 
     // test replace in *not* the first slab
-    {
-      CapacityByteArrayOutputStream cbaos = newCapacityBAOS(5);
+    try (CapacityByteArrayOutputStream cbaos = newCapacityBAOS(5)) {
 
       // advance part way through the 3rd slab
       for (int i = 0; i < 12; i++) {
@@ -232,8 +252,7 @@ public class TestCapacityByteArrayOutputStream {
     }
 
     // test replace last value of a slab
-    {
-      CapacityByteArrayOutputStream cbaos = newCapacityBAOS(5);
+    try (CapacityByteArrayOutputStream cbaos = newCapacityBAOS(5)) {
 
       // advance part way through the 3rd slab
       for (int i = 0; i < 12; i++) {
@@ -249,8 +268,7 @@ public class TestCapacityByteArrayOutputStream {
     }
 
     // test replace last value
-    {
-      CapacityByteArrayOutputStream cbaos = newCapacityBAOS(5);
+    try (CapacityByteArrayOutputStream cbaos = newCapacityBAOS(5)) {
 
       // advance part way through the 3rd slab
       for (int i = 0; i < 12; i++) {
diff --git 
a/parquet-hadoop/src/main/java/org/apache/parquet/hadoop/ColumnChunkPageReadStore.java
 
b/parquet-hadoop/src/main/java/org/apache/parquet/hadoop/ColumnChunkPageReadStore.java
index 5c376c8ce..f5cc76162 100644
--- 
a/parquet-hadoop/src/main/java/org/apache/parquet/hadoop/ColumnChunkPageReadStore.java
+++ 
b/parquet-hadoop/src/main/java/org/apache/parquet/hadoop/ColumnChunkPageReadStore.java
@@ -21,6 +21,7 @@ package org.apache.parquet.hadoop;
 import java.io.IOException;
 import java.nio.ByteBuffer;
 import java.util.ArrayDeque;
+import java.util.ArrayList;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
@@ -28,6 +29,7 @@ import java.util.Optional;
 import java.util.PrimitiveIterator;
 import java.util.Queue;
 import org.apache.parquet.ParquetReadOptions;
+import org.apache.parquet.bytes.ByteBufferAllocator;
 import org.apache.parquet.bytes.BytesInput;
 import org.apache.parquet.column.ColumnDescriptor;
 import org.apache.parquet.column.page.DataPage;
@@ -77,6 +79,7 @@ class ColumnChunkPageReadStore implements PageReadStore, 
DictionaryPageReadStore
     private final BlockCipher.Decryptor blockDecryptor;
     private final byte[] dataPageAAD;
     private final byte[] dictionaryPageAAD;
+    private final List<ByteBuffer> toRelease = new ArrayList<>();
 
     ColumnChunkPageReader(
         BytesInputDecompressor decompressor,
@@ -156,6 +159,7 @@ class ColumnChunkPageReadStore implements PageReadStore, 
DictionaryPageReadStore
 
               ByteBuffer decompressedBuffer =
                   
options.getAllocator().allocate(dataPageV1.getUncompressedSize());
+              toRelease.add(decompressedBuffer);
               long start = System.nanoTime();
               decompressor.decompress(
                   byteBuffer,
@@ -238,6 +242,7 @@ class ColumnChunkPageReadStore implements PageReadStore, 
DictionaryPageReadStore
                     - dataPageV2.getRepetitionLevels().size());
                 ByteBuffer decompressedBuffer =
                     options.getAllocator().allocate(uncompressedSize);
+                toRelease.add(decompressedBuffer);
                 long start = System.nanoTime();
                 decompressor.decompress(
                     byteBuffer, (int) compressedSize, decompressedBuffer, 
uncompressedSize);
@@ -344,6 +349,14 @@ class ColumnChunkPageReadStore implements PageReadStore, 
DictionaryPageReadStore
         throw new ParquetDecodingException("Could not decompress dictionary 
page", e);
       }
     }
+
+    private void releaseBuffers() {
+      ByteBufferAllocator allocator = options.getAllocator();
+      for (ByteBuffer buffer : toRelease) {
+        allocator.release(buffer);
+      }
+      toRelease.clear();
+    }
   }
 
   private final Map<ColumnDescriptor, ColumnChunkPageReader> readers =
@@ -351,6 +364,8 @@ class ColumnChunkPageReadStore implements PageReadStore, 
DictionaryPageReadStore
   private final long rowCount;
   private final long rowIndexOffset;
   private final RowRanges rowRanges;
+  private ByteBufferAllocator allocator;
+  private List<ByteBuffer> toRelease;
 
   public ColumnChunkPageReadStore(long rowCount) {
     this(rowCount, -1);
@@ -406,4 +421,19 @@ class ColumnChunkPageReadStore implements PageReadStore, 
DictionaryPageReadStore
       throw new RuntimeException(path + " was added twice");
     }
   }
+
+  void setBuffersToRelease(ByteBufferAllocator allocator, List<ByteBuffer> 
toRelease) {
+    this.allocator = allocator;
+    this.toRelease = toRelease;
+  }
+
+  @Override
+  public void close() {
+    for (ColumnChunkPageReader reader : readers.values()) {
+      reader.releaseBuffers();
+    }
+    for (ByteBuffer buffer : toRelease) {
+      allocator.release(buffer);
+    }
+  }
 }
diff --git 
a/parquet-hadoop/src/main/java/org/apache/parquet/hadoop/ColumnIndexValidator.java
 
b/parquet-hadoop/src/main/java/org/apache/parquet/hadoop/ColumnIndexValidator.java
index 8f18c081f..92f7db413 100644
--- 
a/parquet-hadoop/src/main/java/org/apache/parquet/hadoop/ColumnIndexValidator.java
+++ 
b/parquet-hadoop/src/main/java/org/apache/parquet/hadoop/ColumnIndexValidator.java
@@ -616,6 +616,7 @@ public class ColumnIndexValidator {
             pageValidator.finishPage();
           }
         }
+        rowGroup.close();
         rowGroup = reader.readNextRowGroup();
         rowGroupNumber++;
       }
diff --git 
a/parquet-hadoop/src/main/java/org/apache/parquet/hadoop/InternalParquetRecordReader.java
 
b/parquet-hadoop/src/main/java/org/apache/parquet/hadoop/InternalParquetRecordReader.java
index 271423ce7..c9842c937 100644
--- 
a/parquet-hadoop/src/main/java/org/apache/parquet/hadoop/InternalParquetRecordReader.java
+++ 
b/parquet-hadoop/src/main/java/org/apache/parquet/hadoop/InternalParquetRecordReader.java
@@ -82,6 +82,7 @@ class InternalParquetRecordReader<T> {
   private long totalCountLoadedSoFar = 0;
 
   private UnmaterializableRecordCounter unmaterializableRecordCounter;
+  private PageReadStore currentRowGroup;
 
   /**
    * @param readSupport Object which helps reads files of the given type, e.g. 
Thrift, Avro.
@@ -130,29 +131,40 @@ class InternalParquetRecordReader<T> {
         }
       }
 
+      if (currentRowGroup != null) {
+        currentRowGroup.close();
+      }
+
       LOG.info("at row " + current + ". reading next block");
       long t0 = System.currentTimeMillis();
-      PageReadStore pages = reader.readNextFilteredRowGroup();
-      if (pages == null) {
+      currentRowGroup = reader.readNextFilteredRowGroup();
+      if (currentRowGroup == null) {
         throw new IOException(
             "expecting more rows but reached last block. Read " + current + " 
out of " + total);
       }
-      resetRowIndexIterator(pages);
+      resetRowIndexIterator(currentRowGroup);
       long timeSpentReading = System.currentTimeMillis() - t0;
       totalTimeSpentReadingBytes += timeSpentReading;
       BenchmarkCounter.incrementTime(timeSpentReading);
       if (LOG.isInfoEnabled())
-        LOG.info("block read in memory in {} ms. row count = {}", 
timeSpentReading, pages.getRowCount());
+        LOG.info(
+            "block read in memory in {} ms. row count = {}",
+            timeSpentReading,
+            currentRowGroup.getRowCount());
       LOG.debug("initializing Record assembly with requested schema {}", 
requestedSchema);
       MessageColumnIO columnIO = columnIOFactory.getColumnIO(requestedSchema, 
fileSchema, strictTypeChecking);
-      recordReader = columnIO.getRecordReader(pages, recordConverter, 
filterRecords ? filter : FilterCompat.NOOP);
+      recordReader = columnIO.getRecordReader(
+          currentRowGroup, recordConverter, filterRecords ? filter : 
FilterCompat.NOOP);
       startedAssemblingCurrentBlockAt = System.currentTimeMillis();
-      totalCountLoadedSoFar += pages.getRowCount();
+      totalCountLoadedSoFar += currentRowGroup.getRowCount();
       ++currentBlock;
     }
   }
 
   public void close() throws IOException {
+    if (currentRowGroup != null) {
+      currentRowGroup.close();
+    }
     if (reader != null) {
       reader.close();
     }
diff --git 
a/parquet-hadoop/src/main/java/org/apache/parquet/hadoop/InternalParquetRecordWriter.java
 
b/parquet-hadoop/src/main/java/org/apache/parquet/hadoop/InternalParquetRecordWriter.java
index 2541a7ff3..77bfb6099 100644
--- 
a/parquet-hadoop/src/main/java/org/apache/parquet/hadoop/InternalParquetRecordWriter.java
+++ 
b/parquet-hadoop/src/main/java/org/apache/parquet/hadoop/InternalParquetRecordWriter.java
@@ -198,6 +198,7 @@ class InternalParquetRecordWriter<T> {
       this.nextRowGroupSize = 
Math.min(parquetFileWriter.getNextRowGroupSize(), rowGroupSizeThreshold);
     }
 
+    columnStore.close();
     columnStore = null;
     pageStore = null;
   }
diff --git 
a/parquet-hadoop/src/main/java/org/apache/parquet/hadoop/ParquetFileReader.java 
b/parquet-hadoop/src/main/java/org/apache/parquet/hadoop/ParquetFileReader.java
index 19403c329..6bd71ee8b 100644
--- 
a/parquet-hadoop/src/main/java/org/apache/parquet/hadoop/ParquetFileReader.java
+++ 
b/parquet-hadoop/src/main/java/org/apache/parquet/hadoop/ParquetFileReader.java
@@ -1057,6 +1057,7 @@ public class ParquetFileReader implements Closeable {
     for (ConsecutivePartList consecutiveChunks : allParts) {
       consecutiveChunks.readAll(f, builder);
     }
+    rowGroup.setBuffersToRelease(options.getAllocator(), builder.toRelease);
     for (Chunk chunk : builder.build()) {
       readChunkPages(chunk, block, rowGroup);
     }
@@ -1214,6 +1215,7 @@ public class ParquetFileReader implements Closeable {
     for (ConsecutivePartList consecutiveChunks : allParts) {
       consecutiveChunks.readAll(f, builder);
     }
+    rowGroup.setBuffersToRelease(options.getAllocator(), builder.toRelease);
     for (Chunk chunk : builder.build()) {
       readChunkPages(chunk, block, rowGroup);
     }
@@ -1585,6 +1587,7 @@ public class ParquetFileReader implements Closeable {
     private ChunkDescriptor lastDescriptor;
     private final long rowCount;
     private SeekableInputStream f;
+    private List<ByteBuffer> toRelease = new ArrayList<>();
 
     public ChunkListBuilder(long rowCount) {
       this.rowCount = rowCount;
@@ -1596,6 +1599,10 @@ public class ParquetFileReader implements Closeable {
       this.f = f;
     }
 
+    void addBuffersToRelease(List<ByteBuffer> toRelease) {
+      this.toRelease.addAll(toRelease);
+    }
+
     void setOffsetIndex(ChunkDescriptor descriptor, OffsetIndex offsetIndex) {
       map.computeIfAbsent(descriptor, d -> new ChunkData()).offsetIndex = 
offsetIndex;
     }
@@ -2006,6 +2013,7 @@ public class ParquetFileReader implements Closeable {
       if (lastAllocationSize > 0) {
         buffers.add(options.getAllocator().allocate(lastAllocationSize));
       }
+      builder.addBuffersToRelease(buffers);
 
       long readStart = System.nanoTime();
       for (ByteBuffer buffer : buffers) {
diff --git 
a/parquet-hadoop/src/main/java/org/apache/parquet/hadoop/ParquetWriter.java 
b/parquet-hadoop/src/main/java/org/apache/parquet/hadoop/ParquetWriter.java
index 1838d1db4..c609a11df 100644
--- a/parquet-hadoop/src/main/java/org/apache/parquet/hadoop/ParquetWriter.java
+++ b/parquet-hadoop/src/main/java/org/apache/parquet/hadoop/ParquetWriter.java
@@ -24,6 +24,7 @@ import java.util.HashMap;
 import java.util.Map;
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.fs.Path;
+import org.apache.parquet.bytes.ByteBufferAllocator;
 import org.apache.parquet.column.ParquetProperties;
 import org.apache.parquet.column.ParquetProperties.WriterVersion;
 import org.apache.parquet.compression.CompressionCodecFactory;
@@ -876,6 +877,17 @@ public class ParquetWriter<T> implements Closeable {
       return self();
     }
 
+    /**
+     * Sets the ByteBuffer allocator instance to be used for allocating memory 
for writing.
+     *
+     * @param allocator the allocator instance
+     * @return this builder for method chaining
+     */
+    public SELF withAllocator(ByteBufferAllocator allocator) {
+      encodingPropsBuilder.withAllocator(allocator);
+      return self();
+    }
+
     /**
      * Set a property that will be available to the read path. For writers 
that use a Hadoop
      * configuration, this is the recommended way to add configuration values.
diff --git 
a/parquet-hadoop/src/main/java/org/apache/parquet/hadoop/rewrite/ParquetRewriter.java
 
b/parquet-hadoop/src/main/java/org/apache/parquet/hadoop/rewrite/ParquetRewriter.java
index fac19df17..ed3dbc2ae 100644
--- 
a/parquet-hadoop/src/main/java/org/apache/parquet/hadoop/rewrite/ParquetRewriter.java
+++ 
b/parquet-hadoop/src/main/java/org/apache/parquet/hadoop/rewrite/ParquetRewriter.java
@@ -771,6 +771,7 @@ public class ParquetRewriter implements Closeable {
       cStore.endRecord();
     }
 
+    pageReadStore.close();
     cStore.flush();
     cPageStore.flushToFileWriter(writer);
 
diff --git 
a/parquet-hadoop/src/test/java/org/apache/parquet/crypto/TestPropertiesDrivenEncryption.java
 
b/parquet-hadoop/src/test/java/org/apache/parquet/crypto/TestPropertiesDrivenEncryption.java
index e8f2bef43..76b8be1f7 100644
--- 
a/parquet-hadoop/src/test/java/org/apache/parquet/crypto/TestPropertiesDrivenEncryption.java
+++ 
b/parquet-hadoop/src/test/java/org/apache/parquet/crypto/TestPropertiesDrivenEncryption.java
@@ -42,9 +42,9 @@ import java.util.concurrent.TimeUnit;
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.fs.FileSystem;
 import org.apache.hadoop.fs.Path;
-import org.apache.parquet.bytes.ByteBufferAllocator;
 import org.apache.parquet.bytes.DirectByteBufferAllocator;
 import org.apache.parquet.bytes.HeapByteBufferAllocator;
+import org.apache.parquet.bytes.TrackingByteBufferAllocator;
 import org.apache.parquet.column.ParquetProperties.WriterVersion;
 import org.apache.parquet.crypto.keytools.KeyToolkit;
 import org.apache.parquet.crypto.keytools.PropertiesDrivenCryptoFactory;
@@ -675,13 +675,15 @@ public class TestPropertiesDrivenEncryption {
     }
 
     int rowNum = 0;
-    final ByteBufferAllocator allocator =
-        this.isDecryptionDirectMemory ? new DirectByteBufferAllocator() : new 
HeapByteBufferAllocator();
-    try (ParquetReader<Group> reader = ParquetReader.builder(new 
GroupReadSupport(), file)
-        .withConf(hadoopConfig)
-        .withAllocator(allocator)
-        .withDecryption(fileDecryptionProperties)
-        .build()) {
+    try (TrackingByteBufferAllocator allocator = 
TrackingByteBufferAllocator.wrap(
+            this.isDecryptionDirectMemory
+                ? new DirectByteBufferAllocator()
+                : new HeapByteBufferAllocator());
+        ParquetReader<Group> reader = ParquetReader.builder(new 
GroupReadSupport(), file)
+            .withConf(hadoopConfig)
+            .withAllocator(allocator)
+            .withDecryption(fileDecryptionProperties)
+            .build()) {
       for (Group group = reader.read(); group != null; group = reader.read()) {
         SingleRow rowExpected = data.get(rowNum++);
 
diff --git 
a/parquet-hadoop/src/test/java/org/apache/parquet/encodings/FileEncodingsIT.java
 
b/parquet-hadoop/src/test/java/org/apache/parquet/encodings/FileEncodingsIT.java
index 416495601..f2e6e16fc 100644
--- 
a/parquet-hadoop/src/test/java/org/apache/parquet/encodings/FileEncodingsIT.java
+++ 
b/parquet-hadoop/src/test/java/org/apache/parquet/encodings/FileEncodingsIT.java
@@ -31,6 +31,8 @@ import java.util.Random;
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.fs.Path;
 import org.apache.parquet.bytes.BytesInput;
+import org.apache.parquet.bytes.HeapByteBufferAllocator;
+import org.apache.parquet.bytes.TrackingByteBufferAllocator;
 import org.apache.parquet.column.ColumnDescriptor;
 import org.apache.parquet.column.ParquetProperties.WriterVersion;
 import org.apache.parquet.column.impl.ColumnReaderImpl;
@@ -58,6 +60,8 @@ import 
org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName;
 import org.apache.parquet.schema.Types;
 import org.apache.parquet.statistics.RandomValues;
 import org.apache.parquet.statistics.TestStatistics;
+import org.junit.After;
+import org.junit.Before;
 import org.junit.BeforeClass;
 import org.junit.Rule;
 import org.junit.Test;
@@ -97,6 +101,7 @@ public class FileEncodingsIT {
   // Parameters
   private PrimitiveTypeName paramTypeName;
   private CompressionCodecName compression;
+  private TrackingByteBufferAllocator allocator;
 
   @Parameterized.Parameters
   public static Collection<Object[]> getParameters() {
@@ -151,6 +156,16 @@ public class FileEncodingsIT {
     fixedBinaryGenerator = new RandomValues.FixedGenerator(random.nextLong(), 
FIXED_LENGTH);
   }
 
+  @Before
+  public void initAllocator() {
+    allocator = TrackingByteBufferAllocator.wrap(new 
HeapByteBufferAllocator());
+  }
+
+  @After
+  public void closeAllocator() {
+    allocator.close();
+  }
+
   @Test
   public void testFileEncodingsWithoutDictionary() throws Exception {
     final boolean DISABLE_DICTIONARY = false;
@@ -241,6 +256,7 @@ public class FileEncodingsIT {
     GroupWriteSupport.setSchema(schema, configuration);
 
     ParquetWriter<Group> writer = ExampleParquetWriter.builder(file)
+        .withAllocator(allocator)
         .withCompressionCodec(compression)
         .withRowGroupSize(rowGroupSize)
         .withPageSize(pageSize)
diff --git 
a/parquet-hadoop/src/test/java/org/apache/parquet/filter2/recordlevel/PhoneBookWriter.java
 
b/parquet-hadoop/src/test/java/org/apache/parquet/filter2/recordlevel/PhoneBookWriter.java
index 99ebd73c6..97d836aec 100644
--- 
a/parquet-hadoop/src/test/java/org/apache/parquet/filter2/recordlevel/PhoneBookWriter.java
+++ 
b/parquet-hadoop/src/test/java/org/apache/parquet/filter2/recordlevel/PhoneBookWriter.java
@@ -26,6 +26,9 @@ import java.util.ArrayList;
 import java.util.List;
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.fs.Path;
+import org.apache.parquet.bytes.ByteBufferAllocator;
+import org.apache.parquet.bytes.HeapByteBufferAllocator;
+import org.apache.parquet.bytes.TrackingByteBufferAllocator;
 import org.apache.parquet.example.data.Group;
 import org.apache.parquet.example.data.simple.SimpleGroup;
 import org.apache.parquet.filter2.compat.FilterCompat.Filter;
@@ -320,29 +323,33 @@ public class PhoneBookWriter {
     }
   }
 
-  public static ParquetReader<Group> createReader(Path file, Filter filter) 
throws IOException {
+  public static ParquetReader<Group> createReader(Path file, Filter filter, 
ByteBufferAllocator allocator)
+      throws IOException {
     Configuration conf = new Configuration();
     GroupWriteSupport.setSchema(schema, conf);
 
     return ParquetReader.builder(new GroupReadSupport(), file)
         .withConf(conf)
         .withFilter(filter)
+        .withAllocator(allocator)
         .build();
   }
 
   public static List<Group> readFile(File f, Filter filter) throws IOException 
{
-    ParquetReader<Group> reader = createReader(new Path(f.getAbsolutePath()), 
filter);
+    try (TrackingByteBufferAllocator allocator = 
TrackingByteBufferAllocator.wrap(new HeapByteBufferAllocator());
+        ParquetReader<Group> reader = createReader(new 
Path(f.getAbsolutePath()), filter, allocator)) {
 
-    Group current;
-    List<Group> users = new ArrayList<Group>();
+      Group current;
+      List<Group> users = new ArrayList<Group>();
 
-    current = reader.read();
-    while (current != null) {
-      users.add(current);
       current = reader.read();
-    }
+      while (current != null) {
+        users.add(current);
+        current = reader.read();
+      }
 
-    return users;
+      return users;
+    }
   }
 
   public static List<User> readUsers(ParquetReader.Builder<Group> builder) 
throws IOException {
@@ -356,18 +363,18 @@ public class PhoneBookWriter {
    */
   public static List<User> readUsers(ParquetReader.Builder<Group> builder, 
boolean validateRowIndexes)
       throws IOException {
-    ParquetReader<Group> reader = 
builder.set(GroupWriteSupport.PARQUET_EXAMPLE_SCHEMA, schema.toString())
-        .build();
-
-    List<User> users = new ArrayList<>();
-    for (Group group = reader.read(); group != null; group = reader.read()) {
-      User u = userFromGroup(group);
-      users.add(u);
-      if (validateRowIndexes) {
-        assertEquals("Row index should be equal to User id", u.id, 
reader.getCurrentRowIndex());
+    try (ParquetReader<Group> reader = 
builder.set(GroupWriteSupport.PARQUET_EXAMPLE_SCHEMA, schema.toString())
+        .build()) {
+      List<User> users = new ArrayList<>();
+      for (Group group = reader.read(); group != null; group = reader.read()) {
+        User u = userFromGroup(group);
+        users.add(u);
+        if (validateRowIndexes) {
+          assertEquals("Row index should be equal to User id", u.id, 
reader.getCurrentRowIndex());
+        }
       }
+      return users;
     }
-    return users;
   }
 
   public static void main(String[] args) throws IOException {
diff --git 
a/parquet-hadoop/src/test/java/org/apache/parquet/hadoop/TestColumnChunkPageWriteStore.java
 
b/parquet-hadoop/src/test/java/org/apache/parquet/hadoop/TestColumnChunkPageWriteStore.java
index ec22ce086..58b48e0ea 100644
--- 
a/parquet-hadoop/src/test/java/org/apache/parquet/hadoop/TestColumnChunkPageWriteStore.java
+++ 
b/parquet-hadoop/src/test/java/org/apache/parquet/hadoop/TestColumnChunkPageWriteStore.java
@@ -50,6 +50,7 @@ import org.apache.parquet.bytes.BytesInput;
 import org.apache.parquet.bytes.DirectByteBufferAllocator;
 import org.apache.parquet.bytes.HeapByteBufferAllocator;
 import org.apache.parquet.bytes.LittleEndianDataInputStream;
+import org.apache.parquet.bytes.TrackingByteBufferAllocator;
 import org.apache.parquet.column.ColumnDescriptor;
 import org.apache.parquet.column.Encoding;
 import org.apache.parquet.column.page.DataPageV2;
@@ -75,6 +76,7 @@ import org.apache.parquet.schema.MessageType;
 import org.apache.parquet.schema.MessageTypeParser;
 import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName;
 import org.apache.parquet.schema.Types;
+import org.junit.After;
 import org.junit.Before;
 import org.junit.Test;
 import org.mockito.InOrder;
@@ -124,15 +126,21 @@ public class TestColumnChunkPageWriteStore {
   private int pageSize = 1024;
   private int initialSize = 1024;
   private Configuration conf;
+  private TrackingByteBufferAllocator allocator;
 
   @Before
   public void initConfiguration() {
     this.conf = new Configuration();
   }
 
+  @After
+  public void closeAllocator() {
+    allocator.close();
+  }
+
   @Test
   public void test() throws Exception {
-    test(conf, new HeapByteBufferAllocator());
+    test(conf, allocator = TrackingByteBufferAllocator.wrap(new 
HeapByteBufferAllocator()));
   }
 
   @Test
@@ -141,7 +149,7 @@ public class TestColumnChunkPageWriteStore {
     // we want to test the path with direct buffers so we need to enable this 
config as well
     // even though this file is not encrypted
     config.set(ParquetInputFormat.OFF_HEAP_DECRYPT_BUFFER_ENABLED, "true");
-    test(config, new DirectByteBufferAllocator());
+    test(config, allocator = TrackingByteBufferAllocator.wrap(new 
DirectByteBufferAllocator()));
   }
 
   public void test(Configuration config, ByteBufferAllocator allocator) throws 
Exception {
@@ -269,10 +277,11 @@ public class TestColumnChunkPageWriteStore {
     int fakeCount = 3;
     BinaryStatistics fakeStats = new BinaryStatistics();
 
-    // TODO - look back at this, an allocator was being passed here in the 
ByteBuffer changes
-    // see comment at this constructor
     ColumnChunkPageWriteStore store = new ColumnChunkPageWriteStore(
-        compressor(UNCOMPRESSED), schema, new HeapByteBufferAllocator(), 
Integer.MAX_VALUE);
+        compressor(UNCOMPRESSED),
+        schema,
+        allocator = TrackingByteBufferAllocator.wrap(new 
HeapByteBufferAllocator()),
+        Integer.MAX_VALUE);
 
     for (ColumnDescriptor col : schema.getColumns()) {
       PageWriter pageWriter = store.getPageWriter(col);
diff --git 
a/parquet-hadoop/src/test/java/org/apache/parquet/hadoop/TestColumnIndexFiltering.java
 
b/parquet-hadoop/src/test/java/org/apache/parquet/hadoop/TestColumnIndexFiltering.java
index 60415bfcf..154dd6f5c 100644
--- 
a/parquet-hadoop/src/test/java/org/apache/parquet/hadoop/TestColumnIndexFiltering.java
+++ 
b/parquet-hadoop/src/test/java/org/apache/parquet/hadoop/TestColumnIndexFiltering.java
@@ -65,6 +65,8 @@ import java.util.stream.Collectors;
 import java.util.stream.Stream;
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.fs.Path;
+import org.apache.parquet.bytes.HeapByteBufferAllocator;
+import org.apache.parquet.bytes.TrackingByteBufferAllocator;
 import org.apache.parquet.column.ParquetProperties;
 import org.apache.parquet.column.ParquetProperties.WriterVersion;
 import org.apache.parquet.crypto.ColumnEncryptionProperties;
@@ -87,7 +89,9 @@ import org.apache.parquet.hadoop.metadata.ColumnPath;
 import org.apache.parquet.io.api.Binary;
 import org.apache.parquet.schema.MessageType;
 import org.apache.parquet.schema.Types;
+import org.junit.After;
 import org.junit.AfterClass;
+import org.junit.Before;
 import org.junit.BeforeClass;
 import org.junit.Test;
 import org.junit.runner.RunWith;
@@ -142,6 +146,17 @@ public class TestColumnIndexFiltering {
 
   private final Path file;
   private final boolean isEncrypted;
+  private TrackingByteBufferAllocator allocator;
+
+  @Before
+  public void initAllocator() {
+    allocator = TrackingByteBufferAllocator.wrap(new 
HeapByteBufferAllocator());
+  }
+
+  @After
+  public void closeAllocator() {
+    allocator.close();
+  }
 
   public TestColumnIndexFiltering(Path file, boolean isEncrypted) {
     this.file = file;
@@ -245,6 +260,7 @@ public class TestColumnIndexFiltering {
     FileDecryptionProperties decryptionProperties = 
getFileDecryptionProperties();
     return PhoneBookWriter.readUsers(
         ParquetReader.builder(new GroupReadSupport(), file)
+            .withAllocator(allocator)
             .withFilter(filter)
             .withDecryption(decryptionProperties)
             .useDictionaryFilter(useOtherFiltering)
@@ -336,14 +352,17 @@ public class TestColumnIndexFiltering {
     int pageSize = DATA.size() / 10; // Ensure that several pages will be 
created
     int rowGroupSize = pageSize * 6 * 5; // Ensure that there are more 
row-groups created
 
-    PhoneBookWriter.write(
-        ExampleParquetWriter.builder(file)
-            .withWriteMode(OVERWRITE)
-            .withRowGroupSize(rowGroupSize)
-            .withPageSize(pageSize)
-            .withEncryption(encryptionProperties)
-            .withWriterVersion(parquetVersion),
-        DATA);
+    try (TrackingByteBufferAllocator allocator = 
TrackingByteBufferAllocator.wrap(new HeapByteBufferAllocator())) {
+      PhoneBookWriter.write(
+          ExampleParquetWriter.builder(file)
+              .withAllocator(allocator)
+              .withWriteMode(OVERWRITE)
+              .withRowGroupSize(rowGroupSize)
+              .withPageSize(pageSize)
+              .withEncryption(encryptionProperties)
+              .withWriterVersion(parquetVersion),
+          DATA);
+    }
   }
 
   private static FileEncryptionProperties getFileEncryptionProperties() {
diff --git 
a/parquet-hadoop/src/test/java/org/apache/parquet/hadoop/TestParquetReader.java 
b/parquet-hadoop/src/test/java/org/apache/parquet/hadoop/TestParquetReader.java
index 7eae14f61..db14f6915 100644
--- 
a/parquet-hadoop/src/test/java/org/apache/parquet/hadoop/TestParquetReader.java
+++ 
b/parquet-hadoop/src/test/java/org/apache/parquet/hadoop/TestParquetReader.java
@@ -35,13 +35,17 @@ import java.util.List;
 import java.util.Set;
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.fs.Path;
+import org.apache.parquet.bytes.HeapByteBufferAllocator;
+import org.apache.parquet.bytes.TrackingByteBufferAllocator;
 import org.apache.parquet.column.ParquetProperties;
 import org.apache.parquet.example.data.Group;
 import org.apache.parquet.filter2.compat.FilterCompat;
 import org.apache.parquet.filter2.recordlevel.PhoneBookWriter;
 import org.apache.parquet.hadoop.example.ExampleParquetWriter;
 import org.apache.parquet.hadoop.example.GroupReadSupport;
+import org.junit.After;
 import org.junit.AfterClass;
+import org.junit.Before;
 import org.junit.BeforeClass;
 import org.junit.Test;
 import org.junit.runner.RunWith;
@@ -58,6 +62,7 @@ public class TestParquetReader {
 
   private final Path file;
   private final long fileSize;
+  private TrackingByteBufferAllocator allocator;
 
   private static Path createPathFromCP(String path) {
     try {
@@ -150,6 +155,7 @@ public class TestParquetReader {
       throws IOException {
     return PhoneBookWriter.readUsers(
         ParquetReader.builder(new GroupReadSupport(), file)
+            .withAllocator(allocator)
             .withFilter(filter)
             .useDictionaryFilter(useOtherFiltering)
             .useStatsFilter(useOtherFiltering)
@@ -159,9 +165,19 @@ public class TestParquetReader {
         true);
   }
 
+  @Before
+  public void initAllocator() {
+    allocator = TrackingByteBufferAllocator.wrap(new 
HeapByteBufferAllocator());
+  }
+
+  @After
+  public void closeAllocator() {
+    allocator.close();
+  }
+
   @Test
   public void testCurrentRowIndex() throws Exception {
-    ParquetReader<Group> reader = PhoneBookWriter.createReader(file, 
FilterCompat.NOOP);
+    ParquetReader<Group> reader = PhoneBookWriter.createReader(file, 
FilterCompat.NOOP, allocator);
     // Fetch row index without processing any row.
     assertEquals(reader.getCurrentRowIndex(), -1);
     reader.read();
diff --git 
a/parquet-hadoop/src/test/java/org/apache/parquet/hadoop/TestParquetWriter.java 
b/parquet-hadoop/src/test/java/org/apache/parquet/hadoop/TestParquetWriter.java
index fa9ee865d..55d132c19 100644
--- 
a/parquet-hadoop/src/test/java/org/apache/parquet/hadoop/TestParquetWriter.java
+++ 
b/parquet-hadoop/src/test/java/org/apache/parquet/hadoop/TestParquetWriter.java
@@ -48,6 +48,8 @@ import net.openhft.hashing.LongHashFunction;
 import org.apache.commons.lang3.RandomStringUtils;
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.fs.Path;
+import org.apache.parquet.bytes.HeapByteBufferAllocator;
+import org.apache.parquet.bytes.TrackingByteBufferAllocator;
 import org.apache.parquet.column.Encoding;
 import org.apache.parquet.column.ParquetProperties;
 import org.apache.parquet.column.ParquetProperties.WriterVersion;
@@ -70,7 +72,9 @@ import org.apache.parquet.schema.GroupType;
 import org.apache.parquet.schema.InvalidSchemaException;
 import org.apache.parquet.schema.MessageType;
 import org.apache.parquet.schema.Types;
+import org.junit.After;
 import org.junit.Assert;
+import org.junit.Before;
 import org.junit.Rule;
 import org.junit.Test;
 import org.junit.rules.TemporaryFolder;
@@ -109,6 +113,18 @@ public class TestParquetWriter {
     }
   }
 
+  private TrackingByteBufferAllocator allocator;
+
+  @Before
+  public void initAllocator() {
+    allocator = TrackingByteBufferAllocator.wrap(new 
HeapByteBufferAllocator());
+  }
+
+  @After
+  public void closeAllocator() {
+    allocator.close();
+  }
+
   @Test
   public void test() throws Exception {
     Configuration conf = new Configuration();
@@ -135,6 +151,7 @@ public class TestParquetWriter {
       for (WriterVersion version : WriterVersion.values()) {
         Path file = new Path(root, version.name() + "_" + modulo);
         ParquetWriter<Group> writer = ExampleParquetWriter.builder(new 
TestOutputFile(file, conf))
+            .withAllocator(allocator)
             .withCompressionCodec(UNCOMPRESSED)
             .withRowGroupSize(1024)
             .withPageSize(1024)
@@ -204,6 +221,7 @@ public class TestParquetWriter {
     TestUtils.assertThrows(
         "Should reject a schema with an empty group", 
InvalidSchemaException.class, (Callable<Void>) () -> {
           ExampleParquetWriter.builder(new Path(file.toString()))
+              .withAllocator(allocator)
               .withType(Types.buildMessage()
                   .addField(new GroupType(REQUIRED, "invalid_group"))
                   .named("invalid_message"))
@@ -235,6 +253,7 @@ public class TestParquetWriter {
     file.delete();
     Path path = new Path(file.getAbsolutePath());
     try (ParquetWriter<Group> writer = ExampleParquetWriter.builder(path)
+        .withAllocator(allocator)
         .withPageRowCountLimit(10)
         .withConf(conf)
         .build()) {
@@ -271,6 +290,7 @@ public class TestParquetWriter {
     file.delete();
     Path path = new Path(file.getAbsolutePath());
     try (ParquetWriter<Group> writer = ExampleParquetWriter.builder(path)
+        .withAllocator(allocator)
         .withPageRowCountLimit(10)
         .withConf(conf)
         .withDictionaryEncoding(false)
@@ -321,6 +341,7 @@ public class TestParquetWriter {
       file.delete();
       Path path = new Path(file.getAbsolutePath());
       try (ParquetWriter<Group> writer = ExampleParquetWriter.builder(path)
+          .withAllocator(allocator)
           .withPageRowCountLimit(10)
           .withConf(conf)
           .withDictionaryEncoding(false)
@@ -381,6 +402,7 @@ public class TestParquetWriter {
     file.delete();
     Path path = new Path(file.getAbsolutePath());
     try (ParquetWriter<Group> writer = ExampleParquetWriter.builder(path)
+        .withAllocator(allocator)
         .withConf(conf)
         .withDictionaryEncoding(false)
         .withBloomFilterEnabled("name", true)
@@ -482,6 +504,7 @@ public class TestParquetWriter {
     temp.delete();
     Path path = new Path(file.getAbsolutePath());
     try (ParquetWriter<Group> writer = ExampleParquetWriter.builder(path)
+        .withAllocator(allocator)
         .withConf(conf)
         // Set row group size to 1, to make sure we flush every time
         // minRowCountForPageSizeCheck or maxRowCountForPageSizeCheck is 
exceeded
diff --git 
a/parquet-hadoop/src/test/java/org/apache/parquet/hadoop/TestStoreBloomFilter.java
 
b/parquet-hadoop/src/test/java/org/apache/parquet/hadoop/TestStoreBloomFilter.java
index 85e0112b3..701dcc419 100644
--- 
a/parquet-hadoop/src/test/java/org/apache/parquet/hadoop/TestStoreBloomFilter.java
+++ 
b/parquet-hadoop/src/test/java/org/apache/parquet/hadoop/TestStoreBloomFilter.java
@@ -31,6 +31,8 @@ import java.util.List;
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.fs.Path;
 import org.apache.parquet.ParquetReadOptions;
+import org.apache.parquet.bytes.HeapByteBufferAllocator;
+import org.apache.parquet.bytes.TrackingByteBufferAllocator;
 import org.apache.parquet.column.EncodingStats;
 import org.apache.parquet.column.ParquetProperties;
 import org.apache.parquet.filter2.recordlevel.PhoneBookWriter;
@@ -77,27 +79,29 @@ public class TestStoreBloomFilter {
 
   @Test
   public void testStoreBloomFilter() throws IOException {
-    ParquetFileReader reader = new ParquetFileReader(
-        HadoopInputFile.fromPath(file, new Configuration()),
-        ParquetReadOptions.builder().build());
-    List<BlockMetaData> blocks = reader.getRowGroups();
-    blocks.forEach(block -> {
-      try {
-        // column `id` isn't fully encoded in dictionary, it will generate 
`BloomFilter`
-        ColumnChunkMetaData idMeta = block.getColumns().get(0);
-        EncodingStats idEncoding = idMeta.getEncodingStats();
-        Assert.assertTrue(idEncoding.hasNonDictionaryEncodedPages());
-        Assert.assertNotNull(reader.readBloomFilter(idMeta));
+    try (TrackingByteBufferAllocator allocator = 
TrackingByteBufferAllocator.wrap(new HeapByteBufferAllocator());
+        ParquetFileReader reader = new ParquetFileReader(
+            HadoopInputFile.fromPath(file, new Configuration()),
+            ParquetReadOptions.builder().withAllocator(allocator).build())) {
+      List<BlockMetaData> blocks = reader.getRowGroups();
+      blocks.forEach(block -> {
+        try {
+          // column `id` isn't fully encoded in dictionary, it will generate 
`BloomFilter`
+          ColumnChunkMetaData idMeta = block.getColumns().get(0);
+          EncodingStats idEncoding = idMeta.getEncodingStats();
+          Assert.assertTrue(idEncoding.hasNonDictionaryEncodedPages());
+          Assert.assertNotNull(reader.readBloomFilter(idMeta));
 
-        // column `name` is fully encoded in dictionary, it won't generate 
`BloomFilter`
-        ColumnChunkMetaData nameMeta = block.getColumns().get(1);
-        EncodingStats nameEncoding = nameMeta.getEncodingStats();
-        Assert.assertFalse(nameEncoding.hasNonDictionaryEncodedPages());
-        Assert.assertNull(reader.readBloomFilter(nameMeta));
-      } catch (IOException e) {
-        e.printStackTrace();
-      }
-    });
+          // column `name` is fully encoded in dictionary, it won't generate 
`BloomFilter`
+          ColumnChunkMetaData nameMeta = block.getColumns().get(1);
+          EncodingStats nameEncoding = nameMeta.getEncodingStats();
+          Assert.assertFalse(nameEncoding.hasNonDictionaryEncodedPages());
+          Assert.assertNull(reader.readBloomFilter(nameMeta));
+        } catch (IOException e) {
+          e.printStackTrace();
+        }
+      });
+    }
   }
 
   private static Path createTempFile(String version) {
@@ -118,14 +122,17 @@ public class TestStoreBloomFilter {
       throws IOException {
     int pageSize = DATA.size() / 100; // Ensure that several pages will be 
created
     int rowGroupSize = pageSize * 4; // Ensure that there are more row-groups 
created
-    PhoneBookWriter.write(
-        ExampleParquetWriter.builder(file)
-            .withWriteMode(OVERWRITE)
-            .withRowGroupSize(rowGroupSize)
-            .withPageSize(pageSize)
-            .withBloomFilterNDV("id", 10000L)
-            .withBloomFilterNDV("name", 10000L)
-            .withWriterVersion(parquetVersion),
-        DATA);
+    try (TrackingByteBufferAllocator allocator = 
TrackingByteBufferAllocator.wrap(new HeapByteBufferAllocator())) {
+      PhoneBookWriter.write(
+          ExampleParquetWriter.builder(file)
+              .withAllocator(allocator)
+              .withWriteMode(OVERWRITE)
+              .withRowGroupSize(rowGroupSize)
+              .withPageSize(pageSize)
+              .withBloomFilterNDV("id", 10000L)
+              .withBloomFilterNDV("name", 10000L)
+              .withWriterVersion(parquetVersion),
+          DATA);
+    }
   }
 }

Reply via email to