This is an automated email from the ASF dual-hosted git repository. aleksey pushed a commit to branch 15202-4.0 in repository https://gitbox.apache.org/repos/asf/cassandra.git
commit be64c1c26a879bad97390ad3d38e2ed03806c8c8 Author: Jeff Jirsa <[email protected]> AuthorDate: Sat Mar 16 17:30:54 2019 +0000 Make repair coordination less expensive by moving MerkleTrees off heap patch by Aleksey Yeschenko and Jeff Jirsa; reviewed by Benedict Elliott Smith and Marcus Eriksson for CASSANDRA-15202 Co-authored-by: Aleksey Yeschenko <[email protected]> Co-authored-by: Jeff Jirsa <[email protected]> --- src/java/org/apache/cassandra/config/Config.java | 2 + .../cassandra/config/DatabaseDescriptor.java | 11 + .../cassandra/dht/ByteOrderedPartitioner.java | 6 + .../org/apache/cassandra/dht/IPartitioner.java | 5 + .../apache/cassandra/dht/Murmur3Partitioner.java | 33 + .../apache/cassandra/dht/RandomPartitioner.java | 32 + src/java/org/apache/cassandra/dht/Token.java | 40 +- .../org/apache/cassandra/repair/RepairJob.java | 3 + .../org/apache/cassandra/repair/Validator.java | 73 +- .../repair/messages/ValidationComplete.java | 9 + .../cassandra/service/ActiveRepairService.java | 10 + .../org/apache/cassandra/utils/ByteBufferUtil.java | 30 +- .../org/apache/cassandra/utils/FBUtilities.java | 16 + .../apache/cassandra/utils/FastByteOperations.java | 46 +- .../org/apache/cassandra/utils/MerkleTree.java | 1406 +++++++++++++------- .../org/apache/cassandra/utils/MerkleTrees.java | 63 +- .../apache/cassandra/repair/LocalSyncTaskTest.java | 4 - .../org/apache/cassandra/repair/RepairJobTest.java | 4 - .../org/apache/cassandra/repair/ValidatorTest.java | 18 +- .../repair/asymmetric/DifferenceHolderTest.java | 6 +- .../org/apache/cassandra/utils/MerkleTreeTest.java | 127 +- .../apache/cassandra/utils/MerkleTreesTest.java | 99 +- 22 files changed, 1344 insertions(+), 699 deletions(-) diff --git a/src/java/org/apache/cassandra/config/Config.java b/src/java/org/apache/cassandra/config/Config.java index 34a5ce8..b56419d 100644 --- a/src/java/org/apache/cassandra/config/Config.java +++ b/src/java/org/apache/cassandra/config/Config.java @@ -130,6 +130,8 @@ public class Config public volatile Integer repair_session_max_tree_depth = null; public volatile Integer repair_session_space_in_mb = null; + public volatile boolean prefer_offheap_merkle_trees = true; + public int storage_port = 7000; public int ssl_storage_port = 7001; public String listen_address; diff --git a/src/java/org/apache/cassandra/config/DatabaseDescriptor.java b/src/java/org/apache/cassandra/config/DatabaseDescriptor.java index 0166c5f..49761da 100644 --- a/src/java/org/apache/cassandra/config/DatabaseDescriptor.java +++ b/src/java/org/apache/cassandra/config/DatabaseDescriptor.java @@ -2907,4 +2907,15 @@ public class DatabaseDescriptor { return strictRuntimeChecks; } + + public static boolean getOffheapMerkleTreesEnabled() + { + return conf.prefer_offheap_merkle_trees; + } + + public static void setOffheapMerkleTreesEnabled(boolean value) + { + logger.info("Setting prefer_offheap_merkle_trees to {}", value); + conf.prefer_offheap_merkle_trees = value; + } } diff --git a/src/java/org/apache/cassandra/dht/ByteOrderedPartitioner.java b/src/java/org/apache/cassandra/dht/ByteOrderedPartitioner.java index 08088f7..a6314dc 100644 --- a/src/java/org/apache/cassandra/dht/ByteOrderedPartitioner.java +++ b/src/java/org/apache/cassandra/dht/ByteOrderedPartitioner.java @@ -234,6 +234,12 @@ public class ByteOrderedPartitioner implements IPartitioner return new BytesToken(bytes); } + @Override + public int byteSize(Token token) + { + return ((BytesToken) token).token.length; + } + public String toString(Token token) { BytesToken bytesToken = (BytesToken) token; diff --git a/src/java/org/apache/cassandra/dht/IPartitioner.java b/src/java/org/apache/cassandra/dht/IPartitioner.java index f433f20..ef8ced2 100644 --- a/src/java/org/apache/cassandra/dht/IPartitioner.java +++ b/src/java/org/apache/cassandra/dht/IPartitioner.java @@ -135,4 +135,9 @@ public interface IPartitioner { return Optional.empty(); } + + default public int getMaxTokenSize() + { + return Integer.MIN_VALUE; + } } diff --git a/src/java/org/apache/cassandra/dht/Murmur3Partitioner.java b/src/java/org/apache/cassandra/dht/Murmur3Partitioner.java index 0f922e3..52d0efb 100644 --- a/src/java/org/apache/cassandra/dht/Murmur3Partitioner.java +++ b/src/java/org/apache/cassandra/dht/Murmur3Partitioner.java @@ -17,6 +17,7 @@ */ package org.apache.cassandra.dht; +import java.io.IOException; import java.math.BigDecimal; import java.math.BigInteger; import java.nio.ByteBuffer; @@ -25,10 +26,12 @@ import java.util.concurrent.ThreadLocalRandom; import org.apache.cassandra.db.DecoratedKey; import org.apache.cassandra.db.PreHashedDecoratedKey; +import org.apache.cassandra.db.TypeSizes; import org.apache.cassandra.db.marshal.AbstractType; import org.apache.cassandra.db.marshal.PartitionerDefinedOrder; import org.apache.cassandra.db.marshal.LongType; import org.apache.cassandra.exceptions.ConfigurationException; +import org.apache.cassandra.io.util.DataOutputPlus; import org.apache.cassandra.utils.ByteBufferUtil; import org.apache.cassandra.utils.MurmurHash; import org.apache.cassandra.utils.ObjectSizes; @@ -42,6 +45,7 @@ public class Murmur3Partitioner implements IPartitioner { public static final LongToken MINIMUM = new LongToken(Long.MIN_VALUE); public static final long MAXIMUM = Long.MAX_VALUE; + private static final int MAXIMUM_TOKEN_SIZE = TypeSizes.sizeof(MAXIMUM); private static final int HEAP_SIZE = (int) ObjectSizes.measureDeep(MINIMUM); @@ -224,6 +228,11 @@ public class Murmur3Partitioner implements IPartitioner return new LongToken(normalize(hash[0])); } + public int getMaxTokenSize() + { + return MAXIMUM_TOKEN_SIZE; + } + private long[] getHash(ByteBuffer key) { long[] hash = new long[2]; @@ -300,11 +309,35 @@ public class Murmur3Partitioner implements IPartitioner return ByteBufferUtil.bytes(longToken.token); } + @Override + public void serialize(Token token, DataOutputPlus out) throws IOException + { + out.writeLong(((LongToken) token).token); + } + + @Override + public void serialize(Token token, ByteBuffer out) + { + out.putLong(((LongToken) token).token); + } + + @Override + public int byteSize(Token token) + { + return 8; + } + public Token fromByteArray(ByteBuffer bytes) { return new LongToken(ByteBufferUtil.toLong(bytes)); } + @Override + public Token fromByteBuffer(ByteBuffer bytes, int position, int length) + { + return new LongToken(bytes.getLong(position)); + } + public String toString(Token token) { return token.toString(); diff --git a/src/java/org/apache/cassandra/dht/RandomPartitioner.java b/src/java/org/apache/cassandra/dht/RandomPartitioner.java index 4e63475..0457a89 100644 --- a/src/java/org/apache/cassandra/dht/RandomPartitioner.java +++ b/src/java/org/apache/cassandra/dht/RandomPartitioner.java @@ -17,6 +17,7 @@ */ package org.apache.cassandra.dht; +import java.io.IOException; import java.math.BigDecimal; import java.math.BigInteger; import java.nio.ByteBuffer; @@ -31,6 +32,7 @@ import org.apache.cassandra.db.DecoratedKey; import org.apache.cassandra.db.marshal.AbstractType; import org.apache.cassandra.db.marshal.IntegerType; import org.apache.cassandra.db.marshal.PartitionerDefinedOrder; +import org.apache.cassandra.io.util.DataOutputPlus; import org.apache.cassandra.utils.ByteBufferUtil; import org.apache.cassandra.utils.FBUtilities; import org.apache.cassandra.utils.GuidGenerator; @@ -46,6 +48,7 @@ public class RandomPartitioner implements IPartitioner public static final BigInteger ZERO = new BigInteger("0"); public static final BigIntegerToken MINIMUM = new BigIntegerToken("-1"); public static final BigInteger MAXIMUM = new BigInteger("2").pow(127); + public static final int MAXIMUM_TOKEN_SIZE = MAXIMUM.bitLength() / 8 + 1; /** * Maintain a separate threadlocal message digest, exclusively for token hashing. This is necessary because @@ -162,11 +165,35 @@ public class RandomPartitioner implements IPartitioner return ByteBuffer.wrap(bigIntegerToken.token.toByteArray()); } + @Override + public void serialize(Token token, DataOutputPlus out) throws IOException + { + out.write(((BigIntegerToken) token).token.toByteArray()); + } + + @Override + public void serialize(Token token, ByteBuffer out) + { + out.put(((BigIntegerToken) token).token.toByteArray()); + } + + @Override + public int byteSize(Token token) + { + return ((BigIntegerToken) token).token.bitLength() / 8 + 1; + } + public Token fromByteArray(ByteBuffer bytes) { return new BigIntegerToken(new BigInteger(ByteBufferUtil.getArray(bytes))); } + @Override + public Token fromByteBuffer(ByteBuffer bytes, int position, int length) + { + return new BigIntegerToken(new BigInteger(ByteBufferUtil.getArray(bytes, position, length))); + } + public String toString(Token token) { BigIntegerToken bigIntegerToken = (BigIntegerToken) token; @@ -252,6 +279,11 @@ public class RandomPartitioner implements IPartitioner return new BigIntegerToken(hashToBigInteger(key)); } + public int getMaxTokenSize() + { + return MAXIMUM_TOKEN_SIZE; + } + public Map<Token, Float> describeOwnership(List<Token> sortedTokens) { Map<Token, Float> ownerships = new HashMap<Token, Float>(); diff --git a/src/java/org/apache/cassandra/dht/Token.java b/src/java/org/apache/cassandra/dht/Token.java index 20b45ef..ccb66fd 100644 --- a/src/java/org/apache/cassandra/dht/Token.java +++ b/src/java/org/apache/cassandra/dht/Token.java @@ -26,7 +26,6 @@ import org.apache.cassandra.db.PartitionPosition; import org.apache.cassandra.db.TypeSizes; import org.apache.cassandra.exceptions.ConfigurationException; import org.apache.cassandra.io.util.DataOutputPlus; -import org.apache.cassandra.utils.ByteBufferUtil; public abstract class Token implements RingPosition<Token>, Serializable { @@ -40,8 +39,30 @@ public abstract class Token implements RingPosition<Token>, Serializable public abstract Token fromByteArray(ByteBuffer bytes); public abstract String toString(Token token); // serialize as string, not necessarily human-readable public abstract Token fromString(String string); // deserialize - public abstract void validate(String token) throws ConfigurationException; + + public void serialize(Token token, DataOutputPlus out) throws IOException + { + out.write(toByteArray(token)); + } + + public void serialize(Token token, ByteBuffer out) throws IOException + { + out.put(toByteArray(token)); + } + + public Token fromByteBuffer(ByteBuffer bytes, int position, int length) + { + bytes = bytes.duplicate(); + bytes.position(position) + .limit(position + length); + return fromByteArray(bytes); + } + + public int byteSize(Token token) + { + return toByteArray(token).remaining(); + } } public static class TokenSerializer implements IPartitionerDependentSerializer<Token> @@ -49,23 +70,28 @@ public abstract class Token implements RingPosition<Token>, Serializable public void serialize(Token token, DataOutputPlus out, int version) throws IOException { IPartitioner p = token.getPartitioner(); - ByteBuffer b = p.getTokenFactory().toByteArray(token); - ByteBufferUtil.writeWithLength(b, out); + out.writeInt(p.getTokenFactory().byteSize(token)); + p.getTokenFactory().serialize(token, out); } public Token deserialize(DataInput in, IPartitioner p, int version) throws IOException { - int size = in.readInt(); + int size = deserializeSize(in); byte[] bytes = new byte[size]; in.readFully(bytes); return p.getTokenFactory().fromByteArray(ByteBuffer.wrap(bytes)); } + public int deserializeSize(DataInput in) throws IOException + { + return in.readInt(); + } + public long serializedSize(Token object, int version) { IPartitioner p = object.getPartitioner(); - ByteBuffer b = p.getTokenFactory().toByteArray(object); - return TypeSizes.sizeof(b.remaining()) + b.remaining(); + int byteSize = p.getTokenFactory().byteSize(object); + return TypeSizes.sizeof(byteSize) + byteSize; } } diff --git a/src/java/org/apache/cassandra/repair/RepairJob.java b/src/java/org/apache/cassandra/repair/RepairJob.java index a67aac0..f682bfb 100644 --- a/src/java/org/apache/cassandra/repair/RepairJob.java +++ b/src/java/org/apache/cassandra/repair/RepairJob.java @@ -236,7 +236,10 @@ public class RepairJob extends AbstractFuture<RepairResult> implements Runnable } syncTasks.add(task); } + trees.get(i).trees.release(); } + trees.get(trees.size() - 1).trees.release(); + return syncTasks; } diff --git a/src/java/org/apache/cassandra/repair/Validator.java b/src/java/org/apache/cassandra/repair/Validator.java index c8becaa..9a89fa6 100644 --- a/src/java/org/apache/cassandra/repair/Validator.java +++ b/src/java/org/apache/cassandra/repair/Validator.java @@ -19,6 +19,7 @@ package org.apache.cassandra.repair; import java.nio.ByteBuffer; import java.nio.charset.Charset; +import java.io.IOException; import java.util.ArrayList; import java.util.List; import java.util.Random; @@ -29,7 +30,6 @@ import com.google.common.hash.HashCode; import com.google.common.hash.HashFunction; import com.google.common.hash.Hasher; import com.google.common.hash.Hashing; - import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -46,6 +46,7 @@ import org.apache.cassandra.net.Message; import org.apache.cassandra.net.MessagingService; import org.apache.cassandra.repair.messages.ValidationComplete; import org.apache.cassandra.streaming.PreviewKind; +import org.apache.cassandra.service.ActiveRepairService; import org.apache.cassandra.tracing.Tracing; import org.apache.cassandra.utils.FBUtilities; import org.apache.cassandra.utils.MerkleTree; @@ -149,7 +150,7 @@ public class Validator implements Runnable } } logger.debug("Prepared AEService trees of size {} for {}", trees.size(), desc); - ranges = tree.invalids(); + ranges = tree.rangeIterator(); } /** @@ -172,7 +173,7 @@ public class Validator implements Runnable if (!findCorrectRange(lastKey.getToken())) { // add the empty hash, and move to the next range - ranges = trees.invalids(); + ranges = trees.rangeIterator(); findCorrectRange(lastKey.getToken()); } @@ -368,9 +369,7 @@ public class Validator implements Runnable */ public void complete() { - completeTree(); - - StageManager.getStage(Stage.ANTI_ENTROPY).execute(this); + assert ranges != null : "Validator was not prepared()"; if (logger.isDebugEnabled()) { @@ -380,20 +379,8 @@ public class Validator implements Runnable logger.debug("Validated {} partitions for {}. Partition sizes are:", validated, desc.sessionId); trees.logRowSizePerLeaf(logger); } - } - - @VisibleForTesting - public void completeTree() - { - assert ranges != null : "Validator was not prepared()"; - - ranges = trees.invalids(); - while (ranges.hasNext()) - { - range = ranges.next(); - range.ensureHashInitialised(); - } + StageManager.getStage(Stage.ANTI_ENTROPY).execute(this); } /** @@ -404,8 +391,7 @@ public class Validator implements Runnable public void fail() { logger.error("Failed creating a merkle tree for {}, {} (see log for details)", desc, initiator); - // send fail message only to nodes >= version 2.0 - MessagingService.instance().send(Message.out(REPAIR_REQ, new ValidationComplete(desc)), initiator); + respond(new ValidationComplete(desc)); } /** @@ -413,12 +399,51 @@ public class Validator implements Runnable */ public void run() { - // respond to the request that triggered this validation - if (!initiator.equals(FBUtilities.getBroadcastAddressAndPort())) + if (initiatorIsRemote()) { logger.info("{} Sending completed merkle tree to {} for {}.{}", previewKind.logPrefix(desc.sessionId), initiator, desc.keyspace, desc.columnFamily); Tracing.traceRepair("Sending completed merkle tree to {} for {}.{}", initiator, desc.keyspace, desc.columnFamily); } - MessagingService.instance().send(Message.out(REPAIR_REQ, new ValidationComplete(desc, trees)), initiator); + else + { + logger.info("{} Local completed merkle tree for {} for {}.{}", previewKind.logPrefix(desc.sessionId), initiator, desc.keyspace, desc.columnFamily); + Tracing.traceRepair("Local completed merkle tree for {} for {}.{}", initiator, desc.keyspace, desc.columnFamily); + + } + respond(new ValidationComplete(desc, trees)); + } + + private boolean initiatorIsRemote() + { + return !FBUtilities.getBroadcastAddressAndPort().equals(initiator); + } + + private void respond(ValidationComplete response) + { + if (initiatorIsRemote()) + { + MessagingService.instance().send(Message.out(REPAIR_REQ, response), initiator); + return; + } + + /* + * For local initiators, DO NOT send the message to self over loopback. This is a wasted ser/de loop + * and a ton of garbage. Instead, move the trees off heap and invoke message handler. We could do it + * directly, since this method will only be called from {@code Stage.ENTI_ENTROPY}, but we do instead + * execute a {@code Runnable} on the stage - in case that assumption ever changes by accident. + */ + StageManager.getStage(Stage.ANTI_ENTROPY).execute(() -> + { + ValidationComplete movedResponse = response; + try + { + movedResponse = response.tryMoveOffHeap(); + } + catch (IOException e) + { + logger.error("Failed to move local merkle tree for {} off heap", desc, e); + } + ActiveRepairService.instance.handleMessage(initiator, movedResponse); + }); } } diff --git a/src/java/org/apache/cassandra/repair/messages/ValidationComplete.java b/src/java/org/apache/cassandra/repair/messages/ValidationComplete.java index 704bffb..b8aa736 100644 --- a/src/java/org/apache/cassandra/repair/messages/ValidationComplete.java +++ b/src/java/org/apache/cassandra/repair/messages/ValidationComplete.java @@ -56,6 +56,15 @@ public class ValidationComplete extends RepairMessage return trees != null; } + /** + * @return a new {@link ValidationComplete} instance with all trees moved off heap, or {@code this} + * if it's a failure response. + */ + public ValidationComplete tryMoveOffHeap() throws IOException + { + return trees == null ? this : new ValidationComplete(desc, trees.tryMoveOffHeap()); + } + @Override public boolean equals(Object o) { diff --git a/src/java/org/apache/cassandra/service/ActiveRepairService.java b/src/java/org/apache/cassandra/service/ActiveRepairService.java index abfd6d9..689771c 100644 --- a/src/java/org/apache/cassandra/service/ActiveRepairService.java +++ b/src/java/org/apache/cassandra/service/ActiveRepairService.java @@ -248,6 +248,16 @@ public class ActiveRepairService implements IEndpointStateChangeSubscriber, IFai return session; } + public boolean getOffheapMerkleTreesEnabled() + { + return DatabaseDescriptor.getOffheapMerkleTreesEnabled(); + } + + public void setOffheapMerkleTreesEnabled(boolean enabled) + { + DatabaseDescriptor.setOffheapMerkleTreesEnabled(enabled); + } + private <T extends AbstractFuture & IEndpointStateChangeSubscriber & IFailureDetectionEventListener> void registerOnFdAndGossip(final T task) diff --git a/src/java/org/apache/cassandra/utils/ByteBufferUtil.java b/src/java/org/apache/cassandra/utils/ByteBufferUtil.java index 788300c..518436e 100644 --- a/src/java/org/apache/cassandra/utils/ByteBufferUtil.java +++ b/src/java/org/apache/cassandra/utils/ByteBufferUtil.java @@ -101,6 +101,16 @@ public class ByteBufferUtil return FastByteOperations.compareUnsigned(o1, o2, 0, o2.length); } + public static int compare(ByteBuffer o1, int s1, int l1, byte[] o2) + { + return FastByteOperations.compareUnsigned(o1, s1, l1, o2, 0, o2.length); + } + + public static int compare(byte[] o1, ByteBuffer o2, int s2, int l2) + { + return FastByteOperations.compareUnsigned(o1, 0, o1.length, o2, s2, l2); + } + /** * Decode a String representation. * This method assumes that the encoding charset is UTF_8. @@ -161,16 +171,25 @@ public class ByteBufferUtil */ public static byte[] getArray(ByteBuffer buffer) { - int length = buffer.remaining(); + return getArray(buffer, buffer.position(), buffer.remaining()); + } + + /** + * You should almost never use this. Instead, use the write* methods to avoid copies. + */ + public static byte[] getArray(ByteBuffer buffer, int position, int length) + { if (buffer.hasArray()) { - int boff = buffer.arrayOffset() + buffer.position(); + int boff = buffer.arrayOffset() + position; return Arrays.copyOfRange(buffer.array(), boff, boff + length); } + // else, DirectByteBuffer.get() is the fastest route byte[] bytes = new byte[length]; - buffer.duplicate().get(bytes); - + ByteBuffer dup = buffer.duplicate(); + dup.position(position).limit(position + length); + dup.get(bytes); return bytes; } @@ -631,6 +650,7 @@ public class ByteBufferUtil assert bytes1.limit() >= offset1 + length : "The first byte array isn't long enough for the specified offset and length."; assert bytes2.limit() >= offset2 + length : "The second byte array isn't long enough for the specified offset and length."; + for (int i = 0; i < length; i++) { byte byte1 = bytes1.get(offset1 + i); @@ -669,7 +689,7 @@ public class ByteBufferUtil return buf.capacity() > buf.remaining() || !buf.hasArray() ? ByteBuffer.wrap(getArray(buf)) : buf; } - // Doesn't change bb position + // doesn't change bb position public static int getShortLength(ByteBuffer bb, int position) { int length = (bb.get(position) & 0xFF) << 8; diff --git a/src/java/org/apache/cassandra/utils/FBUtilities.java b/src/java/org/apache/cassandra/utils/FBUtilities.java index c37dcca..0f7b2ca 100644 --- a/src/java/org/apache/cassandra/utils/FBUtilities.java +++ b/src/java/org/apache/cassandra/utils/FBUtilities.java @@ -289,6 +289,22 @@ public class FBUtilities return out; } + /** + * Bitwise XOR of the inputs, in place on the left array + * + * Assumes inputs are same length + */ + static void xorOntoLeft(byte[] left, byte[] right) + { + if (left == null || right == null) + return; + + assert left.length == right.length; + + for (int i = 0; i < left.length; i++) + left[i] = (byte) ((left[i] & 0xFF) ^ (right[i] & 0xFF)); + } + public static void sortSampledKeys(List<DecoratedKey> keys, Range<Token> range) { if (range.left.compareTo(range.right) >= 0) diff --git a/src/java/org/apache/cassandra/utils/FastByteOperations.java b/src/java/org/apache/cassandra/utils/FastByteOperations.java index 6581736..060dee5 100644 --- a/src/java/org/apache/cassandra/utils/FastByteOperations.java +++ b/src/java/org/apache/cassandra/utils/FastByteOperations.java @@ -55,6 +55,16 @@ public class FastByteOperations return -BestHolder.BEST.compare(b2, b1, s1, l1); } + public static int compareUnsigned(ByteBuffer b1, int s1, int l1, byte[] b2, int s2, int l2) + { + return BestHolder.BEST.compare(b1, s1, l1, b2, s2, l2); + } + + public static int compareUnsigned(byte[] b1, int s1, int l1, ByteBuffer b2, int s2, int l2) + { + return -BestHolder.BEST.compare(b2, s2, l2, b1, s1, l1); + } + public static int compareUnsigned(ByteBuffer b1, ByteBuffer b2) { return BestHolder.BEST.compare(b1, b2); @@ -77,6 +87,8 @@ public class FastByteOperations abstract public int compare(ByteBuffer buffer1, byte[] buffer2, int offset2, int length2); + abstract public int compare(ByteBuffer buffer1, int offset1, int length1, byte[] buffer2, int offset2, int length2); + abstract public int compare(ByteBuffer buffer1, ByteBuffer buffer2); abstract public void copy(ByteBuffer src, int srcPosition, byte[] trg, int trgPosition, int length); @@ -187,25 +199,24 @@ public class FastByteOperations public int compare(ByteBuffer buffer1, byte[] buffer2, int offset2, int length2) { + return compare(buffer1, buffer1.position(), buffer1.remaining(), buffer2, offset2, length2); + } + + public int compare(ByteBuffer buffer1, int position1, int length1, byte[] buffer2, int offset2, int length2) + { Object obj1; long offset1; if (buffer1.hasArray()) { obj1 = buffer1.array(); - offset1 = BYTE_ARRAY_BASE_OFFSET + buffer1.arrayOffset(); + offset1 = BYTE_ARRAY_BASE_OFFSET + buffer1.arrayOffset() + position1; } else { obj1 = null; - offset1 = theUnsafe.getLong(buffer1, DIRECT_BUFFER_ADDRESS_OFFSET); - } - int length1; - { - int position = buffer1.position(); - int limit = buffer1.limit(); - length1 = limit - position; - offset1 += position; + offset1 = theUnsafe.getLong(buffer1, DIRECT_BUFFER_ADDRESS_OFFSET) + position1; } + return compareTo(obj1, offset1, length1, buffer2, BYTE_ARRAY_BASE_OFFSET + offset2, length2); } @@ -397,11 +408,28 @@ public class FastByteOperations return length1 - length2; } + public int compare(ByteBuffer buffer1, int position1, int length1, byte[] buffer2, int offset2, int length2) + { + if (buffer1.hasArray()) + return compare(buffer1.array(), buffer1.arrayOffset() + position1, length1, buffer2, offset2, length2); + + if (position1 != buffer1.position()) + { + buffer1 = buffer1.duplicate(); + buffer1.position(position1); + } + + return compare(buffer1, ByteBuffer.wrap(buffer2, offset2, length2)); + } + public int compare(ByteBuffer buffer1, byte[] buffer2, int offset2, int length2) { if (buffer1.hasArray()) + { return compare(buffer1.array(), buffer1.arrayOffset() + buffer1.position(), buffer1.remaining(), buffer2, offset2, length2); + } + return compare(buffer1, ByteBuffer.wrap(buffer2, offset2, length2)); } diff --git a/src/java/org/apache/cassandra/utils/MerkleTree.java b/src/java/org/apache/cassandra/utils/MerkleTree.java index d131ff5..9d9eadb 100644 --- a/src/java/org/apache/cassandra/utils/MerkleTree.java +++ b/src/java/org/apache/cassandra/utils/MerkleTree.java @@ -19,25 +19,33 @@ package org.apache.cassandra.utils; import java.io.DataInput; import java.io.IOException; -import java.io.Serializable; +import java.nio.ByteBuffer; import java.util.*; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; import com.google.common.collect.PeekingIterator; +import com.google.common.primitives.Ints; +import com.google.common.primitives.Shorts; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import org.apache.cassandra.db.TypeSizes; +import org.apache.cassandra.config.DatabaseDescriptor; import org.apache.cassandra.dht.IPartitioner; -import org.apache.cassandra.dht.IPartitionerDependentSerializer; +import org.apache.cassandra.dht.Murmur3Partitioner; +import org.apache.cassandra.dht.RandomPartitioner; import org.apache.cassandra.dht.Range; import org.apache.cassandra.dht.Token; import org.apache.cassandra.exceptions.ConfigurationException; -import org.apache.cassandra.io.IVersionedSerializer; import org.apache.cassandra.io.util.DataInputPlus; import org.apache.cassandra.io.util.DataOutputPlus; +import org.apache.cassandra.io.util.FileUtils; +import org.apache.cassandra.utils.concurrent.Ref; +import org.apache.cassandra.utils.memory.MemoryUtil; + +import static org.apache.cassandra.db.TypeSizes.sizeof; +import static org.apache.cassandra.utils.ByteBufferUtil.compare; /** * A MerkleTree implemented as a binary tree. @@ -59,84 +67,51 @@ import org.apache.cassandra.io.util.DataOutputPlus; * If two MerkleTrees have the same hashdepth, they represent a perfect tree * of the same depth, and can always be compared, regardless of size or splits. */ -public class MerkleTree implements Serializable +public class MerkleTree { - private static Logger logger = LoggerFactory.getLogger(MerkleTree.class); + private static final Logger logger = LoggerFactory.getLogger(MerkleTree.class); + + private static final int HASH_SIZE = 32; // SHA-256 = 32 bytes. + private static final byte[] EMPTY_HASH = new byte[HASH_SIZE]; + + /* + * Thread-local byte array, large enough to host SHA256 bytes or MM3/Random partitoners' tokens + */ + private static final ThreadLocal<byte[]> byteArray = ThreadLocal.withInitial(() -> new byte[HASH_SIZE]); - public static final MerkleTreeSerializer serializer = new MerkleTreeSerializer(); - private static final long serialVersionUID = 2L; + private static byte[] getTempArray(int minimumSize) + { + return minimumSize <= HASH_SIZE ? byteArray.get() : new byte[minimumSize]; + } public static final byte RECOMMENDED_DEPTH = Byte.MAX_VALUE - 1; - public static final int CONSISTENT = 0; - public static final int FULLY_INCONSISTENT = 1; - public static final int PARTIALLY_INCONSISTENT = 2; - private static final byte[] EMPTY_HASH = new byte[0]; + @SuppressWarnings("WeakerAccess") + static final int CONSISTENT = 0; + static final int FULLY_INCONSISTENT = 1; + @SuppressWarnings("WeakerAccess") + static final int PARTIALLY_INCONSISTENT = 2; - public final byte hashdepth; + private final int hashdepth; /** The top level range that this MerkleTree covers. */ - public final Range<Token> fullRange; + final Range<Token> fullRange; private final IPartitioner partitioner; private long maxsize; private long size; - private Hashable root; + private Node root; - public static class MerkleTreeSerializer implements IVersionedSerializer<MerkleTree> + /** + * @param partitioner The partitioner in use. + * @param range the range this tree covers + * @param hashdepth The maximum depth of the tree. 100/(2^depth) is the % + * of the key space covered by each subrange of a fully populated tree. + * @param maxsize The maximum number of subranges in the tree. + */ + public MerkleTree(IPartitioner partitioner, Range<Token> range, int hashdepth, long maxsize) { - public void serialize(MerkleTree mt, DataOutputPlus out, int version) throws IOException - { - out.writeByte(mt.hashdepth); - out.writeLong(mt.maxsize); - out.writeLong(mt.size); - out.writeUTF(mt.partitioner.getClass().getCanonicalName()); - // full range - Token.serializer.serialize(mt.fullRange.left, out, version); - Token.serializer.serialize(mt.fullRange.right, out, version); - Hashable.serializer.serialize(mt.root, out, version); - } - - public MerkleTree deserialize(DataInputPlus in, int version) throws IOException - { - byte hashdepth = in.readByte(); - long maxsize = in.readLong(); - long size = in.readLong(); - IPartitioner partitioner; - try - { - partitioner = FBUtilities.newPartitioner(in.readUTF()); - } - catch (ConfigurationException e) - { - throw new IOException(e); - } - - // full range - Token left = Token.serializer.deserialize(in, partitioner, version); - Token right = Token.serializer.deserialize(in, partitioner, version); - Range<Token> fullRange = new Range<>(left, right); - - MerkleTree mt = new MerkleTree(partitioner, fullRange, hashdepth, maxsize); - mt.size = size; - mt.root = Hashable.serializer.deserialize(in, partitioner, version); - return mt; - } - - public long serializedSize(MerkleTree mt, int version) - { - long size = 1 // mt.hashdepth - + TypeSizes.sizeof(mt.maxsize) - + TypeSizes.sizeof(mt.size) - + TypeSizes.sizeof(mt.partitioner.getClass().getCanonicalName()); - - // full range - size += Token.serializer.serializedSize(mt.fullRange.left, version); - size += Token.serializer.serializedSize(mt.fullRange.right, version); - - size += Hashable.serializer.serializedSize(mt.root, version); - return size; - } + this(new OnHeapLeaf(), partitioner, range, hashdepth, maxsize, 1); } /** @@ -145,60 +120,56 @@ public class MerkleTree implements Serializable * @param hashdepth The maximum depth of the tree. 100/(2^depth) is the % * of the key space covered by each subrange of a fully populated tree. * @param maxsize The maximum number of subranges in the tree. + * @param size The size of the tree. Typically 1, unless deserilized from an existing tree */ - public MerkleTree(IPartitioner partitioner, Range<Token> range, byte hashdepth, long maxsize) + private MerkleTree(Node root, IPartitioner partitioner, Range<Token> range, int hashdepth, long maxsize, long size) { assert hashdepth < Byte.MAX_VALUE; + + this.root = root; this.fullRange = Preconditions.checkNotNull(range); this.partitioner = Preconditions.checkNotNull(partitioner); this.hashdepth = hashdepth; this.maxsize = maxsize; - - size = 1; - root = new Leaf(null); - } - - - static byte inc(byte in) - { - assert in < Byte.MAX_VALUE; - return (byte)(in + 1); + this.size = size; } /** * Initializes this tree by splitting it until hashdepth is reached, * or until an additional level of splits would violate maxsize. * - * NB: Replaces all nodes in the tree. + * NB: Replaces all nodes in the tree, and always builds on the heap */ public void init() { // determine the depth to which we can safely split the tree - byte sizedepth = (byte)(Math.log10(maxsize) / Math.log10(2)); - byte depth = (byte)Math.min(sizedepth, hashdepth); + int sizedepth = (int) (Math.log10(maxsize) / Math.log10(2)); + int depth = Math.min(sizedepth, hashdepth); + + root = initHelper(fullRange.left, fullRange.right, 0, depth); + size = (long) Math.pow(2, depth); + } - root = initHelper(fullRange.left, fullRange.right, (byte)0, depth); - size = (long)Math.pow(2, depth); + public void release() + { + if (root instanceof OffHeapNode) + ((OffHeapNode) root).release(); + root = null; } - private Hashable initHelper(Token left, Token right, byte depth, byte max) + private OnHeapNode initHelper(Token left, Token right, int depth, int max) { if (depth == max) // we've reached the leaves - return new Leaf(); + return new OnHeapLeaf(); Token midpoint = partitioner.midpoint(left, right); if (midpoint.equals(left) || midpoint.equals(right)) - return new Leaf(); - - Hashable lchild = initHelper(left, midpoint, inc(depth), max); - Hashable rchild = initHelper(midpoint, right, inc(depth), max); - return new Inner(midpoint, lchild, rchild); - } + return new OnHeapLeaf(); - Hashable root() - { - return root; + OnHeapNode leftChild = initHelper(left, midpoint, depth + 1, max); + OnHeapNode rightChild = initHelper(midpoint, right, depth + 1, max); + return new OnHeapInner(midpoint, leftChild, rightChild); } public IPartitioner partitioner() @@ -233,20 +204,17 @@ public class MerkleTree implements Serializable public static List<TreeRange> difference(MerkleTree ltree, MerkleTree rtree) { if (!ltree.fullRange.equals(rtree.fullRange)) - throw new IllegalArgumentException("Difference only make sense on tree covering the same range (but " + ltree.fullRange + " != " + rtree.fullRange + ")"); + throw new IllegalArgumentException("Difference only make sense on tree covering the same range (but " + ltree.fullRange + " != " + rtree.fullRange + ')'); List<TreeRange> diff = new ArrayList<>(); - TreeDifference active = new TreeDifference(ltree.fullRange.left, ltree.fullRange.right, (byte)0); + TreeRange active = new TreeRange(ltree.fullRange.left, ltree.fullRange.right, 0); - Hashable lnode = ltree.find(active); - Hashable rnode = rtree.find(active); - byte[] lhash = lnode.hash(); - byte[] rhash = rnode.hash(); - active.setSize(lnode.sizeOfRange(), rnode.sizeOfRange()); + Node lnode = ltree.find(active); + Node rnode = rtree.find(active); - if (lhash != null && rhash != null && !Arrays.equals(lhash, rhash)) + if (lnode.hashesDiffer(rnode)) { - if(lnode instanceof Leaf || rnode instanceof Leaf) + if (lnode instanceof Leaf || rnode instanceof Leaf) { logger.debug("Digest mismatch detected among leaf nodes {}, {}", lnode, rnode); diff.add(active); @@ -261,14 +229,12 @@ public class MerkleTree implements Serializable } } } - else if (lhash == null || rhash == null) - diff.add(active); + return diff; } /** - * TODO: This function could be optimized into a depth first traversal of - * the two trees in parallel. + * TODO: This function could be optimized into a depth first traversal of the two trees in parallel. * * Takes two trees and a range for which they have hashes, but are inconsistent. * @return FULLY_INCONSISTENT if active is inconsistent, PARTIALLY_INCONSISTENT if only a subrange is inconsistent. @@ -289,54 +255,39 @@ public class MerkleTree implements Serializable return FULLY_INCONSISTENT; } - TreeDifference left = new TreeDifference(active.left, midpoint, inc(active.depth)); - TreeDifference right = new TreeDifference(midpoint, active.right, inc(active.depth)); + TreeRange left = new TreeRange(active.left, midpoint, active.depth + 1); + TreeRange right = new TreeRange(midpoint, active.right, active.depth + 1); logger.debug("({}) Hashing sub-ranges [{}, {}] for {} divided by midpoint {}", active.depth, left, right, active, midpoint); - byte[] lhash, rhash; - Hashable lnode, rnode; + Node lnode, rnode; // see if we should recurse left lnode = ltree.find(left); rnode = rtree.find(left); - lhash = lnode.hash(); - rhash = rnode.hash(); - left.setSize(lnode.sizeOfRange(), rnode.sizeOfRange()); - left.setRows(lnode.rowsInRange(), rnode.rowsInRange()); int ldiff = CONSISTENT; - boolean lreso = lhash != null && rhash != null; - if (lreso && !Arrays.equals(lhash, rhash)) + if (lnode.hashesDiffer(rnode)) { logger.debug("({}) Inconsistent digest on left sub-range {}: [{}, {}]", active.depth, left, lnode, rnode); - if (lnode instanceof Leaf) ldiff = FULLY_INCONSISTENT; - else ldiff = differenceHelper(ltree, rtree, diff, left); - } - else if (!lreso) - { - logger.debug("({}) Left sub-range fully inconsistent {}", active.depth, left); - ldiff = FULLY_INCONSISTENT; + + if (lnode instanceof Leaf) + ldiff = FULLY_INCONSISTENT; + else + ldiff = differenceHelper(ltree, rtree, diff, left); } // see if we should recurse right lnode = ltree.find(right); rnode = rtree.find(right); - lhash = lnode.hash(); - rhash = rnode.hash(); - right.setSize(lnode.sizeOfRange(), rnode.sizeOfRange()); - right.setRows(lnode.rowsInRange(), rnode.rowsInRange()); int rdiff = CONSISTENT; - boolean rreso = lhash != null && rhash != null; - if (rreso && !Arrays.equals(lhash, rhash)) + if (lnode.hashesDiffer(rnode)) { logger.debug("({}) Inconsistent digest on right sub-range {}: [{}, {}]", active.depth, right, lnode, rnode); - if (rnode instanceof Leaf) rdiff = FULLY_INCONSISTENT; - else rdiff = differenceHelper(ltree, rtree, diff, right); - } - else if (!rreso) - { - logger.debug("({}) Right sub-range fully inconsistent {}", active.depth, right); - rdiff = FULLY_INCONSISTENT; + + if (rnode instanceof Leaf) + rdiff = FULLY_INCONSISTENT; + else + rdiff = differenceHelper(ltree, rtree, diff, right); } if (ldiff == FULLY_INCONSISTENT && rdiff == FULLY_INCONSISTENT) @@ -367,32 +318,36 @@ public class MerkleTree implements Serializable */ public TreeRange get(Token t) { - return getHelper(root, fullRange.left, fullRange.right, (byte)0, t); + return getHelper(root, fullRange.left, fullRange.right, t); } - TreeRange getHelper(Hashable hashable, Token pleft, Token pright, byte depth, Token t) + private TreeRange getHelper(Node node, Token pleft, Token pright, Token t) { + int depth = 0; + while (true) { - if (hashable instanceof Leaf) + if (node instanceof Leaf) { // we've reached a hash: wrap it up and deliver it - return new TreeRange(this, pleft, pright, depth, hashable); + return new TreeRange(this, pleft, pright, depth, node); } - // else: node. - - Inner node = (Inner) hashable; - depth = inc(depth); - if (Range.contains(pleft, node.token, t)) - { // left child contains token - hashable = node.lchild; - pright = node.token; + + assert node instanceof Inner; + Inner inner = (Inner) node; + + if (Range.contains(pleft, inner.token(), t)) // left child contains token + { + pright = inner.token(); + node = inner.left(); } - else - { // else: right child contains token - hashable = node.rchild; - pleft = node.token; + else // right child contains token + { + pleft = inner.token(); + node = inner.right(); } + + depth++; } } @@ -405,20 +360,21 @@ public class MerkleTree implements Serializable invalidateHelper(root, fullRange.left, t); } - private void invalidateHelper(Hashable hashable, Token pleft, Token t) + private void invalidateHelper(Node node, Token pleft, Token t) { - hashable.hash(null); - if (hashable instanceof Leaf) + node.hash(EMPTY_HASH); + + if (node instanceof Leaf) return; - // else: node. - Inner node = (Inner)hashable; - if (Range.contains(pleft, node.token, t)) - // left child contains token - invalidateHelper(node.lchild, pleft, t); + assert node instanceof Inner; + Inner inner = (Inner) node; + // TODO: reset computed flag on OnHeapInners + + if (Range.contains(pleft, inner.token(), t)) + invalidateHelper(inner.left(), pleft, t); // left child contains token else - // right child contains token - invalidateHelper(node.rchild, node.token, t); + invalidateHelper(inner.right(), inner.token(), t); // right child contains token } /** @@ -428,67 +384,106 @@ public class MerkleTree implements Serializable * NB: Currently does not support wrapping ranges that do not end with * partitioner.getMinimumToken(). * - * @return Null if any subrange of the range is invalid, or if the exact + * @return {@link #EMPTY_HASH} if any subrange of the range is invalid, or if the exact * range cannot be calculated using this tree. */ + @VisibleForTesting public byte[] hash(Range<Token> range) { return find(range).hash(); } /** - * Find the {@link Hashable} node that matches the given {@code range}. + * Exceptions that stop recursion early when we are sure that no answer + * can be found. + */ + static abstract class StopRecursion extends Exception + { + static class TooDeep extends StopRecursion {} + static class BadRange extends StopRecursion {} + } + + /** + * Find the {@link Node} node that matches the given {@code range}. * * @param range Range to find - * @return {@link Hashable} found. If nothing found, return {@link Leaf} with null hash. + * @return {@link Node} found. If nothing found, return {@link Leaf} with empty hash. */ - private Hashable find(Range<Token> range) + private Node find(Range<Token> range) + { + try + { + return findHelper(root, new Range<>(fullRange.left, fullRange.right), range); + } + catch (StopRecursion e) + { + return new OnHeapLeaf(); + } + } + + interface Consumer<E extends Exception> + { + void accept(Node node) throws E; + } + + @VisibleForTesting + <E extends Exception> boolean ifHashesRange(Range<Token> range, Consumer<E> consumer) throws E { try { - return findHelper(root, new Range<Token>(fullRange.left, fullRange.right), range); + Node node = findHelper(root, new Range<>(fullRange.left, fullRange.right), range); + boolean hasHash = !node.hasEmptyHash(); + if (hasHash) + consumer.accept(node); + return hasHash; } catch (StopRecursion e) { - return new Leaf(); + return false; } } + @VisibleForTesting + boolean hashesRange(Range<Token> range) + { + return ifHashesRange(range, n -> {}); + } + /** * @throws StopRecursion If no match could be found for the range. */ - private Hashable findHelper(Hashable current, Range<Token> activeRange, Range<Token> find) throws StopRecursion + private Node findHelper(Node current, Range<Token> activeRange, Range<Token> find) throws StopRecursion { while (true) { if (current instanceof Leaf) { if (!find.contains(activeRange)) - // we are not fully contained in this range! - throw new StopRecursion.BadRange(); + throw new StopRecursion.BadRange(); // we are not fully contained in this range! + return current; } - // else: node. - Inner node = (Inner) current; - Range<Token> leftRange = new Range<>(activeRange.left, node.token); - Range<Token> rightRange = new Range<>(node.token, activeRange.right); + assert current instanceof Inner; + Inner inner = (Inner) current; - if (find.contains(activeRange)) - // this node is fully contained in the range - return node.calc(); + Range<Token> leftRange = new Range<>(activeRange.left, inner.token()); + Range<Token> rightRange = new Range<>(inner.token(), activeRange.right); + + if (find.contains(activeRange)) // this node is fully contained in the range + return inner.compute(); // else: one of our children contains the range - if (leftRange.contains(find)) - { // left child contains/matches the range - current = node.lchild; + if (leftRange.contains(find)) // left child contains/matches the range + { activeRange = leftRange; + current = inner.left(); } - else if (rightRange.contains(find)) - { // right child contains/matches the range - current = node.rchild; + else if (rightRange.contains(find)) // right child contains/matches the range + { activeRange = rightRange; + current = inner.right(); } else { @@ -506,12 +501,12 @@ public class MerkleTree implements Serializable */ public boolean split(Token t) { - if (!(size < maxsize)) + if (size >= maxsize) return false; try { - root = splitHelper(root, fullRange.left, fullRange.right, (byte)0, t); + root = splitHelper(root, fullRange.left, fullRange.right, 0, t); } catch (StopRecursion.TooDeep e) { @@ -520,12 +515,12 @@ public class MerkleTree implements Serializable return true; } - private Hashable splitHelper(Hashable hashable, Token pleft, Token pright, byte depth, Token t) throws StopRecursion.TooDeep + private OnHeapNode splitHelper(Node node, Token pleft, Token pright, int depth, Token t) throws StopRecursion.TooDeep { if (depth >= hashdepth) throw new StopRecursion.TooDeep(); - if (hashable instanceof Leaf) + if (node instanceof Leaf) { Token midpoint = partitioner.midpoint(pleft, pright); @@ -536,47 +531,47 @@ public class MerkleTree implements Serializable // split size++; - return new Inner(midpoint, new Leaf(), new Leaf()); + return new OnHeapInner(midpoint, new OnHeapLeaf(), new OnHeapLeaf()); } // else: node. // recurse on the matching child - Inner node = (Inner)hashable; + assert node instanceof OnHeapInner; + OnHeapInner inner = (OnHeapInner) node; - if (Range.contains(pleft, node.token, t)) - // left child contains token - node.lchild(splitHelper(node.lchild, pleft, node.token, inc(depth), t)); - else - // else: right child contains token - node.rchild(splitHelper(node.rchild, node.token, pright, inc(depth), t)); - return node; + if (Range.contains(pleft, inner.token(), t)) // left child contains token + inner.left(splitHelper(inner.left(), pleft, inner.token(), depth + 1, t)); + else // else: right child contains token + inner.right(splitHelper(inner.right(), inner.token(), pright, depth + 1, t)); + + return inner; } /** * Returns a lazy iterator of invalid TreeRanges that need to be filled * in order to make the given Range valid. */ - public TreeRangeIterator invalids() + TreeRangeIterator rangeIterator() { return new TreeRangeIterator(this); } - public EstimatedHistogram histogramOfRowSizePerLeaf() + EstimatedHistogram histogramOfRowSizePerLeaf() { HistogramBuilder histbuild = new HistogramBuilder(); for (TreeRange range : new TreeRangeIterator(this)) { - histbuild.add(range.hashable.sizeOfRange); + histbuild.add(range.node.sizeOfRange()); } return histbuild.buildWithStdevRangesAroundMean(); } - public EstimatedHistogram histogramOfRowCountPerLeaf() + EstimatedHistogram histogramOfRowCountPerLeaf() { HistogramBuilder histbuild = new HistogramBuilder(); for (TreeRange range : new TreeRangeIterator(this)) { - histbuild.add(range.hashable.rowsInRange); + histbuild.add(range.node.partitionsInRange()); } return histbuild.buildWithStdevRangesAroundMean(); } @@ -586,7 +581,7 @@ public class MerkleTree implements Serializable long count = 0; for (TreeRange range : new TreeRangeIterator(this)) { - count += range.hashable.rowsInRange; + count += range.node.partitionsInRange(); } return count; } @@ -597,61 +592,23 @@ public class MerkleTree implements Serializable StringBuilder buff = new StringBuilder(); buff.append("#<MerkleTree root="); root.toString(buff, 8); - buff.append(">"); + buff.append('>'); return buff.toString(); } - public static class TreeDifference extends TreeRange + @Override + public boolean equals(Object other) { - private static final long serialVersionUID = 6363654174549968183L; - - private long sizeOnLeft; - private long sizeOnRight; - private long rowsOnLeft; - private long rowsOnRight; - - void setSize(long sizeOnLeft, long sizeOnRight) - { - this.sizeOnLeft = sizeOnLeft; - this.sizeOnRight = sizeOnRight; - } - - void setRows(long rowsOnLeft, long rowsOnRight) - { - this.rowsOnLeft = rowsOnLeft; - this.rowsOnRight = rowsOnRight; - } - - public long sizeOnLeft() - { - return sizeOnLeft; - } - - public long sizeOnRight() - { - return sizeOnRight; - } - - public long rowsOnLeft() - { - return rowsOnLeft; - } - - public long rowsOnRight() - { - return rowsOnRight; - } - - public TreeDifference(Token left, Token right, byte depth) - { - super(null, left, right, depth, null); - } - - public long totalRows() - { - return rowsOnLeft + rowsOnRight; - } - + if (!(other instanceof MerkleTree)) + return false; + MerkleTree that = (MerkleTree) other; + + return this.root.equals(that.root) + && this.fullRange.equals(that.fullRange) + && this.partitioner == that.partitioner + && this.hashdepth == that.hashdepth + && this.maxsize == that.maxsize + && this.size == that.size; } /** @@ -664,28 +621,27 @@ public class MerkleTree implements Serializable */ public static class TreeRange extends Range<Token> { - public static final long serialVersionUID = 1L; private final MerkleTree tree; - public final byte depth; - private final Hashable hashable; + public final int depth; + private final Node node; - TreeRange(MerkleTree tree, Token left, Token right, byte depth, Hashable hashable) + TreeRange(MerkleTree tree, Token left, Token right, int depth, Node node) { super(left, right); this.tree = tree; this.depth = depth; - this.hashable = hashable; + this.node = node; } - public void hash(byte[] hash) + TreeRange(Token left, Token right, int depth) { - assert tree != null : "Not intended for modification!"; - hashable.hash(hash); + this(null, left, right, depth, null); } - public byte[] hash() + public void hash(byte[] hash) { - return hashable.hash(); + assert tree != null : "Not intended for modification!"; + node.hash(hash); } /** @@ -694,18 +650,9 @@ public class MerkleTree implements Serializable public void addHash(RowHash entry) { assert tree != null : "Not intended for modification!"; - assert hashable instanceof Leaf; - - hashable.addHash(entry.hash, entry.size); - } - - public void ensureHashInitialised() - { - assert tree != null : "Not intended for modification!"; - assert hashable instanceof Leaf; - if (hashable.hash == null) - hashable.hash = EMPTY_HASH; + assert node instanceof OnHeapLeaf; + ((OnHeapLeaf) node).addHash(entry.hash, entry.size); } public void addAll(Iterator<RowHash> entries) @@ -717,9 +664,7 @@ public class MerkleTree implements Serializable @Override public String toString() { - StringBuilder buff = new StringBuilder("#<TreeRange "); - buff.append(super.toString()).append(" depth=").append(depth); - return buff.append(">").toString(); + return "#<TreeRange " + super.toString() + " depth=" + depth + '>'; } } @@ -740,8 +685,8 @@ public class MerkleTree implements Serializable TreeRangeIterator(MerkleTree tree) { - tovisit = new ArrayDeque<TreeRange>(); - tovisit.add(new TreeRange(tree, tree.fullRange.left, tree.fullRange.right, (byte)0, tree.root)); + tovisit = new ArrayDeque<>(); + tovisit.add(new TreeRange(tree, tree.fullRange.left, tree.fullRange.right, 0, tree.root)); this.tree = tree; } @@ -756,7 +701,7 @@ public class MerkleTree implements Serializable { TreeRange active = tovisit.pop(); - if (active.hashable instanceof Leaf) + if (active.node instanceof Leaf) { // found a leaf invalid range if (active.isWrapAround() && !tovisit.isEmpty()) @@ -765,9 +710,9 @@ public class MerkleTree implements Serializable return active; } - Inner node = (Inner)active.hashable; - TreeRange left = new TreeRange(tree, active.left, node.token, inc(active.depth), node.lchild); - TreeRange right = new TreeRange(tree, node.token, active.right, inc(active.depth), node.rchild); + Inner node = (Inner)active.node; + TreeRange left = new TreeRange(tree, active.left, node.token(), active.depth + 1, node.left()); + TreeRange right = new TreeRange(tree, node.token(), active.right, active.depth + 1, node.right()); if (right.isWrapAround()) { @@ -792,123 +737,357 @@ public class MerkleTree implements Serializable } /** - * An inner node in the MerkleTree. Inners can contain cached hash values, which - * are the binary hash of their two children. + * Hash value representing a row, to be used to pass hashes to the MerkleTree. + * The byte[] hash value should contain a digest of the key and value of the row + * created using a very strong hash function. */ - static class Inner extends Hashable + public static class RowHash { - public static final long serialVersionUID = 1L; - static final byte IDENT = 2; public final Token token; - private Hashable lchild; - private Hashable rchild; - - private static final InnerSerializer serializer = new InnerSerializer(); + public final byte[] hash; + public final long size; - /** - * Constructs an Inner with the given token and children, and a null hash. - */ - public Inner(Token token, Hashable lchild, Hashable rchild) + public RowHash(Token token, byte[] hash, long size) { - super(null); this.token = token; - this.lchild = lchild; - this.rchild = rchild; + this.hash = hash; + this.size = size; } - public Hashable lchild() + @Override + public String toString() { - return lchild; + return "#<RowHash " + token + ' ' + (hash == null ? "null" : Hex.bytesToHex(hash)) + " @ " + size + " bytes>"; } + } + + public void serialize(DataOutputPlus out, int version) throws IOException + { + out.writeByte(hashdepth); + out.writeLong(maxsize); + out.writeLong(size); + out.writeUTF(partitioner.getClass().getCanonicalName()); + Token.serializer.serialize(fullRange.left, out, version); + Token.serializer.serialize(fullRange.right, out, version); + root.serialize(out, version); + } + + public long serializedSize(int version) + { + long size = 1 // mt.hashdepth + + sizeof(maxsize) + + sizeof(this.size) + + sizeof(partitioner.getClass().getCanonicalName()); + size += Token.serializer.serializedSize(fullRange.left, version); + size += Token.serializer.serializedSize(fullRange.right, version); + size += root.serializedSize(version); + return size; + } + + public static MerkleTree deserialize(DataInputPlus in, int version) throws IOException + { + return deserialize(in, DatabaseDescriptor.getOffheapMerkleTreesEnabled(), version); + } + + public static MerkleTree deserialize(DataInputPlus in, boolean offHeapRequested, int version) throws IOException + { + int hashDepth = in.readByte(); + long maxSize = in.readLong(); + int innerNodeCount = Ints.checkedCast(in.readLong()); - public Hashable rchild() + IPartitioner partitioner; + try + { + partitioner = FBUtilities.newPartitioner(in.readUTF()); + } + catch (ConfigurationException e) { - return rchild; + throw new IOException(e); } - public void lchild(Hashable child) + Token left = Token.serializer.deserialize(in, partitioner, version); + Token right = Token.serializer.deserialize(in, partitioner, version); + Range<Token> fullRange = new Range<>(left, right); + Node root = deserializeTree(in, partitioner, innerNodeCount, offHeapRequested, version); + return new MerkleTree(root, partitioner, fullRange, hashDepth, maxSize, innerNodeCount); + } + + private static boolean warnedOnce; + private static Node deserializeTree(DataInputPlus in, IPartitioner partitioner, int innerNodeCount, boolean offHeapRequested, int version) throws IOException + { + boolean offHeapSupported = partitioner instanceof Murmur3Partitioner || partitioner instanceof RandomPartitioner; + + if (offHeapRequested && !offHeapSupported && !warnedOnce) { - lchild = child; + logger.warn("Configuration requests offheap memtables, but partitioner does not support it. Ignoring."); + warnedOnce = true; } - public void rchild(Hashable child) + return offHeapRequested && offHeapSupported + ? deserializeOffHeap(in, partitioner, innerNodeCount, version) + : OnHeapNode.deserialize(in, partitioner, version); + } + + /* + * Coordinating multiple trees from multiple replicas can get expensive. + * On the deserialization path, we know in advance what the tree looks like, + * So we can pre-size an offheap buffer and deserialize into that. + */ + + MerkleTree tryMoveOffHeap() throws IOException + { + boolean offHeapEnabled = DatabaseDescriptor.getOffheapMerkleTreesEnabled(); + boolean offHeapSupported = partitioner instanceof Murmur3Partitioner || partitioner instanceof RandomPartitioner; + + if (offHeapEnabled && !offHeapSupported && !warnedOnce) { - rchild = child; + logger.warn("Configuration requests offheap memtables, but partitioner does not support it. Ignoring."); + warnedOnce = true; } - Hashable calc() + return root instanceof OnHeapNode && offHeapEnabled && offHeapSupported ? moveOffHeap() : this; + } + + private MerkleTree moveOffHeap() throws IOException + { + assert root instanceof OnHeapNode; + int bufferSize = offHeapBufferSize(Ints.checkedCast(size), partitioner); + logger.debug("Allocating direct buffer of size {} to move merkle tree off heap", bufferSize); + ByteBuffer buffer = ByteBuffer.allocateDirect(bufferSize); + int pointer = ((OnHeapNode) root).serializeOffHeap(buffer, partitioner); + OffHeapNode newRoot = fromPointer(pointer, buffer, partitioner).attachRef(); + return new MerkleTree(newRoot, partitioner, fullRange, hashdepth, maxsize, size); + } + + private static OffHeapNode deserializeOffHeap(DataInputPlus in, IPartitioner partitioner, int innerNodeCount, int version) throws IOException + { + int bufferSize = offHeapBufferSize(innerNodeCount, partitioner); + logger.debug("Allocating direct buffer of size {} for merkle tree deserialization", bufferSize); + ByteBuffer buffer = ByteBuffer.allocateDirect(bufferSize); + int pointer = OffHeapNode.deserialize(in, buffer, partitioner, version); + return fromPointer(pointer, buffer, partitioner).attachRef(); + } + + private static OffHeapNode fromPointer(int pointer, ByteBuffer buffer, IPartitioner partitioner) + { + return pointer >= 0 ? new OffHeapInner(buffer, pointer, partitioner) : new OffHeapLeaf(buffer, ~pointer); + } + + private static int offHeapBufferSize(int innerNodeCount, IPartitioner partitioner) + { + return innerNodeCount * OffHeapInner.maxOffHeapSize(partitioner) + (innerNodeCount + 1) * OffHeapLeaf.maxOffHeapSize(); + } + + interface Node + { + byte[] hash(); + + boolean hasEmptyHash(); + + void hash(byte[] hash); + + boolean hashesDiffer(Node other); + + default Node compute() { - if (hash == null) - { - // hash and size haven't been calculated; calc children then compute - Hashable lnode = lchild.calc(); - Hashable rnode = rchild.calc(); - // cache the computed value - hash(lnode.hash, rnode.hash); - sizeOfRange = lnode.sizeOfRange + rnode.sizeOfRange; - rowsInRange = lnode.rowsInRange + rnode.rowsInRange; - } return this; } - /** - * Recursive toString. - */ - public void toString(StringBuilder buff, int maxdepth) + default long sizeOfRange() { - buff.append("#<").append(getClass().getSimpleName()); - buff.append(" ").append(token); - buff.append(" hash=").append(Hashable.toString(hash())); - buff.append(" children=["); - if (maxdepth < 1) - { - buff.append("#"); - } - else - { - if (lchild == null) - buff.append("null"); - else - lchild.toString(buff, maxdepth-1); - buff.append(" "); - if (rchild == null) - buff.append("null"); - else - rchild.toString(buff, maxdepth-1); - } - buff.append("]>"); + return 0; + } + + default long partitionsInRange() + { + return 0; + } + + void serialize(DataOutputPlus out, int version) throws IOException; + int serializedSize(int version); + + void toString(StringBuilder buff, int maxdepth); + + static String toString(byte[] hash) + { + return hash == null + ? "null" + : '[' + Hex.bytesToHex(hash) + ']'; + } + + boolean equals(Node node); + } + + static abstract class OnHeapNode implements Node + { + long sizeOfRange; + long partitionsInRange; + + protected byte[] hash; + + OnHeapNode(byte[] hash) + { + if (hash == null) + throw new IllegalArgumentException(); + + this.hash = hash; + } + + public byte[] hash() + { + return hash; + } + + public boolean hasEmptyHash() + { + //noinspection ArrayEquality + return hash == EMPTY_HASH; + } + + public void hash(byte[] hash) + { + if (hash == null) + throw new IllegalArgumentException(); + + this.hash = hash; + } + + public boolean hashesDiffer(Node other) + { + return other instanceof OnHeapNode + ? hashesDiffer( (OnHeapNode) other) + : hashesDiffer((OffHeapNode) other); + } + + private boolean hashesDiffer(OnHeapNode other) + { + return !Arrays.equals(hash(), other.hash()); + } + + private boolean hashesDiffer(OffHeapNode other) + { + return compare(hash(), other.buffer(), other.hashBytesOffset(), HASH_SIZE) != 0; } @Override - public String toString() + public long sizeOfRange() { - StringBuilder buff = new StringBuilder(); - toString(buff, 1); - return buff.toString(); + return sizeOfRange; } - private static class InnerSerializer implements IPartitionerDependentSerializer<Inner> + @Override + public long partitionsInRange() { - public void serialize(Inner inner, DataOutputPlus out, int version) throws IOException - { - Token.serializer.serialize(inner.token, out, version); - Hashable.serializer.serialize(inner.lchild, out, version); - Hashable.serializer.serialize(inner.rchild, out, version); - } + return partitionsInRange; + } - public Inner deserialize(DataInput in, IPartitioner p, int version) throws IOException + static OnHeapNode deserialize(DataInputPlus in, IPartitioner p, int version) throws IOException + { + byte ident = in.readByte(); + + switch (ident) { - Token token = Token.serializer.deserialize(in, p, version); - Hashable lchild = Hashable.serializer.deserialize(in, p, version); - Hashable rchild = Hashable.serializer.deserialize(in, p, version); - return new Inner(token, lchild, rchild); + case Inner.IDENT: + return OnHeapInner.deserializeWithoutIdent(in, p, version); + case Leaf.IDENT: + return OnHeapLeaf.deserializeWithoutIdent(in); + default: + throw new IOException("Unexpected node type: " + ident); } + } + + abstract int serializeOffHeap(ByteBuffer buffer, IPartitioner p) throws IOException; + } + + static abstract class OffHeapNode implements Node + { + protected final ByteBuffer buffer; + protected final int offset; + + OffHeapNode(ByteBuffer buffer, int offset) + { + this.buffer = buffer; + this.offset = offset; + } + + ByteBuffer buffer() + { + return buffer; + } + + public byte[] hash() + { + final int position = buffer.position(); + buffer.position(hashBytesOffset()); + byte[] array = new byte[HASH_SIZE]; + buffer.get(array); + buffer.position(position); + return array; + } + + public boolean hasEmptyHash() + { + return compare(buffer(), hashBytesOffset(), HASH_SIZE, EMPTY_HASH) == 0; + } + + public void hash(byte[] hash) + { + throw new UnsupportedOperationException(); + } - public long serializedSize(Inner inner, int version) + public boolean hashesDiffer(Node other) + { + return other instanceof OnHeapNode + ? hashesDiffer((OnHeapNode) other) + : hashesDiffer((OffHeapNode) other); + } + + private boolean hashesDiffer(OnHeapNode other) + { + return compare(buffer(), hashBytesOffset(), HASH_SIZE, other.hash()) != 0; + } + + private boolean hashesDiffer(OffHeapNode other) + { + int thisOffset = hashBytesOffset(); + int otherOffset = other.hashBytesOffset(); + + for (int i = 0; i < HASH_SIZE; i += 8) + if (buffer().getLong(thisOffset + i) != other.buffer().getLong(otherOffset + i)) + return true; + + return false; + } + + OffHeapNode attachRef() + { + if (Ref.DEBUG_ENABLED) + MemoryUtil.setAttachment(buffer, new Ref<>(this, null)); + return this; + } + + void release() + { + Object attachment = MemoryUtil.getAttachment(buffer); + if (attachment instanceof Ref) + ((Ref) attachment).release(); + FileUtils.clean(buffer); + } + + abstract int hashBytesOffset(); + + static int deserialize(DataInputPlus in, ByteBuffer buffer, IPartitioner p, int version) throws IOException + { + byte ident = in.readByte(); + + switch (ident) { - return Token.serializer.serializedSize(inner.token, version) - + Hashable.serializer.serializedSize(inner.lchild, version) - + Hashable.serializer.serializedSize(inner.rchild, version); + case Inner.IDENT: + return OffHeapInner.deserializeWithoutIdent(in, buffer, p, version); + case Leaf.IDENT: + return OffHeapLeaf.deserializeWithoutIdent(in, buffer); + default: + throw new IOException("Unexpected node type: " + ident); } } } @@ -922,237 +1101,440 @@ public class MerkleTree implements Serializable * tree extending below the Leaf is generated in memory, but only the root * is stored in the Leaf. */ - static class Leaf extends Hashable + interface Leaf extends Node { - public static final long serialVersionUID = 1L; static final byte IDENT = 1; - private static final LeafSerializer serializer = new LeafSerializer(); + + default void serialize(DataOutputPlus out, int version) throws IOException + { + byte[] hash = hash(); + assert hash.length == HASH_SIZE; + + out.writeByte(Leaf.IDENT); + + if (!hasEmptyHash()) + { + out.writeByte(HASH_SIZE); + out.write(hash); + } + else + { + out.writeByte(0); + } + } + + default int serializedSize(int version) + { + return 2 + (hasEmptyHash() ? 0 : HASH_SIZE); + } + + default void toString(StringBuilder buff, int maxdepth) + { + buff.append(toString()); + } + + default boolean equals(Node other) + { + return other instanceof Leaf && !hashesDiffer(other); + } + } + + static class OnHeapLeaf extends OnHeapNode implements Leaf + { + OnHeapLeaf() + { + super(EMPTY_HASH); + } + + OnHeapLeaf(byte[] hash) + { + super(hash); + } /** - * Constructs a null hash. + * Mixes the given value into our hash. If our hash is null, + * our hash will become the given value. */ - public Leaf() + void addHash(byte[] partitionHash, long partitionSize) { - super(null); + if (hasEmptyHash()) + hash(partitionHash); + else + FBUtilities.xorOntoLeft(hash, partitionHash); + + sizeOfRange += partitionSize; + partitionsInRange += 1; } - public Leaf(byte[] hash) + static OnHeapLeaf deserializeWithoutIdent(DataInputPlus in) throws IOException { - super(hash); + if (in.readByte() > 0) + { + byte[] hash = new byte[HASH_SIZE]; + in.readFully(hash); + return new OnHeapLeaf(hash); + } + else + { + return new OnHeapLeaf(); + } } - public void toString(StringBuilder buff, int maxdepth) + int serializeOffHeap(ByteBuffer buffer, IPartitioner p) { - buff.append(toString()); + if (buffer.remaining() < OffHeapLeaf.maxOffHeapSize()) + throw new IllegalStateException("Insufficient remaining bytes to deserialize a Leaf node off-heap"); + + if (hash.length != HASH_SIZE) + throw new IllegalArgumentException("Hash of unexpected size when serializing a Leaf off-heap: " + hash.length); + + final int position = buffer.position(); + buffer.put(hash); + return ~position; } @Override public String toString() { - return "#<Leaf " + Hashable.toString(hash()) + ">"; + return "#<OnHeapLeaf " + Node.toString(hash()) + '>'; } + } + + static class OffHeapLeaf extends OffHeapNode implements Leaf + { + static final int HASH_BYTES_OFFSET = 0; - private static class LeafSerializer implements IPartitionerDependentSerializer<Leaf> + OffHeapLeaf(ByteBuffer buffer, int offset) { - public void serialize(Leaf leaf, DataOutputPlus out, int version) throws IOException - { - if (leaf.hash == null) - { - out.writeByte(-1); - } - else - { - out.writeByte(leaf.hash.length); - out.write(leaf.hash); - } - } + super(buffer, offset); + } + + public int hashBytesOffset() + { + return offset + HASH_BYTES_OFFSET; + } + + static int deserializeWithoutIdent(DataInput in, ByteBuffer buffer) throws IOException + { + if (buffer.remaining() < maxOffHeapSize()) + throw new IllegalStateException("Insufficient remaining bytes to deserialize a Leaf node off-heap"); - public Leaf deserialize(DataInput in, IPartitioner p, int version) throws IOException + final int position = buffer.position(); + + int hashLength = in.readByte(); + if (hashLength > 0) { - int hashLen = in.readByte(); - byte[] hash = hashLen < 0 ? null : new byte[hashLen]; - if (hash != null) - in.readFully(hash); - return new Leaf(hash); - } + if (hashLength != HASH_SIZE) + throw new IllegalStateException("Hash of unexpected size when deserializing an off-heap Leaf node: " + hashLength); - public long serializedSize(Leaf leaf, int version) + byte[] hashBytes = getTempArray(HASH_SIZE); + in.readFully(hashBytes, 0, HASH_SIZE); + buffer.put(hashBytes, 0, HASH_SIZE); + } + else { - long size = 1; - if (leaf.hash != null) - size += leaf.hash().length; - return size; + buffer.put(EMPTY_HASH, 0, HASH_SIZE); } + + return ~position; } - } - /** - * Hash value representing a row, to be used to pass hashes to the MerkleTree. - * The byte[] hash value should contain a digest of the key and value of the row - * created using a very strong hash function. - */ - public static class RowHash - { - public final Token token; - public final byte[] hash; - public final long size; - public RowHash(Token token, byte[] hash, long size) + static int maxOffHeapSize() { - this.token = token; - this.hash = hash; - this.size = size; + return HASH_SIZE; } @Override public String toString() { - return "#<RowHash " + token + " " + Hashable.toString(hash) + " @ " + size + " bytes>"; + return "#<OffHeapLeaf " + Node.toString(hash()) + '>'; } } /** - * Abstract class containing hashing logic, and containing a single hash field. + * An inner node in the MerkleTree. Inners can contain cached hash values, which + * are the binary hash of their two children. */ - static abstract class Hashable implements Serializable + interface Inner extends Node { - private static final long serialVersionUID = 1L; - private static final IPartitionerDependentSerializer<Hashable> serializer = new HashableSerializer(); + static final byte IDENT = 2; - protected byte[] hash; - protected long sizeOfRange; - protected long rowsInRange; + public Token token(); - protected Hashable(byte[] hash) + public Node left(); + public Node right(); + + default void serialize(DataOutputPlus out, int version) throws IOException { - this.hash = hash; + out.writeByte(Inner.IDENT); + Token.serializer.serialize(token(), out, version); + left().serialize(out, version); + right().serialize(out, version); } - public byte[] hash() + default int serializedSize(int version) { - return hash; + return 1 + + (int) Token.serializer.serializedSize(token(), version) + + left().serializedSize(version) + + right().serializedSize(version); } - public long sizeOfRange() + default void toString(StringBuilder buff, int maxdepth) { - return sizeOfRange; + buff.append("#<").append(getClass().getSimpleName()) + .append(' ').append(token()) + .append(" hash=").append(Node.toString(hash())) + .append(" children=["); + + if (maxdepth < 1) + { + buff.append('#'); + } + else + { + Node left = left(); + if (left == null) + buff.append("null"); + else + left.toString(buff, maxdepth - 1); + + buff.append(' '); + + Node right = right(); + if (right == null) + buff.append("null"); + else + right.toString(buff, maxdepth - 1); + } + + buff.append("]>"); } - public long rowsInRange() + default boolean equals(Node other) { - return rowsInRange; + if (!(other instanceof Inner)) + return false; + Inner that = (Inner) other; + return !hashesDiffer(other) && this.left().equals(that.left()) && this.right().equals(that.right()); } + } + + static class OnHeapInner extends OnHeapNode implements Inner + { + private final Token token; - void hash(byte[] hash) + private OnHeapNode left; + private OnHeapNode right; + + private boolean computed; + + OnHeapInner(Token token, OnHeapNode left, OnHeapNode right) { - this.hash = hash; + super(EMPTY_HASH); + + this.token = token; + this.left = left; + this.right = right; } - Hashable calc() + public Token token() { - return this; + return token; } - /** - * Sets the value of this hash to binaryHash of its children. - * @param lefthash Hash of left child. - * @param righthash Hash of right child. - */ - void hash(byte[] lefthash, byte[] righthash) + public OnHeapNode left() { - hash = binaryHash(lefthash, righthash); + return left; } - /** - * Mixes the given value into our hash. If our hash is null, - * our hash will become the given value. - */ - void addHash(byte[] righthash, long sizeOfRow) + public OnHeapNode right() { - if (hash == null) - hash = righthash; - else - hash = binaryHash(hash, righthash); - this.sizeOfRange += sizeOfRow; - this.rowsInRange += 1; + return right; } - /** - * The primitive with which all hashing should be accomplished: hashes - * a left and right value together. - */ - static byte[] binaryHash(final byte[] left, final byte[] right) + void left(OnHeapNode child) { - return FBUtilities.xor(left, right); + left = child; } - public abstract void toString(StringBuilder buff, int maxdepth); - - public static String toString(byte[] hash) + void right(OnHeapNode child) { - if (hash == null) - return "null"; - return "[" + Hex.bytesToHex(hash) + "]"; + right = child; } - private static class HashableSerializer implements IPartitionerDependentSerializer<Hashable> + @Override + public Node compute() { - public void serialize(Hashable h, DataOutputPlus out, int version) throws IOException + if (!computed) // hash and size haven't been calculated; compute children then compute this { - if (h instanceof Inner) - { - out.writeByte(Inner.IDENT); - Inner.serializer.serialize((Inner)h, out, version); - } - else if (h instanceof Leaf) - { - out.writeByte(Leaf.IDENT); - Leaf.serializer.serialize((Leaf) h, out, version); - } - else - throw new IOException("Unexpected Hashable: " + h.getClass().getCanonicalName()); - } + left.compute(); + right.compute(); - public Hashable deserialize(DataInput in, IPartitioner p, int version) throws IOException - { - byte ident = in.readByte(); - if (Inner.IDENT == ident) - return Inner.serializer.deserialize(in, p, version); - else if (Leaf.IDENT == ident) - return Leaf.serializer.deserialize(in, p, version); - else - throw new IOException("Unexpected Hashable: " + ident); + if (!left.hasEmptyHash() && !right.hasEmptyHash()) + hash = FBUtilities.xor(left.hash(), right.hash()); + else if (left.hasEmptyHash()) + hash = right.hash(); + else if (right.hasEmptyHash()) + hash = left.hash(); + + sizeOfRange = left.sizeOfRange() + right.sizeOfRange(); + partitionsInRange = left.partitionsInRange() + right.partitionsInRange(); + + computed = true; } - public long serializedSize(Hashable h, int version) + return this; + } + + static OnHeapInner deserializeWithoutIdent(DataInputPlus in, IPartitioner p, int version) throws IOException + { + Token token = Token.serializer.deserialize(in, p, version); + OnHeapNode left = OnHeapNode.deserialize(in, p, version); + OnHeapNode right = OnHeapNode.deserialize(in, p, version); + return new OnHeapInner(token, left, right); + } + + int serializeOffHeap(ByteBuffer buffer, IPartitioner partitioner) throws IOException + { + if (buffer.remaining() < OffHeapInner.maxOffHeapSize(partitioner)) + throw new IllegalStateException("Insufficient remaining bytes to deserialize Inner node off-heap"); + + final int offset = buffer.position(); + + int tokenSize = partitioner.getTokenFactory().byteSize(token); + buffer.putShort(offset + OffHeapInner.TOKEN_LENGTH_OFFSET, Shorts.checkedCast(tokenSize)); + buffer.position(offset + OffHeapInner.TOKEN_BYTES_OFFSET); + partitioner.getTokenFactory().serialize(token, buffer); + + int leftPointer = left.serializeOffHeap(buffer, partitioner); + int rightPointer = right.serializeOffHeap(buffer, partitioner); + + buffer.putInt(offset + OffHeapInner.LEFT_CHILD_POINTER_OFFSET, leftPointer); + buffer.putInt(offset + OffHeapInner.RIGHT_CHILD_POINTER_OFFSET, rightPointer); + + int leftHashOffset = OffHeapInner.hashBytesOffset(leftPointer); + int rightHashOffset = OffHeapInner.hashBytesOffset(rightPointer); + + for (int i = 0; i < HASH_SIZE; i += 8) { - if (h instanceof Inner) - return 1 + Inner.serializer.serializedSize((Inner) h, version); - else if (h instanceof Leaf) - return 1 + Leaf.serializer.serializedSize((Leaf) h, version); - throw new AssertionError(h.getClass()); + buffer.putLong(offset + OffHeapInner.HASH_BYTES_OFFSET + i, + buffer.getLong(leftHashOffset + i) ^ buffer.getLong(rightHashOffset + i)); } + + return offset; + } + + @Override + public String toString() + { + StringBuilder buff = new StringBuilder(); + toString(buff, 1); + return buff.toString(); } } - /** - * Exceptions that stop recursion early when we are sure that no answer - * can be found. - */ - static abstract class StopRecursion extends Exception + static class OffHeapInner extends OffHeapNode implements Inner { - static class BadRange extends StopRecursion + /** + * All we want to keep here is just a pointer to the start of the Inner leaf in the + * direct buffer. From there, we'll be able to deserialize the following, in this order: + * + * 1. pointer to left child (int) + * 2. pointer to right child (int) + * 3. hash bytes (space allocated as HASH_MAX_SIZE) + * 4. token length (short) + * 5. token bytes (variable length) + */ + static final int LEFT_CHILD_POINTER_OFFSET = 0; + static final int RIGHT_CHILD_POINTER_OFFSET = 4; + static final int HASH_BYTES_OFFSET = 8; + static final int TOKEN_LENGTH_OFFSET = 8 + HASH_SIZE; + static final int TOKEN_BYTES_OFFSET = TOKEN_LENGTH_OFFSET + 2; + + private final IPartitioner partitioner; + + OffHeapInner(ByteBuffer buffer, int offset, IPartitioner partitioner) { - public BadRange(){ super(); } + super(buffer, offset); + this.partitioner = partitioner; } - static class InvalidHash extends StopRecursion + public Token token() { - public InvalidHash(){ super(); } + int length = buffer.getShort(offset + TOKEN_LENGTH_OFFSET); + return partitioner.getTokenFactory().fromByteBuffer(buffer, offset + TOKEN_BYTES_OFFSET, length); } - static class TooDeep extends StopRecursion + public Node left() { - public TooDeep(){ super(); } + int pointer = buffer.getInt(offset + LEFT_CHILD_POINTER_OFFSET); + return pointer >= 0 ? new OffHeapInner(buffer, pointer, partitioner) : new OffHeapLeaf(buffer, ~pointer); + } + + public Node right() + { + int pointer = buffer.getInt(offset + RIGHT_CHILD_POINTER_OFFSET); + return pointer >= 0 ? new OffHeapInner(buffer, pointer, partitioner) : new OffHeapLeaf(buffer, ~pointer); + } + + public int hashBytesOffset() + { + return offset + HASH_BYTES_OFFSET; + } + + static int deserializeWithoutIdent(DataInputPlus in, ByteBuffer buffer, IPartitioner partitioner, int version) throws IOException + { + if (buffer.remaining() < maxOffHeapSize(partitioner)) + throw new IllegalStateException("Insufficient remaining bytes to deserialize Inner node off-heap"); + + final int offset = buffer.position(); + + int tokenSize = Token.serializer.deserializeSize(in); + byte[] tokenBytes = getTempArray(tokenSize); + in.readFully(tokenBytes, 0, tokenSize); + + buffer.putShort(offset + OffHeapInner.TOKEN_LENGTH_OFFSET, Shorts.checkedCast(tokenSize)); + buffer.position(offset + OffHeapInner.TOKEN_BYTES_OFFSET); + buffer.put(tokenBytes, 0, tokenSize); + + int leftPointer = OffHeapNode.deserialize(in, buffer, partitioner, version); + int rightPointer = OffHeapNode.deserialize(in, buffer, partitioner, version); + + buffer.putInt(offset + OffHeapInner.LEFT_CHILD_POINTER_OFFSET, leftPointer); + buffer.putInt(offset + OffHeapInner.RIGHT_CHILD_POINTER_OFFSET, rightPointer); + + int leftHashOffset = hashBytesOffset(leftPointer); + int rightHashOffset = hashBytesOffset(rightPointer); + + for (int i = 0; i < HASH_SIZE; i += 8) + { + buffer.putLong(offset + OffHeapInner.HASH_BYTES_OFFSET + i, + buffer.getLong(leftHashOffset + i) ^ buffer.getLong(rightHashOffset + i)); + } + + return offset; + } + + static int maxOffHeapSize(IPartitioner partitioner) + { + return 4 // left pointer + + 4 // right pointer + + HASH_SIZE + + 2 + partitioner.getMaxTokenSize(); + } + + static int hashBytesOffset(int pointer) + { + return pointer >= 0 ? pointer + OffHeapInner.HASH_BYTES_OFFSET : ~pointer + OffHeapLeaf.HASH_BYTES_OFFSET; + } + + @Override + public String toString() + { + StringBuilder buff = new StringBuilder(); + toString(buff, 1); + return buff.toString(); } } @@ -1183,10 +1565,10 @@ public class MerkleTree implements Serializable { byte[] hashLeft = new byte[bytesPerHash]; byte[] hashRigth = new byte[bytesPerHash]; - Leaf left = new Leaf(hashLeft); - Leaf right = new Leaf(hashRigth); - Inner inner = new Inner(partitioner.getMinimumToken(), left, right); - inner.calc(); + OnHeapLeaf left = new OnHeapLeaf(hashLeft); + OnHeapLeaf right = new OnHeapLeaf(hashRigth); + Inner inner = new OnHeapInner(partitioner.getMinimumToken(), left, right); + inner.compute(); // Some partioners have variable token sizes, try to estimate as close as we can by using the same // heap estimate as the memtables use. diff --git a/src/java/org/apache/cassandra/utils/MerkleTrees.java b/src/java/org/apache/cassandra/utils/MerkleTrees.java index d2a8058..9cf04c5 100644 --- a/src/java/org/apache/cassandra/utils/MerkleTrees.java +++ b/src/java/org/apache/cassandra/utils/MerkleTrees.java @@ -44,9 +44,9 @@ public class MerkleTrees implements Iterable<Map.Entry<Range<Token>, MerkleTree> { public static final MerkleTreesSerializer serializer = new MerkleTreesSerializer(); - private Map<Range<Token>, MerkleTree> merkleTrees = new TreeMap<>(new TokenRangeComparator()); + private final Map<Range<Token>, MerkleTree> merkleTrees = new TreeMap<>(new TokenRangeComparator()); - private IPartitioner partitioner; + private final IPartitioner partitioner; /** * Creates empty MerkleTrees object. @@ -143,6 +143,14 @@ public class MerkleTrees implements Iterable<Map.Entry<Range<Token>, MerkleTree> } /** + * Dereference all merkle trees and release direct memory for all off-heap trees. + */ + public void release() + { + merkleTrees.values().forEach(MerkleTree::release); merkleTrees.clear(); + } + + /** * Init a selected MerkleTree with an even tree distribution. * * @param range @@ -247,11 +255,11 @@ public class MerkleTrees implements Iterable<Map.Entry<Range<Token>, MerkleTree> } /** - * Get an iterator for all the invalids generated by the MerkleTrees. + * Get an iterator for all the iterator generated by the MerkleTrees. * * @return */ - public TreeRangeIterator invalids() + public TreeRangeIterator rangeIterator() { return new TreeRangeIterator(); } @@ -285,30 +293,20 @@ public class MerkleTrees implements Iterable<Map.Entry<Range<Token>, MerkleTree> @VisibleForTesting public byte[] hash(Range<Token> range) { - ByteArrayOutputStream baos = new ByteArrayOutputStream(); - boolean hashed = false; - - try + try (ByteArrayOutputStream baos = new ByteArrayOutputStream()) { - for (Range<Token> rt : merkleTrees.keySet()) - { - if (rt.intersects(range)) - { - byte[] bytes = merkleTrees.get(rt).hash(range); - if (bytes != null) - { - baos.write(bytes); - hashed = true; - } - } - } + boolean hashed = false; + + for (Map.Entry<Range<Token>, MerkleTree> entry : merkleTrees.entrySet()) + if (entry.getKey().intersects(range)) + hashed |= entry.getValue().ifHashesRange(range, n -> baos.write(n.hash())); + + return hashed ? baos.toByteArray() : null; } catch (IOException e) { throw new RuntimeException("Unable to append merkle tree hash to result"); } - - return hashed ? baos.toByteArray() : null; } /** @@ -354,7 +352,7 @@ public class MerkleTrees implements Iterable<Map.Entry<Range<Token>, MerkleTree> { if (it.hasNext()) { - current = it.next().invalids(); + current = it.next().rangeIterator(); return current.next(); } @@ -369,6 +367,17 @@ public class MerkleTrees implements Iterable<Map.Entry<Range<Token>, MerkleTree> } /** + * @return a new {@link MerkleTrees} instance with all trees moved off heap. + */ + public MerkleTrees tryMoveOffHeap() throws IOException + { + Map<Range<Token>, MerkleTree> movedTrees = new TreeMap<>(new TokenRangeComparator()); + for (Map.Entry<Range<Token>, MerkleTree> entry : merkleTrees.entrySet()) + movedTrees.put(entry.getKey(), entry.getValue().tryMoveOffHeap()); + return new MerkleTrees(partitioner, movedTrees.values()); + } + + /** * Get the differences between the two sets of MerkleTrees. * * @param ltree @@ -379,9 +388,7 @@ public class MerkleTrees implements Iterable<Map.Entry<Range<Token>, MerkleTree> { List<Range<Token>> differences = new ArrayList<>(); for (MerkleTree tree : ltree.merkleTrees.values()) - { differences.addAll(MerkleTree.difference(tree, rtree.getMerkleTree(tree.fullRange))); - } return differences; } @@ -392,7 +399,7 @@ public class MerkleTrees implements Iterable<Map.Entry<Range<Token>, MerkleTree> out.writeInt(trees.merkleTrees.size()); for (MerkleTree tree : trees.merkleTrees.values()) { - MerkleTree.serializer.serialize(tree, out, version); + tree.serialize(out, version); } } @@ -405,7 +412,7 @@ public class MerkleTrees implements Iterable<Map.Entry<Range<Token>, MerkleTree> { for (int i = 0; i < nTrees; i++) { - MerkleTree tree = MerkleTree.serializer.deserialize(in, version); + MerkleTree tree = MerkleTree.deserialize(in, version); trees.add(tree); if (partitioner == null) @@ -425,7 +432,7 @@ public class MerkleTrees implements Iterable<Map.Entry<Range<Token>, MerkleTree> long size = TypeSizes.sizeof(trees.merkleTrees.size()); for (MerkleTree tree : trees.merkleTrees.values()) { - size += MerkleTree.serializer.serializedSize(tree, version); + size += tree.serializedSize(version); } return size; } diff --git a/test/unit/org/apache/cassandra/repair/LocalSyncTaskTest.java b/test/unit/org/apache/cassandra/repair/LocalSyncTaskTest.java index e787595..443d59e 100644 --- a/test/unit/org/apache/cassandra/repair/LocalSyncTaskTest.java +++ b/test/unit/org/apache/cassandra/repair/LocalSyncTaskTest.java @@ -238,10 +238,6 @@ public class LocalSyncTaskTest extends AbstractRepairTest MerkleTrees tree = new MerkleTrees(partitioner); tree.addMerkleTrees((int) Math.pow(2, 15), desc.ranges); tree.init(); - for (MerkleTree.TreeRange r : tree.invalids()) - { - r.ensureHashInitialised(); - } return tree; } diff --git a/test/unit/org/apache/cassandra/repair/RepairJobTest.java b/test/unit/org/apache/cassandra/repair/RepairJobTest.java index 78fa588..b84adaa 100644 --- a/test/unit/org/apache/cassandra/repair/RepairJobTest.java +++ b/test/unit/org/apache/cassandra/repair/RepairJobTest.java @@ -774,10 +774,6 @@ public class RepairJobTest MerkleTrees tree = new MerkleTrees(MURMUR3_PARTITIONER); tree.addMerkleTrees((int) Math.pow(2, 15), fullRange); tree.init(); - for (MerkleTree.TreeRange r : tree.invalids()) - { - r.ensureHashInitialised(); - } if (invalidate) { diff --git a/test/unit/org/apache/cassandra/repair/ValidatorTest.java b/test/unit/org/apache/cassandra/repair/ValidatorTest.java index aec2612..9e848a9 100644 --- a/test/unit/org/apache/cassandra/repair/ValidatorTest.java +++ b/test/unit/org/apache/cassandra/repair/ValidatorTest.java @@ -203,12 +203,14 @@ public class ValidatorTest cfs.getTableName(), Collections.singletonList(new Range<>(sstable.first.getToken(), sstable.last.getToken()))); - ActiveRepairService.instance.registerParentRepairSession(repairSessionId, FBUtilities.getBroadcastAddressAndPort(), + InetAddressAndPort host = InetAddressAndPort.getByName("127.0.0.2"); + + ActiveRepairService.instance.registerParentRepairSession(repairSessionId, host, Collections.singletonList(cfs), desc.ranges, false, ActiveRepairService.UNREPAIRED_SSTABLE, false, PreviewKind.NONE); final CompletableFuture<Message> outgoingMessageSink = registerOutgoingMessageSink(); - Validator validator = new Validator(desc, FBUtilities.getBroadcastAddressAndPort(), 0, true, false, PreviewKind.NONE); + Validator validator = new Validator(desc, host, 0, true, false, PreviewKind.NONE); ValidationManager.instance.submitValidation(cfs, validator); Message message = outgoingMessageSink.get(TEST_TIMEOUT, TimeUnit.SECONDS); @@ -260,12 +262,14 @@ public class ValidatorTest cfs.getTableName(), Collections.singletonList(new Range<>(sstable.first.getToken(), sstable.last.getToken()))); - ActiveRepairService.instance.registerParentRepairSession(repairSessionId, FBUtilities.getBroadcastAddressAndPort(), + InetAddressAndPort host = InetAddressAndPort.getByName("127.0.0.2"); + + ActiveRepairService.instance.registerParentRepairSession(repairSessionId, host, Collections.singletonList(cfs), desc.ranges, false, ActiveRepairService.UNREPAIRED_SSTABLE, false, PreviewKind.NONE); final CompletableFuture<Message> outgoingMessageSink = registerOutgoingMessageSink(); - Validator validator = new Validator(desc, FBUtilities.getBroadcastAddressAndPort(), 0, true, false, PreviewKind.NONE); + Validator validator = new Validator(desc, host, 0, true, false, PreviewKind.NONE); ValidationManager.instance.submitValidation(cfs, validator); Message message = outgoingMessageSink.get(TEST_TIMEOUT, TimeUnit.SECONDS); @@ -320,12 +324,14 @@ public class ValidatorTest final RepairJobDesc desc = new RepairJobDesc(repairSessionId, UUIDGen.getTimeUUID(), cfs.keyspace.getName(), cfs.getTableName(), ranges); - ActiveRepairService.instance.registerParentRepairSession(repairSessionId, FBUtilities.getBroadcastAddressAndPort(), + InetAddressAndPort host = InetAddressAndPort.getByName("127.0.0.2"); + + ActiveRepairService.instance.registerParentRepairSession(repairSessionId, host, Collections.singletonList(cfs), desc.ranges, false, ActiveRepairService.UNREPAIRED_SSTABLE, false, PreviewKind.NONE); final CompletableFuture<Message> outgoingMessageSink = registerOutgoingMessageSink(); - Validator validator = new Validator(desc, FBUtilities.getBroadcastAddressAndPort(), 0, true, false, PreviewKind.NONE); + Validator validator = new Validator(desc, host, 0, true, false, PreviewKind.NONE); ValidationManager.instance.submitValidation(cfs, validator); Message message = outgoingMessageSink.get(TEST_TIMEOUT, TimeUnit.SECONDS); diff --git a/test/unit/org/apache/cassandra/repair/asymmetric/DifferenceHolderTest.java b/test/unit/org/apache/cassandra/repair/asymmetric/DifferenceHolderTest.java index 9693010..de69cd7 100644 --- a/test/unit/org/apache/cassandra/repair/asymmetric/DifferenceHolderTest.java +++ b/test/unit/org/apache/cassandra/repair/asymmetric/DifferenceHolderTest.java @@ -74,9 +74,9 @@ public class DifferenceHolderTest mt1.init(); mt2.init(); // add dummy hashes to both trees - for (MerkleTree.TreeRange range : mt1.invalids()) + for (MerkleTree.TreeRange range : mt1.rangeIterator()) range.addAll(new MerkleTreesTest.HIterator(range.right)); - for (MerkleTree.TreeRange range : mt2.invalids()) + for (MerkleTree.TreeRange range : mt2.rangeIterator()) range.addAll(new MerkleTreesTest.HIterator(range.right)); MerkleTree.TreeRange leftmost = null; @@ -85,7 +85,7 @@ public class DifferenceHolderTest mt1.maxsize(fullRange, maxsize + 2); // give some room for splitting // split the leftmost - Iterator<MerkleTree.TreeRange> ranges = mt1.invalids(); + Iterator<MerkleTree.TreeRange> ranges = mt1.rangeIterator(); leftmost = ranges.next(); mt1.split(leftmost.right); diff --git a/test/unit/org/apache/cassandra/utils/MerkleTreeTest.java b/test/unit/org/apache/cassandra/utils/MerkleTreeTest.java index c213271..d8491d0 100644 --- a/test/unit/org/apache/cassandra/utils/MerkleTreeTest.java +++ b/test/unit/org/apache/cassandra/utils/MerkleTreeTest.java @@ -1,21 +1,20 @@ /* -* Licensed to the Apache Software Foundation (ASF) under one -* or more contributor license agreements. See the NOTICE file -* distributed with this work for additional information -* regarding copyten ownership. The ASF licenses this file -* to you under the Apache License, Version 2.0 (the -* "License"); you may not use this file except in compliance -* with the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, -* software distributed under the License is distributed on an -* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -* KIND, either express or implied. See the License for the -* specific language governing permissions and limitations -* under the License. -*/ + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package org.apache.cassandra.utils; import java.math.BigInteger; @@ -36,10 +35,8 @@ import org.apache.cassandra.dht.RandomPartitioner.BigIntegerToken; import org.apache.cassandra.dht.Range; import org.apache.cassandra.dht.Token; import org.apache.cassandra.io.util.DataInputBuffer; -import org.apache.cassandra.io.util.DataInputPlus; import org.apache.cassandra.io.util.DataOutputBuffer; import org.apache.cassandra.net.MessagingService; -import org.apache.cassandra.utils.MerkleTree.Hashable; import org.apache.cassandra.utils.MerkleTree.RowHash; import org.apache.cassandra.utils.MerkleTree.TreeRange; import org.apache.cassandra.utils.MerkleTree.TreeRangeIterator; @@ -49,7 +46,7 @@ import static org.junit.Assert.*; public class MerkleTreeTest { - public static byte[] DUMMY = "blah".getBytes(); + private static final byte[] DUMMY = HashingUtils.newMessageDigest("SHA-256").digest("dummy".getBytes()); /** * If a test assumes that the tree is 8 units wide, then it should set this value @@ -68,6 +65,9 @@ public class MerkleTreeTest @Before public void setup() { + DatabaseDescriptor.clientInitialization(); + DatabaseDescriptor.setOffheapMerkleTreesEnabled(false); + TOKEN_SCALE = new BigInteger("8"); partitioner = RandomPartitioner.instance; // TODO need to trickle TokenSerializer @@ -171,7 +171,7 @@ public class MerkleTreeTest Iterator<TreeRange> ranges; // (zero, zero] - ranges = mt.invalids(); + ranges = mt.rangeIterator(); assertEquals(new Range<>(tok(-1), tok(-1)), ranges.next()); assertFalse(ranges.hasNext()); @@ -181,7 +181,7 @@ public class MerkleTreeTest mt.split(tok(6)); mt.split(tok(3)); mt.split(tok(5)); - ranges = mt.invalids(); + ranges = mt.rangeIterator(); assertEquals(new Range<>(tok(6), tok(-1)), ranges.next()); assertEquals(new Range<>(tok(-1), tok(2)), ranges.next()); assertEquals(new Range<>(tok(2), tok(3)), ranges.next()); @@ -200,7 +200,7 @@ public class MerkleTreeTest Range<Token> range = new Range<>(tok(-1), tok(-1)); // (zero, zero] - assertNull(mt.hash(range)); + assertFalse(mt.hashesRange(range)); // validate the range mt.get(tok(-1)).hash(val); @@ -223,11 +223,12 @@ public class MerkleTreeTest // (zero,two] (two,four] (four, zero] mt.split(tok(4)); mt.split(tok(2)); - assertNull(mt.hash(left)); - assertNull(mt.hash(partial)); - assertNull(mt.hash(right)); - assertNull(mt.hash(linvalid)); - assertNull(mt.hash(rinvalid)); + + assertFalse(mt.hashesRange(left)); + assertFalse(mt.hashesRange(partial)); + assertFalse(mt.hashesRange(right)); + assertFalse(mt.hashesRange(linvalid)); + assertFalse(mt.hashesRange(rinvalid)); // validate the range mt.get(tok(2)).hash(val); @@ -237,8 +238,8 @@ public class MerkleTreeTest assertHashEquals(leftval, mt.hash(left)); assertHashEquals(partialval, mt.hash(partial)); assertHashEquals(val, mt.hash(right)); - assertNull(mt.hash(linvalid)); - assertNull(mt.hash(rinvalid)); + assertFalse(mt.hashesRange(linvalid)); + assertFalse(mt.hashesRange(rinvalid)); } @Test @@ -258,10 +259,6 @@ public class MerkleTreeTest mt.split(tok(2)); mt.split(tok(6)); mt.split(tok(1)); - assertNull(mt.hash(full)); - assertNull(mt.hash(lchild)); - assertNull(mt.hash(rchild)); - assertNull(mt.hash(invalid)); // validate the range mt.get(tok(1)).hash(val); @@ -270,10 +267,14 @@ public class MerkleTreeTest mt.get(tok(6)).hash(val); mt.get(tok(-1)).hash(val); + assertTrue(mt.hashesRange(full)); + assertTrue(mt.hashesRange(lchild)); + assertTrue(mt.hashesRange(rchild)); + assertFalse(mt.hashesRange(invalid)); + assertHashEquals(fullval, mt.hash(full)); assertHashEquals(lchildval, mt.hash(lchild)); assertHashEquals(rchildval, mt.hash(rchild)); - assertNull(mt.hash(invalid)); } @Test @@ -294,9 +295,6 @@ public class MerkleTreeTest mt.split(tok(4)); mt.split(tok(2)); mt.split(tok(1)); - assertNull(mt.hash(full)); - assertNull(mt.hash(childfull)); - assertNull(mt.hash(invalid)); // validate the range mt.get(tok(1)).hash(val); @@ -306,9 +304,12 @@ public class MerkleTreeTest mt.get(tok(16)).hash(val); mt.get(tok(-1)).hash(val); + assertTrue(mt.hashesRange(full)); + assertTrue(mt.hashesRange(childfull)); + assertFalse(mt.hashesRange(invalid)); + assertHashEquals(fullval, mt.hash(full)); assertHashEquals(childfullval, mt.hash(childfull)); - assertNull(mt.hash(invalid)); } @Test @@ -326,7 +327,7 @@ public class MerkleTreeTest } // validate the tree - TreeRangeIterator ranges = mt.invalids(); + TreeRangeIterator ranges = mt.rangeIterator(); for (TreeRange range : ranges) range.addHash(new RowHash(range.right, new byte[0], 0)); @@ -355,7 +356,7 @@ public class MerkleTreeTest mt.split(tok(6)); mt.split(tok(10)); - ranges = mt.invalids(); + ranges = mt.rangeIterator(); ranges.next().addAll(new HIterator(2, 4)); // (-1,4]: depth 2 ranges.next().addAll(new HIterator(6)); // (4,6] ranges.next().addAll(new HIterator(8)); // (6,8] @@ -372,7 +373,7 @@ public class MerkleTreeTest mt2.split(tok(9)); mt2.split(tok(11)); - ranges = mt2.invalids(); + ranges = mt2.rangeIterator(); ranges.next().addAll(new HIterator(2)); // (-1,2] ranges.next().addAll(new HIterator(4)); // (2,4] ranges.next().addAll(new HIterator(6, 8)); // (4,8]: depth 2 @@ -395,19 +396,33 @@ public class MerkleTreeTest // populate and validate the tree mt.maxsize(256); mt.init(); - for (TreeRange range : mt.invalids()) + for (TreeRange range : mt.rangeIterator()) range.addAll(new HIterator(range.right)); byte[] initialhash = mt.hash(full); DataOutputBuffer out = new DataOutputBuffer(); - MerkleTree.serializer.serialize(mt, out, MessagingService.current_version); + mt.serialize(out, MessagingService.current_version); byte[] serialized = out.toByteArray(); - DataInputPlus in = new DataInputBuffer(serialized); - MerkleTree restored = MerkleTree.serializer.deserialize(in, MessagingService.current_version); + MerkleTree restoredOnHeap = + MerkleTree.deserialize(new DataInputBuffer(serialized), false, MessagingService.current_version); + MerkleTree restoredOffHeap = + MerkleTree.deserialize(new DataInputBuffer(serialized), true, MessagingService.current_version); + MerkleTree movedOffHeap = mt.tryMoveOffHeap(); + + assertHashEquals(initialhash, restoredOnHeap.hash(full)); + assertHashEquals(initialhash, restoredOffHeap.hash(full)); + assertHashEquals(initialhash, movedOffHeap.hash(full)); + + assertEquals(mt, restoredOnHeap); + assertEquals(mt, restoredOffHeap); + assertEquals(mt, movedOffHeap); + + assertEquals(restoredOnHeap, restoredOffHeap); + assertEquals(restoredOnHeap, movedOffHeap); - assertHashEquals(initialhash, restored.hash(full)); + assertEquals(restoredOffHeap, movedOffHeap); } @Test @@ -420,9 +435,9 @@ public class MerkleTreeTest mt2.init(); // add dummy hashes to both trees - for (TreeRange range : mt.invalids()) + for (TreeRange range : mt.rangeIterator()) range.addAll(new HIterator(range.right)); - for (TreeRange range : mt2.invalids()) + for (TreeRange range : mt2.rangeIterator()) range.addAll(new HIterator(range.right)); TreeRange leftmost = null; @@ -431,7 +446,7 @@ public class MerkleTreeTest mt.maxsize(maxsize + 2); // give some room for splitting // split the leftmost - Iterator<TreeRange> ranges = mt.invalids(); + Iterator<TreeRange> ranges = mt.rangeIterator(); leftmost = ranges.next(); mt.split(leftmost.right); @@ -465,18 +480,18 @@ public class MerkleTreeTest byte[] h2 = "hjkl".getBytes(); // add dummy hashes to both trees - for (TreeRange tree : ltree.invalids()) + for (TreeRange tree : ltree.rangeIterator()) { tree.addHash(new RowHash(range.right, h1, h1.length)); } - for (TreeRange tree : rtree.invalids()) + for (TreeRange tree : rtree.rangeIterator()) { tree.addHash(new RowHash(range.right, h2, h2.length)); } List<TreeRange> diffs = MerkleTree.difference(ltree, rtree); assertEquals(Lists.newArrayList(range), diffs); - assertEquals(MerkleTree.FULLY_INCONSISTENT, MerkleTree.differenceHelper(ltree, rtree, new ArrayList<>(), new MerkleTree.TreeDifference(ltree.fullRange.left, ltree.fullRange.right, (byte) 0))); + assertEquals(MerkleTree.FULLY_INCONSISTENT, MerkleTree.differenceHelper(ltree, rtree, new ArrayList<>(), new MerkleTree.TreeRange(ltree.fullRange.left, ltree.fullRange.right, (byte)0))); } /** @@ -499,11 +514,11 @@ public class MerkleTreeTest // add dummy hashes to both trees - for (TreeRange tree : ltree.invalids()) + for (TreeRange tree : ltree.rangeIterator()) { tree.addHash(new RowHash(range.right, h1, h1.length)); } - for (TreeRange tree : rtree.invalids()) + for (TreeRange tree : rtree.rangeIterator()) { tree.addHash(new RowHash(range.right, h2, h2.length)); } @@ -533,7 +548,7 @@ public class MerkleTreeTest while (depth.equals(dstack.peek())) { // consume the stack - hash = Hashable.binaryHash(hstack.pop(), hash); + hash = FBUtilities.xor(hstack.pop(), hash); depth = dstack.pop() - 1; } dstack.push(depth); diff --git a/test/unit/org/apache/cassandra/utils/MerkleTreesTest.java b/test/unit/org/apache/cassandra/utils/MerkleTreesTest.java index b40f6c4..a1f6068 100644 --- a/test/unit/org/apache/cassandra/utils/MerkleTreesTest.java +++ b/test/unit/org/apache/cassandra/utils/MerkleTreesTest.java @@ -34,7 +34,6 @@ import org.apache.cassandra.io.util.DataInputBuffer; import org.apache.cassandra.io.util.DataOutputBuffer; import org.apache.cassandra.net.MessagingService; import org.apache.cassandra.service.StorageService; -import org.apache.cassandra.utils.MerkleTree.Hashable; import org.apache.cassandra.utils.MerkleTree.RowHash; import org.apache.cassandra.utils.MerkleTree.TreeRange; import org.apache.cassandra.utils.MerkleTrees.TreeRangeIterator; @@ -43,7 +42,7 @@ import static org.junit.Assert.*; public class MerkleTreesTest { - public static byte[] DUMMY = "blah".getBytes(); + private static final byte[] DUMMY = HashingUtils.newMessageDigest("SHA-256").digest("dummy".getBytes()); /** * If a test assumes that the tree is 8 units wide, then it should set this value @@ -193,7 +192,7 @@ public class MerkleTreesTest Iterator<TreeRange> ranges; // (zero, zero] - ranges = mts.invalids(); + ranges = mts.rangeIterator(); assertEquals(new Range<>(tok(-1), tok(-1)), ranges.next()); assertFalse(ranges.hasNext()); @@ -203,7 +202,7 @@ public class MerkleTreesTest mts.split(tok(6)); mts.split(tok(3)); mts.split(tok(5)); - ranges = mts.invalids(); + ranges = mts.rangeIterator(); assertEquals(new Range<>(tok(6), tok(-1)), ranges.next()); assertEquals(new Range<>(tok(-1), tok(2)), ranges.next()); assertEquals(new Range<>(tok(2), tok(3)), ranges.next()); @@ -245,11 +244,6 @@ public class MerkleTreesTest // (zero,two] (two,four] (four, zero] mts.split(tok(4)); mts.split(tok(2)); - assertNull(mts.hash(left)); - assertNull(mts.hash(partial)); - assertNull(mts.hash(right)); - assertNull(mts.hash(linvalid)); - assertNull(mts.hash(rinvalid)); // validate the range mts.get(tok(2)).hash(val); @@ -280,10 +274,6 @@ public class MerkleTreesTest mts.split(tok(2)); mts.split(tok(6)); mts.split(tok(1)); - assertNull(mts.hash(full)); - assertNull(mts.hash(lchild)); - assertNull(mts.hash(rchild)); - assertNull(mts.hash(invalid)); // validate the range mts.get(tok(1)).hash(val); @@ -315,9 +305,6 @@ public class MerkleTreesTest mts.split(tok(4)); mts.split(tok(2)); mts.split(tok(1)); - assertNull(mts.hash(full)); - assertNull(mts.hash(childfull)); - assertNull(mts.hash(invalid)); // validate the range mts.get(tok(1)).hash(val); @@ -349,7 +336,7 @@ public class MerkleTreesTest } // validate the tree - TreeRangeIterator ranges = mts.invalids(); + TreeRangeIterator ranges = mts.rangeIterator(); for (TreeRange range : ranges) range.addHash(new RowHash(range.right, new byte[0], 0)); @@ -378,13 +365,16 @@ public class MerkleTreesTest mts.split(tok(6)); mts.split(tok(10)); - ranges = mts.invalids(); - ranges.next().addAll(new HIterator(2, 4)); // (-1,4]: depth 2 - ranges.next().addAll(new HIterator(6)); // (4,6] - ranges.next().addAll(new HIterator(8)); // (6,8] + int seed = 123456789; + + Random random1 = new Random(seed); + ranges = mts.rangeIterator(); + ranges.next().addAll(new HIterator(random1, 2, 4)); // (-1,4]: depth 2 + ranges.next().addAll(new HIterator(random1, 6)); // (4,6] + ranges.next().addAll(new HIterator(random1, 8)); // (6,8] ranges.next().addAll(new HIterator(/*empty*/ new int[0])); // (8,10] - ranges.next().addAll(new HIterator(12)); // (10,12] - ranges.next().addAll(new HIterator(14, -1)); // (12,-1]: depth 2 + ranges.next().addAll(new HIterator(random1, 12)); // (10,12] + ranges.next().addAll(new HIterator(random1, 14, -1)); // (12,-1]: depth 2 mts2.split(tok(8)); @@ -395,15 +385,16 @@ public class MerkleTreesTest mts2.split(tok(9)); mts2.split(tok(11)); - ranges = mts2.invalids(); - ranges.next().addAll(new HIterator(2)); // (-1,2] - ranges.next().addAll(new HIterator(4)); // (2,4] - ranges.next().addAll(new HIterator(6, 8)); // (4,8]: depth 2 + Random random2 = new Random(seed); + ranges = mts2.rangeIterator(); + ranges.next().addAll(new HIterator(random2, 2)); // (-1,2] + ranges.next().addAll(new HIterator(random2, 4)); // (2,4] + ranges.next().addAll(new HIterator(random2, 6, 8)); // (4,8]: depth 2 ranges.next().addAll(new HIterator(/*empty*/ new int[0])); // (8,9] ranges.next().addAll(new HIterator(/*empty*/ new int[0])); // (9,10] ranges.next().addAll(new HIterator(/*empty*/ new int[0])); // (10,11]: depth 4 - ranges.next().addAll(new HIterator(12)); // (11,12]: depth 4 - ranges.next().addAll(new HIterator(14, -1)); // (12,-1]: depth 2 + ranges.next().addAll(new HIterator(random2, 12)); // (11,12]: depth 4 + ranges.next().addAll(new HIterator(random2, 14, -1)); // (12,-1]: depth 2 byte[] mthash = mts.hash(full); byte[] mt2hash = mts2.hash(full); @@ -425,7 +416,7 @@ public class MerkleTreesTest // populate and validate the tree mts.init(); - for (TreeRange range : mts.invalids()) + for (TreeRange range : mts.rangeIterator()) range.addAll(new HIterator(range.right)); byte[] initialhash = mts.hash(first); @@ -456,11 +447,15 @@ public class MerkleTreesTest mts.init(); mts2.init(); + int seed = 123456789; // add dummy hashes to both trees - for (TreeRange range : mts.invalids()) - range.addAll(new HIterator(range.right)); - for (TreeRange range : mts2.invalids()) - range.addAll(new HIterator(range.right)); + Random random1 = new Random(seed); + for (TreeRange range : mts.rangeIterator()) + range.addAll(new HIterator(random1, range.right)); + + Random random2 = new Random(seed); + for (TreeRange range : mts2.rangeIterator()) + range.addAll(new HIterator(random2, range.right)); TreeRange leftmost = null; TreeRange middle = null; @@ -468,7 +463,7 @@ public class MerkleTreesTest mts.maxsize(fullRange(), maxsize + 2); // give some room for splitting // split the leftmost - Iterator<TreeRange> ranges = mts.invalids(); + Iterator<TreeRange> ranges = mts.rangeIterator(); leftmost = ranges.next(); mts.split(leftmost.right); @@ -504,7 +499,7 @@ public class MerkleTreesTest while (depth.equals(dstack.peek())) { // consume the stack - hash = Hashable.binaryHash(hstack.pop(), hash); + hash = FBUtilities.xor(hstack.pop(), hash); depth = dstack.pop()-1; } dstack.push(depth); @@ -516,25 +511,47 @@ public class MerkleTreesTest public static class HIterator extends AbstractIterator<RowHash> { - private Iterator<Token> tokens; + private final Random random; + private final Iterator<Token> tokens; - public HIterator(int... tokens) + HIterator(int... tokens) { - List<Token> tlist = new LinkedList<Token>(); + this(new Random(), tokens); + } + + HIterator(Random random, int... tokens) + { + List<Token> tlist = new ArrayList<>(tokens.length); for (int token : tokens) tlist.add(tok(token)); this.tokens = tlist.iterator(); + this.random = random; } public HIterator(Token... tokens) { - this.tokens = Arrays.asList(tokens).iterator(); + this(new Random(), tokens); + } + + HIterator(Random random, Token... tokens) + { + this(random, Arrays.asList(tokens).iterator()); + } + + private HIterator(Random random, Iterator<Token> tokens) + { + this.random = random; + this.tokens = tokens; } public RowHash computeNext() { if (tokens.hasNext()) - return new RowHash(tokens.next(), DUMMY, DUMMY.length); + { + byte[] digest = new byte[32]; + random.nextBytes(digest); + return new RowHash(tokens.next(), digest, 12345L); + } return endOfData(); } } --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
