package org.tukaani.xz.common;
import static java.lang.invoke.MethodType.methodType;
import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
import java.lang.reflect.Constructor;
import java.lang.reflect.Method;
import java.nio.ByteOrder;
import java.util.Arrays;
import java.util.Locale;
import java.util.Properties;
import java.util.logging.Level;
import java.util.logging.Logger;
/**
* Utilities for optimized array interactions.
*
* <p>
* The means of comparing arrays can be controlled by setting the
system property
* {@code org.tukaani.xz.ArrayComparison} to a value from {@link
ArrayComparison}.
* </p>
*
* @author Brett Okken
*/
public final class ArrayUtil {
/**
* Enumerated options for controlling implementation of how to
compare arrays.
*/
public static enum ArrayComparison {
/**
* Uses {@code VarHandle} for {@code int[]} access.
* <p>
* This is default behavior on jdk9+ for 32 bit x86.
* </p>
*/
VH_INT,
/**
* Uses {@code VarHandle} for {@code int[]} access after attempting
* to align the reads on 4 byte boundaries.
*/
VH_INT_ALIGN,
/**
* Uses {@code VarHandle} for {@code long[]} access.
* <p>
* This is default behavior on jdk9+ for 64 bit x86.
* </p>
*/
VH_LONG,
/**
* Uses {@code VarHandle} for {@code long[]} access after attempting
* to align the reads.
*/
VH_LONG_ALIGN,
/**
* Uses {@code Arrays.mismatch()} to perform vectorized comparison.
* <p>
* This is default behavior on jdk9+ for non-x86.
* </p>
*/
VECTOR,
/**
* Uses {@code sun.misc.Unsafe.getInt()} for unaligned {@code int[]}
* access.
* <p>
* This is default behavior on jdk 8 and prior for 32 bit x86.
* </p>
*/
UNSAFE_GET_INT,
/**
* Uses {@code sun.misc.Unsafe.getLong()} for unaligned {@code long[]}
* access.
* <p>
* This is default behavior on jdk 8 and prior for 64 bit x86.
* </p>
*/
UNSAFE_GET_LONG,
/**
* Performs byte-by-byte comparison.
*/
LEGACY;
static ArrayComparison getFromProperty(String prop) {
if (prop == null || prop.isEmpty()) {
return null;
}
try {
return ArrayComparison.valueOf(prop.toUpperCase(Locale.US));
} catch (Exception e) {
final Logger logger =
Logger.getLogger(ArrayUtil.class.getName());
logger.log(Level.INFO,
"Invalid ArrayComparison option, using
default behavior",
e);
return null;
}
}
}
/**
* MethodHandle to the actual mismatch method to use at runtime.
*/
private static final MethodHandle MISMATCH;
/**
* The method this is bound to at runtime depends on the chosen
* implementation for {@code byte[]} comparison.
* <p>
* For {@code long} based comparisons, it will be bound to either
* {@link Long#numberOfLeadingZeros(long)} or
* {@link Long#numberOfTrailingZeros(long)} depending on
* {@link ByteOrder#nativeOrder()}.
* </p>
* <p>
* For {@code int} based comparisons it will be bound to either
* {@link Integer#numberOfLeadingZeros(int)} or
* {@link Integer#numberOfTrailingZeros(int)} depending on
* {@link ByteOrder#nativeOrder()}.
* </p>
*/
private static final MethodHandle LEADING_ZEROS;
/**
* Populated from reflected read of
* {@code sun.misc.Unsafe.ARRAY_BYTE_BASE_OFFSET} if one of the unsafe
* implementations is used.
*/
private static final long ARRAY_BASE_OFFSET;
/**
* The method this is bound to at runtime is depends on the chosen
* implementation for {@code byte[]} comparison.
* <p>
* For {@link ArrayComparison#VECTOR} and
* {@link ArrayComparison#LEGACY} this will be {@code null}.
* </p>
* <p>
* For {@link ArrayComparison#VH_INT} and {@link
ArrayComparison#VH_INT_ALIGN}
* this will be a jdk 9+ {@code byteArrayViewVarHandle} for {@code int[]}
* using the {@link ByteOrder#nativeOrder()}. The method signature is
* {@code int get(byte[], int)}.
* </p>
* <p>
* For {@link ArrayComparison#VH_LONG} and {@link
ArrayComparison#VH_LONG_ALIGN}
* this will be a jdk 9+ {@code byteArrayViewVarHandle} for {@code long[]}
* using the {@link ByteOrder#nativeOrder()}. The method signature is
* {@code long get(byte[], int)}.
* </p>
* <p>
* For {@link ArrayComparison#UNSAFE_GET_INT} this is bound to
* {@code sun.misc.Unsafe.getInt(Object, long)}.
* </p>
* <p>
* For {@link ArrayComparison#UNSAFE_GET_LONG} this is bound to
* {@code sun.misc.Unsafe.getLong(Object, long)}.
* </p>
*/
private static final MethodHandle GET_PRIMITIVE;
/**
* MethodHandle to the jdk 9+
* {@code Arrays.mismatch(byte[] a, int aFromIndex, int aToIndex,
byte[] b, int bFromIndex, int bToIndex)}.
*/
private static final MethodHandle ARRAYS_MISMATCH;
static {
final Logger logger = Logger.getLogger(ArrayUtil.class.getName());
MethodHandle leadingZeros = null;
MethodHandle getPrimitive = null;
MethodHandle arraysMismatch = null;
long arrayBaseOffset = 0;
MethodHandle mismatch = null;
final MethodHandles.Lookup lookup = MethodHandles.lookup();
final MethodType mismatchType = methodType(
int.class, byte[].class, int.class, int.class, int.class);
try {
final Properties props = System.getProperties();
final ArrayComparison algo = ArrayComparison.getFromProperty(
props.getProperty("org.tukaani.xz.ArrayComparison"));
final String arch = props.getProperty("os.arch", "");
final boolean unaligned =
arch.matches("^(i[3-6]86|x86(_64)?|x64|amd64)$");
//if unaligned, or explicitly configured, try VarHandles
if ((unaligned && algo == null)
|| algo == ArrayComparison.VH_LONG
|| algo == ArrayComparison.VH_LONG_ALIGN
|| algo == ArrayComparison.VH_INT
|| algo == ArrayComparison.VH_INT_ALIGN) {
try {
final Class<?> varHandleClazz =
Class.forName("java.lang.invoke.VarHandle", true, null);
final Method byteArrayViewHandle =
MethodHandles.class.getDeclaredMethod("byteArrayViewVarHandle",
new
Class[] {Class.class,
ByteOrder.class});
final boolean doLong = (algo == null && arch.contains("64"))
|| algo == ArrayComparison.VH_LONG
|| algo ==
ArrayComparison.VH_LONG_ALIGN;
final Object varHandle =
byteArrayViewHandle.invoke(null,
doLong ?
long[].class : int[].class,
ByteOrder.nativeOrder());
final Class<?> accessModeEnum =
Class.forName("java.lang.invoke.VarHandle$AccessMode", true, null);
@SuppressWarnings({ "unchecked", "rawtypes" })
final Object getAccessModeEnum =
Enum.valueOf((Class)accessModeEnum, "GET");
final Method toMethodHandle =
varHandleClazz.getDeclaredMethod("toMethodHandle", accessModeEnum);
getPrimitive =
(MethodHandle)
toMethodHandle.invoke(varHandle, getAccessModeEnum);
//the primitive will be in platform byte order.
the concept of "leading
//zeros" being bytes in encounter order is true
for big endian
//for little endian platform, the trailing zeros
gives the encounter order result
leadingZeros = lookup.findStatic(doLong?
Long.class : Integer.class,
ByteOrder.BIG_ENDIAN == ByteOrder.nativeOrder()
?
"numberOfLeadingZeros" : "numberOfTrailingZeros",
methodType(int.class, doLong ? long.class : int.class));
final String mismatchMethod;
if (doLong) {
mismatchMethod = algo == null || algo ==
ArrayComparison.VH_LONG
? "varHandleMismatch" :
"alignedVarHandleMismatch";
logger.finest("byte[] comparison using long VarHandle");
} else {
mismatchMethod = algo == null || algo ==
ArrayComparison.VH_INT
? "intVarHandleMismatch" :
"alignedIntVarHandleMismatch";
logger.finest("byte[] comparison using int VarHandle");
}
mismatch =
lookup.findStatic(ArrayUtil.class,
mismatchMethod, mismatchType);
} catch (Throwable t) {
logger.log(Level.FINE,
"failed trying to load a MethodHandle
to invoke get on a byteArrayViewVarHandle",
t);
mismatch = null;
}
}
if (mismatch == null && ((!unaligned && algo == null)
|| algo == ArrayComparison.VECTOR)) {
try {
final MethodType arraysType =
methodType(int.class, byte[].class, int.class, int.class,
byte[].class, int.class, int.class);
arraysMismatch = lookup.findStatic(Arrays.class,
"mismatch", arraysType);
mismatch =
lookup.findStatic(ArrayUtil.class,
"arraysMismatch", mismatchType);
logger.finest("byte[] comparisons using Arrays.mismatch");
} catch (Throwable t) {
logger.log(Level.FINE,
"failed trying to load a MethodHandle
to invoke Arrays.mismatch",
t);
arraysMismatch = null;
}
}
//if byteArrayViewVarHandle for a long[] could not be loaded, then
//try to load sun.misc.Unsafe for unaligned archs only
if (mismatch == null && ((unaligned && algo == null)
|| algo == ArrayComparison.UNSAFE_GET_LONG
|| algo ==
ArrayComparison.UNSAFE_GET_INT)) {
Class<?> unsafeClazz =
Class.forName("sun.misc.Unsafe", true, null);
Constructor<?> unsafeConstructor =
unsafeClazz.getDeclaredConstructor();
unsafeConstructor.setAccessible(true);
Object unsafe = unsafeConstructor.newInstance();
arrayBaseOffset = unsafeClazz.getField("ARRAY_BYTE_BASE_OFFSET")
.getLong(null);
if (algo == ArrayComparison.UNSAFE_GET_LONG || (algo
== null && arch.contains("64"))) {
leadingZeros = lookup.findStatic(Long.class,
ByteOrder.BIG_ENDIAN == ByteOrder.nativeOrder()
?
"numberOfLeadingZeros" : "numberOfTrailingZeros",
methodType(int.class, long.class));
MethodHandle virtualGetLong = lookup.findVirtual(
unsafeClazz, "getLong", methodType(long.class,
Object.class, long.class));
getPrimitive = virtualGetLong.bindTo(unsafe);
// do a test read to confirm unsafe is actually functioning
long val = (long) getPrimitive.invokeExact(
(Object) new byte[] { 0, 0, 0, 0, 0, 0, 0, 0
}, arrayBaseOffset + 0L);
if (val != 0) {
throw new IllegalStateException("invalid
value: " + val);
}
mismatch = lookup.findStatic(ArrayUtil.class,
"unsafeMismatch", mismatchType);
logger.finest("byte[] comparisons using Unsafe.getLong");
} else {
leadingZeros = lookup.findStatic(Integer.class,
ByteOrder.BIG_ENDIAN == ByteOrder.nativeOrder()
?
"numberOfLeadingZeros" : "numberOfTrailingZeros",
methodType(int.class, int.class));
MethodHandle virtualGetInt = lookup.findVirtual(
unsafeClazz, "getInt", methodType(int.class,
Object.class, long.class));
getPrimitive = virtualGetInt.bindTo(unsafe);
// do a test read to confirm unsafe is actually functioning
int val = (int) getPrimitive.invokeExact(
(Object) new byte[] { 0, 0, 0, 0 },
arrayBaseOffset + 0L);
if (val != 0) {
throw new IllegalStateException("invalid
value: " + val);
}
mismatch = lookup.findStatic(ArrayUtil.class,
"intUnsafeMismatch", mismatchType);
logger.finest("byte[] comparisons using Unsafe.getInt");
}
}
} catch (Throwable t) {
logger.log(Level.FINE, "failed trying to load means to
compare byte[] by longs", t);
}
if (mismatch == null) {
getPrimitive = null;
leadingZeros = null;
arraysMismatch = null;
logger.finest("byte[] comparisons byte by byte");
try {
mismatch = lookup.findStatic(ArrayUtil.class,
"legacyMismatch", mismatchType);
} catch (Exception e) {
throw new IllegalStateException(e);
}
}
GET_PRIMITIVE = getPrimitive;
ARRAY_BASE_OFFSET = arrayBaseOffset;
LEADING_ZEROS = leadingZeros;
ARRAYS_MISMATCH = arraysMismatch;
MISMATCH = mismatch;
}
/**
* Compares the values in <i>bytes</i>, starting at <i>aFromIndex</i> and
* <i>bFromIndex</i> and returns the zero-based index of the first
* {@code byte} which differs.
* @param bytes The {@code byte[]} for comparison.
* @param aFromIndex The first offset into <i>bytes</i> to start
reading from.
* @param bFromIndex The second offset into <i>bytes</i> to start
reading from.
* @param length The number of bytes to compare.
* @return The offset from the starting indexes of the first byte
which differs.
* If all match, <i>length</i> will be returned.
*/
public static int mismatch(
byte[] bytes, int aFromIndex, int bFromIndex, int length) {
try {
return (int) MISMATCH.invokeExact(bytes, aFromIndex,
bFromIndex, length);
} catch (RuntimeException e) {
throw e;
} catch (Error e) {
throw e;
} catch (Throwable t) {
throw new RuntimeException(t);
}
}
/**
* If <i>aFromIndex</i> and <i>bFromIndex</i> are similarly
mis-aligned, will
* do single byte comparisons to obtain alignment, then call
* {@link #varHandleMismatch(byte[], int, byte[], int, int)}.
*/
@SuppressWarnings("unused")
private static int alignedVarHandleMismatch(
byte[] a, int aFromIndex, int bFromIndex, int length)
throws Throwable {
//while we could do an index check, the VarHandle call
incorporates a check,
//making any check here duplicative
int aFromAlignment = aFromIndex & 7;
int bFromAlignment = bFromIndex & 7;
//if they are aligned, just go
if (aFromAlignment == 0 && bFromAlignment == 0) {
return varHandleMismatch(a, aFromIndex, bFromIndex, length);
}
int i=0;
//if both are similarly out of alignment, adjust
if (aFromAlignment == bFromAlignment) {
for (int j = Math.min(8 - aFromAlignment, length); i < j; ++i) {
if (a[aFromIndex + i] != a[bFromIndex + i]) {
return i;
}
}
} else if ((aFromAlignment & 1) == 1 && (bFromAlignment & 1) == 1) {
//if they both have an odd alignment, adjust by one
if (a[aFromIndex] != a[bFromIndex]) {
return 0;
}
++i;
}
return i + varHandleMismatch(a, aFromIndex + i, bFromIndex +
i, length - i);
}
/**
* Uses {@link #VAR_HANDLE_GET_LONG} to compare 8 bytes at a time.
*/
private static int varHandleMismatch(
byte[] a, int aFromIndex, int bFromIndex, int length)
throws Throwable {
//while we could do an index check, the VarHandle call
incorporates a check,
//making any check here duplicative
int i=0;
for (int j=length-7; i<j; i+=8) {
final long aVal = (long) GET_PRIMITIVE.invokeExact(a,
aFromIndex + i);
final long bVal = (long) GET_PRIMITIVE.invokeExact(a,
bFromIndex + i);
if (aVal != bVal) {
//this returns a value where bits which match are 0
and bits which
//differ are 1
final long diff = aVal ^ bVal;
//the first (in native byte order) bit which differs
tells us which
//byte differed
final int leadingZeros = (int) LEADING_ZEROS.invokeExact(diff);
return Math.min(i + (leadingZeros >>> 3), length);
}
}
for ( ; i<length; ++i) {
if (a[aFromIndex + i] != a[bFromIndex + i]) {
return i;
}
}
return length;
}
/**
* If <i>aFromIndex</i> and <i>bFromIndex</i> are similarly
mis-aligned, will
* do single byte comparisons to obtain alignment, then call
* {@link #intVarHandleMismatch(byte[], int, byte[], int, int)}.
*/
@SuppressWarnings("unused")
private static int alignedIntVarHandleMismatch(
byte[] a, int aFromIndex, int bFromIndex, int length)
throws Throwable {
//while we could do an index check, the VarHandle call
incorporates a check,
//making any check here duplicative
int aFromAlignment = aFromIndex & 3;
int bFromAlignment = bFromIndex & 3;
//if they are aligned, just go
if (aFromAlignment == 0 && bFromAlignment == 0) {
return intVarHandleMismatch(a, aFromIndex, bFromIndex, length);
}
int i=0;
//if both are similarly out of alignment, adjust
if (aFromAlignment == bFromAlignment) {
for (int j=Math.min(4 - aFromAlignment, length); i<j; ++i) {
if (a[aFromIndex + i] != a[bFromIndex + i]) {
return i;
}
}
} else if ((aFromAlignment & 1) == 1 && (bFromAlignment & 1) == 1) {
//if they both have an odd alignment, adjust by one
if (a[aFromIndex] != a[bFromIndex]) {
return 0;
}
++i;
}
return i + intVarHandleMismatch(a, aFromIndex + i, bFromIndex
+ i, length - i);
}
/**
* Uses {@link #VAR_HANDLE_GET_INT} to compare 4 bytes at a time.
*/
private static int intVarHandleMismatch(
byte[] a, int aFromIndex, int bFromIndex, int length)
throws Throwable {
//while we could do an index check, the VarHandle call
incorporates a check,
//making any check here duplicative
int i=0;
for (int j=length-3; i<j; i+=4) {
final int aVal = (int) GET_PRIMITIVE.invokeExact(a, aFromIndex + i);
final int bVal = (int) GET_PRIMITIVE.invokeExact(a, bFromIndex + i);
if (aVal != bVal) {
//this returns a value where bits which match are 0
and bits which
//differ are 1
final int diff = aVal ^ bVal;
//the first (in native byte order) bit which differs tells us
//which byte differed
final int leadingZeros = (int) LEADING_ZEROS.invokeExact(diff);
return Math.min(i + (leadingZeros >>> 3), length);
}
}
for ( ; i<length; ++i) {
if (a[aFromIndex + i] != a[bFromIndex + i]) {
return i;
}
}
return length;
}
/**
* Uses {@code UNSAFE_GET_LONG} to compare 8 bytes at a time.
*/
@SuppressWarnings("unused")
private static int unsafeMismatch(
byte[] a, int aFromIndex, int bFromIndex, int length)
throws Throwable {
//it is important to check the indexes prior to making the Unsafe calls,
//as Unsafe does not validate and could result in SIGSEGV if
out of bounds
if (length < 0 || aFromIndex < 0 || bFromIndex < 0
|| Math.max(aFromIndex, bFromIndex) > a.length - length) {
throw new ArrayIndexOutOfBoundsException();
}
int i=0;
for (int j=length-7; i<j; i+=8) {
final long aVal = (long)
GET_PRIMITIVE.invokeExact((Object) a, ARRAY_BASE_OFFSET + aFromIndex +
i);
final long bVal = (long)
GET_PRIMITIVE.invokeExact((Object) a, ARRAY_BASE_OFFSET + bFromIndex +
i);
if (aVal != bVal) {
//this returns a value where bits which match are 0
and bits which differ are 1
final long diff = aVal ^ bVal;
//the first (in native byte order) bit which differs
tells us which byte differed
final int leadingZeros = (int) LEADING_ZEROS.invokeExact(diff);
return i + (leadingZeros >>> 3);
}
}
for ( ; i<length; ++i) {
if (a[aFromIndex + i] != a[bFromIndex + i]) {
return i;
}
}
return length;
}
/**
* Uses {@code UNSAFE_GET_LONG} to compare 8 bytes at a time.
*/
@SuppressWarnings("unused")
private static int intUnsafeMismatch(
byte[] a, int aFromIndex, int bFromIndex, int length)
throws Throwable {
//it is important to check the indexes prior to making the Unsafe calls,
//as Unsafe does not validate and could result in SIGSEGV if
out of bounds
if (length < 0 || aFromIndex < 0 || bFromIndex < 0
|| Math.max(aFromIndex, bFromIndex) > a.length - length) {
throw new ArrayIndexOutOfBoundsException();
}
int i=0;
for (int j=length-3; i<j; i+=4) {
final int aVal = (int) GET_PRIMITIVE.invokeExact((Object)
a, ARRAY_BASE_OFFSET + aFromIndex + i);
final int bVal = (int) GET_PRIMITIVE.invokeExact((Object)
a, ARRAY_BASE_OFFSET + bFromIndex + i);
if (aVal != bVal) {
//this returns a value where bits which match are 0
and bits which
//differ are 1
final int diff = aVal ^ bVal;
//the first (in native byte order) bit which differs tells us
//which byte differed
final int leadingZeros = (int) LEADING_ZEROS.invokeExact(diff);
return i + (leadingZeros >>> 3);
}
}
for ( ; i<length; ++i) {
if (a[aFromIndex + i] != a[bFromIndex + i]) {
return i;
}
}
return length;
}
/**
* Uses {@code ARRAYS_MISMATCH} to compare <i>a</i> and <i>b</i>.
*/
@SuppressWarnings("unused")
private static int arraysMismatch(
byte[] a, int aFromIndex, int bFromIndex, int length)
throws Throwable {
final int m = (int) ARRAYS_MISMATCH.invokeExact(a, aFromIndex,
aFromIndex + length,
a, bFromIndex,
bFromIndex + length);
return m == -1 ? length : m;
}
/**
* Simply loops over all of the bytes, comparing one at a time.
*/
@SuppressWarnings("unused")
private static int legacyMismatch(
byte[] a, int aFromIndex, int bFromIndex, int length) {
for (int i=0; i<length; ++i) {
if (a[aFromIndex + i] != a[bFromIndex + i]) {
return i;
}
}
return length;
}
private ArrayUtil() {
}
}