This is an automated email from the ASF dual-hosted git repository.

gabor pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/parquet-mr.git


The following commit(s) were added to refs/heads/master by this push:
     new cf294a35c PARQUET-2436: More optimal memory usage in compression 
codecs (#1280)
cf294a35c is described below

commit cf294a35ca4b635f44357fe4c170f3b66870dc6e
Author: Gabor Szadovszky <[email protected]>
AuthorDate: Tue Feb 27 09:53:53 2024 +0100

    PARQUET-2436: More optimal memory usage in compression codecs (#1280)
    
    Rewritten direct codec implementations to have common memory handling
---
 .../parquet/bytes/ReusingByteBufferAllocator.java  |  52 ++++-
 .../bytes/TestReusingByteBufferAllocator.java      |  85 +++++--
 .../apache/parquet/hadoop/DirectCodecFactory.java  | 252 +++++++++++++++------
 .../java/org/apache/parquet/hadoop/DirectZstd.java | 155 -------------
 .../apache/parquet/hadoop/ParquetFileReader.java   |   8 +-
 .../apache/parquet/hadoop/ParquetFileWriter.java   |   4 +-
 .../parquet/hadoop/TestDirectCodecFactory.java     |  25 +-
 7 files changed, 297 insertions(+), 284 deletions(-)

diff --git 
a/parquet-common/src/main/java/org/apache/parquet/bytes/ReusingByteBufferAllocator.java
 
b/parquet-common/src/main/java/org/apache/parquet/bytes/ReusingByteBufferAllocator.java
index 83b77b0ca..81ba67fe3 100644
--- 
a/parquet-common/src/main/java/org/apache/parquet/bytes/ReusingByteBufferAllocator.java
+++ 
b/parquet-common/src/main/java/org/apache/parquet/bytes/ReusingByteBufferAllocator.java
@@ -25,7 +25,7 @@ import java.nio.ByteBuffer;
  * next {@link #allocate(int)} call. The {@link #close()} shall be called when 
this allocator is not needed anymore to
  * really release the one buffer.
  */
-public class ReusingByteBufferAllocator implements ByteBufferAllocator, 
AutoCloseable {
+public abstract class ReusingByteBufferAllocator implements 
ByteBufferAllocator, AutoCloseable {
 
   private final ByteBufferAllocator allocator;
   private final ByteBufferReleaser releaser = new ByteBufferReleaser(this);
@@ -33,12 +33,46 @@ public class ReusingByteBufferAllocator implements 
ByteBufferAllocator, AutoClos
   private ByteBuffer bufferOut;
 
   /**
-   * Constructs a new {@link ReusingByteBufferAllocator} object with the 
specified "parent" allocator to be used for
+   * Creates a new strict {@link ReusingByteBufferAllocator} object with the 
specified "parent" allocator to be used for
    * allocating/releasing the one buffer.
+   * <p>
+   * Strict means it is enforced that {@link #release(ByteBuffer)} is invoked 
before a new {@link #allocate(int)} can be
+   * called.
    *
    * @param allocator the allocator to be used for allocating/releasing the 
one buffer
+   * @return a new strict {@link ReusingByteBufferAllocator} object
    */
-  public ReusingByteBufferAllocator(ByteBufferAllocator allocator) {
+  public static ReusingByteBufferAllocator strict(ByteBufferAllocator 
allocator) {
+    return new ReusingByteBufferAllocator(allocator) {
+      @Override
+      void allocateCheck(ByteBuffer bufferOut) {
+        if (bufferOut != null) {
+          throw new IllegalStateException("The single buffer is not yet 
released");
+        }
+      }
+    };
+  }
+
+  /**
+   * Creates a new unsafe {@link ReusingByteBufferAllocator} object with the 
specified "parent" allocator to be used for
+   * allocating/releasing the one buffer.
+   * <p>
+   * Unsafe means it is not enforced that {@link #release(ByteBuffer)} is 
invoked before a new {@link #allocate(int)}
+   * can be called, i.e. no exceptions will be thrown at {@link 
#allocate(int)}.
+   *
+   * @param allocator the allocator to be used for allocating/releasing the 
one buffer
+   * @return a new unsafe {@link ReusingByteBufferAllocator} object
+   */
+  public static ReusingByteBufferAllocator unsafe(ByteBufferAllocator 
allocator) {
+    return new ReusingByteBufferAllocator(allocator) {
+      @Override
+      void allocateCheck(ByteBuffer bufferOut) {
+        // no-op
+      }
+    };
+  }
+
+  private ReusingByteBufferAllocator(ByteBufferAllocator allocator) {
     this.allocator = allocator;
   }
 
@@ -54,13 +88,13 @@ public class ReusingByteBufferAllocator implements 
ByteBufferAllocator, AutoClos
   /**
    * {@inheritDoc}
    *
-   * @throws IllegalStateException if the one buffer was not released yet
+   * @throws IllegalStateException if strict and the one buffer was not 
released yet
+   * @see #strict(ByteBufferAllocator)
+   * @see #unsafe(ByteBufferAllocator)
    */
   @Override
   public ByteBuffer allocate(int size) {
-    if (bufferOut != null) {
-      throw new IllegalStateException("The single buffer is not yet released");
-    }
+    allocateCheck(bufferOut);
     if (buffer == null) {
       bufferOut = buffer = allocator.allocate(size);
     } else if (buffer.capacity() < size) {
@@ -74,10 +108,12 @@ public class ReusingByteBufferAllocator implements 
ByteBufferAllocator, AutoClos
     return bufferOut;
   }
 
+  abstract void allocateCheck(ByteBuffer bufferOut);
+
   /**
    * {@inheritDoc}
    *
-   * @throws IllegalStateException    if the one has already been released or 
never allocated
+   * @throws IllegalStateException    if the one buffer has already been 
released or never allocated
    * @throws IllegalArgumentException if the specified buffer is not the one 
allocated by this allocator
    */
   @Override
diff --git 
a/parquet-common/src/test/java/org/apache/parquet/bytes/TestReusingByteBufferAllocator.java
 
b/parquet-common/src/test/java/org/apache/parquet/bytes/TestReusingByteBufferAllocator.java
index cc585d7d4..e200b5872 100644
--- 
a/parquet-common/src/test/java/org/apache/parquet/bytes/TestReusingByteBufferAllocator.java
+++ 
b/parquet-common/src/test/java/org/apache/parquet/bytes/TestReusingByteBufferAllocator.java
@@ -23,7 +23,10 @@ import static org.junit.Assert.assertThrows;
 
 import java.nio.ByteBuffer;
 import java.nio.InvalidMarkException;
+import java.util.ArrayList;
+import java.util.List;
 import java.util.Random;
+import java.util.function.Function;
 import org.junit.After;
 import org.junit.Before;
 import org.junit.Test;
@@ -35,6 +38,20 @@ import org.junit.runners.Parameterized.Parameters;
 @RunWith(Parameterized.class)
 public class TestReusingByteBufferAllocator {
 
+  private enum AllocatorType {
+    STRICT(ReusingByteBufferAllocator::strict),
+    UNSAFE(ReusingByteBufferAllocator::unsafe);
+    private final Function<ByteBufferAllocator, ReusingByteBufferAllocator> 
factory;
+
+    AllocatorType(Function<ByteBufferAllocator, ReusingByteBufferAllocator> 
factory) {
+      this.factory = factory;
+    }
+
+    public ReusingByteBufferAllocator create(ByteBufferAllocator allocator) {
+      return factory.apply(allocator);
+    }
+  }
+
   private static final Random RANDOM = new Random(2024_02_22_09_51L);
 
   private TrackingByteBufferAllocator allocator;
@@ -42,26 +59,31 @@ public class TestReusingByteBufferAllocator {
   @Parameter
   public ByteBufferAllocator innerAllocator;
 
-  @Parameters(name = "{0}")
-  public static Object[][] parameters() {
-    return new Object[][] {
-      {
-        new HeapByteBufferAllocator() {
-          @Override
-          public String toString() {
-            return "HEAP";
-          }
+  @Parameter(1)
+  public AllocatorType type;
+
+  @Parameters(name = "{0} {1}")
+  public static List<Object[]> parameters() {
+    List<Object[]> params = new ArrayList<>();
+    for (Object allocator : new Object[] {
+      new HeapByteBufferAllocator() {
+        @Override
+        public String toString() {
+          return "HEAP";
         }
       },
-      {
-        new DirectByteBufferAllocator() {
-          @Override
-          public String toString() {
-            return "DIRECT";
-          }
+      new DirectByteBufferAllocator() {
+        @Override
+        public String toString() {
+          return "DIRECT";
         }
       }
-    };
+    }) {
+      for (Object type : AllocatorType.values()) {
+        params.add(new Object[] {allocator, type});
+      }
+    }
+    return params;
   }
 
   @Before
@@ -76,7 +98,7 @@ public class TestReusingByteBufferAllocator {
 
   @Test
   public void normalUseCase() {
-    try (ReusingByteBufferAllocator reusingAllocator = new 
ReusingByteBufferAllocator(allocator)) {
+    try (ReusingByteBufferAllocator reusingAllocator = type.create(allocator)) 
{
       assertEquals(innerAllocator.isDirect(), reusingAllocator.isDirect());
       for (int i = 0; i < 10; ++i) {
         try (ByteBufferReleaser releaser = reusingAllocator.getReleaser()) {
@@ -84,11 +106,7 @@ public class TestReusingByteBufferAllocator {
           ByteBuffer buf = reusingAllocator.allocate(size);
           releaser.releaseLater(buf);
 
-          assertEquals(0, buf.position());
-          assertEquals(size, buf.capacity());
-          assertEquals(size, buf.remaining());
-          assertEquals(allocator.isDirect(), buf.isDirect());
-          assertThrows(InvalidMarkException.class, buf::reset);
+          validateBuffer(buf, size);
 
           // Let's see if the next allocate would clear the buffer
           buf.position(buf.capacity() / 2);
@@ -102,10 +120,18 @@ public class TestReusingByteBufferAllocator {
     }
   }
 
+  private void validateBuffer(ByteBuffer buf, int size) {
+    assertEquals(0, buf.position());
+    assertEquals(size, buf.capacity());
+    assertEquals(size, buf.remaining());
+    assertEquals(allocator.isDirect(), buf.isDirect());
+    assertThrows(InvalidMarkException.class, buf::reset);
+  }
+
   @Test
   public void validateExceptions() {
     try (ByteBufferReleaser releaser = new ByteBufferReleaser(allocator);
-        ReusingByteBufferAllocator reusingAllocator = new 
ReusingByteBufferAllocator(allocator)) {
+        ReusingByteBufferAllocator reusingAllocator = type.create(allocator)) {
       ByteBuffer fromOther = allocator.allocate(10);
       releaser.releaseLater(fromOther);
 
@@ -114,11 +140,20 @@ public class TestReusingByteBufferAllocator {
       ByteBuffer fromReusing = reusingAllocator.allocate(10);
 
       assertThrows(IllegalArgumentException.class, () -> 
reusingAllocator.release(fromOther));
-      assertThrows(IllegalStateException.class, () -> 
reusingAllocator.allocate(10));
+      switch (type) {
+        case STRICT:
+          assertThrows(IllegalStateException.class, () -> 
reusingAllocator.allocate(5));
+          break;
+        case UNSAFE:
+          fromReusing = reusingAllocator.allocate(5);
+          validateBuffer(fromReusing, 5);
+          break;
+      }
 
       reusingAllocator.release(fromReusing);
+      ByteBuffer fromReusingFinal = fromReusing;
       assertThrows(IllegalStateException.class, () -> 
reusingAllocator.release(fromOther));
-      assertThrows(IllegalStateException.class, () -> 
reusingAllocator.release(fromReusing));
+      assertThrows(IllegalStateException.class, () -> 
reusingAllocator.release(fromReusingFinal));
     }
   }
 }
diff --git 
a/parquet-hadoop/src/main/java/org/apache/parquet/hadoop/DirectCodecFactory.java
 
b/parquet-hadoop/src/main/java/org/apache/parquet/hadoop/DirectCodecFactory.java
index 6d166c5df..3e2ad10b1 100644
--- 
a/parquet-hadoop/src/main/java/org/apache/parquet/hadoop/DirectCodecFactory.java
+++ 
b/parquet-hadoop/src/main/java/org/apache/parquet/hadoop/DirectCodecFactory.java
@@ -17,6 +17,9 @@
  */
 package org.apache.parquet.hadoop;
 
+import com.github.luben.zstd.Zstd;
+import com.github.luben.zstd.ZstdCompressCtx;
+import com.github.luben.zstd.ZstdDecompressCtx;
 import java.io.IOException;
 import java.lang.reflect.InvocationTargetException;
 import java.lang.reflect.Method;
@@ -34,8 +37,12 @@ import org.apache.hadoop.io.compress.Decompressor;
 import org.apache.parquet.ParquetRuntimeException;
 import org.apache.parquet.Preconditions;
 import org.apache.parquet.bytes.ByteBufferAllocator;
+import org.apache.parquet.bytes.ByteBufferReleaser;
 import org.apache.parquet.bytes.BytesInput;
+import org.apache.parquet.bytes.ReusingByteBufferAllocator;
+import org.apache.parquet.hadoop.codec.ZstandardCodec;
 import org.apache.parquet.hadoop.metadata.CompressionCodecName;
+import org.apache.parquet.util.AutoCloseables;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 import org.xerial.snappy.Snappy;
@@ -87,25 +94,6 @@ class DirectCodecFactory extends CodecFactory implements 
AutoCloseable {
         getClass().getSimpleName());
   }
 
-  private ByteBuffer ensure(ByteBuffer buffer, int size) {
-    if (buffer == null) {
-      buffer = allocator.allocate(size);
-    } else if (buffer.capacity() >= size) {
-      buffer.clear();
-    } else {
-      release(buffer);
-      buffer = allocator.allocate(size);
-    }
-    return buffer;
-  }
-
-  ByteBuffer release(ByteBuffer buffer) {
-    if (buffer != null) {
-      allocator.release(buffer);
-    }
-    return null;
-  }
-
   @Override
   protected BytesCompressor createCompressor(final CompressionCodecName 
codecName) {
 
@@ -116,7 +104,7 @@ class DirectCodecFactory extends CodecFactory implements 
AutoCloseable {
       // avoid using the default Snappy codec since it allocates direct 
buffers at awkward spots.
       return new SnappyCompressor();
     } else if (codecName == CompressionCodecName.ZSTD) {
-      return DirectZstd.createCompressor(configuration, pageSize);
+      return new ZstdCompressor();
     } else {
       // todo: create class similar to the SnappyCompressor for zlib and 
exclude it as
       // snappy is above since it also generates allocateDirect calls.
@@ -132,7 +120,7 @@ class DirectCodecFactory extends CodecFactory implements 
AutoCloseable {
     } else if (codecName == CompressionCodecName.SNAPPY) {
       return new SnappyDecompressor();
     } else if (codecName == CompressionCodecName.ZSTD) {
-      return DirectZstd.createDecompressor(configuration);
+      return new ZstdDecompressor();
     } else if 
(DirectCodecPool.INSTANCE.codec(codec).supportsDirectDecompression()) {
       return new FullDirectDecompressor(codecName);
     } else {
@@ -186,6 +174,100 @@ class DirectCodecFactory extends CodecFactory implements 
AutoCloseable {
     }
   }
 
+  private abstract class BaseDecompressor extends BytesDecompressor {
+    private final ReusingByteBufferAllocator inputAllocator;
+    private final ReusingByteBufferAllocator outputAllocator;
+
+    BaseDecompressor() {
+      inputAllocator = ReusingByteBufferAllocator.strict(allocator);
+      // Using unsafe reusing allocator because we give out the output 
ByteBuffer wrapped in a BytesInput. But
+      // that's what BytesInputs are for. It is expected to copy the data from 
the returned BytesInput before
+      // using this decompressor again.
+      outputAllocator = ReusingByteBufferAllocator.unsafe(allocator);
+    }
+
+    @Override
+    public BytesInput decompress(BytesInput bytes, int uncompressedSize) 
throws IOException {
+      try (ByteBufferReleaser releaser = inputAllocator.getReleaser()) {
+        ByteBuffer input = bytes.toByteBuffer(releaser);
+        ByteBuffer output = outputAllocator.allocate(uncompressedSize);
+        int size = decompress(input.slice(), output.slice());
+        if (size != uncompressedSize) {
+          throw new DirectCodecPool.ParquetCompressionCodecException(
+              "Unexpected decompressed size: " + size + " != " + 
uncompressedSize);
+        }
+        output.limit(size);
+        return BytesInput.from(output);
+      }
+    }
+
+    abstract int decompress(ByteBuffer input, ByteBuffer output) throws 
IOException;
+
+    @Override
+    public void decompress(ByteBuffer input, int compressedSize, ByteBuffer 
output, int uncompressedSize)
+        throws IOException {
+      input.limit(input.position() + compressedSize);
+      output.limit(output.position() + uncompressedSize);
+      int size = decompress(input.slice(), output.slice());
+      if (size != uncompressedSize) {
+        throw new DirectCodecPool.ParquetCompressionCodecException(
+            "Unexpected decompressed size: " + size + " != " + 
uncompressedSize);
+      }
+      input.position(input.limit());
+      output.position(output.position() + uncompressedSize);
+    }
+
+    @Override
+    public void release() {
+      try {
+        AutoCloseables.uncheckedClose(outputAllocator, inputAllocator);
+      } finally {
+        closeDecompressor();
+      }
+    }
+
+    abstract void closeDecompressor();
+  }
+
+  private abstract class BaseCompressor extends BytesCompressor {
+    private final ReusingByteBufferAllocator inputAllocator;
+    private final ReusingByteBufferAllocator outputAllocator;
+
+    BaseCompressor() {
+      inputAllocator = ReusingByteBufferAllocator.strict(allocator);
+      // Using unsafe reusing allocator because we give out the output 
ByteBuffer wrapped in a BytesInput. But
+      // that's what BytesInputs are for. It is expected to copy the data from 
the returned BytesInput before
+      // using this compressor again.
+      outputAllocator = ReusingByteBufferAllocator.unsafe(allocator);
+    }
+
+    @Override
+    public BytesInput compress(BytesInput bytes) throws IOException {
+      try (ByteBufferReleaser releaser = inputAllocator.getReleaser()) {
+        ByteBuffer input = bytes.toByteBuffer(releaser);
+        ByteBuffer output = 
outputAllocator.allocate(maxCompressedSize(Math.toIntExact(bytes.size())));
+        int size = compress(input.slice(), output.slice());
+        output.limit(size);
+        return BytesInput.from(output);
+      }
+    }
+
+    abstract int maxCompressedSize(int size);
+
+    abstract int compress(ByteBuffer input, ByteBuffer output) throws 
IOException;
+
+    @Override
+    public void release() {
+      try {
+        AutoCloseables.uncheckedClose(outputAllocator, inputAllocator);
+      } finally {
+        closeCompressor();
+      }
+    }
+
+    abstract void closeCompressor();
+  }
+
   /**
    * Wrapper around new Hadoop compressors that implement a direct memory
    * based version of a particular decompression algorithm. To maintain
@@ -194,38 +276,47 @@ class DirectCodecFactory extends CodecFactory implements 
AutoCloseable {
    * are currently retrieved and have their decompression method invoked
    * with reflection.
    */
-  public class FullDirectDecompressor extends BytesDecompressor {
+  public class FullDirectDecompressor extends BaseDecompressor {
     private final Object decompressor;
-    private HeapBytesDecompressor extraDecompressor;
 
     public FullDirectDecompressor(CompressionCodecName codecName) {
       CompressionCodec codec = getCodec(codecName);
       this.decompressor = 
DirectCodecPool.INSTANCE.codec(codec).borrowDirectDecompressor();
-      this.extraDecompressor = new HeapBytesDecompressor(codecName);
     }
 
     @Override
     public BytesInput decompress(BytesInput compressedBytes, int 
uncompressedSize) throws IOException {
-      return extraDecompressor.decompress(compressedBytes, uncompressedSize);
+      // Similarly to non-direct decompressors, we reset before use, if 
possible (see HeapBytesDecompressor)
+      if (decompressor instanceof Decompressor) {
+        ((Decompressor) decompressor).reset();
+      }
+      return super.decompress(compressedBytes, uncompressedSize);
     }
 
     @Override
     public void decompress(ByteBuffer input, int compressedSize, ByteBuffer 
output, int uncompressedSize)
         throws IOException {
-      output.clear();
+      // Similarly to non-direct decompressors, we reset before use, if 
possible (see HeapBytesDecompressor)
+      if (decompressor instanceof Decompressor) {
+        ((Decompressor) decompressor).reset();
+      }
+      super.decompress(input, compressedSize, output, uncompressedSize);
+    }
+
+    @Override
+    int decompress(ByteBuffer input, ByteBuffer output) {
+      int startPos = output.position();
       try {
-        DECOMPRESS_METHOD.invoke(decompressor, (ByteBuffer) 
input.limit(compressedSize), (ByteBuffer)
-            output.limit(uncompressedSize));
+        DECOMPRESS_METHOD.invoke(decompressor, input, output);
       } catch (IllegalAccessException | InvocationTargetException e) {
         throw new DirectCodecPool.ParquetCompressionCodecException(e);
       }
-      output.position(uncompressedSize);
+      return output.position() - startPos;
     }
 
     @Override
-    public void release() {
+    void closeDecompressor() {
       DirectCodecPool.INSTANCE.returnDirectDecompressor(decompressor);
-      extraDecompressor.release();
     }
   }
 
@@ -250,75 +341,86 @@ class DirectCodecFactory extends CodecFactory implements 
AutoCloseable {
     public void release() {}
   }
 
-  public class SnappyDecompressor extends BytesDecompressor {
+  public class SnappyDecompressor extends BaseDecompressor {
+    @Override
+    int decompress(ByteBuffer input, ByteBuffer output) throws IOException {
+      return Snappy.uncompress(input, output);
+    }
 
-    private HeapBytesDecompressor extraDecompressor;
+    @Override
+    void closeDecompressor() {
+      // no-op
+    }
+  }
 
-    public SnappyDecompressor() {
-      this.extraDecompressor = new 
HeapBytesDecompressor(CompressionCodecName.SNAPPY);
+  public class SnappyCompressor extends BaseCompressor {
+
+    @Override
+    int compress(ByteBuffer input, ByteBuffer output) throws IOException {
+      return Snappy.compress(input, output);
     }
 
     @Override
-    public BytesInput decompress(BytesInput bytes, int uncompressedSize) 
throws IOException {
-      return extraDecompressor.decompress(bytes, uncompressedSize);
+    int maxCompressedSize(int size) {
+      return Snappy.maxCompressedLength(size);
     }
 
     @Override
-    public void decompress(ByteBuffer src, int compressedSize, ByteBuffer dst, 
int uncompressedSize)
-        throws IOException {
-      dst.clear();
-      int size = Snappy.uncompress(src, dst);
-      dst.limit(size);
+    public CompressionCodecName getCodecName() {
+      return CompressionCodecName.SNAPPY;
     }
 
     @Override
-    public void release() {}
+    void closeCompressor() {
+      // no-op
+    }
   }
 
-  public class SnappyCompressor extends BytesCompressor {
+  private class ZstdDecompressor extends BaseDecompressor {
+    private final ZstdDecompressCtx context;
 
-    // TODO - this outgoing buffer might be better off not being shared, this 
seems to
-    // only work because of an extra copy currently happening where this 
interface is
-    // be consumed
-    private ByteBuffer incoming;
-    private ByteBuffer outgoing;
+    ZstdDecompressor() {
+      context = new ZstdDecompressCtx();
+    }
 
-    /**
-     * Compress a given buffer of bytes
-     * @param bytes
-     * @return
-     * @throws IOException
-     */
     @Override
-    public BytesInput compress(BytesInput bytes) throws IOException {
-      int maxOutputSize = Snappy.maxCompressedLength((int) bytes.size());
-      ByteBuffer bufferIn = bytes.toByteBuffer();
-      outgoing = ensure(outgoing, maxOutputSize);
-      final int size;
-      if (bufferIn.isDirect()) {
-        size = Snappy.compress(bufferIn, outgoing);
-      } else {
-        // Snappy library requires buffers be direct
-        this.incoming = ensure(this.incoming, (int) bytes.size());
-        this.incoming.put(bufferIn);
-        this.incoming.flip();
-        size = Snappy.compress(this.incoming, outgoing);
-      }
+    int decompress(ByteBuffer input, ByteBuffer output) {
+      return context.decompress(output, input);
+    }
 
-      outgoing.limit(size);
+    @Override
+    void closeDecompressor() {
+      context.close();
+    }
+  }
+
+  private class ZstdCompressor extends BaseCompressor {
+    private final ZstdCompressCtx context;
 
-      return BytesInput.from(outgoing);
+    ZstdCompressor() {
+      context = new ZstdCompressCtx();
+      context.setLevel(configuration.getInt(
+          ZstandardCodec.PARQUET_COMPRESS_ZSTD_LEVEL, 
ZstandardCodec.DEFAULT_PARQUET_COMPRESS_ZSTD_LEVEL));
     }
 
     @Override
     public CompressionCodecName getCodecName() {
-      return CompressionCodecName.SNAPPY;
+      return CompressionCodecName.ZSTD;
     }
 
     @Override
-    public void release() {
-      outgoing = DirectCodecFactory.this.release(outgoing);
-      incoming = DirectCodecFactory.this.release(incoming);
+    int maxCompressedSize(int size) {
+      return Math.toIntExact(Zstd.compressBound(size));
+    }
+
+    @Override
+    int compress(ByteBuffer input, ByteBuffer output) {
+      return context.compress(output, input);
+    }
+
+    @Override
+    void closeCompressor() {
+      context.close();
     }
   }
 
diff --git 
a/parquet-hadoop/src/main/java/org/apache/parquet/hadoop/DirectZstd.java 
b/parquet-hadoop/src/main/java/org/apache/parquet/hadoop/DirectZstd.java
deleted file mode 100644
index 73da562a0..000000000
--- a/parquet-hadoop/src/main/java/org/apache/parquet/hadoop/DirectZstd.java
+++ /dev/null
@@ -1,155 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- * <p>
- * http://www.apache.org/licenses/LICENSE-2.0
- * <p>
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.parquet.hadoop;
-
-import static 
org.apache.parquet.hadoop.codec.ZstandardCodec.DEFAULTPARQUET_COMPRESS_ZSTD_WORKERS;
-import static 
org.apache.parquet.hadoop.codec.ZstandardCodec.DEFAULT_PARQUET_COMPRESS_ZSTD_BUFFERPOOL_ENABLED;
-import static 
org.apache.parquet.hadoop.codec.ZstandardCodec.DEFAULT_PARQUET_COMPRESS_ZSTD_LEVEL;
-import static 
org.apache.parquet.hadoop.codec.ZstandardCodec.PARQUET_COMPRESS_ZSTD_BUFFERPOOL_ENABLED;
-import static 
org.apache.parquet.hadoop.codec.ZstandardCodec.PARQUET_COMPRESS_ZSTD_LEVEL;
-import static 
org.apache.parquet.hadoop.codec.ZstandardCodec.PARQUET_COMPRESS_ZSTD_WORKERS;
-
-import com.github.luben.zstd.BufferPool;
-import com.github.luben.zstd.NoPool;
-import com.github.luben.zstd.RecyclingBufferPool;
-import com.github.luben.zstd.Zstd;
-import com.github.luben.zstd.ZstdOutputStream;
-import java.io.ByteArrayOutputStream;
-import java.io.IOException;
-import java.nio.ByteBuffer;
-import org.apache.hadoop.conf.Configuration;
-import org.apache.parquet.bytes.BytesInput;
-import org.apache.parquet.conf.HadoopParquetConfiguration;
-import org.apache.parquet.conf.ParquetConfiguration;
-import org.apache.parquet.hadoop.codec.ZstdDecompressorStream;
-import org.apache.parquet.hadoop.metadata.CompressionCodecName;
-
-/**
- * Utility class to support creating compressor and decompressor instances for 
the ZStandard codec. It is implemented in
- * a way to work around the codec pools implemented in both parquet-mr and 
hadoop. These codec pools may result creating
- * and dereferencing direct byte buffers causing OOM errors in case of many 
parallel compressor/decompressor instances
- * are required working on direct memory.
- *
- * @see DirectCodecFactory.DirectCodecPool
- * @see org.apache.hadoop.io.compress.CodecPool
- */
-class DirectZstd {
-
-  static CodecFactory.BytesCompressor createCompressor(Configuration conf, int 
pageSize) {
-    return createCompressor(new HadoopParquetConfiguration(conf), pageSize);
-  }
-
-  static CodecFactory.BytesCompressor createCompressor(ParquetConfiguration 
conf, int pageSize) {
-    return new ZstdCompressor(
-        getPool(conf),
-        conf.getInt(PARQUET_COMPRESS_ZSTD_LEVEL, 
DEFAULT_PARQUET_COMPRESS_ZSTD_LEVEL),
-        conf.getInt(PARQUET_COMPRESS_ZSTD_WORKERS, 
DEFAULTPARQUET_COMPRESS_ZSTD_WORKERS),
-        pageSize);
-  }
-
-  static CodecFactory.BytesDecompressor createDecompressor(Configuration conf) 
{
-    return createDecompressor(new HadoopParquetConfiguration(conf));
-  }
-
-  static CodecFactory.BytesDecompressor 
createDecompressor(ParquetConfiguration conf) {
-    return new ZstdDecompressor(getPool(conf));
-  }
-
-  private static class ZstdCompressor extends CodecFactory.BytesCompressor {
-    private final BufferPool pool;
-    private final int level;
-    private final int workers;
-    private final int pageSize;
-
-    ZstdCompressor(BufferPool pool, int level, int workers, int pageSize) {
-      this.pool = pool;
-      this.level = level;
-      this.workers = workers;
-      this.pageSize = pageSize;
-    }
-
-    @Override
-    public BytesInput compress(BytesInput bytes) throws IOException {
-      // Since BytesInput does not support direct memory this implementation 
is heap based
-      BytesInputProviderOutputStream stream = new 
BytesInputProviderOutputStream(pageSize);
-      try (ZstdOutputStream zstdStream = new ZstdOutputStream(stream, pool, 
level)) {
-        zstdStream.setWorkers(workers);
-        bytes.writeAllTo(zstdStream);
-      }
-      return stream.getBytesInput();
-    }
-
-    @Override
-    public CompressionCodecName getCodecName() {
-      return CompressionCodecName.ZSTD;
-    }
-
-    @Override
-    public void release() {
-      // Nothing to do here since we release resources where we create them
-    }
-  }
-
-  private static class ZstdDecompressor extends CodecFactory.BytesDecompressor 
{
-    private final BufferPool pool;
-
-    private ZstdDecompressor(BufferPool pool) {
-      this.pool = pool;
-    }
-
-    @Override
-    public BytesInput decompress(BytesInput bytes, int uncompressedSize) 
throws IOException {
-      // Since BytesInput does not support direct memory this implementation 
is heap based
-      try (ZstdDecompressorStream decompressorStream = new 
ZstdDecompressorStream(bytes.toInputStream(), pool)) {
-        // We need to copy the bytes from the input stream, so we can close it 
here (BytesInput does not support
-        // closing)
-        return BytesInput.copy(BytesInput.from(decompressorStream, 
uncompressedSize));
-      }
-    }
-
-    @Override
-    public void decompress(ByteBuffer input, int compressedSize, ByteBuffer 
output, int uncompressedSize)
-        throws IOException {
-      Zstd.decompress(output, input);
-    }
-
-    @Override
-    public void release() {
-      // Nothing to do here since we release resources where we create them
-    }
-  }
-
-  private static class BytesInputProviderOutputStream extends 
ByteArrayOutputStream {
-    BytesInputProviderOutputStream(int initialCapacity) {
-      super(initialCapacity);
-    }
-
-    BytesInput getBytesInput() {
-      return BytesInput.from(buf, 0, count);
-    }
-  }
-
-  private static BufferPool getPool(ParquetConfiguration conf) {
-    if (conf.getBoolean(
-        PARQUET_COMPRESS_ZSTD_BUFFERPOOL_ENABLED, 
DEFAULT_PARQUET_COMPRESS_ZSTD_BUFFERPOOL_ENABLED)) {
-      return RecyclingBufferPool.INSTANCE;
-    } else {
-      return NoPool.INSTANCE;
-    }
-  }
-}
diff --git 
a/parquet-hadoop/src/main/java/org/apache/parquet/hadoop/ParquetFileReader.java 
b/parquet-hadoop/src/main/java/org/apache/parquet/hadoop/ParquetFileReader.java
index 628b6dcf1..be43cf647 100644
--- 
a/parquet-hadoop/src/main/java/org/apache/parquet/hadoop/ParquetFileReader.java
+++ 
b/parquet-hadoop/src/main/java/org/apache/parquet/hadoop/ParquetFileReader.java
@@ -781,7 +781,7 @@ public class ParquetFileReader implements Closeable {
 
     if (options.usePageChecksumVerification()) {
       this.crc = new CRC32();
-      this.crcAllocator = new 
ReusingByteBufferAllocator(options.getAllocator());
+      this.crcAllocator = 
ReusingByteBufferAllocator.strict(options.getAllocator());
     } else {
       this.crc = null;
       this.crcAllocator = null;
@@ -840,7 +840,7 @@ public class ParquetFileReader implements Closeable {
 
     if (options.usePageChecksumVerification()) {
       this.crc = new CRC32();
-      this.crcAllocator = new 
ReusingByteBufferAllocator(options.getAllocator());
+      this.crcAllocator = 
ReusingByteBufferAllocator.strict(options.getAllocator());
     } else {
       this.crc = null;
       this.crcAllocator = null;
@@ -879,7 +879,7 @@ public class ParquetFileReader implements Closeable {
 
     if (options.usePageChecksumVerification()) {
       this.crc = new CRC32();
-      this.crcAllocator = new 
ReusingByteBufferAllocator(options.getAllocator());
+      this.crcAllocator = 
ReusingByteBufferAllocator.strict(options.getAllocator());
     } else {
       this.crc = null;
       this.crcAllocator = null;
@@ -921,7 +921,7 @@ public class ParquetFileReader implements Closeable {
 
     if (options.usePageChecksumVerification()) {
       this.crc = new CRC32();
-      this.crcAllocator = new 
ReusingByteBufferAllocator(options.getAllocator());
+      this.crcAllocator = 
ReusingByteBufferAllocator.strict(options.getAllocator());
     } else {
       this.crc = null;
       this.crcAllocator = null;
diff --git 
a/parquet-hadoop/src/main/java/org/apache/parquet/hadoop/ParquetFileWriter.java 
b/parquet-hadoop/src/main/java/org/apache/parquet/hadoop/ParquetFileWriter.java
index abc408779..5c7612652 100644
--- 
a/parquet-hadoop/src/main/java/org/apache/parquet/hadoop/ParquetFileWriter.java
+++ 
b/parquet-hadoop/src/main/java/org/apache/parquet/hadoop/ParquetFileWriter.java
@@ -473,7 +473,7 @@ public class ParquetFileWriter implements AutoCloseable {
     this.pageWriteChecksumEnabled = pageWriteChecksumEnabled;
     this.crc = pageWriteChecksumEnabled ? new CRC32() : null;
     this.crcAllocator = pageWriteChecksumEnabled
-        ? new ReusingByteBufferAllocator(allocator == null ? new 
HeapByteBufferAllocator() : allocator)
+        ? ReusingByteBufferAllocator.strict(allocator == null ? new 
HeapByteBufferAllocator() : allocator)
         : null;
 
     this.metadataConverter = new 
ParquetMetadataConverter(statisticsTruncateLength);
@@ -546,7 +546,7 @@ public class ParquetFileWriter implements AutoCloseable {
     this.pageWriteChecksumEnabled = 
ParquetOutputFormat.getPageWriteChecksumEnabled(configuration);
     this.crc = pageWriteChecksumEnabled ? new CRC32() : null;
     this.crcAllocator = pageWriteChecksumEnabled
-        ? new ReusingByteBufferAllocator(allocator == null ? new 
HeapByteBufferAllocator() : allocator)
+        ? ReusingByteBufferAllocator.strict(allocator == null ? new 
HeapByteBufferAllocator() : allocator)
         : null;
     this.metadataConverter = new 
ParquetMetadataConverter(ParquetProperties.DEFAULT_STATISTICS_TRUNCATE_LENGTH);
     this.fileEncryptor = null;
diff --git 
a/parquet-hadoop/src/test/java/org/apache/parquet/hadoop/TestDirectCodecFactory.java
 
b/parquet-hadoop/src/test/java/org/apache/parquet/hadoop/TestDirectCodecFactory.java
index e5b87a7e9..3a1779588 100644
--- 
a/parquet-hadoop/src/test/java/org/apache/parquet/hadoop/TestDirectCodecFactory.java
+++ 
b/parquet-hadoop/src/test/java/org/apache/parquet/hadoop/TestDirectCodecFactory.java
@@ -27,10 +27,11 @@ import java.util.Random;
 import java.util.Set;
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.io.compress.CompressionCodec;
-import org.apache.parquet.bytes.ByteBufferAllocator;
+import org.apache.parquet.bytes.ByteBufferReleaser;
 import org.apache.parquet.bytes.BytesInput;
 import org.apache.parquet.bytes.DirectByteBufferAllocator;
 import org.apache.parquet.bytes.HeapByteBufferAllocator;
+import org.apache.parquet.bytes.TrackingByteBufferAllocator;
 import 
org.apache.parquet.compression.CompressionCodecFactory.BytesInputCompressor;
 import 
org.apache.parquet.compression.CompressionCodecFactory.BytesInputDecompressor;
 import org.apache.parquet.hadoop.metadata.CompressionCodecName;
@@ -52,16 +53,15 @@ public class TestDirectCodecFactory {
   private final int pageSize = 64 * 1024;
 
   private void test(int size, CompressionCodecName codec, boolean 
useOnHeapCompression, Decompression decomp) {
-    ByteBuffer rawBuf = null;
-    ByteBuffer outBuf = null;
-    ByteBufferAllocator allocator = null;
-    try {
-      allocator = new DirectByteBufferAllocator();
+    try (TrackingByteBufferAllocator allocator = 
TrackingByteBufferAllocator.wrap(new DirectByteBufferAllocator());
+        ByteBufferReleaser releaser = new ByteBufferReleaser(allocator)) {
       final CodecFactory codecFactory =
           CodecFactory.createDirectCodecFactory(new Configuration(), 
allocator, pageSize);
-      rawBuf = allocator.allocate(size);
+      ByteBuffer rawBuf = allocator.allocate(size);
+      releaser.releaseLater(rawBuf);
       final byte[] rawArr = new byte[size];
-      outBuf = allocator.allocate(size * 2);
+      ByteBuffer outBuf = allocator.allocate(size * 2);
+      releaser.releaseLater(outBuf);
       final Random r = new Random();
       final byte[] random = new byte[1024];
       int pos = 0;
@@ -121,19 +121,14 @@ public class TestDirectCodecFactory {
           break;
         }
       }
+      c.release();
+      d.release();
     } catch (Exception e) {
       final String msg = String.format(
           "Failure while testing Codec: %s, OnHeapCompressionInput: %s, 
Decompression Mode: %s, Data Size: %d",
           codec.name(), useOnHeapCompression, decomp.name(), size);
       LOG.error(msg);
       throw new RuntimeException(msg, e);
-    } finally {
-      if (rawBuf != null) {
-        allocator.release(rawBuf);
-      }
-      if (outBuf != null) {
-        allocator.release(rawBuf);
-      }
     }
   }
 


Reply via email to