Lasse,
I have found a way to use VarHandle byte array access at runtime in
code which is compile time compatible with jdk 7. So here is an
updated ArrayUtil class which will use a VarHandle to read long values
in jdk 9+. If that is not available, it will attempt to use
sun.misc.Unsafe. If that cannot be found, it falls back to standard
byte by byte comparison.
I did add an index bounds check for the unsafe implementation and
found it had minimal impact on over all performance.
Using VarHandle (at least on jdk 11) offers very similar performance
to Unsafe across all 3 files I used for benchmarking.
--Baseline 1.8
Benchmark (file) Mode Cnt Score
Error Units
XZCompressionBenchmark.compress ihe_ovly_pr.dcm avgt 4 9.558
± 0.239 ms/op
XZCompressionBenchmark.compress image1.dcm avgt 4 6553.304
± 112.475 ms/op
XZCompressionBenchmark.compress large.xml avgt 4 10592.151
± 291.527 ms/op
--Unsafe
Benchmark (file) Mode Cnt Score
Error Units
XZCompressionBenchmark.compress ihe_ovly_pr.dcm avgt 4 7.699
± 0.058 ms/op
XZCompressionBenchmark.compress image1.dcm avgt 4 6001.170
± 143.814 ms/op
XZCompressionBenchmark.compress large.xml avgt 4 7853.963
± 83.753 ms/op
--VarHandle
Benchmark (file) Mode Cnt Score
Error Units
XZCompressionBenchmark.compress ihe_ovly_pr.dcm avgt 4 7.630
± 0.542 ms/op
XZCompressionBenchmark.compress image1.dcm avgt 4 5872.098
± 71.185 ms/op
XZCompressionBenchmark.compress large.xml avgt 4 8239.880
± 346.036 ms/op
I know you said you were not going to be able to work on xz-java for
awhile, but given these benchmark results, which really exceeded my
expectations, could this get some priority to release?
package org.tukaani.xz.common;
import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
import java.lang.reflect.Constructor;
import java.lang.reflect.Method;
import java.nio.ByteOrder;
import java.util.logging.Level;
import java.util.logging.Logger;
/**
* Utilities for optimized array interactions.
*
* @author Brett Okken
*/
public final class ArrayUtil {
/**
* MethodHandle to the actual mismatch method to use at runtime.
*/
private static final MethodHandle MISMATCH;
/**
* If {@code sun.misc.Unsafe} can be loaded, this is MethodHandle
bound to an instance of Unsafe for method {@code long getLong(Object,
long)}.
*/
private static final MethodHandle UNSAFE_GET_LONG;
/**
* MethodHandle to either {@link Long#numberOfLeadingZeros(long)}
or {@link Long#numberOfTrailingZeros(long)} depending on {@link
ByteOrder#nativeOrder()}.
*/
private static final MethodHandle LEADING_ZEROS;
/**
* Populated from reflected read of {@code
sun.misc.Unsafe.ARRAY_BYTE_BASE_OFFSET}.
*/
private static final long ARRAY_BASE_OFFSET;
/**
* {@code MethodHandle} for a jdk 9+ {@code
byteArrayViewVarHandle} for {@code long[]} using the {@link
ByteOrder#nativeOrder()}.
* The method signature is {@code long get(byte[], int)}.
*/
private static final MethodHandle VAR_HANDLE_GET_LONG;
static {
final Logger logger = Logger.getLogger(ArrayUtil.class.getName());
MethodHandle leadingZeros = null;
MethodHandle varHandleGetLong = null;
MethodHandle unsafeGetLong = null;
long arrayBaseOffset = 0;
MethodHandle mismatch = null;
final MethodHandles.Lookup lookup = MethodHandles.lookup();
final MethodType mismatchType =
MethodType.methodType(int.class, byte[].class, int.class,
byte[].class, int.class, int.class);
try {
//getLong interprets in platform byte order. the concept
of "leading zeros" being bytes
//in encounter order is true for big endian
//for little endian platform, the trailing zeros gives the
encounter order result
leadingZeros = lookup.findStatic(Long.class,
ByteOrder.BIG_ENDIAN ==
ByteOrder.nativeOrder()
?
"numberOfLeadingZeros" : "numberOfTrailingZeros",
MethodType.methodType(int.class, long.class));
//first try to load byteArrayViewVarHandle for a long[]
try {
final Class<?> varHandleClazz =
Class.forName("java.lang.invoke.VarHandle", true, null);
final Method byteArrayViewHandle =
MethodHandles.class.getDeclaredMethod("byteArrayViewVarHandle", new
Class[] {Class.class, ByteOrder.class});
final Object varHandle =
byteArrayViewHandle.invoke(null, long[].class,
ByteOrder.nativeOrder());
final Class<?> accessModeEnum =
Class.forName("java.lang.invoke.VarHandle$AccessMode", true, null);
@SuppressWarnings({ "unchecked", "rawtypes" })
final Object getAccessModeEnum =
Enum.valueOf((Class)accessModeEnum, "GET");
final Method toMethodHandle =
varHandleClazz.getDeclaredMethod("toMethodHandle", accessModeEnum);
varHandleGetLong = (MethodHandle)
toMethodHandle.invoke(varHandle, getAccessModeEnum);
mismatch = lookup.findStatic(ArrayUtil.class,
"varHandleMismatch", mismatchType);
logger.finest("byte[] comparison using VarHandle");
} catch (Throwable t) {
logger.log(Level.FINE, "failed trying to load a
MethodHandle to invoke get on a byteArrayViewVarHandle for a long[]",
t);
unsafeGetLong = null;
mismatch = null;
}
//if byteArrayViewVarHandle for a long[] could not be
loaded, then try to load sun.misc.Unsafe
if (mismatch == null) {
Class<?> unsafeClazz =
Class.forName("sun.misc.Unsafe", true, null);
Constructor<?> unsafeConstructor =
unsafeClazz.getDeclaredConstructor();
unsafeConstructor.setAccessible(true);
Object unsafe = unsafeConstructor.newInstance();
arrayBaseOffset =
unsafeClazz.getField("ARRAY_BYTE_BASE_OFFSET").getLong(null);
MethodHandle virtualGetLong =
lookup.findVirtual(unsafeClazz, "getLong",
MethodType.methodType(long.class, Object.class, long.class));
unsafeGetLong = virtualGetLong.bindTo(unsafe);
// do a test read to confirm unsafe is actually functioning
long val = (long) unsafeGetLong.invokeExact((Object)
new byte[] { 0, 0, 0, 0, 0, 0, 0, 0 }, arrayBaseOffset + 0L);
if (val != 0) {
throw new IllegalStateException("invalid value: " + val);
}
mismatch = lookup.findStatic(ArrayUtil.class,
"unsafeMismatch", mismatchType);
logger.finest("byte[] comparisons using Unsafe");
}
} catch (Throwable t) {
logger.log(Level.FINE, "failed trying to load means to
compare byte[] by longs", t);
logger.finest("byte[] comparisons byte by byte");
varHandleGetLong = null;
unsafeGetLong = null;
leadingZeros = null;
try {
mismatch = lookup.findStatic(ArrayUtil.class,
"legacyMismatch", mismatchType);
} catch (Exception e) {
throw new IllegalStateException(e);
}
}
VAR_HANDLE_GET_LONG = varHandleGetLong;
UNSAFE_GET_LONG = unsafeGetLong;
ARRAY_BASE_OFFSET = arrayBaseOffset;
LEADING_ZEROS = leadingZeros;
MISMATCH = mismatch;
}
/**
* Compares the values in <i>a</i> and <i>b</i> and returns the
index of the first {@code byte} which differs.
* @param a The first {@code byte[]} for comparison.
* @param aFromIndex The offset into <i>a</i> to start reading from.
* @param b The second {@code byte[]} for comparison.
* @param bFromIndex The offset into <i>b</i> to start reading from.
* @param length The number of bytes to compare.
* @return The offset from the starting indexes of the first byte
which differs. If all match, <i>length</i> will be returned.
*/
public static int mismatch(byte[] a, int aFromIndex, byte[] b, int
bFromIndex, int length) {
try {
return (int) MISMATCH.invokeExact(a, aFromIndex, b,
bFromIndex, length);
} catch (RuntimeException e) {
throw e;
} catch (Error e) {
throw e;
} catch (Throwable t) {
throw new RuntimeException(t);
}
}
/**
* Uses {@link #VAR_HANDLE_GET_LONG} to compare 8 bytes at a time.
*/
@SuppressWarnings("unused")
private static int varHandleMismatch(byte[] a, int aFromIndex,
byte[] b, int bFromIndex, int length) throws Throwable {
//while we could do an index check, the VarHandle call
incorporates a check, making any check here duplicative
int i=0;
for (int j=length - 7; i<j; i+=8) {
final long aVal = (long)
VAR_HANDLE_GET_LONG.invokeExact(a, aFromIndex + i);
final long bVal = (long)
VAR_HANDLE_GET_LONG.invokeExact(b, bFromIndex + i);
if (aVal != bVal) {
//this returns a value where bits which match are 0
and bits which differ are 1
final long diff = aVal ^ bVal;
//the first (in native byte order) bit which differs
tells us which byte differed
final int leadingZeros = (int) LEADING_ZEROS.invokeExact(diff);
return i + (leadingZeros / Byte.SIZE);
}
}
for ( ; i<length; ++i) {
if (a[aFromIndex + i] != b[bFromIndex + i]) {
return i;
}
}
return length;
}
/**
* Uses {@code UNSAFE_GET_LONG} to compare 8 bytes at a time.
*/
@SuppressWarnings("unused")
private static int unsafeMismatch(byte[] a, int aFromIndex, byte[]
b, int bFromIndex, int length) throws Throwable {
//it is important to check the indexes prior to making the
Unsafe calls, as Unsafe does not validate
//and could result in SIGSEGV if out of bounds
if (aFromIndex < 0 || aFromIndex + length > a.length ||
bFromIndex < 0 || bFromIndex + length > b.length) {
throw new ArrayIndexOutOfBoundsException();
}
int i=0;
for (int j=length - 7; i<j; i+=8) {
final long aVal = (long)
UNSAFE_GET_LONG.invokeExact((Object) a, ARRAY_BASE_OFFSET + aFromIndex
+ i);
final long bVal = (long)
UNSAFE_GET_LONG.invokeExact((Object) b, ARRAY_BASE_OFFSET + bFromIndex
+ i);
if (aVal != bVal) {
//this returns a value where bits which match are 0
and bits which differ are 1
final long diff = aVal ^ bVal;
//the first (in native byte order) bit which differs
tells us which byte differed
final int leadingZeros = (int) LEADING_ZEROS.invokeExact(diff);
return i + (leadingZeros / Byte.SIZE);
}
}
for ( ; i<length; ++i) {
if (a[aFromIndex + i] != b[bFromIndex + i]) {
return i;
}
}
return length;
}
/**
* Simply loops over all of the bytes, comparing one at a time.
*/
@SuppressWarnings("unused")
private static int legacyMismatch(byte[] a, int aFromIndex, byte[]
b, int bFromIndex, int length) {
for (int i=0; i<length; ++i) {
if (a[aFromIndex + i] != b[bFromIndex + i]) {
return i;
}
}
return length;
}
private ArrayUtil() {
}
}