This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch branch-2.0
in repository https://gitbox.apache.org/repos/asf/orc.git


The following commit(s) were added to refs/heads/branch-2.0 by this push:
     new 4e14cb301 ORC-817, ORC-1088: Support ZStandard compression using 
zstd-jni
4e14cb301 is described below

commit 4e14cb301765c376f0bf747c7be4844ea5798f53
Author: sychen <[email protected]>
AuthorDate: Tue Jan 16 15:06:29 2024 -0800

    ORC-817, ORC-1088: Support ZStandard compression using zstd-jni
    
    ### What changes were proposed in this pull request?
    Original PR: https://github.com/apache/orc/pull/988
    Original author: dchristle
    
    This PR will support the use of 
[zstd-jni](https://github.com/luben/zstd-jni) library as the implementation of 
ORC zstd, with better performance than 
[aircompressor](https://github.com/airlift/aircompressor).  
(https://github.com/apache/orc/pull/988#issuecomment-1884443205)
    
    This PR also exposes the compression level and "long mode" settings to ORC 
users. These settings allow the user to select different speed/compression 
trade-offs that were not supported by the original aircompressor.
    
    - Add zstd-jni dependency, and add a new CompressionCodec ZstdCodec that 
uses it. Add ORC conf to set compression level.
    - Add ORC conf to use long mode, and add configuration setters for 
windowLog.
    - Add tests that verify the correctness of writing and reading across 
compression levels, window sizes, and long mode use.
    - Add test for compatibility between Zstd aircompressor and zstd-jni 
implementations.
    
    ### Why are the changes needed?
    These change makes sense for a few reasons:
    
    ORC users will gain all the improvements from the main zstd library. It is 
under active development and receives regular speed and compression 
improvements. In contrast, aircompressor's zstd implementation is older and 
stale.
    
    ORC users will be able to use the entire speed/compression tradeoff space. 
Today, aircompressor's implementation has only one of eight compression 
strategies 
([link](https://github.com/airlift/aircompressor/blob/c5e6972bd37e1d3834514957447028060a268eea/src/main/java/io/airlift/compress/zstd/CompressionParameters.java#L143)).
 This means only a small range of faster but less compressive strategies can be 
exposed to ORC users. ORC storage with high compression (e.g. for 
large-but-infrequ [...]
    
    It will harmonize the Java ORC implementation with other projects in the 
Hadoop ecosystem. Parquet, Spark, and even the C++ ORC reader/writers all rely 
on the official zstd implementation either via zstd-jni or directly. In this 
way, the Java reader/writer code is an outlier.
    
    Detection and fixing any bugs or regressions will generally happen much 
faster, given the larger number of users and active developer community of zstd 
and zstd-jni.
    
    The largest tradeoff is that zstd-jni wraps compiled code. That said, many 
microprocessor architectures are already targeted & bundled into zstd-jni, so 
this should be a rare hurdle.
    
    ### How was this patch tested?
    - Unit tests for reading and writing ORC files using a variety of 
compression levels, window logs, all pass.
    - Unit test to compress and decompress between aircompressor and zstd-jni 
passes. Note that the current aircompressor implementation uses a small subset 
of levels, so the test only compares data using the default compression 
settings.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No
    
    Closes #1743 from cxzl25/ORC-817.
    
    Lead-authored-by: sychen <[email protected]>
    Co-authored-by: Dongjoon Hyun <[email protected]>
    Co-authored-by: David Christle <[email protected]>
    Co-authored-by: Yiqun Zhang <[email protected]>
    Signed-off-by: Dongjoon Hyun <[email protected]>
    (cherry picked from commit 33be57124f2d1ddcbf4bbb097613f346aef6ae6c)
    Signed-off-by: Dongjoon Hyun <[email protected]>
---
 java/core/pom.xml                                  |   4 +
 java/core/src/java/org/apache/orc/OrcConf.java     |   8 +
 java/core/src/java/org/apache/orc/OrcFile.java     |  32 +++
 .../java/org/apache/orc/impl/PhysicalFsWriter.java |  12 +-
 .../src/java/org/apache/orc/impl/WriterImpl.java   |  24 +-
 .../src/java/org/apache/orc/impl/ZstdCodec.java    | 252 +++++++++++++++++++++
 .../src/test/org/apache/orc/TestVectorOrcFile.java |  97 +++++---
 .../src/test/org/apache/orc/impl/TestZstd.java     |  87 ++++++-
 java/pom.xml                                       |   6 +
 9 files changed, 477 insertions(+), 45 deletions(-)

diff --git a/java/core/pom.xml b/java/core/pom.xml
index a750911b7..755c005b1 100644
--- a/java/core/pom.xml
+++ b/java/core/pom.xml
@@ -51,6 +51,10 @@
       <groupId>io.airlift</groupId>
       <artifactId>aircompressor</artifactId>
     </dependency>
+    <dependency>
+      <groupId>com.github.luben</groupId>
+      <artifactId>zstd-jni</artifactId>
+    </dependency>
     <dependency>
       <groupId>org.apache.hadoop</groupId>
       <artifactId>hadoop-client-api</artifactId>
diff --git a/java/core/src/java/org/apache/orc/OrcConf.java 
b/java/core/src/java/org/apache/orc/OrcConf.java
index 7e5296f52..d4bebe2cd 100644
--- a/java/core/src/java/org/apache/orc/OrcConf.java
+++ b/java/core/src/java/org/apache/orc/OrcConf.java
@@ -72,6 +72,14 @@ public enum OrcConf {
       "Define the compression strategy to use while writing data.\n" +
           "This changes the compression level of higher level compression\n" +
           "codec (like ZLIB)."),
+  COMPRESSION_ZSTD_LEVEL("orc.compression.zstd.level",
+      "hive.exec.orc.compression.zstd.level", 1,
+      "Define the compression level to use with ZStandard codec "
+          + "while writing data. The valid range is 1~22"),
+  COMPRESSION_ZSTD_WINDOWLOG("orc.compression.zstd.windowlog",
+      "hive.exec.orc.compression.zstd.windowlog", 0,
+      "Set the maximum allowed back-reference distance for "
+          + "ZStandard codec, expressed as power of 2."),
   BLOCK_PADDING_TOLERANCE("orc.block.padding.tolerance",
       "hive.exec.orc.block.padding.tolerance", 0.05,
       "Define the tolerance for block padding as a decimal fraction of\n" +
diff --git a/java/core/src/java/org/apache/orc/OrcFile.java 
b/java/core/src/java/org/apache/orc/OrcFile.java
index e41e79945..dfe3088fb 100644
--- a/java/core/src/java/org/apache/orc/OrcFile.java
+++ b/java/core/src/java/org/apache/orc/OrcFile.java
@@ -426,6 +426,27 @@ public class OrcFile {
     }
   }
 
+  public static class ZstdCompressOptions {
+    private int compressionZstdLevel;
+    private int compressionZstdWindowLog;
+
+    public int getCompressionZstdLevel() {
+      return compressionZstdLevel;
+    }
+
+    public void setCompressionZstdLevel(int compressionZstdLevel) {
+      this.compressionZstdLevel = compressionZstdLevel;
+    }
+
+    public int getCompressionZstdWindowLog() {
+      return compressionZstdWindowLog;
+    }
+
+    public void setCompressionZstdWindowLog(int compressionZstdWindowLog) {
+      this.compressionZstdWindowLog = compressionZstdWindowLog;
+    }
+  }
+
   /**
    * Options for creating ORC file writers.
    */
@@ -447,6 +468,7 @@ public class OrcFile {
     private WriterCallback callback;
     private EncodingStrategy encodingStrategy;
     private CompressionStrategy compressionStrategy;
+    private ZstdCompressOptions zstdCompressOptions;
     private double paddingTolerance;
     private String bloomFilterColumns;
     private double bloomFilterFpp;
@@ -493,6 +515,12 @@ public class OrcFile {
           OrcConf.COMPRESSION_STRATEGY.getString(tableProperties, conf);
       compressionStrategy = CompressionStrategy.valueOf(compString);
 
+      zstdCompressOptions = new ZstdCompressOptions();
+      zstdCompressOptions.setCompressionZstdLevel(
+              OrcConf.COMPRESSION_ZSTD_LEVEL.getInt(tableProperties, conf));
+      zstdCompressOptions.setCompressionZstdWindowLog(
+              OrcConf.COMPRESSION_ZSTD_WINDOWLOG.getInt(tableProperties, 
conf));
+
       paddingTolerance =
           OrcConf.BLOCK_PADDING_TOLERANCE.getDouble(tableProperties, conf);
 
@@ -938,6 +966,10 @@ public class OrcFile {
       return encodingStrategy;
     }
 
+    public ZstdCompressOptions getZstdCompressOptions() {
+      return zstdCompressOptions;
+    }
+
     public double getPaddingTolerance() {
       return paddingTolerance;
     }
diff --git a/java/core/src/java/org/apache/orc/impl/PhysicalFsWriter.java 
b/java/core/src/java/org/apache/orc/impl/PhysicalFsWriter.java
index deaf63446..4eb5f8562 100644
--- a/java/core/src/java/org/apache/orc/impl/PhysicalFsWriter.java
+++ b/java/core/src/java/org/apache/orc/impl/PhysicalFsWriter.java
@@ -115,8 +115,18 @@ public class PhysicalFsWriter implements PhysicalWriter {
     }
     CompressionCodec codec = OrcCodecPool.getCodec(opts.getCompress());
     if (codec != null){
-      compress.withCodec(codec, codec.getDefaultOptions());
+      CompressionCodec.Options tempOptions = codec.getDefaultOptions();
+      if (codec instanceof ZstdCodec &&
+              codec.getDefaultOptions() instanceof ZstdCodec.ZstdOptions 
options) {
+        OrcFile.ZstdCompressOptions zstdCompressOptions = 
opts.getZstdCompressOptions();
+        if (zstdCompressOptions != null) {
+          options.setLevel(zstdCompressOptions.getCompressionZstdLevel());
+          
options.setWindowLog(zstdCompressOptions.getCompressionZstdWindowLog());
+        }
+      }
+      compress.withCodec(codec, tempOptions);
     }
+
     this.compressionStrategy = opts.getCompressionStrategy();
     this.maxPadding = (int) (opts.getPaddingTolerance() * defaultStripeSize);
     this.blockSize = opts.getBlockSize();
diff --git a/java/core/src/java/org/apache/orc/impl/WriterImpl.java 
b/java/core/src/java/org/apache/orc/impl/WriterImpl.java
index 776baa28a..c028228ef 100644
--- a/java/core/src/java/org/apache/orc/impl/WriterImpl.java
+++ b/java/core/src/java/org/apache/orc/impl/WriterImpl.java
@@ -18,6 +18,7 @@
 
 package org.apache.orc.impl;
 
+import com.github.luben.zstd.util.Native;
 import com.google.protobuf.ByteString;
 import io.airlift.compress.lz4.Lz4Compressor;
 import io.airlift.compress.lz4.Lz4Decompressor;
@@ -273,6 +274,17 @@ public class WriterImpl implements WriterInternal, 
MemoryManager.Callback {
     return Math.min(kb256, Math.max(kb4, pow2));
   }
 
+  static {
+    try {
+      if 
(!"java".equalsIgnoreCase(System.getProperty("orc.compression.zstd.impl"))) {
+        Native.load();
+      }
+    } catch (UnsatisfiedLinkError | ExceptionInInitializerError e) {
+      LOG.warn("Unable to load zstd-jni library for your platform. " +
+            "Using builtin-java classes where applicable");
+    }
+  }
+
   public static CompressionCodec createCodec(CompressionKind kind) {
     switch (kind) {
       case NONE:
@@ -288,8 +300,16 @@ public class WriterImpl implements WriterInternal, 
MemoryManager.Callback {
         return new AircompressorCodec(kind, new Lz4Compressor(),
             new Lz4Decompressor());
       case ZSTD:
-        return new AircompressorCodec(kind, new ZstdCompressor(),
-            new ZstdDecompressor());
+        if 
("java".equalsIgnoreCase(System.getProperty("orc.compression.zstd.impl"))) {
+          return new AircompressorCodec(kind, new ZstdCompressor(),
+                  new ZstdDecompressor());
+        }
+        if (Native.isLoaded()) {
+          return new ZstdCodec();
+        } else {
+          return new AircompressorCodec(kind, new ZstdCompressor(),
+              new ZstdDecompressor());
+        }
       case BROTLI:
         return new BrotliCodec();
       default:
diff --git a/java/core/src/java/org/apache/orc/impl/ZstdCodec.java 
b/java/core/src/java/org/apache/orc/impl/ZstdCodec.java
new file mode 100644
index 000000000..cdbf1f3fd
--- /dev/null
+++ b/java/core/src/java/org/apache/orc/impl/ZstdCodec.java
@@ -0,0 +1,252 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.orc.impl;
+
+import com.github.luben.zstd.Zstd;
+import com.github.luben.zstd.ZstdCompressCtx;
+import org.apache.orc.CompressionCodec;
+import org.apache.orc.CompressionKind;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+
+public class ZstdCodec implements CompressionCodec {
+  private ZstdOptions zstdOptions = null;
+  private ZstdCompressCtx zstdCompressCtx = null;
+
+  public ZstdCodec(int level, int windowLog) {
+    this.zstdOptions = new ZstdOptions(level, windowLog);
+  }
+
+  public ZstdCodec() {
+    this(1, 0);
+  }
+
+  public ZstdOptions getZstdOptions() {
+    return zstdOptions;
+  }
+
+  // Thread local buffer
+  private static final ThreadLocal<byte[]> threadBuffer =
+          ThreadLocal.withInitial(() -> null);
+
+  protected static byte[] getBuffer(int size) {
+    byte[] result = threadBuffer.get();
+    if (result == null || result.length < size || result.length > size * 2) {
+      result = new byte[size];
+      threadBuffer.set(result);
+    }
+    return result;
+  }
+
+  static class ZstdOptions implements Options {
+    private int level;
+    private int windowLog;
+
+    ZstdOptions(int level, int windowLog) {
+      this.level = level;
+      this.windowLog = windowLog;
+    }
+
+    @Override
+    public ZstdOptions copy() {
+      return new ZstdOptions(level, windowLog);
+    }
+
+    @Override
+    public Options setSpeed(SpeedModifier newValue) {
+      return this;
+    }
+
+    /**
+     * Sets the Zstandard long mode maximum back-reference distance, expressed
+     * as a power of 2.
+     * <p>
+     * The value must be between ZSTD_WINDOWLOG_MIN (10) and ZSTD_WINDOWLOG_MAX
+     * (30 and 31 on 32/64-bit architectures, respectively).
+     * <p>
+     * A value of 0 is a special value indicating to use the default
+     * ZSTD_WINDOWLOG_LIMIT_DEFAULT of 27, which corresponds to back-reference
+     * window size of 128MiB.
+     *
+     * @param newValue The desired power-of-2 value back-reference distance.
+     * @return ZstdOptions
+     */
+    public ZstdOptions setWindowLog(int newValue) {
+      if ((newValue < Zstd.windowLogMin() || newValue > Zstd.windowLogMax()) 
&& newValue != 0) {
+        throw new IllegalArgumentException(
+            String.format(
+                "Zstd compression window size should be in the range %d to %d,"
+                    + " or set to the default value of 0.",
+                Zstd.windowLogMin(),
+                Zstd.windowLogMax()));
+      }
+      windowLog = newValue;
+      return this;
+    }
+
+    /**
+     * Sets the Zstandard compression codec compression level directly using
+     * the integer setting. This value is typically between 0 and 22, with
+     * larger numbers indicating more aggressive compression and lower speed.
+     * <p>
+     * This method provides additional granularity beyond the setSpeed method
+     * so that users can select a specific level.
+     *
+     * @param newValue The level value of compression to set.
+     * @return ZstdOptions
+     */
+    public ZstdOptions setLevel(int newValue) {
+      if (newValue < Zstd.minCompressionLevel() || newValue > 
Zstd.maxCompressionLevel()) {
+        throw new IllegalArgumentException(
+            String.format(
+                "Zstd compression level should be in the range %d to %d",
+                Zstd.minCompressionLevel(),
+                Zstd.maxCompressionLevel()));
+      }
+      level = newValue;
+      return this;
+    }
+
+    @Override
+    public ZstdOptions setData(DataKind newValue) {
+      return this; // We don't support setting DataKind in ZstdCodec.
+    }
+
+    @Override
+    public boolean equals(Object o) {
+      if (this == o) return true;
+      if (o == null || getClass() != o.getClass()) return false;
+
+      ZstdOptions that = (ZstdOptions) o;
+
+      if (level != that.level) return false;
+      return windowLog == that.windowLog;
+    }
+
+    @Override
+    public int hashCode() {
+      int result = level;
+      result = 31 * result + windowLog;
+      return result;
+    }
+  }
+
+  private static final ZstdOptions DEFAULT_OPTIONS =
+      new ZstdOptions(1, 0);
+
+  @Override
+  public Options getDefaultOptions() {
+    return DEFAULT_OPTIONS;
+  }
+
+  /**
+   * Compresses an input ByteBuffer into an output ByteBuffer using Zstandard
+   * compression. If the maximum bound of the number of output bytes exceeds
+   * the output ByteBuffer size, the remaining bytes are written to the 
overflow
+   * ByteBuffer.
+   *
+   * @param in       the bytes to compress
+   * @param out      the compressed bytes
+   * @param overflow put any additional bytes here
+   * @param options  the options to control compression
+   * @return ZstdOptions
+   */
+  @Override
+  public boolean compress(ByteBuffer in, ByteBuffer out,
+      ByteBuffer overflow,
+      Options options) throws IOException {
+    ZstdOptions zso = (ZstdOptions) options;
+
+    zstdCompressCtx = new ZstdCompressCtx();
+    zstdCompressCtx.setLevel(zso.level);
+    zstdCompressCtx.setLong(zso.windowLog);
+    zstdCompressCtx.setChecksum(false);
+
+    try {
+      int inBytes = in.remaining();
+      byte[] compressed = getBuffer((int) Zstd.compressBound(inBytes));
+
+      int outBytes = zstdCompressCtx.compressByteArray(compressed, 0, 
compressed.length,
+              in.array(), in.arrayOffset() + in.position(), inBytes);
+      if (Zstd.isError(outBytes)) {
+        throw new IOException(String.format("Error code %s!", outBytes));
+      }
+      if (outBytes < inBytes) {
+        int remaining = out.remaining();
+        if (remaining >= outBytes) {
+          System.arraycopy(compressed, 0, out.array(), out.arrayOffset() +
+                  out.position(), outBytes);
+          out.position(out.position() + outBytes);
+        } else {
+          System.arraycopy(compressed, 0, out.array(), out.arrayOffset() +
+                  out.position(), remaining);
+          out.position(out.limit());
+          System.arraycopy(compressed, remaining, overflow.array(),
+                  overflow.arrayOffset(), outBytes - remaining);
+          overflow.position(outBytes - remaining);
+        }
+        return true;
+      } else {
+        return false;
+      }
+    } finally {
+      zstdCompressCtx.close();
+    }
+  }
+
+  @Override
+  public void decompress(ByteBuffer in, ByteBuffer out) throws IOException {
+    int srcOffset = in.arrayOffset() + in.position();
+    int srcSize = in.remaining();
+    int dstOffset = out.arrayOffset() + out.position();
+    int dstSize = out.remaining() - dstOffset;
+
+    long decompressOut =
+        Zstd.decompressByteArray(out.array(), dstOffset, dstSize, in.array(),
+            srcOffset, srcSize);
+    if (Zstd.isError(decompressOut)) {
+      throw new IOException(String.format("Error code %s!", decompressOut));
+    }
+    in.position(in.limit());
+    out.position(dstOffset + (int) decompressOut);
+    out.flip();
+  }
+
+  @Override
+  public void reset() {
+
+  }
+
+  @Override
+  public void destroy() {
+    if (zstdCompressCtx != null) {
+      zstdCompressCtx.close();
+    }
+  }
+
+  @Override
+  public CompressionKind getKind() {
+    return CompressionKind.ZSTD;
+  }
+
+  @Override
+  public void close() {
+    OrcCodecPool.returnCodec(CompressionKind.ZSTD, this);
+  }
+}
diff --git a/java/core/src/test/org/apache/orc/TestVectorOrcFile.java 
b/java/core/src/test/org/apache/orc/TestVectorOrcFile.java
index 2dacb8d60..c24514f69 100644
--- a/java/core/src/test/org/apache/orc/TestVectorOrcFile.java
+++ b/java/core/src/test/org/apache/orc/TestVectorOrcFile.java
@@ -2272,51 +2272,67 @@ public class TestVectorOrcFile {
   }
 
   /**
-   * Read and write a randomly generated zstd file.
+   * Write a randomly generated zstd-compressed file, read it back, and check
+   * that the output matches the input.
+   * <p>
+   * Checks correctness across a variety of valid settings:
+   * <p>
+   *  * Negative, low, moderate, and high compression levels
+   *  * Valid window sizes in [10-31], and default value of 0.
+   *
+   * @throws Exception
    */
   @ParameterizedTest
   @MethodSource("data")
   public void testZstd(Version fileFormat) throws Exception {
     TypeDescription schema =
         TypeDescription.fromString("struct<x:bigint,y:int,z:bigint>");
-    try (Writer writer = OrcFile.createWriter(testFilePath,
-        OrcFile.writerOptions(conf)
-            .setSchema(schema)
-            .compress(CompressionKind.ZSTD)
-            .bufferSize(1000)
-            .version(fileFormat))) {
-      VectorizedRowBatch batch = schema.createRowBatch();
-      Random rand = new Random(3);
-      batch.size = 1000;
-      for (int b = 0; b < 10; ++b) {
-        for (int r = 0; r < 1000; ++r) {
-          ((LongColumnVector) batch.cols[0]).vector[r] = rand.nextInt();
-          ((LongColumnVector) batch.cols[1]).vector[r] = b * 1000 + r;
-          ((LongColumnVector) batch.cols[2]).vector[r] = rand.nextLong();
+
+    for (Integer level : new ArrayList<>(Arrays.asList(-4, -1, 0, 1, 3, 8, 12, 
17, 22))) {
+      for (Integer windowLog : new ArrayList<>(Arrays.asList(0, 10, 20, 31))) {
+        OrcConf.COMPRESSION_ZSTD_LEVEL.setInt(conf, level);
+        OrcConf.COMPRESSION_ZSTD_WINDOWLOG.setInt(conf, windowLog);
+        try (Writer writer = OrcFile.createWriter(testFilePath,
+                OrcFile.writerOptions(conf)
+                        .setSchema(schema)
+                        .compress(CompressionKind.ZSTD)
+                        .bufferSize(1000)
+                        .version(fileFormat))) {
+          VectorizedRowBatch batch = schema.createRowBatch();
+          Random rand = new Random(3);
+          batch.size = 1000;
+          for (int b = 0; b < 10; ++b) {
+            for (int r = 0; r < 1000; ++r) {
+              ((LongColumnVector) batch.cols[0]).vector[r] = rand.nextInt();
+              ((LongColumnVector) batch.cols[1]).vector[r] = b * 1000 + r;
+              ((LongColumnVector) batch.cols[2]).vector[r] = rand.nextLong();
+            }
+            writer.addRowBatch(batch);
+          }
         }
-        writer.addRowBatch(batch);
-      }
-    }
-    try (Reader reader = OrcFile.createReader(testFilePath,
-           OrcFile.readerOptions(conf).filesystem(fs));
-         RecordReader rows = reader.rows()) {
-      assertEquals(CompressionKind.ZSTD, reader.getCompressionKind());
-      VectorizedRowBatch batch = reader.getSchema().createRowBatch(1000);
-      Random rand = new Random(3);
-      for (int b = 0; b < 10; ++b) {
-        rows.nextBatch(batch);
-        assertEquals(1000, batch.size);
-        for (int r = 0; r < batch.size; ++r) {
-          assertEquals(rand.nextInt(),
-              ((LongColumnVector) batch.cols[0]).vector[r]);
-          assertEquals(b * 1000 + r,
-              ((LongColumnVector) batch.cols[1]).vector[r]);
-          assertEquals(rand.nextLong(),
-              ((LongColumnVector) batch.cols[2]).vector[r]);
+        try (Reader reader = OrcFile.createReader(testFilePath,
+                OrcFile.readerOptions(conf).filesystem(fs));
+             RecordReader rows = reader.rows()) {
+          assertEquals(CompressionKind.ZSTD, reader.getCompressionKind());
+          VectorizedRowBatch batch = reader.getSchema().createRowBatch(1000);
+          Random rand = new Random(3);
+          for (int b = 0; b < 10; ++b) {
+            rows.nextBatch(batch);
+            assertEquals(1000, batch.size);
+            for (int r = 0; r < batch.size; ++r) {
+              assertEquals(rand.nextInt(),
+                      ((LongColumnVector) batch.cols[0]).vector[r]);
+              assertEquals(b * 1000 + r,
+                      ((LongColumnVector) batch.cols[1]).vector[r]);
+              assertEquals(rand.nextLong(),
+                      ((LongColumnVector) batch.cols[2]).vector[r]);
+            }
+          }
+          rows.nextBatch(batch);
+          assertEquals(0, batch.size);
         }
+        fs.delete(testFilePath, false);
       }
-      rows.nextBatch(batch);
-      assertEquals(0, batch.size);
     }
   }
 
@@ -2333,7 +2349,7 @@ public class TestVectorOrcFile {
     WriterOptions opts = OrcFile.writerOptions(conf)
         
.setSchema(schema).stripeSize(1000).bufferSize(100).version(fileFormat);
 
-    CompressionCodec snappyCodec, zlibCodec;
+    CompressionCodec snappyCodec, zlibCodec, zstdCodec;
     snappyCodec = writeBatchesAndGetCodec(10, 1000, 
opts.compress(CompressionKind.SNAPPY), batch);
     assertEquals(1, OrcCodecPool.getPoolSize(CompressionKind.SNAPPY));
     Reader reader = OrcFile.createReader(testFilePath, 
OrcFile.readerOptions(conf).filesystem(fs));
@@ -2355,6 +2371,13 @@ public class TestVectorOrcFile {
     assertEquals(1, OrcCodecPool.getPoolSize(CompressionKind.ZLIB));
     assertSame(zlibCodec, codec);
 
+    zstdCodec = writeBatchesAndGetCodec(10, 1000, 
opts.compress(CompressionKind.ZSTD), batch);
+    assertNotSame(zlibCodec, zstdCodec);
+    assertEquals(1, OrcCodecPool.getPoolSize(CompressionKind.ZSTD));
+    codec = writeBatchesAndGetCodec(10, 1000, 
opts.compress(CompressionKind.ZSTD), batch);
+    assertEquals(1, OrcCodecPool.getPoolSize(CompressionKind.ZSTD));
+    assertSame(zstdCodec, codec);
+
     assertSame(snappyCodec, OrcCodecPool.getCodec(CompressionKind.SNAPPY));
     CompressionCodec snappyCodec2 = writeBatchesAndGetCodec(
         10, 1000, opts.compress(CompressionKind.SNAPPY), batch);
diff --git a/java/core/src/test/org/apache/orc/impl/TestZstd.java 
b/java/core/src/test/org/apache/orc/impl/TestZstd.java
index bd04b6ce1..b424feb82 100644
--- a/java/core/src/test/org/apache/orc/impl/TestZstd.java
+++ b/java/core/src/test/org/apache/orc/impl/TestZstd.java
@@ -18,15 +18,19 @@
 
 package org.apache.orc.impl;
 
+import com.github.luben.zstd.Zstd;
+import com.github.luben.zstd.ZstdException;
 import io.airlift.compress.zstd.ZstdCompressor;
 import io.airlift.compress.zstd.ZstdDecompressor;
+import java.nio.ByteBuffer;
+import java.util.Random;
 import org.apache.orc.CompressionCodec;
 import org.apache.orc.CompressionKind;
 import org.junit.jupiter.api.Test;
 
-import java.nio.ByteBuffer;
-
+import static org.junit.jupiter.api.Assertions.assertEquals;
 import static org.junit.jupiter.api.Assertions.assertFalse;
+import static org.junit.jupiter.api.Assertions.fail;
 
 public class TestZstd {
 
@@ -34,12 +38,85 @@ public class TestZstd {
   public void testNoOverflow() throws Exception {
     ByteBuffer in = ByteBuffer.allocate(10);
     ByteBuffer out = ByteBuffer.allocate(10);
-    in.put(new byte[]{1,2,3,4,5,6,7,10});
+    ByteBuffer jniOut = ByteBuffer.allocate(10);
+    in.put(new byte[]{1, 2, 3, 4, 5, 6, 7, 10});
     in.flip();
     CompressionCodec codec = new AircompressorCodec(
-        CompressionKind.ZSTD, new ZstdCompressor(), new ZstdDecompressor());
+            CompressionKind.ZSTD, new ZstdCompressor(), new 
ZstdDecompressor());
     assertFalse(codec.compress(in, out, null,
-        codec.getDefaultOptions()));
+            codec.getDefaultOptions()));
+    CompressionCodec zstdCodec = new ZstdCodec();
+    assertFalse(zstdCodec.compress(in, jniOut, null,
+            zstdCodec.getDefaultOptions()));
   }
 
+  @Test
+  public void testCorrupt() throws Exception {
+    ByteBuffer buf = ByteBuffer.allocate(1000);
+    buf.put(new byte[] {127, 125, 1, 99, 98, 1});
+    buf.flip();
+    CompressionCodec codec = new ZstdCodec();
+    ByteBuffer out = ByteBuffer.allocate(1000);
+    try {
+      codec.decompress(buf, out);
+      fail();
+    } catch (ZstdException ioe) {
+      // EXPECTED
+    }
+  }
+
+  /**
+   * Test compatibility of zstd-jni and aircompressor Zstd implementations
+   * by checking that bytes compressed with one can be decompressed by the
+   * other when using the default options.
+   */
+  @Test
+  public void testZstdAircompressorJniCompressDecompress() throws Exception {
+    int inputSize = 27182;
+    Random rd = new Random();
+
+    CompressionCodec zstdAircompressorCodec = new AircompressorCodec(
+        CompressionKind.ZSTD, new ZstdCompressor(), new ZstdDecompressor());
+    CompressionCodec zstdJniCodec = new ZstdCodec();
+
+    ByteBuffer sourceCompressorIn = ByteBuffer.allocate(inputSize);
+    ByteBuffer sourceCompressorOut =
+        ByteBuffer.allocate((int) Zstd.compressBound(inputSize));
+    ByteBuffer destCompressorOut = ByteBuffer.allocate(inputSize);
+
+    // Use an array half filled with a constant value & half filled with
+    // random values.
+    byte[] constantBytes = new byte[inputSize / 2];
+    java.util.Arrays.fill(constantBytes, 0, inputSize / 2, (byte) 2);
+    sourceCompressorIn.put(constantBytes);
+    byte[] randomBytes = new byte[inputSize - inputSize / 2];
+    rd.nextBytes(randomBytes);
+    sourceCompressorIn.put(randomBytes);
+    sourceCompressorIn.flip();
+
+    // Verify that input -> aircompressor compresson -> zstd-jni
+    // decompression returns the input.
+    zstdAircompressorCodec.compress(sourceCompressorIn, sourceCompressorOut,
+        null, zstdAircompressorCodec.getDefaultOptions());
+    sourceCompressorOut.flip();
+
+    zstdJniCodec.decompress(sourceCompressorOut, destCompressorOut);
+    assertEquals(sourceCompressorIn, destCompressorOut,
+        "aircompressor compression with zstd-jni decompression did not return"
+            + " the input!");
+
+    sourceCompressorIn.rewind();
+    sourceCompressorOut.clear();
+    destCompressorOut.clear();
+
+    // Verify that input -> zstd-jni compresson -> aircompressor
+    // decompression returns the input.
+    zstdJniCodec.compress(sourceCompressorIn, sourceCompressorOut, null,
+        zstdJniCodec.getDefaultOptions());
+    sourceCompressorOut.flip();
+    zstdAircompressorCodec.decompress(sourceCompressorOut, destCompressorOut);
+    assertEquals(sourceCompressorIn, destCompressorOut,
+        "zstd-jni compression with aircompressor decompression did not return"
+            + " the input!");
+  }
 }
diff --git a/java/pom.xml b/java/pom.xml
index f4f017762..d9c9e98a2 100644
--- a/java/pom.xml
+++ b/java/pom.xml
@@ -83,6 +83,7 @@
     <storage-api.version>2.8.1</storage-api.version>
     <surefire.version>3.0.0-M5</surefire.version>
     <test.tmp.dir>${project.build.directory}/testing-tmp</test.tmp.dir>
+    <zstd-jni.version>1.5.5-11</zstd-jni.version>
   </properties>
 
   <dependencyManagement>
@@ -155,6 +156,11 @@
         <artifactId>aircompressor</artifactId>
         <version>0.25</version>
       </dependency>
+      <dependency>
+        <groupId>com.github.luben</groupId>
+        <artifactId>zstd-jni</artifactId>
+        <version>${zstd-jni.version}</version>
+      </dependency>
       <dependency>
         <groupId>org.apache.commons</groupId>
         <artifactId>commons-csv</artifactId>

Reply via email to